天天看点

bzoj 2956: 模积和 (反演)

2956: 模积和

Time Limit: 10 Sec   Memory Limit: 128 MB

Submit: 1276   Solved: 574

[ Submit][ Status][ Discuss]

Description

 求∑∑((n mod i)*(m mod j))其中1<=i<=n,1<=j<=m,i≠j。

  

Input

第一行两个数n,m。

Output

  一个整数表示答案mod 19940417的值

Sample Input

3 4

Sample Output

1

样例说明

  答案为(3 mod 1)*(4 mod 2)+(3 mod 1) * (4 mod 3)+(3 mod 1) * (4 mod 4) + (3 mod 2) * (4 mod 1) + (3 mod 2) * (4 mod 3) + (3 mod 2) * (4 mod 4) + (3 mod 3) * (4 mod 1) + (3 mod 3) * (4 mod 2) + (3 mod 3) * (4 mod 4) = 1

数据规模和约定

  对于100%的数据n,m<=10^9。

HINT

Source

中国国家队清华集训 2012-2013 第一天

[ Submit][ Status][ Discuss]

题解:数论+乘法逆元

刚开始没看到i!=j 这个条件,所以直接将式子化成了sigma(i=1..n)(n mod i)*sigma(i=1..m)(m mod i)

就可以把两部分分开计算,那么这道题就变成了CQOI余数之和

sigma (i=1..n) n mod i

=sigma(i=1..n) n-(floor(n/i)*i)

因为floor(n/i)的取值是一段一段的,所以可以在O(sqrt(n))的时间内出解。

然后我们考虑从答案中除去不符合条件的,即sigma(i=1..min(n,m) (n mod i)*(m mod i)

=sigma(i=1..min(n,m))(n-floor(n/i)*i)*(m-floor(m/i)*i)

=sigma(i=1..min(n,m))n*m-m*floor(n/i)*i-n*floor(m/i)*i-floor(n/i)*floor(m/i)*i*i

可以在O(sqrt(n)+sqrt(m))的时间内出解

需要用到平方和公式sum=n*(n+1)*(2n+1)/6

#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdio>
#define p 19940417
#define LL long long
using namespace std;
LL n,m,inv1;
LL quickpow(LL num,int x)  
{  
    LL base=num%p; LL ans=1;  
    while (x) {  
        if (x&1) ans=ans*base%p;  
        x>>=1;  
        base=base*base%p;  
    }  
    return ans;  
}  
LL calc(LL n,LL k)
{
	LL i,j; LL ans=0;
	for (i=1,j=0;i<=k;i=j+1) {
		if (n/i!=0) j=min(n/(n/i),k);
		else j=k;
		ans+=((j-i+1)*(i+j)/2)%p*(n/i)%p;
		ans%=p;
	}
	return ans;
}
void exgcd(LL a,LL b,LL &x,LL &y)    
{    
    if (!b) {    
        x=1; y=0; return;    
    }    
    exgcd(b,a%b,x,y);    
    LL t=y;    
    y=x-(a/b)*y;    
    x=t;    
}    
LL inv(LL a,LL b)    
{    
    LL x,y;    
    exgcd(a,b,x,y);    
    return x;    
}    
LL sum(LL n1)
{
    return (LL)n1*(n1+1)%p*(2*n1+1)%p*inv1%p;
    //return n1*(n1+1)*(2*n1+1)/6;
}
LL calc1(LL k)
{
	LL i,j; LL ans=n*m%p*k%p;
	for (i=1,j=0;i<=k;i=j+1) {
		j=min(n/(n/i),m/(m/i));
		j=min(j,k);
		LL t=m*(n/i)+n*(m/i); t=(t%p+p)%p;
		t=((j-i+1)*(i+j)/2)%p*t;
		LL t1=sum(j)-sum(i-1); t1=(t1%p+p)%p;
		ans=ans+((n/i)*(m/i)%p*t1%p-t+p)%p;
		ans=(ans%p+p)%p;
	}
	return ans;
}
int main()
{
	freopen("a.in","r",stdin);
	freopen("my.out","w",stdout);
	scanf("%d%d",&n,&m);
	inv1=inv(6,p);
	LL t1=calc(n,n); t1=((LL)n*n-t1)%p;
	LL t2=calc(m,m); t2=((LL)m*m-t2)%p;
	LL t3=calc1(min(n,m));
	//cout<<t1*t2%p<<" "<<t3<<endl;
	LL ans=t1*t2%p-t3; ans=(ans%p+p)%p;
	printf("%I64d\n",ans);
}