天天看点

矩阵加速

矩阵基础

在上面的blog里,我提到了矩阵可以用来优化dp等问题的时间复杂度,这篇blog便来详细说一下。

比如不管是递推还是dp入门都用到了的斐波那契数列,他就可以用矩阵优化加快。

\(fib_1 = 1\)

\(fib_2 = 1\)

\(fib_i = fib_{i-1} + fib_{i-2}\)

求\(fib_n,n\leq 10^9\)

这个明显数据很大,连数组都开不下,我们来考虑矩阵加速。

我们从状态转移方程入手:

\(fib_i = fib_{i-1} + fib_{i-2}\),这说明\(fib_i\)这个地方是由\(1\)个\(fib_{i-1}\)和\(1\)个\(fib_{i-2}\)推过来的,所以

\(\begin{bmatrix}fib_i\\fib_i-1\end{bmatrix} = \begin{bmatrix}1&1\\?&?\end{bmatrix} \times \begin{bmatrix}fib_{i-1}\\fib_{i-2}\end{bmatrix}\)

把这个乘开,则\(fib_{i-1} \times 1 + fib_{i-2} \times 1 = fib_{i - 1} + fib_{i - 2}\),不就正好是\(fib_i\)了吗?

然后左边还有一个\(fib_{i-1}\)没有配好,我们把它配好。

\(fib_{i-1}\)肯定是由一个\(fib_{i-1}\)组成的,所以

\(\begin{bmatrix}fib_i\\fib_{i-1}\end{bmatrix} = \begin{bmatrix}1&1\\1&?\end{bmatrix} \times \begin{bmatrix}fib_{i-1}\\fib_{i-2}\end{bmatrix}\)

剩下的就是0了

\(\begin{bmatrix}fib_i\\fib_{i-1}\end{bmatrix} = \begin{bmatrix}1&1\\1&0\end{bmatrix} \times \begin{bmatrix}fib_{i-1}\\fib_{i-2}\end{bmatrix}\)

然后如果我已经算出来\(\begin{bmatrix}fib_i\\fib_i-1\end{bmatrix}\)了,那我再乘一个\(\begin{bmatrix}1&1\\1&0\end{bmatrix}\),不就算出来\(\begin{bmatrix}fib_{i +1}\\fib_i\end{bmatrix}\)了?这就是矩阵加速。

#include <bits/stdc++.h>
using namespace std;
#define LL long long
LL n, m;
struct node {
    LL r, c, jz[10][10];
    node operator * (const node& rhs) const{
        node ans;
        ans.r = r, ans.c = rhs.c;
        for (int i = 0; i <= 9; i ++)
            for (int j = 0; j <= 9; j ++)
                ans.jz[i][j] = 0;
        for (int i = 1; i <= r; i ++)
            for (int j = 1; j <= rhs.c; j ++)
                for (int k = 1; k <= c; k ++)
                    ans.jz[i][j] = (ans.jz[i][j] + jz[i][k] * rhs.jz[k][j] % m) % m;
        return ans;
    }
}A, B, C;
void prepare (node &ans){
    for (int i = 0; i <= 9; i ++)
        for (int j = 0; j <= 9; j ++)
            ans.jz[i][j] = 0;
}
node qkpow (node x, LL y){
    node ans;
    prepare (ans);
    ans.r = ans.c = 2;
    for (int i = 1; i <= ans.r; i ++)
        ans.jz[i][i] = 1;
    while (y > 0){
        if (y % 2 == 1)
            ans = ans * x;
        x = x * x;
        y /= 2;
    }
    return ans;
}
int main (){
    scanf ("%lld %lld", &n, &m);
    B.r = 2, B.c = 2;
    B.jz[1][1] = 0, B.jz[1][2] = B.jz[2][1] = B.jz[2][2] = 1;
    A.r = 1, A.c = 2;
    A.jz[1][1] = A.jz[1][2] = 1;
    A = A * qkpow (B, n - 2);
    printf ("%lld\n", A.jz[1][2]);
    return 0;
}