Advertisement

[HDU 5362]Just A String

阅读量:

一、题目

题目描述

设字符集合的大小为m(记作\mathcal{S}),从该集合中随机选取n个字符构造一个字符串s = s_1s_2\cdots s_n \in \mathcal{S}^n。定义函数f(s)表示在s的所有可能子串中满足存在重排后成为回文的数量,则问题转化为求f(s)在所有可能s\in\mathcal{S}^n下的期望值\mathbb{E}[f(s)]。根据题意的要求,在计算完上述期望值后需将其结果乘以m^n并对结果取模于\bmod (10^9+7)

数据范围

1\leq n,m\leq 2000

二、解法

0x01 dp

其实原问题要求的是所有子串的个数,我们考虑每种子串的贡献。

有一个动态规划(DP)问题相对容易想到,在这个问题中我们定义状态 dp[i][j] 表示从选i个字符中偶数相消后剩余j个字符的情况的数量。其转移方程如上所述:

dp[i][j] = dp[i-1][j+1] \times (j+1) + dp[i-1][j-1] \times (m-j+1)

这一转移方程分析了新加入的字符会对当前状态产生怎样的影响,并且整个过程的时间复杂度为 O(n^2) 时间复杂度。
在完成这一DP表之后,在枚举所有可能的i值时需要考虑子类情况:即当i与奇偶性一致时(即i\&1=0/1),统计这些情况的数量,并将它们在整个字符串中的起始位置数量与剩余可选位置数量相乘以获得最终结果。

有一个性质,就是dpi,j同奇同偶,可以利用这一点来卡常qwq。

复制代码
    #include <cstdio>
    #include <cstring>
    const int MOD = 1e9+7;
    const int MAXN = 2005;
    int read()
    {
    int num=0,flag=1;char c;
    while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
    while(c>='0'&&c<='9')num=(num<<3)+(num<<1)+(c^48),c=getchar();
    return num*flag;
    }
    int T,n,m,ans,pw[MAXN],dp[MAXN][MAXN];
    int main()
    {
    	T=read();
    	while(T--)
    	{
    		n=read();m=read();
    		pw[0]=1;ans=0;
    		for(int i=1;i<=n;i++)
    			pw[i]=1ll*pw[i-1]*m%MOD;
    		dp[0][0]=1;
    		for(int i=1;i<=n;i++)
    		{
    			int j=0;
    			if(i&1) j=1;
    			for(;j<=i;j+=2)
    			{
    				if(j>=1) dp[i][j]=(1ll*dp[i-1][j-1]*(m-j+1)+1ll*dp[i-1][j+1]*(j+1))%MOD;
    				else dp[i][j]=1ll*dp[i-1][j+1]*(j+1)%MOD;
    			}
    		}
    		for(int i=1;i<=n;i++)
    			ans=(ans+1ll*dp[i][i&1]*(n-i+1)%MOD*pw[n-i]%MOD)%MOD;
    		printf("%d\n",ans);
    	}
    }

0x02 母函数

针对子串长度为偶数的情况进行分析,在这种情况下字符只能选取偶数个数量。我们可以将其闭合形式表示为:
\frac{e^x + e^{-x}}{2}
接下来对其进行二项式展开:
2^{-m}\sum_{i=0}^m C_{m}^i \times e^{(m-2i)x}
对于所有偶数长度的子串,在计算其系数时可采用以下方法:
2^{-m}\sum_{j=0}^m C_{m}^j \times ((m-2j)x)^i
当处理奇数长度的子串时,则必须强制选择一个字符作为奇位置,并将其闭合形式表示为:
\frac{e^x - e^{-x}}{2}
单独处理该字符后,则需计算以下表达式:
m \cdot 2^{-m}\sum_{j=0}^{m-1} C_{m-1}^j \times ((2j + 2 - m)^i - (2j - m)^i)
计算上述结果后即可按照动态规划的方式进行贡献计算,最终的时间复杂度仍保持在O(nm)级别,并且常系数会略高于动态规划方法。

复制代码
    #include <cstdio>
    #define int long long
    const int MOD = 1e9+7;
    const int MAXN = 2005;
    int read()
    {
    int num=0,flag=1;char c;
    while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1;
    while(c>='0'&&c<='9')num=(num<<3)+(num<<1)+(c^48),c=getchar();
    return num*flag;
    }
    int T,n,m,ans,tmp,pw[MAXN],fac[MAXN],inv[MAXN];
    void init(int n)
    {
    	fac[0]=inv[0]=inv[1]=1;
    	for(int i=2;i<=n;i++) inv[i]=inv[MOD%i]*(MOD-MOD/i)%MOD;
    	for(int i=1;i<=n;i++) inv[i]=inv[i]*inv[i-1]%MOD;
    	for(int i=1;i<=n;i++) fac[i]=fac[i-1]*i%MOD;
    }
    int C(int n,int m)
    {
    	return fac[n]*inv[m]%MOD*inv[n-m]%MOD;
    }
    int qkpow(int a,int b)
    {
    	int res=1;
    	while(b>0)
    	{
    		if(b&1) res=res*a%MOD;
    		a=a*a%MOD;
    		b>>=1; 
    	}
    	return res;
    }
    signed main()
    {
    	init(2000);
    	T=read();
    	while(T--)
    	{
    		n=read();m=read();
    		ans=0;
    		pw[0]=1;
    		for(int i=1;i<=n;i++)
    			pw[i]=1ll*pw[i-1]*m%MOD;
    		for(int i=1;i<=n;i++)
    		{
    			tmp=0;
    			if((i&1)==0)
    			{
    				for(int j=0;j<=m;j++)
    					tmp=(tmp+C(m,j)*qkpow(m-2*j,i)%MOD)%MOD;
    			}
    			else
    			{
    				for(int j=0;j<m;j++)
    					tmp=(tmp+C(m-1,j)*(qkpow(m-2*j,i)-qkpow(m-2*j-2,i))%MOD)%MOD;
    				tmp=tmp*m%MOD;
    			}
    			tmp=tmp*qkpow(qkpow(2,m),MOD-2)%MOD;
    			ans=(ans+tmp*(n-i+1)%MOD*pw[n-i]%MOD)%MOD;
    		}
    		printf("%lld\n",(ans%MOD+MOD)%MOD);
    	}
    }

全部评论 (0)

还没有任何评论哟~