Codeforces 891E Lust

Description

给出一个长度为 n 的数组 a ,有一个计数器 res ,接下来进行 k 次操作,每次操作等概率的选择 [1,n] 中的某个数 i ,然后 res 加上 jiaj ,然后 ai=ai1 ,求 res 的期望值模 109+7

Data Constraint

2n5000   0ai,k109

Solution

首先通过归纳可以得到题目让我们求的是

E[Πni=1aiΠni=1(aibi)]
其中 ni=1bi = k
那也就是让我们求
E[Πni=1(aibi)]

将选择 [1,n] 中的每个数的概率看成 1 ,最后答案再乘上 1nk 即可。
接下来我们考虑生成函数 EGF
考虑 ai 它的生成函数 Fi(x)
Fi(x)=j0(aij)xjj!

Fi(x)=j0aixjj!j1xj(j1)!

Fi(x)=j0aixjj!j0xj+1j!

Fi(x)=aiexxex

Fi(x)=ex(aix)

那总的答案的生成函数 F(x) = Πni=1Fi(x) = enxΠni=1(aix)
那答案就是让我们求 [xk]n!nkF(x) ,也就是 [xk]n!nkenxΠni=1(aix)
O(n2) 暴力卷积求出 Πni=1(aix) ,或者打个分治 NTT O(n log2 n) 也行,记得到的多项式为 G(x)

那答案就是求 [xk]n!nkenxG(x)
考虑到 G(x) 只有 n+1 项,于是答案可以写成

Ans=n!nki=0n[xi]G(x)[xki]enx

Ans=n!nki=0n[xi]G(x)nki(ki)!

Ans=1nki=0n[xi]G(x)nkiki

随便算算就好,时间复杂度 O(n2) O(n log2 n)

Code

#include
#include
#include
#include

#define fo(i,j,l) for(int i=j;i<=l;++i)
#define fd(i,j,l) for(int i=j;i>=l;--i)

using namespace std;
typedef long long ll;
const ll N=34e4,mo=998244353,ZD=262144;

ll a[N],f[N],g[N],aa[N],bb[N],cc[N];
ll jc[N],w[N];
int n,bits[N];
ll k;

inline ll ksm(ll o,ll t)
{
    ll y=1;
    for(;t;t>>=1,o=o*o%mo)
    if(t&1)y=y*o%mo;
    return y;
}

inline int read()
{
    int o=0; char ch=' ';
    for(;ch<'0'||ch>'9';ch=getchar());
    for(;ch>='0'&&ch<='9';ch=getchar())o=o*10+ch-48;
    return o;
}

inline ll mod(ll o)
{return o<0?o+mo:(o>=mo?o-mo:o);}

void prepare()
{
    w[0]=1; w[1]=ksm(3,(mo-1)/ZD);
    fo(i,2,ZD)w[i]=w[i-1]*w[1]%mo;
}

inline void dft(ll *c,int mm,int sig)
{
    ll ww,v;
    fo(i,1,mm-1)if(bits[i]for(int m=2;m<=mm;m<<=1){
        int half=m>>1,U=ZD/m;
        fo(i,0,half-1){
            ww=sig==1?w[U*i]:w[ZD-U*i];
            for(int j=i;jm){
                v=c[j+half]*ww%mo;
                c[j+half]=(c[j]-v+mo)%mo;
                c[j]=(c[j]+v)%mo;
            }
        }
    }
    if(sig==-1){
        ll ny=ksm(mm,mo-2);
        fo(i,0,mm-1)c[i]=c[i]*ny%mo;
    }
}

void divi(int l,int r)
{
    if(l+100>=r){
        g[0]=1;
        fo(i,1,r-l+1)g[i]=0;
        fo(i,l,r){
            fd(j,i-l+1,1)g[j]=(g[j]*a[i]-g[j-1]+mo)%mo;
            g[0]=g[0]*a[i]%mo; 
        }
        fo(i,l,r)f[i]=g[i-l+1];
        return;
    }
    int mid=l+r>>1;
    divi(l,mid); divi(mid+1,r);
    int ss=0,mm=1;
    while(mm<=r-l+2)mm<<=1,++ss;
    fo(i,l,mid)aa[i-l+1]=f[i];
    aa[0]=1;
    fo(i,l,mid)aa[0]=aa[0]*a[i]%mo;
    fo(i,mid-l+2,mm)aa[i]=0;
    fo(i,mid+1,r)bb[i-mid]=f[i];
    bb[0]=1;
    fo(i,mid+1,r)bb[0]=bb[0]*a[i]%mo;
    fo(i,r-mid+1,mm)bb[i]=0;
    fo(i,0,mm-1)bits[i]=(bits[i>>1]>>1)|((i&1)<1);
    dft(aa,mm,1); 
    dft(bb,mm,1);
    fo(i,0,mm-1)aa[i]=aa[i]*bb[i]%mo;
    dft(aa,mm,-1);
    fo(i,l,r)f[i]=aa[i-l+1];
}

int main()
{
    cin>>n>>k;
    ll lj=ksm(n,(mo-2)*k);
    fo(i,1,n)a[i]=read();
    prepare();
    divi(1,n); f[0]=1;
    fo(i,1,n)f[0]=f[0]*a[i]%mo;
    ll dq=0,ans=0;
    jc[0]=k;
    fo(i,1,n)jc[i]=jc[i-1]*(k-i)%mo;
    if(n1);
    fd(i,n,0){
        if(i==k)dq=1;else dq=dq*n%mo;
        ll dd=f[i]*dq%mo;
        if(i!=0)dd=dd*jc[i-1]%mo;
        ans=(ans+dd)%mo;
    }
    ans=ans*lj%mo;
    ll js=1;
    fo(i,1,n)js=js*a[i]%mo;
    ans=(js-ans+mo)%mo;
    cout<

你可能感兴趣的:(生成函数,快速数论变换NTT)