C++ 中大数的模幂运算

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/2207006/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-27 22:32:09  来源:igfitidea点击:

Modular Exponentiation for high numbers in C++

c++modulointeger-overflowexponentiation

提问by Axel Magnuson

So I've been working recently on an implementation of the Miller-Rabin primality test. I am limiting it to a scope of all 32-bit numbers, because this is a just-for-fun project that I am doing to familiarize myself with c++, and I don't want to have to work with anything 64-bits for awhile. An added bonus is that the algorithm is deterministic for all 32-bit numbers, so I can significantly increase efficiency because I know exactly what witnesses to test for.

所以我最近一直在研究 Miller-Rabin 素性检验的实现。我将它限制在所有 32 位数字的范围内,因为这是一个仅供娱乐的项目,我正在做这个项目来熟悉 C++,而且我不想使用任何 64 位数字一会儿。一个额外的好处是该算法对于所有 32 位数字都是确定性的,因此我可以显着提高效率,因为我确切地知道要测试哪些证人。

So for low numbers, the algorithm works exceptionally well. However, part of the process relies upon modular exponentiation, that is (num ^ pow) % mod. so, for example,

因此,对于低数字,该算法运行得非常好。但是,该过程的一部分依赖于模幂,即 (num ^ pow) % mod。所以,例如,

3 ^ 2 % 5 = 
9 % 5 = 
4

here is the code I have been using for this modular exponentiation:

这是我一直用于这个模幂运算的代码:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
    unsigned test;
    for(test = 1; pow; pow >>= 1)
    {
        if (pow & 1)
            test = (test * num) % mod;
        num = (num * num) % mod;
    }

    return test;

}

As you might have already guessed, problems arise when the arguments are all exceptionally large numbers. For example, if I want to test the number 673109 for primality, I will at one point have to find:

正如您可能已经猜到的那样,当参数都是非常大的数字时就会出现问题。例如,如果我想测试数字 673109 的素数,我必须在某一时刻找到:

(2 ^ 168277) % 673109

(2 ^ 168277) % 673109

now 2 ^ 168277 is an exceptionally large number, and somewhere in the process it overflows test, which results in an incorrect evaluation.

现在 2 ^ 168277 是一个非常大的数字,并且在此过程中的某个地方溢出了 test,从而导致了错误的评估。

on the reverse side, arguments such as

另一方面,诸如此类的论点

4000111222 ^ 3 % 1608

4000111222 ^ 3 % 1608

also evaluate incorrectly, for much the same reason.

出于同样的原因,也会错误地评估。

Does anyone have suggestions for modular exponentiation in a way that can prevent this overflow and/or manipulate it to produce the correct result? (the way I see it, overflow is just another form of modulo, that is num % (UINT_MAX+1))

有没有人建议以一种可以防止这种溢出和/或操纵它以产生正确结果的方式进行模幂运算?(在我看来,溢出只是模的另一种形式,即 num % (UINT_MAX+1))

采纳答案by Steve Jessop

Exponentiation by squaringstill "works" for modulo exponentiation. Your problem isn't that 2 ^ 168277is an exceptionally large number, it's that one of your intermediate results is a fairly large number (bigger than 2^32), because 673109 is bigger than 2^16.

通过平方求幂仍然“有效”用于模幂。您的问题不是2 ^ 168277一个特别大的数字,而是您的中间结果之一是一个相当大的数字(大于 2^32),因为 673109 大于 2^16。

So I think the following will do. It's possible I've missed a detail, but the basic idea works, and this is how "real" crypto code might do large mod-exponentiation (although not with 32 and 64 bit numbers, rather with bignums that never have to get bigger than 2 * log (modulus)):

所以我认为下面会做。我可能遗漏了一个细节,但基本思想是有效的,这就是“真实”加密代码如何进行大的模幂运算(尽管不是使用 32 位和 64 位数字,而是使用永远不必变得更大的 bignums 2 * log(模数)):

  • Start with exponentiation by squaring, as you have.
  • Perform the actual squaring in a 64-bit unsigned integer.
  • Reduce modulo 673109 at each step to get back within the 32-bit range, as you do.
  • 像你一样,从平方开始取幂。
  • 在 64 位无符号整数中执行实际平方。
  • 像您一样,在每一步减少模 673109 以回到 32 位范围内。

Obviously that's a bit awkward if your C++ implementation doesn't have a 64 bit integer, although you can always fake one.

显然,如果您的 C++ 实现没有 64 位整数,这有点尴尬,尽管您总是可以伪造一个。

There's an example on slide 22 here: http://www.cs.princeton.edu/courses/archive/spr05/cos126/lectures/22.pdf, although it uses very small numbers (less than 2^16), so it may not illustrate anything you don't already know.

这里有一个幻灯片 22 的例子:http: //www.cs.princeton.edu/courses/archive/spr05/cos126/lectures/22.pdf,虽然它使用非常小的数字(小于 2^16),所以它可能无法说明您尚不知道的任何内容。

Your other example, 4000111222 ^ 3 % 1608would work in your current code if you just reduce 4000111222modulo 1608before you start. 1608is small enough that you can safely multiply any two mod-1608 numbers in a 32 bit int.

4000111222 ^ 3 % 1608如果您在开始之前减少4000111222模数,您的另一个示例将在您当前的代码中工作16081608足够小,您可以安全地将 32 位整数中的任何两个 mod-1608 数字相乘。

回答by clinux

I wrote something for this recently for RSA in C++, bit messy though.

我最近用 C++ 为 RSA 写了一些东西,虽然有点乱。

#include "BigInteger.h"
#include <iostream>
#include <sstream>
#include <stack>

BigInteger::BigInteger() {
    digits.push_back(0);
    negative = false;
}

BigInteger::~BigInteger() {
}

void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
    int sum_n_carry = 0;
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size()) {
        n = b.digits.size();
    }
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size()) {
            a_digit = a.digits[i];
        }
        if (i < (int)b.digits.size()) {
            b_digit = b.digits[i];
        }
        sum_n_carry += a_digit + b_digit;
        c.digits[i] = (sum_n_carry & 0xFFFF);
        sum_n_carry >>= 16;
    }
    if (sum_n_carry != 0) {
        putCarryInfront(c, sum_n_carry);
    }
    while (c.digits.size() > 1 && c.digits.back() == 0) {
        c.digits.pop_back();
    }
    //std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl;
}

void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) {
    int sub_n_borrow = 0;
    int n = a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    c.digits.resize(n);
    for (int i = 0; i < n; ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        sub_n_borrow += a_digit - b_digit;
        if (sub_n_borrow >= 0) {
            c.digits[i] = sub_n_borrow;
            sub_n_borrow = 0;
        } else {
            c.digits[i] = 0x10000 + sub_n_borrow;
            sub_n_borrow = -1;
        }
    }
    while (c.digits.size() > 1 && c.digits.back() == 0) {
        c.digits.pop_back();
    }
    //std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl;
}

int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) {
    int n = (int)a.digits.size();
    if (n < (int)b.digits.size())
        n = (int)b.digits.size();
    //std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == ";
    for (int i = n-1; i >= 0; --i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = 0;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        if (i < (int)b.digits.size())
            b_digit = b.digits[i];
        if (a_digit < b_digit) {
            //std::cout << "-1" << std::endl;
            return -1;
        } else if (a_digit > b_digit) {
            //std::cout << "+1" << std::endl;
            return +1;
        }
    }
    //std::cout << "0" << std::endl;
    return 0;
}

void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) {
    unsigned int mult_n_carry = 0;
    c.digits.clear();
    c.digits.resize(a.digits.size());
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        unsigned short a_digit = 0;
        unsigned short b_digit = b;
        if (i < (int)a.digits.size())
            a_digit = a.digits[i];
        mult_n_carry += a_digit * b_digit;
        c.digits[i] = (mult_n_carry & 0xFFFF);
        mult_n_carry >>= 16;
    }
    if (mult_n_carry != 0) {
        putCarryInfront(c, mult_n_carry);
    }
    //std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl;
}

void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) {
    b.digits.resize(a.digits.size() + times);
    for (int i = 0; i < times; ++i) {
        b.digits[i] = 0;
    }
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        b.digits[i + times] = a.digits[i];
    }
}

void BigInteger::shiftRight(BigInteger& a) {
    //std::cout << "shr " << a.toString() << " == ";
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        a.digits[i] >>= 1;
        if (i+1 < (int)a.digits.size()) {
            if ((a.digits[i+1] & 0x1) != 0) {
                a.digits[i] |= 0x8000;
            }
        }
    }
    //std::cout << a.toString() << std::endl;
}

void BigInteger::shiftLeft(BigInteger& a) {
    bool lastBit = false;
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        bool bit = (a.digits[i] & 0x8000) != 0;
        a.digits[i] <<= 1;
        if (lastBit)
            a.digits[i] |= 1;
        lastBit = bit;
    }
    if (lastBit) {
        a.digits.push_back(1);
    }
}

void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) {
    BigInteger b;
    b.negative = a.negative;
    b.digits.resize(a.digits.size() + 1);
    b.digits[a.digits.size()] = carry;
    for (int i = 0; i < (int)a.digits.size(); ++i) {
        b.digits[i] = a.digits[i];
    }
    a.digits.swap(b.digits);
}

void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) {
    c.digits.clear();
    c.digits.push_back(0);
    BigInteger two("2");
    BigInteger e = b;
    BigInteger f("1");
    BigInteger g = a;
    BigInteger one("1");
    while (cmpWithoutSign(g, e) >= 0) {
        shiftLeft(e);
        shiftLeft(f);
    }
    shiftRight(e);
    shiftRight(f);
    while (cmpWithoutSign(g, b) >= 0) {
        g -= e;
        c += f;
        while (cmpWithoutSign(g, e) < 0) {
            shiftRight(e);
            shiftRight(f);
        }
    }
    e = c;
    e *= b;
    f = a;
    f -= e;
    d = f;
}

BigInteger::BigInteger(const BigInteger& other) {
    digits = other.digits;
    negative = other.negative;
}

BigInteger::BigInteger(const char* other) {
    digits.push_back(0);
    negative = false;
    BigInteger ten;
    ten.digits[0] = 10;
    const char* c = other;
    bool make_negative = false;
    if (*c == '-') {
        make_negative = true;
        ++c;
    }
    while (*c != 0) {
        BigInteger digit;
        digit.digits[0] = *c - '0';
        *this *= ten;
        *this += digit;
        ++c;
    }
    negative = make_negative;
}

bool BigInteger::isOdd() const {
    return (digits[0] & 0x1) != 0;
}

BigInteger& BigInteger::operator=(const BigInteger& other) {
    if (this == &other) // handle self assignment
        return *this;
    digits = other.digits;
    negative = other.negative;
    return *this;
}

BigInteger& BigInteger::operator+=(const BigInteger& other) {
    BigInteger result;
    if (negative) {
        if (other.negative) {
            result.negative = true;
            addWithoutSign(result, *this, other);
        } else {
            int a = cmpWithoutSign(*this, other);
            if (a < 0) {
                result.negative = false;
                subWithoutSign(result, other, *this);
            } else if (a > 0) {
                result.negative = true;
                subWithoutSign(result, *this, other);
            } else {
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            }
        }
    } else {
        if (other.negative) {
            int a = cmpWithoutSign(*this, other);
            if (a < 0) {
                result.negative = true;
                subWithoutSign(result, other, *this);
            } else if (a > 0) {
                result.negative = false;
                subWithoutSign(result, *this, other);
            } else {
                result.negative = false;
                result.digits.clear();
                result.digits.push_back(0);
            }
        } else {
            result.negative = false;
            addWithoutSign(result, *this, other);
        }
    }
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator-=(const BigInteger& other) {
    BigInteger neg_other = other;
    neg_other.negative = !neg_other.negative;
    return *this += neg_other;
}

BigInteger& BigInteger::operator*=(const BigInteger& other) {
    BigInteger result;
    for (int i = 0; i < (int)digits.size(); ++i) {
        BigInteger mult;
        multByDigitWithoutSign(mult, other, digits[i]);
        BigInteger shift;
        shiftLeftByBase(shift, mult, i);
        BigInteger add;
        addWithoutSign(add, result, shift);
        result = add;
    }
    if (negative != other.negative) {
        result.negative = true;
    } else {
        result.negative = false;
    }
    //std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl;
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator/=(const BigInteger& other) {
    BigInteger result, tmp;
    divideWithoutSign(result, tmp, *this, other);
    result.negative = (negative != other.negative);
    negative = result.negative;
    digits.swap(result.digits);
    return *this;
}

BigInteger& BigInteger::operator%=(const BigInteger& other) {
    BigInteger c, d;
    divideWithoutSign(c, d, *this, other);
    *this = d;
    return *this;
}

bool BigInteger::operator>(const BigInteger& other) const {
    if (negative) {
        if (other.negative) {
            return cmpWithoutSign(*this, other) < 0;
        } else {
            return false;
        }
    } else {
        if (other.negative) {
            return true;
        } else {
            return cmpWithoutSign(*this, other) > 0;
        }
    }
}

BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) {
    BigInteger zero("0");
    BigInteger one("1");
    BigInteger e = exponent;
    BigInteger base = *this;
    *this = one;
    while (cmpWithoutSign(e, zero) != 0) {
        //std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl;
        if (e.isOdd()) {
            *this *= base;
            *this %= modulus;
        }
        shiftRight(e);
        base *= BigInteger(base);
        base %= modulus;
    }
    return *this;
}

std::string BigInteger::toString() const {
    std::ostringstream os;
    if (negative)
        os << "-";
    BigInteger tmp = *this;
    BigInteger zero("0");
    BigInteger ten("10");
    tmp.negative = false;
    std::stack<char> s;
    while (cmpWithoutSign(tmp, zero) != 0) {
        BigInteger tmp2, tmp3;
        divideWithoutSign(tmp2, tmp3, tmp, ten);
        s.push((char)(tmp3.digits[0] + '0'));
        tmp = tmp2;
    }
    while (!s.empty()) {
        os << s.top();
        s.pop();
    }
    /*
    for (int i = digits.size()-1; i >= 0; --i) {
        os << digits[i];
        if (i != 0) {
            os << ",";
        }
    }
    */
    return os.str();

And an example usage.

和一个示例用法。

BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344")

// Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344
a.powAssignUnderMod(b, c);

Its fast too, and has unlimited number of digits.

它也很快,并且具有无限数量的数字。

回答by dirkgently

Two things:

两件事情:

  • Are you using the appropriate data type? In other words, does UINT_MAX allow you to have 673109 as an argument?
  • 您是否使用了适当的数据类型?换句话说, UINT_MAX 是否允许您将 673109 作为参数?

No, it does not, since at one point you have Your code does not work because at one point you have num = 2^16and the num = ...causes overflow. Use a bigger data type to hold this intermediate value.

不,它没有,因为在某一时刻你有你的代码不起作用,因为在某一时刻你有num = 2^16并且num = ...原因溢出。使用更大的数据类型来保存这个中间值。

  • How about taking modulo at every possible overflow oppertunity such as:

    test = ((test % mod) * (num % mod)) % mod;

  • 如何在每个可能的溢出机会处取模,例如:

    test = ((test % mod) * (num % mod)) % mod;

Edit:

编辑:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod)
{
    unsigned long long test;
    unsigned long long n = num;
    for(test = 1; pow; pow >>= 1)
    {
        if (pow & 1)
            test = ((test % mod) * (n % mod)) % mod;
        n = ((n % mod) * (n % mod)) % mod;
    }

    return test; /* note this is potentially lossy */
}

int main(int argc, char* argv[])
{

    /* (2 ^ 168277) % 673109 */
    printf("%u\n", mod_pow(2, 168277, 673109));
    return 0;
}

回答by ShowLove

    package playTime;

    public class play {

        public static long count = 0; 
        public static long binSlots = 10; 
        public static long y = 645; 
        public static long finalValue = 1; 
        public static long x = 11; 

        public static void main(String[] args){

            int[] binArray = new int[]{0,0,1,0,0,0,0,1,0,1};  

            x = BME(x, count, binArray); 

            System.out.print("\nfinal value:"+finalValue);

        }

        public static long BME(long x, long count, int[] binArray){

            if(count == binSlots){
                return finalValue; 
            }

            if(binArray[(int) count] == 1){
                finalValue = finalValue*x%y; 
            }

            x = (x*x)%y; 
            System.out.print("Array("+binArray[(int) count]+") "
                            +"x("+x+")" +" finalVal("+              finalValue + ")\n");

            count++; 


            return BME(x, count,binArray); 
        }

    }

回答by abkds

LLis for long long int

LL是为了 long long int

LL power_mod(LL a, LL k) {
    if (k == 0)
        return 1;
    LL temp = power(a, k/2);
    LL res;

    res = ( ( temp % P ) * (temp % P) ) % P;
    if (k % 2 == 1)
        res = ((a % P) * (res % P)) % P;
    return res;
}

Use the above recursive function for finding the mod exp of the number. This will not result in overflow because it calculates in a bottom up manner.

使用上面的递归函数找到数字的 mod exp。这不会导致溢出,因为它以自下而上的方式计算。

Sample test run for : a = 2and k = 168277shows output to be 518358 which is correct and the function runs in O(log(k))time;

示例测试运行: a = 2k = 168277显示输出为 518358,正确且函数运行O(log(k))及时;

回答by Alexander Poluektov

You could use following identity:

您可以使用以下身份:

(a * b) (mod m) === (a (mod m)) * (b (mod m)) (mod m)

(a * b) (mod m) === (a (mod m)) * (b (mod m)) (mod m)

Try using it straightforward way and incrementally improve.

尝试直接使用它并逐步改进。

    if (pow & 1)
        test = ((test % mod) * (num % mod)) % mod;
    num = ((num % mod) * (num % mod)) % mod;