二叉搜尋樹 題解
題意
(~~~~) 給出一個排列,并允許将 ([l,r]) 自行重排,求在這之後将所有數依次加入一棵 BST 的最小總深度。
(~~~~) (1leq nleq 10^5,1leq r-l+1leq 200)
題解
(~~~~) 首先大喊三聲:“Reanap yyds!”
(~~~~) 首先有一種樸素的方法,即枚舉 ([l,r]) 内的排列,然後依次插入二叉搜尋樹。
(~~~~) 考慮一種插入二叉搜尋樹的辦法,首先用
set
記錄一下已經插入的所有數,若新加入的某個數 (w) 在 (x) 和 (y) 之間,則這個數要麼是 (x) 的右兒子,要麼是 (y) 的左兒子,同時其左兒子的值 (in) ([x+1,w-1]) ,右兒子的值 (in) ([w+1,y]) 。為了友善維護,我們預設為 (x) 的兒子,則每次加入 (w) 後 (dep_w leftarrow dep_x+1) ,同時 (dep_x leftarrow dep_x+1) 以保證之後加入的正确性,這樣可以對任何已知序列做到 (mathcal{O(nlog n)}) 求出答案。
(~~~~) 現在考慮重排的問題,在重排之前 ([1,l-1]) 已經插入,同時分割出了若幹區間。同時每個區間都是與其他區間互相獨立的。是以我們現在考慮怎麼在每個區間插入值使得目前及之後插入的代價最小。
(~~~~) 這裡我們使用代價提前計算的技巧,即對于一次在 ([x,y]) 之間插入 (w) ,則将 ([x,w-1]) 和 ([w+1,y]) 的深度提前 (+1) ,這樣的話我們可以用一個區間DP來計算每個區間。
(~~~~) 定義 (dp_{l,r}) ,表示在某個小區間的 ([l,r]) 插入後的最小代價。則枚舉轉移點可以得到 (dp_{l,r}=max_{i=l}^r dp_{l,i-1}+dp_{i+1,r}+r-l) ,同時記錄轉移點以還原重排後的序列再插入即可。
代碼
檢視代碼
#include <set>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;
int n;
set<int>S;
set<int>::iterator it;
int arr[100005];
int dep[100005];
vector <int> V[100005];
ll Solve()
{
S.clear();S.insert(0);
ll ret=0;
for(int i=1;i<=n;i++) dep[i]=0;dep[0]=0;
for(int i=1;i<=n;i++)
{
int x=arr[i];
it=prev(S.lower_bound(x));
dep[x]=dep[*it]+1;dep[*it]++;ret+=dep[x];S.insert(x);
}
return ret;
}
int tot=0,cnt=0;
int dp[205][205],P[205],from[205][205],ord[205],nxt[100005];
void Rev(int l,int r)
{
if(l>r) return;
if(l==r)
{
ord[++tot]=P[l];
return;
}
ord[++tot]=P[from[l][r]];
Rev(l,from[l][r]-1); Rev(from[l][r]+1,r);
}
void DP(int R)
{
memset(dp,0,sizeof(dp));
for(int len=1;len<=R;len++)
{
for(int l=1;l+len-1<=R;l++)
{
int r=l+len-1;dp[l][r]=1e9;
for(int x=l;x<=r;x++)
{
if(dp[l][x-1]+dp[x+1][r]+(P[r+1]-1)-P[l-1]<dp[l][r])
{
dp[l][r]=dp[l][x-1]+dp[x+1][r]+(P[r+1]-1)-P[l-1];
from[l][r]=x;
}
}
}
}
Rev(1,R);
}
bool Beg[100005];
int l,r;
void Pre()
{
Beg[0]=true;S.insert(0);
for(int i=1;i<l;i++)
{
int x=arr[i];
it=prev(S.lower_bound(x));
Beg[x]=true;S.insert(x);
}
for(int i=l;i<=r;i++)
{
int x=arr[i];
it=prev(S.lower_bound(x));
V[*it].push_back(x);it++;
if(it==S.end()) nxt[x]=n+1;
else nxt[x]=*it;
}
for(int i=0;i<=n;i++)
{
if(Beg[i]&&!V[i].empty())
{
sort(V[i].begin(),V[i].end());
P[0]=i;cnt=0;for(int j=0;j<(int)V[i].size();j++) P[++cnt]=V[i][j];
P[cnt+1]=nxt[V[i][0]];
DP(cnt);
}
}
}
int main() {
scanf("%d",&n);
for(int i=1;i<=n;i++) scanf("%d",&arr[i]);
scanf("%d %d",&l,&r);
Pre();
for(int i=l;i<=r;i++) arr[i]=ord[i-l+1];
printf("%lld",Solve());
return 0;
}