题目链接:https://atcoder.jp/contests/arc190/tasks/arc190_c

题目大意:一个 $n*m$ 的网格,从左上走到右下且只能向下或者向右走的路径的权值为路径点权乘积,每次会从当前点往四个方向走一格并且修改当前位置的权值,输出每次移动后的所有路径权值和。

做法

显然会考虑根号做法,不妨假设 $n\le m$ ,我们可以发现,设 $l[i][j]$ 表示 $(1,1)\to (i,j)$ 的所有路径权值和,$r[i][j]$ 表示 $(i,j)\to (n,m)$ 的所有路径权值和。

设当前位置为 $(nx,ny)$ ,然后正确的维护 $l[1\sim n][1\sim ny],r[1\sim n][ny+1\sim m]$ ,可以发现每次移动的修改量是 $O(n)$ ,而且显然根据这个信息可以在 $O(n)$ 的时间得到答案,所以时间复杂度就是 $O(nm+\min(n,m)q)$ 。

至于我怎么想到的?想到是根号复杂度后感觉这个做法就比较自然了,大概。

和官方题解复杂度一致。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
typedef vector<LL> VL;
const LL mod = 998244353;
const int N = 2e5 + 5;
const int B = 5e2 + 5;
int n, m, q, nx, ny, type = 0;
int dx[] = {-1, 0, 1, 0};
int dy[] = {0, 1, 0, -1};
int be[300];
VL a[B], ll[B], rr[B];
void update(int col, int dir){
if(col == m + 1) return ;
if(!dir){
for(int i = 1; i <= n; i++){
ll[i][col] = (ll[i][col - 1] + ll[i - 1][col]) * a[i][col] % mod;
}
}
else{
for(int i = n; i >= 1; i--){
rr[i][col] = (rr[i][col + 1] + rr[i + 1][col]) * a[i][col] % mod;
}
}
}
void printans(){
LL ans = 0ll;
for(int i = 1; i <= n; i++) ans = (ans + ll[i][ny] * rr[i][ny + 1]) % mod;
cout << ans << "\n";
}
int main(){
be['U'] = 0; be['R'] = 1; be['D'] = 2; be['L'] = 3;
cin >> n >> m;
if(n > m){
type = 1;
swap(n, m);
}
for(int i = 0; i <= n + 1; i++){
a[i].resize(m + 2);
ll[i].resize(m + 2);
rr[i].resize(m + 2);
}
ll[1][0] = rr[n][m + 1] = 1ll;
if(!type){
for(int i = 1; i <= n; i++){
for(int j = 1; j <= m; j++) cin >> a[i][j];
}
}
else{
for(int j = 1; j <= m; j++){
for(int i = 1; i <= n; i++) cin >> a[i][j];
}
}
cin >> q >> nx >> ny;
if(type) swap(nx, ny);
for(int i = 1; i <= ny; i++) update(i, 0);
for(int i = m; i > ny; i--) update(i, 1);
// for(int i = 1; i <= n; i++){
// for(int j = 1; j <= m; j++) cout << a[i][j] << " ";
// cout << "\n";
// }
// printans();
for(int i = 1; i <= q; i++){
char st[10]; LL tmp;
cin >> st >> tmp;
int t = be[st[0]];
if(type == 1) t = 3 - t;
nx += dx[t], ny += dy[t];
a[nx][ny] = tmp;
assert(nx >= 1 && ny >= 1 && nx <= n && ny <= m);
update(ny, 0); update(ny + 1, 1);
printans();
}
return 0;
}