Advertisement

ACM-ICPC 2018 青岛赛区网络预赛 B. Red Black Tree (LCA、二分)

阅读量:

传送门:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemId=5807

题意:

给出一棵树,根节点为1。每条边有一个距离,树上有m个点为红色的点,其余为黑色,每个点的权值为到其最近红色祖先的距离。有q次询问,每次给出一个点集,问在树上涂红一个点后,点集中所有点的最大值的最小值是多少。

思路:

预处理每个点到根的距离cost、到最近红色祖先的距离cost_t和ST表。

对于每次询问,按cost_t从大到小排序,在0~cost_t[0]范围内二分答案,对所有大于答案的点求它们的公共祖先(利用ST表可以O(1)求两点的公共祖先),将其涂红,之后计算每个大于答案的点的新权值是否小于答案。

AC代码:

复制代码
 #include<iostream>

    
 #include<cstdio>
    
 #include<cstring>
    
 #include<string>
    
 #include<cstdlib>
    
 #include<utility>
    
 #include<algorithm>
    
 #include<utility>
    
 #include<queue>
    
 #include<vector>
    
 #include<set>
    
 #include<stack>
    
 #include<cmath>
    
 #include<map>
    
 #include<ctime>
    
 #include<functional>
    
 #include<bitset>
    
 #define P pair<int,int>
    
 #define ll long long
    
 #define ull unsigned long long
    
 #define lson id*2,l,mid
    
 #define rson id*2+1,mid+1,r
    
 #define ls id*2
    
 #define rs (id*2+1)
    
 #define Mod(a,b) a<b?a:a%b+b
    
 #define cl0(a) memset(a,0,sizeof(a))
    
 #define cl1(a) memset(a,-1,sizeof(a))
    
 using namespace std;
    
  
    
 const ll M = 1e9 + 7;
    
 const ll INF = 1e15;
    
 const int N = 4010;
    
 const double _e = 10e-6;
    
 const int maxn = 1e5 + 10;
    
 const int dx[4] = { 0,0,1,-1 }, dy[4] = { 1,-1,0,0 };
    
 const int _dx[8] = { -1,-1,-1,0,0,1,1,1 }, _dy[8] = { -1,0,1,-1,1,-1,0,1 };
    
  
    
 int x, y;
    
  
    
 int t, n, m, q, cc;
    
 bool red[maxn];
    
  
    
 struct node
    
 {
    
 	int nex, to;
    
 	ll c;
    
 }e[maxn * 2];
    
  
    
 int beg[2 * maxn], e_max;
    
 int pos[2 * maxn], T[2 * maxn], tot, rmq[2 * maxn];
    
 ll cost[maxn], cost_t[maxn];
    
 struct ST
    
 {
    
 	int mm[2 * maxn];
    
 	int dp[2 * maxn][20];
    
 	void init(int n)
    
 	{
    
 		mm[0] = -1;
    
 		for (int i = 1; i <= n; i++) {
    
 			mm[i] = ((i&(i - 1)) == 0) ? mm[i - 1] + 1 : mm[i - 1];
    
 			dp[i][0] = i;
    
 		}
    
 		for (int j = 1; j <= mm[n]; j++)
    
 			for (int i = 1; i + (1 << j) - 1 <= n; i++)
    
 				dp[i][j] = rmq[dp[i][j - 1]] < rmq[dp[i + (1 << (j - 1))][j - 1]] ? dp[i][j - 1] : dp[i + (1 << (j - 1))][j - 1];
    
 	}
    
 	int query(int a, int b)
    
 	{
    
 		if (a > b)swap(a, b);
    
 		int k = mm[b - a + 1];
    
 		return rmq[dp[a][k]] <= rmq[dp[b - (1 << k) + 1][k]] ? dp[a][k] : dp[b - (1 << k) + 1][k];
    
 	}
    
 }st;
    
 void init()
    
 {
    
 	memset(beg, -1, sizeof beg);
    
 	e_max = 0; tot = 1;
    
 }
    
 void add_edge(int s, int t, int cost)
    
 {
    
 	e[e_max].to = t; e[e_max].c = cost;
    
 	e[e_max].nex = beg[s]; beg[s] = e_max++;
    
 }
    
 void dfs(int u, int pre, int dis, ll sum, ll sum1)
    
 {
    
 	if (red[u])sum1 = 0;
    
 	cost[u] = sum; cost_t[u] = sum1;
    
 	pos[u] = tot; T[tot] = u; rmq[tot++] = dis;
    
 	for (int i = beg[u]; ~i; i = e[i].nex) {
    
 		int v = e[i].to;
    
 		if (v == pre) continue;
    
 		dfs(v, u, dis + 1, sum + e[i].c, sum1 + e[i].c);
    
 		T[tot] = u;
    
 		rmq[tot++] = dis;
    
 	}
    
 }
    
 void lca_init()
    
 {
    
 	dfs(1, -1, 1, 0, 0);
    
 	st.init(tot - 1);
    
 }
    
 int query(int s, int t)
    
 {
    
 	return T[st.query(pos[s], pos[t])];
    
 }
    
  
    
 int a[maxn];
    
 int nn;
    
  
    
 bool cmp(int a, int b)
    
 {
    
 	return cost_t[a] > cost_t[b];
    
 }
    
  
    
 bool check(ll val)
    
 {
    
 	if (cost_t[a[0]] <= val)
    
 		return true;
    
 	int lca = a[0];
    
 	for (int i = 1; i < nn; i++) {
    
 		if (cost_t[a[i]] <= val)
    
 			break;
    
 		lca = query(lca, a[i]);
    
 	}
    
 	for (int i = 0; i < nn; i++) {
    
 		if (cost_t[a[i]] <= val)
    
 			return true;
    
 		if (cost[a[i]] - cost[lca] > val)
    
 			return false;
    
 	}
    
 	return true;
    
 }
    
  
    
 int main()
    
 {
    
 	scanf("%d", &t);
    
 	while (t--) {
    
 		init();
    
 		memset(red, false, sizeof(red));
    
 		scanf("%d%d%d", &n, &m, &q);
    
 		while (m--) {
    
 			scanf("%d", &x);
    
 			red[x] = true;
    
 		}
    
 		for (int i = 1; i < n; i++) {
    
 			scanf("%d%d%d", &x, &y, &cc);
    
 			add_edge(x, y, cc); add_edge(y, x, cc);
    
 		}
    
 		lca_init();
    
 		while (q--) {
    
 			scanf("%d", &nn);
    
 			for (int i = 0; i < nn; i++)
    
 				scanf("%d", &a[i]);
    
 			sort(a, a + nn, cmp);
    
 			ll l = 0, r = cost_t[a[0]];
    
 			while (l < r) {
    
 				ll mid = (l + r) / 2;
    
 				if (check(mid))
    
 					r = mid;
    
 				else
    
 					l = mid + 1;
    
 			}
    
 			printf("%lld\n", l);
    
 		}
    
 	}
    
 	return 0;
    
 }

全部评论 (0)

还没有任何评论哟~