diff options
author | (no author) <(no author)@57ff6487-cd31-0410-9ec3-f628ee90f5f0> | 2003-02-28 21:28:28 +0000 |
---|---|---|
committer | (no author) <(no author)@57ff6487-cd31-0410-9ec3-f628ee90f5f0> | 2003-02-28 21:28:28 +0000 |
commit | 85b77f5560b435cc0fe85eea5831b11f30e8f3e9 (patch) | |
tree | 320f097c8c47a9b9903ba443495d0d24d87d7737 | |
parent | 1f2afe4d61ac1d34a4b245ca8e7e9ad06697774d (diff) | |
download | cryptopp-85b77f5560b435cc0fe85eea5831b11f30e8f3e9.tar.gz |
This commit was manufactured by cvs2svn to create branch 'c50-fixes'.
git-svn-id: svn://svn.code.sf.net/p/cryptopp/code/branches/c50-fixes/c5@34 57ff6487-cd31-0410-9ec3-f628ee90f5f0
-rw-r--r-- | files.cpp | 186 | ||||
-rw-r--r-- | files.h | 95 | ||||
-rw-r--r-- | filters.cpp | 890 | ||||
-rw-r--r-- | filters.h | 681 | ||||
-rw-r--r-- | fltrimpl.h | 42 | ||||
-rw-r--r-- | integer.cpp | 3987 | ||||
-rw-r--r-- | modes.cpp | 266 | ||||
-rw-r--r-- | modes.h | 370 |
8 files changed, 6517 insertions, 0 deletions
diff --git a/files.cpp b/files.cpp new file mode 100644 index 0000000..2b42010 --- /dev/null +++ b/files.cpp @@ -0,0 +1,186 @@ +// files.cpp - written and placed in the public domain by Wei Dai + +#include "pch.h" +#include "files.h" + +NAMESPACE_BEGIN(CryptoPP) + +using namespace std; + +void Files_TestInstantiations() +{ + FileStore f0; + FileSource f1; + FileSink f2; +} + +void FileStore::StoreInitialize(const NameValuePairs ¶meters) +{ + const char *fileName; + if (parameters.GetValue("InputFileName", fileName)) + { + ios::openmode binary = parameters.GetValueWithDefault("InputBinaryMode", true) ? ios::binary : ios::openmode(0); + m_file.open(fileName, ios::in | binary); + if (!m_file) + throw OpenErr(fileName); + m_stream = &m_file; + } + else + { + m_stream = NULL; + parameters.GetValue("InputStreamPointer", m_stream); + } + m_waiting = false; +} + +unsigned long FileStore::MaxRetrievable() const +{ + if (!m_stream) + return 0; + + streampos current = m_stream->tellg(); + streampos end = m_stream->seekg(0, ios::end).tellg(); + m_stream->seekg(current); + return end-current; +} + +unsigned int FileStore::TransferTo2(BufferedTransformation &target, unsigned long &transferBytes, const std::string &channel, bool blocking) +{ + if (!m_stream) + { + transferBytes = 0; + return 0; + } + + unsigned long size=transferBytes; + transferBytes = 0; + + if (m_waiting) + goto output; + + while (size && m_stream->good()) + { + { + unsigned int spaceSize = 1024; + m_space = HelpCreatePutSpace(target, channel, 1, (unsigned int)STDMIN(size, (unsigned long)UINT_MAX), spaceSize); + + m_stream->read((char *)m_space, STDMIN(size, (unsigned long)spaceSize)); + } + m_len = m_stream->gcount(); + unsigned int blockedBytes; +output: + blockedBytes = target.ChannelPutModifiable2(channel, m_space, m_len, 0, blocking); + m_waiting = blockedBytes > 0; + if (m_waiting) + return blockedBytes; + size -= m_len; + transferBytes += m_len; + } + + if (!m_stream->good() && !m_stream->eof()) + throw ReadErr(); + + return 0; +} + +unsigned int FileStore::CopyRangeTo2(BufferedTransformation &target, unsigned long &begin, unsigned long end, const std::string &channel, bool blocking) const +{ + if (!m_stream) + return 0; + + if (begin == 0 && end == 1) + { + int result = m_stream->peek(); + if (result == EOF) // GCC workaround: 2.95.2 doesn't have char_traits<char>::eof() + return 0; + else + { + unsigned int blockedBytes = target.ChannelPut(channel, byte(result), blocking); + begin += 1-blockedBytes; + return blockedBytes; + } + } + + // TODO: figure out what happens on cin + streampos current = m_stream->tellg(); + streampos endPosition = m_stream->seekg(0, ios::end).tellg(); + streampos newPosition = current + (streamoff)begin; + + if (newPosition >= endPosition) + { + m_stream->seekg(current); + return 0; // don't try to seek beyond the end of file + } + m_stream->seekg(newPosition); + unsigned long total = 0; + try + { + assert(!m_waiting); + unsigned long copyMax = end-begin; + unsigned int blockedBytes = const_cast<FileStore *>(this)->TransferTo2(target, copyMax, channel, blocking); + begin += copyMax; + if (blockedBytes) + { + const_cast<FileStore *>(this)->m_waiting = false; + return blockedBytes; + } + } + catch(...) + { + m_stream->clear(); + m_stream->seekg(current); + throw; + } + m_stream->clear(); + m_stream->seekg(current); + + return 0; +} + +void FileSink::IsolatedInitialize(const NameValuePairs ¶meters) +{ + const char *fileName; + if (parameters.GetValue("OutputFileName", fileName)) + { + ios::openmode binary = parameters.GetValueWithDefault("OutputBinaryMode", true) ? ios::binary : ios::openmode(0); + m_file.open(fileName, ios::out | ios::trunc | binary); + if (!m_file) + throw OpenErr(fileName); + m_stream = &m_file; + } + else + { + m_stream = NULL; + parameters.GetValue("OutputStreamPointer", m_stream); + } +} + +bool FileSink::IsolatedFlush(bool hardFlush, bool blocking) +{ + if (!m_stream) + throw Err("FileSink: output stream not opened"); + + m_stream->flush(); + if (!m_stream->good()) + throw WriteErr(); + + return false; +} + +unsigned int FileSink::Put2(const byte *inString, unsigned int length, int messageEnd, bool blocking) +{ + if (!m_stream) + throw Err("FileSink: output stream not opened"); + + m_stream->write((const char *)inString, length); + + if (messageEnd) + m_stream->flush(); + + if (!m_stream->good()) + throw WriteErr(); + + return 0; +} + +NAMESPACE_END @@ -0,0 +1,95 @@ +#ifndef CRYPTOPP_FILES_H +#define CRYPTOPP_FILES_H + +#include "cryptlib.h" +#include "filters.h" + +#include <iostream> +#include <fstream> + +NAMESPACE_BEGIN(CryptoPP) + +//! . +class FileStore : public Store, private FilterPutSpaceHelper +{ +public: + class Err : public Exception + { + public: + Err(const std::string &s) : Exception(IO_ERROR, s) {} + }; + class OpenErr : public Err {public: OpenErr(const std::string &filename) : Err("FileStore: error opening file for reading: " + filename) {}}; + class ReadErr : public Err {public: ReadErr() : Err("FileStore: error reading file") {}}; + + FileStore() : m_stream(NULL) {} + FileStore(std::istream &in) + {StoreInitialize(MakeParameters("InputStreamPointer", &in));} + FileStore(const char *filename) + {StoreInitialize(MakeParameters("InputFileName", filename));} + + std::istream* GetStream() {return m_stream;} + + unsigned long MaxRetrievable() const; + unsigned int TransferTo2(BufferedTransformation &target, unsigned long &transferBytes, const std::string &channel=NULL_CHANNEL, bool blocking=true); + unsigned int CopyRangeTo2(BufferedTransformation &target, unsigned long &begin, unsigned long end=ULONG_MAX, const std::string &channel=NULL_CHANNEL, bool blocking=true) const; + +private: + void StoreInitialize(const NameValuePairs ¶meters); + + std::ifstream m_file; + std::istream *m_stream; + byte *m_space; + unsigned int m_len; + bool m_waiting; +}; + +//! . +class FileSource : public SourceTemplate<FileStore> +{ +public: + typedef FileStore::Err Err; + typedef FileStore::OpenErr OpenErr; + typedef FileStore::ReadErr ReadErr; + + FileSource(BufferedTransformation *attachment = NULL) + : SourceTemplate<FileStore>(attachment) {} + FileSource(std::istream &in, bool pumpAll, BufferedTransformation *attachment = NULL) + : SourceTemplate<FileStore>(attachment) {SourceInitialize(pumpAll, MakeParameters("InputStreamPointer", &in));} + FileSource(const char *filename, bool pumpAll, BufferedTransformation *attachment = NULL, bool binary=true) + : SourceTemplate<FileStore>(attachment) {SourceInitialize(pumpAll, MakeParameters("InputFileName", filename)("InputBinaryMode", binary));} + + std::istream* GetStream() {return m_store.GetStream();} +}; + +//! . +class FileSink : public Sink +{ +public: + class Err : public Exception + { + public: + Err(const std::string &s) : Exception(IO_ERROR, s) {} + }; + class OpenErr : public Err {public: OpenErr(const std::string &filename) : Err("FileSink: error opening file for writing: " + filename) {}}; + class WriteErr : public Err {public: WriteErr() : Err("FileSink: error writing file") {}}; + + FileSink() : m_stream(NULL) {} + FileSink(std::ostream &out) + {IsolatedInitialize(MakeParameters("OutputStreamPointer", &out));} + FileSink(const char *filename, bool binary=true) + {IsolatedInitialize(MakeParameters("OutputFileName", filename)("OutputBinaryMode", binary));} + + std::ostream* GetStream() {return m_stream;} + + void IsolatedInitialize(const NameValuePairs ¶meters); + unsigned int Put2(const byte *inString, unsigned int length, int messageEnd, bool blocking); + bool IsolatedFlush(bool hardFlush, bool blocking); + +private: + std::ofstream m_file; + std::ostream *m_stream; +}; + +NAMESPACE_END + +#endif diff --git a/filters.cpp b/filters.cpp new file mode 100644 index 0000000..d2b08fe --- /dev/null +++ b/filters.cpp @@ -0,0 +1,890 @@ +// filters.cpp - written and placed in the public domain by Wei Dai + +#include "pch.h" +#include "filters.h" +#include "mqueue.h" +#include "fltrimpl.h" +#include "argnames.h" +#include <memory> +#include <functional> + +NAMESPACE_BEGIN(CryptoPP) + +Filter::Filter(BufferedTransformation *attachment) + : m_attachment(attachment), m_continueAt(0) +{ +} + +BufferedTransformation * Filter::NewDefaultAttachment() const +{ + return new MessageQueue; +} + +BufferedTransformation * Filter::AttachedTransformation() +{ + if (m_attachment.get() == NULL) + m_attachment.reset(NewDefaultAttachment()); + return m_attachment.get(); +} + +const BufferedTransformation *Filter::AttachedTransformation() const +{ + if (m_attachment.get() == NULL) + const_cast<Filter *>(this)->m_attachment.reset(NewDefaultAttachment()); + return m_attachment.get(); +} + +void Filter::Detach(BufferedTransformation *newOut) +{ + m_attachment.reset(newOut); + NotifyAttachmentChange(); +} + +void Filter::Insert(Filter *filter) +{ + filter->m_attachment.reset(m_attachment.release()); + m_attachment.reset(filter); + NotifyAttachmentChange(); +} + +unsigned int Filter::CopyRangeTo2(BufferedTransformation &target, unsigned long &begin, unsigned long end, const std::string &channel, bool blocking) const +{ + return AttachedTransformation()->CopyRangeTo2(target, begin, end, channel, blocking); +} + +unsigned int Filter::TransferTo2(BufferedTransformation &target, unsigned long &transferBytes, const std::string &channel, bool blocking) +{ + return AttachedTransformation()->TransferTo2(target, transferBytes, channel, blocking); +} + +void Filter::Initialize(const NameValuePairs ¶meters, int propagation) +{ + m_continueAt = 0; + IsolatedInitialize(parameters); + PropagateInitialize(parameters, propagation); +} + +bool Filter::Flush(bool hardFlush, int propagation, bool blocking) +{ + switch (m_continueAt) + { + case 0: + if (IsolatedFlush(hardFlush, blocking)) + return true; + case 1: + if (OutputFlush(1, hardFlush, propagation, blocking)) + return true; + } + return false; +} + +bool Filter::MessageSeriesEnd(int propagation, bool blocking) +{ + switch (m_continueAt) + { + case 0: + if (IsolatedMessageSeriesEnd(blocking)) + return true; + case 1: + if (ShouldPropagateMessageSeriesEnd() && OutputMessageSeriesEnd(1, propagation, blocking)) + return true; + } + return false; +} + +void Filter::PropagateInitialize(const NameValuePairs ¶meters, int propagation, const std::string &channel) +{ + if (propagation) + AttachedTransformation()->ChannelInitialize(channel, parameters, propagation-1); +} + +unsigned int Filter::Output(int outputSite, const byte *inString, unsigned int length, int messageEnd, bool blocking, const std::string &channel) +{ + if (messageEnd) + messageEnd--; + unsigned int result = AttachedTransformation()->Put2(inString, length, messageEnd, blocking); + m_continueAt = result ? outputSite : 0; + return result; +} + +bool Filter::OutputFlush(int outputSite, bool hardFlush, int propagation, bool blocking, const std::string &channel) +{ + if (propagation && AttachedTransformation()->ChannelFlush(channel, hardFlush, propagation-1, blocking)) + { + m_continueAt = outputSite; + return true; + } + m_continueAt = 0; + return false; +} + +bool Filter::OutputMessageSeriesEnd(int outputSite, int propagation, bool blocking, const std::string &channel) +{ + if (propagation && AttachedTransformation()->ChannelMessageSeriesEnd(channel, propagation-1, blocking)) + { + m_continueAt = outputSite; + return true; + } + m_continueAt = 0; + return false; +} + +// ************************************************************* + +unsigned int MeterFilter::Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking) +{ + FILTER_BEGIN; + m_currentMessageBytes += length; + m_totalBytes += length; + + if (messageEnd) + { + m_currentMessageBytes = 0; + m_currentSeriesMessages++; + m_totalMessages++; + } + + FILTER_OUTPUT(1, begin, length, messageEnd); + FILTER_END_NO_MESSAGE_END; +} + +bool MeterFilter::IsolatedMessageSeriesEnd(bool blocking) +{ + m_currentMessageBytes = 0; + m_currentSeriesMessages = 0; + m_totalMessageSeries++; + return false; +} + +// ************************************************************* + +void FilterWithBufferedInput::BlockQueue::ResetQueue(unsigned int blockSize, unsigned int maxBlocks) +{ + m_buffer.New(blockSize * maxBlocks); + m_blockSize = blockSize; + m_maxBlocks = maxBlocks; + m_size = 0; + m_begin = m_buffer; +} + +byte *FilterWithBufferedInput::BlockQueue::GetBlock() +{ + if (m_size >= m_blockSize) + { + byte *ptr = m_begin; + if ((m_begin+=m_blockSize) == m_buffer.end()) + m_begin = m_buffer; + m_size -= m_blockSize; + return ptr; + } + else + return NULL; +} + +byte *FilterWithBufferedInput::BlockQueue::GetContigousBlocks(unsigned int &numberOfBytes) +{ + numberOfBytes = STDMIN(numberOfBytes, STDMIN((unsigned int)(m_buffer.end()-m_begin), m_size)); + byte *ptr = m_begin; + m_begin += numberOfBytes; + m_size -= numberOfBytes; + if (m_size == 0 || m_begin == m_buffer.end()) + m_begin = m_buffer; + return ptr; +} + +unsigned int FilterWithBufferedInput::BlockQueue::GetAll(byte *outString) +{ + unsigned int size = m_size; + unsigned int numberOfBytes = m_maxBlocks*m_blockSize; + const byte *ptr = GetContigousBlocks(numberOfBytes); + memcpy(outString, ptr, numberOfBytes); + memcpy(outString+numberOfBytes, m_begin, m_size); + m_size = 0; + return size; +} + +void FilterWithBufferedInput::BlockQueue::Put(const byte *inString, unsigned int length) +{ + assert(m_size + length <= m_buffer.size()); + byte *end = (m_size < (unsigned int)(m_buffer.end()-m_begin)) ? m_begin + m_size : m_begin + m_size - m_buffer.size(); + unsigned int len = STDMIN(length, (unsigned int)(m_buffer.end()-end)); + memcpy(end, inString, len); + if (len < length) + memcpy(m_buffer, inString+len, length-len); + m_size += length; +} + +FilterWithBufferedInput::FilterWithBufferedInput(BufferedTransformation *attachment) + : Filter(attachment) +{ +} + +FilterWithBufferedInput::FilterWithBufferedInput(unsigned int firstSize, unsigned int blockSize, unsigned int lastSize, BufferedTransformation *attachment) + : Filter(attachment), m_firstSize(firstSize), m_blockSize(blockSize), m_lastSize(lastSize) + , m_firstInputDone(false) +{ + if (m_firstSize < 0 || m_blockSize < 1 || m_lastSize < 0) + throw InvalidArgument("FilterWithBufferedInput: invalid buffer size"); + + m_queue.ResetQueue(1, m_firstSize); +} + +void FilterWithBufferedInput::IsolatedInitialize(const NameValuePairs ¶meters) +{ + InitializeDerivedAndReturnNewSizes(parameters, m_firstSize, m_blockSize, m_lastSize); + if (m_firstSize < 0 || m_blockSize < 1 || m_lastSize < 0) + throw InvalidArgument("FilterWithBufferedInput: invalid buffer size"); + m_queue.ResetQueue(1, m_firstSize); + m_firstInputDone = false; +} + +bool FilterWithBufferedInput::IsolatedFlush(bool hardFlush, bool blocking) +{ + if (!blocking) + throw BlockingInputOnly("FilterWithBufferedInput"); + + if (hardFlush) + ForceNextPut(); + FlushDerived(); + + return false; +} + +unsigned int FilterWithBufferedInput::PutMaybeModifiable(byte *inString, unsigned int length, int messageEnd, bool blocking, bool modifiable) +{ + if (!blocking) + throw BlockingInputOnly("FilterWithBufferedInput"); + + if (length != 0) + { + unsigned int newLength = m_queue.CurrentSize() + length; + + if (!m_firstInputDone && newLength >= m_firstSize) + { + unsigned int len = m_firstSize - m_queue.CurrentSize(); + m_queue.Put(inString, len); + FirstPut(m_queue.GetContigousBlocks(m_firstSize)); + assert(m_queue.CurrentSize() == 0); + m_queue.ResetQueue(m_blockSize, (2*m_blockSize+m_lastSize-2)/m_blockSize); + + inString += len; + newLength -= m_firstSize; + m_firstInputDone = true; + } + + if (m_firstInputDone) + { + if (m_blockSize == 1) + { + while (newLength > m_lastSize && m_queue.CurrentSize() > 0) + { + unsigned int len = newLength - m_lastSize; + byte *ptr = m_queue.GetContigousBlocks(len); + NextPutModifiable(ptr, len); + newLength -= len; + } + + if (newLength > m_lastSize) + { + unsigned int len = newLength - m_lastSize; + NextPutMaybeModifiable(inString, len, modifiable); + inString += len; + newLength -= len; + } + } + else + { + while (newLength >= m_blockSize + m_lastSize && m_queue.CurrentSize() >= m_blockSize) + { + NextPutModifiable(m_queue.GetBlock(), m_blockSize); + newLength -= m_blockSize; + } + + if (newLength >= m_blockSize + m_lastSize && m_queue.CurrentSize() > 0) + { + assert(m_queue.CurrentSize() < m_blockSize); + unsigned int len = m_blockSize - m_queue.CurrentSize(); + m_queue.Put(inString, len); + inString += len; + NextPutModifiable(m_queue.GetBlock(), m_blockSize); + newLength -= m_blockSize; + } + + if (newLength >= m_blockSize + m_lastSize) + { + unsigned int len = RoundDownToMultipleOf(newLength - m_lastSize, m_blockSize); + NextPutMaybeModifiable(inString, len, modifiable); + inString += len; + newLength -= len; + } + } + } + + m_queue.Put(inString, newLength - m_queue.CurrentSize()); + } + + if (messageEnd) + { + if (!m_firstInputDone && m_firstSize==0) + FirstPut(NULL); + + SecByteBlock temp(m_queue.CurrentSize()); + m_queue.GetAll(temp); + LastPut(temp, temp.size()); + + m_firstInputDone = false; + m_queue.ResetQueue(1, m_firstSize); + + Output(1, NULL, 0, messageEnd, blocking); + } + return 0; +} + +void FilterWithBufferedInput::ForceNextPut() +{ + if (!m_firstInputDone) + return; + + if (m_blockSize > 1) + { + while (m_queue.CurrentSize() >= m_blockSize) + NextPutModifiable(m_queue.GetBlock(), m_blockSize); + } + else + { + unsigned int len; + while ((len = m_queue.CurrentSize()) > 0) + NextPutModifiable(m_queue.GetContigousBlocks(len), len); + } +} + +void FilterWithBufferedInput::NextPutMultiple(const byte *inString, unsigned int length) +{ + assert(m_blockSize > 1); // m_blockSize = 1 should always override this function + while (length > 0) + { + assert(length >= m_blockSize); + NextPutSingle(inString); + inString += m_blockSize; + length -= m_blockSize; + } +} + +// ************************************************************* + +void Redirector::ChannelInitialize(const std::string &channel, const NameValuePairs ¶meters, int propagation) +{ + if (channel.empty()) + { + m_target = parameters.GetValueWithDefault("RedirectionTargetPointer", (BufferedTransformation*)NULL); + m_passSignal = parameters.GetValueWithDefault("PassSignal", true); + } + + if (m_target && m_passSignal) + m_target->ChannelInitialize(channel, parameters, propagation); +} + +// ************************************************************* + +ProxyFilter::ProxyFilter(BufferedTransformation *filter, unsigned int firstSize, unsigned int lastSize, BufferedTransformation *attachment) + : FilterWithBufferedInput(firstSize, 1, lastSize, attachment), m_filter(filter) +{ + if (m_filter.get()) + m_filter->Attach(new OutputProxy(*this, false)); +} + +bool ProxyFilter::IsolatedFlush(bool hardFlush, bool blocking) +{ + return m_filter.get() ? m_filter->Flush(hardFlush, -1, blocking) : false; +} + +void ProxyFilter::SetFilter(Filter *filter) +{ + m_filter.reset(filter); + if (filter) + { + OutputProxy *proxy; + std::auto_ptr<OutputProxy> temp(proxy = new OutputProxy(*this, false)); + m_filter->TransferAllTo(*proxy); + m_filter->Attach(temp.release()); + } +} + +void ProxyFilter::NextPutMultiple(const byte *s, unsigned int len) +{ + if (m_filter.get()) + m_filter->Put(s, len); +} + +// ************************************************************* + +unsigned int ArraySink::Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking) +{ + memcpy(m_buf+m_total, begin, STDMIN(length, SaturatingSubtract(m_size, m_total))); + m_total += length; + return 0; +} + +byte * ArraySink::CreatePutSpace(unsigned int &size) +{ + size = m_size - m_total; + return m_buf + m_total; +} + +void ArraySink::IsolatedInitialize(const NameValuePairs ¶meters) +{ + ByteArrayParameter array; + if (!parameters.GetValue(Name::OutputBuffer(), array)) + throw InvalidArgument("ArraySink: missing OutputBuffer argument"); + m_buf = array.begin(); + m_size = array.size(); + m_total = 0; +} + +unsigned int ArrayXorSink::Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking) +{ + xorbuf(m_buf+m_total, begin, STDMIN(length, SaturatingSubtract(m_size, m_total))); + m_total += length; + return 0; +} + +// ************************************************************* + +unsigned int StreamTransformationFilter::LastBlockSize(StreamTransformation &c, BlockPaddingScheme padding) +{ + if (c.MinLastBlockSize() > 0) + return c.MinLastBlockSize(); + else if (c.MandatoryBlockSize() > 1 && !c.IsForwardTransformation() && padding != NO_PADDING && padding != ZEROS_PADDING) + return c.MandatoryBlockSize(); + else + return 0; +} + +StreamTransformationFilter::StreamTransformationFilter(StreamTransformation &c, BufferedTransformation *attachment, BlockPaddingScheme padding) + : FilterWithBufferedInput(0, c.MandatoryBlockSize(), LastBlockSize(c, padding), attachment) + , m_cipher(c) +{ + assert(c.MinLastBlockSize() == 0 || c.MinLastBlockSize() > c.MandatoryBlockSize()); + + bool isBlockCipher = (c.MandatoryBlockSize() > 1 && c.MinLastBlockSize() == 0); + + if (padding == DEFAULT_PADDING) + { + if (isBlockCipher) + m_padding = PKCS_PADDING; + else + m_padding = NO_PADDING; + } + else + m_padding = padding; + + if (!isBlockCipher && (m_padding == PKCS_PADDING || m_padding == ONE_AND_ZEROS_PADDING)) + throw InvalidArgument("StreamTransformationFilter: PKCS_PADDING and ONE_AND_ZEROS_PADDING cannot be used with " + c.AlgorithmName()); +} + +void StreamTransformationFilter::FirstPut(const byte *inString) +{ + m_optimalBufferSize = m_cipher.OptimalBlockSize(); + m_optimalBufferSize = STDMAX(m_optimalBufferSize, RoundDownToMultipleOf(4096U, m_optimalBufferSize)); +} + +void StreamTransformationFilter::NextPutMultiple(const byte *inString, unsigned int length) +{ + if (!length) + return; + + unsigned int s = m_cipher.MandatoryBlockSize(); + + do + { + unsigned int len = m_optimalBufferSize; + byte *space = HelpCreatePutSpace(*AttachedTransformation(), NULL_CHANNEL, s, length, len); + if (len < length) + { + if (len == m_optimalBufferSize) + len -= m_cipher.GetOptimalBlockSizeUsed(); + len = RoundDownToMultipleOf(len, s); + } + else + len = length; + m_cipher.ProcessString(space, inString, len); + AttachedTransformation()->PutModifiable(space, len); + inString += len; + length -= len; + } + while (length > 0); +} + +void StreamTransformationFilter::NextPutModifiable(byte *inString, unsigned int length) +{ + m_cipher.ProcessString(inString, length); + AttachedTransformation()->PutModifiable(inString, length); +} + +void StreamTransformationFilter::LastPut(const byte *inString, unsigned int length) +{ + byte *space = NULL; + + switch (m_padding) + { + case NO_PADDING: + case ZEROS_PADDING: + if (length > 0) + { + unsigned int minLastBlockSize = m_cipher.MinLastBlockSize(); + bool isForwardTransformation = m_cipher.IsForwardTransformation(); + + if (isForwardTransformation && m_padding == ZEROS_PADDING && (minLastBlockSize == 0 || length < minLastBlockSize)) + { + // do padding + unsigned int blockSize = STDMAX(minLastBlockSize, m_cipher.MandatoryBlockSize()); + space = HelpCreatePutSpace(*AttachedTransformation(), NULL_CHANNEL, blockSize); + memcpy(space, inString, length); + memset(space + length, 0, blockSize - length); + m_cipher.ProcessLastBlock(space, space, blockSize); + AttachedTransformation()->Put(space, blockSize); + } + else + { + if (minLastBlockSize == 0) + { + if (isForwardTransformation) + throw InvalidDataFormat("StreamTransformationFilter: plaintext length is not a multiple of block size and NO_PADDING is specified"); + else + throw InvalidCiphertext("StreamTransformationFilter: ciphertext length is not a multiple of block size"); + } + + space = HelpCreatePutSpace(*AttachedTransformation(), NULL_CHANNEL, length, m_optimalBufferSize); + m_cipher.ProcessLastBlock(space, inString, length); + AttachedTransformation()->Put(space, length); + } + } + break; + + case PKCS_PADDING: + case ONE_AND_ZEROS_PADDING: + unsigned int s; + s = m_cipher.MandatoryBlockSize(); + assert(s > 1); + space = HelpCreatePutSpace(*AttachedTransformation(), NULL_CHANNEL, s, m_optimalBufferSize); + if (m_cipher.IsForwardTransformation()) + { + assert(length < s); + memcpy(space, inString, length); + if (m_padding == PKCS_PADDING) + { + assert(s < 256); + byte pad = s-length; + memset(space+length, pad, s-length); + } + else + { + space[length] = 1; + memset(space+length+1, 0, s-length-1); + } + m_cipher.ProcessData(space, space, s); + AttachedTransformation()->Put(space, s); + } + else + { + if (length != s) + throw InvalidCiphertext("StreamTransformationFilter: ciphertext length is not a multiple of block size"); + m_cipher.ProcessData(space, inString, s); + if (m_padding == PKCS_PADDING) + { + byte pad = space[s-1]; + if (pad < 1 || pad > s || std::find_if(space+s-pad, space+s, std::bind2nd(std::not_equal_to<byte>(), pad)) != space+s) + throw InvalidCiphertext("StreamTransformationFilter: invalid PKCS #7 block padding found"); + length = s-pad; + } + else + { + while (length > 1 && space[length-1] == '\0') + --length; + if (space[--length] != '\1') + throw InvalidCiphertext("StreamTransformationFilter: invalid ones-and-zeros padding found"); + } + AttachedTransformation()->Put(space, length); + } + break; + + default: + assert(false); + } +} + +// ************************************************************* + +void HashFilter::IsolatedInitialize(const NameValuePairs ¶meters) +{ + m_putMessage = parameters.GetValueWithDefault(Name::PutMessage(), false); + m_hashModule.Restart(); +} + +unsigned int HashFilter::Put2(const byte *inString, unsigned int length, int messageEnd, bool blocking) +{ + FILTER_BEGIN; + m_hashModule.Update(inString, length); + if (m_putMessage) + FILTER_OUTPUT(1, inString, length, 0); + if (messageEnd) + { + { + unsigned int size, digestSize = m_hashModule.DigestSize(); + m_space = HelpCreatePutSpace(*AttachedTransformation(), NULL_CHANNEL, digestSize, digestSize, size = digestSize); + m_hashModule.Final(m_space); + } + FILTER_OUTPUT(2, m_space, m_hashModule.DigestSize(), messageEnd); + } + FILTER_END_NO_MESSAGE_END; +} + +// ************************************************************* + +HashVerificationFilter::HashVerificationFilter(HashTransformation &hm, BufferedTransformation *attachment, word32 flags) + : FilterWithBufferedInput(attachment) + , m_hashModule(hm) +{ + IsolatedInitialize(MakeParameters(Name::HashVerificationFilterFlags(), flags)); +} + +void HashVerificationFilter::InitializeDerivedAndReturnNewSizes(const NameValuePairs ¶meters, unsigned int &firstSize, unsigned int &blockSize, unsigned int &lastSize) +{ + m_flags = parameters.GetValueWithDefault(Name::HashVerificationFilterFlags(), (word32)DEFAULT_FLAGS); + m_hashModule.Restart(); + unsigned int size = m_hashModule.DigestSize(); + m_verified = false; + firstSize = m_flags & HASH_AT_BEGIN ? size : 0; + blockSize = 1; + lastSize = m_flags & HASH_AT_BEGIN ? 0 : size; +} + +void HashVerificationFilter::FirstPut(const byte *inString) +{ + if (m_flags & HASH_AT_BEGIN) + { + m_expectedHash.New(m_hashModule.DigestSize()); + memcpy(m_expectedHash, inString, m_expectedHash.size()); + if (m_flags & PUT_HASH) + AttachedTransformation()->Put(inString, m_expectedHash.size()); + } +} + +void HashVerificationFilter::NextPutMultiple(const byte *inString, unsigned int length) +{ + m_hashModule.Update(inString, length); + if (m_flags & PUT_MESSAGE) + AttachedTransformation()->Put(inString, length); +} + +void HashVerificationFilter::LastPut(const byte *inString, unsigned int length) +{ + if (m_flags & HASH_AT_BEGIN) + { + assert(length == 0); + m_verified = m_hashModule.Verify(m_expectedHash); + } + else + { + m_verified = (length==m_hashModule.DigestSize() && m_hashModule.Verify(inString)); + if (m_flags & PUT_HASH) + AttachedTransformation()->Put(inString, length); + } + + if (m_flags & PUT_RESULT) + AttachedTransformation()->Put(m_verified); + + if ((m_flags & THROW_EXCEPTION) && !m_verified) + throw HashVerificationFailed(); +} + +// ************************************************************* + +void SignerFilter::IsolatedInitialize(const NameValuePairs ¶meters) +{ + m_putMessage = parameters.GetValueWithDefault(Name::PutMessage(), false); + m_messageAccumulator.reset(m_signer.NewSignatureAccumulator()); +} + +unsigned int SignerFilter::Put2(const byte *inString, unsigned int length, int messageEnd, bool blocking) +{ + FILTER_BEGIN; + m_messageAccumulator->Update(inString, length); + if (m_putMessage) + FILTER_OUTPUT(1, inString, length, 0); + if (messageEnd) + { + m_buf.New(m_signer.SignatureLength()); + m_signer.Sign(m_rng, m_messageAccumulator.release(), m_buf); + FILTER_OUTPUT(2, m_buf, m_buf.size(), messageEnd); + m_messageAccumulator.reset(m_signer.NewSignatureAccumulator()); + } + FILTER_END_NO_MESSAGE_END; +} + +SignatureVerificationFilter::SignatureVerificationFilter(const PK_Verifier &verifier, BufferedTransformation *attachment, word32 flags) + : FilterWithBufferedInput(attachment) + , m_verifier(verifier) +{ + IsolatedInitialize(MakeParameters(Name::SignatureVerificationFilterFlags(), flags)); +} + +void SignatureVerificationFilter::InitializeDerivedAndReturnNewSizes(const NameValuePairs ¶meters, unsigned int &firstSize, unsigned int &blockSize, unsigned int &lastSize) +{ + m_flags = parameters.GetValueWithDefault(Name::SignatureVerificationFilterFlags(), (word32)DEFAULT_FLAGS); + m_messageAccumulator.reset(m_verifier.NewVerificationAccumulator()); + unsigned int size = m_verifier.SignatureLength(); + m_verified = false; + firstSize = m_flags & SIGNATURE_AT_BEGIN ? size : 0; + blockSize = 1; + lastSize = m_flags & SIGNATURE_AT_BEGIN ? 0 : size; +} + +void SignatureVerificationFilter::FirstPut(const byte *inString) +{ + if (m_flags & SIGNATURE_AT_BEGIN) + { + if (m_verifier.SignatureUpfrontForVerification()) + m_verifier.InitializeVerificationAccumulator(*m_messageAccumulator, inString); + else + { + m_signature.New(m_verifier.SignatureLength()); + memcpy(m_signature, inString, m_signature.size()); + } + + if (m_flags & PUT_SIGNATURE) + AttachedTransformation()->Put(inString, m_signature.size()); + } + else + { + assert(!m_verifier.SignatureUpfrontForVerification()); + } +} + +void SignatureVerificationFilter::NextPutMultiple(const byte *inString, unsigned int length) +{ + m_messageAccumulator->Update(inString, length); + if (m_flags & PUT_MESSAGE) + AttachedTransformation()->Put(inString, length); +} + +void SignatureVerificationFilter::LastPut(const byte *inString, unsigned int length) +{ + if (m_flags & SIGNATURE_AT_BEGIN) + { + assert(length == 0); + m_verified = m_verifier.Verify(m_messageAccumulator.release(), m_signature); + } + else + { + m_verified = (length==m_verifier.SignatureLength() && m_verifier.Verify(m_messageAccumulator.release(), inString)); + if (m_flags & PUT_SIGNATURE) + AttachedTransformation()->Put(inString, length); + } + + if (m_flags & PUT_RESULT) + AttachedTransformation()->Put(m_verified); + + if ((m_flags & THROW_EXCEPTION) && !m_verified) + throw SignatureVerificationFailed(); +} + +// ************************************************************* + +unsigned int Source::PumpAll2(bool blocking) +{ + // TODO: switch length type + unsigned long i = UINT_MAX; + RETURN_IF_NONZERO(Pump2(i, blocking)); + unsigned int j = UINT_MAX; + return PumpMessages2(j, blocking); +} + +bool Store::GetNextMessage() +{ + if (!m_messageEnd && !AnyRetrievable()) + { + m_messageEnd=true; + return true; + } + else + return false; +} + +unsigned int Store::CopyMessagesTo(BufferedTransformation &target, unsigned int count, const std::string &channel) const +{ + if (m_messageEnd || count == 0) + return 0; + else + { + CopyTo(target, ULONG_MAX, channel); + if (GetAutoSignalPropagation()) + target.ChannelMessageEnd(channel, GetAutoSignalPropagation()-1); + return 1; + } +} + +void StringStore::StoreInitialize(const NameValuePairs ¶meters) +{ + ConstByteArrayParameter array; + if (!parameters.GetValue(Name::InputBuffer(), array)) + throw InvalidArgument("StringStore: missing InputBuffer argument"); + m_store = array.begin(); + m_length = array.size(); + m_count = 0; +} + +unsigned int StringStore::TransferTo2(BufferedTransformation &target, unsigned long &transferBytes, const std::string &channel, bool blocking) +{ + unsigned long position = 0; + unsigned int blockedBytes = CopyRangeTo2(target, position, transferBytes, channel, blocking); + m_count += position; + transferBytes = position; + return blockedBytes; +} + +unsigned int StringStore::CopyRangeTo2(BufferedTransformation &target, unsigned long &begin, unsigned long end, const std::string &channel, bool blocking) const +{ + unsigned int i = (unsigned int)STDMIN((unsigned long)m_count+begin, (unsigned long)m_length); + unsigned int len = (unsigned int)STDMIN((unsigned long)m_length-i, end-begin); + unsigned int blockedBytes = target.ChannelPut2(channel, m_store+i, len, 0, blocking); + if (!blockedBytes) + begin += len; + return blockedBytes; +} + +unsigned int RandomNumberStore::TransferTo2(BufferedTransformation &target, unsigned long &transferBytes, const std::string &channel, bool blocking) +{ + if (!blocking) + throw NotImplemented("RandomNumberStore: nonblocking transfer is not implemented by this object"); + + unsigned long transferMax = transferBytes; + for (transferBytes = 0; transferBytes<transferMax && m_count < m_length; ++transferBytes, ++m_count) + target.ChannelPut(channel, m_rng.GenerateByte()); + return 0; +} + +unsigned int NullStore::CopyRangeTo2(BufferedTransformation &target, unsigned long &begin, unsigned long end, const std::string &channel, bool blocking) const +{ + static const byte nullBytes[128] = {0}; + while (begin < end) + { + unsigned int len = STDMIN(end-begin, 128UL); + unsigned int blockedBytes = target.ChannelPut2(channel, nullBytes, len, 0, blocking); + if (blockedBytes) + return blockedBytes; + begin += len; + } + return 0; +} + +unsigned int NullStore::TransferTo2(BufferedTransformation &target, unsigned long &transferBytes, const std::string &channel, bool blocking) +{ + unsigned long begin = 0; + unsigned int blockedBytes = NullStore::CopyRangeTo2(target, begin, transferBytes, channel, blocking); + transferBytes = begin; + m_size -= begin; + return blockedBytes; +} + +NAMESPACE_END diff --git a/filters.h b/filters.h new file mode 100644 index 0000000..1b8965b --- /dev/null +++ b/filters.h @@ -0,0 +1,681 @@ +#ifndef CRYPTOPP_FILTERS_H +#define CRYPTOPP_FILTERS_H + +#include "simple.h" +#include "secblock.h" +#include "misc.h" +#include "smartptr.h" +#include "queue.h" +#include "algparam.h" + +NAMESPACE_BEGIN(CryptoPP) + +/// provides an implementation of BufferedTransformation's attachment interface +class Filter : public BufferedTransformation, public NotCopyable +{ +public: + Filter(BufferedTransformation *attachment); + + bool Attachable() {return true;} + BufferedTransformation *AttachedTransformation(); + const BufferedTransformation *AttachedTransformation() const; + void Detach(BufferedTransformation *newAttachment = NULL); + + unsigned int TransferTo2(BufferedTransformation &target, unsigned long &transferBytes, const std::string &channel=NULL_CHANNEL, bool blocking=true); + unsigned int CopyRangeTo2(BufferedTransformation &target, unsigned long &begin, unsigned long end=ULONG_MAX, const std::string &channel=NULL_CHANNEL, bool blocking=true) const; + + void Initialize(const NameValuePairs ¶meters=g_nullNameValuePairs, int propagation=-1); + bool Flush(bool hardFlush, int propagation=-1, bool blocking=true); + bool MessageSeriesEnd(int propagation=-1, bool blocking=true); + +protected: + virtual void NotifyAttachmentChange() {} + virtual BufferedTransformation * NewDefaultAttachment() const; + void Insert(Filter *nextFilter); // insert filter after this one + + virtual bool ShouldPropagateMessageEnd() const {return true;} + virtual bool ShouldPropagateMessageSeriesEnd() const {return true;} + + void PropagateInitialize(const NameValuePairs ¶meters, int propagation, const std::string &channel=NULL_CHANNEL); + + unsigned int Output(int outputSite, const byte *inString, unsigned int length, int messageEnd, bool blocking, const std::string &channel=NULL_CHANNEL); + bool OutputMessageEnd(int outputSite, int propagation, bool blocking, const std::string &channel=NULL_CHANNEL); + bool OutputFlush(int outputSite, bool hardFlush, int propagation, bool blocking, const std::string &channel=NULL_CHANNEL); + bool OutputMessageSeriesEnd(int outputSite, int propagation, bool blocking, const std::string &channel=NULL_CHANNEL); + +private: + member_ptr<BufferedTransformation> m_attachment; + +protected: + unsigned int m_inputPosition; + int m_continueAt; +}; + +struct FilterPutSpaceHelper +{ + // desiredSize is how much to ask target, bufferSize is how much to allocate in m_tempSpace + byte *HelpCreatePutSpace(BufferedTransformation &target, const std::string &channel, unsigned int minSize, unsigned int desiredSize, unsigned int &bufferSize) + { + assert(desiredSize >= minSize && bufferSize >= minSize); + if (m_tempSpace.size() < minSize) + { + byte *result = target.ChannelCreatePutSpace(channel, desiredSize); + if (desiredSize >= minSize) + { + bufferSize = desiredSize; + return result; + } + m_tempSpace.New(bufferSize); + } + + bufferSize = m_tempSpace.size(); + return m_tempSpace.begin(); + } + byte *HelpCreatePutSpace(BufferedTransformation &target, const std::string &channel, unsigned int minSize) + {return HelpCreatePutSpace(target, channel, minSize, minSize, minSize);} + byte *HelpCreatePutSpace(BufferedTransformation &target, const std::string &channel, unsigned int minSize, unsigned int bufferSize) + {return HelpCreatePutSpace(target, channel, minSize, minSize, bufferSize);} + SecByteBlock m_tempSpace; +}; + +//! measure how many byte and messages pass through, also serves as valve +class MeterFilter : public Bufferless<Filter> +{ +public: + MeterFilter(BufferedTransformation *attachment=NULL, bool transparent=true) + : Bufferless<Filter>(attachment), m_transparent(transparent) {ResetMeter();} + + void SetTransparent(bool transparent) {m_transparent = transparent;} + void ResetMeter() {m_currentMessageBytes = m_totalBytes = m_currentSeriesMessages = m_totalMessages = m_totalMessageSeries = 0;} + + unsigned long GetCurrentMessageBytes() const {return m_currentMessageBytes;} + unsigned long GetTotalBytes() {return m_totalBytes;} + unsigned int GetCurrentSeriesMessages() {return m_currentSeriesMessages;} + unsigned int GetTotalMessages() {return m_totalMessages;} + unsigned int GetTotalMessageSeries() {return m_totalMessageSeries;} + + unsigned int Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking); + bool IsolatedMessageSeriesEnd(bool blocking); + +private: + bool ShouldPropagateMessageEnd() const {return m_transparent;} + bool ShouldPropagateMessageSeriesEnd() const {return m_transparent;} + + bool m_transparent; + unsigned long m_currentMessageBytes, m_totalBytes; + unsigned int m_currentSeriesMessages, m_totalMessages, m_totalMessageSeries; +}; + +//! . +class TransparentFilter : public MeterFilter +{ +public: + TransparentFilter(BufferedTransformation *attachment=NULL) : MeterFilter(attachment, true) {} +}; + +//! . +class OpaqueFilter : public MeterFilter +{ +public: + OpaqueFilter(BufferedTransformation *attachment=NULL) : MeterFilter(attachment, false) {} +}; + +/*! FilterWithBufferedInput divides up the input stream into + a first block, a number of middle blocks, and a last block. + First and last blocks are optional, and middle blocks may + be a stream instead (i.e. blockSize == 1). +*/ +class FilterWithBufferedInput : public Filter +{ +public: + FilterWithBufferedInput(BufferedTransformation *attachment); + //! firstSize and lastSize may be 0, blockSize must be at least 1 + FilterWithBufferedInput(unsigned int firstSize, unsigned int blockSize, unsigned int lastSize, BufferedTransformation *attachment); + + void IsolatedInitialize(const NameValuePairs ¶meters); + unsigned int Put2(const byte *inString, unsigned int length, int messageEnd, bool blocking) + { + return PutMaybeModifiable(const_cast<byte *>(inString), length, messageEnd, blocking, false); + } + unsigned int PutModifiable2(byte *inString, unsigned int length, int messageEnd, bool blocking) + { + return PutMaybeModifiable(inString, length, messageEnd, blocking, true); + } + /*! calls ForceNextPut() if hardFlush is true */ + bool IsolatedFlush(bool hardFlush, bool blocking); + + /*! The input buffer may contain more than blockSize bytes if lastSize != 0. + ForceNextPut() forces a call to NextPut() if this is the case. + */ + void ForceNextPut(); + +protected: + bool DidFirstPut() {return m_firstInputDone;} + + virtual void InitializeDerivedAndReturnNewSizes(const NameValuePairs ¶meters, unsigned int &firstSize, unsigned int &blockSize, unsigned int &lastSize) + {InitializeDerived(parameters);} + virtual void InitializeDerived(const NameValuePairs ¶meters) {} + // FirstPut() is called if (firstSize != 0 and totalLength >= firstSize) + // or (firstSize == 0 and (totalLength > 0 or a MessageEnd() is received)) + virtual void FirstPut(const byte *inString) =0; + // NextPut() is called if totalLength >= firstSize+blockSize+lastSize + virtual void NextPutSingle(const byte *inString) {assert(false);} + // Same as NextPut() except length can be a multiple of blockSize + // Either NextPut() or NextPutMultiple() must be overriden + virtual void NextPutMultiple(const byte *inString, unsigned int length); + // Same as NextPutMultiple(), but inString can be modified + virtual void NextPutModifiable(byte *inString, unsigned int length) + {NextPutMultiple(inString, length);} + // LastPut() is always called + // if totalLength < firstSize then length == totalLength + // else if totalLength <= firstSize+lastSize then length == totalLength-firstSize + // else lastSize <= length < lastSize+blockSize + virtual void LastPut(const byte *inString, unsigned int length) =0; + virtual void FlushDerived() {} + +private: + unsigned int PutMaybeModifiable(byte *begin, unsigned int length, int messageEnd, bool blocking, bool modifiable); + void NextPutMaybeModifiable(byte *inString, unsigned int length, bool modifiable) + { + if (modifiable) NextPutModifiable(inString, length); + else NextPutMultiple(inString, length); + } + + // This function should no longer be used, put this here to cause a compiler error + // if someone tries to override NextPut(). + virtual int NextPut(const byte *inString, unsigned int length) {assert(false); return 0;} + + class BlockQueue + { + public: + void ResetQueue(unsigned int blockSize, unsigned int maxBlocks); + byte *GetBlock(); + byte *GetContigousBlocks(unsigned int &numberOfBytes); + unsigned int GetAll(byte *outString); + void Put(const byte *inString, unsigned int length); + unsigned int CurrentSize() const {return m_size;} + unsigned int MaxSize() const {return m_buffer.size();} + + private: + SecByteBlock m_buffer; + unsigned int m_blockSize, m_maxBlocks, m_size; + byte *m_begin; + }; + + unsigned int m_firstSize, m_blockSize, m_lastSize; + bool m_firstInputDone; + BlockQueue m_queue; +}; + +//! . +class FilterWithInputQueue : public Filter +{ +public: + FilterWithInputQueue(BufferedTransformation *attachment) : Filter(attachment) {} + unsigned int Put2(const byte *inString, unsigned int length, int messageEnd, bool blocking) + { + if (!blocking) + throw BlockingInputOnly("FilterWithInputQueue"); + + m_inQueue.Put(inString, length); + if (messageEnd) + { + IsolatedMessageEnd(blocking); + Output(0, NULL, 0, messageEnd, blocking); + } + return 0; + } + +protected: + virtual bool IsolatedMessageEnd(bool blocking) =0; + void IsolatedInitialize(const NameValuePairs ¶meters) {m_inQueue.Clear();} + + ByteQueue m_inQueue; +}; + +//! Filter Wrapper for StreamTransformation +class StreamTransformationFilter : public FilterWithBufferedInput, private FilterPutSpaceHelper +{ +public: + enum BlockPaddingScheme {NO_PADDING, ZEROS_PADDING, PKCS_PADDING, ONE_AND_ZEROS_PADDING, DEFAULT_PADDING}; + /*! DEFAULT_PADDING means PKCS_PADDING if c.MandatoryBlockSize() > 1 && c.MinLastBlockSize() == 0 (e.g. ECB or CBC mode), + otherwise NO_PADDING (OFB, CFB, CTR, CBC-CTS modes) */ + StreamTransformationFilter(StreamTransformation &c, BufferedTransformation *attachment = NULL, BlockPaddingScheme padding = DEFAULT_PADDING); + + void FirstPut(const byte *inString); + void NextPutMultiple(const byte *inString, unsigned int length); + void NextPutModifiable(byte *inString, unsigned int length); + void LastPut(const byte *inString, unsigned int length); +// byte * CreatePutSpace(unsigned int &size); + +protected: + static unsigned int LastBlockSize(StreamTransformation &c, BlockPaddingScheme padding); + + StreamTransformation &m_cipher; + BlockPaddingScheme m_padding; + unsigned int m_optimalBufferSize; +}; + +#ifdef CRYPTOPP_MAINTAIN_BACKWARDS_COMPATIBILITY +typedef StreamTransformationFilter StreamCipherFilter; +#endif + +//! Filter Wrapper for HashTransformation +class HashFilter : public Bufferless<Filter>, private FilterPutSpaceHelper +{ +public: + HashFilter(HashTransformation &hm, BufferedTransformation *attachment = NULL, bool putMessage=false) + : Bufferless<Filter>(attachment), m_hashModule(hm), m_putMessage(putMessage) {} + + void IsolatedInitialize(const NameValuePairs ¶meters); + unsigned int Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking); + + byte * CreatePutSpace(unsigned int &size) {return m_hashModule.CreateUpdateSpace(size);} + +private: + HashTransformation &m_hashModule; + bool m_putMessage; + byte *m_space; +}; + +//! Filter Wrapper for HashTransformation +class HashVerificationFilter : public FilterWithBufferedInput +{ +public: + class HashVerificationFailed : public Exception + { + public: + HashVerificationFailed() + : Exception(DATA_INTEGRITY_CHECK_FAILED, "HashVerifier: message hash not valid") {} + }; + + enum Flags {HASH_AT_BEGIN=1, PUT_MESSAGE=2, PUT_HASH=4, PUT_RESULT=8, THROW_EXCEPTION=16, DEFAULT_FLAGS = HASH_AT_BEGIN | PUT_RESULT}; + HashVerificationFilter(HashTransformation &hm, BufferedTransformation *attachment = NULL, word32 flags = DEFAULT_FLAGS); + + bool GetLastResult() const {return m_verified;} + +protected: + void InitializeDerivedAndReturnNewSizes(const NameValuePairs ¶meters, unsigned int &firstSize, unsigned int &blockSize, unsigned int &lastSize); + void FirstPut(const byte *inString); + void NextPutMultiple(const byte *inString, unsigned int length); + void LastPut(const byte *inString, unsigned int length); + +private: + static inline unsigned int FirstSize(word32 flags, HashTransformation &hm) {return flags & HASH_AT_BEGIN ? hm.DigestSize() : 0;} + static inline unsigned int LastSize(word32 flags, HashTransformation &hm) {return flags & HASH_AT_BEGIN ? 0 : hm.DigestSize();} + + HashTransformation &m_hashModule; + word32 m_flags; + SecByteBlock m_expectedHash; + bool m_verified; +}; + +typedef HashVerificationFilter HashVerifier; // for backwards compatibility + +//! Filter Wrapper for PK_Signer +class SignerFilter : public Unflushable<Filter> +{ +public: + SignerFilter(RandomNumberGenerator &rng, const PK_Signer &signer, BufferedTransformation *attachment = NULL, bool putMessage=false) + : Unflushable<Filter>(attachment), m_rng(rng), m_signer(signer), m_messageAccumulator(signer.NewSignatureAccumulator()), m_putMessage(putMessage) {} + + void IsolatedInitialize(const NameValuePairs ¶meters); + unsigned int Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking); + +private: + RandomNumberGenerator &m_rng; + const PK_Signer &m_signer; + member_ptr<HashTransformation> m_messageAccumulator; + bool m_putMessage; + SecByteBlock m_buf; +}; + +//! Filter Wrapper for PK_Verifier +class SignatureVerificationFilter : public FilterWithBufferedInput +{ +public: + class SignatureVerificationFailed : public Exception + { + public: + SignatureVerificationFailed() + : Exception(DATA_INTEGRITY_CHECK_FAILED, "VerifierFilter: digital signature not valid") {} + }; + + enum Flags {SIGNATURE_AT_BEGIN=1, PUT_MESSAGE=2, PUT_SIGNATURE=4, PUT_RESULT=8, THROW_EXCEPTION=16, DEFAULT_FLAGS = SIGNATURE_AT_BEGIN | PUT_RESULT}; + SignatureVerificationFilter(const PK_Verifier &verifier, BufferedTransformation *attachment = NULL, word32 flags = DEFAULT_FLAGS); + + bool GetLastResult() const {return m_verified;} + +protected: + void InitializeDerivedAndReturnNewSizes(const NameValuePairs ¶meters, unsigned int &firstSize, unsigned int &blockSize, unsigned int &lastSize); + void FirstPut(const byte *inString); + void NextPutMultiple(const byte *inString, unsigned int length); + void LastPut(const byte *inString, unsigned int length); + +private: + const PK_Verifier &m_verifier; + member_ptr<HashTransformation> m_messageAccumulator; + word32 m_flags; + SecByteBlock m_signature; + bool m_verified; +}; + +typedef SignatureVerificationFilter VerifierFilter; // for backwards compatibility + +//! Redirect input to another BufferedTransformation without owning it +class Redirector : public CustomSignalPropagation<Sink> +{ +public: + Redirector() : m_target(NULL), m_passSignal(true) {} + Redirector(BufferedTransformation &target, bool passSignal=true) : m_target(&target), m_passSignal(passSignal) {} + + void Redirect(BufferedTransformation &target) {m_target = ⌖} + void StopRedirection() {m_target = NULL;} + bool GetPassSignal() const {return m_passSignal;} + void SetPassSignal(bool passSignal) {m_passSignal = passSignal;} + + unsigned int Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking) + {return m_target ? m_target->Put2(begin, length, m_passSignal ? messageEnd : 0, blocking) : 0;} + void Initialize(const NameValuePairs ¶meters, int propagation) + {ChannelInitialize(NULL_CHANNEL, parameters, propagation);} + bool Flush(bool hardFlush, int propagation=-1, bool blocking=true) + {return m_target && m_passSignal ? m_target->Flush(hardFlush, propagation, blocking) : false;} + bool MessageSeriesEnd(int propagation=-1, bool blocking=true) + {return m_target && m_passSignal ? m_target->MessageSeriesEnd(propagation, blocking) : false;} + + void ChannelInitialize(const std::string &channel, const NameValuePairs ¶meters=g_nullNameValuePairs, int propagation=-1); + unsigned int ChannelPut2(const std::string &channel, const byte *begin, unsigned int length, int messageEnd, bool blocking) + {return m_target ? m_target->ChannelPut2(channel, begin, length, m_passSignal ? messageEnd : 0, blocking) : 0;} + unsigned int ChannelPutModifiable2(const std::string &channel, byte *begin, unsigned int length, int messageEnd, bool blocking) + {return m_target ? m_target->ChannelPutModifiable2(channel, begin, length, m_passSignal ? messageEnd : 0, blocking) : 0;} + bool ChannelFlush(const std::string &channel, bool completeFlush, int propagation=-1, bool blocking=true) + {return m_target && m_passSignal ? m_target->ChannelFlush(channel, completeFlush, propagation, blocking) : false;} + bool ChannelMessageSeriesEnd(const std::string &channel, int propagation=-1, bool blocking=true) + {return m_target && m_passSignal ? m_target->ChannelMessageSeriesEnd(channel, propagation, blocking) : false;} + +private: + BufferedTransformation *m_target; + bool m_passSignal; +}; + +// Used By ProxyFilter +class OutputProxy : public CustomSignalPropagation<Sink> +{ +public: + OutputProxy(BufferedTransformation &owner, bool passSignal) : m_owner(owner), m_passSignal(passSignal) {} + + bool GetPassSignal() const {return m_passSignal;} + void SetPassSignal(bool passSignal) {m_passSignal = passSignal;} + + unsigned int Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking) + {return m_owner.AttachedTransformation()->Put2(begin, length, m_passSignal ? messageEnd : 0, blocking);} + unsigned int PutModifiable2(byte *begin, unsigned int length, int messageEnd, bool blocking) + {return m_owner.AttachedTransformation()->PutModifiable2(begin, length, m_passSignal ? messageEnd : 0, blocking);} + void Initialize(const NameValuePairs ¶meters=g_nullNameValuePairs, int propagation=-1) + {if (m_passSignal) m_owner.AttachedTransformation()->Initialize(parameters, propagation);} + bool Flush(bool hardFlush, int propagation=-1, bool blocking=true) + {return m_passSignal ? m_owner.AttachedTransformation()->Flush(hardFlush, propagation, blocking) : false;} + bool MessageSeriesEnd(int propagation=-1, bool blocking=true) + {return m_passSignal ? m_owner.AttachedTransformation()->MessageSeriesEnd(propagation, blocking) : false;} + + unsigned int ChannelPut2(const std::string &channel, const byte *begin, unsigned int length, int messageEnd, bool blocking) + {return m_owner.AttachedTransformation()->ChannelPut2(channel, begin, length, m_passSignal ? messageEnd : 0, blocking);} + unsigned int ChannelPutModifiable2(const std::string &channel, byte *begin, unsigned int length, int messageEnd, bool blocking) + {return m_owner.AttachedTransformation()->ChannelPutModifiable2(channel, begin, length, m_passSignal ? messageEnd : 0, blocking);} + void ChannelInitialize(const std::string &channel, const NameValuePairs ¶meters, int propagation=-1) + {if (m_passSignal) m_owner.AttachedTransformation()->ChannelInitialize(channel, parameters, propagation);} + bool ChannelFlush(const std::string &channel, bool completeFlush, int propagation=-1, bool blocking=true) + {return m_passSignal ? m_owner.AttachedTransformation()->ChannelFlush(channel, completeFlush, propagation, blocking) : false;} + bool ChannelMessageSeriesEnd(const std::string &channel, int propagation=-1, bool blocking=true) + {return m_passSignal ? m_owner.AttachedTransformation()->ChannelMessageSeriesEnd(channel, propagation, blocking) : false;} + +private: + BufferedTransformation &m_owner; + bool m_passSignal; +}; + +//! Base class for Filter classes that are proxies for a chain of other filters. +class ProxyFilter : public FilterWithBufferedInput +{ +public: + ProxyFilter(BufferedTransformation *filter, unsigned int firstSize, unsigned int lastSize, BufferedTransformation *attachment); + + bool IsolatedFlush(bool hardFlush, bool blocking); + + void SetFilter(Filter *filter); + void NextPutMultiple(const byte *s, unsigned int len); + +protected: + member_ptr<BufferedTransformation> m_filter; +}; + +//! simple proxy filter that doesn't modify the underlying filter's input or output +class SimpleProxyFilter : public ProxyFilter +{ +public: + SimpleProxyFilter(BufferedTransformation *filter, BufferedTransformation *attachment) + : ProxyFilter(filter, 0, 0, attachment) {} + + void FirstPut(const byte *) {} + void LastPut(const byte *, unsigned int) {m_filter->MessageEnd();} +}; + +//! proxy for the filter created by PK_Encryptor::CreateEncryptionFilter +/*! This class is here just to provide symmetry with VerifierFilter. */ +class PK_EncryptorFilter : public SimpleProxyFilter +{ +public: + PK_EncryptorFilter(RandomNumberGenerator &rng, const PK_Encryptor &encryptor, BufferedTransformation *attachment = NULL) + : SimpleProxyFilter(encryptor.CreateEncryptionFilter(rng), attachment) {} +}; + +//! proxy for the filter created by PK_Decryptor::CreateDecryptionFilter +/*! This class is here just to provide symmetry with SignerFilter. */ +class PK_DecryptorFilter : public SimpleProxyFilter +{ +public: + PK_DecryptorFilter(const PK_Decryptor &decryptor, BufferedTransformation *attachment = NULL) + : SimpleProxyFilter(decryptor.CreateDecryptionFilter(), attachment) {} +}; + +//! Append input to a string object +template <class T> +class StringSinkTemplate : public Bufferless<Sink> +{ +public: + // VC60 workaround: no T::char_type + typedef typename T::traits_type::char_type char_type; + + StringSinkTemplate(T &output) + : m_output(&output) {assert(sizeof(output[0])==1);} + + void IsolatedInitialize(const NameValuePairs ¶meters) + {if (!parameters.GetValue("OutputStringPointer", m_output)) throw InvalidArgument("StringSink: OutputStringPointer not specified");} + unsigned int Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking) + { + m_output->append((const char_type *)begin, (const char_type *)begin+length); + return 0; + } + +private: + T *m_output; +}; + +//! Append input to an std::string +typedef StringSinkTemplate<std::string> StringSink; + +//! Copy input to a memory buffer +class ArraySink : public Bufferless<Sink> +{ +public: + ArraySink(const NameValuePairs ¶meters = g_nullNameValuePairs) {IsolatedInitialize(parameters);} + ArraySink(byte *buf, unsigned int size) : m_buf(buf), m_size(size), m_total(0) {} + + unsigned int AvailableSize() {return m_size - STDMIN(m_total, (unsigned long)m_size);} + unsigned long TotalPutLength() {return m_total;} + + void IsolatedInitialize(const NameValuePairs ¶meters); + byte * CreatePutSpace(unsigned int &size); + unsigned int Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking); + +protected: + byte *m_buf; + unsigned int m_size; + unsigned long m_total; +}; + +//! Xor input to a memory buffer +class ArrayXorSink : public ArraySink +{ +public: + ArrayXorSink(byte *buf, unsigned int size) + : ArraySink(buf, size) {} + + unsigned int Put2(const byte *begin, unsigned int length, int messageEnd, bool blocking); + byte * CreatePutSpace(unsigned int &size) {return BufferedTransformation::CreatePutSpace(size);} +}; + +//! . +class StringStore : public Store +{ +public: + StringStore(const char *string = NULL) + {StoreInitialize(MakeParameters("InputBuffer", ConstByteArrayParameter(string)));} + StringStore(const byte *string, unsigned int length) + {StoreInitialize(MakeParameters("InputBuffer", ConstByteArrayParameter(string, length)));} + template <class T> StringStore(const T &string) + {StoreInitialize(MakeParameters("InputBuffer", ConstByteArrayParameter(string)));} + + unsigned int TransferTo2(BufferedTransformation &target, unsigned long &transferBytes, const std::string &channel=NULL_CHANNEL, bool blocking=true); + unsigned int CopyRangeTo2(BufferedTransformation &target, unsigned long &begin, unsigned long end=ULONG_MAX, const std::string &channel=NULL_CHANNEL, bool blocking=true) const; + +private: + void StoreInitialize(const NameValuePairs ¶meters); + + const byte *m_store; + unsigned int m_length, m_count; +}; + +//! . +class RandomNumberStore : public Store +{ +public: + RandomNumberStore(RandomNumberGenerator &rng, unsigned long length) + : m_rng(rng), m_length(length), m_count(0) {} + + bool AnyRetrievable() const {return MaxRetrievable() != 0;} + unsigned long MaxRetrievable() const {return m_length-m_count;} + + unsigned int TransferTo2(BufferedTransformation &target, unsigned long &transferBytes, const std::string &channel=NULL_CHANNEL, bool blocking=true); + unsigned int CopyRangeTo2(BufferedTransformation &target, unsigned long &begin, unsigned long end=ULONG_MAX, const std::string &channel=NULL_CHANNEL, bool blocking=true) const + { + throw NotImplemented("RandomNumberStore: CopyRangeTo2() is not supported by this store"); + } + +private: + void StoreInitialize(const NameValuePairs ¶meters) {m_count = 0;} + + RandomNumberGenerator &m_rng; + const unsigned long m_length; + unsigned long m_count; +}; + +//! . +class NullStore : public Store +{ +public: + NullStore(unsigned long size = ULONG_MAX) : m_size(size) {} + void StoreInitialize(const NameValuePairs ¶meters) {} + unsigned long MaxRetrievable() const {return m_size;} + unsigned int TransferTo2(BufferedTransformation &target, unsigned long &transferBytes, const std::string &channel=NULL_CHANNEL, bool blocking=true); + unsigned int CopyRangeTo2(BufferedTransformation &target, unsigned long &begin, unsigned long end=ULONG_MAX, const std::string &channel=NULL_CHANNEL, bool blocking=true) const; + +private: + unsigned long m_size; +}; + +//! A Filter that pumps data into its attachment as input +class Source : public InputRejecting<Filter> +{ +public: + Source(BufferedTransformation *attachment) + : InputRejecting<Filter>(attachment) {} + + unsigned long Pump(unsigned long pumpMax=ULONG_MAX) + {Pump2(pumpMax); return pumpMax;} + unsigned int PumpMessages(unsigned int count=UINT_MAX) + {PumpMessages2(count); return count;} + void PumpAll() + {PumpAll2();} + virtual unsigned int Pump2(unsigned long &byteCount, bool blocking=true) =0; + virtual unsigned int PumpMessages2(unsigned int &messageCount, bool blocking=true) =0; + virtual unsigned int PumpAll2(bool blocking=true); + virtual bool SourceExhausted() const =0; + +protected: + void SourceInitialize(bool pumpAll, const NameValuePairs ¶meters) + { + IsolatedInitialize(parameters); + if (pumpAll) + PumpAll(); + } +}; + +//! Turn a Store into a Source +template <class T> +class SourceTemplate : public Source +{ +public: + SourceTemplate<T>(BufferedTransformation *attachment) + : Source(attachment) {} + SourceTemplate<T>(BufferedTransformation *attachment, T store) + : Source(attachment), m_store(store) {} + void IsolatedInitialize(const NameValuePairs ¶meters) + {m_store.IsolatedInitialize(parameters);} + unsigned int Pump2(unsigned long &byteCount, bool blocking=true) + {return m_store.TransferTo2(*AttachedTransformation(), byteCount, NULL_CHANNEL, blocking);} + unsigned int PumpMessages2(unsigned int &messageCount, bool blocking=true) + {return m_store.TransferMessagesTo2(*AttachedTransformation(), messageCount, NULL_CHANNEL, blocking);} + unsigned int PumpAll2(bool blocking=true) + {return m_store.TransferAllTo2(*AttachedTransformation(), NULL_CHANNEL, blocking);} + bool SourceExhausted() const + {return !m_store.AnyRetrievable() && !m_store.AnyMessages();} + void SetAutoSignalPropagation(int propagation) + {m_store.SetAutoSignalPropagation(propagation);} + int GetAutoSignalPropagation() const + {return m_store.GetAutoSignalPropagation();} + +protected: + T m_store; +}; + +//! . +class StringSource : public SourceTemplate<StringStore> +{ +public: + StringSource(BufferedTransformation *attachment = NULL) + : SourceTemplate<StringStore>(attachment) {} + StringSource(const char *string, bool pumpAll, BufferedTransformation *attachment = NULL) + : SourceTemplate<StringStore>(attachment) {SourceInitialize(pumpAll, MakeParameters("InputBuffer", ConstByteArrayParameter(string)));} + StringSource(const byte *string, unsigned int length, bool pumpAll, BufferedTransformation *attachment = NULL) + : SourceTemplate<StringStore>(attachment) {SourceInitialize(pumpAll, MakeParameters("InputBuffer", ConstByteArrayParameter(string, length)));} + +#ifdef __MWERKS__ // CW60 workaround + StringSource(const std::string &string, bool pumpAll, BufferedTransformation *attachment = NULL) +#else + template <class T> StringSource(const T &string, bool pumpAll, BufferedTransformation *attachment = NULL) +#endif + : SourceTemplate<StringStore>(attachment) {SourceInitialize(pumpAll, MakeParameters("InputBuffer", ConstByteArrayParameter(string)));} +}; + +//! . +class RandomNumberSource : public SourceTemplate<RandomNumberStore> +{ +public: + RandomNumberSource(RandomNumberGenerator &rng, unsigned int length, bool pumpAll, BufferedTransformation *attachment = NULL) + : SourceTemplate<RandomNumberStore>(attachment, RandomNumberStore(rng, length)) {if (pumpAll) PumpAll();} +}; + +NAMESPACE_END + +#endif diff --git a/fltrimpl.h b/fltrimpl.h new file mode 100644 index 0000000..a35e68b --- /dev/null +++ b/fltrimpl.h @@ -0,0 +1,42 @@ +#ifndef CRYPTOPP_FLTRIMPL_H +#define CRYPTOPP_FLTRIMPL_H + +#define FILTER_BEGIN \ + switch (m_continueAt) \ + { \ + case 0: \ + m_inputPosition = 0; + +#define FILTER_END_NO_MESSAGE_END_NO_RETURN \ + break; \ + default: \ + assert(false); \ + } + +#define FILTER_END_NO_MESSAGE_END \ + FILTER_END_NO_MESSAGE_END_NO_RETURN \ + return 0; + +/* +#define FILTER_END \ + case -1: \ + if (messageEnd && Output(-1, NULL, 0, messageEnd, blocking)) \ + return 1; \ + FILTER_END_NO_MESSAGE_END +*/ + +#define FILTER_OUTPUT2(site, statement, output, length, messageEnd) \ + {\ + case site: \ + statement; \ + if (Output(site, output, length, messageEnd, blocking)) \ + return STDMAX(1U, (unsigned int)length-m_inputPosition);\ + } + +#define FILTER_OUTPUT(site, output, length, messageEnd) \ + FILTER_OUTPUT2(site, 0, output, length, messageEnd) + +#define FILTER_OUTPUT_BYTE(site, output) \ + FILTER_OUTPUT(site, &(const byte &)(byte)output, 1, 0) + +#endif diff --git a/integer.cpp b/integer.cpp new file mode 100644 index 0000000..35312f6 --- /dev/null +++ b/integer.cpp @@ -0,0 +1,3987 @@ +// integer.cpp - written and placed in the public domain by Wei Dai +// contains public domain code contributed by Alister Lee and Leonard Janke + +#include "pch.h" +#include "integer.h" +#include "modarith.h" +#include "nbtheory.h" +#include "asn.h" +#include "oids.h" +#include "words.h" +#include "algparam.h" +#include "pubkey.h" // for P1363_KDF2 +#include "sha.h" + +#include <iostream> + +#ifdef SSE2_INTRINSICS_AVAILABLE +#include <emmintrin.h> +#endif + +#include "algebra.cpp" +#include "eprecomp.cpp" + +NAMESPACE_BEGIN(CryptoPP) + +#ifdef SSE2_INTRINSICS_AVAILABLE +template <class T> +AllocatorBase<T>::pointer AlignedAllocator<T>::allocate(size_type n, const void *) +{ + if (n < 4) + return new T[n]; + else + return (T *)_mm_malloc(sizeof(T)*n, 16); + +} + +template <class T> +void AlignedAllocator<T>::deallocate(void *p, size_type n) +{ + memset(p, 0, n*sizeof(T)); + if (n < 4) + delete [] p; + else + _mm_free(p); +} + +template class AlignedAllocator<word>; +#endif + +#define MAKE_DWORD(lowWord, highWord) ((dword(highWord)<<WORD_BITS) | (lowWord)) + +static int Compare(const word *A, const word *B, unsigned int N) +{ + while (N--) + if (A[N] > B[N]) + return 1; + else if (A[N] < B[N]) + return -1; + + return 0; +} + +static word Increment(word *A, unsigned int N, word B=1) +{ + assert(N); + word t = A[0]; + A[0] = t+B; + if (A[0] >= t) + return 0; + for (unsigned i=1; i<N; i++) + if (++A[i]) + return 0; + return 1; +} + +static word Decrement(word *A, unsigned int N, word B=1) +{ + assert(N); + word t = A[0]; + A[0] = t-B; + if (A[0] <= t) + return 0; + for (unsigned i=1; i<N; i++) + if (A[i]--) + return 0; + return 1; +} + +static void TwosComplement(word *A, unsigned int N) +{ + Decrement(A, N); + for (unsigned i=0; i<N; i++) + A[i] = ~A[i]; +} + +static word LinearMultiply(word *C, const word *A, word B, unsigned int N) +{ + word carry=0; + for(unsigned i=0; i<N; i++) + { + dword p = (dword)A[i] * B + carry; + C[i] = LOW_WORD(p); + carry = HIGH_WORD(p); + } + return carry; +} + +static void AtomicInverseModPower2(word *C, word A0, word A1) +{ + assert(A0%2==1); + + dword A=MAKE_DWORD(A0, A1), R=A0%8; + + for (unsigned i=3; i<2*WORD_BITS; i*=2) + R = R*(2-R*A); + + assert(R*A==1); + + C[0] = LOW_WORD(R); + C[1] = HIGH_WORD(R); +} + +// ******************************************************** + +class Portable +{ +public: + static word Add(word *C, const word *A, const word *B, unsigned int N); + static word Subtract(word *C, const word *A, const word *B, unsigned int N); + + static inline void Multiply2(word *C, const word *A, const word *B); + static inline word Multiply2Add(word *C, const word *A, const word *B); + static void Multiply4(word *C, const word *A, const word *B); + static void Multiply8(word *C, const word *A, const word *B); + static inline unsigned int MultiplyRecursionLimit() {return 8;} + + static inline void Multiply2Bottom(word *C, const word *A, const word *B); + static void Multiply4Bottom(word *C, const word *A, const word *B); + static void Multiply8Bottom(word *C, const word *A, const word *B); + static inline unsigned int MultiplyBottomRecursionLimit() {return 8;} + + static void Square2(word *R, const word *A); + static void Square4(word *R, const word *A); + static void Square8(word *R, const word *A) {assert(false);} + static inline unsigned int SquareRecursionLimit() {return 4;} +}; + +word Portable::Add(word *C, const word *A, const word *B, unsigned int N) +{ + assert (N%2 == 0); + +#ifdef IS_LITTLE_ENDIAN + if (sizeof(dword) == sizeof(size_t)) // dword is only register size + { + dword carry = 0; + N >>= 1; + for (unsigned int i = 0; i < N; i++) + { + dword a = ((const dword *)A)[i] + carry; + dword c = a + ((const dword *)B)[i]; + ((dword *)C)[i] = c; + carry = (a < carry) | (c < a); + } + return (word)carry; + } + else +#endif + { + word carry = 0; + for (unsigned int i = 0; i < N; i+=2) + { + dword u = (dword) carry + A[i] + B[i]; + C[i] = LOW_WORD(u); + u = (dword) HIGH_WORD(u) + A[i+1] + B[i+1]; + C[i+1] = LOW_WORD(u); + carry = HIGH_WORD(u); + } + return carry; + } +} + +word Portable::Subtract(word *C, const word *A, const word *B, unsigned int N) +{ + assert (N%2 == 0); + +#ifdef IS_LITTLE_ENDIAN + if (sizeof(dword) == sizeof(size_t)) // dword is only register size + { + dword borrow = 0; + N >>= 1; + for (unsigned int i = 0; i < N; i++) + { + dword a = ((const dword *)A)[i]; + dword b = a - borrow; + dword c = b - ((const dword *)B)[i]; + ((dword *)C)[i] = c; + borrow = (b > a) | (c > b); + } + return (word)borrow; + } + else +#endif + { + word borrow=0; + for (unsigned i = 0; i < N; i+=2) + { + dword u = (dword) A[i] - B[i] - borrow; + C[i] = LOW_WORD(u); + u = (dword) A[i+1] - B[i+1] - (word)(0-HIGH_WORD(u)); + C[i+1] = LOW_WORD(u); + borrow = 0-HIGH_WORD(u); + } + return borrow; + } +} + +void Portable::Multiply2(word *C, const word *A, const word *B) +{ +/* + word s; + dword d; + + if (A1 >= A0) + if (B0 >= B1) + { + s = 0; + d = (dword)(A1-A0)*(B0-B1); + } + else + { + s = (A1-A0); + d = (dword)s*(word)(B0-B1); + } + else + if (B0 > B1) + { + s = (B0-B1); + d = (word)(A1-A0)*(dword)s; + } + else + { + s = 0; + d = (dword)(A0-A1)*(B1-B0); + } +*/ + // this segment is the branchless equivalent of above + word D[4] = {A[1]-A[0], A[0]-A[1], B[0]-B[1], B[1]-B[0]}; + unsigned int ai = A[1] < A[0]; + unsigned int bi = B[0] < B[1]; + unsigned int di = ai & bi; + dword d = (dword)D[di]*D[di+2]; + D[1] = D[3] = 0; + unsigned int si = ai + !bi; + word s = D[si]; + + dword A0B0 = (dword)A[0]*B[0]; + C[0] = LOW_WORD(A0B0); + + dword A1B1 = (dword)A[1]*B[1]; + dword t = (dword) HIGH_WORD(A0B0) + LOW_WORD(A0B0) + LOW_WORD(d) + LOW_WORD(A1B1); + C[1] = LOW_WORD(t); + + t = A1B1 + HIGH_WORD(t) + HIGH_WORD(A0B0) + HIGH_WORD(d) + HIGH_WORD(A1B1) - s; + C[2] = LOW_WORD(t); + C[3] = HIGH_WORD(t); +} + +inline void Portable::Multiply2Bottom(word *C, const word *A, const word *B) +{ +#ifdef IS_LITTLE_ENDIAN + if (sizeof(dword) == sizeof(size_t)) + { + dword a = *(const dword *)A, b = *(const dword *)B; + ((dword *)C)[0] = a*b; + } + else +#endif + { + dword t = (dword)A[0]*B[0]; + C[0] = LOW_WORD(t); + C[1] = HIGH_WORD(t) + A[0]*B[1] + A[1]*B[0]; + } +} + +word Portable::Multiply2Add(word *C, const word *A, const word *B) +{ + word D[4] = {A[1]-A[0], A[0]-A[1], B[0]-B[1], B[1]-B[0]}; + unsigned int ai = A[1] < A[0]; + unsigned int bi = B[0] < B[1]; + unsigned int di = ai & bi; + dword d = (dword)D[di]*D[di+2]; + D[1] = D[3] = 0; + unsigned int si = ai + !bi; + word s = D[si]; + + dword A0B0 = (dword)A[0]*B[0]; + dword t = A0B0 + C[0]; + C[0] = LOW_WORD(t); + + dword A1B1 = (dword)A[1]*B[1]; + t = (dword) HIGH_WORD(t) + LOW_WORD(A0B0) + LOW_WORD(d) + LOW_WORD(A1B1) + C[1]; + C[1] = LOW_WORD(t); + + t = (dword) HIGH_WORD(t) + LOW_WORD(A1B1) + HIGH_WORD(A0B0) + HIGH_WORD(d) + HIGH_WORD(A1B1) - s + C[2]; + C[2] = LOW_WORD(t); + + t = (dword) HIGH_WORD(t) + HIGH_WORD(A1B1) + C[3]; + C[3] = LOW_WORD(t); + return HIGH_WORD(t); +} + +#define MulAcc(x, y) \ + p = (dword)A[x] * B[y] + c; \ + c = LOW_WORD(p); \ + p = (dword)d + HIGH_WORD(p); \ + d = LOW_WORD(p); \ + e += HIGH_WORD(p); + +#define SaveMulAcc(s, x, y) \ + R[s] = c; \ + p = (dword)A[x] * B[y] + d; \ + c = LOW_WORD(p); \ + p = (dword)e + HIGH_WORD(p); \ + d = LOW_WORD(p); \ + e = HIGH_WORD(p); + +#define SquAcc(x, y) \ + q = (dword)A[x] * A[y]; \ + p = q + c; \ + c = LOW_WORD(p); \ + p = (dword)d + HIGH_WORD(p); \ + d = LOW_WORD(p); \ + e += HIGH_WORD(p); \ + p = q + c; \ + c = LOW_WORD(p); \ + p = (dword)d + HIGH_WORD(p); \ + d = LOW_WORD(p); \ + e += HIGH_WORD(p); + +#define SaveSquAcc(s, x, y) \ + R[s] = c; \ + q = (dword)A[x] * A[y]; \ + p = q + d; \ + c = LOW_WORD(p); \ + p = (dword)e + HIGH_WORD(p); \ + d = LOW_WORD(p); \ + e = HIGH_WORD(p); \ + p = q + c; \ + c = LOW_WORD(p); \ + p = (dword)d + HIGH_WORD(p); \ + d = LOW_WORD(p); \ + e += HIGH_WORD(p); + +void Portable::Multiply4(word *R, const word *A, const word *B) +{ + dword p; + word c, d, e; + + p = (dword)A[0] * B[0]; + R[0] = LOW_WORD(p); + c = HIGH_WORD(p); + d = e = 0; + + MulAcc(0, 1); + MulAcc(1, 0); + + SaveMulAcc(1, 2, 0); + MulAcc(1, 1); + MulAcc(0, 2); + + SaveMulAcc(2, 0, 3); + MulAcc(1, 2); + MulAcc(2, 1); + MulAcc(3, 0); + + SaveMulAcc(3, 3, 1); + MulAcc(2, 2); + MulAcc(1, 3); + + SaveMulAcc(4, 2, 3); + MulAcc(3, 2); + + R[5] = c; + p = (dword)A[3] * B[3] + d; + R[6] = LOW_WORD(p); + R[7] = e + HIGH_WORD(p); +} + +void Portable::Square2(word *R, const word *A) +{ + dword p, q; + word c, d, e; + + p = (dword)A[0] * A[0]; + R[0] = LOW_WORD(p); + c = HIGH_WORD(p); + d = e = 0; + + SquAcc(0, 1); + + R[1] = c; + p = (dword)A[1] * A[1] + d; + R[2] = LOW_WORD(p); + R[3] = e + HIGH_WORD(p); +} + +void Portable::Square4(word *R, const word *A) +{ + const word *B = A; + dword p, q; + word c, d, e; + + p = (dword)A[0] * A[0]; + R[0] = LOW_WORD(p); + c = HIGH_WORD(p); + d = e = 0; + + SquAcc(0, 1); + + SaveSquAcc(1, 2, 0); + MulAcc(1, 1); + + SaveSquAcc(2, 0, 3); + SquAcc(1, 2); + + SaveSquAcc(3, 3, 1); + MulAcc(2, 2); + + SaveSquAcc(4, 2, 3); + + R[5] = c; + p = (dword)A[3] * A[3] + d; + R[6] = LOW_WORD(p); + R[7] = e + HIGH_WORD(p); +} + +void Portable::Multiply8(word *R, const word *A, const word *B) +{ + dword p; + word c, d, e; + + p = (dword)A[0] * B[0]; + R[0] = LOW_WORD(p); + c = HIGH_WORD(p); + d = e = 0; + + MulAcc(0, 1); + MulAcc(1, 0); + + SaveMulAcc(1, 2, 0); + MulAcc(1, 1); + MulAcc(0, 2); + + SaveMulAcc(2, 0, 3); + MulAcc(1, 2); + MulAcc(2, 1); + MulAcc(3, 0); + + SaveMulAcc(3, 0, 4); + MulAcc(1, 3); + MulAcc(2, 2); + MulAcc(3, 1); + MulAcc(4, 0); + + SaveMulAcc(4, 0, 5); + MulAcc(1, 4); + MulAcc(2, 3); + MulAcc(3, 2); + MulAcc(4, 1); + MulAcc(5, 0); + + SaveMulAcc(5, 0, 6); + MulAcc(1, 5); + MulAcc(2, 4); + MulAcc(3, 3); + MulAcc(4, 2); + MulAcc(5, 1); + MulAcc(6, 0); + + SaveMulAcc(6, 0, 7); + MulAcc(1, 6); + MulAcc(2, 5); + MulAcc(3, 4); + MulAcc(4, 3); + MulAcc(5, 2); + MulAcc(6, 1); + MulAcc(7, 0); + + SaveMulAcc(7, 1, 7); + MulAcc(2, 6); + MulAcc(3, 5); + MulAcc(4, 4); + MulAcc(5, 3); + MulAcc(6, 2); + MulAcc(7, 1); + + SaveMulAcc(8, 2, 7); + MulAcc(3, 6); + MulAcc(4, 5); + MulAcc(5, 4); + MulAcc(6, 3); + MulAcc(7, 2); + + SaveMulAcc(9, 3, 7); + MulAcc(4, 6); + MulAcc(5, 5); + MulAcc(6, 4); + MulAcc(7, 3); + + SaveMulAcc(10, 4, 7); + MulAcc(5, 6); + MulAcc(6, 5); + MulAcc(7, 4); + + SaveMulAcc(11, 5, 7); + MulAcc(6, 6); + MulAcc(7, 5); + + SaveMulAcc(12, 6, 7); + MulAcc(7, 6); + + R[13] = c; + p = (dword)A[7] * B[7] + d; + R[14] = LOW_WORD(p); + R[15] = e + HIGH_WORD(p); +} + +void Portable::Multiply4Bottom(word *R, const word *A, const word *B) +{ + dword p; + word c, d, e; + + p = (dword)A[0] * B[0]; + R[0] = LOW_WORD(p); + c = HIGH_WORD(p); + d = e = 0; + + MulAcc(0, 1); + MulAcc(1, 0); + + SaveMulAcc(1, 2, 0); + MulAcc(1, 1); + MulAcc(0, 2); + + R[2] = c; + R[3] = d + A[0] * B[3] + A[1] * B[2] + A[2] * B[1] + A[3] * B[0]; +} + +void Portable::Multiply8Bottom(word *R, const word *A, const word *B) +{ + dword p; + word c, d, e; + + p = (dword)A[0] * B[0]; + R[0] = LOW_WORD(p); + c = HIGH_WORD(p); + d = e = 0; + + MulAcc(0, 1); + MulAcc(1, 0); + + SaveMulAcc(1, 2, 0); + MulAcc(1, 1); + MulAcc(0, 2); + + SaveMulAcc(2, 0, 3); + MulAcc(1, 2); + MulAcc(2, 1); + MulAcc(3, 0); + + SaveMulAcc(3, 0, 4); + MulAcc(1, 3); + MulAcc(2, 2); + MulAcc(3, 1); + MulAcc(4, 0); + + SaveMulAcc(4, 0, 5); + MulAcc(1, 4); + MulAcc(2, 3); + MulAcc(3, 2); + MulAcc(4, 1); + MulAcc(5, 0); + + SaveMulAcc(5, 0, 6); + MulAcc(1, 5); + MulAcc(2, 4); + MulAcc(3, 3); + MulAcc(4, 2); + MulAcc(5, 1); + MulAcc(6, 0); + + R[6] = c; + R[7] = d + A[0] * B[7] + A[1] * B[6] + A[2] * B[5] + A[3] * B[4] + + A[4] * B[3] + A[5] * B[2] + A[6] * B[1] + A[7] * B[0]; +} + +#undef MulAcc +#undef SaveMulAcc +#undef SquAcc +#undef SaveSquAcc + +// CodeWarrior defines _MSC_VER +#if defined(_MSC_VER) && !defined(__MWERKS__) && defined(_M_IX86) && (_M_IX86<=700) + +class PentiumOptimized : public Portable +{ +public: + static word __fastcall Add(word *C, const word *A, const word *B, unsigned int N); + static word __fastcall Subtract(word *C, const word *A, const word *B, unsigned int N); + static inline void Square4(word *R, const word *A) + { + // VC60 workaround: MSVC 6.0 has an optimization bug that makes + // (dword)A*B where either A or B has been cast to a dword before + // very expensive. Revisit this function when this + // bug is fixed. + Multiply4(R, A, A); + } +}; + +typedef PentiumOptimized LowLevel; + +__declspec(naked) word __fastcall PentiumOptimized::Add(word *C, const word *A, const word *B, unsigned int N) +{ + __asm + { + push ebp + push ebx + push esi + push edi + + mov esi, [esp+24] ; N + mov ebx, [esp+20] ; B + + // now: ebx = B, ecx = C, edx = A, esi = N + + sub ecx, edx // hold the distance between C & A so we can add this to A to get C + xor eax, eax // clear eax + + sub eax, esi // eax is a negative index from end of B + lea ebx, [ebx+4*esi] // ebx is end of B + + sar eax, 1 // unit of eax is now dwords; this also clears the carry flag + jz loopend // if no dwords then nothing to do + +loopstart: + mov esi,[edx] // load lower word of A + mov ebp,[edx+4] // load higher word of A + + mov edi,[ebx+8*eax] // load lower word of B + lea edx,[edx+8] // advance A and C + + adc esi,edi // add lower words + mov edi,[ebx+8*eax+4] // load higher word of B + + adc ebp,edi // add higher words + inc eax // advance B + + mov [edx+ecx-8],esi // store lower word result + mov [edx+ecx-4],ebp // store higher word result + + jnz loopstart // loop until eax overflows and becomes zero + +loopend: + adc eax, 0 // store carry into eax (return result register) + pop edi + pop esi + pop ebx + pop ebp + ret 8 + } +} + +__declspec(naked) word __fastcall PentiumOptimized::Subtract(word *C, const word *A, const word *B, unsigned int N) +{ + __asm + { + push ebp + push ebx + push esi + push edi + + mov esi, [esp+24] ; N + mov ebx, [esp+20] ; B + + sub ecx, edx + xor eax, eax + + sub eax, esi + lea ebx, [ebx+4*esi] + + sar eax, 1 + jz loopend + +loopstart: + mov esi,[edx] + mov ebp,[edx+4] + + mov edi,[ebx+8*eax] + lea edx,[edx+8] + + sbb esi,edi + mov edi,[ebx+8*eax+4] + + sbb ebp,edi + inc eax + + mov [edx+ecx-8],esi + mov [edx+ecx-4],ebp + + jnz loopstart + +loopend: + adc eax, 0 + pop edi + pop esi + pop ebx + pop ebp + ret 8 + } +} + +#ifdef SSE2_INTRINSICS_AVAILABLE + +static bool GetSSE2Capability() +{ + word32 b; + + __asm + { + mov eax, 1 + cpuid + mov b, edx + } + + return (b & (1 << 26)) != 0; +} + +bool g_sse2DetectionDone = false, g_sse2Detected, g_sse2Enabled = true; + +static inline bool HasSSE2() +{ + if (g_sse2Enabled && !g_sse2DetectionDone) + { + g_sse2Detected = GetSSE2Capability(); + g_sse2DetectionDone = true; + } + return g_sse2Enabled && g_sse2Detected; +} + +class P4Optimized : public PentiumOptimized +{ +public: + static word __fastcall Add(word *C, const word *A, const word *B, unsigned int N); + static word __fastcall Subtract(word *C, const word *A, const word *B, unsigned int N); + static void Multiply4(word *C, const word *A, const word *B); + static void Multiply8(word *C, const word *A, const word *B); + static inline void Square4(word *R, const word *A) + { + Multiply4(R, A, A); + } + static void Multiply8Bottom(word *C, const word *A, const word *B); +}; + +static void __fastcall P4_Mul(__m128i *C, const __m128i *A, const __m128i *B) +{ + __m128i a3210 = _mm_load_si128(A); + __m128i b3210 = _mm_load_si128(B); + + __m128i sum; + + __m128i z = _mm_setzero_si128(); + __m128i a2b2_a0b0 = _mm_mul_epu32(a3210, b3210); + C[0] = a2b2_a0b0; + + __m128i a3120 = _mm_shuffle_epi32(a3210, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i b3021 = _mm_shuffle_epi32(b3210, _MM_SHUFFLE(3, 0, 2, 1)); + __m128i a1b0_a0b1 = _mm_mul_epu32(a3120, b3021); + __m128i a1b0 = _mm_unpackhi_epi32(a1b0_a0b1, z); + __m128i a0b1 = _mm_unpacklo_epi32(a1b0_a0b1, z); + C[1] = _mm_add_epi64(a1b0, a0b1); + + __m128i a31 = _mm_srli_epi64(a3210, 32); + __m128i b31 = _mm_srli_epi64(b3210, 32); + __m128i a3b3_a1b1 = _mm_mul_epu32(a31, b31); + C[6] = a3b3_a1b1; + + __m128i a1b1 = _mm_unpacklo_epi32(a3b3_a1b1, z); + __m128i b3012 = _mm_shuffle_epi32(b3210, _MM_SHUFFLE(3, 0, 1, 2)); + __m128i a2b0_a0b2 = _mm_mul_epu32(a3210, b3012); + __m128i a0b2 = _mm_unpacklo_epi32(a2b0_a0b2, z); + __m128i a2b0 = _mm_unpackhi_epi32(a2b0_a0b2, z); + sum = _mm_add_epi64(a1b1, a0b2); + C[2] = _mm_add_epi64(sum, a2b0); + + __m128i a2301 = _mm_shuffle_epi32(a3210, _MM_SHUFFLE(2, 3, 0, 1)); + __m128i b2103 = _mm_shuffle_epi32(b3210, _MM_SHUFFLE(2, 1, 0, 3)); + __m128i a3b0_a1b2 = _mm_mul_epu32(a2301, b3012); + __m128i a2b1_a0b3 = _mm_mul_epu32(a3210, b2103); + __m128i a3b0 = _mm_unpackhi_epi32(a3b0_a1b2, z); + __m128i a1b2 = _mm_unpacklo_epi32(a3b0_a1b2, z); + __m128i a2b1 = _mm_unpackhi_epi32(a2b1_a0b3, z); + __m128i a0b3 = _mm_unpacklo_epi32(a2b1_a0b3, z); + __m128i sum1 = _mm_add_epi64(a3b0, a1b2); + sum = _mm_add_epi64(a2b1, a0b3); + C[3] = _mm_add_epi64(sum, sum1); + + __m128i a3b1_a1b3 = _mm_mul_epu32(a2301, b2103); + __m128i a2b2 = _mm_unpackhi_epi32(a2b2_a0b0, z); + __m128i a3b1 = _mm_unpackhi_epi32(a3b1_a1b3, z); + __m128i a1b3 = _mm_unpacklo_epi32(a3b1_a1b3, z); + sum = _mm_add_epi64(a2b2, a3b1); + C[4] = _mm_add_epi64(sum, a1b3); + + __m128i a1302 = _mm_shuffle_epi32(a3210, _MM_SHUFFLE(1, 3, 0, 2)); + __m128i b1203 = _mm_shuffle_epi32(b3210, _MM_SHUFFLE(1, 2, 0, 3)); + __m128i a3b2_a2b3 = _mm_mul_epu32(a1302, b1203); + __m128i a3b2 = _mm_unpackhi_epi32(a3b2_a2b3, z); + __m128i a2b3 = _mm_unpacklo_epi32(a3b2_a2b3, z); + C[5] = _mm_add_epi64(a3b2, a2b3); +} + +void P4Optimized::Multiply4(word *C, const word *A, const word *B) +{ + __m128i temp[7]; + const word *w = (word *)temp; + const __m64 *mw = (__m64 *)w; + + P4_Mul(temp, (__m128i *)A, (__m128i *)B); + + C[0] = w[0]; + + __m64 s1, s2; + + __m64 w1 = _m_from_int(w[1]); + __m64 w4 = mw[2]; + __m64 w6 = mw[3]; + __m64 w8 = mw[4]; + __m64 w10 = mw[5]; + __m64 w12 = mw[6]; + __m64 w14 = mw[7]; + __m64 w16 = mw[8]; + __m64 w18 = mw[9]; + __m64 w20 = mw[10]; + __m64 w22 = mw[11]; + __m64 w26 = _m_from_int(w[26]); + + s1 = _mm_add_si64(w1, w4); + C[1] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s2 = _mm_add_si64(w6, w8); + s1 = _mm_add_si64(s1, s2); + C[2] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s2 = _mm_add_si64(w10, w12); + s1 = _mm_add_si64(s1, s2); + C[3] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s2 = _mm_add_si64(w14, w16); + s1 = _mm_add_si64(s1, s2); + C[4] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s2 = _mm_add_si64(w18, w20); + s1 = _mm_add_si64(s1, s2); + C[5] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s2 = _mm_add_si64(w22, w26); + s1 = _mm_add_si64(s1, s2); + C[6] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + C[7] = _m_to_int(s1) + w[27]; + _mm_empty(); +} + +void P4Optimized::Multiply8(word *C, const word *A, const word *B) +{ + __m128i temp[28]; + const word *w = (word *)temp; + const __m64 *mw = (__m64 *)w; + const word *x = (word *)temp+7*4; + const __m64 *mx = (__m64 *)x; + const word *y = (word *)temp+7*4*2; + const __m64 *my = (__m64 *)y; + const word *z = (word *)temp+7*4*3; + const __m64 *mz = (__m64 *)z; + + P4_Mul(temp, (__m128i *)A, (__m128i *)B); + + P4_Mul(temp+7, (__m128i *)A+1, (__m128i *)B); + + P4_Mul(temp+14, (__m128i *)A, (__m128i *)B+1); + + P4_Mul(temp+21, (__m128i *)A+1, (__m128i *)B+1); + + C[0] = w[0]; + + __m64 s1, s2, s3, s4; + + __m64 w1 = _m_from_int(w[1]); + __m64 w4 = mw[2]; + __m64 w6 = mw[3]; + __m64 w8 = mw[4]; + __m64 w10 = mw[5]; + __m64 w12 = mw[6]; + __m64 w14 = mw[7]; + __m64 w16 = mw[8]; + __m64 w18 = mw[9]; + __m64 w20 = mw[10]; + __m64 w22 = mw[11]; + __m64 w26 = _m_from_int(w[26]); + __m64 w27 = _m_from_int(w[27]); + + __m64 x0 = _m_from_int(x[0]); + __m64 x1 = _m_from_int(x[1]); + __m64 x4 = mx[2]; + __m64 x6 = mx[3]; + __m64 x8 = mx[4]; + __m64 x10 = mx[5]; + __m64 x12 = mx[6]; + __m64 x14 = mx[7]; + __m64 x16 = mx[8]; + __m64 x18 = mx[9]; + __m64 x20 = mx[10]; + __m64 x22 = mx[11]; + __m64 x26 = _m_from_int(x[26]); + __m64 x27 = _m_from_int(x[27]); + + __m64 y0 = _m_from_int(y[0]); + __m64 y1 = _m_from_int(y[1]); + __m64 y4 = my[2]; + __m64 y6 = my[3]; + __m64 y8 = my[4]; + __m64 y10 = my[5]; + __m64 y12 = my[6]; + __m64 y14 = my[7]; + __m64 y16 = my[8]; + __m64 y18 = my[9]; + __m64 y20 = my[10]; + __m64 y22 = my[11]; + __m64 y26 = _m_from_int(y[26]); + __m64 y27 = _m_from_int(y[27]); + + __m64 z0 = _m_from_int(z[0]); + __m64 z1 = _m_from_int(z[1]); + __m64 z4 = mz[2]; + __m64 z6 = mz[3]; + __m64 z8 = mz[4]; + __m64 z10 = mz[5]; + __m64 z12 = mz[6]; + __m64 z14 = mz[7]; + __m64 z16 = mz[8]; + __m64 z18 = mz[9]; + __m64 z20 = mz[10]; + __m64 z22 = mz[11]; + __m64 z26 = _m_from_int(z[26]); + + s1 = _mm_add_si64(w1, w4); + C[1] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s2 = _mm_add_si64(w6, w8); + s1 = _mm_add_si64(s1, s2); + C[2] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s2 = _mm_add_si64(w10, w12); + s1 = _mm_add_si64(s1, s2); + C[3] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x0, y0); + s2 = _mm_add_si64(w14, w16); + s1 = _mm_add_si64(s1, s3); + s1 = _mm_add_si64(s1, s2); + C[4] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x1, y1); + s4 = _mm_add_si64(x4, y4); + s1 = _mm_add_si64(s1, w18); + s3 = _mm_add_si64(s3, s4); + s1 = _mm_add_si64(s1, w20); + s1 = _mm_add_si64(s1, s3); + C[5] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x6, y6); + s4 = _mm_add_si64(x8, y8); + s1 = _mm_add_si64(s1, w22); + s3 = _mm_add_si64(s3, s4); + s1 = _mm_add_si64(s1, w26); + s1 = _mm_add_si64(s1, s3); + C[6] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x10, y10); + s4 = _mm_add_si64(x12, y12); + s1 = _mm_add_si64(s1, w27); + s3 = _mm_add_si64(s3, s4); + s1 = _mm_add_si64(s1, s3); + C[7] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x14, y14); + s4 = _mm_add_si64(x16, y16); + s1 = _mm_add_si64(s1, z0); + s3 = _mm_add_si64(s3, s4); + s1 = _mm_add_si64(s1, s3); + C[8] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x18, y18); + s4 = _mm_add_si64(x20, y20); + s1 = _mm_add_si64(s1, z1); + s3 = _mm_add_si64(s3, s4); + s1 = _mm_add_si64(s1, z4); + s1 = _mm_add_si64(s1, s3); + C[9] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x22, y22); + s4 = _mm_add_si64(x26, y26); + s1 = _mm_add_si64(s1, z6); + s3 = _mm_add_si64(s3, s4); + s1 = _mm_add_si64(s1, z8); + s1 = _mm_add_si64(s1, s3); + C[10] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x27, y27); + s1 = _mm_add_si64(s1, z10); + s1 = _mm_add_si64(s1, z12); + s1 = _mm_add_si64(s1, s3); + C[11] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(z14, z16); + s1 = _mm_add_si64(s1, s3); + C[12] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(z18, z20); + s1 = _mm_add_si64(s1, s3); + C[13] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(z22, z26); + s1 = _mm_add_si64(s1, s3); + C[14] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + C[15] = z[27] + _m_to_int(s1); + _mm_empty(); +} + +void P4Optimized::Multiply8Bottom(word *C, const word *A, const word *B) +{ + __m128i temp[21]; + const word *w = (word *)temp; + const __m64 *mw = (__m64 *)w; + const word *x = (word *)temp+7*4; + const __m64 *mx = (__m64 *)x; + const word *y = (word *)temp+7*4*2; + const __m64 *my = (__m64 *)y; + + P4_Mul(temp, (__m128i *)A, (__m128i *)B); + + P4_Mul(temp+7, (__m128i *)A+1, (__m128i *)B); + + P4_Mul(temp+14, (__m128i *)A, (__m128i *)B+1); + + C[0] = w[0]; + + __m64 s1, s2, s3, s4; + + __m64 w1 = _m_from_int(w[1]); + __m64 w4 = mw[2]; + __m64 w6 = mw[3]; + __m64 w8 = mw[4]; + __m64 w10 = mw[5]; + __m64 w12 = mw[6]; + __m64 w14 = mw[7]; + __m64 w16 = mw[8]; + __m64 w18 = mw[9]; + __m64 w20 = mw[10]; + __m64 w22 = mw[11]; + __m64 w26 = _m_from_int(w[26]); + + __m64 x0 = _m_from_int(x[0]); + __m64 x1 = _m_from_int(x[1]); + __m64 x4 = mx[2]; + __m64 x6 = mx[3]; + __m64 x8 = mx[4]; + + __m64 y0 = _m_from_int(y[0]); + __m64 y1 = _m_from_int(y[1]); + __m64 y4 = my[2]; + __m64 y6 = my[3]; + __m64 y8 = my[4]; + + s1 = _mm_add_si64(w1, w4); + C[1] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s2 = _mm_add_si64(w6, w8); + s1 = _mm_add_si64(s1, s2); + C[2] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s2 = _mm_add_si64(w10, w12); + s1 = _mm_add_si64(s1, s2); + C[3] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x0, y0); + s2 = _mm_add_si64(w14, w16); + s1 = _mm_add_si64(s1, s3); + s1 = _mm_add_si64(s1, s2); + C[4] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x1, y1); + s4 = _mm_add_si64(x4, y4); + s1 = _mm_add_si64(s1, w18); + s3 = _mm_add_si64(s3, s4); + s1 = _mm_add_si64(s1, w20); + s1 = _mm_add_si64(s1, s3); + C[5] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + s3 = _mm_add_si64(x6, y6); + s4 = _mm_add_si64(x8, y8); + s1 = _mm_add_si64(s1, w22); + s3 = _mm_add_si64(s3, s4); + s1 = _mm_add_si64(s1, w26); + s1 = _mm_add_si64(s1, s3); + C[6] = _m_to_int(s1); + s1 = _m_psrlqi(s1, 32); + + C[7] = _m_to_int(s1) + w[27] + x[10] + y[10] + x[12] + y[12]; + _mm_empty(); +} + +__declspec(naked) word __fastcall P4Optimized::Add(word *C, const word *A, const word *B, unsigned int N) +{ + __asm + { + sub esp, 16 + xor eax, eax + mov [esp], edi + mov [esp+4], esi + mov [esp+8], ebx + mov [esp+12], ebp + + mov ebx, [esp+20] // B + mov esi, [esp+24] // N + + // now: ebx = B, ecx = C, edx = A, esi = N + + neg esi + jz loopend // if no dwords then nothing to do + + mov edi, [edx] + mov ebp, [ebx] + +loopstart: + add edi, eax + jc carry1 + + xor eax, eax + +carry1continue: + add edi, ebp + mov ebp, 1 + mov [ecx], edi + mov edi, [edx+4] + cmovc eax, ebp + mov ebp, [ebx+4] + lea ebx, [ebx+8] + add edi, eax + jc carry2 + + xor eax, eax + +carry2continue: + add edi, ebp + mov ebp, 1 + cmovc eax, ebp + mov [ecx+4], edi + add ecx, 8 + mov edi, [edx+8] + add edx, 8 + add esi, 2 + mov ebp, [ebx] + jnz loopstart + +loopend: + mov edi, [esp] + mov esi, [esp+4] + mov ebx, [esp+8] + mov ebp, [esp+12] + add esp, 16 + ret 8 + +carry1: + mov eax, 1 + jmp carry1continue + +carry2: + mov eax, 1 + jmp carry2continue + } +} + +__declspec(naked) word __fastcall P4Optimized::Subtract(word *C, const word *A, const word *B, unsigned int N) +{ + __asm + { + sub esp, 16 + xor eax, eax + mov [esp], edi + mov [esp+4], esi + mov [esp+8], ebx + mov [esp+12], ebp + + mov ebx, [esp+20] // B + mov esi, [esp+24] // N + + // now: ebx = B, ecx = C, edx = A, esi = N + + neg esi + jz loopend // if no dwords then nothing to do + + mov edi, [edx] + mov ebp, [ebx] + +loopstart: + sub edi, eax + jc carry1 + + xor eax, eax + +carry1continue: + sub edi, ebp + mov ebp, 1 + mov [ecx], edi + mov edi, [edx+4] + cmovc eax, ebp + mov ebp, [ebx+4] + lea ebx, [ebx+8] + sub edi, eax + jc carry2 + + xor eax, eax + +carry2continue: + sub edi, ebp + mov ebp, 1 + cmovc eax, ebp + mov [ecx+4], edi + add ecx, 8 + mov edi, [edx+8] + add edx, 8 + add esi, 2 + mov ebp, [ebx] + jnz loopstart + +loopend: + mov edi, [esp] + mov esi, [esp+4] + mov ebx, [esp+8] + mov ebp, [esp+12] + add esp, 16 + ret 8 + +carry1: + mov eax, 1 + jmp carry1continue + +carry2: + mov eax, 1 + jmp carry2continue + } +} + +#endif // #ifdef SSE2_INTRINSICS_AVAILABLE + +#elif defined(__GNUC__) && defined(__i386__) + +class PentiumOptimized : public Portable +{ +public: +#ifndef __pic__ // -fpic uses up a register, leaving too few for the asm code + static word Add(word *C, const word *A, const word *B, unsigned int N); + static word Subtract(word *C, const word *A, const word *B, unsigned int N); +#endif + static void Square4(word *R, const word *A); + static void Multiply4(word *C, const word *A, const word *B); + static void Multiply8(word *C, const word *A, const word *B); +}; + +typedef PentiumOptimized LowLevel; + +// Add and Subtract assembly code originally contributed by Alister Lee + +#ifndef __pic__ +__attribute__((regparm(3))) word PentiumOptimized::Add(word *C, const word *A, const word *B, unsigned int N) +{ + assert (N%2 == 0); + + register word carry, temp; + + __asm__ __volatile__( + "push %%ebp;" + "sub %3, %2;" + "xor %0, %0;" + "sub %4, %0;" + "lea (%1,%4,4), %1;" + "sar $1, %0;" + "jz 1f;" + + "0:;" + "mov 0(%3), %4;" + "mov 4(%3), %%ebp;" + "mov (%1,%0,8), %5;" + "lea 8(%3), %3;" + "adc %5, %4;" + "mov 4(%1,%0,8), %5;" + "adc %5, %%ebp;" + "inc %0;" + "mov %4, -8(%3, %2);" + "mov %%ebp, -4(%3, %2);" + "jnz 0b;" + + "1:;" + "adc $0, %0;" + "pop %%ebp;" + + : "=aSD" (carry), "+r" (B), "+r" (C), "+r" (A), "+r" (N), "=r" (temp) + : : "cc", "memory"); + + return carry; +} + +__attribute__((regparm(3))) word PentiumOptimized::Subtract(word *C, const word *A, const word *B, unsigned int N) +{ + assert (N%2 == 0); + + register word carry, temp; + + __asm__ __volatile__( + "push %%ebp;" + "sub %3, %2;" + "xor %0, %0;" + "sub %4, %0;" + "lea (%1,%4,4), %1;" + "sar $1, %0;" + "jz 1f;" + + "0:;" + "mov 0(%3), %4;" + "mov 4(%3), %%ebp;" + "mov (%1,%0,8), %5;" + "lea 8(%3), %3;" + "sbb %5, %4;" + "mov 4(%1,%0,8), %5;" + "sbb %5, %%ebp;" + "inc %0;" + "mov %4, -8(%3, %2);" + "mov %%ebp, -4(%3, %2);" + "jnz 0b;" + + "1:;" + "adc $0, %0;" + "pop %%ebp;" + + : "=aSD" (carry), "+r" (B), "+r" (C), "+r" (A), "+r" (N), "=r" (temp) + : : "cc", "memory"); + + return carry; +} +#endif // __pic__ + +// Comba square and multiply assembly code originally contributed by Leonard Janke + +#define SqrStartup \ + "push %%ebp\n\t" \ + "push %%esi\n\t" \ + "push %%ebx\n\t" \ + "xor %%ebp, %%ebp\n\t" \ + "xor %%ebx, %%ebx\n\t" \ + "xor %%ecx, %%ecx\n\t" + +#define SqrShiftCarry \ + "mov %%ebx, %%ebp\n\t" \ + "mov %%ecx, %%ebx\n\t" \ + "xor %%ecx, %%ecx\n\t" + +#define SqrAccumulate(i,j) \ + "mov 4*"#j"(%%esi), %%eax\n\t" \ + "mull 4*"#i"(%%esi)\n\t" \ + "add %%eax, %%ebp\n\t" \ + "adc %%edx, %%ebx\n\t" \ + "adc %%ch, %%cl\n\t" \ + "add %%eax, %%ebp\n\t" \ + "adc %%edx, %%ebx\n\t" \ + "adc %%ch, %%cl\n\t" + +#define SqrAccumulateCentre(i) \ + "mov 4*"#i"(%%esi), %%eax\n\t" \ + "mull 4*"#i"(%%esi)\n\t" \ + "add %%eax, %%ebp\n\t" \ + "adc %%edx, %%ebx\n\t" \ + "adc %%ch, %%cl\n\t" + +#define SqrStoreDigit(X) \ + "mov %%ebp, 4*"#X"(%%edi)\n\t" \ + +#define SqrLastDiagonal(digits) \ + "mov 4*("#digits"-1)(%%esi), %%eax\n\t" \ + "mull 4*("#digits"-1)(%%esi)\n\t" \ + "add %%eax, %%ebp\n\t" \ + "adc %%edx, %%ebx\n\t" \ + "mov %%ebp, 4*(2*"#digits"-2)(%%edi)\n\t" \ + "mov %%ebx, 4*(2*"#digits"-1)(%%edi)\n\t" + +#define SqrCleanup \ + "pop %%ebx\n\t" \ + "pop %%esi\n\t" \ + "pop %%ebp\n\t" + +void PentiumOptimized::Square4(word* Y, const word* X) +{ + __asm__ __volatile__( + SqrStartup + + SqrAccumulateCentre(0) + SqrStoreDigit(0) + SqrShiftCarry + + SqrAccumulate(1,0) + SqrStoreDigit(1) + SqrShiftCarry + + SqrAccumulate(2,0) + SqrAccumulateCentre(1) + SqrStoreDigit(2) + SqrShiftCarry + + SqrAccumulate(3,0) + SqrAccumulate(2,1) + SqrStoreDigit(3) + SqrShiftCarry + + SqrAccumulate(3,1) + SqrAccumulateCentre(2) + SqrStoreDigit(4) + SqrShiftCarry + + SqrAccumulate(3,2) + SqrStoreDigit(5) + SqrShiftCarry + + SqrLastDiagonal(4) + + SqrCleanup + + : + : "D" (Y), "S" (X) + : "eax", "ecx", "edx", "ebp", "memory" + ); +} + +#define MulStartup \ + "push %%ebp\n\t" \ + "push %%esi\n\t" \ + "push %%ebx\n\t" \ + "push %%edi\n\t" \ + "mov %%eax, %%ebx \n\t" \ + "xor %%ebp, %%ebp\n\t" \ + "xor %%edi, %%edi\n\t" \ + "xor %%ecx, %%ecx\n\t" + +#define MulShiftCarry \ + "mov %%edx, %%ebp\n\t" \ + "mov %%ecx, %%edi\n\t" \ + "xor %%ecx, %%ecx\n\t" + +#define MulAccumulate(i,j) \ + "mov 4*"#j"(%%ebx), %%eax\n\t" \ + "mull 4*"#i"(%%esi)\n\t" \ + "add %%eax, %%ebp\n\t" \ + "adc %%edx, %%edi\n\t" \ + "adc %%ch, %%cl\n\t" + +#define MulStoreDigit(X) \ + "mov %%edi, %%edx \n\t" \ + "mov (%%esp), %%edi \n\t" \ + "mov %%ebp, 4*"#X"(%%edi)\n\t" \ + "mov %%edi, (%%esp)\n\t" + +#define MulLastDiagonal(digits) \ + "mov 4*("#digits"-1)(%%ebx), %%eax\n\t" \ + "mull 4*("#digits"-1)(%%esi)\n\t" \ + "add %%eax, %%ebp\n\t" \ + "adc %%edi, %%edx\n\t" \ + "mov (%%esp), %%edi\n\t" \ + "mov %%ebp, 4*(2*"#digits"-2)(%%edi)\n\t" \ + "mov %%edx, 4*(2*"#digits"-1)(%%edi)\n\t" + +#define MulCleanup \ + "pop %%edi\n\t" \ + "pop %%ebx\n\t" \ + "pop %%esi\n\t" \ + "pop %%ebp\n\t" + +void PentiumOptimized::Multiply4(word* Z, const word* X, const word* Y) +{ + __asm__ __volatile__( + MulStartup + MulAccumulate(0,0) + MulStoreDigit(0) + MulShiftCarry + + MulAccumulate(1,0) + MulAccumulate(0,1) + MulStoreDigit(1) + MulShiftCarry + + MulAccumulate(2,0) + MulAccumulate(1,1) + MulAccumulate(0,2) + MulStoreDigit(2) + MulShiftCarry + + MulAccumulate(3,0) + MulAccumulate(2,1) + MulAccumulate(1,2) + MulAccumulate(0,3) + MulStoreDigit(3) + MulShiftCarry + + MulAccumulate(3,1) + MulAccumulate(2,2) + MulAccumulate(1,3) + MulStoreDigit(4) + MulShiftCarry + + MulAccumulate(3,2) + MulAccumulate(2,3) + MulStoreDigit(5) + MulShiftCarry + + MulLastDiagonal(4) + + MulCleanup + + : + : "D" (Z), "S" (X), "a" (Y) + : "%ecx", "%edx", "memory" + ); +} + +void PentiumOptimized::Multiply8(word* Z, const word* X, const word* Y) +{ + __asm__ __volatile__( + MulStartup + MulAccumulate(0,0) + MulStoreDigit(0) + MulShiftCarry + + MulAccumulate(1,0) + MulAccumulate(0,1) + MulStoreDigit(1) + MulShiftCarry + + MulAccumulate(2,0) + MulAccumulate(1,1) + MulAccumulate(0,2) + MulStoreDigit(2) + MulShiftCarry + + MulAccumulate(3,0) + MulAccumulate(2,1) + MulAccumulate(1,2) + MulAccumulate(0,3) + MulStoreDigit(3) + MulShiftCarry + + MulAccumulate(4,0) + MulAccumulate(3,1) + MulAccumulate(2,2) + MulAccumulate(1,3) + MulAccumulate(0,4) + MulStoreDigit(4) + MulShiftCarry + + MulAccumulate(5,0) + MulAccumulate(4,1) + MulAccumulate(3,2) + MulAccumulate(2,3) + MulAccumulate(1,4) + MulAccumulate(0,5) + MulStoreDigit(5) + MulShiftCarry + + MulAccumulate(6,0) + MulAccumulate(5,1) + MulAccumulate(4,2) + MulAccumulate(3,3) + MulAccumulate(2,4) + MulAccumulate(1,5) + MulAccumulate(0,6) + MulStoreDigit(6) + MulShiftCarry + + MulAccumulate(7,0) + MulAccumulate(6,1) + MulAccumulate(5,2) + MulAccumulate(4,3) + MulAccumulate(3,4) + MulAccumulate(2,5) + MulAccumulate(1,6) + MulAccumulate(0,7) + MulStoreDigit(7) + MulShiftCarry + + MulAccumulate(7,1) + MulAccumulate(6,2) + MulAccumulate(5,3) + MulAccumulate(4,4) + MulAccumulate(3,5) + MulAccumulate(2,6) + MulAccumulate(1,7) + MulStoreDigit(8) + MulShiftCarry + + MulAccumulate(7,2) + MulAccumulate(6,3) + MulAccumulate(5,4) + MulAccumulate(4,5) + MulAccumulate(3,6) + MulAccumulate(2,7) + MulStoreDigit(9) + MulShiftCarry + + MulAccumulate(7,3) + MulAccumulate(6,4) + MulAccumulate(5,5) + MulAccumulate(4,6) + MulAccumulate(3,7) + MulStoreDigit(10) + MulShiftCarry + + MulAccumulate(7,4) + MulAccumulate(6,5) + MulAccumulate(5,6) + MulAccumulate(4,7) + MulStoreDigit(11) + MulShiftCarry + + MulAccumulate(7,5) + MulAccumulate(6,6) + MulAccumulate(5,7) + MulStoreDigit(12) + MulShiftCarry + + MulAccumulate(7,6) + MulAccumulate(6,7) + MulStoreDigit(13) + MulShiftCarry + + MulLastDiagonal(8) + + MulCleanup + + : + : "D" (Z), "S" (X), "a" (Y) + : "%ecx", "%edx", "memory" + ); +} + +#elif defined(__GNUC__) && defined(__alpha__) + +class AlphaOptimized : public Portable +{ +public: + static inline void Multiply2(word *C, const word *A, const word *B); + static inline word Multiply2Add(word *C, const word *A, const word *B); + static inline void Multiply4(word *C, const word *A, const word *B); + static inline unsigned int MultiplyRecursionLimit() {return 4;} + + static inline void Multiply4Bottom(word *C, const word *A, const word *B); + static inline unsigned int MultiplyBottomRecursionLimit() {return 4;} + + static inline void Square4(word *R, const word *A) + { + Multiply4(R, A, A); + } +}; + +typedef AlphaOptimized LowLevel; + +inline void AlphaOptimized::Multiply2(word *C, const word *A, const word *B) +{ + register dword c, a = *(const dword *)A, b = *(const dword *)B; + ((dword *)C)[0] = a*b; + __asm__("umulh %1,%2,%0" : "=r" (c) : "r" (a), "r" (b)); + ((dword *)C)[1] = c; +} + +inline word AlphaOptimized::Multiply2Add(word *C, const word *A, const word *B) +{ + register dword c, d, e, a = *(const dword *)A, b = *(const dword *)B; + c = ((dword *)C)[0]; + d = a*b + c; + __asm__("umulh %1,%2,%0" : "=r" (e) : "r" (a), "r" (b)); + ((dword *)C)[0] = d; + d = (d < c); + c = ((dword *)C)[1] + d; + d = (c < d); + c += e; + ((dword *)C)[1] = c; + d |= (c < e); + return d; +} + +inline void AlphaOptimized::Multiply4(word *R, const word *A, const word *B) +{ + Multiply2(R, A, B); + Multiply2(R+4, A+2, B+2); + word carry = Multiply2Add(R+2, A+0, B+2); + carry += Multiply2Add(R+2, A+2, B+0); + Increment(R+6, 2, carry); +} + +static inline void Multiply2BottomAdd(word *C, const word *A, const word *B) +{ + register dword a = *(const dword *)A, b = *(const dword *)B; + ((dword *)C)[0] = a*b + ((dword *)C)[0]; +} + +inline void AlphaOptimized::Multiply4Bottom(word *R, const word *A, const word *B) +{ + Multiply2(R, A, B); + Multiply2BottomAdd(R+2, A+0, B+2); + Multiply2BottomAdd(R+2, A+2, B+0); +} + +#else // no processor specific code available + +typedef Portable LowLevel; + +#endif + +// ******************************************************** + +#define A0 A +#define A1 (A+N2) +#define B0 B +#define B1 (B+N2) + +#define T0 T +#define T1 (T+N2) +#define T2 (T+N) +#define T3 (T+N+N2) + +#define R0 R +#define R1 (R+N2) +#define R2 (R+N) +#define R3 (R+N+N2) + +//VC60 workaround: compiler bug triggered without the extra dummy parameters + +// R[2*N] - result = A*B +// T[2*N] - temporary work space +// A[N] --- multiplier +// B[N] --- multiplicant + +template <class P> +void DoRecursiveMultiply(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy=NULL); + +template <class P> +inline void RecursiveMultiply(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy=NULL) +{ + assert(N>=2 && N%2==0); + + if (P::MultiplyRecursionLimit() >= 8 && N==8) + P::Multiply8(R, A, B); + else if (P::MultiplyRecursionLimit() >= 4 && N==4) + P::Multiply4(R, A, B); + else if (N==2) + P::Multiply2(R, A, B); + else + DoRecursiveMultiply<P>(R, T, A, B, N, NULL); // VC60 workaround: needs this NULL +} + +template <class P> +void DoRecursiveMultiply(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy) +{ + const unsigned int N2 = N/2; + int carry; + + int aComp = Compare(A0, A1, N2); + int bComp = Compare(B0, B1, N2); + + switch (2*aComp + aComp + bComp) + { + case -4: + P::Subtract(R0, A1, A0, N2); + P::Subtract(R1, B0, B1, N2); + RecursiveMultiply<P>(T0, T2, R0, R1, N2); + P::Subtract(T1, T1, R0, N2); + carry = -1; + break; + case -2: + P::Subtract(R0, A1, A0, N2); + P::Subtract(R1, B0, B1, N2); + RecursiveMultiply<P>(T0, T2, R0, R1, N2); + carry = 0; + break; + case 2: + P::Subtract(R0, A0, A1, N2); + P::Subtract(R1, B1, B0, N2); + RecursiveMultiply<P>(T0, T2, R0, R1, N2); + carry = 0; + break; + case 4: + P::Subtract(R0, A1, A0, N2); + P::Subtract(R1, B0, B1, N2); + RecursiveMultiply<P>(T0, T2, R0, R1, N2); + P::Subtract(T1, T1, R1, N2); + carry = -1; + break; + default: + SetWords(T0, 0, N); + carry = 0; + } + + RecursiveMultiply<P>(R0, T2, A0, B0, N2); + RecursiveMultiply<P>(R2, T2, A1, B1, N2); + + // now T[01] holds (A1-A0)*(B0-B1), R[01] holds A0*B0, R[23] holds A1*B1 + + carry += P::Add(T0, T0, R0, N); + carry += P::Add(T0, T0, R2, N); + carry += P::Add(R1, R1, T0, N); + + assert (carry >= 0 && carry <= 2); + Increment(R3, N2, carry); +} + +// R[2*N] - result = A*A +// T[2*N] - temporary work space +// A[N] --- number to be squared + +template <class P> +void DoRecursiveSquare(word *R, word *T, const word *A, unsigned int N, const P *dummy=NULL); + +template <class P> +inline void RecursiveSquare(word *R, word *T, const word *A, unsigned int N, const P *dummy=NULL) +{ + assert(N && N%2==0); + if (P::SquareRecursionLimit() >= 8 && N==8) + P::Square8(R, A); + if (P::SquareRecursionLimit() >= 4 && N==4) + P::Square4(R, A); + else if (N==2) + P::Square2(R, A); + else + DoRecursiveSquare<P>(R, T, A, N, NULL); // VC60 workaround: needs this NULL +} + +template <class P> +void DoRecursiveSquare(word *R, word *T, const word *A, unsigned int N, const P *dummy) +{ + const unsigned int N2 = N/2; + + RecursiveSquare<P>(R0, T2, A0, N2); + RecursiveSquare<P>(R2, T2, A1, N2); + RecursiveMultiply<P>(T0, T2, A0, A1, N2); + + word carry = P::Add(R1, R1, T0, N); + carry += P::Add(R1, R1, T0, N); + Increment(R3, N2, carry); +} + +// R[N] - bottom half of A*B +// T[N] - temporary work space +// A[N] - multiplier +// B[N] - multiplicant + +template <class P> +void DoRecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy=NULL); + +template <class P> +inline void RecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy=NULL) +{ + assert(N>=2 && N%2==0); + if (P::MultiplyBottomRecursionLimit() >= 8 && N==8) + P::Multiply8Bottom(R, A, B); + else if (P::MultiplyBottomRecursionLimit() >= 4 && N==4) + P::Multiply4Bottom(R, A, B); + else if (N==2) + P::Multiply2Bottom(R, A, B); + else + DoRecursiveMultiplyBottom<P>(R, T, A, B, N, NULL); +} + +template <class P> +void DoRecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy) +{ + const unsigned int N2 = N/2; + + RecursiveMultiply<P>(R, T, A0, B0, N2); + RecursiveMultiplyBottom<P>(T0, T1, A1, B0, N2); + P::Add(R1, R1, T0, N2); + RecursiveMultiplyBottom<P>(T0, T1, A0, B1, N2); + P::Add(R1, R1, T0, N2); +} + +// R[N] --- upper half of A*B +// T[2*N] - temporary work space +// L[N] --- lower half of A*B +// A[N] --- multiplier +// B[N] --- multiplicant + +template <class P> +void RecursiveMultiplyTop(word *R, word *T, const word *L, const word *A, const word *B, unsigned int N, const P *dummy=NULL) +{ + assert(N>=2 && N%2==0); + + if (N==4) + { + P::Multiply4(T, A, B); + ((dword *)R)[0] = ((dword *)T)[2]; + ((dword *)R)[1] = ((dword *)T)[3]; + } + else if (N==2) + { + P::Multiply2(T, A, B); + ((dword *)R)[0] = ((dword *)T)[1]; + } + else + { + const unsigned int N2 = N/2; + int carry; + + int aComp = Compare(A0, A1, N2); + int bComp = Compare(B0, B1, N2); + + switch (2*aComp + aComp + bComp) + { + case -4: + P::Subtract(R0, A1, A0, N2); + P::Subtract(R1, B0, B1, N2); + RecursiveMultiply<P>(T0, T2, R0, R1, N2); + P::Subtract(T1, T1, R0, N2); + carry = -1; + break; + case -2: + P::Subtract(R0, A1, A0, N2); + P::Subtract(R1, B0, B1, N2); + RecursiveMultiply<P>(T0, T2, R0, R1, N2); + carry = 0; + break; + case 2: + P::Subtract(R0, A0, A1, N2); + P::Subtract(R1, B1, B0, N2); + RecursiveMultiply<P>(T0, T2, R0, R1, N2); + carry = 0; + break; + case 4: + P::Subtract(R0, A1, A0, N2); + P::Subtract(R1, B0, B1, N2); + RecursiveMultiply<P>(T0, T2, R0, R1, N2); + P::Subtract(T1, T1, R1, N2); + carry = -1; + break; + default: + SetWords(T0, 0, N); + carry = 0; + } + + RecursiveMultiply<P>(T2, R0, A1, B1, N2); + + // now T[01] holds (A1-A0)*(B0-B1), T[23] holds A1*B1 + + word c2 = P::Subtract(R0, L+N2, L, N2); + c2 += P::Subtract(R0, R0, T0, N2); + word t = (Compare(R0, T2, N2) == -1); + + carry += t; + carry += Increment(R0, N2, c2+t); + carry += P::Add(R0, R0, T1, N2); + carry += P::Add(R0, R0, T3, N2); + assert (carry >= 0 && carry <= 2); + + CopyWords(R1, T3, N2); + Increment(R1, N2, carry); + } +} + +inline word Add(word *C, const word *A, const word *B, unsigned int N) +{ + return LowLevel::Add(C, A, B, N); +} + +inline word Subtract(word *C, const word *A, const word *B, unsigned int N) +{ + return LowLevel::Subtract(C, A, B, N); +} + +inline void Multiply(word *R, word *T, const word *A, const word *B, unsigned int N) +{ +#ifdef SSE2_INTRINSICS_AVAILABLE + if (HasSSE2()) + RecursiveMultiply<P4Optimized>(R, T, A, B, N); + else +#endif + RecursiveMultiply<LowLevel>(R, T, A, B, N); +} + +inline void Square(word *R, word *T, const word *A, unsigned int N) +{ +#ifdef SSE2_INTRINSICS_AVAILABLE + if (HasSSE2()) + RecursiveSquare<P4Optimized>(R, T, A, N); + else +#endif + RecursiveSquare<LowLevel>(R, T, A, N); +} + +inline void MultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N) +{ +#ifdef SSE2_INTRINSICS_AVAILABLE + if (HasSSE2()) + RecursiveMultiplyBottom<P4Optimized>(R, T, A, B, N); + else +#endif + RecursiveMultiplyBottom<LowLevel>(R, T, A, B, N); +} + +inline void MultiplyTop(word *R, word *T, const word *L, const word *A, const word *B, unsigned int N) +{ +#ifdef SSE2_INTRINSICS_AVAILABLE + if (HasSSE2()) + RecursiveMultiplyTop<P4Optimized>(R, T, L, A, B, N); + else +#endif + RecursiveMultiplyTop<LowLevel>(R, T, L, A, B, N); +} + +// R[NA+NB] - result = A*B +// T[NA+NB] - temporary work space +// A[NA] ---- multiplier +// B[NB] ---- multiplicant + +void AsymmetricMultiply(word *R, word *T, const word *A, unsigned int NA, const word *B, unsigned int NB) +{ + if (NA == NB) + { + if (A == B) + Square(R, T, A, NA); + else + Multiply(R, T, A, B, NA); + + return; + } + + if (NA > NB) + { + std::swap(A, B); + std::swap(NA, NB); + } + + assert(NB % NA == 0); + assert((NB/NA)%2 == 0); // NB is an even multiple of NA + + if (NA==2 && !A[1]) + { + switch (A[0]) + { + case 0: + SetWords(R, 0, NB+2); + return; + case 1: + CopyWords(R, B, NB); + R[NB] = R[NB+1] = 0; + return; + default: + R[NB] = LinearMultiply(R, B, A[0], NB); + R[NB+1] = 0; + return; + } + } + + Multiply(R, T, A, B, NA); + CopyWords(T+2*NA, R+NA, NA); + + unsigned i; + + for (i=2*NA; i<NB; i+=2*NA) + Multiply(T+NA+i, T, A, B+i, NA); + for (i=NA; i<NB; i+=2*NA) + Multiply(R+i, T, A, B+i, NA); + + if (Add(R+NA, R+NA, T+2*NA, NB-NA)) + Increment(R+NB, NA); +} + +// R[N] ----- result = A inverse mod 2**(WORD_BITS*N) +// T[3*N/2] - temporary work space +// A[N] ----- an odd number as input + +void RecursiveInverseModPower2(word *R, word *T, const word *A, unsigned int N) +{ + if (N==2) + AtomicInverseModPower2(R, A[0], A[1]); + else + { + const unsigned int N2 = N/2; + RecursiveInverseModPower2(R0, T0, A0, N2); + T0[0] = 1; + SetWords(T0+1, 0, N2-1); + MultiplyTop(R1, T1, T0, R0, A0, N2); + MultiplyBottom(T0, T1, R0, A1, N2); + Add(T0, R1, T0, N2); + TwosComplement(T0, N2); + MultiplyBottom(R1, T1, R0, T0, N2); + } +} + +// R[N] --- result = X/(2**(WORD_BITS*N)) mod M +// T[3*N] - temporary work space +// X[2*N] - number to be reduced +// M[N] --- modulus +// U[N] --- multiplicative inverse of M mod 2**(WORD_BITS*N) + +void MontgomeryReduce(word *R, word *T, const word *X, const word *M, const word *U, unsigned int N) +{ + MultiplyBottom(R, T, X, U, N); + MultiplyTop(T, T+N, X, R, M, N); + word borrow = Subtract(T, X+N, T, N); + // defend against timing attack by doing this Add even when not needed + word carry = Add(T+N, T, M, N); + assert(carry || !borrow); + CopyWords(R, T + (borrow ? N : 0), N); +} + +// R[N] --- result = X/(2**(WORD_BITS*N/2)) mod M +// T[2*N] - temporary work space +// X[2*N] - number to be reduced +// M[N] --- modulus +// U[N/2] - multiplicative inverse of M mod 2**(WORD_BITS*N/2) +// V[N] --- 2**(WORD_BITS*3*N/2) mod M + +void HalfMontgomeryReduce(word *R, word *T, const word *X, const word *M, const word *U, const word *V, unsigned int N) +{ + assert(N%2==0 && N>=4); + +#define M0 M +#define M1 (M+N2) +#define V0 V +#define V1 (V+N2) + +#define X0 X +#define X1 (X+N2) +#define X2 (X+N) +#define X3 (X+N+N2) + + const unsigned int N2 = N/2; + Multiply(T0, T2, V0, X3, N2); + int c2 = Add(T0, T0, X0, N); + MultiplyBottom(T3, T2, T0, U, N2); + MultiplyTop(T2, R, T0, T3, M0, N2); + c2 -= Subtract(T2, T1, T2, N2); + Multiply(T0, R, T3, M1, N2); + c2 -= Subtract(T0, T2, T0, N2); + int c3 = -(int)Subtract(T1, X2, T1, N2); + Multiply(R0, T2, V1, X3, N2); + c3 += Add(R, R, T, N); + + if (c2>0) + c3 += Increment(R1, N2); + else if (c2<0) + c3 -= Decrement(R1, N2, -c2); + + assert(c3>=-1 && c3<=1); + if (c3>0) + Subtract(R, R, M, N); + else if (c3<0) + Add(R, R, M, N); + +#undef M0 +#undef M1 +#undef V0 +#undef V1 + +#undef X0 +#undef X1 +#undef X2 +#undef X3 +} + +#undef A0 +#undef A1 +#undef B0 +#undef B1 + +#undef T0 +#undef T1 +#undef T2 +#undef T3 + +#undef R0 +#undef R1 +#undef R2 +#undef R3 + +// do a 3 word by 2 word divide, returns quotient and leaves remainder in A +static word SubatomicDivide(word *A, word B0, word B1) +{ + // assert {A[2],A[1]} < {B1,B0}, so quotient can fit in a word + assert(A[2] < B1 || (A[2]==B1 && A[1] < B0)); + + dword p, u; + word Q; + + // estimate the quotient: do a 2 word by 1 word divide + if (B1+1 == 0) + Q = A[2]; + else + Q = word(MAKE_DWORD(A[1], A[2]) / (B1+1)); + + // now subtract Q*B from A + p = (dword) B0*Q; + u = (dword) A[0] - LOW_WORD(p); + A[0] = LOW_WORD(u); + u = (dword) A[1] - HIGH_WORD(p) - (word)(0-HIGH_WORD(u)) - (dword)B1*Q; + A[1] = LOW_WORD(u); + A[2] += HIGH_WORD(u); + + // Q <= actual quotient, so fix it + while (A[2] || A[1] > B1 || (A[1]==B1 && A[0]>=B0)) + { + u = (dword) A[0] - B0; + A[0] = LOW_WORD(u); + u = (dword) A[1] - B1 - (word)(0-HIGH_WORD(u)); + A[1] = LOW_WORD(u); + A[2] += HIGH_WORD(u); + Q++; + assert(Q); // shouldn't overflow + } + + return Q; +} + +// do a 4 word by 2 word divide, returns 2 word quotient in Q0 and Q1 +static inline void AtomicDivide(word *Q, const word *A, const word *B) +{ + if (!B[0] && !B[1]) // if divisor is 0, we assume divisor==2**(2*WORD_BITS) + { + Q[0] = A[2]; + Q[1] = A[3]; + } + else + { + word T[4]; + T[0] = A[0]; T[1] = A[1]; T[2] = A[2]; T[3] = A[3]; + Q[1] = SubatomicDivide(T+1, B[0], B[1]); + Q[0] = SubatomicDivide(T, B[0], B[1]); + +#ifndef NDEBUG + // multiply quotient and divisor and add remainder, make sure it equals dividend + assert(!T[2] && !T[3] && (T[1] < B[1] || (T[1]==B[1] && T[0]<B[0]))); + word P[4]; + LowLevel::Multiply2(P, Q, B); + Add(P, P, T, 4); + assert(memcmp(P, A, 4*WORD_SIZE)==0); +#endif + } +} + +// for use by Divide(), corrects the underestimated quotient {Q1,Q0} +static void CorrectQuotientEstimate(word *R, word *T, word *Q, const word *B, unsigned int N) +{ + assert(N && N%2==0); + + if (Q[1]) + { + T[N] = T[N+1] = 0; + unsigned i; + for (i=0; i<N; i+=4) + LowLevel::Multiply2(T+i, Q, B+i); + for (i=2; i<N; i+=4) + if (LowLevel::Multiply2Add(T+i, Q, B+i)) + T[i+5] += (++T[i+4]==0); + } + else + { + T[N] = LinearMultiply(T, B, Q[0], N); + T[N+1] = 0; + } + + word borrow = Subtract(R, R, T, N+2); + assert(!borrow && !R[N+1]); + + while (R[N] || Compare(R, B, N) >= 0) + { + R[N] -= Subtract(R, R, B, N); + Q[1] += (++Q[0]==0); + assert(Q[0] || Q[1]); // no overflow + } +} + +// R[NB] -------- remainder = A%B +// Q[NA-NB+2] --- quotient = A/B +// T[NA+2*NB+4] - temp work space +// A[NA] -------- dividend +// B[NB] -------- divisor + +void Divide(word *R, word *Q, word *T, const word *A, unsigned int NA, const word *B, unsigned int NB) +{ + assert(NA && NB && NA%2==0 && NB%2==0); + assert(B[NB-1] || B[NB-2]); + assert(NB <= NA); + + // set up temporary work space + word *const TA=T; + word *const TB=T+NA+2; + word *const TP=T+NA+2+NB; + + // copy B into TB and normalize it so that TB has highest bit set to 1 + unsigned shiftWords = (B[NB-1]==0); + TB[0] = TB[NB-1] = 0; + CopyWords(TB+shiftWords, B, NB-shiftWords); + unsigned shiftBits = WORD_BITS - BitPrecision(TB[NB-1]); + assert(shiftBits < WORD_BITS); + ShiftWordsLeftByBits(TB, NB, shiftBits); + + // copy A into TA and normalize it + TA[0] = TA[NA] = TA[NA+1] = 0; + CopyWords(TA+shiftWords, A, NA); + ShiftWordsLeftByBits(TA, NA+2, shiftBits); + + if (TA[NA+1]==0 && TA[NA] <= 1) + { + Q[NA-NB+1] = Q[NA-NB] = 0; + while (TA[NA] || Compare(TA+NA-NB, TB, NB) >= 0) + { + TA[NA] -= Subtract(TA+NA-NB, TA+NA-NB, TB, NB); + ++Q[NA-NB]; + } + } + else + { + NA+=2; + assert(Compare(TA+NA-NB, TB, NB) < 0); + } + + word BT[2]; + BT[0] = TB[NB-2] + 1; + BT[1] = TB[NB-1] + (BT[0]==0); + + // start reducing TA mod TB, 2 words at a time + for (unsigned i=NA-2; i>=NB; i-=2) + { + AtomicDivide(Q+i-NB, TA+i-2, BT); + CorrectQuotientEstimate(TA+i-NB, TP, Q+i-NB, TB, NB); + } + + // copy TA into R, and denormalize it + CopyWords(R, TA+shiftWords, NB); + ShiftWordsRightByBits(R, NB, shiftBits); +} + +static inline unsigned int EvenWordCount(const word *X, unsigned int N) +{ + while (N && X[N-2]==0 && X[N-1]==0) + N-=2; + return N; +} + +// return k +// R[N] --- result = A^(-1) * 2^k mod M +// T[4*N] - temporary work space +// A[NA] -- number to take inverse of +// M[N] --- modulus + +unsigned int AlmostInverse(word *R, word *T, const word *A, unsigned int NA, const word *M, unsigned int N) +{ + assert(NA<=N && N && N%2==0); + + word *b = T; + word *c = T+N; + word *f = T+2*N; + word *g = T+3*N; + unsigned int bcLen=2, fgLen=EvenWordCount(M, N); + unsigned int k=0, s=0; + + SetWords(T, 0, 3*N); + b[0]=1; + CopyWords(f, A, NA); + CopyWords(g, M, N); + + while (1) + { + word t=f[0]; + while (!t) + { + if (EvenWordCount(f, fgLen)==0) + { + SetWords(R, 0, N); + return 0; + } + + ShiftWordsRightByWords(f, fgLen, 1); + if (c[bcLen-1]) bcLen+=2; + assert(bcLen <= N); + ShiftWordsLeftByWords(c, bcLen, 1); + k+=WORD_BITS; + t=f[0]; + } + + unsigned int i=0; + while (t%2 == 0) + { + t>>=1; + i++; + } + k+=i; + + if (t==1 && f[1]==0 && EvenWordCount(f, fgLen)==2) + { + if (s%2==0) + CopyWords(R, b, N); + else + Subtract(R, M, b, N); + return k; + } + + ShiftWordsRightByBits(f, fgLen, i); + t=ShiftWordsLeftByBits(c, bcLen, i); + if (t) + { + c[bcLen] = t; + bcLen+=2; + assert(bcLen <= N); + } + + if (f[fgLen-2]==0 && g[fgLen-2]==0 && f[fgLen-1]==0 && g[fgLen-1]==0) + fgLen-=2; + + if (Compare(f, g, fgLen)==-1) + { + std::swap(f, g); + std::swap(b, c); + s++; + } + + Subtract(f, f, g, fgLen); + + if (Add(b, b, c, bcLen)) + { + b[bcLen] = 1; + bcLen+=2; + assert(bcLen <= N); + } + } +} + +// R[N] - result = A/(2^k) mod M +// A[N] - input +// M[N] - modulus + +void DivideByPower2Mod(word *R, const word *A, unsigned int k, const word *M, unsigned int N) +{ + CopyWords(R, A, N); + + while (k--) + { + if (R[0]%2==0) + ShiftWordsRightByBits(R, N, 1); + else + { + word carry = Add(R, R, M, N); + ShiftWordsRightByBits(R, N, 1); + R[N-1] += carry<<(WORD_BITS-1); + } + } +} + +// R[N] - result = A*(2^k) mod M +// A[N] - input +// M[N] - modulus + +void MultiplyByPower2Mod(word *R, const word *A, unsigned int k, const word *M, unsigned int N) +{ + CopyWords(R, A, N); + + while (k--) + if (ShiftWordsLeftByBits(R, N, 1) || Compare(R, M, N)>=0) + Subtract(R, R, M, N); +} + +// ****************************************************************** + +static const unsigned int RoundupSizeTable[] = {2, 2, 2, 4, 4, 8, 8, 8, 8}; + +static inline unsigned int RoundupSize(unsigned int n) +{ + if (n<=8) + return RoundupSizeTable[n]; + else if (n<=16) + return 16; + else if (n<=32) + return 32; + else if (n<=64) + return 64; + else return 1U << BitPrecision(n-1); +} + +Integer::Integer() + : reg(2), sign(POSITIVE) +{ + reg[0] = reg[1] = 0; +} + +Integer::Integer(const Integer& t) + : reg(RoundupSize(t.WordCount())), sign(t.sign) +{ + CopyWords(reg, t.reg, reg.size()); +} + +Integer::Integer(signed long value) + : reg(2) +{ + if (value >= 0) + sign = POSITIVE; + else + { + sign = NEGATIVE; + value = -value; + } + reg[0] = word(value); + reg[1] = word(SafeRightShift<WORD_BITS, unsigned long>(value)); +} + +bool Integer::IsConvertableToLong() const +{ + if (ByteCount() > sizeof(long)) + return false; + + unsigned long value = reg[0]; + value += SafeLeftShift<WORD_BITS, unsigned long>(reg[1]); + + if (sign==POSITIVE) + return (signed long)value >= 0; + else + return -(signed long)value < 0; +} + +signed long Integer::ConvertToLong() const +{ + assert(IsConvertableToLong()); + + unsigned long value = reg[0]; + value += SafeLeftShift<WORD_BITS, unsigned long>(reg[1]); + return sign==POSITIVE ? value : -(signed long)value; +} + +Integer::Integer(BufferedTransformation &encodedInteger, unsigned int byteCount, Signedness s) +{ + Decode(encodedInteger, byteCount, s); +} + +Integer::Integer(const byte *encodedInteger, unsigned int byteCount, Signedness s) +{ + Decode(encodedInteger, byteCount, s); +} + +Integer::Integer(BufferedTransformation &bt) +{ + BERDecode(bt); +} + +Integer::Integer(RandomNumberGenerator &rng, unsigned int bitcount) +{ + Randomize(rng, bitcount); +} + +Integer::Integer(RandomNumberGenerator &rng, const Integer &min, const Integer &max, RandomNumberType rnType, const Integer &equiv, const Integer &mod) +{ + if (!Randomize(rng, min, max, rnType, equiv, mod)) + throw Integer::RandomNumberNotFound(); +} + +Integer Integer::Power2(unsigned int e) +{ + Integer r((word)0, BitsToWords(e+1)); + r.SetBit(e); + return r; +} + +const Integer &Integer::Zero() +{ + static const Integer zero; + return zero; +} + +const Integer &Integer::One() +{ + static const Integer one(1,2); + return one; +} + +const Integer &Integer::Two() +{ + static const Integer two(2,2); + return two; +} + +bool Integer::operator!() const +{ + return IsNegative() ? false : (reg[0]==0 && WordCount()==0); +} + +Integer& Integer::operator=(const Integer& t) +{ + if (this != &t) + { + reg.New(RoundupSize(t.WordCount())); + CopyWords(reg, t.reg, reg.size()); + sign = t.sign; + } + return *this; +} + +bool Integer::GetBit(unsigned int n) const +{ + if (n/WORD_BITS >= reg.size()) + return 0; + else + return bool((reg[n/WORD_BITS] >> (n % WORD_BITS)) & 1); +} + +void Integer::SetBit(unsigned int n, bool value) +{ + if (value) + { + reg.CleanGrow(RoundupSize(BitsToWords(n+1))); + reg[n/WORD_BITS] |= (word(1) << (n%WORD_BITS)); + } + else + { + if (n/WORD_BITS < reg.size()) + reg[n/WORD_BITS] &= ~(word(1) << (n%WORD_BITS)); + } +} + +byte Integer::GetByte(unsigned int n) const +{ + if (n/WORD_SIZE >= reg.size()) + return 0; + else + return byte(reg[n/WORD_SIZE] >> ((n%WORD_SIZE)*8)); +} + +void Integer::SetByte(unsigned int n, byte value) +{ + reg.CleanGrow(RoundupSize(BytesToWords(n+1))); + reg[n/WORD_SIZE] &= ~(word(0xff) << 8*(n%WORD_SIZE)); + reg[n/WORD_SIZE] |= (word(value) << 8*(n%WORD_SIZE)); +} + +unsigned long Integer::GetBits(unsigned int i, unsigned int n) const +{ + assert(n <= sizeof(unsigned long)*8); + unsigned long v = 0; + for (unsigned int j=0; j<n; j++) + v |= GetBit(i+j) << j; + return v; +} + +Integer Integer::operator-() const +{ + Integer result(*this); + result.Negate(); + return result; +} + +Integer Integer::AbsoluteValue() const +{ + Integer result(*this); + result.sign = POSITIVE; + return result; +} + +void Integer::swap(Integer &a) +{ + reg.swap(a.reg); + std::swap(sign, a.sign); +} + +Integer::Integer(word value, unsigned int length) + : reg(RoundupSize(length)), sign(POSITIVE) +{ + reg[0] = value; + SetWords(reg+1, 0, reg.size()-1); +} + +template <class T> +static Integer StringToInteger(const T *str) +{ + word radix; +#if (defined(__GNUC__) && __GNUC__ <= 3) // GCC workaround + // std::char_traits doesn't exist in GCC 2.x + // std::char_traits<wchar_t>::length() not defined in GCC 3.2 + unsigned int length; + for (length = 0; str[length] != 0; length++) {} +#else + unsigned int length = std::char_traits<T>::length(str); +#endif + + Integer v; + + if (length == 0) + return v; + + switch (str[length-1]) + { + case 'h': + case 'H': + radix=16; + break; + case 'o': + case 'O': + radix=8; + break; + case 'b': + case 'B': + radix=2; + break; + default: + radix=10; + } + + if (length > 2 && str[0] == '0' && str[1] == 'x') + radix = 16; + + for (unsigned i=0; i<length; i++) + { + word digit; + + if (str[i] >= '0' && str[i] <= '9') + digit = str[i] - '0'; + else if (str[i] >= 'A' && str[i] <= 'F') + digit = str[i] - 'A' + 10; + else if (str[i] >= 'a' && str[i] <= 'f') + digit = str[i] - 'a' + 10; + else + digit = radix; + + if (digit < radix) + { + v *= radix; + v += digit; + } + } + + if (str[0] == '-') + v.Negate(); + + return v; +} + +Integer::Integer(const char *str) + : reg(2), sign(POSITIVE) +{ + *this = StringToInteger(str); +} + +Integer::Integer(const wchar_t *str) + : reg(2), sign(POSITIVE) +{ + *this = StringToInteger(str); +} + +unsigned int Integer::WordCount() const +{ + return CountWords(reg, reg.size()); +} + +unsigned int Integer::ByteCount() const +{ + unsigned wordCount = WordCount(); + if (wordCount) + return (wordCount-1)*WORD_SIZE + BytePrecision(reg[wordCount-1]); + else + return 0; +} + +unsigned int Integer::BitCount() const +{ + unsigned wordCount = WordCount(); + if (wordCount) + return (wordCount-1)*WORD_BITS + BitPrecision(reg[wordCount-1]); + else + return 0; +} + +void Integer::Decode(const byte *input, unsigned int inputLen, Signedness s) +{ + StringStore store(input, inputLen); + Decode(store, inputLen, s); +} + +void Integer::Decode(BufferedTransformation &bt, unsigned int inputLen, Signedness s) +{ + assert(bt.MaxRetrievable() >= inputLen); + + byte b; + bt.Peek(b); + sign = ((s==SIGNED) && (b & 0x80)) ? NEGATIVE : POSITIVE; + + while (inputLen>0 && (sign==POSITIVE ? b==0 : b==0xff)) + { + bt.Skip(1); + inputLen--; + bt.Peek(b); + } + + reg.CleanNew(RoundupSize(BytesToWords(inputLen))); + + for (unsigned int i=inputLen; i > 0; i--) + { + bt.Get(b); + reg[(i-1)/WORD_SIZE] |= b << ((i-1)%WORD_SIZE)*8; + } + + if (sign == NEGATIVE) + { + for (unsigned i=inputLen; i<reg.size()*WORD_SIZE; i++) + reg[i/WORD_SIZE] |= 0xff << (i%WORD_SIZE)*8; + TwosComplement(reg, reg.size()); + } +} + +unsigned int Integer::MinEncodedSize(Signedness signedness) const +{ + unsigned int outputLen = STDMAX(1U, ByteCount()); + if (signedness == UNSIGNED) + return outputLen; + if (NotNegative() && (GetByte(outputLen-1) & 0x80)) + outputLen++; + if (IsNegative() && *this < -Power2(outputLen*8-1)) + outputLen++; + return outputLen; +} + +unsigned int Integer::Encode(byte *output, unsigned int outputLen, Signedness signedness) const +{ + ArraySink sink(output, outputLen); + return Encode(sink, outputLen, signedness); +} + +unsigned int Integer::Encode(BufferedTransformation &bt, unsigned int outputLen, Signedness signedness) const +{ + if (signedness == UNSIGNED || NotNegative()) + { + for (unsigned int i=outputLen; i > 0; i--) + bt.Put(GetByte(i-1)); + } + else + { + // take two's complement of *this + Integer temp = Integer::Power2(8*STDMAX(ByteCount(), outputLen)) + *this; + for (unsigned i=0; i<outputLen; i++) + bt.Put(temp.GetByte(outputLen-i-1)); + } + return outputLen; +} + +void Integer::DEREncode(BufferedTransformation &bt) const +{ + DERGeneralEncoder enc(bt, INTEGER); + Encode(enc, MinEncodedSize(SIGNED), SIGNED); + enc.MessageEnd(); +} + +void Integer::BERDecode(const byte *input, unsigned int len) +{ + StringStore store(input, len); + BERDecode(store); +} + +void Integer::BERDecode(BufferedTransformation &bt) +{ + BERGeneralDecoder dec(bt, INTEGER); + if (!dec.IsDefiniteLength() || dec.MaxRetrievable() < dec.RemainingLength()) + BERDecodeError(); + Decode(dec, dec.RemainingLength(), SIGNED); + dec.MessageEnd(); +} + +void Integer::DEREncodeAsOctetString(BufferedTransformation &bt, unsigned int length) const +{ + DERGeneralEncoder enc(bt, OCTET_STRING); + Encode(enc, length); + enc.MessageEnd(); +} + +void Integer::BERDecodeAsOctetString(BufferedTransformation &bt, unsigned int length) +{ + BERGeneralDecoder dec(bt, OCTET_STRING); + if (!dec.IsDefiniteLength() || dec.RemainingLength() != length) + BERDecodeError(); + Decode(dec, length); + dec.MessageEnd(); +} + +unsigned int Integer::OpenPGPEncode(byte *output, unsigned int len) const +{ + ArraySink sink(output, len); + return OpenPGPEncode(sink); +} + +unsigned int Integer::OpenPGPEncode(BufferedTransformation &bt) const +{ + word16 bitCount = BitCount(); + bt.PutWord16(bitCount); + return 2 + Encode(bt, BitsToBytes(bitCount)); +} + +void Integer::OpenPGPDecode(const byte *input, unsigned int len) +{ + StringStore store(input, len); + OpenPGPDecode(store); +} + +void Integer::OpenPGPDecode(BufferedTransformation &bt) +{ + word16 bitCount; + if (bt.GetWord16(bitCount) != 2 || bt.MaxRetrievable() < BitsToBytes(bitCount)) + throw OpenPGPDecodeErr(); + Decode(bt, BitsToBytes(bitCount)); +} + +void Integer::Randomize(RandomNumberGenerator &rng, unsigned int nbits) +{ + const unsigned int nbytes = nbits/8 + 1; + SecByteBlock buf(nbytes); + rng.GenerateBlock(buf, nbytes); + if (nbytes) + buf[0] = (byte)Crop(buf[0], nbits % 8); + Decode(buf, nbytes, UNSIGNED); +} + +void Integer::Randomize(RandomNumberGenerator &rng, const Integer &min, const Integer &max) +{ + if (min > max) + throw InvalidArgument("Integer: Min must be no greater than Max"); + + Integer range = max - min; + const unsigned int nbits = range.BitCount(); + + do + { + Randomize(rng, nbits); + } + while (*this > range); + + *this += min; +} + +bool Integer::Randomize(RandomNumberGenerator &rng, const Integer &min, const Integer &max, RandomNumberType rnType, const Integer &equiv, const Integer &mod) +{ + return GenerateRandomNoThrow(rng, MakeParameters("Min", min)("Max", max)("RandomNumberType", rnType)("EquivalentTo", equiv)("Mod", mod)); +} + +class KDF2_RNG : public RandomNumberGenerator +{ +public: + KDF2_RNG(const byte *seed, unsigned int seedSize) + : m_counter(0), m_counterAndSeed(seedSize + 4) + { + memcpy(m_counterAndSeed + 4, seed, seedSize); + } + + byte GenerateByte() + { + byte b; + GenerateBlock(&b, 1); + return b; + } + + void GenerateBlock(byte *output, unsigned int size) + { + UnalignedPutWord(BIG_ENDIAN_ORDER, m_counterAndSeed, m_counter); + ++m_counter; + P1363_KDF2<SHA1>::DeriveKey(output, size, m_counterAndSeed, m_counterAndSeed.size()); + } + +private: + word32 m_counter; + SecByteBlock m_counterAndSeed; +}; + +bool Integer::GenerateRandomNoThrow(RandomNumberGenerator &i_rng, const NameValuePairs ¶ms) +{ + Integer min = params.GetValueWithDefault("Min", Integer::Zero()); + Integer max; + if (!params.GetValue("Max", max)) + { + int bitLength; + if (params.GetIntValue("BitLength", bitLength)) + max = Integer::Power2(bitLength); + else + throw InvalidArgument("Integer: missing Max argument"); + } + if (min > max) + throw InvalidArgument("Integer: Min must be no greater than Max"); + + Integer equiv = params.GetValueWithDefault("EquivalentTo", Integer::Zero()); + Integer mod = params.GetValueWithDefault("Mod", Integer::One()); + + if (equiv.IsNegative() || equiv >= mod) + throw InvalidArgument("Integer: invalid EquivalentTo and/or Mod argument"); + + Integer::RandomNumberType rnType = params.GetValueWithDefault("RandomNumberType", Integer::ANY); + + member_ptr<KDF2_RNG> kdf2Rng; + ConstByteArrayParameter seed; + if (params.GetValue("Seed", seed)) + { + ByteQueue bq; + DERSequenceEncoder seq(bq); + min.DEREncode(seq); + max.DEREncode(seq); + equiv.DEREncode(seq); + mod.DEREncode(seq); + DEREncodeUnsigned(seq, rnType); + DEREncodeOctetString(seq, seed.begin(), seed.size()); + seq.MessageEnd(); + + SecByteBlock finalSeed(bq.MaxRetrievable()); + bq.Get(finalSeed, finalSeed.size()); + kdf2Rng.reset(new KDF2_RNG(finalSeed.begin(), finalSeed.size())); + } + RandomNumberGenerator &rng = kdf2Rng.get() ? (RandomNumberGenerator &)*kdf2Rng : i_rng; + + switch (rnType) + { + case ANY: + if (mod == One()) + Randomize(rng, min, max); + else + { + Integer min1 = min + (equiv-min)%mod; + if (max < min1) + return false; + Randomize(rng, Zero(), (max - min1) / mod); + *this *= mod; + *this += min1; + } + return true; + + case PRIME: + { + const PrimeSelector *pSelector = params.GetValueWithDefault("PointerToPrimeSelector", (const PrimeSelector *)NULL); + + int i; + i = 0; + while (1) + { + if (++i==16) + { + // check if there are any suitable primes in [min, max] + Integer first = min; + if (FirstPrime(first, max, equiv, mod, pSelector)) + { + // if there is only one suitable prime, we're done + *this = first; + if (!FirstPrime(first, max, equiv, mod, pSelector)) + return true; + } + else + return false; + } + + Randomize(rng, min, max); + if (FirstPrime(*this, STDMIN(*this+mod*PrimeSearchInterval(max), max), equiv, mod, pSelector)) + return true; + } + } + + default: + throw InvalidArgument("Integer: invalid RandomNumberType argument"); + } +} + +std::istream& operator>>(std::istream& in, Integer &a) +{ + char c; + unsigned int length = 0; + SecBlock<char> str(length + 16); + + std::ws(in); + + do + { + in.read(&c, 1); + str[length++] = c; + if (length >= str.size()) + str.Grow(length + 16); + } + while (in && (c=='-' || c=='x' || (c>='0' && c<='9') || (c>='a' && c<='f') || (c>='A' && c<='F') || c=='h' || c=='H' || c=='o' || c=='O' || c==',' || c=='.')); + + if (in.gcount()) + in.putback(c); + str[length-1] = '\0'; + a = Integer(str); + + return in; +} + +std::ostream& operator<<(std::ostream& out, const Integer &a) +{ + // Get relevant conversion specifications from ostream. + long f = out.flags() & std::ios::basefield; // Get base digits. + int base, block; + char suffix; + switch(f) + { + case std::ios::oct : + base = 8; + block = 8; + suffix = 'o'; + break; + case std::ios::hex : + base = 16; + block = 4; + suffix = 'h'; + break; + default : + base = 10; + block = 3; + suffix = '.'; + } + + SecBlock<char> s(a.BitCount() / (BitPrecision(base)-1) + 1); + Integer temp1=a, temp2; + unsigned i=0; + const char vec[]="0123456789ABCDEF"; + + if (a.IsNegative()) + { + out << '-'; + temp1.Negate(); + } + + if (!a) + out << '0'; + + while (!!temp1) + { + word digit; + Integer::Divide(digit, temp2, temp1, base); + s[i++]=vec[digit]; + temp1=temp2; + } + + while (i--) + { + out << s[i]; +// if (i && !(i%block)) +// out << ","; + } + return out << suffix; +} + +Integer& Integer::operator++() +{ + if (NotNegative()) + { + if (Increment(reg, reg.size())) + { + reg.CleanGrow(2*reg.size()); + reg[reg.size()/2]=1; + } + } + else + { + word borrow = Decrement(reg, reg.size()); + assert(!borrow); + if (WordCount()==0) + *this = Zero(); + } + return *this; +} + +Integer& Integer::operator--() +{ + if (IsNegative()) + { + if (Increment(reg, reg.size())) + { + reg.CleanGrow(2*reg.size()); + reg[reg.size()/2]=1; + } + } + else + { + if (Decrement(reg, reg.size())) + *this = -One(); + } + return *this; +} + +void PositiveAdd(Integer &sum, const Integer &a, const Integer& b) +{ + word carry; + if (a.reg.size() == b.reg.size()) + carry = Add(sum.reg, a.reg, b.reg, a.reg.size()); + else if (a.reg.size() > b.reg.size()) + { + carry = Add(sum.reg, a.reg, b.reg, b.reg.size()); + CopyWords(sum.reg+b.reg.size(), a.reg+b.reg.size(), a.reg.size()-b.reg.size()); + carry = Increment(sum.reg+b.reg.size(), a.reg.size()-b.reg.size(), carry); + } + else + { + carry = Add(sum.reg, a.reg, b.reg, a.reg.size()); + CopyWords(sum.reg+a.reg.size(), b.reg+a.reg.size(), b.reg.size()-a.reg.size()); + carry = Increment(sum.reg+a.reg.size(), b.reg.size()-a.reg.size(), carry); + } + + if (carry) + { + sum.reg.CleanGrow(2*sum.reg.size()); + sum.reg[sum.reg.size()/2] = 1; + } + sum.sign = Integer::POSITIVE; +} + +void PositiveSubtract(Integer &diff, const Integer &a, const Integer& b) +{ + unsigned aSize = a.WordCount(); + aSize += aSize%2; + unsigned bSize = b.WordCount(); + bSize += bSize%2; + + if (aSize == bSize) + { + if (Compare(a.reg, b.reg, aSize) >= 0) + { + Subtract(diff.reg, a.reg, b.reg, aSize); + diff.sign = Integer::POSITIVE; + } + else + { + Subtract(diff.reg, b.reg, a.reg, aSize); + diff.sign = Integer::NEGATIVE; + } + } + else if (aSize > bSize) + { + word borrow = Subtract(diff.reg, a.reg, b.reg, bSize); + CopyWords(diff.reg+bSize, a.reg+bSize, aSize-bSize); + borrow = Decrement(diff.reg+bSize, aSize-bSize, borrow); + assert(!borrow); + diff.sign = Integer::POSITIVE; + } + else + { + word borrow = Subtract(diff.reg, b.reg, a.reg, aSize); + CopyWords(diff.reg+aSize, b.reg+aSize, bSize-aSize); + borrow = Decrement(diff.reg+aSize, bSize-aSize, borrow); + assert(!borrow); + diff.sign = Integer::NEGATIVE; + } +} + +Integer Integer::Plus(const Integer& b) const +{ + Integer sum((word)0, STDMAX(reg.size(), b.reg.size())); + if (NotNegative()) + { + if (b.NotNegative()) + PositiveAdd(sum, *this, b); + else + PositiveSubtract(sum, *this, b); + } + else + { + if (b.NotNegative()) + PositiveSubtract(sum, b, *this); + else + { + PositiveAdd(sum, *this, b); + sum.sign = Integer::NEGATIVE; + } + } + return sum; +} + +Integer& Integer::operator+=(const Integer& t) +{ + reg.CleanGrow(t.reg.size()); + if (NotNegative()) + { + if (t.NotNegative()) + PositiveAdd(*this, *this, t); + else + PositiveSubtract(*this, *this, t); + } + else + { + if (t.NotNegative()) + PositiveSubtract(*this, t, *this); + else + { + PositiveAdd(*this, *this, t); + sign = Integer::NEGATIVE; + } + } + return *this; +} + +Integer Integer::Minus(const Integer& b) const +{ + Integer diff((word)0, STDMAX(reg.size(), b.reg.size())); + if (NotNegative()) + { + if (b.NotNegative()) + PositiveSubtract(diff, *this, b); + else + PositiveAdd(diff, *this, b); + } + else + { + if (b.NotNegative()) + { + PositiveAdd(diff, *this, b); + diff.sign = Integer::NEGATIVE; + } + else + PositiveSubtract(diff, b, *this); + } + return diff; +} + +Integer& Integer::operator-=(const Integer& t) +{ + reg.CleanGrow(t.reg.size()); + if (NotNegative()) + { + if (t.NotNegative()) + PositiveSubtract(*this, *this, t); + else + PositiveAdd(*this, *this, t); + } + else + { + if (t.NotNegative()) + { + PositiveAdd(*this, *this, t); + sign = Integer::NEGATIVE; + } + else + PositiveSubtract(*this, t, *this); + } + return *this; +} + +Integer& Integer::operator<<=(unsigned int n) +{ + const unsigned int wordCount = WordCount(); + const unsigned int shiftWords = n / WORD_BITS; + const unsigned int shiftBits = n % WORD_BITS; + + reg.CleanGrow(RoundupSize(wordCount+BitsToWords(n))); + ShiftWordsLeftByWords(reg, wordCount + shiftWords, shiftWords); + ShiftWordsLeftByBits(reg+shiftWords, wordCount+BitsToWords(shiftBits), shiftBits); + return *this; +} + +Integer& Integer::operator>>=(unsigned int n) +{ + const unsigned int wordCount = WordCount(); + const unsigned int shiftWords = n / WORD_BITS; + const unsigned int shiftBits = n % WORD_BITS; + + ShiftWordsRightByWords(reg, wordCount, shiftWords); + if (wordCount > shiftWords) + ShiftWordsRightByBits(reg, wordCount-shiftWords, shiftBits); + if (IsNegative() && WordCount()==0) // avoid -0 + *this = Zero(); + return *this; +} + +void PositiveMultiply(Integer &product, const Integer &a, const Integer &b) +{ + unsigned aSize = RoundupSize(a.WordCount()); + unsigned bSize = RoundupSize(b.WordCount()); + + product.reg.CleanNew(RoundupSize(aSize+bSize)); + product.sign = Integer::POSITIVE; + + SecAlignedWordBlock workspace(aSize + bSize); + AsymmetricMultiply(product.reg, workspace, a.reg, aSize, b.reg, bSize); +} + +void Multiply(Integer &product, const Integer &a, const Integer &b) +{ + PositiveMultiply(product, a, b); + + if (a.NotNegative() != b.NotNegative()) + product.Negate(); +} + +Integer Integer::Times(const Integer &b) const +{ + Integer product; + Multiply(product, *this, b); + return product; +} + +/* +void PositiveDivide(Integer &remainder, Integer "ient, + const Integer ÷nd, const Integer &divisor) +{ + remainder.reg.CleanNew(divisor.reg.size()); + remainder.sign = Integer::POSITIVE; + quotient.reg.New(0); + quotient.sign = Integer::POSITIVE; + unsigned i=dividend.BitCount(); + while (i--) + { + word overflow = ShiftWordsLeftByBits(remainder.reg, remainder.reg.size(), 1); + remainder.reg[0] |= dividend[i]; + if (overflow || remainder >= divisor) + { + Subtract(remainder.reg, remainder.reg, divisor.reg, remainder.reg.size()); + quotient.SetBit(i); + } + } +} +*/ + +void PositiveDivide(Integer &remainder, Integer "ient, + const Integer &a, const Integer &b) +{ + unsigned aSize = a.WordCount(); + unsigned bSize = b.WordCount(); + + if (!bSize) + throw Integer::DivideByZero(); + + if (a.PositiveCompare(b) == -1) + { + remainder = a; + remainder.sign = Integer::POSITIVE; + quotient = Integer::Zero(); + return; + } + + aSize += aSize%2; // round up to next even number + bSize += bSize%2; + + remainder.reg.CleanNew(RoundupSize(bSize)); + remainder.sign = Integer::POSITIVE; + quotient.reg.CleanNew(RoundupSize(aSize-bSize+2)); + quotient.sign = Integer::POSITIVE; + + SecAlignedWordBlock T(aSize+2*bSize+4); + Divide(remainder.reg, quotient.reg, T, a.reg, aSize, b.reg, bSize); +} + +void Integer::Divide(Integer &remainder, Integer "ient, const Integer ÷nd, const Integer &divisor) +{ + PositiveDivide(remainder, quotient, dividend, divisor); + + if (dividend.IsNegative()) + { + quotient.Negate(); + if (remainder.NotZero()) + { + --quotient; + remainder = divisor.AbsoluteValue() - remainder; + } + } + + if (divisor.IsNegative()) + quotient.Negate(); +} + +void Integer::DivideByPowerOf2(Integer &r, Integer &q, const Integer &a, unsigned int n) +{ + q = a; + q >>= n; + + const unsigned int wordCount = BitsToWords(n); + if (wordCount <= a.WordCount()) + { + r.reg.resize(RoundupSize(wordCount)); + CopyWords(r.reg, a.reg, wordCount); + SetWords(r.reg+wordCount, 0, r.reg.size()-wordCount); + if (n % WORD_BITS != 0) + r.reg[wordCount-1] %= (1 << (n % WORD_BITS)); + } + else + { + r.reg.resize(RoundupSize(a.WordCount())); + CopyWords(r.reg, a.reg, r.reg.size()); + } + r.sign = POSITIVE; + + if (a.IsNegative() && r.NotZero()) + { + --q; + r = Power2(n) - r; + } +} + +Integer Integer::DividedBy(const Integer &b) const +{ + Integer remainder, quotient; + Integer::Divide(remainder, quotient, *this, b); + return quotient; +} + +Integer Integer::Modulo(const Integer &b) const +{ + Integer remainder, quotient; + Integer::Divide(remainder, quotient, *this, b); + return remainder; +} + +void Integer::Divide(word &remainder, Integer "ient, const Integer ÷nd, word divisor) +{ + if (!divisor) + throw Integer::DivideByZero(); + + assert(divisor); + + if ((divisor & (divisor-1)) == 0) // divisor is a power of 2 + { + quotient = dividend >> (BitPrecision(divisor)-1); + remainder = dividend.reg[0] & (divisor-1); + return; + } + + unsigned int i = dividend.WordCount(); + quotient.reg.CleanNew(RoundupSize(i)); + remainder = 0; + while (i--) + { + quotient.reg[i] = word(MAKE_DWORD(dividend.reg[i], remainder) / divisor); + remainder = word(MAKE_DWORD(dividend.reg[i], remainder) % divisor); + } + + if (dividend.NotNegative()) + quotient.sign = POSITIVE; + else + { + quotient.sign = NEGATIVE; + if (remainder) + { + --quotient; + remainder = divisor - remainder; + } + } +} + +Integer Integer::DividedBy(word b) const +{ + word remainder; + Integer quotient; + Integer::Divide(remainder, quotient, *this, b); + return quotient; +} + +word Integer::Modulo(word divisor) const +{ + if (!divisor) + throw Integer::DivideByZero(); + + assert(divisor); + + word remainder; + + if ((divisor & (divisor-1)) == 0) // divisor is a power of 2 + remainder = reg[0] & (divisor-1); + else + { + unsigned int i = WordCount(); + + if (divisor <= 5) + { + dword sum=0; + while (i--) + sum += reg[i]; + remainder = word(sum%divisor); + } + else + { + remainder = 0; + while (i--) + remainder = word(MAKE_DWORD(reg[i], remainder) % divisor); + } + } + + if (IsNegative() && remainder) + remainder = divisor - remainder; + + return remainder; +} + +void Integer::Negate() +{ + if (!!(*this)) // don't flip sign if *this==0 + sign = Sign(1-sign); +} + +int Integer::PositiveCompare(const Integer& t) const +{ + unsigned size = WordCount(), tSize = t.WordCount(); + + if (size == tSize) + return CryptoPP::Compare(reg, t.reg, size); + else + return size > tSize ? 1 : -1; +} + +int Integer::Compare(const Integer& t) const +{ + if (NotNegative()) + { + if (t.NotNegative()) + return PositiveCompare(t); + else + return 1; + } + else + { + if (t.NotNegative()) + return -1; + else + return -PositiveCompare(t); + } +} + +Integer Integer::SquareRoot() const +{ + if (!IsPositive()) + return Zero(); + + // overestimate square root + Integer x, y = Power2((BitCount()+1)/2); + assert(y*y >= *this); + + do + { + x = y; + y = (x + *this/x) >> 1; + } while (y<x); + + return x; +} + +bool Integer::IsSquare() const +{ + Integer r = SquareRoot(); + return *this == r.Squared(); +} + +bool Integer::IsUnit() const +{ + return (WordCount() == 1) && (reg[0] == 1); +} + +Integer Integer::MultiplicativeInverse() const +{ + return IsUnit() ? *this : Zero(); +} + +Integer a_times_b_mod_c(const Integer &x, const Integer& y, const Integer& m) +{ + return x*y%m; +} + +Integer a_exp_b_mod_c(const Integer &x, const Integer& e, const Integer& m) +{ + ModularArithmetic mr(m); + return mr.Exponentiate(x, e); +} + +Integer Integer::Gcd(const Integer &a, const Integer &b) +{ + return EuclideanDomainOf<Integer>().Gcd(a, b); +} + +Integer Integer::InverseMod(const Integer &m) const +{ + assert(m.NotNegative()); + + if (IsNegative() || *this>=m) + return (*this%m).InverseMod(m); + + if (m.IsEven()) + { + if (!m || IsEven()) + return Zero(); // no inverse + if (*this == One()) + return One(); + + Integer u = m.InverseMod(*this); + return !u ? Zero() : (m*(*this-u)+1)/(*this); + } + + SecBlock<word> T(m.reg.size() * 4); + Integer r((word)0, m.reg.size()); + unsigned k = AlmostInverse(r.reg, T, reg, reg.size(), m.reg, m.reg.size()); + DivideByPower2Mod(r.reg, r.reg, k, m.reg, m.reg.size()); + return r; +} + +word Integer::InverseMod(const word mod) const +{ + word g0 = mod, g1 = *this % mod; + word v0 = 0, v1 = 1; + word y; + + while (g1) + { + if (g1 == 1) + return v1; + y = g0 / g1; + g0 = g0 % g1; + v0 += y * v1; + + if (!g0) + break; + if (g0 == 1) + return mod-v0; + y = g1 / g0; + g1 = g1 % g0; + v1 += y * v0; + } + return 0; +} + +// ******************************************************** + +ModularArithmetic::ModularArithmetic(BufferedTransformation &bt) +{ + BERSequenceDecoder seq(bt); + OID oid(seq); + if (oid != ASN1::prime_field()) + BERDecodeError(); + modulus.BERDecode(seq); + seq.MessageEnd(); + result.reg.resize(modulus.reg.size()); +} + +void ModularArithmetic::DEREncode(BufferedTransformation &bt) const +{ + DERSequenceEncoder seq(bt); + ASN1::prime_field().DEREncode(seq); + modulus.DEREncode(seq); + seq.MessageEnd(); +} + +void ModularArithmetic::DEREncodeElement(BufferedTransformation &out, const Element &a) const +{ + a.DEREncodeAsOctetString(out, MaxElementByteLength()); +} + +void ModularArithmetic::BERDecodeElement(BufferedTransformation &in, Element &a) const +{ + a.BERDecodeAsOctetString(in, MaxElementByteLength()); +} + +const Integer& ModularArithmetic::Half(const Integer &a) const +{ + if (a.reg.size()==modulus.reg.size()) + { + CryptoPP::DivideByPower2Mod(result.reg.begin(), a.reg, 1, modulus.reg, a.reg.size()); + return result; + } + else + return result1 = (a.IsEven() ? (a >> 1) : ((a+modulus) >> 1)); +} + +const Integer& ModularArithmetic::Add(const Integer &a, const Integer &b) const +{ + if (a.reg.size()==modulus.reg.size() && b.reg.size()==modulus.reg.size()) + { + if (CryptoPP::Add(result.reg.begin(), a.reg, b.reg, a.reg.size()) + || Compare(result.reg, modulus.reg, a.reg.size()) >= 0) + { + CryptoPP::Subtract(result.reg.begin(), result.reg, modulus.reg, a.reg.size()); + } + return result; + } + else + { + result1 = a+b; + if (result1 >= modulus) + result1 -= modulus; + return result1; + } +} + +Integer& ModularArithmetic::Accumulate(Integer &a, const Integer &b) const +{ + if (a.reg.size()==modulus.reg.size() && b.reg.size()==modulus.reg.size()) + { + if (CryptoPP::Add(a.reg, a.reg, b.reg, a.reg.size()) + || Compare(a.reg, modulus.reg, a.reg.size()) >= 0) + { + CryptoPP::Subtract(a.reg, a.reg, modulus.reg, a.reg.size()); + } + } + else + { + a+=b; + if (a>=modulus) + a-=modulus; + } + + return a; +} + +const Integer& ModularArithmetic::Subtract(const Integer &a, const Integer &b) const +{ + if (a.reg.size()==modulus.reg.size() && b.reg.size()==modulus.reg.size()) + { + if (CryptoPP::Subtract(result.reg.begin(), a.reg, b.reg, a.reg.size())) + CryptoPP::Add(result.reg.begin(), result.reg, modulus.reg, a.reg.size()); + return result; + } + else + { + result1 = a-b; + if (result1.IsNegative()) + result1 += modulus; + return result1; + } +} + +Integer& ModularArithmetic::Reduce(Integer &a, const Integer &b) const +{ + if (a.reg.size()==modulus.reg.size() && b.reg.size()==modulus.reg.size()) + { + if (CryptoPP::Subtract(a.reg, a.reg, b.reg, a.reg.size())) + CryptoPP::Add(a.reg, a.reg, modulus.reg, a.reg.size()); + } + else + { + a-=b; + if (a.IsNegative()) + a+=modulus; + } + + return a; +} + +const Integer& ModularArithmetic::Inverse(const Integer &a) const +{ + if (!a) + return a; + + CopyWords(result.reg.begin(), modulus.reg, modulus.reg.size()); + if (CryptoPP::Subtract(result.reg.begin(), result.reg, a.reg, a.reg.size())) + Decrement(result.reg.begin()+a.reg.size(), 1, modulus.reg.size()-a.reg.size()); + + return result; +} + +Integer ModularArithmetic::CascadeExponentiate(const Integer &x, const Integer &e1, const Integer &y, const Integer &e2) const +{ + if (modulus.IsOdd()) + { + MontgomeryRepresentation dr(modulus); + return dr.ConvertOut(dr.CascadeExponentiate(dr.ConvertIn(x), e1, dr.ConvertIn(y), e2)); + } + else + return AbstractRing<Integer>::CascadeExponentiate(x, e1, y, e2); +} + +void ModularArithmetic::SimultaneousExponentiate(Integer *results, const Integer &base, const Integer *exponents, unsigned int exponentsCount) const +{ + if (modulus.IsOdd()) + { + MontgomeryRepresentation dr(modulus); + dr.SimultaneousExponentiate(results, dr.ConvertIn(base), exponents, exponentsCount); + for (unsigned int i=0; i<exponentsCount; i++) + results[i] = dr.ConvertOut(results[i]); + } + else + AbstractRing<Integer>::SimultaneousExponentiate(results, base, exponents, exponentsCount); +} + +MontgomeryRepresentation::MontgomeryRepresentation(const Integer &m) // modulus must be odd + : ModularArithmetic(m), + u((word)0, modulus.reg.size()), + workspace(5*modulus.reg.size()) +{ + if (!modulus.IsOdd()) + throw InvalidArgument("MontgomeryRepresentation: Montgomery representation requires an odd modulus"); + + RecursiveInverseModPower2(u.reg, workspace, modulus.reg, modulus.reg.size()); +} + +const Integer& MontgomeryRepresentation::Multiply(const Integer &a, const Integer &b) const +{ + word *const T = workspace.begin(); + word *const R = result.reg.begin(); + const unsigned int N = modulus.reg.size(); + assert(a.reg.size()<=N && b.reg.size()<=N); + + AsymmetricMultiply(T, T+2*N, a.reg, a.reg.size(), b.reg, b.reg.size()); + SetWords(T+a.reg.size()+b.reg.size(), 0, 2*N-a.reg.size()-b.reg.size()); + MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N); + return result; +} + +const Integer& MontgomeryRepresentation::Square(const Integer &a) const +{ + word *const T = workspace.begin(); + word *const R = result.reg.begin(); + const unsigned int N = modulus.reg.size(); + assert(a.reg.size()<=N); + + CryptoPP::Square(T, T+2*N, a.reg, a.reg.size()); + SetWords(T+2*a.reg.size(), 0, 2*N-2*a.reg.size()); + MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N); + return result; +} + +Integer MontgomeryRepresentation::ConvertOut(const Integer &a) const +{ + word *const T = workspace.begin(); + word *const R = result.reg.begin(); + const unsigned int N = modulus.reg.size(); + assert(a.reg.size()<=N); + + CopyWords(T, a.reg, a.reg.size()); + SetWords(T+a.reg.size(), 0, 2*N-a.reg.size()); + MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N); + return result; +} + +const Integer& MontgomeryRepresentation::MultiplicativeInverse(const Integer &a) const +{ +// return (EuclideanMultiplicativeInverse(a, modulus)<<(2*WORD_BITS*modulus.reg.size()))%modulus; + word *const T = workspace.begin(); + word *const R = result.reg.begin(); + const unsigned int N = modulus.reg.size(); + assert(a.reg.size()<=N); + + CopyWords(T, a.reg, a.reg.size()); + SetWords(T+a.reg.size(), 0, 2*N-a.reg.size()); + MontgomeryReduce(R, T+2*N, T, modulus.reg, u.reg, N); + unsigned k = AlmostInverse(R, T, R, N, modulus.reg, N); + +// cout << "k=" << k << " N*32=" << 32*N << endl; + + if (k>N*WORD_BITS) + DivideByPower2Mod(R, R, k-N*WORD_BITS, modulus.reg, N); + else + MultiplyByPower2Mod(R, R, N*WORD_BITS-k, modulus.reg, N); + + return result; +} + +template class AbstractRing<Integer>; + +NAMESPACE_END diff --git a/modes.cpp b/modes.cpp new file mode 100644 index 0000000..70c2323 --- /dev/null +++ b/modes.cpp @@ -0,0 +1,266 @@ +// modes.cpp - written and placed in the public domain by Wei Dai + +#include "pch.h" +#include "modes.h" + +#include "des.h" + +#include "strciphr.cpp" + +NAMESPACE_BEGIN(CryptoPP) + +void Modes_TestInstantiations() +{ + CFB_Mode<DES>::Encryption m0; + CFB_Mode<DES>::Decryption m1; + OFB_Mode<DES>::Encryption m2; + CTR_Mode<DES>::Encryption m3; + ECB_Mode<DES>::Encryption m4; + CBC_Mode<DES>::Encryption m5; +} + +// explicit instantiations for Darwin gcc-932.1 +template class CFB_CipherTemplate<AbstractPolicyHolder<CFB_CipherAbstractPolicy, SymmetricCipher> >; +template class CFB_EncryptionTemplate<>; +template class CFB_DecryptionTemplate<>; +template class AdditiveCipherTemplate<>; +template class CFB_CipherTemplate<AbstractPolicyHolder<CFB_CipherAbstractPolicy, CFB_ModePolicy> >; +template class CFB_EncryptionTemplate<AbstractPolicyHolder<CFB_CipherAbstractPolicy, CFB_ModePolicy> >; +template class CFB_DecryptionTemplate<AbstractPolicyHolder<CFB_CipherAbstractPolicy, CFB_ModePolicy> >; +template class AdditiveCipherTemplate<AbstractPolicyHolder<AdditiveCipherAbstractPolicy, OFB_ModePolicy> >; +template class AdditiveCipherTemplate<AbstractPolicyHolder<AdditiveCipherAbstractPolicy, CTR_ModePolicy> >; + +void CipherModeBase::SetKey(const byte *key, unsigned int length, const NameValuePairs ¶ms) +{ + UncheckedSetKey(params, key, length); // the underlying cipher will check the key length +} + +void CipherModeBase::GetNextIV(byte *IV) +{ + if (!IsForwardTransformation()) + throw NotImplemented("CipherModeBase: GetNextIV() must be called on an encryption object"); + + m_cipher->ProcessBlock(m_register); + memcpy(IV, m_register, BlockSize()); +} + +void CipherModeBase::SetIV(const byte *iv) +{ + if (iv) + Resynchronize(iv); + else if (IsResynchronizable()) + { + if (!CanUseStructuredIVs()) + throw InvalidArgument("CipherModeBase: this cipher mode cannot use a null IV"); + + // use all zeros as default IV + SecByteBlock iv(BlockSize()); + memset(iv, 0, iv.size()); + Resynchronize(iv); + } +} + +void CTR_ModePolicy::SeekToIteration(dword iterationCount) +{ + int carry=0; + for (int i=BlockSize()-1; i>=0; i--) + { + unsigned int sum = m_register[i] + byte(iterationCount) + carry; + m_counterArray[i] = (byte) sum; + carry = sum >> 8; + iterationCount >>= 8; + } +} + +static inline void IncrementCounterByOne(byte *inout, unsigned int s) +{ + for (int i=s-1, carry=1; i>=0 && carry; i--) + carry = !++inout[i]; +} + +static inline void IncrementCounterByOne(byte *output, const byte *input, unsigned int s) +{ + for (int i=s-1, carry=1; i>=0; i--) + carry = !(output[i] = input[i]+carry) && carry; +} + +inline void CTR_ModePolicy::ProcessMultipleBlocks(byte *output, const byte *input, unsigned int n) +{ + unsigned int s = BlockSize(), j = 0; + for (unsigned int i=1; i<n; i++, j+=s) + IncrementCounterByOne(m_counterArray + j + s, m_counterArray + j, s); + m_cipher->ProcessAndXorMultipleBlocks(m_counterArray, input, output, n); + IncrementCounterByOne(m_counterArray, m_counterArray + s*(n-1), s); +} + +void CTR_ModePolicy::OperateKeystream(KeystreamOperation operation, byte *output, const byte *input, unsigned int iterationCount) +{ + unsigned int maxBlocks = m_cipher->OptimalNumberOfParallelBlocks(); + if (maxBlocks == 1) + { + unsigned int sizeIncrement = BlockSize(); + while (iterationCount) + { + m_cipher->ProcessAndXorBlock(m_counterArray, input, output); + IncrementCounterByOne(m_counterArray, sizeIncrement); + output += sizeIncrement; + input += sizeIncrement; + iterationCount -= 1; + } + } + else + { + unsigned int sizeIncrement = maxBlocks * BlockSize(); + while (iterationCount >= maxBlocks) + { + ProcessMultipleBlocks(output, input, maxBlocks); + output += sizeIncrement; + input += sizeIncrement; + iterationCount -= maxBlocks; + } + if (iterationCount > 0) + ProcessMultipleBlocks(output, input, iterationCount); + } +} + +void CTR_ModePolicy::CipherResynchronize(byte *keystreamBuffer, const byte *iv) +{ + unsigned int s = BlockSize(); + memcpy(m_register, iv, s); + m_counterArray.New(s * m_cipher->OptimalNumberOfParallelBlocks()); + memcpy(m_counterArray, iv, s); +} + +void BlockOrientedCipherModeBase::UncheckedSetKey(const NameValuePairs ¶ms, const byte *key, unsigned int length) +{ + m_cipher->SetKey(key, length, params); + ResizeBuffers(); + const byte *iv = params.GetValueWithDefault(Name::IV(), (const byte *)NULL); + SetIV(iv); +} + +void BlockOrientedCipherModeBase::ProcessData(byte *outString, const byte *inString, unsigned int length) +{ + unsigned int s = BlockSize(); + assert(length % s == 0); + unsigned int alignment = m_cipher->BlockAlignment(); + bool inputAlignmentOk = !RequireAlignedInput() || IsAlignedOn(inString, alignment); + + if (IsAlignedOn(outString, alignment)) + { + if (inputAlignmentOk) + ProcessBlocks(outString, inString, length / s); + else + { + memcpy(outString, inString, length); + ProcessBlocks(outString, outString, length / s); + } + } + else + { + while (length) + { + if (inputAlignmentOk) + ProcessBlocks(m_buffer, inString, 1); + else + { + memcpy(m_buffer, inString, s); + ProcessBlocks(m_buffer, m_buffer, 1); + } + memcpy(outString, m_buffer, s); + inString += s; + outString += s; + length -= s; + } + } +} + +void CBC_Encryption::ProcessBlocks(byte *outString, const byte *inString, unsigned int numberOfBlocks) +{ + unsigned int blockSize = BlockSize(); + while (numberOfBlocks--) + { + xorbuf(m_register, inString, blockSize); + m_cipher->ProcessBlock(m_register); + memcpy(outString, m_register, blockSize); + inString += blockSize; + outString += blockSize; + } +} + +void CBC_CTS_Encryption::ProcessLastBlock(byte *outString, const byte *inString, unsigned int length) +{ + if (length <= BlockSize()) + { + if (!m_stolenIV) + throw InvalidArgument("CBC_Encryption: message is too short for ciphertext stealing"); + + // steal from IV + memcpy(outString, m_register, length); + outString = m_stolenIV; + } + else + { + // steal from next to last block + xorbuf(m_register, inString, BlockSize()); + m_cipher->ProcessBlock(m_register); + inString += BlockSize(); + length -= BlockSize(); + memcpy(outString+BlockSize(), m_register, length); + } + + // output last full ciphertext block + xorbuf(m_register, inString, length); + m_cipher->ProcessBlock(m_register); + memcpy(outString, m_register, BlockSize()); +} + +void CBC_Decryption::ProcessBlocks(byte *outString, const byte *inString, unsigned int numberOfBlocks) +{ + unsigned int blockSize = BlockSize(); + while (numberOfBlocks--) + { + memcpy(m_temp, inString, blockSize); + m_cipher->ProcessBlock(m_temp, outString); + xorbuf(outString, m_register, blockSize); + m_register.swap(m_temp); + inString += blockSize; + outString += blockSize; + } +} + +void CBC_CTS_Decryption::ProcessLastBlock(byte *outString, const byte *inString, unsigned int length) +{ + const byte *pn, *pn1; + bool stealIV = length <= BlockSize(); + + if (stealIV) + { + pn = inString; + pn1 = m_register; + } + else + { + pn = inString + BlockSize(); + pn1 = inString; + length -= BlockSize(); + } + + // decrypt last partial plaintext block + memcpy(m_temp, pn1, BlockSize()); + m_cipher->ProcessBlock(m_temp); + xorbuf(m_temp, pn, length); + + if (stealIV) + memcpy(outString, m_temp, length); + else + { + memcpy(outString+BlockSize(), m_temp, length); + // decrypt next to last plaintext block + memcpy(m_temp, pn, length); + m_cipher->ProcessBlock(m_temp); + xorbuf(outString, m_temp, m_register, BlockSize()); + } +} + +NAMESPACE_END @@ -0,0 +1,370 @@ +#ifndef CRYPTOPP_MODES_H +#define CRYPTOPP_MODES_H + +/*! \file +*/ + +#include "cryptlib.h" +#include "secblock.h" +#include "misc.h" +#include "strciphr.h" +#include "argnames.h" +#include "algparam.h" + +NAMESPACE_BEGIN(CryptoPP) + +//! Cipher mode documentation. See NIST SP 800-38A for definitions of these modes. + +/*! Each class derived from this one defines two types, Encryption and Decryption, + both of which implement the SymmetricCipher interface. + For each mode there are two classes, one of which is a template class, + and the other one has a name that ends in "_ExternalCipher". + The "external cipher" mode objects hold a reference to the underlying block cipher, + instead of holding an instance of it. The reference must be passed in to the constructor. + For the "cipher holder" classes, the CIPHER template parameter should be a class + derived from BlockCipherDocumentation, for example DES or AES. +*/ +struct CipherModeDocumentation : public SymmetricCipherDocumentation +{ +}; + +class CipherModeBase : public SymmetricCipher +{ +public: + unsigned int MinKeyLength() const {return m_cipher->MinKeyLength();} + unsigned int MaxKeyLength() const {return m_cipher->MaxKeyLength();} + unsigned int DefaultKeyLength() const {return m_cipher->DefaultKeyLength();} + unsigned int GetValidKeyLength(unsigned int n) const {return m_cipher->GetValidKeyLength(n);} + bool IsValidKeyLength(unsigned int n) const {return m_cipher->IsValidKeyLength(n);} + + void SetKey(const byte *key, unsigned int length, const NameValuePairs ¶ms = g_nullNameValuePairs); + + unsigned int OptimalDataAlignment() const {return BlockSize();} + + unsigned int IVSize() const {return BlockSize();} + void GetNextIV(byte *IV); + virtual IV_Requirement IVRequirement() const =0; + +protected: + inline unsigned int BlockSize() const {assert(m_register.size() > 0); return m_register.size();} + void SetIV(const byte *iv); + virtual void SetFeedbackSize(unsigned int feedbackSize) + { + if (!(feedbackSize == 0 || feedbackSize == BlockSize())) + throw InvalidArgument("CipherModeBase: feedback size cannot be specified for this cipher mode"); + } + virtual void ResizeBuffers() + { + m_register.New(m_cipher->BlockSize()); + } + virtual void UncheckedSetKey(const NameValuePairs ¶ms, const byte *key, unsigned int length) =0; + + BlockCipher *m_cipher; + SecByteBlock m_register; +}; + +template <class POLICY_INTERFACE> +class ModePolicyCommonTemplate : public CipherModeBase, public POLICY_INTERFACE +{ + unsigned int GetAlignment() const {return m_cipher->BlockAlignment();} + void CipherSetKey(const NameValuePairs ¶ms, const byte *key, unsigned int length) + { + m_cipher->SetKey(key, length, params); + ResizeBuffers(); + int feedbackSize = params.GetIntValueWithDefault(Name::FeedbackSize(), 0); + SetFeedbackSize(feedbackSize); + const byte *iv = params.GetValueWithDefault(Name::IV(), (const byte *)NULL); + SetIV(iv); + } +}; + +class CFB_ModePolicy : public ModePolicyCommonTemplate<CFB_CipherAbstractPolicy> +{ +public: + IV_Requirement IVRequirement() const {return RANDOM_IV;} + +protected: + unsigned int GetBytesPerIteration() const {return m_feedbackSize;} + byte * GetRegisterBegin() {return m_register + BlockSize() - m_feedbackSize;} + void TransformRegister() + { + m_cipher->ProcessBlock(m_register, m_temp); + memmove(m_register, m_register+m_feedbackSize, BlockSize()-m_feedbackSize); + memcpy(m_register+BlockSize()-m_feedbackSize, m_temp, m_feedbackSize); + } + void CipherResynchronize(const byte *iv) + { + memcpy(m_register, iv, BlockSize()); + TransformRegister(); + } + void SetFeedbackSize(unsigned int feedbackSize) + { + if (feedbackSize > BlockSize()) + throw InvalidArgument("CFB_Mode: invalid feedback size"); + m_feedbackSize = feedbackSize ? feedbackSize : BlockSize(); + } + void ResizeBuffers() + { + CipherModeBase::ResizeBuffers(); + m_temp.New(BlockSize()); + } + + SecByteBlock m_temp; + unsigned int m_feedbackSize; +}; + +class OFB_ModePolicy : public ModePolicyCommonTemplate<AdditiveCipherAbstractPolicy> +{ + unsigned int GetBytesPerIteration() const {return BlockSize();} + unsigned int GetIterationsToBuffer() const {return 1;} + void WriteKeystream(byte *keystreamBuffer, unsigned int iterationCount) + { + assert(iterationCount == 1); + m_cipher->ProcessBlock(keystreamBuffer); + } + void CipherResynchronize(byte *keystreamBuffer, const byte *iv) + { + memcpy(keystreamBuffer, iv, BlockSize()); + } + bool IsRandomAccess() const {return false;} + IV_Requirement IVRequirement() const {return STRUCTURED_IV;} +}; + +class CTR_ModePolicy : public ModePolicyCommonTemplate<AdditiveCipherAbstractPolicy> +{ + unsigned int GetBytesPerIteration() const {return BlockSize();} + unsigned int GetIterationsToBuffer() const {return m_cipher->OptimalNumberOfParallelBlocks();} + void WriteKeystream(byte *buffer, unsigned int iterationCount) + {OperateKeystream(WRITE_KEYSTREAM, buffer, NULL, iterationCount);} + bool CanOperateKeystream() const {return true;} + void OperateKeystream(KeystreamOperation operation, byte *output, const byte *input, unsigned int iterationCount); + void CipherResynchronize(byte *keystreamBuffer, const byte *iv); + bool IsRandomAccess() const {return true;} + void SeekToIteration(dword iterationCount); + IV_Requirement IVRequirement() const {return STRUCTURED_IV;} + + inline void ProcessMultipleBlocks(byte *output, const byte *input, unsigned int n); + + SecByteBlock m_counterArray; +}; + +class BlockOrientedCipherModeBase : public CipherModeBase +{ +public: + void UncheckedSetKey(const NameValuePairs ¶ms, const byte *key, unsigned int length); + unsigned int MandatoryBlockSize() const {return BlockSize();} + bool IsRandomAccess() const {return false;} + bool IsSelfInverting() const {return false;} + bool IsForwardTransformation() const {return m_cipher->IsForwardTransformation();} + void Resynchronize(const byte *iv) {memcpy(m_register, iv, BlockSize());} + void ProcessData(byte *outString, const byte *inString, unsigned int length); + +protected: + bool RequireAlignedInput() const {return true;} + virtual void ProcessBlocks(byte *outString, const byte *inString, unsigned int numberOfBlocks) =0; + void ResizeBuffers() + { + CipherModeBase::ResizeBuffers(); + m_buffer.New(BlockSize()); + } + + SecByteBlock m_buffer; +}; + +class ECB_OneWay : public BlockOrientedCipherModeBase +{ +public: + IV_Requirement IVRequirement() const {return NOT_RESYNCHRONIZABLE;} + unsigned int OptimalBlockSize() const {return BlockSize() * m_cipher->OptimalNumberOfParallelBlocks();} + void ProcessBlocks(byte *outString, const byte *inString, unsigned int numberOfBlocks) + {m_cipher->ProcessAndXorMultipleBlocks(inString, NULL, outString, numberOfBlocks);} +}; + +class CBC_ModeBase : public BlockOrientedCipherModeBase +{ +public: + IV_Requirement IVRequirement() const {return UNPREDICTABLE_RANDOM_IV;} + bool RequireAlignedInput() const {return false;} + unsigned int MinLastBlockSize() const {return 0;} +}; + +class CBC_Encryption : public CBC_ModeBase +{ +public: + void ProcessBlocks(byte *outString, const byte *inString, unsigned int numberOfBlocks); +}; + +class CBC_CTS_Encryption : public CBC_Encryption +{ +public: + void SetStolenIV(byte *iv) {m_stolenIV = iv;} + unsigned int MinLastBlockSize() const {return BlockSize()+1;} + void ProcessLastBlock(byte *outString, const byte *inString, unsigned int length); + +protected: + void UncheckedSetKey(const NameValuePairs ¶ms, const byte *key, unsigned int length) + { + CBC_Encryption::UncheckedSetKey(params, key, length); + m_stolenIV = params.GetValueWithDefault(Name::StolenIV(), (byte *)NULL); + } + + byte *m_stolenIV; +}; + +class CBC_Decryption : public CBC_ModeBase +{ +public: + void ProcessBlocks(byte *outString, const byte *inString, unsigned int numberOfBlocks); + +protected: + void ResizeBuffers() + { + BlockOrientedCipherModeBase::ResizeBuffers(); + m_temp.New(BlockSize()); + } + SecByteBlock m_temp; +}; + +class CBC_CTS_Decryption : public CBC_Decryption +{ +public: + unsigned int MinLastBlockSize() const {return BlockSize()+1;} + void ProcessLastBlock(byte *outString, const byte *inString, unsigned int length); +}; + +//! . +template <class CIPHER, class BASE> +class CipherModeFinalTemplate_CipherHolder : public ObjectHolder<CIPHER>, public BASE +{ +public: + CipherModeFinalTemplate_CipherHolder() + { + m_cipher = &m_object; + ResizeBuffers(); + } + CipherModeFinalTemplate_CipherHolder(const byte *key, unsigned int length) + { + m_cipher = &m_object; + SetKey(key, length); + } + CipherModeFinalTemplate_CipherHolder(const byte *key, unsigned int length, const byte *iv, int feedbackSize = 0) + { + m_cipher = &m_object; + SetKey(key, length, MakeParameters("IV", iv)("FeedbackSize", feedbackSize)); + } +}; + +//! . +template <class BASE> +class CipherModeFinalTemplate_ExternalCipher : public BASE +{ +public: + CipherModeFinalTemplate_ExternalCipher(BlockCipher &cipher, const byte *iv = NULL, int feedbackSize = 0) + { + m_cipher = &cipher; + ResizeBuffers(); + SetFeedbackSize(feedbackSize); + SetIV(iv); + } +}; + +//! CFB mode +template <class CIPHER> +struct CFB_Mode : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_CipherHolder<CPP_TYPENAME CIPHER::Encryption, ConcretePolicyHolder<Empty, CFB_EncryptionTemplate<AbstractPolicyHolder<CFB_CipherAbstractPolicy, CFB_ModePolicy> > > > Encryption; + typedef CipherModeFinalTemplate_CipherHolder<CPP_TYPENAME CIPHER::Encryption, ConcretePolicyHolder<Empty, CFB_DecryptionTemplate<AbstractPolicyHolder<CFB_CipherAbstractPolicy, CFB_ModePolicy> > > > Decryption; +}; + +//! CFB mode, external cipher +struct CFB_Mode_ExternalCipher : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_ExternalCipher<ConcretePolicyHolder<Empty, CFB_EncryptionTemplate<AbstractPolicyHolder<CFB_CipherAbstractPolicy, CFB_ModePolicy> > > > Encryption; + typedef CipherModeFinalTemplate_ExternalCipher<ConcretePolicyHolder<Empty, CFB_DecryptionTemplate<AbstractPolicyHolder<CFB_CipherAbstractPolicy, CFB_ModePolicy> > > > Decryption; +}; + +//! OFB mode +template <class CIPHER> +struct OFB_Mode : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_CipherHolder<CPP_TYPENAME CIPHER::Encryption, ConcretePolicyHolder<Empty, AdditiveCipherTemplate<AbstractPolicyHolder<AdditiveCipherAbstractPolicy, OFB_ModePolicy> > > > Encryption; + typedef Encryption Decryption; +}; + +//! OFB mode, external cipher +struct OFB_Mode_ExternalCipher : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_ExternalCipher<ConcretePolicyHolder<Empty, AdditiveCipherTemplate<AbstractPolicyHolder<AdditiveCipherAbstractPolicy, OFB_ModePolicy> > > > Encryption; + typedef Encryption Decryption; +}; + +//! CTR mode +template <class CIPHER> +struct CTR_Mode : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_CipherHolder<CPP_TYPENAME CIPHER::Encryption, ConcretePolicyHolder<Empty, AdditiveCipherTemplate<AbstractPolicyHolder<AdditiveCipherAbstractPolicy, CTR_ModePolicy> > > > Encryption; + typedef Encryption Decryption; +}; + +//! CTR mode, external cipher +struct CTR_Mode_ExternalCipher : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_ExternalCipher<ConcretePolicyHolder<Empty, AdditiveCipherTemplate<AbstractPolicyHolder<AdditiveCipherAbstractPolicy, CTR_ModePolicy> > > > Encryption; + typedef Encryption Decryption; +}; + +//! ECB mode +template <class CIPHER> +struct ECB_Mode : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_CipherHolder<CPP_TYPENAME CIPHER::Encryption, ECB_OneWay> Encryption; + typedef CipherModeFinalTemplate_CipherHolder<CPP_TYPENAME CIPHER::Decryption, ECB_OneWay> Decryption; +}; + +//! ECB mode, external cipher +struct ECB_Mode_ExternalCipher : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_ExternalCipher<ECB_OneWay> Encryption; + typedef Encryption Decryption; +}; + +//! CBC mode +template <class CIPHER> +struct CBC_Mode : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_CipherHolder<CPP_TYPENAME CIPHER::Encryption, CBC_Encryption> Encryption; + typedef CipherModeFinalTemplate_CipherHolder<CPP_TYPENAME CIPHER::Decryption, CBC_Decryption> Decryption; +}; + +//! CBC mode, external cipher +struct CBC_Mode_ExternalCipher : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_ExternalCipher<CBC_Encryption> Encryption; + typedef CipherModeFinalTemplate_ExternalCipher<CBC_Decryption> Decryption; +}; + +//! CBC mode with ciphertext stealing +template <class CIPHER> +struct CBC_CTS_Mode : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_CipherHolder<CPP_TYPENAME CIPHER::Encryption, CBC_CTS_Encryption> Encryption; + typedef CipherModeFinalTemplate_CipherHolder<CPP_TYPENAME CIPHER::Decryption, CBC_CTS_Decryption> Decryption; +}; + +//! CBC mode with ciphertext stealing, external cipher +struct CBC_CTS_Mode_ExternalCipher : public CipherModeDocumentation +{ + typedef CipherModeFinalTemplate_ExternalCipher<CBC_CTS_Encryption> Encryption; + typedef CipherModeFinalTemplate_ExternalCipher<CBC_CTS_Decryption> Decryption; +}; + +#ifdef CRYPTOPP_MAINTAIN_BACKWARDS_COMPATIBILITY +typedef CFB_Mode_ExternalCipher::Encryption CFBEncryption; +typedef CFB_Mode_ExternalCipher::Decryption CFBDecryption; +typedef OFB_Mode_ExternalCipher::Encryption OFB; +typedef CTR_Mode_ExternalCipher::Encryption CounterMode; +#endif + +NAMESPACE_END + +#endif |