Tree of Tree Time Limit: 1 Second Memory Limit: 32768 KB
You're given a tree with weights of each node, you need to find the maximum subtree of specified size of this tree.
Tree Definition
A tree is a connected graph which contains no cycles.
Input
There are several test cases in the input.
The first line of each case are two integers N(1 <= N <= 100), K(1 <= K <= N), where N is the number of nodes of this tree, and K is the subtree's size, followed by a line with N nonnegative integers, where the k-th integer indicates the weight of k-th node. The following N - 1 lines describe the tree, each line are two integers which means there is an edge between these two nodes. All indices above are zero-base and it is guaranteed that the description of the tree is correct.
Output
One line with a single integer for each case, which is the total weights of the maximum subtree.
Sample Input
3 1
10 20 30
0 1
0 2
3 2
10 20 30
0 1
0 2
Sample Output
30
40
Author: LIU, Yaoting
Source: ZOJ Monthly, May 2009
题意:
给你有n的点的树和要求的子树的点数,下一行跟的是n个点的权值,然后跟着n-1条边连着这些点,求该n个点的树的子树中权值和最大的是多少?(需要满足子树点为k个)
分析:
首先想子树要被选中那么它的根节点必定要选,然后该节点被选中,它下面有很多个子树,一些要被选中,一些不用被选中,那么就是一个背包问题了。
开一个dp[u][j]的二维数组u表示当前根节点,j表示这个树下面有多少个点。
而dp[u][1]=a[u]是显而易见的,那么外层必定是j从k枚举到1,再枚举有多少个点在这个子树上,多少个点不在这个子树上即可!
即dp方程为dp[u][j]=max(dp[u][j],dp[u][kk]+dp[v][j-kk])(kk是在u上的点,j-kk是不在u上的点,v是u的子树的根节点)
code:
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
int a[105];
struct s{
int to,next,w;
}hehe[420];
int p[105],eid,dp[105][105],vis[105],k;
void init(){
memset(p,-1,sizeof(p));
memset(dp,0,sizeof(dp));
memset(vis,0,sizeof(vis));
eid=0;
}
void ljb(int from,int to,int w){
hehe[eid].to=to;
hehe[eid].w=w;
hehe[eid].next=p[from];
p[from]=eid++;
}
void dfs(int u){
int v;
vis[u]=1;
dp[u][1]=a[u];
for(int i=p[u];i!=-1;i=hehe[i].next){
v=hehe[i].to;
if(vis[v])
continue;
dfs(v);
for(int j=k;j>=1;j--)
for(int kk=1;kk<=j;kk++)
dp[u][j]=max(dp[u][j],dp[u][kk]+dp[v][j-kk]);
}
}
int main()
{
int n,x,y;
while(~scanf("%d%d",&n,&k)){
init();
for(int i=0;i<n;i++)
scanf("%d",&a[i]);
for(int i=1;i<n;i++){
scanf("%d%d",&x,&y);
ljb(x,y,1);
ljb(y,x,1);
}
dfs(0);
int ans=0;
for(int i=0;i<n;i++)
if(dp[i][k]>ans)
ans=dp[i][k];
printf("%d\n",ans);
}
}