问题:假设n个已经排序的关键字{1,2,…,n},他们被搜索的概率分别为{p1,p2,…,pn},其它搜索值则被这些关键字分割,它们的概率为{q0,q1,…,qn}。q0代表搜索值小于1的可能性,qn代表搜索值大于n的可能性,qi(i=1,2,,…,n-1)代表搜索之大于i小于i+1的可能性。构建二叉搜索树,使得值搜索的期望代价最小。
分析:令a[i][j]代表i->j的最优二叉搜索树的期望代价。于是我们可以得到以下递推式:

这个算法的复杂度为O(n^3)。
Knuth已经证明对于所有的1 <= i < j <= n,总存在最优子树的根使得root[i][j-1] <= root[i][j] <= root[i+1][j]。所以只要在i<j时,在求min操作中,使root[i][j-1] <= k <= root[i+1][j]。此时算法复杂度为O(n^2)。
代码:
#include <iostream>
#define MAXLENGTH 100
#define MAX 100
#define LEFT false
#define RIGHT true
using namespace std;
/**
* @brief Create the optimal binary search tree. The time complexity is O(n^3).
*
* @param p[]: the possibility of inner nodes. The nodes are assumed to have been sorted by their key
* @param q[]: the possibility of outer nodes. These nodes have the keys who are not in the binary search tree
* @param n: the number of inner nodes
* @param a[][MAXLENGTH]: a[i][j] means the minimum cost of tree i->j
* @param root[][MAXLENGTH]: root[i][j] means the root of tree i->j
*/
void OptimalBinarySearchTree(double p[], double q[], int n, double a[][MAXLENGTH], int root[][MAXLENGTH])
{
double w[MAXLENGTH][MAXLENGTH];
//initial states
for (int i = 0; i <= n; ++i)
{
a[i + 1][i] = q[i];
w[i + 1][i] = q[i];
}
for (int len = 1; len <= n; ++len)
{
for (int i = 1; i <= n - len + 1; ++i)
{
int j = i + len - 1;
a[i][j] = MAX;
w[i][j] = w[i][j - 1] + p[j] + q[j];
//if use {for (int k = root[i][j-1]; k <= root[i+1][j]; ++k)}, the time complexity will be O(n^2)
//this method only applies to such a situation when 1 <= i < j <= n, so when len == 1, we should not use this method
for (int k = i; k <= j; ++k)
{
double temp = a[i][k - 1] + a[k + 1][j] + w[i][j];
//get the minimum cost
if (temp < a[i][j])
{
a[i][j] = temp;
root[i][j] = k;
}
}
}
}
}
/**
* @brief Print the tree
*
* @param root[][MAXLENGTH]: the root matrix
* @param start: the start position of the tree
* @param end: the end position of the tree
* @param k: the parent, -1 means root has no parent
* @param LeftOrRight: the left or right subtree of the parent
*/
void PrintTree(int root[][MAXLENGTH], int start, int end, int k, bool LeftOrRight)
{
if (k == -1)
{
cout << "the root is p" << root[start][end] << endl;
PrintTree(root, start, root[start][end] - 1, root[start][end], LEFT);
PrintTree(root, root[start][end] + 1, end, root[start][end], RIGHT);
}
else if (LeftOrRight == LEFT)
{
if (start <= end)
{
cout << "the left child of p" << k << " is p" << root[start][end] << endl;
PrintTree(root, start, root[start][end] - 1, root[start][end], LEFT);
PrintTree(root, root[start][end] + 1, end, root[start][end], RIGHT);
}
else
{
cout << "the left child of p" << k << " is q" << start - 1 << endl;
}
}
else
{
if (start <= end)
{
cout << "the right child of p" << k << " is p" << root[start][end] << endl;
PrintTree(root, start, root[start][end] - 1, root[start][end], LEFT);
PrintTree(root, root[start][end] + 1, end, root[start][end], RIGHT);
}
else
{
cout << "the right child of p" << k << " is q" << start - 1 << endl;
}
}
}
int main()
{
double p[MAXLENGTH];
double q[MAXLENGTH];
double a[MAXLENGTH][MAXLENGTH];
int root[MAXLENGTH][MAXLENGTH];
int n;
while (cin >> n, n != 0)
{
for (int i = 1; i <= n; ++i)
{
cin >> p[i];
}
for (int i = 0; i <= n; ++i)
{
cin >> q[i];
}
OptimalBinarySearchTree(p, q, n, a, root);
cout << a[1][n] << endl;
PrintTree(root, 1, n, -1, true);
}
}