天天看點

POJ 2442 - Sequence - [小頂堆][優先隊列]

題目連結:​​http://poj.org/problem?id=2442​​

Time Limit: 6000MS Memory Limit: 65536K

Description

Given m sequences, each contains n non-negative integer. Now we may select one number from each sequence to form a sequence with m integers. It's clear that we may get n ^ m this kind of sequences. Then we can calculate the sum of numbers in each sequence, and get n ^ m values. What we need is the smallest n sums. Could you help us?

Input

The first line is an integer T, which shows the number of test cases, and then T test cases follow. The first line of each case contains two integers m, n (0 < m <= 100, 0 < n <= 2000). The following m lines indicate the m sequence respectively. No integer in the sequence is greater than 10000.

Output

For each test case, print a line with the smallest n sums in increasing order, which is separated by a space.

Sample Input

1

2 3

1 2 3

2 2 3

Sample Output

3 3 4

題意:

給出 $m$ 個長度為 $n$ 的序列,在每一個序列中挑選一個數求和,可知有 $n^m$ 個結果,要求給出這些結果中前 $n$ 小的。

題解(參考《算法競賽進階指南》):

首先考慮當 $M=2$ 時的做法,記這兩個序列為 $a,b$,先對所有的序列進行升序排序,并且用兩個指針 $p,q$ 分别指向這兩個序列的頭部 $a[0],b[0]$,顯然此時就是第 $1$ 小的和 $S_1 = a[0] + b[0]$。

那麼,第 $2$ 小的和 $S_2 = \min (a[1] + b[0], a[0] + b[1])$,

如果第 $2$ 小和是 $S_2 = a[1] + b[0]$,那麼此時競争的第 $3$ 小的候選者,除了上面已經存在的 $a[0] + b[1]$ 外還應當再增加 $a[2] + b[0], a[1] + b[1]$,也就是序列一的指針 $ptr_1++$ 或者序列二的指針 $ptr_2++$。

不難想到,可以用一個小頂堆(或者STL的優先隊列)來維護所有候選者,不斷地扔進去新晉的候選者,再取出堆頂作為答案之一,再由該堆頂求出新的候選者插入堆中,再取出堆頂,反複如此……

候選方案産生方式如圖:

POJ 2442 - Sequence - [小頂堆][優先隊列]

不過,有一點需要注意,例如我們在上述例子中已得 $S_2 = a[1] + b[0]$,那麼入堆兩個新的候選者之後堆中有三個節點: $a[0] + b[1], a[2] + b[0], a[1] + b[1]$;此時如果 $S_3 = a[0] + b[1]$,會發現又一次産生了候選者 $a[1] + b[1]$,這樣就産生了重複,具體表現就是在上圖中,就是絕大部分候選方案從 $(0,0)$ 出發可以有多條路徑到達。重複方案多次入堆顯然是影響正确性的,是以要避免這種情況發生。

不妨增加限定,如果目前選中的方案所産生新的候選者是 $ptr_2++$,那麼以後都隻能 $ptr_2++$,不能再回到 $ptr_1++$。換句話說,$a[0]+b[0]$ 要走到任何候選方案 $a[i]+a[j]$,必須先移動 $ptr_1 = 0 \sim i$,再移動 $ptr_2 = 0 \sim j$,使得到達備選方案 $a[i]+a[j]$ 的路徑的唯一性,這樣即可避免産生同一個候選方案重複入隊的情況。

增加限定條件後的候選方案産生方式如圖(不難看出,已經變為了一棵樹):

POJ 2442 - Sequence - [小頂堆][優先隊列]

考慮到添加了該限定條件後,是否影響到算法的正确性:

考慮原算法在標明 $a[i] + b[j]$ 成為第 $k$ 小之後,原本會産生 $a[i+1] + b[j]$ 和 $a[i] + b[j+1]$ 兩個新的候選者。那麼增加限定條件後,是否會發生本來第 $k+1$ 小應當是 $a[i+1] + b[j]$ 但是現在卻沒有生成該候選方案的情況?

根據限定條件,可知 $a[i] + b[j]$ 該方案成為第 $k$ 小,必然是由于 $a[i] + b[j-1]$ 産生了它,往前依次類推必然是在某一時刻選擇了 $a[i] + b[0]$,而 $a[i] + b[0]$ 會産生兩個候選者 $a[i+1] + b[0]$ 和 $a[i] + b[1]$,由此可知選中方案 $a[i] + b[j]$,隻要 $j \ge 1$,那麼此時隊列裡必然曾經出現過 $a[i+1] + b[0]$。

是以,如果說標明 $a[i] + b[j]$ 成為第 $k$ 小同時堆中沒有 $a[i+1] + b[j]$,那麼此時堆中必然存在 $a[i+1] +b[0],a[i+1] +b[1], \cdots ,a[i+1] +b[j-1]$ 中的某一個。顯然,存在比 $a[i+1] + b[j]$ 還小的方案,肯定不會選到 $a[i+1] + b[j]$,是以增加該限定條件不影響算法正确性。

時間複雜度:

由于我們每次擷取并删除一個堆頂,最多往堆中插入兩個新的元素,是以每一次堆的大小最多增加 $1$,又因為最多隻有 $N$ 次出堆操作且堆初始為空,是以堆的規模最大為 $O(n)$。

是以每次push和pop都是 $O(\log n)$ 的複雜度,而最多做 $O(n)$ 次push和pop,是以 $O(n \log n)$ 就能求得兩個序列的前 $n$ 小的和。

而對于 $M>2$ 的情況,可以先求出第一個序列和第二個序列的前 $n$ 小的和,作為一個新序列再去和第三個序列求前 $n$ 小的和,以此類推總的時間複雜度為 $O(mn \log n)$。

AC代碼:

手寫二叉堆版本(500ms):

#include<cstdio>
#include<iostream>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
typedef pair<int,int> pii;

const int maxm=100+5;
const int maxn=2000+5;

int m,n;
vector<int> a[maxm];
struct Plan{
    pii ptr;
    bool last; //記錄本方案是否是由ptr2++得到的
    int sum;
    Plan(){}
    Plan(int i,int j,pii _ptr,bool _last)
    {
        ptr=_ptr;
        last=_last;
        sum=a[i][ptr.first]+a[j][ptr.second];
    }
}init;

struct Heap
{
    int sz;
    Plan heap[3*maxn];
    void up(int now)
    {
        while(now>1)
        {
            int par=now>>1;
            if(heap[now].sum<heap[par].sum) //子節點小于父節點,不滿足小頂堆性質
            {
                swap(heap[par],heap[now]);
                now=par;
            }
            else break;
        }
    }
    void push(const Plan &x) //插入權值為x的節點
    {
        heap[++sz]=x;
        up(sz);
    }
    inline Plan top(){return heap[1];}
    void down(int now)
    {
        while((now<<1)<=sz)
        {
            int nxt=now<<1;
            if(nxt+1<=sz && heap[nxt+1].sum<heap[nxt].sum) nxt++; //取左右子節點中較小的
            if(heap[now].sum>heap[nxt].sum) //子節點小于父節點,不滿足小頂堆性質
            {
                swap(heap[now],heap[nxt]);
                now=nxt;
            }
            else break;
        }
    }
    void pop() //移除堆頂
    {
        heap[1]=heap[sz--];
        down(1);
    }
    void del(int p) //删除存儲在數組下标為p位置的節點
    {
        heap[p]=heap[sz--];
        up(p), down(p);
    }
    inline void clr(){sz=0;}
}h;

int main()
{
    int T;
    cin>>T;
    while(T--)
    {
        scanf("%d%d",&m,&n);
        for(int i=1;i<=m;i++)
        {
            a[i].clear();
            for(int j=1,x;j<=n;j++)
            {
                scanf("%d",&x);
                a[i].push_back(x);
            }
            sort(a[i].begin(),a[i].end());
        }

        int i=1;
        for(int j=2;j<=m;j++,i^=1)
        {
            init=Plan(i,j,make_pair(0,0),0);
            h.clr();
            h.push(init);
            a[i^1].clear();
            while(h.sz)
            {
                Plan now=h.top(); h.pop();

                a[i^1].push_back(now.sum);
                if(a[i^1].size()>=n) break;

                if(!now.last && now.ptr.first<n-1)
                    h.push(Plan(i,j,make_pair(now.ptr.first+1,now.ptr.second),0));
                if(now.ptr.second<n-1)
                    h.push(Plan(i,j,make_pair(now.ptr.first,now.ptr.second+1),1));
            }
        }
        for(int k=0;k<a[i].size();k++) printf("%d ",a[i][k]);
        printf("\n");
    }
}      

優先隊列版本(563ms):

#include<cstdio>
#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<algorithm>
using namespace std;
typedef pair<int,int> pii;

const int maxm=100+5;
const int maxn=2000+5;

int m,n;
vector<int> a[maxm];
struct Plan{
    pii ptr;
    bool last; //記錄本方案是否是由ptr2++得到的
    int sum;
    Plan(){}
    Plan(int i,int j,pii _ptr,bool _last)
    {
        ptr=_ptr;
        last=_last;
        sum=a[i][ptr.first]+a[j][ptr.second];
    }
    bool operator<(const Plan& oth)const{return sum>oth.sum;}
}init;

priority_queue<Plan> Q;

int main()
{
    int T;
    cin>>T;
    while(T--)
    {
        scanf("%d%d",&m,&n);
        for(int i=1;i<=m;i++)
        {
            a[i].clear();
            for(int j=1,x;j<=n;j++)
            {
                scanf("%d",&x);
                a[i].push_back(x);
            }
            sort(a[i].begin(),a[i].end());
        }

        int i=1;
        for(int j=2;j<=m;j++,i^=1)
        {
            init=Plan(i,j,make_pair(0,0),0);
            while(!Q.empty()) Q.pop();
            Q.push(init);
            a[i^1].clear();
            while(!Q.empty())
            {
                Plan now=Q.top(); Q.pop();

                a[i^1].push_back(now.sum);
                if(a[i^1].size()>=n) break;

                if(!now.last && now.ptr.first<n-1)
                    Q.push(Plan(i,j,make_pair(now.ptr.first+1,now.ptr.second),0));
                if(now.ptr.second<n-1)
                    Q.push(Plan(i,j,make_pair(now.ptr.first,now.ptr.second+1),1));
            }
        }
        for(int k=0;k<a[i].size();k++) printf("%d ",a[i][k]);
        printf("\n");
    }
}