天天看點

Hdu 6769 In Search of Gold —— 上下界優化,樹形DP

This way

題意:

現在有一顆大小為n的樹,每條邊都有兩個權值:a,b現在讓你最多選k個邊的權值為a,其它邊的權值為b,使得最終這棵樹的直徑最短。問你最短是多少。

題解:

最大值最小的問題考慮二分。dp[i][j]表示到第i個點,它的子樹中用了j個a的最長長度最短是多少。

然後枚舉目前點的選了多少個a的同時枚舉子樹選了多少個a來進行轉移。

但是可以發現這個的時間複雜度是 O ( T ∗ N ∗ K 2 ∗ l o g 2 e 13 ) O(T*N*K^2*log^{2e13}) O(T∗N∗K2∗log2e13)的,那麼很明顯就超過了所給的時限。

這個時候用到了一個知識叫做樹上背包的上下界優化,也就是說,對于目前點,枚舉它已經有的值,然後對于兒子節點,枚舉他們的和不超過k的值。雖然看起來這無可厚非,但是實際上最終的複雜度少了一個k。我其實不是很懂,因為在比賽的時候我被A題卡住都還沒看過這道題目,但是我還是在這裡口胡一下:

假設目前的樹是一條鍊,那麼易證這種方法的時間複雜度是 O ( N K ) O(NK) O(NK)的,因為每個點隻會将兒子節點轉移到它,是以每個點最多會做k次。然後如果将這條鍊折了一半,最上面的點的複雜度變成 O ( k 2 / 2 ) O(k^2/2) O(k2/2)。但是與此同時,最下面的k個點的總和時間複雜度從 k 2 k^2 k2變成了 1 + 2 + . . . + k → O ( k 2 / 2 ) 1+2+...+k\rightarrow O(k^2/2) 1+2+...+k→O(k2/2),也就是少了 k 2 / 2 k^2/2 k2/2的時間複雜度(大緻)。那麼總的時間複雜度看起來還是不變的,那麼以此類推…總的時間複雜度為 O ( T ∗ N ∗ K ∗ l o g 2 e 13 ) → 2 e 5 ∗ 20 ∗ 50 O(T*N*K*log^{2e13})\rightarrow 2e5*20*50 O(T∗N∗K∗log2e13)→2e5∗20∗50大概就是 2 e 8 2e8 2e8的時間複雜度

事實上我也在經過了一次次嘗試之後過了這道題

Hdu 6769 In Search of Gold —— 上下界優化,樹形DP
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=2e4+5;
const ll inf=2e14;
struct node{
    int x;
    ll a,b;
};
vector<node>vec[N];
int n,k;
ll dp[N][25],tmp[25],mid;
int siz[N];
void dfs(int x,int fa){
    for(int i=0;i<=k;i++)
        dp[x][i]=0;
    siz[x]=0;
    for(node ne:vec[x]){
        if(ne.x==fa)continue;
        dfs(ne.x,x);
        for(int i=0;i<=k;i++)tmp[i]=inf;
        for(int i=0;i<=min(siz[x],k);i++){
            for(int j=0;j+i<=k&&j<=siz[ne.x];j++){
                if(dp[x][i]+dp[ne.x][j]+ne.a<=mid)
                    tmp[i+j+1]=min(tmp[i+j+1],max(dp[x][i],dp[ne.x][j]+ne.a));
                if(dp[x][i]+dp[ne.x][j]+ne.b<=mid)
                    tmp[i+j]=min(tmp[i+j],max(dp[x][i],dp[ne.x][j]+ne.b));
            }
        }
        siz[x]+=siz[ne.x]+1;
        for(int i=0;i<=k&&i<=siz[x];i++)
            dp[x][i]=tmp[i];
    }
}
int main()
{
    int t;
    scanf("%d",&t);
    while(t--){
        scanf("%d%d",&n,&k);
        for(int i=1;i<=n;i++)
            vec[i].clear();
        int x,y;
        ll a,b;
        ll l=0,r=0,ans;
        for(int i=1;i<n;i++){
            scanf("%d%d%lld%lld",&x,&y,&a,&b),vec[x].push_back({y,a,b}),vec[y].push_back({x,a,b});
            r+=max(a,b);
        }
        while(r>=l){
            mid=l+r>>1;
            dfs(1,0);
            if(dp[1][k]<=mid)
                r=mid-1,ans=mid;
            else
                l=mid+1;
        }
        printf("%lld\n",ans);
    }
    return 0;
}