Advertisement

(每日一题)2016 北京ICPC网络赛G hihocoder 1388 (中国剩余定理 + NTT)

阅读量:

整理的算法模板合集: ACM模板

点我看算法全家桶系列!!!

实际上是一个全新的精炼模板整合计划


Weblink

https://vjudge.net/problem/HihoCoder-1388

Problem

给定数组 A 和数组 B,求:

min \left\{ \sum_{i=0}^{n-1}(A_i-B_{(i+k) \space mod \space n})^2\ \big | \ k=0,1\dots n-1 \right\}

Solution

将式子展开:

\begin{aligned} & min \left\{ \sum_{i=0}^{n-1}(A_i-B_{(i+k) \mod n})^2\ \big | \ k=0,1\dots n-1 \right\}& \\ &= \sum_{i=0}^{n-1}A_i^2+(B_{(i+k) \mod n})^2+2\times A_i\times B_{(i+k) \mod n}& \\ &= \sum_{i=0}^{n-1}A_i^2+ \sum_{i=0}^{n-1}B_i^2+ \sum_{i=0}^{n-1}2\times A_i\times B_{(i+k) \mod n} \end{aligned}

显然前面两个都是定值,我们只需要求 \displaystyle2\times \sum_{i=0}^{n-1} A_i\times B_{(i+k) \mod n} 的最小值即可。

经典将 B 数组翻转,然后倍长,这样就是卷积的形式了,我们就可以直接卷了。

然后因为数据较大,FFT 精度不够,并且题目中还没有给模数,可以找两个大模数直接NTT,然后 CRT 合并即可。

模数我选的 10^9 ,中间记得用龟速乘,不然会爆 long long

Code

(原题的OJ炸了,反正样例过了就是过了(doge))

复制代码
    #include <bits/stdc++.h>
    
    using namespace std;
    typedef long long ll;
    
    const int N = 1000007, G = 3;
    const ll mod1 = 1004535809, mod2 = 998244353;
    
    ll M;
    int n, m;
    ll a[N], b[N];
    ll f1[N], f2[N], g1[N], g2[N];
    int limit, L;
    int RR[N];
      
    ll mul(ll a, ll b, ll mod)
    {
    	ll res = 0;
    	while(b) {
    		if(b & 1) res = res + a % mod;
    		a = a + a % mod;
    		b >>= 1;
    	}
    	return res;
    }
    
    ll qpow(ll a, ll b, ll mod)
    {
    	ll res = 1;
    	while(b) {
    		if(b & 1) res = res * a % mod;
    		a = a * a % mod;
    		b >>= 1;
    	}
    	return res;
    }
    
    ll inv(ll x, ll mod)
    {
    	return qpow(x, mod - 2, mod);
    }
    
    ll exgcd(ll a, ll b, ll &x, ll &y)
    {
    	if(b == 0) {
    		x = 1, y = 0;
    		return a;
    	}
    	ll d = exgcd(b, a % b, x, y);
    	ll z = x;
    	x = y;
    	y = z - y * (a / b);
    	return d;
    }
    
    ll mm[5], aa[5];
      
    ll CRT(int n, ll *a, ll *mo)
    {
    	M = mo[1] * mo[0];
    ll res = 0;
    for(ll i = 0; i < n; ++ i) {
        ll Mi = M / mo[i];
        ll ti, y;
        ll d = exgcd(Mi, mo[i], ti, y);
        ti = (ti % mo[i] + mo[i]) % mo[i];
        res = (res + mul(mul(a[i], ti, M), Mi, M)) % M;
    }
    return (res % M + M) % M;
    } 
    
    void NTT(ll *A, int type, ll mod)
    {
    	for(int i = 0; i < limit; ++ i)
    		if(i < RR[i])
    			swap(A[i], A[RR[i]]); 
    
    	for(int mid = 1; mid < limit; mid <<= 1) {
    		ll wn = qpow(G, (mod - 1) / (mid * 2), mod);
    		if(type == -1)
    			wn = qpow(wn, mod - 2, mod);
    		for(int len = mid << 1, pos = 0; pos < limit; pos += len) {
    			ll w = 1;
    			for(int k = 0; k < mid; ++ k, w = (w * wn) % mod) {
    				ll x = A[pos + k], y = w * A[pos + mid + k] % mod;
    				A[pos + k] = (x + y) % mod;
    				A[pos + k + mid] = (x - y + mod) % mod;
    			}
    		}
    	}
    	if(type == -1) {
    		ll limit_inv = inv(limit, mod);
    		for(int i = 0; i < limit; ++ i)
    			A[i] = (A[i] * limit_inv) % mod;
    	}
    }
     
    
    void solve()
    {
    	scanf("%d", &n);
    	for(int i = 0; i < n; ++ i) {
    		scanf("%lld", &a[i]);
    	}
    	for(int i = 0; i < n; ++ i) {
    		scanf("%lld", &b[i]);
    	}
    
    for (int i = 0; i < n / 2; ++ i) {
        swap (b[i], b[n - i - 1]);
    }
    for (int i = n; i < 2 * n; ++ i) {
        b[i] = b[i - n];
    }
    
    	limit = 1, L = 0;
    	while(limit < 2 * n) L ++ , limit <<= 1;
    
    	for(int i = 0; i < limit; ++ i) {
    		RR[i] = (RR[i >> 1] >> 1) | ((i & 1) << (L - 1));
    	}
    	for(int i = 0; i < n; ++ i) {
    		f1[i] = f2[i] = a[i];
    	}
    	for(int i = n; i <= limit; ++ i) {
    		f1[i] = f2[i] = 0;
    	}
    	for(int i = 0; i < 2 * n; ++ i) {
    		g1[i] = g2[i] = b[i];
    	}
    	for(int i = 2 * n; i <= limit; ++ i) {
    		g1[i] = g2[i] = 0;
    	}
    
    	NTT(f1, 1, mod1);
    	NTT(g1, 1, mod1);
    	NTT(f2, 1, mod2);
    	NTT(g2, 1, mod2);
    
    	for(int i = 0; i < limit; ++ i) {
    		f1[i] = f1[i] * g1[i] % mod1;
    		f2[i] = f2[i] * g2[i] % mod2;
    	}
    
    	NTT(f1, -1, mod1);
    	NTT(f2, -1, mod2);
    
    	ll ans = 0;
    	for(int i = n - 1; i <= 2 * n - 2; ++ i) {
    		mm[0] = mod1, mm[1] = mod2;
    		aa[0] = f1[i], aa[1] = f2[i];
    		ll res = CRT(2, aa, mm);
    		ans = max(ans, res);
    	}
    	ll res = 0;
    	for(int i = 0; i < n; ++ i) {
    		res += 1ll * a[i] * a[i];
    		res += 1ll * b[i] * b[i];
    	} 
    	printf("%lld\n", res - 2 * ans);
    	return ;
    }
    
    int main()
    {
    	int t;
    	M = 1ll * mod1 * mod2;
    	scanf("%d", &t);
    	while(t -- ) {
    		solve();
    	}
    }
    
    
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
      
    

全部评论 (0)

还没有任何评论哟~