/* * Copyright 2018 Google LLC. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef RLWE_INT256_H_ #define RLWE_INT256_H_ #include "absl/numeric/int128.h" #include "integral_types.h" namespace rlwe { struct uint256_pod; // An unsigned 256-bit integer type. Thread-compatible. class uint256 { public: constexpr uint256(); constexpr uint256(absl::uint128 top, absl::uint128 bottom); // Implicit type conversion is allowed so these behave like familiar int types #ifndef SWIG constexpr uint256(int bottom); constexpr uint256(Uint32 bottom); #endif constexpr uint256(Uint8 bottom); constexpr uint256(unsigned long bottom); constexpr uint256(unsigned long long bottom); constexpr uint256(absl::uint128 bottom); constexpr uint256(const uint256_pod &val); // Conversion operators to other arithmetic types constexpr explicit operator bool() const; constexpr explicit operator char() const; constexpr explicit operator signed char() const; constexpr explicit operator unsigned char() const; constexpr explicit operator char16_t() const; constexpr explicit operator char32_t() const; constexpr explicit operator short() const; constexpr explicit operator unsigned short() const; constexpr explicit operator int() const; constexpr explicit operator unsigned int() const; constexpr explicit operator long() const; constexpr explicit operator unsigned long() const; constexpr explicit operator long long() const; constexpr explicit operator unsigned long long() const; constexpr explicit operator absl::int128() const; constexpr explicit operator absl::uint128() const; explicit operator float() const; explicit operator double() const; explicit operator long double() const; // Trivial copy constructor, assignment operator and destructor. void Initialize(absl::uint128 top, absl::uint128 bottom); // Arithmetic operators. uint256& operator+=(const uint256& b); uint256& operator-=(const uint256& b); uint256& operator*=(const uint256& b); // Long division/modulo for uint256. uint256& operator/=(const uint256& b); uint256& operator%=(const uint256& b); uint256 operator++(int); uint256 operator--(int); uint256& operator<<=(int); uint256& operator>>=(int); uint256& operator&=(const uint256& b); uint256& operator|=(const uint256& b); uint256& operator^=(const uint256& b); uint256& operator++(); uint256& operator--(); friend absl::uint128 Uint256Low128(const uint256& v); friend absl::uint128 Uint256High128(const uint256& v); // We add "std::" to avoid including all of port.h. friend std::ostream& operator<<(std::ostream& o, const uint256& b); private: static void DivModImpl(uint256 dividend, uint256 divisor, uint256* quotient_ret, uint256* remainder_ret); // Little-endian memory order optimizations can benefit from // having lo_ first, hi_ last. // See util/endian/endian.h and Load256/Store256 for storing a uint256. // Adding any new members will cause sizeof(uint256) tests to fail. absl::uint128 lo_; absl::uint128 hi_; // Uint256Max() // // Returns the highest value for a 256-bit unsigned integer. friend constexpr uint256 Uint256Max(); // Not implemented, just declared for catching automatic type conversions. uint256(Uint16); uint256(float v); uint256(double v); }; constexpr uint256 Uint256Max() { return uint256((std::numeric_limits::max)(), (std::numeric_limits::max)()); } // This is a POD form of uint256 which can be used for static variables which // need to be operated on as uint256. struct uint256_pod { // Note: The ordering of fields is different than 'class uint256' but the // same as its 2-arg constructor. This enables more obvious initialization // of static instances, which is the primary reason for this struct in the // first place. This does not seem to defeat any optimizations wrt // operations involving this struct. absl::uint128 hi; absl::uint128 lo; }; extern const uint256_pod kuint256max; // allow uint256 to be logged extern std::ostream& operator<<(std::ostream& o, const uint256& b); // Methods to access low and high pieces of 256-bit value. // Defined externally from uint256 to facilitate conversion // to native 256-bit types when compilers support them. inline absl::uint128 Uint256Low128(const uint256& v) { return v.lo_; } inline absl::uint128 Uint256High128(const uint256& v) { return v.hi_; } // -------------------------------------------------------------------------- // Implementation details follow // -------------------------------------------------------------------------- inline bool operator==(const uint256& lhs, const uint256& rhs) { return (Uint256Low128(lhs) == Uint256Low128(rhs) && Uint256High128(lhs) == Uint256High128(rhs)); } inline bool operator!=(const uint256& lhs, const uint256& rhs) { return !(lhs == rhs); } inline constexpr uint256::uint256() : lo_(0), hi_(0) {} inline constexpr uint256::uint256(absl::uint128 top, absl::uint128 bottom) : lo_(bottom), hi_(top) {} inline constexpr uint256::uint256(const uint256_pod& v) : lo_(v.lo), hi_(v.hi) {} inline constexpr uint256::uint256(absl::uint128 bottom) : lo_(bottom), hi_(0) {} #ifndef SWIG inline constexpr uint256::uint256(int bottom) : lo_(bottom), hi_((bottom < 0) ? -1 : 0) {} inline constexpr uint256::uint256(Uint32 bottom) : lo_(bottom), hi_(0) {} #endif inline constexpr uint256::uint256(Uint8 bottom) : lo_(bottom), hi_(0) {} inline constexpr uint256::uint256(unsigned long bottom) : lo_(bottom), hi_(0) {} inline constexpr uint256::uint256(unsigned long long bottom) : lo_(bottom), hi_(0) {} inline void uint256::Initialize(absl::uint128 top, absl::uint128 bottom) { hi_ = top; lo_ = bottom; } // Conversion operators to integer types. constexpr uint256::operator bool() const { return lo_ || hi_; } constexpr uint256::operator char() const { return static_cast(lo_); } constexpr uint256::operator signed char() const { return static_cast(lo_); } constexpr uint256::operator unsigned char() const { return static_cast(lo_); } constexpr uint256::operator char16_t() const { return static_cast(lo_); } constexpr uint256::operator char32_t() const { return static_cast(lo_); } constexpr uint256::operator short() const { return static_cast(lo_); } constexpr uint256::operator unsigned short() const { return static_cast(lo_); } constexpr uint256::operator int() const { return static_cast(lo_); } constexpr uint256::operator unsigned int() const { return static_cast(lo_); } constexpr uint256::operator long() const { return static_cast(lo_); } constexpr uint256::operator unsigned long() const { return static_cast(lo_); } constexpr uint256::operator long long() const { return static_cast(lo_); } constexpr uint256::operator unsigned long long() const { return static_cast(lo_); } constexpr uint256::operator absl::uint128() const { return lo_; } constexpr uint256::operator absl::int128() const { return static_cast(lo_); } // Conversion operators to floating point types. inline uint256::operator float() const { return static_cast(lo_) + std::ldexp(static_cast(hi_), 128); } inline uint256::operator double() const { return static_cast(lo_) + std::ldexp(static_cast(hi_), 128); } inline uint256::operator long double() const { return static_cast(lo_) + std::ldexp(static_cast(hi_), 128); } // Comparison operators. #define CMP256(op) \ inline bool operator op(const uint256& lhs, const uint256& rhs) { \ return (Uint256High128(lhs) == Uint256High128(rhs)) \ ? (Uint256Low128(lhs) op Uint256Low128(rhs)) \ : (Uint256High128(lhs) op Uint256High128(rhs)); \ } CMP256(<) CMP256(>) CMP256(>=) CMP256(<=) #undef CMP256 // Unary operators inline uint256 operator-(const uint256& val) { const absl::uint128 hi_flip = ~Uint256High128(val); const absl::uint128 lo_flip = ~Uint256Low128(val); const absl::uint128 lo_add = lo_flip + 1; if (lo_add < lo_flip) { return uint256(hi_flip + 1, lo_add); } return uint256(hi_flip, lo_add); } inline bool operator!(const uint256& val) { return !Uint256High128(val) && !Uint256Low128(val); } // Logical operators. inline uint256 operator~(const uint256& val) { return uint256(~Uint256High128(val), ~Uint256Low128(val)); } #define LOGIC256(op) \ inline uint256 operator op(const uint256& lhs, const uint256& rhs) { \ return uint256(Uint256High128(lhs) op Uint256High128(rhs), \ Uint256Low128(lhs) op Uint256Low128(rhs)); \ } LOGIC256(|) LOGIC256(&) LOGIC256(^) #undef LOGIC256 #define LOGICASSIGN256(op) \ inline uint256& uint256::operator op(const uint256& b) { \ hi_ op b.hi_; \ lo_ op b.lo_; \ return *this; \ } LOGICASSIGN256(|=) LOGICASSIGN256(&=) LOGICASSIGN256(^=) #undef LOGICASSIGN256 // Shift operators. inline uint256 operator<<(const uint256& val, int amount) { uint256 out(val); out <<= amount; return out; } inline uint256 operator>>(const uint256& val, int amount) { uint256 out(val); out >>= amount; return out; } inline uint256& uint256::operator<<=(int amount) { // uint128 shifts of >= 128 are undefined, so we will need some special-casing if (amount < 128) { if (amount != 0) { hi_ = (hi_ << amount) | (lo_ >> (128 - amount)); lo_ = lo_ << amount; } } else if (amount < 256) { hi_ = lo_ << (amount - 128); lo_ = 0; } else { hi_ = 0; lo_ = 0; } return *this; } inline uint256& uint256::operator>>=(int amount) { // uint128 shifts of >= 128 are undefined, so we will need some special-casing if (amount < 128) { if (amount != 0) { lo_ = (lo_ >> amount) | (hi_ << (128 - amount)); hi_ = hi_ >> amount; } } else if (amount < 256) { lo_ = hi_ >> (amount - 128); hi_ = 0; } else { lo_ = 0; hi_ = 0; } return *this; } inline uint256 operator+(const uint256& lhs, const uint256& rhs) { return uint256(lhs) += rhs; } inline uint256 operator-(const uint256& lhs, const uint256& rhs) { return uint256(lhs) -= rhs; } inline uint256 operator*(const uint256& lhs, const uint256& rhs) { return uint256(lhs) *= rhs; } inline uint256 operator/(const uint256& lhs, const uint256& rhs) { return uint256(lhs) /= rhs; } inline uint256 operator%(const uint256& lhs, const uint256& rhs) { return uint256(lhs) %= rhs; } inline uint256& uint256::operator+=(const uint256& b) { hi_ += b.hi_; absl::uint128 lolo = lo_ + b.lo_; if (lolo < lo_) ++hi_; lo_ = lolo; return *this; } inline uint256& uint256::operator-=(const uint256& b) { hi_ -= b.hi_; if (b.lo_ > lo_) --hi_; lo_ -= b.lo_; return *this; } inline uint256& uint256::operator*=(const uint256& b) { // Computes the product c = a * b modulo 2^256. // // We have that // a = [a.hi_ || a.lo_] and b = [b.hi_ || b.lo_] // where hi_, lo_ are 128-bit numbers. Further, we have that // a.lo_ = [a64 || a00] and b.lo_ = [b64 || b00] // where a64, a00, b64, b00 are 64-bit numbers. // // The product c = (a * b mod 2^256) is equal to // (a.hi_ * b.lo_ + a64 * b64 + b.hi_ * a.lo_ mod 2^128) * 2^128 + // (a64 * b00 + a00 * b64) * 2^64 + // (a00 * b00) // // The first and last lines can be computed without worrying about the // carries, and then we add the two elements from the second line. absl::uint128 a64 = absl::Uint128High64(lo_); absl::uint128 a00 = absl::Uint128Low64(lo_); absl::uint128 b64 = absl::Uint128High64(b.lo_); absl::uint128 b00 = absl::Uint128Low64(b.lo_); // Compute the high order and low order part of c (safe to ignore carry bits). this->hi_ = hi_ * b.lo_ + a64 * b64 + lo_ * b.hi_; this->lo_ = a00 * b00; // add middle term and capture carry uint256 middle_term = uint256(a64 * b00) + uint256(a00 * b64); *this += middle_term << 64; return *this; } inline uint256 uint256::operator++(int) { uint256 tmp(*this); lo_++; if (lo_ == 0) hi_++; // If there was a wrap around, increase the high word. return tmp; } inline uint256 uint256::operator--(int) { uint256 tmp(*this); if (lo_ == 0) hi_--; // If it wraps around, decrease the high word. lo_--; return tmp; } inline uint256& uint256::operator++() { lo_++; if (lo_ == 0) hi_++; // If there was a wrap around, increase the high word. return *this; } inline uint256& uint256::operator--() { if (lo_ == 0) hi_--; // If it wraps around, decrease the high word. lo_--; return *this; } } // namespace rlwe // Specialized numeric_limits for uint256. namespace std { template <> class numeric_limits { public: static constexpr bool is_specialized = true; static constexpr bool is_signed = false; static constexpr bool is_integer = true; static constexpr bool is_exact = true; static constexpr bool has_infinity = false; static constexpr bool has_quiet_NaN = false; static constexpr bool has_signaling_NaN = false; static constexpr float_denorm_style has_denorm = denorm_absent; static constexpr bool has_denorm_loss = false; static constexpr float_round_style round_style = round_toward_zero; static constexpr bool is_iec559 = false; static constexpr bool is_bounded = true; static constexpr bool is_modulo = true; static constexpr int digits = 256; static constexpr int digits10 = 77; static constexpr int max_digits10 = 0; static constexpr int radix = 2; static constexpr int min_exponent = 0; static constexpr int min_exponent10 = 0; static constexpr int max_exponent = 0; static constexpr int max_exponent10 = 0; static constexpr bool traps = numeric_limits::traps; static constexpr bool tinyness_before = false; static constexpr rlwe::uint256(min)() { return 0; } static constexpr rlwe::uint256 lowest() { return 0; } static constexpr rlwe::uint256(max)() { return rlwe::Uint256Max(); } static constexpr rlwe::uint256 epsilon() { return 0; } static constexpr rlwe::uint256 round_error() { return 0; } static constexpr rlwe::uint256 infinity() { return 0; } static constexpr rlwe::uint256 quiet_NaN() { return 0; } static constexpr rlwe::uint256 signaling_NaN() { return 0; } static constexpr rlwe::uint256 denorm_min() { return 0; } }; } // namespace std #endif // RLWE_INT256_H_