/* * Copyright 2017 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Defines types that are necessary for Ring-Learning with Errors (rlwe) // encryption. #ifndef RLWE_MONTGOMERY_H_ #define RLWE_MONTGOMERY_H_ #include #include #include #include #include #include "absl/numeric/int128.h" #include "absl/strings/str_cat.h" #include "bits_util.h" #include "constants.h" #include "int256.h" #include "prng/prng.h" #include "serialization.pb.h" #include "status_macros.h" #include "statusor.h" namespace rlwe { namespace internal { // Struct to capture the "bigger int" type. template struct BigInt; // Specialization for uint8, uint16, uint32, uint64, and uint128. template <> struct BigInt { typedef Uint16 value_type; }; template <> struct BigInt { typedef Uint32 value_type; }; template <> struct BigInt { typedef Uint64 value_type; }; template <> struct BigInt { typedef absl::uint128 value_type; }; template <> struct BigInt { typedef uint256 value_type; }; } // namespace internal // The parameters necessary for a Montgomery integer. Note that the template // parameters ensure that T is an unsigned integral of at least 8 bits. template struct MontgomeryIntParams { // Expose Int and its greater type. BigInt is required in order to multiply // two Int and ensure that no overflow occurs. // // Thread safe. using Int = T; using BigInt = typename internal::BigInt::value_type; static const size_t bitsize_int = sizeof(Int) * 8; // Factory function to create MontgomeryIntParams. static rlwe::StatusOr> Create( Int modulus); // The value R to be used in Montgomery multiplication. R will be selected as // 2^bitsize(Int) and hence automatically verifies R > modulus. const BigInt r = static_cast(1) << bitsize_int; // The modulus over which these modular operations are being performed. const Int modulus; // The modulus over which these modular operations are being performed, cast // as a BigInt. const BigInt modulus_bigint; // The number of bits in the modulus. const unsigned int log_modulus; // The value of R taken modulo the modulus. const Int r_mod_modulus; const Int r_mod_modulus_barrett; // = (r_mod_modulus << bitsize_int) / modulus, to // speed up multiplication by r_mod_modulus. // The values below, inv_modulus and inv_r, satisfy the formula: // R * inv_r - modulus * inv_modulus = 1 // Note that inv_modulus is not literally the multiplicative inverse of // modulus modulo R. const Int inv_modulus; const Int inv_r; // needed to translate from Montgomery to normal representation const Int inv_r_barrett; // = (inv_r << bitsize_int) / modulus, to speed up // multiplication by inv_r. // The numerator used in the Barrett reduction. const BigInt barrett_numerator; // = 2^(sizeof(Int)*8) / modulus inline const Int Zero() const { return 0; } inline const Int One() const { return 1; } inline const Int Two() const { return 2; } // Functions to perform Barrett reduction. For more details, see // https://en.wikipedia.org/wiki/Barrett_reduction. // This function should remains in the header file to avoid performance // regressions. inline Int BarrettReduce(Int input) const { Int out = static_cast((this->barrett_numerator * input) >> bitsize_int); out = input - (out * this->modulus); // The steps above produce an integer that is in the range [0, 2N). // We now reduce to the range [0, N). return (out >= this->modulus) ? out - this->modulus : out; } // Computes the serialized byte length of an integer. inline unsigned int SerializedSize() const { return (log_modulus + 7) / 8; } // Check whether (1 << log_n) fits into the underlying Int type. static bool DoesLogNFit(Uint64 log_n) { return (log_n < bitsize_int - 1); } private: MontgomeryIntParams(Int mod) : modulus(mod), modulus_bigint(static_cast(this->modulus)), log_modulus(internal::BitLength(this->modulus)), r_mod_modulus(static_cast(this->r % this->modulus_bigint)), r_mod_modulus_barrett(static_cast( (static_cast(r_mod_modulus) << bitsize_int) / modulus)), inv_modulus(static_cast( std::get<1>(MontgomeryIntParams::Inverses(modulus_bigint, r)))), inv_r(std::get<0>(MontgomeryIntParams::Inverses(modulus_bigint, r))), inv_r_barrett(static_cast( (static_cast(inv_r) << bitsize_int) / modulus)), barrett_numerator(this->r / this->modulus_bigint) {} // Computes the Montgomery inverse coefficients for r and modulus using // the Extended Euclidean Algorithm. // // modulus must be odd. // Returns a tuple of (inv_r, inv_modulus) such that: // r * inv_r - modulus * inv_modulus = 1 static std::tuple Inverses(BigInt modulus_bigint, BigInt r); }; // Stores an integer in Montgomery representation. The goal of this // class is to provide a static type that differentiates between integers // in Montgomery representation and standard integers. Once a Montgomery // integer is created, it has the * and + operators necessary to be treated // as another integer. // The underlying integer type T must be unsigned and must not be bool. // This class is thread safe. template class ABSL_MUST_USE_RESULT MontgomeryInt { public: // Expose Int and its greater type. BigInt is required in order to multiply // two Int and ensure that no overflow occurs. This should also be used by // external classes. using Int = T; using BigInt = typename internal::BigInt::value_type; // Expose the parameter type. using Params = MontgomeryIntParams; // Static factory that converts a non-Montgomery representation integer, the // underlying integer type, into a Montgomery representation integer. Does not // take ownership of params. i.e., import "a". static rlwe::StatusOr ImportInt(Int n, const Params* params); // Static functions to create a MontgomeryInt of 0 and 1. static MontgomeryInt ImportZero(const Params* params); static MontgomeryInt ImportOne(const Params* params); // Import a random integer using entropy from specified prng. Does not take // ownership of params or prng. template static rlwe::StatusOr ImportRandom(Prng* prng, const Params* params) { // In order to generate unbiased randomness, we uniformly and randomly // sample integers in [0, 2^params->log_modulus) until the generated integer // is less than the modulus (i.e., we perform rejection sampling). RLWE_ASSIGN_OR_RETURN(Int random_int, GenerateRandomInt(params->log_modulus, prng)); while (random_int >= params->modulus) { RLWE_ASSIGN_OR_RETURN(random_int, GenerateRandomInt(params->log_modulus, prng)); } return MontgomeryInt(random_int); } static BigInt DivAndTruncate(BigInt dividend, BigInt divisor); // No default constructor. MontgomeryInt() = delete; // Default copy constructor. MontgomeryInt(const MontgomeryInt& that) = default; MontgomeryInt& operator=(const MontgomeryInt& that) = default; // Convert a Montgomery representation integer back to the underlying integer. // i.e., export "a". inline Int ExportInt(const Params* params) const { Int out = static_cast((static_cast(params->inv_r_barrett) * n_) >> params->bitsize_int); out = n_ * params->inv_r - out * params->modulus; // The steps above produce an integer that is in the range [0, 2N). // We now reduce to the range [0, N). out -= (out >= params->modulus) ? params->modulus : 0; return out; } // Returns the least significant 64 bits of n. static Uint64 ExportUInt64(Int n) { return static_cast(n); } // Serialization. rlwe::StatusOr Serialize(const Params* params) const; static rlwe::StatusOr SerializeVector( const std::vector& coeffs, const Params* params); // Deserialization. static rlwe::StatusOr Deserialize(absl::string_view payload, const Params* params); static rlwe::StatusOr> DeserializeVector( int num_coeffs, absl::string_view serialized, const Params* params); // Modular multiplication. // Perform a multiply followed by a modular reduction. Produces an output // in the range [0, N). // Taken from Hacker's Delight chapter on Montgomery multiplication. MontgomeryInt Mul(const MontgomeryInt& that, const Params* params) const { MontgomeryInt out(*this); return out.MulInPlace(that, params); } // This function should remains in the header file to avoid performance // regressions. MontgomeryInt& MulInPlace(const MontgomeryInt& that, const Params* params) { // This function computes the product of the two numbers (a and b), which // will equal a * R * b * R in Montgomery representation. It then performs // the reduction, creating a * b * R (mod N). Int u = static_cast(n_ * that.n_) * params->inv_modulus; BigInt t = static_cast(n_) * that.n_ + params->modulus_bigint * u; Int t_msb = static_cast(t >> Params::bitsize_int); // The steps above produce an integer that is in the range [0, 2N). // We now reduce to the range [0, N). n_ = (t_msb >= params->modulus) ? t_msb - params->modulus : t_msb; return *this; } // Modular multiplication by a constant with precomputation. // MontgomeryInt are stored under the form n * R, where R = 1 << bitsize_int. // When multiplying by a constant c, one can precompute // constant = (c * R^(-1) mod modulus) // constant_barrett = (constant << bitsize_int) / modulus // and perform modular reduction using Barrett reduction instead of // Montgomery multiplication. // The funciton GetConstant() returns a tuple (constant, constant_barrett): // constant = ExportInt(params), // and // constant_barrett = (constant << bitsize_int) / modulus std::tuple GetConstant(const Params* params) const; MontgomeryInt MulConstant(const Int& constant, const Int& constant_barrett, const Params* params) const { MontgomeryInt out(*this); return out.MulConstantInPlace(constant, constant_barrett, params); } // This function should remains in the header file to avoid performance // regressions. MontgomeryInt& MulConstantInPlace(const Int& constant, const Int& constant_barrett, const Params* params) { Int out = static_cast((static_cast(constant_barrett) * n_) >> Params::bitsize_int); n_ = n_ * constant - out * params->modulus; // The steps above produce an integer that is in the range [0, 2N). // We now reduce to the range [0, N). n_ -= (n_ >= params->modulus) ? params->modulus : 0; return *this; } // Montgomery addition. MontgomeryInt Add(const MontgomeryInt& that, const Params* params) const { MontgomeryInt out(*this); return out.AddInPlace(that, params); } // This function should remains in the header file to avoid performance // regressions. MontgomeryInt& AddInPlace(const MontgomeryInt& that, const Params* params) { // We can use Barrett reduction because n_ <= modulus < Max(Int)/2. n_ = params->BarrettReduce(n_ + that.n_); return *this; } // Modular negation. MontgomeryInt Negate(const Params* params) const { return MontgomeryInt(params->modulus - n_); } // This function should remains in the header file to avoid performance // regressions. MontgomeryInt& NegateInPlace(const Params* params) { n_ = params->modulus - n_; return *this; } // Modular subtraction. MontgomeryInt Sub(const MontgomeryInt& that, const Params* params) const { MontgomeryInt out(*this); return out.SubInPlace(that, params); } // This function should remains in the header file to avoid performance // regressions. MontgomeryInt& SubInPlace(const MontgomeryInt& that, const Params* params) { // We can use Barrett reduction because n_ <= modulus < Max(Int)/2. n_ = params->BarrettReduce(n_ + (params->modulus - that.n_)); return *this; } // Batch operations (and in-place variants) over vectors of MontgomeryInt. // We define two versions of the batch operations: // - the first input is a vector and the second is a MontgomeryInt scalar: // in that case, the scalar will be added/subtracted/multiplied with each // element of the first input. // - both inputs are vectors of same length: in that case, the operations // will be performed component wise. // These batch operations may fail if the input vectors are not of the same // size. // Batch addition of two vectors. static rlwe::StatusOr> BatchAdd( const std::vector& in1, const std::vector& in2, const Params* params); static absl::Status BatchAddInPlace(std::vector* in1, const std::vector& in2, const Params* params); // Batch addition of one vector with a scalar. static rlwe::StatusOr> BatchAdd( const std::vector& in1, const MontgomeryInt& in2, const Params* params); static absl::Status BatchAddInPlace(std::vector* in1, const MontgomeryInt& in2, const Params* params); // Batch subtraction of two vectors. static rlwe::StatusOr> BatchSub( const std::vector& in1, const std::vector& in2, const Params* params); static absl::Status BatchSubInPlace(std::vector* in1, const std::vector& in2, const Params* params); // Batch subtraction of one vector with a scalar. static rlwe::StatusOr> BatchSub( const std::vector& in1, const MontgomeryInt& in2, const Params* params); static absl::Status BatchSubInPlace(std::vector* in1, const MontgomeryInt& in2, const Params* params); // Batch multiplication of two vectors. static rlwe::StatusOr> BatchMul( const std::vector& in1, const std::vector& in2, const Params* params); static absl::Status BatchMulInPlace(std::vector* in1, const std::vector& in2, const Params* params); // Batch multiplication of two vectors, where the second vector is a constant. static rlwe::StatusOr> BatchMulConstant( const std::vector& in1, const std::vector& in2_constant, const std::vector& in2_constant_barrett, const Params* params); static absl::Status BatchMulConstantInPlace( std::vector* in1, const std::vector& in2_constant, const std::vector& in2_constant_barrett, const Params* params); // Batch multiplication of a vector with a scalar. static rlwe::StatusOr> BatchMul( const std::vector& in1, const MontgomeryInt& in2, const Params* params); static absl::Status BatchMulInPlace(std::vector* in1, const MontgomeryInt& in2, const Params* params); // Batch multiplication of a vector with a constant scalar. static rlwe::StatusOr> BatchMulConstant( const std::vector& in1, const Int& constant, const Int& constant_barrett, const Params* params); static absl::Status BatchMulConstantInPlace(std::vector* in1, const Int& constant, const Int& constant_barrett, const Params* params); // Equality. bool operator==(const MontgomeryInt& that) const { return (n_ == that.n_); } bool operator!=(const MontgomeryInt& that) const { return !(*this == that); } // Modular exponentiation. MontgomeryInt ModExp(Int exponent, const Params* params) const; // Inverse. MontgomeryInt MultiplicativeInverse(const Params* params) const; private: template static rlwe::StatusOr GenerateRandomInt(int log_modulus, Prng* prng) { // Generate a random Int. As the modulus is always smaller than max(Int), // there will be no issues with overflow. int max_bits_per_step = std::min((int)Params::bitsize_int, (int)64); auto bits_required = log_modulus; Int rand = 0; while (bits_required > 0) { Int rand_bits = 0; if (bits_required <= 8) { // Generate 8 bits of randomness. RLWE_ASSIGN_OR_RETURN(rand_bits, prng->Rand8()); // Extract bits_required bits and add them to rand. Int needed_bits = rand_bits & ((static_cast(1) << bits_required) - 1); rand = (rand << bits_required) + needed_bits; break; } else { // Generate 64 bits of randomness. RLWE_ASSIGN_OR_RETURN(rand_bits, prng->Rand64()); // Extract min(64, bits in Int, bits_required) bits and add them to rand int bits_to_extract = std::min(bits_required, max_bits_per_step); Int needed_bits = rand_bits & ((static_cast(1) << bits_to_extract) - 1); rand = (rand << bits_to_extract) + needed_bits; bits_required -= bits_to_extract; } } return rand; } explicit MontgomeryInt(Int n) : n_(n) {} Int n_; }; } // namespace rlwe #endif // RLWE_MONTGOMERY_H_