高维前缀和与子集求和(SOS-DP)

“自以为历尽沧桑,其实刚蹒跚学步;自以为掌握了竞争的秘密,其实远没有竞争的资格。”——《三体》

高维前缀和

一般而言,前缀和是我们处理问题的常见技巧,比如能使我们在 的时间内得到区间和,也常用在类似于 (树状数组)之类的数据结构里,在数论中也很常见,比如说整数分块。

高维前缀和一般用于快速莫比乌斯变换()和子集求和()。

前缀和一般写法

对于 维的前缀和,我们可以这样表示:

参考代码
1
for(int i=1;i<=n;++i) sum[i]=sum[i-1]+a[i]

这是常见的形式,实现线性。

维前缀需要用到一点点容斥的技巧:

参考代码
1
2
3
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)
s[i][j]=s[i-1][j]+s[i][j-1]-s[i-1][j-1]+a[i][j];

实现复杂度为

再往高维走,对于 维,则容斥会变得有些许麻烦。

参考代码
1
2
3
4
for(int i=1;i<=n;i++)
for(int j=1;j<=m;j++)
for(int k=1;k<=p;k++)
s[i][j][k]=s[i-1][j][k]+s[i][j-1][k]+s[i][j][k-1]-s[i-1][j-1][k]-s[i-1][j][k-1]-s[i][j-1][j-1]+s[i-1][j-1][k-1]+a[i][j][k];

容易发现,设第 维的大小为 ,一共 维,则容斥的复杂度为 ,而整个 维前缀和的时间复杂度为 ,如果各维同阶,则时间复杂度为 ,效果显然,极其不优秀。

前缀和另一写法

有一种更为高效的写法:

维不变,我们考虑 维的写法:

参考代码
1
2
3
4
5
6
7
std::memcpy(s,a,sizeof(a));
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j)
s[i][j]+=s[i-1][j];
for(int i=1;i<=n;++i)
for(int j=1;j<=m;++j)
s[i][j]+=s[i][j-1];

看起来十分冗长,但如果是对于 维而言:

参考代码
1
2
3
4
5
6
7
8
9
10
11
12
13
std::memcpy(s,a,sizeof(a));
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)
for(int k=1;k<=n;++k)
s[i][j][k]+=s[i-1][j][k];
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)
for(int k=1;k<=n;++k)
s[i][j][k]+=s[i][j-1][k];
for(int i=1;i<=n;++i)
for(int j=1;j<=n;++j)
for(int k=1;k<=n;++k)
s[i][j][k]+=s[i][j][k-1];

容易发现,比容斥要更好理解,时间复杂度也相对更优一些。

维前缀和的时间复杂度为 ,在 的时候优于上一种写法。

这种写法启发我们在求解高维问题的时候应当一维一维地处理。


子集求和

称为 ,简称为 ,也可以认为是基于 思想的高维前缀和计算。

一般用于对下标进行限制的问题求解。

二维子集求和

考虑下面这个问题:

个物品,如果每一个物品都有取或不取两种状态,用 表示,那么选取的情况就会组成集合,集合又会组成我们的方案集合,为: ,那我们用二进制位表示,则是四个数 ,如果当前我们要求某一个集合及其子集的权值和,可以定义为 表示选取情况为 的子集权值和,则:

那么, 就是我们想要的答案,而这个形式也类似于我们的前缀和处理,所以考虑这样做来解决子集求和的问题。

多维子集求和

若让我们求 的话,直接暴力枚举子集的时间复杂度会高达 ,这是不可接受的。考虑如何优化。

同样的,我们并不枚举每一个集合的子集,而是通过枚举子集来得到答案,因为任何布尔集合都可以使用一个二进制数来表示,所以我们考虑一个状态压缩的过程,在一个更为优秀的时间内的得到答案。

设一共有 个物品,则:

状态压缩实现参考
1
2
3
for(int i=0;i<n;++i)
for(int s=0;s<(1<<n);++s)
if(s&(1<<i)) Dp[s]+=Dp[s^(1<<i)];

这样的话,枚举子集的工作就可以在 的时间内解决。

解释一下上面的代码:第一层 for 枚举物品(由于状压众所周知的原因所以从 开始标号),第二层 for 枚举集合(一共有 个集合),if 判断当前集合 s 是否包含了第 i 个物品(用位运算解决),如果包含了,则去掉 i 的集合 s(表示为 s^(1<<i))即是 s 的子集。

例题

Or Plus Max

题目简介

题目名称:

题目来源:

评测链接:https://atcoder.jp/contests/arc100/tasks/arc100_c

形式化题意:给定一个长度为 的序列 ,对于 ,求出 ,满足

数据范围:

显然 的范围以及 的提示就是告诉我们这道题要进行状态压缩,因为是对于所有 都要求答案,而显然对于 的条件,容易发现,我们可以进行递推得到答案,因为状压的过程也会使 递进。

记录 表示选取集合为 包含的最大值和次大值,就可以从 x^(1<<i) 转移到 x

注意 ,不能选的时候把 选重了。

这是求高维前缀最大值的经典例题。

AC Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// ----- Eternally question-----
// Problem: [ARC100E] Or Plus Max
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/AT_arc100_c
// Memory Limit: 1 MB
// Time Limit: 2000 ms
// Written by: Eternity
// Time: 2023-01-05 16:27:12
// ----- Endless solution-------

#include<bits/stdc++.h>
#define re register
typedef long long ll;
template<class T>
inline void read(T &x)
{
x=0;
char ch=getchar(),t=0;
while(ch<'0'||ch>'9') t|=ch=='-',ch=getchar();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
if(t) x=-x;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1){ read(x),read(x1...); }
template<class T>
inline void write(T x)
{
if(x<0) putchar('-'),x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
template<>
inline void write(bool x){ putchar(x?'1':'0'); }
template<>
inline void write(char c){ putchar(c); }
template<>
inline void write(char *s){ while(*s!='\0') putchar(*s++); }
template<>
inline void write(const char *s){ while(*s!='\0') putchar(*s++); }
template<class T,class ...T1>
inline void write(T x,T1 ...x1){ write(x),write(x1...); }
template<class T>
inline bool checkMax(T &x,T y){ return x<y?x=y,1:0; }
template<class T>
inline bool checkMin(T &x,T y){ return x>y?x=y,1:0; }
const int MAXN=19,MAXS=(1<<18)+10;
const int INF=0x3f3f3f3f;
int N,a[MAXS];
struct Node
{
int fir,sec;
Node(){ fir=sec=-INF; }
}Dp[MAXS];
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
read(N);
for(int i=0;i<(1<<N);++i) read(a[i]),Dp[i].fir=a[i],Dp[i].sec=-INF;
for(int i=0;i<N;++i)
for(int s=0;s<(1<<N);++s) if(s>>i&1)
{
int lst=s^(1<<i);
Node tmp;
if(Dp[s].fir>Dp[lst].fir)
{
tmp.fir=Dp[s].fir;
tmp.sec=std::max(Dp[s].sec,Dp[lst].fir);
}
else
{
tmp.fir=Dp[lst].fir;
tmp.sec=std::max(Dp[s].fir,Dp[lst].sec);
}
Dp[s]=tmp;
}
int ans=-INF;
for(int k=1;k<(1<<N);++k)
{
checkMax(ans,Dp[k].fir+Dp[k].sec);
write(ans,'\n');
}
return 0;
}
/*

*/

Compatible Numbers

题目简介

题目名称:

题目来源:

评测链接:https://codeforces.com/problemset/problem/165/E

形式化题意:给定一个长度为 的序列 ,对于 求出一个 使得 ,如果无解输出 -1。带

数据范围:

考虑 的性质,与 相容的 一定是 的子集,这是显然的。那考虑高维前缀和的过程,用 表示 的子集中在 里出现过的数(任一),按照状压形式转移,把求和变为赋值即可。

AC Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// ----- Eternally question-----
// Problem: Compatible Numbers
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/CF165E
// Memory Limit: 250 MB
// Time Limit: 4000 ms
// Written by: Eternity
// Time: 2023-01-05 19:43:04
// ----- Endless solution-------

#include<bits/stdc++.h>
#define re register
typedef long long ll;
template<class T>
inline void read(T &x)
{
x=0;
char ch=getchar(),t=0;
while(ch<'0'||ch>'9') t|=ch=='-',ch=getchar();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
if(t) x=-x;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1){ read(x),read(x1...); }
template<class T>
inline void write(T x)
{
if(x<0) putchar('-'),x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
template<>
inline void write(bool x){ putchar(x?'1':'0'); }
template<>
inline void write(char c){ putchar(c); }
template<>
inline void write(char *s){ while(*s!='\0') putchar(*s++); }
template<>
inline void write(const char *s){ while(*s!='\0') putchar(*s++); }
template<class T,class ...T1>
inline void write(T x,T1 ...x1){ write(x),write(x1...); }
template<class T>
inline bool checkMax(T &x,T y){ return x<y?x=y,1:0; }
template<class T>
inline bool checkMin(T &x,T y){ return x>y?x=y,1:0; }
const int MAXS=(1<<22)+10,MAXN=1e6+10,S=22;
int N,a[MAXN];
int Dp[MAXS];
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
read(N);
std::memset(Dp,-1,sizeof(Dp));
for(int i=1;i<=N;++i) read(a[i]),Dp[a[i]]=a[i];
for(int s=0;s<(1<<S);++s)
{
if(Dp[s]!=-1) continue;
for(int i=0;i<S;++i) if(s>>i&1)
if(Dp[s^(1<<i)]!=-1){ Dp[s]=Dp[s^(1<<i)];break; }
}
for(int i=1;i<=N;++i) write(Dp[((1<<S)-1)&(~a[i])],' ');
return 0;
}
/*

*/