diff options
Diffstat (limited to 'extra/yassl/src/handshake.cpp')
-rw-r--r-- | extra/yassl/src/handshake.cpp | 47 |
1 files changed, 32 insertions, 15 deletions
diff --git a/extra/yassl/src/handshake.cpp b/extra/yassl/src/handshake.cpp index 08fae4ac17d..c1ee61d043e 100644 --- a/extra/yassl/src/handshake.cpp +++ b/extra/yassl/src/handshake.cpp @@ -50,7 +50,7 @@ void buildClientHello(SSL& ssl, ClientHello& hello) hello.suite_len_ = ssl.getSecurity().get_parms().suites_size_; memcpy(hello.cipher_suites_, ssl.getSecurity().get_parms().suites_, hello.suite_len_); - hello.comp_len_ = 1; + hello.comp_len_ = 1; hello.set_length(sizeof(ProtocolVersion) + RAN_LEN + @@ -528,8 +528,9 @@ void ProcessOldClientHello(input_buffer& input, SSL& ssl) input.read(len, sizeof(len)); uint16 randomLen; ato16(len, randomLen); + if (ch.suite_len_ > MAX_SUITE_SZ || sessionLen > ID_LEN || - randomLen > RAN_LEN) { + randomLen > RAN_LEN) { ssl.SetError(bad_input); return; } @@ -707,7 +708,7 @@ int DoProcessReply(SSL& ssl) { // wait for input if blocking if (!ssl.useSocket().wait()) { - ssl.SetError(receive_error); + ssl.SetError(receive_error); return 0; } uint ready = ssl.getSocket().get_ready(); @@ -750,8 +751,8 @@ int DoProcessReply(SSL& ssl) if (static_cast<uint>(RECORD_HEADER) > buffer.get_remaining()) needHdr = true; else { - buffer >> hdr; - ssl.verifyState(hdr); + buffer >> hdr; + ssl.verifyState(hdr); } // make sure we have enough input in buffer to process this record @@ -789,9 +790,8 @@ int DoProcessReply(SSL& ssl) void processReply(SSL& ssl) { if (ssl.GetError()) return; - - if (DoProcessReply(ssl)) - { + + if (DoProcessReply(ssl)) { // didn't complete process if (!ssl.getSocket().IsNonBlocking()) { // keep trying now, blocking ok @@ -857,6 +857,7 @@ void sendServerKeyExchange(SSL& ssl, BufferOutput buffer) if (ssl.GetError()) return; ServerKeyExchange sk(ssl); sk.build(ssl); + if (ssl.GetError()) return; RecordLayerHeader rlHeader; HandShakeHeader hsHeader; @@ -875,8 +876,7 @@ void sendServerKeyExchange(SSL& ssl, BufferOutput buffer) // send change cipher void sendChangeCipher(SSL& ssl, BufferOutput buffer) { - if (ssl.getSecurity().get_parms().entity_ == server_end) - { + if (ssl.getSecurity().get_parms().entity_ == server_end) { if (ssl.getSecurity().get_resuming()) ssl.verifyState(clientKeyExchangeComplete); else @@ -913,7 +913,7 @@ void sendFinished(SSL& ssl, ConnectionEnd side, BufferOutput buffer) } else { if (!ssl.getSecurity().GetContext()->GetSessionCacheOff()) - GetSessions().add(ssl); // store session + GetSessions().add(ssl); // store session if (side == client_end) buildFinished(ssl, ssl.useHashes().use_verify(), server); // server } @@ -929,12 +929,22 @@ void sendFinished(SSL& ssl, ConnectionEnd side, BufferOutput buffer) // send data int sendData(SSL& ssl, const void* buffer, int sz) { + int sent = 0; + if (ssl.GetError() == YasslError(SSL_ERROR_WANT_READ)) ssl.SetError(no_error); + if (ssl.GetError() == YasslError(SSL_ERROR_WANT_WRITE)) { + ssl.SetError(no_error); + ssl.SendWriteBuffered(); + if (!ssl.GetError()) { + // advance sent to prvevious sent + plain size just sent + sent = ssl.useBuffers().prevSent + ssl.useBuffers().plainSz; + } + } + ssl.verfiyHandShakeComplete(); if (ssl.GetError()) return -1; - int sent = 0; for (;;) { int len = min(sz - sent, MAX_RECORD_SIZE); @@ -943,6 +953,8 @@ int sendData(SSL& ssl, const void* buffer, int sz) Data data; + if (sent == sz) break; + if (ssl.CompressionOn()) { if (Compress(static_cast<const opaque*>(buffer) + sent, len, tmp) == -1) { @@ -957,9 +969,14 @@ int sendData(SSL& ssl, const void* buffer, int sz) buildMessage(ssl, out, data); ssl.Send(out.get_buffer(), out.get_size()); - if (ssl.GetError()) return -1; + if (ssl.GetError()) { + if (ssl.GetError() == YasslError(SSL_ERROR_WANT_WRITE)) { + ssl.useBuffers().plainSz = len; + ssl.useBuffers().prevSent = sent; + } + return -1; + } sent += len; - if (sent == sz) break; } ssl.useLog().ShowData(sent, true); return sent; @@ -992,7 +1009,7 @@ int receiveData(SSL& ssl, Data& data, bool peek) if (peek) ssl.PeekData(data); else - ssl.fillData(data); + ssl.fillData(data); ssl.useLog().ShowData(data.get_length()); if (ssl.GetError()) return -1; |