Basics of Factorization And Combinatorics

// Usage:
//      Math ma;
//      ma.prepare(n);
//      usage ma.getA(n, k), ma.fac[i], ma.twop[i], ma.iv_twop[i]
struct Math 
{    
    int n;
    int MOD;
    int SMALL_N;
    
    ll *r;
    ll *fac, *ifac;
    ll **A, **C;
    ll *twop, *iv_twop;
    
    Math() {
        n = -1;
        MOD = 1;
        SMALL_N = 3030;
    }

    void prepare(int n_, int MOD_) {
        n = n_;
        MOD = MOD_;
        
        // inverse modulo of 1 -> n (% MOD). Assume n < m.
        r = new ll[n+1];
        r[1] = 1;
        for (int i = 2; i <= n; ++i)
            r[i] = (MOD - (MOD/i) * r[MOD%i] % MOD) % MOD;
        
        
        // factorial and its inverse modulo
        fac = new ll[n+1];
        fac[0] = 1;
        for (int i = 1; i <= n; ++i)
            fac[i] = fac[i-1] * i % MOD;
        
        ifac = new ll[n+1];
        ifac[0] = ifac[1] = 1;
        for (int i = 2; i <= n; ++i)
            ifac[i] = ifac[i-1] * r[i] % MOD;
        
        // 2^ and its inverse
        twop = new ll[n+1];
        iv_twop = new ll[n+1];

        twop[0] = iv_twop[0] = 1;
        twop[1] = 2; iv_twop[1] = (MOD+1)/2; // This is True when MOD is odd
        for (int i = 2; i <= n; ++i) {
            twop[i] = twop[i-1] * 2 % MOD;
            iv_twop[i] = iv_twop[i-1] * iv_twop[1] % MOD;
        }
        
        
        // A
        if (n <= SMALL_N) {
            A = new ll*[n+1];
            for (int i = 0; i <= n; ++i) {
                A[i] = new ll[n+1];
                fill(A[i], A[i] + n + 1, 0);
            }
            
            for (int i = 0; i <= n; ++i) 
                A[i][0] = 1;
            for (int i = 1; i <= n; ++i)
                for (int j = 1; j <= i; ++j)
                    A[i][j] = A[i][j-1] * (i-j+1) % MOD;
        }
        
        
        // C
        if (n <= SMALL_N) {
            C = new ll*[n+1];
            for (int i = 0; i <= n; ++i) {
                C[i] = new ll[n+1];
                fill(C[i], C[i] + n + 1, 0);
            }
            
            for (int i = 0; i <= n; ++i)
                C[i][0] = C[i][i] = 1;
            for (int i = 1; i <= n; ++i)
                for (int j = 1; j < i; ++j)
                    C[i][j] = (C[i-1][j] + C[i-1][j-1]) % MOD;
        }
        

    }

    ll getA(int n_, int k_) {
        if (k_ < 0 || n_ < 0)
            return 0;
        
        if (n_ <= n && n <= SMALL_N)
            return A[n_][k_];
        
        return fac[n_] * ifac[n_ - k_] % MOD;
        
    }
    
    ll getC(int n_, int k_) {
        if (k_ < 0 || n_ < 0)
            return 0; 
        
        if (n_ <= n && n <= SMALL_N)
            return C[n_][k_];
        
        return fac[n_] * ifac[n_ - k_] % MOD * ifac[k_] % MOD;
        
    }

};

Leave a Reply