Skip to content

P4238 【模板】多项式乘法逆

P4238 【模板】多项式乘法逆

思路:假设我们已经求出了\(F(x)\)在模\(x^{\lceil\frac{n}{2}\rceil}\) 意义下的逆元\(H(x)\). 那么

\[F(x)(G(x)-H(x))\equiv 0 (\mathrm{mod} \,x^{\left\lceil\frac{n}{2}\right\rceil})\]
\[G(x)-H(x)\equiv 0 (\mathrm{mod} \,x^{\left\lceil\frac{n}{2}\right\rceil})\]
\[(G(x)-H(x))^2\equiv 0 (\mathrm{mod} \,x^n)\]
\[G^2(x)-2G(x)H(x)+H^2(x)\equiv 0(\mathrm{mod}x^n)\]
\[F(x)(G^2(x)-2G(x)H(x)+H^2(x))\equiv G(x)-2H(x)+F(x)H^2(x)\equiv0(\mathrm{mod}x^n)\]
\[\Longrightarrow G(x)=H(x)(2-F(x)H(x))(\mathrm{mod}\,x^n)\]

可以通过此算法递归处理逆元,多项式乘法用\(NTT\),总时间复杂度为\(O(n\log n)\)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#include <bits/stdc++.h>
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define per(i,a,b) for(int i=(b);i>=(a);i--)
#define ll long long
using namespace std;
const ll N=400000,P=998244353,G=3,IG=(P+1)/G;
int n,limit,m,type,inv,a[N];
ll A[N],B[N],C[N];
inline ll read(){
    ll x=0,f=1;char ch=getchar();
    while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
    while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
    return x*f;
}
inline ll fpow(ll a,ll b){
    ll ret=1;
    while(b){
        if(b&1)ret=(ret*a)%P;
        b>>=1;a=(a*a)%P;
    }
    return ret;
}
int rev[N];
inline void init(int x){
    for(limit=1;limit<x;)limit<<=1;
    rep(i,0,limit)rev[i]=(rev[i>>1]>>1)|((i&1)?limit>>1:0);
    inv=fpow(limit,P-2);
}
void ntt(ll *f,bool flag){
    for(int i=0;i<limit;i++)if(i<rev[i])swap(f[i],f[rev[i]]);
    for(int len=2,k=1;len<=limit;len<<=1,k<<=1){
        ll wn=fpow(flag?G:IG,(P-1)/len);
        for(int i=0;i<limit;i+=len){
            for(ll j=i,w=1;j<i+k;j++,w=w*wn%P){
                ll tmp=w*f[j+k]%P;
                f[j+k]=(f[j]-tmp+P)%P;f[j]=(f[j]+tmp)%P;
            }
        }
    }
    if(!flag)rep(i,0,limit)f[i]=f[i]*inv%P;
}
inline void work(int deg,ll *a,ll *b){
    if(deg==1){
        b[0]=fpow(a[0],P-2);
        return;
    }
    work((deg+1)>>1,a,b);
    init(deg<<1);
    rep(i,0,deg-1)C[i]=a[i];
    rep(i,deg,limit-1)C[i]=0;
    ntt(C,1);ntt(b,1);
    rep(i,0,limit-1){
        b[i]=(2-C[i]*b[i]%P+P)%P*b[i]%P;
    }
    ntt(b,0);
    rep(i,deg,limit-1)b[i]=0;
}
int main(){
    n=read();
    rep(i,0,n-1)A[i]=read();
    work(n,A,B);
    rep(i,0,n-1)printf("%lld ",B[i]);
    return 0;
}

Comments