跳转至

矩阵快速幂

约 342 个字 55 行代码 预计阅读时间 2 分钟

矩阵快速幂

给出底矩阵和指数,快速求出幂矩阵

时间复杂度

O(t^3*log n)

把对整数的操作移植到矩阵上来

常见多项式抽象

f(n) = a * f(n-1) + b * f(n-2) +c

\begin{pmatrix}a&b&1\\1&0&0\\0&0&1\end{pmatrix}* \begin{pmatrix}f(n-1)\\f(n-2)\\c\end{pmatrix}= \begin{pmatrix}f(n)\\f(n-1)\\c\end{pmatrix}

f(n) = f(n-1) + c^n

\begin{pmatrix}1&c\\0&c\end{pmatrix}* \begin{pmatrix}f(n-1)\\c^{n-1}\end{pmatrix}= \begin{pmatrix}f(n)\\c^{n}\end{pmatrix}

f(n)= f(n-1) + n^3

n^3= (n - 1 + 1)^3 ,二项式展开得系数

\begin{pmatrix}1&C(0,3)&C(1,3)&C(2,3)&C(3,3)\\ 0&C(0,3)&C(1,3)&C(2,3)&C(3,3)\\ 0&0&C(0,2)&C(1,2)&C(2,2)\\ 0&0&0&C(0,1)&C(1,1)\\ 0&0&0&0&C(0,0)\end{pmatrix}* \begin{pmatrix}f(n-1)\\(n-1)^3\\(n-1)^2\\(n-1)^1\\(n-1)^0\end{pmatrix}= \begin{pmatrix}f(n)\\n^3\\n^2\\n^1\\n^0\end{pmatrix}

一个经典题

给定图,问A到B恰好经过k步的走法总数

当i能一步到达j时,邻接矩阵第i行第j列设为1,邻接矩阵的k次幂,第i行第j列表示i到j恰好经过k步的走法总数

正确性可以观察一下这个位置由谁贡献来

板子

不要用 STL 里的matrix,比较慢

注意稠密矩阵和稀疏矩阵的乘法处理要不一样

const int N = 4;
struct matrix
{
    ll m[N][N];
    void init()//单位矩阵
    {
        memset(m,0,sizeof(m));
        for(int i=0;i<N;++i)
            m[i][i]=1;
    }
    matrix operator*(matrix y)
    {
        matrix c;
        memset(c.m,0,sizeof(c.m));
        //稠密矩阵
        for(int i=0;i<N;++i)
        {
            for(int j=0;j<N;++j)
            {
                for(int k=0;k<N;++k)
                    c.m[i][j] = (c.m[i][j] + m[i][k] * y.m[k][j] % mod) % mod;
            }
        }
        //稀疏矩阵(指矩阵中很多数为0)
        for(int i=0;i<N;++i)
        {
            for(int j=0;j<N;++j)
            {
                if(!m[i][j])
                    continue;
                for(int k=0;k<N;++k)
                    c.m[i][k] = (c.m[i][k] + m[i][j] * y.m[j][k] % mod) % mod;
            }
        }
        return c;
    }
    matrix operator^(ll exponent)
    {
        matrix result;
        result.init();
        for(;exponent;exponent>>=1)
        {
            if(exponent&1)
                result = result*(*this);
            (*this) = (*this)*(*this);
        }
        return result;
    }
};
matrix base{1,1,1,2,
            0,1,1,2,
            0,1,0,0,
            0,1,0,1};
base=base^(n-2);
ll ans=(((base.m[0][0]*2%mod+base.m[0][1])%mod+base.m[0][2])%mod+base.m[0][3])%mod;