From 2bca8928ee8333a9e8823f1776c92285ade384bd Mon Sep 17 00:00:00 2001 From: weidai Date: Sun, 15 Apr 2007 23:00:27 +0000 Subject: MMX/SSE2 optimizations git-svn-id: svn://svn.code.sf.net/p/cryptopp/code/trunk/c5@287 57ff6487-cd31-0410-9ec3-f628ee90f5f0 --- rijndael.cpp | 576 ++++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 369 insertions(+), 207 deletions(-) (limited to 'rijndael.cpp') diff --git a/rijndael.cpp b/rijndael.cpp index 2a1a19e..4a8572f 100644 --- a/rijndael.cpp +++ b/rijndael.cpp @@ -51,10 +51,7 @@ being unloaded from L1 cache, until that round is finished. #include "rijndael.h" #include "misc.h" - -#ifdef CRYPTOPP_L1_CACHE_ALIGN_NOT_AVAILABLE -#pragma message("Don't know how to align data on L1 cache boundary. Defense against AES timing attack may be affected.") -#endif +#include "cpu.h" NAMESPACE_BEGIN(CryptoPP) @@ -122,25 +119,25 @@ void Rijndael::Base::UncheckedSetKey(const byte *userKey, unsigned int keylen, c for (i = 1; i < m_rounds; i++) { rk += 4; rk[0] = - Td0[Se[GETBYTE(rk[0], 3)]] ^ - Td1[Se[GETBYTE(rk[0], 2)]] ^ - Td2[Se[GETBYTE(rk[0], 1)]] ^ - Td3[Se[GETBYTE(rk[0], 0)]]; + Td[0*256+Se[GETBYTE(rk[0], 3)]] ^ + Td[1*256+Se[GETBYTE(rk[0], 2)]] ^ + Td[2*256+Se[GETBYTE(rk[0], 1)]] ^ + Td[3*256+Se[GETBYTE(rk[0], 0)]]; rk[1] = - Td0[Se[GETBYTE(rk[1], 3)]] ^ - Td1[Se[GETBYTE(rk[1], 2)]] ^ - Td2[Se[GETBYTE(rk[1], 1)]] ^ - Td3[Se[GETBYTE(rk[1], 0)]]; + Td[0*256+Se[GETBYTE(rk[1], 3)]] ^ + Td[1*256+Se[GETBYTE(rk[1], 2)]] ^ + Td[2*256+Se[GETBYTE(rk[1], 1)]] ^ + Td[3*256+Se[GETBYTE(rk[1], 0)]]; rk[2] = - Td0[Se[GETBYTE(rk[2], 3)]] ^ - Td1[Se[GETBYTE(rk[2], 2)]] ^ - Td2[Se[GETBYTE(rk[2], 1)]] ^ - Td3[Se[GETBYTE(rk[2], 0)]]; + Td[0*256+Se[GETBYTE(rk[2], 3)]] ^ + Td[1*256+Se[GETBYTE(rk[2], 2)]] ^ + Td[2*256+Se[GETBYTE(rk[2], 1)]] ^ + Td[3*256+Se[GETBYTE(rk[2], 0)]]; rk[3] = - Td0[Se[GETBYTE(rk[3], 3)]] ^ - Td1[Se[GETBYTE(rk[3], 2)]] ^ - Td2[Se[GETBYTE(rk[3], 1)]] ^ - Td3[Se[GETBYTE(rk[3], 0)]]; + Td[0*256+Se[GETBYTE(rk[3], 3)]] ^ + Td[1*256+Se[GETBYTE(rk[3], 2)]] ^ + Td[2*256+Se[GETBYTE(rk[3], 1)]] ^ + Td[3*256+Se[GETBYTE(rk[3], 0)]]; } } @@ -148,15 +145,245 @@ void Rijndael::Base::UncheckedSetKey(const byte *userKey, unsigned int keylen, c ConditionalByteReverse(BIG_ENDIAN_ORDER, m_key + m_rounds*4, m_key + m_rounds*4, 16); } -const static unsigned int s_lineSizeDiv4 = CRYPTOPP_L1_CACHE_LINE_SIZE/4; -#ifdef IS_BIG_ENDIAN -const static unsigned int s_i3=3, s_i2=2, s_i1=1, s_i0=0; -#else -const static unsigned int s_i3=0, s_i2=1, s_i1=2, s_i0=3; -#endif +#pragma warning(disable: 4731) // frame pointer register 'ebp' modified by inline assembly code void Rijndael::Enc::ProcessAndXorBlock(const byte *inBlock, const byte *xorBlock, byte *outBlock) const { +#ifdef CRYPTOPP_X86_ASM_AVAILABLE + if (HasMMX()) + { + const word32 *k = m_key; + const word32 *kLoopEnd = k + m_rounds*4; +#ifdef __GNUC__ + word32 t0, t1, t2, t3; + __asm__ __volatile__ + ( + ".intel_syntax noprefix;" + AS1( push ebx) + AS1( push ebp) + AS2( mov ebp, eax) + AS2( movd mm5, ecx) +#else + AS2( mov edx, g_cacheLineSize) + AS2( mov edi, inBlock) + AS2( mov esi, k) + AS2( movd mm5, kLoopEnd) + AS1( push ebp) + AS2( lea ebp, Te) +#endif + AS2( mov eax, [esi+0*4]) // s0 + AS2( xor eax, [edi+0*4]) + AS2( movd mm0, eax) + AS2( mov ebx, [esi+1*4]) + AS2( xor ebx, [edi+1*4]) + AS2( movd mm1, ebx) + AS2( and ebx, eax) + AS2( mov eax, [esi+2*4]) + AS2( xor eax, [edi+2*4]) + AS2( movd mm2, eax) + AS2( and ebx, eax) + AS2( mov ecx, [esi+3*4]) + AS2( xor ecx, [edi+3*4]) + AS2( and ebx, ecx) + + // read Te0 into L1 cache. this code could be simplifed by using lfence, but that is an SSE2 instruction + AS2( and ebx, 0) + AS2( mov edi, ebx) // make index depend on previous loads to simulate lfence + ASL(2) + AS2( and ebx, [ebp+edi]) + AS2( add edi, edx) + AS2( and ebx, [ebp+edi]) + AS2( add edi, edx) + AS2( and ebx, [ebp+edi]) + AS2( add edi, edx) + AS2( and ebx, [ebp+edi]) + AS2( add edi, edx) + AS2( cmp edi, 1024) + ASJ( jl, 2, b) + AS2( and ebx, [ebp+1020]) + AS2( movd mm6, ebx) + AS2( pxor mm2, mm6) + AS2( pxor mm1, mm6) + AS2( pxor mm0, mm6) + AS2( xor ecx, ebx) + + AS2( mov edi, [esi+4*4]) // t0 + AS2( mov eax, [esi+5*4]) + AS2( mov ebx, [esi+6*4]) + AS2( mov edx, [esi+7*4]) + AS2( add esi, 8*4) + AS2( movd mm4, esi) + +#define QUARTER_ROUND(t, a, b, c, d) \ + AS2(movzx esi, t##l)\ + AS2(d, [ebp+0*1024+4*esi])\ + AS2(movzx esi, t##h)\ + AS2(c, [ebp+1*1024+4*esi])\ + AS2(shr e##t##x, 16)\ + AS2(movzx esi, t##l)\ + AS2(b, [ebp+2*1024+4*esi])\ + AS2(movzx esi, t##h)\ + AS2(a, [ebp+3*1024+4*esi]) + +#define s0 xor edi +#define s1 xor eax +#define s2 xor ebx +#define s3 xor ecx +#define t0 xor edi +#define t1 xor eax +#define t2 xor ebx +#define t3 xor edx + + QUARTER_ROUND(c, t0, t1, t2, t3) + AS2( movd ecx, mm2) + QUARTER_ROUND(c, t3, t0, t1, t2) + AS2( movd ecx, mm1) + QUARTER_ROUND(c, t2, t3, t0, t1) + AS2( movd ecx, mm0) + QUARTER_ROUND(c, t1, t2, t3, t0) + AS2( movd mm2, ebx) + AS2( movd mm1, eax) + AS2( movd mm0, edi) +#undef QUARTER_ROUND + + AS2( movd esi, mm4) + + ASL(0) + AS2( mov edi, [esi+0*4]) + AS2( mov eax, [esi+1*4]) + AS2( mov ebx, [esi+2*4]) + AS2( mov ecx, [esi+3*4]) + +#define QUARTER_ROUND(t, a, b, c, d) \ + AS2(movzx esi, t##l)\ + AS2(a, [ebp+3*1024+4*esi])\ + AS2(movzx esi, t##h)\ + AS2(b, [ebp+2*1024+4*esi])\ + AS2(shr e##t##x, 16)\ + AS2(movzx esi, t##l)\ + AS2(c, [ebp+1*1024+4*esi])\ + AS2(movzx esi, t##h)\ + AS2(d, [ebp+0*1024+4*esi]) + + QUARTER_ROUND(d, s0, s1, s2, s3) + AS2( movd edx, mm2) + QUARTER_ROUND(d, s3, s0, s1, s2) + AS2( movd edx, mm1) + QUARTER_ROUND(d, s2, s3, s0, s1) + AS2( movd edx, mm0) + QUARTER_ROUND(d, s1, s2, s3, s0) + AS2( movd esi, mm4) + AS2( movd mm2, ebx) + AS2( movd mm1, eax) + AS2( movd mm0, edi) + + AS2( mov edi, [esi+4*4]) + AS2( mov eax, [esi+5*4]) + AS2( mov ebx, [esi+6*4]) + AS2( mov edx, [esi+7*4]) + + QUARTER_ROUND(c, t0, t1, t2, t3) + AS2( movd ecx, mm2) + QUARTER_ROUND(c, t3, t0, t1, t2) + AS2( movd ecx, mm1) + QUARTER_ROUND(c, t2, t3, t0, t1) + AS2( movd ecx, mm0) + QUARTER_ROUND(c, t1, t2, t3, t0) + AS2( movd mm2, ebx) + AS2( movd mm1, eax) + AS2( movd mm0, edi) + + AS2( movd esi, mm4) + AS2( movd edi, mm5) + AS2( add esi, 8*4) + AS2( movd mm4, esi) + AS2( cmp edi, esi) + ASJ( jne, 0, b) + +#undef QUARTER_ROUND +#undef s0 +#undef s1 +#undef s2 +#undef s3 +#undef t0 +#undef t1 +#undef t2 +#undef t3 + + AS2( mov eax, [edi+0*4]) + AS2( mov ecx, [edi+1*4]) + AS2( mov esi, [edi+2*4]) + AS2( mov edi, [edi+3*4]) + +#define QUARTER_ROUND(a, b, c, d) \ + AS2( movzx ebx, dl)\ + AS2( movzx ebx, BYTE PTR [ebp+1+4*ebx])\ + AS2( shl ebx, 3*8)\ + AS2( xor a, ebx)\ + AS2( movzx ebx, dh)\ + AS2( movzx ebx, BYTE PTR [ebp+1+4*ebx])\ + AS2( shl ebx, 2*8)\ + AS2( xor b, ebx)\ + AS2( shr edx, 16)\ + AS2( movzx ebx, dl)\ + AS2( shr edx, 8)\ + AS2( movzx ebx, BYTE PTR [ebp+1+4*ebx])\ + AS2( shl ebx, 1*8)\ + AS2( xor c, ebx)\ + AS2( movzx ebx, BYTE PTR [ebp+1+4*edx])\ + AS2( xor d, ebx) + + QUARTER_ROUND(eax, ecx, esi, edi) + AS2( movd edx, mm2) + QUARTER_ROUND(edi, eax, ecx, esi) + AS2( movd edx, mm1) + QUARTER_ROUND(esi, edi, eax, ecx) + AS2( movd edx, mm0) + QUARTER_ROUND(ecx, esi, edi, eax) + +#undef QUARTER_ROUND + + AS1( pop ebp) + AS1( emms) + +#ifdef __GNUC__ + AS1( pop ebx) + ".att_syntax prefix;" + : "=a" (t0), "=c" (t1), "=S" (t2), "=D" (t3) + : "a" (Te), "D" (inBlock), "S" (k), "c" (kLoopEnd), "d" (g_cacheLineSize) + : "memory", "cc" + ); + + if (xorBlock) + { + t0 ^= ((const word32 *)xorBlock)[0]; + t1 ^= ((const word32 *)xorBlock)[1]; + t2 ^= ((const word32 *)xorBlock)[2]; + t3 ^= ((const word32 *)xorBlock)[3]; + } + ((word32 *)outBlock)[0] = t0; + ((word32 *)outBlock)[1] = t1; + ((word32 *)outBlock)[2] = t2; + ((word32 *)outBlock)[3] = t3; +#else + AS2( mov ebx, xorBlock) + AS2( test ebx, ebx) + ASJ( jz, 1, f) + AS2( xor eax, [ebx+0*4]) + AS2( xor ecx, [ebx+1*4]) + AS2( xor esi, [ebx+2*4]) + AS2( xor edi, [ebx+3*4]) + ASL(1) + AS2( mov ebx, outBlock) + AS2( mov [ebx+0*4], eax) + AS2( mov [ebx+1*4], ecx) + AS2( mov [ebx+2*4], esi) + AS2( mov [ebx+3*4], edi) +#endif + } + else +#endif // #ifdef CRYPTOPP_X86_ASM_AVAILABLE + { word32 s0, s1, s2, s3, t0, t1, t2, t3; const word32 *rk = m_key; @@ -171,95 +398,68 @@ void Rijndael::Enc::ProcessAndXorBlock(const byte *inBlock, const byte *xorBlock rk += 8; // timing attack countermeasure. see comments at top for more details + const int cacheLineSize = GetCacheLineSize(); unsigned int i; word32 u = 0; - for (i=0; i>= 8;\ + b ^= rotrFixed(Te[byte(t)], 16); t >>= 8;\ + c ^= rotrFixed(Te[byte(t)], 8); t >>= 8;\ + d ^= Te[t]; +#else +#define QUARTER_ROUND(t, a, b, c, d) \ + d ^= Te[byte(t)]; t >>= 8;\ + c ^= rotrFixed(Te[byte(t)], 8); t >>= 8;\ + b ^= rotrFixed(Te[byte(t)], 16); t >>= 8;\ + a ^= rotrFixed(Te[t], 24); +#endif + + QUARTER_ROUND(s3, t0, t1, t2, t3) + QUARTER_ROUND(s2, t3, t0, t1, t2) + QUARTER_ROUND(s1, t2, t3, t0, t1) + QUARTER_ROUND(s0, t1, t2, t3, t0) +#undef QUARTER_ROUND // Nr - 2 full rounds: unsigned int r = m_rounds/2 - 1; do { - s0 = - Te0[GETBYTE(t0, 3)] ^ - Te1[GETBYTE(t1, 2)] ^ - Te2[GETBYTE(t2, 1)] ^ - Te3[GETBYTE(t3, 0)] ^ - rk[0]; - s1 = - Te0[GETBYTE(t1, 3)] ^ - Te1[GETBYTE(t2, 2)] ^ - Te2[GETBYTE(t3, 1)] ^ - Te3[GETBYTE(t0, 0)] ^ - rk[1]; - s2 = - Te0[GETBYTE(t2, 3)] ^ - Te1[GETBYTE(t3, 2)] ^ - Te2[GETBYTE(t0, 1)] ^ - Te3[GETBYTE(t1, 0)] ^ - rk[2]; - s3 = - Te0[GETBYTE(t3, 3)] ^ - Te1[GETBYTE(t0, 2)] ^ - Te2[GETBYTE(t1, 1)] ^ - Te3[GETBYTE(t2, 0)] ^ - rk[3]; - - t0 = - Te0[GETBYTE(s0, 3)] ^ - Te1[GETBYTE(s1, 2)] ^ - Te2[GETBYTE(s2, 1)] ^ - Te3[GETBYTE(s3, 0)] ^ - rk[4]; - t1 = - Te0[GETBYTE(s1, 3)] ^ - Te1[GETBYTE(s2, 2)] ^ - Te2[GETBYTE(s3, 1)] ^ - Te3[GETBYTE(s0, 0)] ^ - rk[5]; - t2 = - Te0[GETBYTE(s2, 3)] ^ - Te1[GETBYTE(s3, 2)] ^ - Te2[GETBYTE(s0, 1)] ^ - Te3[GETBYTE(s1, 0)] ^ - rk[6]; - t3 = - Te0[GETBYTE(s3, 3)] ^ - Te1[GETBYTE(s0, 2)] ^ - Te2[GETBYTE(s1, 1)] ^ - Te3[GETBYTE(s2, 0)] ^ - rk[7]; +#define QUARTER_ROUND(t, a, b, c, d) \ + a ^= Te[3*256+byte(t)]; t >>= 8;\ + b ^= Te[2*256+byte(t)]; t >>= 8;\ + c ^= Te[1*256+byte(t)]; t >>= 8;\ + d ^= Te[t]; + + s0 = rk[0]; s1 = rk[1]; s2 = rk[2]; s3 = rk[3]; + + QUARTER_ROUND(t3, s0, s1, s2, s3) + QUARTER_ROUND(t2, s3, s0, s1, s2) + QUARTER_ROUND(t1, s2, s3, s0, s1) + QUARTER_ROUND(t0, s1, s2, s3, s0) + + t0 = rk[4]; t1 = rk[5]; t2 = rk[6]; t3 = rk[7]; + + QUARTER_ROUND(s3, t0, t1, t2, t3) + QUARTER_ROUND(s2, t3, t0, t1, t2) + QUARTER_ROUND(s1, t2, t3, t0, t1) + QUARTER_ROUND(s0, t1, t2, t3, t0) +#undef QUARTER_ROUND rk += 8; } while (--r); // timing attack countermeasure. see comments at top for more details u = 0; - for (i=0; i>= 8;\ + tempBlock[b] = Se[byte(t)]; t >>= 8;\ + tempBlock[c] = Se[byte(t)]; t >>= 8;\ + tempBlock[d] = Se[t]; + + QUARTER_ROUND(t2, 15, 2, 5, 8) + QUARTER_ROUND(t1, 11, 14, 1, 4) + QUARTER_ROUND(t0, 7, 10, 13, 0) + QUARTER_ROUND(t3, 3, 6, 9, 12) +#undef QUARTER_ROUND if (xbw) { @@ -299,12 +493,13 @@ void Rijndael::Enc::ProcessAndXorBlock(const byte *inBlock, const byte *xorBlock obw[2] = tbw[2] ^ rk[2]; obw[3] = tbw[3] ^ rk[3]; } + } } void Rijndael::Dec::ProcessAndXorBlock(const byte *inBlock, const byte *xorBlock, byte *outBlock) const { word32 s0, s1, s2, s3, t0, t1, t2, t3; - const word32 *rk = m_key; + const word32 *rk = m_key; s0 = ((const word32 *)inBlock)[0] ^ rk[0]; s1 = ((const word32 *)inBlock)[1] ^ rk[1]; @@ -317,95 +512,68 @@ void Rijndael::Dec::ProcessAndXorBlock(const byte *inBlock, const byte *xorBlock rk += 8; // timing attack countermeasure. see comments at top for more details + const int cacheLineSize = GetCacheLineSize(); unsigned int i; word32 u = 0; - for (i=0; i>= 8;\ + b ^= rotrFixed(Td[byte(t)], 16); t >>= 8;\ + c ^= rotrFixed(Td[byte(t)], 8); t >>= 8;\ + d ^= Td[t]; +#else +#define QUARTER_ROUND(t, a, b, c, d) \ + d ^= Td[byte(t)]; t >>= 8;\ + c ^= rotrFixed(Td[byte(t)], 8); t >>= 8;\ + b ^= rotrFixed(Td[byte(t)], 16); t >>= 8;\ + a ^= rotrFixed(Td[t], 24); +#endif + + QUARTER_ROUND(s3, t2, t1, t0, t3) + QUARTER_ROUND(s2, t1, t0, t3, t2) + QUARTER_ROUND(s1, t0, t3, t2, t1) + QUARTER_ROUND(s0, t3, t2, t1, t0) +#undef QUARTER_ROUND // Nr - 2 full rounds: unsigned int r = m_rounds/2 - 1; do { - s0 = - Td0[GETBYTE(t0, 3)] ^ - Td1[GETBYTE(t3, 2)] ^ - Td2[GETBYTE(t2, 1)] ^ - Td3[GETBYTE(t1, 0)] ^ - rk[0]; - s1 = - Td0[GETBYTE(t1, 3)] ^ - Td1[GETBYTE(t0, 2)] ^ - Td2[GETBYTE(t3, 1)] ^ - Td3[GETBYTE(t2, 0)] ^ - rk[1]; - s2 = - Td0[GETBYTE(t2, 3)] ^ - Td1[GETBYTE(t1, 2)] ^ - Td2[GETBYTE(t0, 1)] ^ - Td3[GETBYTE(t3, 0)] ^ - rk[2]; - s3 = - Td0[GETBYTE(t3, 3)] ^ - Td1[GETBYTE(t2, 2)] ^ - Td2[GETBYTE(t1, 1)] ^ - Td3[GETBYTE(t0, 0)] ^ - rk[3]; - - t0 = - Td0[GETBYTE(s0, 3)] ^ - Td1[GETBYTE(s3, 2)] ^ - Td2[GETBYTE(s2, 1)] ^ - Td3[GETBYTE(s1, 0)] ^ - rk[4]; - t1 = - Td0[GETBYTE(s1, 3)] ^ - Td1[GETBYTE(s0, 2)] ^ - Td2[GETBYTE(s3, 1)] ^ - Td3[GETBYTE(s2, 0)] ^ - rk[5]; - t2 = - Td0[GETBYTE(s2, 3)] ^ - Td1[GETBYTE(s1, 2)] ^ - Td2[GETBYTE(s0, 1)] ^ - Td3[GETBYTE(s3, 0)] ^ - rk[6]; - t3 = - Td0[GETBYTE(s3, 3)] ^ - Td1[GETBYTE(s2, 2)] ^ - Td2[GETBYTE(s1, 1)] ^ - Td3[GETBYTE(s0, 0)] ^ - rk[7]; +#define QUARTER_ROUND(t, a, b, c, d) \ + a ^= Td[3*256+byte(t)]; t >>= 8;\ + b ^= Td[2*256+byte(t)]; t >>= 8;\ + c ^= Td[1*256+byte(t)]; t >>= 8;\ + d ^= Td[t]; + + s0 = rk[0]; s1 = rk[1]; s2 = rk[2]; s3 = rk[3]; + + QUARTER_ROUND(t3, s2, s1, s0, s3) + QUARTER_ROUND(t2, s1, s0, s3, s2) + QUARTER_ROUND(t1, s0, s3, s2, s1) + QUARTER_ROUND(t0, s3, s2, s1, s0) + + t0 = rk[4]; t1 = rk[5]; t2 = rk[6]; t3 = rk[7]; + + QUARTER_ROUND(s3, t2, t1, t0, t3) + QUARTER_ROUND(s2, t1, t0, t3, t2) + QUARTER_ROUND(s1, t0, t3, t2, t1) + QUARTER_ROUND(s0, t3, t2, t1, t0) +#undef QUARTER_ROUND rk += 8; } while (--r); // timing attack countermeasure. see comments at top for more details u = 0; - for (i=0; i>= 8;\ + tempBlock[b] = Sd[byte(t)]; t >>= 8;\ + tempBlock[c] = Sd[byte(t)]; t >>= 8;\ + tempBlock[d] = Sd[t]; + + QUARTER_ROUND(t2, 7, 2, 13, 8) + QUARTER_ROUND(t1, 3, 14, 9, 4) + QUARTER_ROUND(t0, 15, 10, 5, 0) + QUARTER_ROUND(t3, 11, 6, 1, 12) +#undef QUARTER_ROUND if (xbw) { -- cgit v1.2.1