天天看点

大数运算的乘法优化

关于大整数运算的手算算法。普通程序员只要锲而不舍的调试,最后总能做出来。这里着重讨论一下大数乘法的优化算法。

普通手算算法,相当于2项式乘法展开,它的时间复杂性是O(n2)。

(ax+b)*(cx+d) = ac x2 + (ad + bc) x + bd

这样依次计算ac, ad, bc, bd需要4次乘法。

因为 (ad+bc) = (a+b)(c+d) - ac - bd。只要计算(a+b)(c+d), ac, bd。这样就可以减少一次乘法。因此速度更快。但是这是一个递归算法,计算过程中生成了更多的临时变量。而这些临时变量都是大整数,需要仔细进行内存分配和回收,实际上还是很复杂的。这里给出示例代码,供大家参考。

大整数的数据类型说明:

#ifndef BIG_H
#define BIG_H


class big {
public:
        struct imp {
                struct  list {
                        unsigned int val;
                        struct list *prev;
                        struct list *next;
                };
                int n;
                list *numbers;
                list *rear;
                int link;

                imp& operator+=(imp &x);
                imp& operator-=(imp &x);
                imp& operator*=(imp &x);

                imp& divas(imp &x);
                imp& div(imp &x, imp &r);

                int cmp(imp& x);

                imp& operator+(imp &x);
                imp& operator-(imp &x);
                imp& operator*(imp &x);

                void del();
                void dup(imp &x);
                void free();
        };

        int sign;
        struct imp *data;

public:
        big(){sign=0;data=0;}
        big(int a[], int n);
        big(int num);
        big(const big &x);
        void neg() { if (data && data->n) sign = !sign; }
//      ~big() { if(data) { if(--data->link==0) data->free(); }}
        ~big();
        big& operator=(const big &a);
        big& operator+=(const big& a);
        big& operator-=(const big& a);
        big& operator*=(const big& a);
        big& operator/=(const big& a);
        big& operator%=(const big& a);

        big operator+(const big& a);
        big operator-(const big& a);
        big operator*(const big& a);
        big operator/(const big& a);
        big operator%(const big& a);

        big div(const big& a, big &r);

        int tostring(char *s, int n);

        int toint() {
                if (data==0) return -1;
                if (data->n==0) return 0;
                if (data->n>=1) if (sign==0)return data->rear->val;
                                else return -data->rear->val;
        }
        int cmp(const big &a) {
                if (data && a.data){
                        if (sign == a.sign)  {
                                if(sign==0)return data->cmp(*a.data);
                                return - data->cmp(*a.data);
                        }
                        else return a.sign - sign;
                }else if(this ==&a) return 0;
                else return -1;
        }
        int operator>(const big &a) {return cmp(a)>0;}
        int operator<(const big &a) {return cmp(a)<0;}
        int operator==(const big &a) {return cmp(a)==0;}
        int operator>=(const big &a) {return cmp(a)>=0;}
        int operator<=(const big &a) {return cmp(a)<=0;}
        int operator!=(const big &a) {
                if (this==&a)return 0;
                if (!data || !a.data) return 1;
                return cmp(a)!=0;
        }
        big operator-() {big b; b= *this; b.neg(); return b; }
};

#endif

           

使用上面优化的乘法示例代码:

#include <stdio.h>
#include <stdlib.h>
#include "xyz.cpp"


typedef big::imp imp;

char buf[128];

static int populate(int n)
{
        int x;
        x = n|(n>>1);
        x |= x>>2;
        x |= x>>4;
        x |= x>>8;
        x |= x>>16;
        ++x;
        if (x== n+n) x=n;
        return x;
}

imp arith_mul(imp &l, imp &r)
{
        int nl, nr;
        imp im;
        big::imp::list *node;
        int debug=0;

        nl = l.n; nr= r.n;

        if(debug)printf("l,r =%d, %d\n", nl, nr);
        if(nl<=2 && nr<=2) {
                int a, b, c, d;
                int ac, bd, adbc, x;

                a = l.numbers->val;
                b = l.rear->val;
                c = r.numbers->val;
                d = r.rear->val;
                if (l.n==1) a=0;
                if (r.n==1) c=0;
                ac = a*c;
                bd = b*d;
                x = (a+b)*(c+d);
                adbc = x - ac - bd;

                node = getnode();
                node->val = bd % UNIT;
                node->next= node->prev = NULL;
                im.numbers=im.rear = node;
                im.link=1;

                adbc +=  bd / UNIT;
                node =getnode();
                node->val = adbc % UNIT;
                node->prev= NULL;
                node->next = im.numbers;
                im.numbers->prev =node;
                im.numbers = node;

                ac += adbc / UNIT;
                node =getnode();
                node->val = ac % UNIT;
                node->prev= NULL;
                node->next = im.numbers;
                im.numbers->prev =node;
                im.numbers = node;
                im.n=3;

                ac /= UNIT;
                if (ac==0) return im;
                node =getnode();
                node->val = ac;
                node->prev= NULL;
                node->next = im.numbers;
                im.numbers->prev =node;
                im.numbers = node;
                im.n++;
                return im;
        }

        imp left, right;
        imp tail;

        left = l;
        right = r;
        if (nl > nr) {
                left = r;
                right = l;
                nl = left.n;
                nr = right.n;
        }

        int half;
        int i;
        half = populate(nr)>>1 ;
        if (nl <= half ) {
                imp c,d;
                imp lc, ld;
                d.rear = right.rear;
                d.numbers= d.rear;
                for(i=1; i<half; i++) {
                        d.numbers = d.numbers->prev;
                }
                d.n =half;

                c = right;
                c.rear = d.numbers->prev;
                c.rear->next=NULL;
                d.numbers->prev = NULL;
                c.n-= half;

                lc= arith_mul(left, c);
                ld = arith_mul(left, d);
                if(debug)printf("1/2-mul(%d): lc %d, ld %d\n", half, lc.n, ld.n);

                tail.rear = ld.rear;
                tail.numbers = ld.rear;
                for(i=1; i<half; i++) {
                        tail.numbers= tail.numbers->prev;
                }


                ld.rear = tail.numbers->prev;
                ld.rear->next = NULL;
                tail.numbers->prev= NULL;
                ld.n -= half;

                lc += ld;
                lc.rear->next = tail.numbers;
                tail.numbers->prev = lc.rear;
                lc.rear = tail.rear;
                lc.n+= half;

                ld.del();
                c.rear->next = d.numbers;
                d.numbers->prev = c.rear;
                return lc;
        }
        else {
                imp a,b, c,d;

                d.rear = right.rear;
                d.numbers= d.rear;
                for(i=1; i<half; i++) {
                        d.numbers = d.numbers->prev;
                }
                d.n =half;

                c = right;
                c.rear = d.numbers->prev;
                c.rear->next=NULL;
                d.numbers->prev = NULL;
                c.n-= half;

                b.rear = left.rear;
                b.numbers= b.rear;
                for(i=1; i<half; i++) {
                        b.numbers = b.numbers->prev;
                }
                b.n =half;

                a = left;
                a.rear = b.numbers->prev;
                a.rear->next=NULL;
                b.numbers->prev = NULL;
                a.n-= half;


                imp x_ac= arith_mul(a, c);
                imp x_bd= arith_mul(b, d);
                if(debug)printf("mul(%d): ac %d, bd %d\n", half, x_ac.n, x_bd.n);
                imp &a_ab = a+b;
                imp &a_cd = c+d;
                imp x = arith_mul(a_ab, a_cd);
                if(debug)printf("mul(%d): x %d\n", half, x.n);
                a_ab.free();
                a_cd.free();

                node = getnode();
                node->val=1;
                node->next= x.numbers;
                node->prev = NULL;
                x.numbers->prev = node;
                x.numbers= node;
                x.n++;

                x -= x_ac;
                x -= x_bd;

                x.numbers= x.numbers->next;
                x.numbers->prev= NULL;
                x.n--;
                freenode(node);


                tail.rear = x_bd.rear;
                tail.numbers = x_bd.rear;
                for(i=1; i<half; i++) {
                        tail.numbers= tail.numbers->prev;
                }
                x_bd.rear = tail.numbers->prev;
                x_bd.rear->next = NULL;
                tail.numbers->prev= NULL;
                x_bd.n -= half;

                x += x_bd;
                x_bd.del();

                imp m;
                m.rear = x.rear;
                m.numbers = x.rear;
                for(i=1; i<half; i++) {
                        m.numbers= m.numbers->prev;
                }
                x.rear = m.numbers->prev;
                x.rear->next = NULL;
                m.numbers->prev= NULL;
                x.n -= half;
                m.rear->next =tail.numbers;
                tail.numbers->prev = m.rear;
                tail.numbers = m.numbers;

                x_ac += x;
                x_ac.rear->next =tail.numbers;
                tail.numbers->prev = x_ac.rear;
                x_ac.rear = tail.rear;
                x_ac.n += half*2;
                x.del();


                a.rear->next = b.numbers;
                b.numbers->prev = a.rear;
                c.rear->next = d.numbers;
                d.numbers->prev = c.rear;
                return x_ac;
        }

        printf("error\n");
        return im;
}

imp& mul(imp &l, imp &r)
{
        int nl, nr;
        imp& im= *gethead();

        nl = l.n; nr= r.n;
        if(nl<=1 || nr<=1) return l*r;
        im = arith_mul(l, r);
        im.link=1;
        return im;
}

int main()
{
        big a;
        big b;
        big c;

        b=20000022;
        c = 12345678;
        b = b*b - 80000000;
        b*= 100000000;
        b*= 100000000;
        b+=99010101;
        b.tostring(buf,128);
        printf("%s\n", buf);
        c *=c;
        c/=10000;
        c.tostring(buf,128);
        printf("%s\n", buf);

        a = b*c;
        a.tostring(buf,128);
        printf("%s\n", buf);


        a=0;
        a.data->free();
        a.data = & mul(*c.data, *b.data);
        a.tostring(buf,16*128);
        printf("%s\n", buf);


        return 0;
}

           

这里运算结果是:

400,0008,0000,0484,0000,0000,9901,0101
152,4157,6527
6,0966,4280,4068,5985,2303,9218,9070,0313,3749,9227
6,0966,4280,4068,5985,2303,9218,9070,0313,3749,9227

           

继续阅读