天天看点

2019牛客多校赛 第七场 H Pair(数位dp)

2019牛客多校赛 第七场 H Pair(数位dp)
2019牛客多校赛 第七场 H Pair(数位dp)

大致题意:告诉你范围A和B,让你求在两个范围内,有多少对数字可以使得二者按位与大于C或者异或小于C。

一个比较常规的数位dp,然而比赛的时候由于自己复杂度计算错误,还写了好久的优化,最后发现不优化也能过。

我们令dp[len][x][y][lim1][lim2]表示在二进制下,当前长度为len的时候,第一个条件的状态为x,第二个条件状态为j,第一个数字的限制情况为lim1,第二个为lim2时候的方案数。这里,我们条件总共有3个状态:0表示可能满足,1表示已经满足,2表示已经不满足。这是因为在二进制下的大于与小于关系,如果从高位往地位走,在某一位确定大小之后,后面就不会改变结果,而数位dp的过程本身就是高位到低位走的。

然后,转移的话,直接枚举两个数字在第len位的取值0或者1,然后根据当前两个条件的状态判断两个数字的取值是否合法,以及条件状态的变化。最后,结束条件是长度为0的时候,如果两个条件之一满足,那么当前方案合法。

最后说一下可有可无的优化,如果到某一个长度的时候,发现已经有一个条件满足了,那么我们可以不用继续算下去,根据两个数字的限制情况可以直接计算方案数。另外,如果某个时刻,两个条件都已经不满足,那么直接返回方案数0。

最后的最后,由于数字取值不能是0,所以要减去所有包含0的答案。具体见代码:

#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
#define N 2010
#define INF 0x3f3f3f3f3f3f3f3fll
#define eps 1e-5
#define pi 3.141592653589793
#define mod 998244353
#define P 1000000007
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define ls (node<<1)
#define rs (node<<1|1)
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<"      :   "<<x<<endl
#define mem(x) memset(x,0,sizeof x)
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;

LL dp[32][3][3][2][2],pw[63],n1[32],n2[32];
int num1[32],num2[32],num3[32];

LL dfs(int len,int c1,int c2,bool lim1,bool lim2)
{
    if (c1==2&&c2==2) return 0;
    if (len==0) return c1==1||c2==1;
    if (c1==1||c2==1)
    {
        LL a=pw[len],b=pw[len];
        if (lim1) a=n1[len]+1;
        if (lim2) b=n2[len]+1;
        return a*b;
    }
    if (~dp[len][c1][c2][lim1][lim2]) return dp[len][c1][c2][lim1][lim2];
    int up1=lim1?num1[len]:1,up2=lim2?num2[len]:1; LL res=0;
    for(int i=0;i<=up1;i++)
        for(int j=0;j<=up2;j++)
        {
            int x=i&j,y=i^j,nc1=c1,nc2=c2;
            if ((!c1&&x<num3[len]||c1==2)&&(!c2&&y>num3[len]||c2==2)) continue;
            if (!c1) if (x<num3[len]) nc1=2; else if (x>num3[len]) nc1=1; else nc1=0;
            if (!c2) if (y>num3[len]) nc2=2; else if (y<num3[len]) nc2=1; else nc2=0;
            res+=dfs(len-1,nc1,nc2,lim1&&i==up1,lim2&&j==up2);
        }
    dp[len][c1][c2][lim1][lim2]=res;
    return res;
}

inline LL cal(LL A,LL B,LL C)
{
    memset(num1,0,sizeof(num1));
    memset(num2,0,sizeof(num2));
    memset(num3,0,sizeof(num3));
    int t1=0,t2=0,t3=0;
    while(A) num1[++t1]=A&1,A>>=1;
    while(B) num2[++t2]=B&1,B>>=1;
    while(C) num3[++t3]=C&1,C>>=1;
    int t=max(t1,max(t2,t3));
    for(int i=t;i;i--)
    {
        n1[i]=n2[i]=0;
        for(int j=i;j;j--)
        {
            n1[i]=n1[i]<<1|num1[j];
            n2[i]=n2[i]<<1|num2[j];
        }
    }
    return dfs(t,0,0,1,1);
}

void init()
{
    pw[0]=1;
    for(int i=1;i<31;i++)
        pw[i]=pw[i-1]*2;
}

int main(){
    int T;
    sc(T);
    init();
    while(T--){
        memset(dp,-1,sizeof(dp));
        int A,B,C; sccc(A,B,C);
        printf("%lld\n",cal(A,B,C)-min(A,C-1)-min(B,C-1)-1);
    }
    return 0;
}