【学习笔记】[ABC245Ex] Product Modulo 2

有点难

发现可以用中国剩余定理对 M = p α M=p^{\alpha} M=pα求出答案,然后再乘起来。对于每个 A i A_i Ai通过中国剩余定理合并出来的答案也是 唯一 的。

发现对于一个 A i A_i Ai,恰好包含 t t t p p p的方案数为 p α − t − 1 ( p − 1 ) p^{\alpha-t-1}(p-1) pαt1(p1)。那么我们知道了 t t t的总和就可以直接算方案数了。

t t t表示满足 p t ∣ N p^t|N ptN的最大整数,根据同余的性质我们知道 gcd ⁡ ( ∏ i = 1 n A i , p α ) = gcd ⁡ ( N , p α ) = p t \gcd(\prod_{i=1}^nA_i,p^{\alpha})=\gcd(N,p^{\alpha})=p^t gcd(i=1nAi,pα)=gcd(N,pα)=pt

p t p^{t} pt除掉过后,此时 A i A_i Ai N N N都和 M M M互质,根据简约剩余系的理论,此时 A 1 ∼ n − 1 A_{1\sim n-1} A1n1可以取 [ 1 , M ] [1,M] [1,M]中与 M M M互质的任意数,那么 A n A_n An也对应唯一的 [ 1 , M ] [1,M] [1,M]中与 M M M互质的一个数。

A i ∈ [ 1 , p a i ] A_i\in [1,p^{a_i}] Ai[1,pai],发现 M ∣ p a n M|p^{a_n} Mpan,所以 A n A_n An的方案数就是 p a n M \frac{p^{a_n}}{M} Mpan。发现 A 1 ∼ n − 1 A_{1\sim n-1} A1n1只要都和 p p p互质就好了,方案数 ∏ i = 1 n − 1 ( p a i − 1 ( p − 1 ) ) × p a n M = ∏ i = 1 n p a i × ( p − 1 p ) n − 1 × 1 M \prod_{i=1}^{n-1}(p^{a_i-1}(p-1))\times \frac{p^{a_n}}{M}=\prod_{i=1}^{n}p^{a_i}\times (\frac{p-1}{p})^{n-1}\times \frac{1}{M} i=1n1(pai1(p1))×Mpan=i=1npai×(pp1)n1×M1

可以直接计算。对于 N = 0 N=0 N=0的情况简单容斥即可。

#include
#define ll long long
#define pb push_back
#define fi first
#define se second
#define db double
using namespace std;
const int mod=998244353;
ll n,N,M,res=1;
ll fpow(ll x,ll y=mod-2){
    if(y<0)return 0;
    x%=mod;ll z(1);
    for(;y;y>>=1){
        if(y&1)z=z*x%mod;
        x=x*x%mod;
    }return z;
}
ll binom(ll x,ll y){
    ll res=1;
    for(int i=1;i<=y;i++)res=res*fpow(i)%mod;
    for(int i=0;i<y;i++)res=res*((x-i)%mod)%mod;
    return res;
}
void add(ll &x,ll y){
    x=(x+y)%mod;
}
ll solve(ll p,ll N,int k){
    if(N){
        int t=0;
        while(N%p==0)N/=p,t++;
        return binom(n+t-1,t)*fpow(p,k*n-t-n+1-(k-t))%mod*fpow(p-1,n-1)%mod;
    }
    else{
        ll res=fpow(p,k*n);
        for(int t=0;t<k;t++){
            add(res,-fpow(p,k-t-1)*((p-1)%mod)%mod*binom(n+t-1,t)%mod*fpow(p,k*n-t-n+1-(k-t))%mod*fpow(p-1,n-1)%mod);
        }
        return res;
    }
}
signed main(){
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);
    cin>>n>>N>>M;
    for(ll i=2;i<=M/i;i++){
        if(M%i==0){
            ll P=1,k=0;
            while(M%i==0)M/=i,P*=i,k++;
            ll n2=N%P;
            res=res*solve(i,n2,k)%mod;
        }
    }
    if(M>1){
        res=res*solve(M,N%M,1)%mod;
    }
    cout<<(res+mod)%mod;
}

你可能感兴趣的:(学习,笔记,算法)