summaryrefslogtreecommitdiff
path: root/rijndael.cpp
diff options
context:
space:
mode:
authorweidai <weidai@57ff6487-cd31-0410-9ec3-f628ee90f5f0>2010-07-24 05:55:22 +0000
committerweidai <weidai@57ff6487-cd31-0410-9ec3-f628ee90f5f0>2010-07-24 05:55:22 +0000
commit8532f317b3440154b421b1e8b8b004ead28f847e (patch)
tree9fa57aeee5c779a3c9b4f88006050d81ff68e6ef /rijndael.cpp
parent5e47408d6c3c40f0aafaa2b32a2ae0889f9fc089 (diff)
downloadcryptopp-8532f317b3440154b421b1e8b8b004ead28f847e.tar.gz
add support for AES-NI and CLMUL instruction sets in AES and GMAC/GCM
git-svn-id: svn://svn.code.sf.net/p/cryptopp/code/trunk/c5@508 57ff6487-cd31-0410-9ec3-f628ee90f5f0
Diffstat (limited to 'rijndael.cpp')
-rw-r--r--rijndael.cpp358
1 files changed, 319 insertions, 39 deletions
diff --git a/rijndael.cpp b/rijndael.cpp
index a39b65d..fbc7dcc 100644
--- a/rijndael.cpp
+++ b/rijndael.cpp
@@ -5,6 +5,10 @@
// use "cl /EP /P /DCRYPTOPP_GENERATE_X64_MASM rijndael.cpp" to generate MASM code
/*
+July 2010: Added support for AES-NI instructions via compiler intrinsics.
+*/
+
+/*
Feb 2009: The x86/x64 assembly code was rewritten in by Wei Dai to do counter mode
caching, which was invented by Hongjun Wu and popularized by Daniel J. Bernstein
and Peter Schwabe in their paper "New AES software speed records". The round
@@ -69,6 +73,10 @@ being unloaded from L1 cache, until that round is finished.
#include "misc.h"
#include "cpu.h"
+#if CRYPTOPP_BOOL_AESNI_INTRINSICS_AVAILABLE
+#include <wmmintrin.h>
+#endif
+
NAMESPACE_BEGIN(CryptoPP)
#ifdef CRYPTOPP_ALLOW_UNALIGNED_DATA_ACCESS
@@ -198,20 +206,83 @@ void Rijndael::Base::UncheckedSetKey(const byte *userKey, unsigned int keylen, c
m_rounds = keylen/4 + 6;
m_key.New(4*(m_rounds+1));
- word32 temp, *rk = m_key;
- const word32 *rc = rcon;
+ word32 *rk = m_key;
+
+#if CRYPTOPP_BOOL_AESNI_INTRINSICS_AVAILABLE && (!defined(_MSC_VER) || _MSC_VER >= 1600 || CRYPTOPP_BOOL_X86)
+ // MSVC 2008 SP1 generates bad code for _mm_extract_epi32() when compiling for X64
+ if (HasAESNI())
+ {
+ static const word32 rcLE[] = {
+ 0x01, 0x02, 0x04, 0x08,
+ 0x10, 0x20, 0x40, 0x80,
+ 0x1B, 0x36, /* for 128-bit blocks, Rijndael never uses more than 10 rcon values */
+ };
+ const word32 *rc = rcLE;
+
+ __m128i temp = _mm_loadu_si128((__m128i *)(userKey+keylen-16));
+ memcpy(rk, userKey, keylen);
+
+ while (true)
+ {
+ rk[keylen/4] = rk[0] ^ _mm_extract_epi32(_mm_aeskeygenassist_si128(temp, 0), 3) ^ *(rc++);
+ rk[keylen/4+1] = rk[1] ^ rk[keylen/4];
+ rk[keylen/4+2] = rk[2] ^ rk[keylen/4+1];
+ rk[keylen/4+3] = rk[3] ^ rk[keylen/4+2];
+
+ if (rk + keylen/4 + 4 == m_key.end())
+ break;
+
+ if (keylen == 24)
+ {
+ rk[10] = rk[ 4] ^ rk[ 9];
+ rk[11] = rk[ 5] ^ rk[10];
+ temp = _mm_insert_epi32(temp, rk[11], 3);
+ }
+ else if (keylen == 32)
+ {
+ temp = _mm_insert_epi32(temp, rk[11], 3);
+ rk[12] = rk[ 4] ^ _mm_extract_epi32(_mm_aeskeygenassist_si128(temp, 0), 2);
+ rk[13] = rk[ 5] ^ rk[12];
+ rk[14] = rk[ 6] ^ rk[13];
+ rk[15] = rk[ 7] ^ rk[14];
+ temp = _mm_insert_epi32(temp, rk[15], 3);
+ }
+ else
+ temp = _mm_insert_epi32(temp, rk[7], 3);
+
+ rk += keylen/4;
+ }
+
+ if (!IsForwardTransformation())
+ {
+ rk = m_key;
+ unsigned int i, j;
+
+ std::swap(*(__m128i *)(rk), *(__m128i *)(rk+4*m_rounds));
+
+ for (i = 4, j = 4*m_rounds-4; i < j; i += 4, j -= 4)
+ {
+ temp = _mm_aesimc_si128(*(__m128i *)(rk+i));
+ *(__m128i *)(rk+i) = _mm_aesimc_si128(*(__m128i *)(rk+j));
+ *(__m128i *)(rk+j) = temp;
+ }
+
+ *(__m128i *)(rk+i) = _mm_aesimc_si128(*(__m128i *)(rk+i));
+ }
+
+ return;
+ }
+#endif
GetUserKey(BIG_ENDIAN_ORDER, rk, keylen/4, userKey, keylen);
+ const word32 *rc = rcon;
+ word32 temp;
while (true)
{
temp = rk[keylen/4-1];
- rk[keylen/4] = rk[0] ^
- (word32(Se[GETBYTE(temp, 2)]) << 24) ^
- (word32(Se[GETBYTE(temp, 1)]) << 16) ^
- (word32(Se[GETBYTE(temp, 0)]) << 8) ^
- Se[GETBYTE(temp, 3)] ^
- *(rc++);
+ word32 x = (word32(Se[GETBYTE(temp, 2)]) << 24) ^ (word32(Se[GETBYTE(temp, 1)]) << 16) ^ (word32(Se[GETBYTE(temp, 0)]) << 8) ^ Se[GETBYTE(temp, 3)];
+ rk[keylen/4] = rk[0] ^ x ^ *(rc++);
rk[keylen/4+1] = rk[1] ^ rk[keylen/4];
rk[keylen/4+2] = rk[2] ^ rk[keylen/4+1];
rk[keylen/4+3] = rk[3] ^ rk[keylen/4+2];
@@ -227,11 +298,7 @@ void Rijndael::Base::UncheckedSetKey(const byte *userKey, unsigned int keylen, c
else if (keylen == 32)
{
temp = rk[11];
- rk[12] = rk[ 4] ^
- (word32(Se[GETBYTE(temp, 3)]) << 24) ^
- (word32(Se[GETBYTE(temp, 2)]) << 16) ^
- (word32(Se[GETBYTE(temp, 1)]) << 8) ^
- Se[GETBYTE(temp, 0)];
+ rk[12] = rk[ 4] ^ (word32(Se[GETBYTE(temp, 3)]) << 24) ^ (word32(Se[GETBYTE(temp, 2)]) << 16) ^ (word32(Se[GETBYTE(temp, 1)]) << 8) ^ Se[GETBYTE(temp, 0)];
rk[13] = rk[ 5] ^ rk[12];
rk[14] = rk[ 6] ^ rk[13];
rk[15] = rk[ 7] ^ rk[14];
@@ -239,10 +306,15 @@ void Rijndael::Base::UncheckedSetKey(const byte *userKey, unsigned int keylen, c
rk += keylen/4;
}
+ rk = m_key;
+
if (IsForwardTransformation())
{
if (!s_TeFilled)
FillEncTable();
+
+ ConditionalByteReverse(BIG_ENDIAN_ORDER, rk, rk, 16);
+ ConditionalByteReverse(BIG_ENDIAN_ORDER, rk + m_rounds*4, rk + m_rounds*4, 16);
}
else
{
@@ -250,35 +322,37 @@ void Rijndael::Base::UncheckedSetKey(const byte *userKey, unsigned int keylen, c
FillDecTable();
unsigned int i, j;
- rk = m_key;
-
- /* invert the order of the round keys: */
- for (i = 0, j = 4*m_rounds; i < j; i += 4, j -= 4) {
- temp = rk[i ]; rk[i ] = rk[j ]; rk[j ] = temp;
- temp = rk[i + 1]; rk[i + 1] = rk[j + 1]; rk[j + 1] = temp;
- temp = rk[i + 2]; rk[i + 2] = rk[j + 2]; rk[j + 2] = temp;
- temp = rk[i + 3]; rk[i + 3] = rk[j + 3]; rk[j + 3] = temp;
- }
-#define InverseMixColumn(x) x = TL_M(Td, 0, Se[GETBYTE(x, 3)]) ^ TL_M(Td, 1, Se[GETBYTE(x, 2)]) ^ TL_M(Td, 2, Se[GETBYTE(x, 1)]) ^ TL_M(Td, 3, Se[GETBYTE(x, 0)])
+#define InverseMixColumn(x) TL_M(Td, 0, Se[GETBYTE(x, 3)]) ^ TL_M(Td, 1, Se[GETBYTE(x, 2)]) ^ TL_M(Td, 2, Se[GETBYTE(x, 1)]) ^ TL_M(Td, 3, Se[GETBYTE(x, 0)])
- /* apply the inverse MixColumn transform to all round keys but the first and the last: */
- for (i = 1; i < m_rounds; i++) {
- rk += 4;
- InverseMixColumn(rk[0]);
- InverseMixColumn(rk[1]);
- InverseMixColumn(rk[2]);
- InverseMixColumn(rk[3]);
+ for (i = 4, j = 4*m_rounds-4; i < j; i += 4, j -= 4)
+ {
+ temp = InverseMixColumn(rk[i ]); rk[i ] = InverseMixColumn(rk[j ]); rk[j ] = temp;
+ temp = InverseMixColumn(rk[i + 1]); rk[i + 1] = InverseMixColumn(rk[j + 1]); rk[j + 1] = temp;
+ temp = InverseMixColumn(rk[i + 2]); rk[i + 2] = InverseMixColumn(rk[j + 2]); rk[j + 2] = temp;
+ temp = InverseMixColumn(rk[i + 3]); rk[i + 3] = InverseMixColumn(rk[j + 3]); rk[j + 3] = temp;
}
+
+ rk[i+0] = InverseMixColumn(rk[i+0]);
+ rk[i+1] = InverseMixColumn(rk[i+1]);
+ rk[i+2] = InverseMixColumn(rk[i+2]);
+ rk[i+3] = InverseMixColumn(rk[i+3]);
+
+ temp = ConditionalByteReverse(BIG_ENDIAN_ORDER, rk[0]); rk[0] = ConditionalByteReverse(BIG_ENDIAN_ORDER, rk[4*m_rounds+0]); rk[4*m_rounds+0] = temp;
+ temp = ConditionalByteReverse(BIG_ENDIAN_ORDER, rk[1]); rk[1] = ConditionalByteReverse(BIG_ENDIAN_ORDER, rk[4*m_rounds+1]); rk[4*m_rounds+1] = temp;
+ temp = ConditionalByteReverse(BIG_ENDIAN_ORDER, rk[2]); rk[2] = ConditionalByteReverse(BIG_ENDIAN_ORDER, rk[4*m_rounds+2]); rk[4*m_rounds+2] = temp;
+ temp = ConditionalByteReverse(BIG_ENDIAN_ORDER, rk[3]); rk[3] = ConditionalByteReverse(BIG_ENDIAN_ORDER, rk[4*m_rounds+3]); rk[4*m_rounds+3] = temp;
}
- ConditionalByteReverse(BIG_ENDIAN_ORDER, m_key.begin(), m_key.begin(), 16);
- ConditionalByteReverse(BIG_ENDIAN_ORDER, m_key + m_rounds*4, m_key + m_rounds*4, 16);
+#if CRYPTOPP_BOOL_AESNI_INTRINSICS_AVAILABLE
+ if (HasAESNI())
+ ConditionalByteReverse(BIG_ENDIAN_ORDER, rk+4, rk+4, (m_rounds-1)*16);
+#endif
}
void Rijndael::Enc::ProcessAndXorBlock(const byte *inBlock, const byte *xorBlock, byte *outBlock) const
{
-#if CRYPTOPP_BOOL_SSE2_ASM_AVAILABLE || defined(CRYPTOPP_X64_MASM_AVAILABLE)
+#if CRYPTOPP_BOOL_SSE2_ASM_AVAILABLE || defined(CRYPTOPP_X64_MASM_AVAILABLE) || CRYPTOPP_BOOL_AESNI_INTRINSICS_AVAILABLE
if (HasSSE2())
{
Rijndael::Enc::AdvancedProcessBlocks(inBlock, xorBlock, outBlock, 16, 0);
@@ -354,6 +428,14 @@ void Rijndael::Enc::ProcessAndXorBlock(const byte *inBlock, const byte *xorBlock
void Rijndael::Dec::ProcessAndXorBlock(const byte *inBlock, const byte *xorBlock, byte *outBlock) const
{
+#if CRYPTOPP_BOOL_AESNI_INTRINSICS_AVAILABLE
+ if (HasAESNI())
+ {
+ Rijndael::Dec::AdvancedProcessBlocks(inBlock, xorBlock, outBlock, 16, 0);
+ return;
+ }
+#endif
+
typedef BlockGetAndPut<word32, NativeByteOrder> Block;
word32 s0, s1, s2, s3, t0, t1, t2, t3;
@@ -913,14 +995,200 @@ static inline bool AliasedWithTable(const byte *begin, const byte *end)
return (s0 < t1 || s1 <= t1) || (s0 >= t0 || s1 > t0);
}
+#if CRYPTOPP_BOOL_AESNI_INTRINSICS_AVAILABLE
+
+inline void AESNI_Enc_Block(__m128i &block, const __m128i *subkeys, unsigned int rounds)
+{
+ block = _mm_xor_si128(block, subkeys[0]);
+ for (unsigned int i=1; i<rounds-1; i+=2)
+ {
+ block = _mm_aesenc_si128(block, subkeys[i]);
+ block = _mm_aesenc_si128(block, subkeys[i+1]);
+ }
+ block = _mm_aesenc_si128(block, subkeys[rounds-1]);
+ block = _mm_aesenclast_si128(block, subkeys[rounds]);
+}
+
+inline void AESNI_Enc_4_Blocks(__m128i &block0, __m128i &block1, __m128i &block2, __m128i &block3, const __m128i *subkeys, unsigned int rounds)
+{
+ __m128i rk = subkeys[0];
+ block0 = _mm_xor_si128(block0, rk);
+ block1 = _mm_xor_si128(block1, rk);
+ block2 = _mm_xor_si128(block2, rk);
+ block3 = _mm_xor_si128(block3, rk);
+ for (unsigned int i=1; i<rounds; i++)
+ {
+ rk = subkeys[i];
+ block0 = _mm_aesenc_si128(block0, rk);
+ block1 = _mm_aesenc_si128(block1, rk);
+ block2 = _mm_aesenc_si128(block2, rk);
+ block3 = _mm_aesenc_si128(block3, rk);
+ }
+ rk = subkeys[rounds];
+ block0 = _mm_aesenclast_si128(block0, rk);
+ block1 = _mm_aesenclast_si128(block1, rk);
+ block2 = _mm_aesenclast_si128(block2, rk);
+ block3 = _mm_aesenclast_si128(block3, rk);
+}
+
+inline void AESNI_Dec_Block(__m128i &block, const __m128i *subkeys, unsigned int rounds)
+{
+ block = _mm_xor_si128(block, subkeys[0]);
+ for (unsigned int i=1; i<rounds-1; i+=2)
+ {
+ block = _mm_aesdec_si128(block, subkeys[i]);
+ block = _mm_aesdec_si128(block, subkeys[i+1]);
+ }
+ block = _mm_aesdec_si128(block, subkeys[rounds-1]);
+ block = _mm_aesdeclast_si128(block, subkeys[rounds]);
+}
+
+inline void AESNI_Dec_4_Blocks(__m128i &block0, __m128i &block1, __m128i &block2, __m128i &block3, const __m128i *subkeys, unsigned int rounds)
+{
+ __m128i rk = subkeys[0];
+ block0 = _mm_xor_si128(block0, rk);
+ block1 = _mm_xor_si128(block1, rk);
+ block2 = _mm_xor_si128(block2, rk);
+ block3 = _mm_xor_si128(block3, rk);
+ for (unsigned int i=1; i<rounds; i++)
+ {
+ rk = subkeys[i];
+ block0 = _mm_aesdec_si128(block0, rk);
+ block1 = _mm_aesdec_si128(block1, rk);
+ block2 = _mm_aesdec_si128(block2, rk);
+ block3 = _mm_aesdec_si128(block3, rk);
+ }
+ rk = subkeys[rounds];
+ block0 = _mm_aesdeclast_si128(block0, rk);
+ block1 = _mm_aesdeclast_si128(block1, rk);
+ block2 = _mm_aesdeclast_si128(block2, rk);
+ block3 = _mm_aesdeclast_si128(block3, rk);
+}
+
+static CRYPTOPP_ALIGN_DATA(16) const word32 s_one[] = {0, 0, 0, 1<<24};
+
+template <typename F1, typename F4>
+inline size_t AESNI_AdvancedProcessBlocks(F1 func1, F4 func4, const __m128i *subkeys, unsigned int rounds, const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags)
+{
+ size_t blockSize = 16;
+ size_t inIncrement = (flags & (BlockTransformation::BT_InBlockIsCounter|BlockTransformation::BT_DontIncrementInOutPointers)) ? 0 : blockSize;
+ size_t xorIncrement = xorBlocks ? blockSize : 0;
+ size_t outIncrement = (flags & BlockTransformation::BT_DontIncrementInOutPointers) ? 0 : blockSize;
+
+ if (flags & BlockTransformation::BT_ReverseDirection)
+ {
+ assert(length % blockSize == 0);
+ inBlocks += length - blockSize;
+ xorBlocks += length - blockSize;
+ outBlocks += length - blockSize;
+ inIncrement = 0-inIncrement;
+ xorIncrement = 0-xorIncrement;
+ outIncrement = 0-outIncrement;
+ }
+
+ if (flags & BlockTransformation::BT_AllowParallel)
+ {
+ while (length >= 4*blockSize)
+ {
+ __m128i block0 = _mm_loadu_si128((const __m128i *)inBlocks), block1, block2, block3;
+ if (flags & BlockTransformation::BT_InBlockIsCounter)
+ {
+ const __m128i be1 = *(const __m128i *)s_one;
+ block1 = _mm_add_epi32(block0, be1);
+ block2 = _mm_add_epi32(block1, be1);
+ block3 = _mm_add_epi32(block2, be1);
+ _mm_storeu_si128((__m128i *)inBlocks, _mm_add_epi32(block3, be1));
+ }
+ else
+ {
+ inBlocks += inIncrement;
+ block1 = _mm_loadu_si128((const __m128i *)inBlocks);
+ inBlocks += inIncrement;
+ block2 = _mm_loadu_si128((const __m128i *)inBlocks);
+ inBlocks += inIncrement;
+ block3 = _mm_loadu_si128((const __m128i *)inBlocks);
+ inBlocks += inIncrement;
+ }
+
+ if (flags & BlockTransformation::BT_XorInput)
+ {
+ block0 = _mm_xor_si128(block0, _mm_loadu_si128((const __m128i *)xorBlocks));
+ xorBlocks += xorIncrement;
+ block1 = _mm_xor_si128(block1, _mm_loadu_si128((const __m128i *)xorBlocks));
+ xorBlocks += xorIncrement;
+ block2 = _mm_xor_si128(block2, _mm_loadu_si128((const __m128i *)xorBlocks));
+ xorBlocks += xorIncrement;
+ block3 = _mm_xor_si128(block3, _mm_loadu_si128((const __m128i *)xorBlocks));
+ xorBlocks += xorIncrement;
+ }
+
+ func4(block0, block1, block2, block3, subkeys, rounds);
+
+ if (xorBlocks && !(flags & BlockTransformation::BT_XorInput))
+ {
+ block0 = _mm_xor_si128(block0, _mm_loadu_si128((const __m128i *)xorBlocks));
+ xorBlocks += xorIncrement;
+ block1 = _mm_xor_si128(block1, _mm_loadu_si128((const __m128i *)xorBlocks));
+ xorBlocks += xorIncrement;
+ block2 = _mm_xor_si128(block2, _mm_loadu_si128((const __m128i *)xorBlocks));
+ xorBlocks += xorIncrement;
+ block3 = _mm_xor_si128(block3, _mm_loadu_si128((const __m128i *)xorBlocks));
+ xorBlocks += xorIncrement;
+ }
+
+ _mm_storeu_si128((__m128i *)outBlocks, block0);
+ outBlocks += outIncrement;
+ _mm_storeu_si128((__m128i *)outBlocks, block1);
+ outBlocks += outIncrement;
+ _mm_storeu_si128((__m128i *)outBlocks, block2);
+ outBlocks += outIncrement;
+ _mm_storeu_si128((__m128i *)outBlocks, block3);
+ outBlocks += outIncrement;
+
+ length -= 4*blockSize;
+ }
+ }
+
+ while (length >= blockSize)
+ {
+ __m128i block = _mm_loadu_si128((const __m128i *)inBlocks);
+
+ if (flags & BlockTransformation::BT_XorInput)
+ block = _mm_xor_si128(block, _mm_loadu_si128((const __m128i *)xorBlocks));
+
+ if (flags & BlockTransformation::BT_InBlockIsCounter)
+ const_cast<byte *>(inBlocks)[15]++;
+
+ func1(block, subkeys, rounds);
+
+ if (xorBlocks && !(flags & BlockTransformation::BT_XorInput))
+ block = _mm_xor_si128(block, _mm_loadu_si128((const __m128i *)xorBlocks));
+
+ _mm_storeu_si128((__m128i *)outBlocks, block);
+
+ inBlocks += inIncrement;
+ outBlocks += outIncrement;
+ xorBlocks += xorIncrement;
+ length -= blockSize;
+ }
+
+ return length;
+}
+#endif
+
size_t Rijndael::Enc::AdvancedProcessBlocks(const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) const
{
+#if CRYPTOPP_BOOL_AESNI_INTRINSICS_AVAILABLE
+ if (HasAESNI())
+ return AESNI_AdvancedProcessBlocks(AESNI_Enc_Block, AESNI_Enc_4_Blocks, (const __m128i *)m_key.begin(), m_rounds, inBlocks, xorBlocks, outBlocks, length, flags);
+#endif
+
#if CRYPTOPP_BOOL_SSE2_ASM_AVAILABLE || defined(CRYPTOPP_X64_MASM_AVAILABLE)
- if (length < BLOCKSIZE)
- return length;
-
if (HasSSE2())
{
+ if (length < BLOCKSIZE)
+ return length;
+
struct Locals
{
word32 subkeys[4*12], workspace[8];
@@ -966,15 +1234,27 @@ size_t Rijndael::Enc::AdvancedProcessBlocks(const byte *inBlocks, const byte *xo
locals.keysBegin = (12-keysToCopy)*16;
Rijndael_Enc_AdvancedProcessBlocks(&locals, m_key);
- return length%16;
+ return length % BLOCKSIZE;
}
- else
#endif
- return BlockTransformation::AdvancedProcessBlocks(inBlocks, xorBlocks, outBlocks, length, flags);
+
+ return BlockTransformation::AdvancedProcessBlocks(inBlocks, xorBlocks, outBlocks, length, flags);
}
#endif
+#if CRYPTOPP_BOOL_AESNI_INTRINSICS_AVAILABLE
+
+size_t Rijndael::Dec::AdvancedProcessBlocks(const byte *inBlocks, const byte *xorBlocks, byte *outBlocks, size_t length, word32 flags) const
+{
+ if (HasAESNI())
+ return AESNI_AdvancedProcessBlocks(AESNI_Dec_Block, AESNI_Dec_4_Blocks, (const __m128i *)m_key.begin(), m_rounds, inBlocks, xorBlocks, outBlocks, length, flags);
+
+ return BlockTransformation::AdvancedProcessBlocks(inBlocks, xorBlocks, outBlocks, length, flags);
+}
+
+#endif // #if CRYPTOPP_BOOL_AESNI_INTRINSICS_AVAILABLE
+
NAMESPACE_END
#endif