diff options
Diffstat (limited to 'extra/yassl/src/handshake.cpp')
-rw-r--r-- | extra/yassl/src/handshake.cpp | 101 |
1 files changed, 81 insertions, 20 deletions
diff --git a/extra/yassl/src/handshake.cpp b/extra/yassl/src/handshake.cpp index 25f36c4ea8c..c03d72ff2ef 100644 --- a/extra/yassl/src/handshake.cpp +++ b/extra/yassl/src/handshake.cpp @@ -40,9 +40,11 @@ namespace yaSSL { // Build a client hello message from cipher suites and compression method -void buildClientHello(SSL& ssl, ClientHello& hello, - CompressionMethod compression = no_compression) +void buildClientHello(SSL& ssl, ClientHello& hello) { + // store for pre master secret + ssl.useSecurity().use_connection().chVersion_ = hello.client_version_; + ssl.getCrypto().get_random().Fill(hello.random_, RAN_LEN); if (ssl.getSecurity().get_resuming()) { hello.id_len_ = ID_LEN; @@ -55,7 +57,6 @@ void buildClientHello(SSL& ssl, ClientHello& hello, memcpy(hello.cipher_suites_, ssl.getSecurity().get_parms().suites_, hello.suite_len_); hello.comp_len_ = 1; - hello.compression_methods_ = compression; hello.set_length(sizeof(ProtocolVersion) + RAN_LEN + @@ -83,7 +84,7 @@ void buildServerHello(SSL& ssl, ServerHello& hello) hello.cipher_suite_[0] = ssl.getSecurity().get_parms().suite_[0]; hello.cipher_suite_[1] = ssl.getSecurity().get_parms().suite_[1]; - hello.compression_method_ = no_compression; + hello.compression_method_ = hello.compression_method_; hello.set_length(sizeof(ProtocolVersion) + RAN_LEN + ID_LEN + sizeof(hello.id_len_) + SUITE_LEN + SIZEOF_ENUM); @@ -151,12 +152,18 @@ void buildHeaders(SSL& ssl, HandShakeHeader& hsHeader, // add handshake from buffer into md5 and sha hashes, exclude record header -void hashHandShake(SSL& ssl, const output_buffer& output) +void hashHandShake(SSL& ssl, const output_buffer& output, bool removeIV = false) { uint sz = output.get_size() - RECORD_HEADER; const opaque* buffer = output.get_buffer() + RECORD_HEADER; + if (removeIV) { // TLSv1_1 IV + uint blockSz = ssl.getCrypto().get_cipher().get_blockSize(); + sz -= blockSz; + buffer += blockSz; + } + ssl.useHashes().use_MD5().update(buffer, sz); ssl.useHashes().use_SHA().update(buffer, sz); } @@ -229,6 +236,18 @@ void decrypt_message(SSL& ssl, input_buffer& input, uint sz) ssl.useCrypto().use_cipher().decrypt(plain.get_buffer(), cipher, sz); memcpy(cipher, plain.get_buffer(), sz); ssl.useSecurity().use_parms().encrypt_size_ = sz; + + if (ssl.isTLSv1_1()) // IV + input.set_current(input.get_current() + + ssl.getCrypto().get_cipher().get_blockSize()); +} + + +// output operator for input_buffer +output_buffer& operator<<(output_buffer& output, const input_buffer& input) +{ + output.write(input.get_buffer(), input.get_size()); + return output; } @@ -239,9 +258,12 @@ void cipherFinished(SSL& ssl, Finished& fin, output_buffer& output) uint finishedSz = ssl.isTLS() ? TLS_FINISHED_SZ : FINISHED_SZ; uint sz = RECORD_HEADER + HANDSHAKE_HEADER + finishedSz + digestSz; uint pad = 0; + uint blockSz = ssl.getCrypto().get_cipher().get_blockSize(); + if (ssl.getSecurity().get_parms().cipher_type_ == block) { + if (ssl.isTLSv1_1()) + sz += blockSz; // IV sz += 1; // pad byte - uint blockSz = ssl.getCrypto().get_cipher().get_blockSize(); pad = (sz - RECORD_HEADER) % blockSz; pad = blockSz - pad; sz += pad; @@ -252,14 +274,21 @@ void cipherFinished(SSL& ssl, Finished& fin, output_buffer& output) buildHeaders(ssl, hsHeader, rlHeader, fin); rlHeader.length_ = sz - RECORD_HEADER; // record header includes mac // and pad, hanshake doesn't + input_buffer iv; + if (ssl.isTLSv1_1() && ssl.getSecurity().get_parms().cipher_type_== block){ + iv.allocate(blockSz); + ssl.getCrypto().get_random().Fill(iv.get_buffer(), blockSz); + iv.add_size(blockSz); + } + uint ivSz = iv.get_size(); output.allocate(sz); - output << rlHeader << hsHeader << fin; + output << rlHeader << iv << hsHeader << fin; - hashHandShake(ssl, output); + hashHandShake(ssl, output, ssl.isTLSv1_1() ? true : false); opaque digest[SHA_LEN]; // max size if (ssl.isTLS()) - TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER, - output.get_size() - RECORD_HEADER, handshake); + TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER + ivSz, + output.get_size() - RECORD_HEADER - ivSz, handshake); else hmac(ssl, digest, output.get_buffer() + RECORD_HEADER, output.get_size() - RECORD_HEADER, handshake); @@ -282,9 +311,12 @@ void buildMessage(SSL& ssl, output_buffer& output, const Message& msg) uint digestSz = ssl.getCrypto().get_digest().get_digestSize(); uint sz = RECORD_HEADER + msg.get_length() + digestSz; uint pad = 0; + uint blockSz = ssl.getCrypto().get_cipher().get_blockSize(); + if (ssl.getSecurity().get_parms().cipher_type_ == block) { + if (ssl.isTLSv1_1()) // IV + sz += blockSz; sz += 1; // pad byte - uint blockSz = ssl.getCrypto().get_cipher().get_blockSize(); pad = (sz - RECORD_HEADER) % blockSz; pad = blockSz - pad; sz += pad; @@ -294,13 +326,21 @@ void buildMessage(SSL& ssl, output_buffer& output, const Message& msg) buildHeader(ssl, rlHeader, msg); rlHeader.length_ = sz - RECORD_HEADER; // record header includes mac // and pad, hanshake doesn't + input_buffer iv; + if (ssl.isTLSv1_1() && ssl.getSecurity().get_parms().cipher_type_== block){ + iv.allocate(blockSz); + ssl.getCrypto().get_random().Fill(iv.get_buffer(), blockSz); + iv.add_size(blockSz); + } + + uint ivSz = iv.get_size(); output.allocate(sz); - output << rlHeader << msg; + output << rlHeader << iv << msg; opaque digest[SHA_LEN]; // max size if (ssl.isTLS()) - TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER, - output.get_size() - RECORD_HEADER, msg.get_type()); + TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER + ivSz, + output.get_size() - RECORD_HEADER - ivSz, msg.get_type()); else hmac(ssl, digest, output.get_buffer() + RECORD_HEADER, output.get_size() - RECORD_HEADER, msg.get_type()); @@ -456,6 +496,10 @@ void buildSHA_CertVerify(SSL& ssl, byte* digest) // some clients still send sslv2 client hello void ProcessOldClientHello(input_buffer& input, SSL& ssl) { + if (input.get_remaining() < 2) { + ssl.SetError(bad_input); + return; + } byte b0 = input[AUTO]; byte b1 = input[AUTO]; @@ -721,6 +765,7 @@ int DoProcessReply(SSL& ssl) // each message in record, can be more than 1 if not encrypted if (ssl.getSecurity().get_parms().pending_ == false) // cipher on decrypt_message(ssl, buffer, hdr.length_); + mySTL::auto_ptr<Message> msg(mf.CreateObject(hdr.type_)); if (!msg.get()) { ssl.SetError(factory_error); @@ -744,13 +789,13 @@ void processReply(SSL& ssl) if (DoProcessReply(ssl)) // didn't complete process - if (!ssl.getSocket().IsBlocking()) { - // keep trying now + if (!ssl.getSocket().IsNonBlocking()) { + // keep trying now, blocking ok while (!ssl.GetError()) if (DoProcessReply(ssl) == 0) break; } else - // user will have try again later + // user will have try again later, non blocking ssl.SetError(YasslError(SSL_ERROR_WANT_READ)); } @@ -761,7 +806,8 @@ void sendClientHello(SSL& ssl) ssl.verifyState(serverNull); if (ssl.GetError()) return; - ClientHello ch(ssl.getSecurity().get_connection().version_); + ClientHello ch(ssl.getSecurity().get_connection().version_, + ssl.getSecurity().get_connection().compression_); RecordLayerHeader rlHeader; HandShakeHeader hsHeader; output_buffer out; @@ -859,6 +905,7 @@ void sendFinished(SSL& ssl, ConnectionEnd side, BufferOutput buffer) buildFinished(ssl, ssl.useHashes().use_verify(), client); // client } else { + if (!ssl.getSecurity().GetContext()->GetSessionCacheOff()) GetSessions().add(ssl); // store session if (side == client_end) buildFinished(ssl, ssl.useHashes().use_verify(), server); // server @@ -885,7 +932,20 @@ int sendData(SSL& ssl, const void* buffer, int sz) for (;;) { int len = min(sz - sent, MAX_RECORD_SIZE); output_buffer out; - const Data data(len, static_cast<const opaque*>(buffer) + sent); + input_buffer tmp; + + Data data; + + if (ssl.CompressionOn()) { + if (Compress(static_cast<const opaque*>(buffer) + sent, len, + tmp) == -1) { + ssl.SetError(compress_error); + return -1; + } + data.SetData(tmp.get_size(), tmp.get_buffer()); + } + else + data.SetData(len, static_cast<const opaque*>(buffer) + sent); buildMessage(ssl, out, data); ssl.Send(out.get_buffer(), out.get_size()); @@ -947,7 +1007,8 @@ void sendServerHello(SSL& ssl, BufferOutput buffer) ssl.verifyState(clientHelloComplete); if (ssl.GetError()) return; - ServerHello sh(ssl.getSecurity().get_connection().version_); + ServerHello sh(ssl.getSecurity().get_connection().version_, + ssl.getSecurity().get_connection().compression_); RecordLayerHeader rlHeader; HandShakeHeader hsHeader; mySTL::auto_ptr<output_buffer> out(NEW_YS output_buffer); |