天天看點

動态規劃求機率期望和高斯消元求解方程組

算法課的project有一道很有意思的題目,是用動态規劃求機率期望,其中用到了高斯消元法,特此記錄一下。

題目:

小 Z 來到一個古墓去尋找寶藏。 古墓中有非常多的路口和岔路, 有些路口有陷阱, 小 Z 在每次經過路口 i 的陷阱的時候都要掉 A[i]點血, 而且陷阱是永久有效的(即小 Z 每到一次路口 i 就要掉 A[i]點血) 。 幸運的是, 有一些路口沒有陷阱。 可不幸的是, 小 Z 是個路癡, 他完全無法判斷他走過哪裡, 要去哪裡; 他隻能在每一個路口随機(等機率地) 走向某一條岔路到達下一個路口。 小 Z 現在在古墓的入口處(即路口 1) , 這裡沒有陷阱; 寶藏藏匿在路口 n, 那裡也沒有陷阱。

而你萬萬沒有想到的是, 你是這個古墓的守護者。 你知道這個古墓所有的構造(包括它的路口、 岔路和陷阱的情況) , 現在你需要計算出小 Z 能活着見到寶藏的機率。

現給出了該函數的接口:

double func(int n, int hp, vector<int>& damage, vector<int>& edges) {
}
           

其中參數 n 是路口數量; hp 是小 Z 的初始血量; 數組 damage 是 n 個資料,代表 n 個路口陷阱的傷害(無陷阱處為 0, 保證路口 1 和 n 處無陷阱) ; 資料edges 是 2*(岔路條數) 個資料, 每 2 個資料是一條邊, 邊都是雙向的。

舉個例子:

有個三角形路口 1, 2, 3。 小 Z 有 2 點血。 小 Z 在 1, 寶藏在 3, 陷阱在 2(傷害為 1) 。 那麼它的結果為 0.875。 小 Z 在血量降到 0 之前走到 3 就算成功, 是以它失敗的唯一路徑是 1-2-1-2, 每次尋路都是随機的, 三次都走錯的機率是 0.5^3 = 0.125, 那麼成功走到 3 的機率就是 0.875。

動态規劃求機率期望和高斯消元求解方程組

示例輸入:

3 2 3
0 1 0
1 2 
1 3
2 3
           

n=3,hp=2,有3條邊,damage[]={0,1,0},邊是1->2,1->3,2->3,edges[]={1,2,1,3,2,3}。

思路:

建立二維數組dp[hp+1][n+1],其中dp[i][j]表示剩下i滴血時到達j點的期望次數。列出狀态轉移方程:

動态規劃求機率期望和高斯消元求解方程組

i=hp,j=1的時候要加1,因為開始點是節點1,是以必然會經過1次節點1。不包括終點的意思是:一旦到達終點就結束了,是以不能考慮到了終點又折返。

當damage[j]>0的時候,若将下面層均視為常數,則可以很輕松的求出dp[i][j]。而當damage[j]=0時,在這一hp層的所有damage為0的路口的dp值是互相有一個關系(即方程)的,不能直接求出值。要首先求出這一層damage不為0的路口dp值,将它們視為常數,然後利用高斯消元法求解damage=0的路口dp的線性方程組,就可以求出這一hp層的所有dp值。這一層求出來以後,就可以繼續求解上一層。

最後,小Z能活着見到寶藏的機率是

動态規劃求機率期望和高斯消元求解方程組

即當hp>=1時,到達路口n的期望次數之和。

代碼:

//選擇列主元并進行消元
void upperTrangle(vector<vector<double>> &a,int n) {
	double tmp; //用于記錄消元時的因數
	for (int i = 1; i <= n; i++) {
		int r = i;
		for (int j = i + 1; j <= n; j++)
			if (fabs(a[j][i]) > fabs(a[r][i]))
				r = j;
		if (r != i)
			for (int j = i; j <= n + 1; j++)
				swap(a[i][j], a[r][j]);//與最大主元所在行交換
		for (int j = i + 1; j <= n; j++) {//消元
			tmp = a[j][i] / a[i][i];
			for (int k = i; k <= n + 1; k++)
				a[j][k] -= a[i][k] * tmp;
		}
	}
}
//高斯消元法(列選主元)
void Gauss(vector<vector<double>> &a, int n) {
	upperTrangle(a, n);//列選主元并消元成上三角

	for (int i = n; i >= 1; i--) {//回代求解
		for (int j = i + 1; j <= n; j++)
			a[i][n + 1] -= a[i][j] * a[j][n + 1];
		a[i][n + 1] /= a[i][i];
	}
}

vector<int> findAdjacent(vector<int> edges, int p) {//找p點的相鄰點
	vector<int> points;
	for (int i = 0; i < edges.size() / 2; ++i) {
		if (edges[2 * i] == p) {
			points.push_back(edges[2 * i + 1]);
		}
		else if (edges[2 * i + 1] == p) {
			points.push_back(edges[2 * i]);
		}
	}
	return points;
}

double func(int n, int hp, vector<int>& damage, vector<int>& edges) {
	vector<vector<double>> dp;
	for (int i = 0; i < hp + 1; ++i) {
		vector<double> tmp;
		for (int j = 0; j < n + 1; ++j) {
			tmp.push_back(0);
		}
		dp.push_back(tmp);
	}
	dp[hp][1] = 1;

	vector<int> adjacentCount;//鄰接點個數,下标是點辨別
	for(int i=0;i<=n;++i){
		if(i==0) adjacentCount.push_back(0);
		else adjacentCount.push_back(findAdjacent(edges,i).size());
	}

	for (int row = hp; row >= 1; --row) {
		//先計算damage不為0的點
		for (int col = 1; col <= n; ++col) {
			if (damage[col - 1] > 0) {
				for (int i : findAdjacent(edges, col)) {//不為終點的相鄰點
					if (i != n && row + damage[col - 1] <= hp) {
						dp[row][col] += dp[row + damage[col - 1]][i]/(double)adjacentCount[i];
					}
				}
			}
		}
	    
		//計算damage為0的點
		vector<int> zero;
		vector<vector<double>> matrix;//增廣矩陣的擴大

		for (int col = 1; col <= n; ++col) {
			if (damage[col - 1] == 0) {
				zero.push_back(col);
			}
		}

		for (int i = 0; i < zero.size() + 1; ++i) {//矩陣n+1行 n+2列 第1行第1列均為0 其餘部分為增廣矩陣
			vector<double> tmp;
			for (int j = 0; j < zero.size() + 2; ++j) {
				tmp.push_back(0);
			}
			matrix.push_back(tmp);
		}
 

		//填充增廣矩陣
		for (int i = 0;i<zero.size();++i){
			matrix[i + 1][i + 1] = -1;
			matrix[i+1][zero.size() + 1] = -dp[row][zero[i]];//常數項
			
			for (int k : findAdjacent(edges, zero[i])) {
				if (k != n) {

					if (damage[k-1] > 0) {//若damage>0,則為常數項
						matrix[i+1][zero.size() + 1] -= dp[row][k]/(double)adjacentCount[k];
					}
					else {//若damage=0,則為未知項
						for (int index = 0; index < zero.size(); ++index) {
							if (zero[index] == k) {
								matrix[i+1][index + 1] = 1/(double)adjacentCount[k];
								break;
							}
						}
					}
				}
			}
			
		}

		//高斯消元法求解
		Gauss(matrix, zero.size());

		//将解寫回dp中
		for (int i = 0; i < zero.size(); ++i) {
			dp[row][zero[i]] = matrix[i + 1][zero.size() + 1];
		}
	}


	double result = 0;
	for (int i = 1; i <= hp; ++i) {
		result += dp[i][n];
	}

	return result;
}
           

時間複雜度是O(hp*n^2),空間複雜度是O(hp*n)。

繼續閱讀