天天看點

wikioi-天梯-普及一等-區間dp-1166:矩陣取數遊戲

題目描述 Description

【問題描述】

帥帥經常跟同學玩一個矩陣取數遊戲:對于一個給定的n*m 的矩陣,矩陣中的每個元素aij均

為非負整數。遊戲規則如下:

1. 每次取數時須從每行各取走一個元素,共n個。m次後取完矩陣所有元素;

2. 每次取走的各個元素隻能是該元素所在行的行首或行尾;

3. 每次取數都有一個得分值,為每行取數的得分之和,每行取數的得分= 被取走的元素值*2i,

其中i 表示第i 次取數(從1 開始編号);

4. 遊戲結束總得分為m次取數得分之和。

帥帥想請你幫忙寫一個程式,對于任意矩陣,可以求出取數後的最大得分。

輸入描述 Input Description

第1行為兩個用空格隔開的整數n和m。

第2~n+1 行為n*m矩陣,其中每行有m個用單個空格隔開的非負整數。

輸出描述 Output Description

輸出 僅包含1 行,為一個整數,即輸入矩陣取數後的最大得分。

樣例輸入 Sample Input

2 3

1 2 3

3 4 2

樣例輸出 Sample Output

82

資料範圍及提示 Data Size & Hint

樣例解釋

第 1 次:第1 行取行首元素,第2 行取行尾元素,本次得分為1*21+2*21=6

第2 次:兩行均取行首元素,本次得分為2*22+3*22=20

第3 次:得分為3*23+4*23=56。總得分為6+20+56=82

【限制】

60%的資料滿足:1<=n, m<=30, 答案不超過1016

100%的資料滿足:1<=n, m<=80, 0<=aij<=1000

類型:dp 難度:2

題意:給出n*m的矩陣,需要取m次,每次依次取每一行的行首或行尾的數,然後将這n個數的和乘以2^i,i為次數,1<=i<=m,将每次的計算值累加,求怎樣取,能獲得最大值

分析:首先,可以看出每一行的取法互相之間是獨立的,是以問題變成m個數,怎樣取能獲得最大值,然後重複進行n次。這個問題也是區間dp,可以看出進行到某個狀态時,該狀态可以由之前兩個狀态轉換而來,即去掉行首或行尾。

用dp[i][j]表示目前段為[i,j]時,已經處理的數[1,i-1],[j+1,m]所能獲得的最大值。

遞推方程:dp[i][j] = max(dp[i-1][j]+a[i-1]*2^k, dp[i][j+1]+a[j+1]*2^k)

其中,a[i]表示本行的第i個數,k為目前的次數,k = m-(j-i)+1

初始值dp[0][m-1] = 0,最後求dp[i][i] + a[i]*2^m 的最大值,0<=i<m

ps:本次的結果已經超過long long 範圍,還需要用高精度表示,我的思路是,用一個資料表示每個數的二進制,最後再将結果轉化為十進制,因為每一步都要乘以2^i,是以用二進制數組的話,隻需進行數組的移位操作即可。

還有一點注意:由于我用的二進制資料需要事先用memset清零,但是發現将數組作為實參(比如函數參數為int*)傳入函數,memset不能達到效果,原因是memset不知道函數參數(這個int*指針)指向的空間有多大,是以無法初始化,還是要在數組聲明的作用域來進行memset才能達到效果。

代碼:

#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;

int a[90][90],n,m;
int dp[90][90][150];

void lmov(int s,int l,int d[])
{    
    for(; s; l++,s>>=1)
        if(s&1) d[l] = 1;
}

void madd(int a[],int b[])
{
    for(int i=0; i<140; i++)
    {
        b[i] += a[i];
        for(int j=i; b[j]>1; j++)
        {
            b[j] = 0;
            b[j+1]++;
        }
    }
}

void mmax(int a[],int b[])
{
    for(int i=140; i>=0; i--)
    {
        if(a[i]>b[i]) return;
        else if(a[i]<b[i]) break;
    }
    for(int i=0; i<150; i++)
        a[i] = b[i];
}

void bintodec(int a[],int b[])
{
    int pt[150];
    memset(pt,0,sizeof(pt));
    pt[0] = 1;
    for(int i=0; i<140; i++)
    {
        if(a[i])
        {
            for(int j=0; j<140; j++)
                b[j] += pt[j];
            for(int j=0; j<140; j++)
                if(b[j]>9)
                {
                    b[j+1] += b[j]/10;
                    b[j] %= 10;
                }
        }
        for(int j=0; j<140; j++)
            pt[j] *= 2;
        for(int j=0; j<140; j++)
        {
            if(pt[j]>9)
            {
                pt[j+1] += pt[j]/10;
                pt[j] %= 10;
            }
        }
    }
    
    bool f = 0;
    for(int i=140; i>=0; i--)
    {
        if(!f && b[i]>0) f = 1;
        if(f) cout<<b[i];
    }
    if(!f) cout<<"0";
    cout<<endl;
}

int main()
{
    cin>>n>>m;
    for(int i=0; i<n; i++)
        for(int j=0; j<m; j++)
            cin>>a[i][j];

    int ans[150];
    memset(ans,0,sizeof(ans));
    for(int i=0; i<n; i++)
    {
        int ta[150];
        memset(dp,0,sizeof(dp));
        memset(ta,0,sizeof(ta));
        if(m==1)
        {
            lmov(a[i][0],1,ta);
        }
        for(int l=m-2; l>=0; l--)
        {
            for(int j=0; j+l<m; j++)
            {
                int tmp[150];
                if(j>0)
                {
                    memset(tmp,0,sizeof(tmp));
                    lmov(a[i][j-1],m-l-1,tmp);
                    madd(dp[j-1][j+l],tmp);
                    mmax(dp[j][j+l],tmp);
                }
                if(j+l<m-1)
                {
                    memset(tmp,0,sizeof(tmp));
                    lmov(a[i][j+l+1],m-l-1,tmp);
                    madd(dp[j][j+l+1],tmp);
                    mmax(dp[j][j+l],tmp);
                }
                if(l==0)
                {
                    memset(tmp,0,sizeof(tmp));
                    lmov(a[i][j],m,tmp);
                    madd(dp[j][j],tmp);
                    mmax(ta,tmp);
                }
            }
        }
        madd(ta,ans);
    }
    int ret[150];
    memset(ret,0,sizeof(ret));
    bintodec(ans,ret);
}