天天看點

HDU-5242 Game (貪心&&樹鍊剖分&&線段樹)

題目:

http://acm.hdu.edu.cn/showproblem.php?pid=5242

題意:

給出一顆以1節點為根的樹,每個節點有各自的價值,有m次從根節點出發向下走到葉子節點的機會,每次會得到所有經過節點的權值,每個節點隻有在第一次經過時有價值,求m次之後能夠獲得的最大權值。

思路:

典型的樹鍊剖分題,隻要找到所有重鍊的權值然後貪心找前m個的和就行了。

具體解法就是先dfs找到所有葉子節點從根走下來得到的總權值,排序之後将有重複路徑的節點權值減去,這就是一個找重鍊的過程,對于每個節點,他的所有兒子中能夠找到最大權值鍊的那一條就是和這個節點在同一條重鍊上的,其他兒子節點作為其他重鍊的新起點,最後結構造出了一個包含重鍊的樹。每個節點都一定且僅在一條重鍊中,并且每條重鍊都包含一個葉子節點,是以隻要找到權值最大的m條重鍊就是最大總權值了。

一開始并沒有想到寫樹鍊剖分,用線段樹寫的。

前面的處理一樣,找到每個葉子節點的總權值,然後以葉子節點建樹。在dfs找葉子節點總權值時回溯處理出來每個節點在這個線段樹中包含的區間。然後線段樹求m次最大值就行了。每次query之後從該葉子節點向上搜尋找到他的所有父節點對這些父節點線上段樹中包含的葉子節點都進行區間更新,就是都減去這個節點的權值,保證已經更新過的節點就不更新了,是以整個更新的複雜度為O(n*logn)。找最大的m個節點的複雜度是O(m*logn)。

是以樹鍊剖分的複雜度是O(2*n),線段樹葉子節點是O(n*logn+m*logn),實際表現差不多。

代碼:

樹鍊剖分:

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define MOD 1000000007
#define EPS 1e-6
#define N 112345
using namespace std;
struct node
{
    long long val,id;
    friend bool operator < (node a, node b)
    {
        return a.val > b.val;
    }
}p[N];
long long n,m,res,flag,tot;
vector<long long>zi[N];
long long fa[N],val[N],ans[N];
bool vis[N];
void init()
{
    for(int i=0;i<=n;i++)zi[i].clear();
    memset(vis,0,sizeof(vis));
    res=0;tot=0;
}
void dfs(long long now, long long vall)
{
    int num=zi[now].size();
    if(num==0)
        p[tot].id=now, p[tot++].val=val[now]+vall;
    else
        for(int i=0;i<num;i++)
            dfs(zi[now][i],vall+val[now]);
}
long long dfs1(long long now)
{
    if(vis[now])return 0;
    vis[now]=1;
    return val[now]+dfs1(fa[now]);
}
int main()
{
    long long i,j,k,kk,cas,T,t,x,y,z;
    scanf("%I64d",&T);
    cas=0;
    while(T--)
    {
        scanf("%I64d%I64d",&n,&m);
        init();
        for(i=1;i<=n;i++)scanf("%I64d",&val[i]);
        for(i=1;i<n;i++)
        {
            scanf("%I64d%I64d",&x,&y);
            zi[x].push_back(y);
            fa[y]=x;
        }
        dfs(1,0);
        sort(p,p+tot);
        for(i=0;i<tot;i++)
            ans[i]=dfs1(p[i].id);
        sort(ans,ans+tot);
        for(i=tot-1,j=0;i>=0&&j<m;j++,i--)res+=ans[i];
        printf("Case #%I64d: %I64d\n",++cas,res);
    }
    return 0;
}
           

線段樹:

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define MOD 1000000007
#define EPS 1e-6
#define N 112345
using namespace std;
struct node
{
    long long sum,side;
}sum[N<<2],ttt;
long long n,res,flag,tot;
long long a[N],b[N],hehe[N],xixi[N];
#define root 1 , tot , 1
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
vector<long long>zi[N];
long long fa[N];
bool vis[N];
long long val[N];
long long add[N<<2];
void pushUp(long long rt)
{
    if(sum[rt<<1].sum>=sum[rt<<1|1].sum)
    {
        sum[rt]=sum[rt<<1];
    }
    else
    {
        sum[rt]=sum[rt<<1|1];
    }
}
void pushDown(long long l,long long r,long long rt)
{
    if(add[rt])
    {
        long long m = (l+r)>>1;
        add[rt<<1] += add[rt];
        add[rt<<1|1] += add[rt];
        sum[rt<<1].sum += add[rt];
        sum[rt<<1|1].sum += add[rt];
        add[rt] = 0;
    }
}
void update(long long l,long long r,long long rt,long long ql,long long qr,long long val)
{
    if(l>qr||ql>r)return;
    if(l>=ql&&r<=qr)
    {
        sum[rt].sum += val;
        add[rt] += val;
        return;
    }
    pushDown(l,r,rt);
    long long m = (l+r)>>1;
    if(ql<=m)update(lson,ql,qr,val);
    if(qr>m)update(rson,ql,qr,val);
    pushUp(rt);
}
void build(long long l,long long r,long long rt)
{
    add[rt]=0;
    if(l == r)
    {
        sum[rt].sum=hehe[res++];
        sum[rt].side=xixi[res-1];
        return;
    }
    long long m = (l+r)>>1;
    build(lson);
    build(rson);
    pushUp(rt);
}
node query(long long l,long long r,long long rt,long long ql,long long qr)
{
    if(l>qr||ql>r)
        return ttt;
    if(l>=ql&&r<=qr)
        return sum[rt];
    pushDown(l,r,rt);
    long long m = l+r>>1;
    node x=query(l,m,rt<<1,ql,qr);
    node y=query(m+1,r,rt<<1|1,ql,qr);
    return x.sum>=y.sum?x:y;
}
void init()
{
    for(int i=0;i<=n;i++)zi[i].clear();
    memset(vis,0,sizeof(vis));
    ttt.sum=-1;
}
void dfs(int now,long long vall)
{
    long long len=zi[now].size();
    if(len==0)
    {
        hehe[tot++]=val[now]+vall;
        xixi[tot-1]=now;
        a[now]=b[now]=tot-1;
        return ;
    }
    for(int i=0;i<len;i++)
        dfs(zi[now][i],vall+val[now]);
    long long x=INF,y=-1;
    for(int i=0;i<len;i++)
    {
        x=min(x,a[zi[now][i]]);
        y=max(y,b[zi[now][i]]);
    }
    a[now]=x;b[now]=y;
}
void dfs1(long long now)
{
    if(vis[now])return;
    vis[now]=true;
    update(root,a[now]+1,b[now]+1,-val[now]);
    if(now==1)return;
    dfs1(fa[now]);
}
int main()
{
    long long i,j,k,kk,cas,T,t,x,y,z;
    scanf("%I64d",&T);
    cas=0;
    while(T--)
    {
        long long m;
        scanf("%I64d%I64d",&n,&m);
        init();
        for(i=1;i<=n;i++)scanf("%I64d",&val[i]);
        for(i=1;i<n;i++)
        {
            scanf("%I64d%I64d",&x,&y);
            zi[x].push_back(y);
            fa[y]=x;
        }
        tot=0;
        dfs(1,0);
        res=0;
        build(root);
        long long res=0;
        for(i=0;i<m;i++)
        {
            node t=query(root,1,tot);
            res+=t.sum;
            dfs1(t.side);
        }
        printf("Case #%I64d: %I64d\n",++cas,res);
    }
    return 0;
}
           

繼續閱讀