大致题意:告诉你范围A和B,让你求在两个范围内,有多少对数字可以使得二者按位与大于C或者异或小于C。
一个比较常规的数位dp,然而比赛的时候由于自己复杂度计算错误,还写了好久的优化,最后发现不优化也能过。
我们令dp[len][x][y][lim1][lim2]表示在二进制下,当前长度为len的时候,第一个条件的状态为x,第二个条件状态为j,第一个数字的限制情况为lim1,第二个为lim2时候的方案数。这里,我们条件总共有3个状态:0表示可能满足,1表示已经满足,2表示已经不满足。这是因为在二进制下的大于与小于关系,如果从高位往地位走,在某一位确定大小之后,后面就不会改变结果,而数位dp的过程本身就是高位到低位走的。
然后,转移的话,直接枚举两个数字在第len位的取值0或者1,然后根据当前两个条件的状态判断两个数字的取值是否合法,以及条件状态的变化。最后,结束条件是长度为0的时候,如果两个条件之一满足,那么当前方案合法。
最后说一下可有可无的优化,如果到某一个长度的时候,发现已经有一个条件满足了,那么我们可以不用继续算下去,根据两个数字的限制情况可以直接计算方案数。另外,如果某个时刻,两个条件都已经不满足,那么直接返回方案数0。
最后的最后,由于数字取值不能是0,所以要减去所有包含0的答案。具体见代码:
#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <vector>
#define N 2010
#define INF 0x3f3f3f3f3f3f3f3fll
#define eps 1e-5
#define pi 3.141592653589793
#define mod 998244353
#define P 1000000007
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define ls (node<<1)
#define rs (node<<1|1)
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<<x<<endl
#define mem(x) memset(x,0,sizeof x)
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
LL dp[32][3][3][2][2],pw[63],n1[32],n2[32];
int num1[32],num2[32],num3[32];
LL dfs(int len,int c1,int c2,bool lim1,bool lim2)
{
if (c1==2&&c2==2) return 0;
if (len==0) return c1==1||c2==1;
if (c1==1||c2==1)
{
LL a=pw[len],b=pw[len];
if (lim1) a=n1[len]+1;
if (lim2) b=n2[len]+1;
return a*b;
}
if (~dp[len][c1][c2][lim1][lim2]) return dp[len][c1][c2][lim1][lim2];
int up1=lim1?num1[len]:1,up2=lim2?num2[len]:1; LL res=0;
for(int i=0;i<=up1;i++)
for(int j=0;j<=up2;j++)
{
int x=i&j,y=i^j,nc1=c1,nc2=c2;
if ((!c1&&x<num3[len]||c1==2)&&(!c2&&y>num3[len]||c2==2)) continue;
if (!c1) if (x<num3[len]) nc1=2; else if (x>num3[len]) nc1=1; else nc1=0;
if (!c2) if (y>num3[len]) nc2=2; else if (y<num3[len]) nc2=1; else nc2=0;
res+=dfs(len-1,nc1,nc2,lim1&&i==up1,lim2&&j==up2);
}
dp[len][c1][c2][lim1][lim2]=res;
return res;
}
inline LL cal(LL A,LL B,LL C)
{
memset(num1,0,sizeof(num1));
memset(num2,0,sizeof(num2));
memset(num3,0,sizeof(num3));
int t1=0,t2=0,t3=0;
while(A) num1[++t1]=A&1,A>>=1;
while(B) num2[++t2]=B&1,B>>=1;
while(C) num3[++t3]=C&1,C>>=1;
int t=max(t1,max(t2,t3));
for(int i=t;i;i--)
{
n1[i]=n2[i]=0;
for(int j=i;j;j--)
{
n1[i]=n1[i]<<1|num1[j];
n2[i]=n2[i]<<1|num2[j];
}
}
return dfs(t,0,0,1,1);
}
void init()
{
pw[0]=1;
for(int i=1;i<31;i++)
pw[i]=pw[i-1]*2;
}
int main(){
int T;
sc(T);
init();
while(T--){
memset(dp,-1,sizeof(dp));
int A,B,C; sccc(A,B,C);
printf("%lld\n",cal(A,B,C)-min(A,C-1)-min(B,C-1)-1);
}
return 0;
}