题目链接:https://atcoder.jp/contests/arc190/tasks/arc190_d
题目大意:给一个 $n*n$ 的矩阵和素数值域 $p$ (接下来所有运算都在 $\mod{p}$ 意义下进行),然后矩阵中 $0$ 的位置可以是 $1\sim p-1$ 的任何一个数字,所以有 $(p-1)^{\mathrm{cnt}}$ 种不同的矩阵,$\mathrm{cnt}$ 表示矩阵中 $0$ 的个数,对于每个矩阵,其对答案的贡献为其的 $p$ 次方,求所有矩阵的贡献和(答案是一个矩阵)。
做法
设每个 $0$ 的位置是 $x_{1},…,x_{\mathrm{cnt}}$ ,那么答案中每个位置的答案就是关于 $x_{1},…,x_{\mathrm{cnt}}$ 的多元多项式,这启示我们去思考 $1^{k}+2^{k}+…+(p-1)^{k}\mod{p}$ 的值,然后打表可以发现只有当 $p-1$ 整除 $k$ 时为 $-1$ ,其余都为 $0$ 。这个的证明可以用那个经典的递推式证明,这里懒得展开了。
在知道这个后,我们就知道我们只关心多元多项式中每个未知数指数为 $0$ 或 $p-1$ 的项,然后就可以讨论了。
- 当 $p=2$ 的时候,可以知道等价于把 $0$ 设置成 $1$ 。
 
- 当 $p≠2$ 的时候,可以知道所有关心的项至多涉及到一个未知数,这个未知数的指数为 $p-1$ ,这个时候接着讨论。如果这个未知数在乘积中存在相邻位置,那么可以知道这个未知数一定在 $(i,i)$ ,且已知数一定在开头和结尾;否则可以知道 $p=3$ ,且未知数落在开头和结尾,中间是已知数。
 
讨论完毕。
时间复杂度:$O(n^3\log{n})$
空间复杂度:$O(n^3)$
和官方题解复杂度一致。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
   | #include<bits/stdc++.h> using namespace std; typedef long long LL; const int N = 1e2 + 5; int n; LL mod; struct node{     LL a[N][N];     node(){         memset(a, 0, sizeof(a));     }     LL* operator[](int x){return a[x];} }; void upd(LL &x, LL y){x = (x + y) % mod;} node operator * (node x, node y){     node z;     for(int i = 1; i <= n; i++){         for(int j = 1; j <= n; j++){             for(int k = 1; k <= n; k++){                 upd(z[i][j], x[i][k] * y[k][j]);             }         }     }     return z; } node ksm(node x, LL k){     node ans;     for(int i = 1; i <= n; i++) ans[i][i] = 1;     while(k){         if(k & 1) ans = ans * x;         x = x * x;         k >>= 1;     }     return ans; } LL ksm(LL x, LL k){     LL ans = 1ll;     while(k){         if(k & 1) ans = ans * x % mod;         x = x * x % mod;         k >>= 1;     }     return ans; } int main(){     cin.sync_with_stdio(false);     cin.tie(0);     cin >> n >> mod;     int cnt = 0;     node x;     for(int i = 1; i <= n; i++){         for(int j = 1; j <= n; j++){             cin >> x[i][j];             if(x[i][j] == 0 && mod == 2){                 x[i][j] = 1;             }             if(x[i][j] == 0) cnt++;         }     }     node ans = ksm(x, mod);                              for(int i = 1; i <= n; i++){         if(x[i][i] == 0){             for(int j = 1; j <= n; j++){                 upd(ans[j][i], x[j][i]);                 upd(ans[i][j], x[i][j]);             }         }     }     if(mod == 3){         for(int i = 1; i <= n; i++){             for(int j = 1; j <= n; j++){                 if(x[i][j] == 0) upd(ans[i][j], x[j][i]);             }         }     }     LL val = ksm(mod - 1, cnt);     for(int i = 1; i <= n; i++){         for(int j = 1; j <= n; j++) cout << ans[i][j] * val % mod << " ";         cout << "\n";     }     return 0; }
   |