(每日一题)2016 北京ICPC网络赛G hihocoder 1388 (中国剩余定理 + NTT)
整理的算法模板合集: ACM模板
实际上是一个全新的精炼模板整合计划
Weblink
https://vjudge.net/problem/HihoCoder-1388
Problem
给定数组 A 和数组 B,求:
min \left\{ \sum_{i=0}^{n-1}(A_i-B_{(i+k) \space mod \space n})^2\ \big | \ k=0,1\dots n-1 \right\}
Solution
将式子展开:
\begin{aligned} & min \left\{ \sum_{i=0}^{n-1}(A_i-B_{(i+k) \mod n})^2\ \big | \ k=0,1\dots n-1 \right\}& \\ &= \sum_{i=0}^{n-1}A_i^2+(B_{(i+k) \mod n})^2+2\times A_i\times B_{(i+k) \mod n}& \\ &= \sum_{i=0}^{n-1}A_i^2+ \sum_{i=0}^{n-1}B_i^2+ \sum_{i=0}^{n-1}2\times A_i\times B_{(i+k) \mod n} \end{aligned}
显然前面两个都是定值,我们只需要求 \displaystyle2\times \sum_{i=0}^{n-1} A_i\times B_{(i+k) \mod n} 的最小值即可。
经典将 B 数组翻转,然后倍长,这样就是卷积的形式了,我们就可以直接卷了。
然后因为数据较大,FFT 精度不够,并且题目中还没有给模数,可以找两个大模数直接NTT,然后 CRT 合并即可。
模数我选的 10^9 ,中间记得用龟速乘,不然会爆 long long。
Code
(原题的OJ炸了,反正样例过了就是过了(doge))
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1000007, G = 3;
const ll mod1 = 1004535809, mod2 = 998244353;
ll M;
int n, m;
ll a[N], b[N];
ll f1[N], f2[N], g1[N], g2[N];
int limit, L;
int RR[N];
ll mul(ll a, ll b, ll mod)
{
ll res = 0;
while(b) {
if(b & 1) res = res + a % mod;
a = a + a % mod;
b >>= 1;
}
return res;
}
ll qpow(ll a, ll b, ll mod)
{
ll res = 1;
while(b) {
if(b & 1) res = res * a % mod;
a = a * a % mod;
b >>= 1;
}
return res;
}
ll inv(ll x, ll mod)
{
return qpow(x, mod - 2, mod);
}
ll exgcd(ll a, ll b, ll &x, ll &y)
{
if(b == 0) {
x = 1, y = 0;
return a;
}
ll d = exgcd(b, a % b, x, y);
ll z = x;
x = y;
y = z - y * (a / b);
return d;
}
ll mm[5], aa[5];
ll CRT(int n, ll *a, ll *mo)
{
M = mo[1] * mo[0];
ll res = 0;
for(ll i = 0; i < n; ++ i) {
ll Mi = M / mo[i];
ll ti, y;
ll d = exgcd(Mi, mo[i], ti, y);
ti = (ti % mo[i] + mo[i]) % mo[i];
res = (res + mul(mul(a[i], ti, M), Mi, M)) % M;
}
return (res % M + M) % M;
}
void NTT(ll *A, int type, ll mod)
{
for(int i = 0; i < limit; ++ i)
if(i < RR[i])
swap(A[i], A[RR[i]]);
for(int mid = 1; mid < limit; mid <<= 1) {
ll wn = qpow(G, (mod - 1) / (mid * 2), mod);
if(type == -1)
wn = qpow(wn, mod - 2, mod);
for(int len = mid << 1, pos = 0; pos < limit; pos += len) {
ll w = 1;
for(int k = 0; k < mid; ++ k, w = (w * wn) % mod) {
ll x = A[pos + k], y = w * A[pos + mid + k] % mod;
A[pos + k] = (x + y) % mod;
A[pos + k + mid] = (x - y + mod) % mod;
}
}
}
if(type == -1) {
ll limit_inv = inv(limit, mod);
for(int i = 0; i < limit; ++ i)
A[i] = (A[i] * limit_inv) % mod;
}
}
void solve()
{
scanf("%d", &n);
for(int i = 0; i < n; ++ i) {
scanf("%lld", &a[i]);
}
for(int i = 0; i < n; ++ i) {
scanf("%lld", &b[i]);
}
for (int i = 0; i < n / 2; ++ i) {
swap (b[i], b[n - i - 1]);
}
for (int i = n; i < 2 * n; ++ i) {
b[i] = b[i - n];
}
limit = 1, L = 0;
while(limit < 2 * n) L ++ , limit <<= 1;
for(int i = 0; i < limit; ++ i) {
RR[i] = (RR[i >> 1] >> 1) | ((i & 1) << (L - 1));
}
for(int i = 0; i < n; ++ i) {
f1[i] = f2[i] = a[i];
}
for(int i = n; i <= limit; ++ i) {
f1[i] = f2[i] = 0;
}
for(int i = 0; i < 2 * n; ++ i) {
g1[i] = g2[i] = b[i];
}
for(int i = 2 * n; i <= limit; ++ i) {
g1[i] = g2[i] = 0;
}
NTT(f1, 1, mod1);
NTT(g1, 1, mod1);
NTT(f2, 1, mod2);
NTT(g2, 1, mod2);
for(int i = 0; i < limit; ++ i) {
f1[i] = f1[i] * g1[i] % mod1;
f2[i] = f2[i] * g2[i] % mod2;
}
NTT(f1, -1, mod1);
NTT(f2, -1, mod2);
ll ans = 0;
for(int i = n - 1; i <= 2 * n - 2; ++ i) {
mm[0] = mod1, mm[1] = mod2;
aa[0] = f1[i], aa[1] = f2[i];
ll res = CRT(2, aa, mm);
ans = max(ans, res);
}
ll res = 0;
for(int i = 0; i < n; ++ i) {
res += 1ll * a[i] * a[i];
res += 1ll * b[i] * b[i];
}
printf("%lld\n", res - 2 * ans);
return ;
}
int main()
{
int t;
M = 1ll * mod1 * mod2;
scanf("%d", &t);
while(t -- ) {
solve();
}
}
