天天看點

SPOJ COT3

我不會說我是為了學博弈才去做這題的……

題意:

給定一棵N個點的樹,1号點為根,每個節點是白色或者黑色。雙方輪流操作,每次選擇一個白色節點,将從這個點到根的路徑上的點全部染成黑色。問先手是否必勝,以及第一步可選節點有哪些。N<=100000。

分析:

首先是博弈方面的分析。令SG[x]為,隻考慮以x為根的子樹時的SG值。令g[x]為,隻考慮以x為根的子樹時,所有後繼局面的SG值的集合。那麼SG[x]=mex{g[x]}。

我們考慮怎麼計算g[x]。假設x的兒子為v1,v2,...,vk,令sum[x]=SG[v1] xor SG[v2] xor .. xor SG[vk]。考慮兩種情況:

1、x為黑色。不難發現以x的每個兒子為根的子樹是互相獨立的。假設這一步選擇了vi子樹的某一個節點,那麼轉移到的局面的SG值就是sum[x] xor SG[vi] xor (在g[vi]中的某個值)。那麼我們隻需将每個g[vi]整體xor上sum[x] xor SG[vi]再合并到g[x]即可。

2、x為白色。這時候我們多了一種選擇,即選擇x點。可以發現,選擇x點之後x點變成黑色,所有子樹仍然獨立,而轉移到的局面的SG值就是sum[x]。如果此時不選擇x而是選擇x子樹裡的某個白色節點,那麼x一樣會被染成黑色,所有子樹依然獨立。是以x為白色時隻是要向g[x]中多插入一個值sum[x]。

這樣我們就有一個自底向上的DP了。樸素的複雜度是O(N^2)的。

接下來再考慮第一步可選的節點。我們要考慮選擇哪些節點之後整個局面的SG值會變成0。假設我們選擇了x點,那麼從x到根的路徑都會被染黑,将原來的樹分成了一堆森林。我們令up[x]為,不考慮以x為根的子樹,将從x到根的路徑染黑,剩下的子樹的SG值的xor和。那麼up[x]=up[fa[x]] xor sum[fa[x]] xor sg[x],其中fa[x]為x的父親節點編号。那麼如果點x初始顔色為白色且up[x] xor sum[x]=0,那麼這個點就是第一步可選的節點。這一步是O(N)的。

剩下的就是優化求SG了。我們需要一個可以快速整體xor并合并的資料結構。整體xor可以用二進制Trie打标記實作,至于合并,用啟發式合并是O(Nlog^2N)的,而用線段樹合并的方法可以做到O(NlogN)。不過還需要注意各種常數的問題……比如不要用指針,Trie的節點不用記大小,隻要記是否滿了……

做這題的時候先去膜拜了主席的題解……然後又去膜拜了主席冬令營的講課……最後還去膜拜了翺犇的代碼……然後幾乎是照着抄了一遍……

代碼:(SPOJ上排到了倒數第三……)

//SPOJ11414; COT3; Game Theory + Trie Merging
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <ctime>
using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ld;
#define pair(x, y) make_pair(x, y)
#define runtime() ((double)clock() / CLOCKS_PER_SEC)

#define N 100000
#define LOG 17
struct edge {
	int next, node;
} e[N << 1 | 1];
int head[N + 1], tot = 0;

inline void addedge(int a, int b) {
	e[++tot].next = head[a];
	head[a] = tot, e[tot].node = b;
}

#define SIZE 2000000
struct Node {
	int l, r;
	bool full;
	int d;
} tree[SIZE + 1];
#define l(x) tree[x].l
#define r(x) tree[x].r
#define d(x) tree[x].d
#define full(x) tree[x].full
int root[N + 1], tcnt = 0;
int n, col[N + 1], sg[N + 1], sum[N + 1], up[N + 1];
bool v[N + 1];

inline int newnode() {
	return ++tcnt;
}

inline void update(int x) {
	full(x) = full(l(x)) && full(r(x));
}

inline void push(int x) {
	if (d(x)) {
		if (l(x)) d(l(x)) ^= d(x) >> 1;
		if (r(x)) d(r(x)) ^= d(x) >> 1;
		if (d(x) & 1) swap(l(x), r(x));
		d(x) = 0;
	}
}

int merge(int l, int r) {
	if (!l || full(r)) return r;
	if (!r || full(l)) return l;
	push(l), push(r);
	int ret = newnode();
	l(ret) = merge(l(l), l(r));
	r(ret) = merge(r(l), r(r));
	update(ret);
	return ret;
}

inline int rev(int x) {
	int r = 0;
	for (int i = LOG; i > 0; --i)
		if (x >> i - 1 & 1) r += 1 << LOG - i;
	return r;
}

void insert(int x, int v, int p) {
	push(x);
	if (v >> p - 1 & 1) {
		if (!r(x)) r(x) = newnode();
		if (p != 1) insert(r(x), v, p - 1);
		else full(r(x)) = true;
	} else {
		if (!l(x)) l(x) = newnode();
		if (p != 1) insert(l(x), v, p - 1);
		else full(l(x)) = true;
	}
	update(x);
}

int mex(int x) {
	int r = 0;
	for (int i = LOG; x; --i) {
		push(x);
		if (full(l(x))) r += 1 << i - 1, x = r(x);
		else x = l(x);
	}
	return r;
}

void calc(int x) {
	v[x] = true;
	int xorsum = 0;
	for (int i = head[x]; i; i = e[i].next) {
		int node = e[i].node;
		if (v[node]) continue;
		calc(node);
		v[node] = false;
		xorsum ^= sg[node];
	}
	for (int i = head[x]; i; i = e[i].next) {
		int node = e[i].node;
		if (v[node]) continue;
		d(root[node]) ^= rev(xorsum ^ sg[node]);
		root[x] = merge(root[x], root[node]);
	}
	if (!col[x]) insert(root[x], xorsum, LOG);
	sg[x] = mex(root[x]);
	sum[x] = xorsum;
}

int ans[N + 1], cnt = 0;

void find(int x) {
	v[x] = true;
	if ((up[x] ^ sum[x]) == 0 && col[x] == 0) ans[++cnt] = x;
	for (int i = head[x]; i; i = e[i].next) {
		int node = e[i].node;
		if (v[node]) continue;
		up[node] = up[x] ^ sum[x] ^ sg[node];
		find(node);
	}
}

int main(int argc, char* argv[]) {
#ifdef KANARI
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
#endif
	
	scanf("%d", &n);
	for (int i = 1; i <= n; ++i) scanf("%d", col + i);
	for (int i = 1; i < n; ++i) {
		static int x, y;
		scanf("%d%d", &x, &y);
		addedge(x, y), addedge(y, x);
	}
	for (int i = 1; i <= n; ++i) root[i] = newnode();
	calc(1);
	for (int i = 1; i <= n; ++i) v[i] = false;
	find(1);
	
	if (cnt == 0) printf("-1\n");
	else {
		sort(ans + 1, ans + cnt + 1);
		for (int i = 1; i <= cnt; ++i) printf("%d\n", ans[i]);
	}
	
//	cerr << runtime() << endl;
//	for (int i = 1; i <= n; ++i) printf("%d ", sg[i]);
	
	fclose(stdin);
	fclose(stdout);
	return 0;
}
           

順便貼一個指針的,感覺長得更好看:

//
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <climits>
#include <cmath>
#include <utility>
#include <set>
#include <map>
#include <queue>
#include <ios>
#include <iomanip>
#include <ctime>
#include <numeric>
#include <functional>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include <bitset>
#include <cstdarg>
using namespace std;

typedef long long ll;
typedef unsigned long long ull;
typedef unsigned int uint;
typedef long double ld;
#define pair(x, y) make_pair(x, y)
#define runtime() ((double)clock() / CLOCKS_PER_SEC)

inline int read() {
	static int r;
	static char c;
	r = 0, c = getchar();
	while (c < '0' || c > '9') c = getchar();
	while (c >= '0' && c <= '9') r = r * 10 + (c - '0'), c = getchar();
	return r;
}

template <typename T>
inline void print(T *a, int n) {
	for (int i = 1; i < n; ++i) cout << a[i] << " ";
	cout << a[n] << endl;
}
#define PRINT(__l, __r, __begin, __end) { for (int __i = __begin; __i != __end; ++__i) cout << __l __i __r << " "; cout << endl; }

#define N 100000
#define LOG 17
struct edge {
	int next, node;
} e[N << 1 | 1];
int head[N + 1], tot = 0;

inline void addedge(int a, int b) {
	e[++tot].next = head[a];
	head[a] = tot, e[tot].node = b;
}

struct Node {
	Node *l, *r;
	bool full;
	int d;
	Node() { l = r = NULL, full = false, d = 0; }
} *root[N + 1];
int n, col[N + 1], sg[N + 1], sum[N + 1], up[N + 1];
bool v[N + 1];

inline void update(Node *x) {
	if (x->l && x->r) x->full = x->l->full && x->r->full;
	else x->full = false;
}

inline void applyDelta(Node *x, int v) {
	x->d ^= v;
}

inline void push(Node *x) {
	if (x->d) {
		if (x->l) applyDelta(x->l, x->d >> 1);
		if (x->r) applyDelta(x->r, x->d >> 1);
		if (x->d & 1) swap(x->l, x->r);
		x->d = 0;
	}
}

Node* merge(Node *l, Node *r) {
	if (l == NULL || (r != NULL && r->full)) return r;
	if (r == NULL || (l != NULL && l->full)) return l;
	push(l), push(r);
	Node *ret = new Node();
	ret->l = merge(l->l, r->l);
	ret->r = merge(l->r, r->r);
	update(ret);
	return ret;
}

inline int rev(int x) {
	int r = 0;
	for (int i = LOG; i > 0; --i)
		if (x >> i - 1 & 1) r += 1 << LOG - i;
	return r;
}

void insert(Node *x, int v, int p) {
	push(x);
	if (v >> p - 1 & 1) {
		if (x->r == NULL) x->r = new Node();
		if (p != 1) insert(x->r, v, p - 1);
		else x->r->full = true;
	} else {
		if (x->l == NULL) x->l = new Node();
		if (p != 1) insert(x->l, v, p - 1);
		else x->l->full = true;
	}
	update(x);
}

int mex(Node *x) {
	int r = 0;
	for (int i = LOG; x != NULL; --i) {
		push(x);
		if (x->l && x->l->full) r += 1 << i - 1, x = x->r;
		else x = x->l;
	}
	return r;
}

void calc(int x) {
	v[x] = true;
	int xorsum = 0;
	for (int i = head[x]; i; i = e[i].next) {
		int node = e[i].node;
		if (v[node]) continue;
		calc(node);
		v[node] = false;
		xorsum ^= sg[node];
	}
	for (int i = head[x]; i; i = e[i].next) {
		int node = e[i].node;
		if (v[node]) continue;
		applyDelta(root[node], rev(xorsum ^ sg[node]));
		root[x] = merge(root[x], root[node]);
	}
	if (!col[x]) insert(root[x], xorsum, LOG);
	sg[x] = mex(root[x]);
	sum[x] = xorsum;
}

int ans[N + 1], cnt = 0;

void find(int x) {
	v[x] = true;
	if ((up[x] ^ sum[x]) == 0 && col[x] == 0) ans[++cnt] = x;
	for (int i = head[x]; i; i = e[i].next) {
		int node = e[i].node;
		if (v[node]) continue;
		up[node] = up[x] ^ sum[x] ^ sg[node];
		find(node);
	}
}

int main(int argc, char* argv[]) {
#ifdef KANARI
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
#endif
	
	scanf("%d", &n);
	for (int i = 1; i <= n; ++i) scanf("%d", col + i);
	for (int i = 1; i < n; ++i) {
		static int x, y;
		scanf("%d%d", &x, &y);
		addedge(x, y), addedge(y, x);
	}
	for (int i = 1; i <= n; ++i) root[i] = new Node();
	calc(1);
	for (int i = 1; i <= n; ++i) v[i] = false;
	find(1);
	
	if (cnt == 0) printf("-1\n");
	else {
		sort(ans + 1, ans + cnt + 1);
		for (int i = 1; i <= cnt; ++i) printf("%d\n", ans[i]);
	}
	
//	for (int i = 1; i <= n; ++i) printf("%d ", sg[i]);
	
	fclose(stdin);
	fclose(stdout);
	return 0;
}
           

繼續閱讀