题目链接:https://codeforces.com/contest/1906/problem/K

题目大意:在原数组找出两堆异或值相同的数值,问有多少种找法,可以为空,每个数字可以不在任何一堆,两堆有标号。

所有做法的基础

一个显然的事情,这道题目相当于求 $\prod\limits_{i=1}^n(1+2x^{a_i})$ ,这里的乘法是异或卷积。

为了快速计算这个乘积,有了很多种搞法。

题解做法

题解做法的优点就是比较自然。

显然这个可以分治 FWT ,但是 FWT 不同于 FFT ,分治了值域不变,不改变复杂度。

咋整,观察到如果分治区间是 $[l,r)$ ,那么实际上这个区间乘出来的非 $0$ 项只能落在 $[0,r-l),[l,r)$ ,直接拿出左区间的两个非 $0$ 区间和右区间的两个非 $0$ 区间互相乘一下就行了,这样能做到时间复杂度:$O(V\log^2 V)$ 。

代码:https://codeforces.com/contest/1906/submission/235539466

我的做法

我的做法和题解想法基本一致,唯一一点不同的是,对于 $(r-l)=2^l$ ,那么二进制下 $l$ 位都是一样的,因此对于二进制位剩下的位置,要么和 $[l,r)$ 里面的每个数字一样,代表异或了奇数次,要么全是 $0$ ,代表异或了偶数次。

所以我们不妨给每个数字的最高位填个 $1$ ,代表了这个数字异或次数的奇偶性即可。

时间复杂度仍然是:$O(V\log^2 V)$ 。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL mod=998244353;
const int N=1e5+5;
const int L=17;
const int SN=(1<<L);
void FWT(vector<int> &f,const int C[2][2],int len){
assert(f.size()==len);
for(int t=1;t<len;t<<=1){
for(int l=0;l<len;l+=t+t){
int r=l+t;
for(int i=0;i<t;i++){
int x=f[l+i],y=f[r+i];
f[l+i]=(1ll*C[0][0]*x+1ll*C[0][1]*y)%mod;
f[r+i]=(1ll*C[1][0]*x+1ll*C[1][1]*y)%mod;
}
}
}
}

const int Cxor[2][2]={{1,1},{1,mod-1}};
const int Ixor[2][2]={{(mod+1)/2,(mod+1)/2},{(mod+1)/2,mod-(mod+1)/2}};
int n,a[N*2];
LL fc[N],nfc[N],f2[N];
LL C(int x,int y){return fc[x]*nfc[y]%mod*nfc[x-y]%mod;}
vector<array<int,2>> solve(int dep,int l,int r){
if(l==r){
vector<array<int,2> > ans;
ans.push_back({0,0});
for(int i=0;i<=a[l];i++){
if(i&1)ans[0][0]=(ans[0][0]+C(a[l],i)*f2[i])%mod;
else ans[0][1]=(ans[0][1]+C(a[l],i)*f2[i])%mod;
}
return ans;
}
int mid=(l+r)>>1;
auto lans=solve(dep-1,l,mid);
auto rans=solve(dep-1,mid+1,r);
vector<int> lf;lf.resize(1<<(dep+1));
for(int i=0;i<(1<<(dep-1));i++){
auto [f0,f1]=lans[i];
lf[i]=f0;
lf[i+(1<<dep)]=f1;
}
vector<int> rf;rf.resize(1<<(dep+1));
for(int i=0;i<(1<<(dep-1));i++){
auto [f0,f1]=rans[i];
rf[i]=f0;
rf[i+(1<<(dep-1))+(1<<dep)]=f1;
}
FWT(lf,Cxor,(1<<(dep+1)));
FWT(rf,Cxor,(1<<(dep+1)));
for(int i=0;i<(1<<(dep+1));i++)lf[i]=1ll*lf[i]*rf[i]%mod;
FWT(lf,Ixor,(1<<(dep+1)));
vector<array<int,2> > ans;ans.resize(1<<dep);
for(int i=0;i<(1<<dep);i++){
ans[i]={lf[i],lf[i+(1<<dep)]};
}
return ans;
}
int main(){
scanf("%d",&n);
fc[0]=fc[1]=nfc[0]=nfc[1]=f2[0]=1;f2[1]=2;
for(int i=2;i<=n;i++)nfc[i]=(mod-mod/i)*nfc[mod%i]%mod;
for(int i=2;i<=n;i++)nfc[i]=nfc[i-1]*nfc[i]%mod,fc[i]=fc[i-1]*i%mod,f2[i]=f2[i-1]*2%mod;
for(int i=1;i<=n;i++){
int x;scanf("%d",&x);
a[x]++;
}
auto ans=solve(L,0,(1<<L)-1);
printf("%lld\n",(ans[0][0]+ans[0][1])%mod);
return 0;
}
我的做法改进

发现一个事情,矩阵为 $\begin{bmatrix} 1 & 1\ 1 & -1 \end{bmatrix}$ 的 FWT (即异或的 FWT)最终的结果其实可以写成这样:

利用这个式子,可以改进分治 FWT 。

不妨假设:$(r-l)=2^l$ ,显然左区间的值域在 $[0,2^l)$ (高位补了 $1$ 判断奇偶性),现在要扩充到 $[0,2^{l+1})$ ,扩充规则为给原来的每个下表在最高位(算前导 $0$ )下面塞个 $0$ (右区间塞 $1$) ,然后在剩下的位置补 $0$ 。

然后根据上面的式子可以利用变化前的点值,在线性时间得到变化后的点值。(具体见代码)

然后直接乘就行了,时间复杂度:$O(V\log{V})$ 。

这个做法相比较于下面的做法,或许不是最妙的,但是是最适用的,因为不依赖于系数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL mod=998244353;
const int N=1e5+5;
const int L=17;
const int SN=(1<<L);
void FWT(LL *f,const LL C[2][2],int len){
for(int t=1;t<len;t<<=1){
for(int l=0;l<len;l+=t+t){
int r=l+t;
for(int i=0;i<t;i++){
LL x=f[l+i],y=f[r+i];
f[l+i]=(C[0][0]*x+C[0][1]*y)%mod;
f[r+i]=(C[1][0]*x+C[1][1]*y)%mod;
}
}
}
}

const LL Cxor[2][2]={{1,1},{1,mod-1}};
const LL Ixor[2][2]={{(mod+1)/2,(mod+1)/2},{(mod+1)/2,mod-(mod+1)/2}};
int n,a[N*2];
LL fc[N],nfc[N],f2[N];
LL C(int x,int y){return fc[x]*nfc[y]%mod*nfc[x-y]%mod;}
LL f[2][N*2];
void solve(){
for(int i=0;i<(1<<L);i++){
for(int j=0;j<=a[i];j++){
if(j&1)f[1][i]=(f[1][i]+C(a[i],j)*f2[j])%mod;
else f[0][i]=(f[0][i]+C(a[i],j)*f2[j])%mod;
}
LL lf=f[0][i],rf=f[1][i];
f[0][i]=(lf+rf)%mod;
f[1][i]=(lf+mod-rf)%mod;
}
for(int t=1;t<=L;t++){
for(int i=0;i<(1<<L);i+=(1<<t)){
for(int l=i;l<i+(1<<(t-1));l++){
int r=l+(1<<(t-1));
LL l0=f[0][l],l1=f[1][l];
LL l00=l0,l01=l0;
LL l10=l1,l11=l1;

LL r0=f[0][r],r1=f[1][r];
LL r00=r0,r01=r1;
LL r10=r1,r11=r0;

f[0][l]=l00*r00%mod;
f[0][r]=l01*r01%mod;
f[1][l]=l10*r10%mod;
f[1][r]=l11*r11%mod;
}
}
}
FWT(f[0],Ixor,(1<<L));
}
int main(){
scanf("%d",&n);
fc[0]=fc[1]=nfc[0]=nfc[1]=f2[0]=1;f2[1]=2;
for(int i=2;i<=n;i++)nfc[i]=(mod-mod/i)*nfc[mod%i]%mod;
for(int i=2;i<=n;i++)nfc[i]=nfc[i-1]*nfc[i]%mod,fc[i]=fc[i-1]*i%mod,f2[i]=f2[i-1]*2%mod;
for(int i=1;i<=n;i++){
int x;scanf("%d",&x);
a[x]++;
}
solve();
printf("%lld\n",f[0][0]);
return 0;
}
深刻观察法

来自比赛 Announcement 评论区。

发现一个事情,矩阵为 $\begin{bmatrix} 1 & 1\ 1 & -1 \end{bmatrix}$ 的 FWT (即异或的 FWT)最终的结果其实可以写成这样:

同时又观察到,结果只能是 $-1$ 或者 $3$ ,那么这有什么用呢?

思考一下,FWT 和 FFT 有一个很重要的不同,就是 FWT 不需要扩展数组,因为下标值域不会扩展,所以 FWT 实际上能算出所有多项式的点值表达式直接乘起来然后再逆回去,而 FFT 是不行的(除非一开始就把所有的多项式算出充足的点值)。

但是算出所有多项式的点值表达式的时间开销仍然很大,第一种做法采用了分治 FWT 来加速这个过程,但是这里,我们直接观察出了 FWT 后的结果长啥样,那我们是不是可以不用 FWT ,直接算出结果呢?

显然,求 $\prod f_k[i]$ 只需要算出有多少个 $-1$ 或者有多少个 $3$ 就行了。

有了这个思路,就有很多种搞法了。

我使用的做法是:$n=even+odd$ ,那么只需要令 $a[x]=x的出现次数$ ,然后直接跑 FWT ,就可以知道每个位置的 $even-odd$ ,然后就可以直接算出来 $-1$ 和 $3$ 的个数了。

还有别的搞法,例如:SOS dp,但是因为感觉这个的转移式子和 FWT 没什么本质区别,就不再赘述了,放个这个做法的代码: https://codeforces.com/contest/1906/submission/235477273

时间复杂度:$O(V\log{V})$ 的,空间复杂度:$O(V)$,$V$ 是值域。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const LL mod=998244353;
const int N=1e5+5;
const int NN=2e5+5;
const int L=(1<<17);
void FWT(LL *f,const LL C[2][2],int len){
for(int t=1;t<len;t<<=1){
for(int l=0;l<len;l+=t+t){
int r=l+t;
for(int i=0;i<t;i++){
LL x=f[l+i],y=f[r+i];
f[l+i]=(C[0][0]*x+C[0][1]*y)%mod;
f[r+i]=(C[1][0]*x+C[1][1]*y)%mod;
}
}
}
}
int n;
LL f3[N],a[NN];
LL Cxor[2][2]={{1,1},{1,mod-1}};
LL Ixor[2][2]={{(mod+1)/2,(mod+1)/2},{(mod+1)/2,mod-(mod+1)/2}};
int main(){
scanf("%d",&n);
f3[0]=1;
for(int i=1;i<=n;i++){
int x;scanf("%d",&x);
a[x]++;
f3[i]=f3[i-1]*3%mod;
}
FWT(a,Cxor,L);
for(int i=0;i<L;i++){
int num=a[i];
if(num>n)num-=mod;
//n=f+z,num=z-f;
int z=(n+num)/2,f=(n-num)/2;
if(f&1)a[i]=(mod-f3[z])%mod;
else a[i]=f3[z];
}
FWT(a,Ixor,L);
printf("%lld\n",a[0]);
return 0;
}