多项式取模与任意模数多项式卷积(MTT)

“一棵大树,枝叶繁茂,根却因为种种原因坏死,这棵树也只会盖着一头枯枝烂叶。”——毛啸

观前提醒

前置知识:

本文可能存在大量的 ,请确保数学公式渲染完成后再进行阅读,否则会降低阅读体验。


MTT

我们已经学过了两种可以在 时间内求得多项式卷积的方法—— ,其各有各的优点,也各有各的缺点。

  1. 使用复数实现,由于 C++ 浮点数的运算问题,在进行较大数值的运算时可能会损失精度造成答案错误。且无法进行多项式取模。

  2. 使用原根实现,在模运算下进行运算,由于是整型运算,所以不存在精度缺失的问题,但运算值域被限制在模数以内。

那考虑以下问题:

题目简介

题目名称:任意模数多项式乘法
题目来源:

评测链接:https://www.luogu.com.cn/problem/P4245

给定一个 次多项式 和一个 次多项式 ,求 取模的结果。(不保证 是质数)

数据范围:

容易发现,当系数卡满的时候,系数值域会达到 之多,甚至已经超过了 long long 的存储范围。

虽然可以使用 __int128_t 或者高精度,但那样的实现极其麻烦,而且对原根的限制也很大。

对于 ,精度缺失在如此大的值域中是不可避免的事;对于 ,受模数限制,也不可能做如此大的卷积,更何况是任意模数,逆元可能会直接失效。

这时候便需要任意模数多项式卷积算法()的救场了。

有广为流传的两种写法,拆分系数 和三模数 ,笔者会一一进行详解。当然,其中减少变换次数的优化也包括在内。

关于算法的名字由来

这两种多项式卷积变式算法最先出现在国家集训队论文 ,由雅礼中学的毛啸提出,所以取其名字与算法意义得到了该名字,可以译为“快速毛啸数论变换”。


三模数 NTT

这一种可能更好理解一些。

算法如其名,利用三个能够使用 的模数进行卷积运算,会得到三个模意义下不同的多项式,根据我们的同余知识,我们会得到一个三元同余方程,可以使用中国剩余定理来还原出原序列,再对其进行最后的取模运算。

当然,一般而言,这三个模数都是 级别的(至少乘起来要是 以上级别的),一般选作:

这三个质数的原根都是 ,方便计算。

当我们计算得到答案得到三个不同的多项式 ,设真正的第 项系数为 ,则需要求解一个同余方程:

但如果我们直接使用 进行合并的时候,其运算依然会在 long long 中溢出,所以我们考虑能不能中途取模减小运算规格。

考虑 的计算方式,将两个同余方程进行合并并弹回原方程组直到只剩下一个方程从而得到同余方程解。

考虑将同余方程写成普通方程的形式,得到:

然后再逆转回同余方程得到:

所以我们先得到 合并的解设为 ,再将 进行同余方程求解,合并:

从而得到三组模数的通解公式:

当然,这个答案满足轮换性。

AC Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// ----- Eternally question-----
// Problem: P4245 【模板】任意模数多项式乘法
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P4245
// Memory Limit: 500 MB
// Time Limit: 2000 ms
// Written by: Eternity
// Time: 2023-01-12 21:04:36
// ----- Endless solution-------

#include<bits/stdc++.h>
#define re register
typedef long long ll;
template<class T>
inline void read(T &x)
{
x=0;
char ch=getchar(),t=0;
while(ch<'0'||ch>'9') t|=ch=='-',ch=getchar();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
if(t) x=-x;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1){ read(x),read(x1...); }
template<class T>
inline void write(T x)
{
if(x<0) putchar('-'),x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
template<>
inline void write(bool x){ putchar(x?'1':'0'); }
template<>
inline void write(char c){ putchar(c); }
template<>
inline void write(char *s){ while(*s!='\0') putchar(*s++); }
template<>
inline void write(const char *s){ while(*s!='\0') putchar(*s++); }
template<class T,class ...T1>
inline void write(T x,T1 ...x1){ write(x),write(x1...); }
template<class T>
inline bool checkMax(T &x,T y){ return x<y?x=y,1:0; }
template<class T>
inline bool checkMin(T &x,T y){ return x>y?x=y,1:0; }
const int MAXN=4e5+10;
const ll P[]={0,469762049,998244353,1004535809};
const ll Rt=3;
int N,M;
ll Mod,F[MAXN],G[MAXN];
int Rev[MAXN],Tot,Bit;
inline ll qPow(ll a,ll b,ll p)
{
ll res=1;a%=p;
while(b)
{
if(b&1) res=res*a%p;
a=a*a%p;b>>=1;
}
return res;
}
inline void NTT(ll a[],int n,ll p,int inv)
{
for(int i=0;i<n;++i)
if(i<Rev[i]) std::swap(a[i],a[Rev[i]]);
ll invG=qPow(Rt,p-2,p);
for(int mid=1;mid<n;mid<<=1)
{
ll w1=qPow(inv==1?Rt:invG,(p-1)/(mid<<1),p);
for(int i=0;i<n;i+=mid*2)
{
ll wk=1;
for(int j=0;j<mid;++j,wk=wk*w1%p)
{
ll x=a[i+j],y=a[i+j+mid]*wk%p;
a[i+j]=(x+y)%p,a[i+j+mid]=(x-y+p)%p;
}
}
}
if(inv==-1)
{
ll iv=qPow(n,p-2,p);
for(int i=0;i<n;++i) a[i]=a[i]*iv%p;
}
}
ll a[MAXN],b[MAXN];
inline void Mul(ll f[],ll g[],int n,ll p,ll ans[])
{
std::memcpy(a,f,n*sizeof(ll));
std::memcpy(b,g,n*sizeof(ll));
NTT(a,n,p,1),NTT(b,n,p,1);
for(int i=0;i<n;++i) a[i]=a[i]*b[i]%p;
NTT(a,n,p,-1);
std::memcpy(ans,a,n*sizeof(ll));
}
ll ans[4][MAXN];
signed main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
read(N,M,Mod);
for(int i=0;i<=N;++i) read(F[i]);
for(int i=0;i<=M;++i) read(G[i]);
while((1<<Bit)<=N+M) ++Bit;
Tot=1<<Bit;
for(int i=0;i<Tot;++i) Rev[i]=(Rev[i>>1]>>1)|((i&1)<<(Bit-1));
for(int k=1;k<=3;++k) Mul(F,G,Tot,P[k],ans[k]);
for(int i=0;i<=N+M;++i)
{
ll x=ans[1][i]+((ans[2][i]-ans[1][i]+P[2])*qPow(P[1],P[2]-2,P[2])%P[2])*P[1];
ll s=(x%Mod+(ans[3][i]-x%P[3]+P[3])*qPow(P[1]*P[2],P[3]-2,P[3])%P[3]*P[1]%Mod*P[2]%Mod)%Mod;
write(s,' ');
}
return 0;
}
/*

*/

好写,好背,但是常数可能会比较大。注意常数优化,还要记得处处取模(例如快速幂的预取模),否则要挂惨(指只剩下 )。


拆解系数 FFT

对于需要计算的两个多项式 。令其系数分别为: ,设定一个阈值为 ,并记录: ,就会得到:

这样的话,我们就可以通过分别计算 四次卷积后就可以得到原来的多项式卷积。当我们的阈值设定为 的话,则四次项都是 级别的,值域范围在 内,并不会造成过大的精度缺失。

对于洛谷的模板题而言,卡了精度使得 double 并不能过,所以需要使用 long double ,但 C++ 自带的三角函数是 double 类型的,所以需要使用 std::sinstd::cos 来解决。

还是要注意取模,否则可能会运算溢出。

AC Code
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// ----- Eternally question-----
// Problem: P4245 【模板】任意模数多项式乘法
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P4245
// Memory Limit: 500 MB
// Time Limit: 2000 ms
// Written by: Eternity
// Time: 2023-01-13 07:56:27
// ----- Endless solution-------

#include<bits/stdc++.h>
#define re register
typedef long long ll;
typedef long double ld;
template<class T>
inline void read(T &x)
{
x=0;
char ch=getchar(),t=0;
while(ch<'0'||ch>'9') t|=ch=='-',ch=getchar();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
if(t) x=-x;
}
template<class T,class ...T1>
inline void read(T &x,T1 &...x1){ read(x),read(x1...); }
template<class T>
inline void write(T x)
{
if(x<0) putchar('-'),x=-x;
if(x>9) write(x/10);
putchar(x%10+'0');
}
template<>
inline void write(bool x){ putchar(x?'1':'0'); }
template<>
inline void write(char c){ putchar(c); }
template<>
inline void write(char *s){ while(*s!='\0') putchar(*s++); }
template<>
inline void write(const char *s){ while(*s!='\0') putchar(*s++); }
template<class T,class ...T1>
inline void write(T x,T1 ...x1){ write(x),write(x1...); }
template<class T>
inline bool checkMax(T &x,T y){ return x<y?x=y,1:0; }
template<class T>
inline bool checkMin(T &x,T y){ return x>y?x=y,1:0; }
const int MAXN=4e5+10;
const ld pi=std::acos(-1);
const int Base=1<<15;
int N,M,Mod;
struct Complex
{
ld x,y;
Complex operator+(const Complex &a) const
{ return {x+a.x,y+a.y}; }
Complex operator-(const Complex &a) const
{ return {x-a.x,y-a.y}; }
Complex operator*(const Complex &a) const
{ return {x*a.x-y*a.y,x*a.y+y*a.x}; }
}A[MAXN],B[MAXN],C[MAXN],A1[MAXN],A2[MAXN],B1[MAXN],B2[MAXN];
ll F[MAXN],G[MAXN],ans[MAXN];
int Rev[MAXN],Tot,Bit;
inline void FFT(Complex a[],int n,int inv)
{
for(int i=0;i<n;++i)
if(i<Rev[i]) std::swap(a[i],a[Rev[i]]);
for(int mid=1;mid<n;mid<<=1)
{
auto w1=Complex({std::cos(pi/mid),inv*std::sin(pi/mid)});
for(int i=0;i<n;i+=mid*2)
{
auto wk=Complex({1,0});
for(int j=0;j<mid;++j,wk=wk*w1)
{
auto x=a[i+j],y=a[i+j+mid]*wk;
a[i+j]=x+y,a[i+j+mid]=x-y;
}
}
}
if(inv==-1) for(int i=0;i<n;++i) a[i].x/=n;
}
inline void MTT(ll f[],ll g[],ll ans[],int n,int p)
{
for(int i=0;i<n;++i)
{
A1[i].x=f[i]/Base,A2[i].x=f[i]%Base;
B1[i].x=g[i]/Base,B2[i].x=g[i]%Base;
}
FFT(A1,n,1),FFT(A2,n,1),FFT(B1,n,1),FFT(B2,n,1);
for(int i=0;i<n;++i)
{
A[i]=A1[i]*B1[i];
B[i]=A1[i]*B2[i]+A2[i]*B1[i];
C[i]=A2[i]*B2[i];
}
FFT(A,n,-1),FFT(B,n,-1),FFT(C,n,-1);
for(int i=0;i<n;++i)
{
ll av=(ll)(A[i].x+0.5)%p,bv=(ll)(B[i].x+0.5)%p,cv=(ll)(C[i].x+0.5)%p;
ans[i]=((av*Base%p*Base)%p+(bv*Base)%p+cv)%p;
}
}
int main()
{
// freopen(".in","r",stdin);
// freopen(".out","w",stdout);
read(N,M,Mod);
for(int i=0;i<=N;++i) read(F[i]),F[i]=(F[i]+Mod)%Mod;
for(int i=0;i<=M;++i) read(G[i]),G[i]=(G[i]+Mod)%Mod;
while((1<<Bit)<=N+M) ++Bit;
Tot=1<<Bit;
for(int i=0;i<Tot;++i) Rev[i]=(Rev[i>>1]>>1)|((i&1)<<(Bit-1));
MTT(F,G,ans,Tot,Mod);
for(int i=0;i<=N+M;++i) write(ans[i],' ');
return 0;
}
/*

*/

快出甚至三倍,但是消费空间是 的四倍,且精度缺失的问题不可避免,值域如果再大一点就可能会挂掉。

当然,根据 的性质,其变换次数可以从 次优化到 次甚至是 次,最优解法是 次,详见毛啸论文,一般用不到。

个人认为 次足以。


模板参考

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
struct Complex
{
ld x,y;
Complex operator+(const Complex &a) const
{ return {x+a.x,y+a.y}; }
Complex operator-(const Complex &a) const
{ return {x-a.x,y-a.y}; }
Complex operator*(const Complex &a) const
{ return {x*a.x-y*a.y,x*a.y+y*a.x}; }
}A[MAXN],B[MAXN],C[MAXN],A1[MAXN],A2[MAXN],B1[MAXN],B2[MAXN];
ll F[MAXN],G[MAXN],ans[MAXN];
int Rev[MAXN],Tot,Bit;
inline void FFT(Complex a[],int n,int inv)
{
for(int i=0;i<n;++i)
if(i<Rev[i]) std::swap(a[i],a[Rev[i]]);
for(int mid=1;mid<n;mid<<=1)
{
auto w1=Complex({std::cos(pi/mid),inv*std::sin(pi/mid)});
for(int i=0;i<n;i+=mid*2)
{
auto wk=Complex({1,0});
for(int j=0;j<mid;++j,wk=wk*w1)
{
auto x=a[i+j],y=a[i+j+mid]*wk;
a[i+j]=x+y,a[i+j+mid]=x-y;
}
}
}
if(inv==-1) for(int i=0;i<n;++i) a[i].x/=n;
}
inline void MTT(ll f[],ll g[],ll ans[],int n,int p)
{
for(int i=0;i<n;++i)
{
A1[i].x=f[i]/Base,A2[i].x=f[i]%Base;
B1[i].x=g[i]/Base,B2[i].x=g[i]%Base;
}
FFT(A1,n,1),FFT(A2,n,1),FFT(B1,n,1),FFT(B2,n,1);
for(int i=0;i<n;++i)
{
A[i]=A1[i]*B1[i];
B[i]=A1[i]*B2[i]+A2[i]*B1[i];
C[i]=A2[i]*B2[i];
}
FFT(A,n,-1),FFT(B,n,-1),FFT(C,n,-1);
for(int i=0;i<n;++i)
{
ll av=(ll)(A[i].x+0.5)%p,bv=(ll)(B[i].x+0.5)%p,cv=(ll)(C[i].x+0.5)%p;
ans[i]=((av*Base%p*Base)%p+(bv*Base)%p+cv)%p;
}
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
inline void NTT(ll a[],int n,ll p,int inv)
{
for(int i=0;i<n;++i)
if(i<Rev[i]) std::swap(a[i],a[Rev[i]]);
ll invG=qPow(Rt,p-2,p);
for(int mid=1;mid<n;mid<<=1)
{
ll w1=qPow(inv==1?Rt:invG,(p-1)/(mid<<1),p);
for(int i=0;i<n;i+=mid*2)
{
ll wk=1;
for(int j=0;j<mid;++j,wk=wk*w1%p)
{
ll x=a[i+j],y=a[i+j+mid]*wk%p;
a[i+j]=(x+y)%p,a[i+j+mid]=(x-y+p)%p;
}
}
}
if(inv==-1)
{
ll iv=qPow(n,p-2,p);
for(int i=0;i<n;++i) a[i]=a[i]*iv%p;
}
}
ll a[MAXN],b[MAXN];
inline void Mul(ll f[],ll g[],int n,ll p,ll ans[])
{
std::memcpy(a,f,n*sizeof(ll));
std::memcpy(b,g,n*sizeof(ll));
NTT(a,n,p,1),NTT(b,n,p,1);
for(int i=0;i<n;++i) a[i]=a[i]*b[i]%p;
NTT(a,n,p,-1);
std::memcpy(ans,a,n*sizeof(ll));
}
ll ans[4][MAXN];
signed main()
{
for(int k=1;k<=3;++k) Mul(F,G,Tot,P[k],ans[k]);
for(int i=0;i<=N+M;++i)
{
ll x=ans[1][i]+((ans[2][i]-ans[1][i]+P[2])*qPow(P[1],P[2]-2,P[2])%P[2])*P[1];
ll s=(x%Mod+(ans[3][i]-x%P[3]+P[3])*qPow(P[1]*P[2],P[3]-2,P[3])%P[3]*P[1]%Mod*P[2]%Mod)%Mod;
write(s,' ');
}
return 0;
}