summaryrefslogtreecommitdiff
path: root/extra/yassl/src/handshake.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'extra/yassl/src/handshake.cpp')
-rw-r--r--extra/yassl/src/handshake.cpp101
1 files changed, 81 insertions, 20 deletions
diff --git a/extra/yassl/src/handshake.cpp b/extra/yassl/src/handshake.cpp
index 25f36c4ea8c..c03d72ff2ef 100644
--- a/extra/yassl/src/handshake.cpp
+++ b/extra/yassl/src/handshake.cpp
@@ -40,9 +40,11 @@ namespace yaSSL {
// Build a client hello message from cipher suites and compression method
-void buildClientHello(SSL& ssl, ClientHello& hello,
- CompressionMethod compression = no_compression)
+void buildClientHello(SSL& ssl, ClientHello& hello)
{
+ // store for pre master secret
+ ssl.useSecurity().use_connection().chVersion_ = hello.client_version_;
+
ssl.getCrypto().get_random().Fill(hello.random_, RAN_LEN);
if (ssl.getSecurity().get_resuming()) {
hello.id_len_ = ID_LEN;
@@ -55,7 +57,6 @@ void buildClientHello(SSL& ssl, ClientHello& hello,
memcpy(hello.cipher_suites_, ssl.getSecurity().get_parms().suites_,
hello.suite_len_);
hello.comp_len_ = 1;
- hello.compression_methods_ = compression;
hello.set_length(sizeof(ProtocolVersion) +
RAN_LEN +
@@ -83,7 +84,7 @@ void buildServerHello(SSL& ssl, ServerHello& hello)
hello.cipher_suite_[0] = ssl.getSecurity().get_parms().suite_[0];
hello.cipher_suite_[1] = ssl.getSecurity().get_parms().suite_[1];
- hello.compression_method_ = no_compression;
+ hello.compression_method_ = hello.compression_method_;
hello.set_length(sizeof(ProtocolVersion) + RAN_LEN + ID_LEN +
sizeof(hello.id_len_) + SUITE_LEN + SIZEOF_ENUM);
@@ -151,12 +152,18 @@ void buildHeaders(SSL& ssl, HandShakeHeader& hsHeader,
// add handshake from buffer into md5 and sha hashes, exclude record header
-void hashHandShake(SSL& ssl, const output_buffer& output)
+void hashHandShake(SSL& ssl, const output_buffer& output, bool removeIV = false)
{
uint sz = output.get_size() - RECORD_HEADER;
const opaque* buffer = output.get_buffer() + RECORD_HEADER;
+ if (removeIV) { // TLSv1_1 IV
+ uint blockSz = ssl.getCrypto().get_cipher().get_blockSize();
+ sz -= blockSz;
+ buffer += blockSz;
+ }
+
ssl.useHashes().use_MD5().update(buffer, sz);
ssl.useHashes().use_SHA().update(buffer, sz);
}
@@ -229,6 +236,18 @@ void decrypt_message(SSL& ssl, input_buffer& input, uint sz)
ssl.useCrypto().use_cipher().decrypt(plain.get_buffer(), cipher, sz);
memcpy(cipher, plain.get_buffer(), sz);
ssl.useSecurity().use_parms().encrypt_size_ = sz;
+
+ if (ssl.isTLSv1_1()) // IV
+ input.set_current(input.get_current() +
+ ssl.getCrypto().get_cipher().get_blockSize());
+}
+
+
+// output operator for input_buffer
+output_buffer& operator<<(output_buffer& output, const input_buffer& input)
+{
+ output.write(input.get_buffer(), input.get_size());
+ return output;
}
@@ -239,9 +258,12 @@ void cipherFinished(SSL& ssl, Finished& fin, output_buffer& output)
uint finishedSz = ssl.isTLS() ? TLS_FINISHED_SZ : FINISHED_SZ;
uint sz = RECORD_HEADER + HANDSHAKE_HEADER + finishedSz + digestSz;
uint pad = 0;
+ uint blockSz = ssl.getCrypto().get_cipher().get_blockSize();
+
if (ssl.getSecurity().get_parms().cipher_type_ == block) {
+ if (ssl.isTLSv1_1())
+ sz += blockSz; // IV
sz += 1; // pad byte
- uint blockSz = ssl.getCrypto().get_cipher().get_blockSize();
pad = (sz - RECORD_HEADER) % blockSz;
pad = blockSz - pad;
sz += pad;
@@ -252,14 +274,21 @@ void cipherFinished(SSL& ssl, Finished& fin, output_buffer& output)
buildHeaders(ssl, hsHeader, rlHeader, fin);
rlHeader.length_ = sz - RECORD_HEADER; // record header includes mac
// and pad, hanshake doesn't
+ input_buffer iv;
+ if (ssl.isTLSv1_1() && ssl.getSecurity().get_parms().cipher_type_== block){
+ iv.allocate(blockSz);
+ ssl.getCrypto().get_random().Fill(iv.get_buffer(), blockSz);
+ iv.add_size(blockSz);
+ }
+ uint ivSz = iv.get_size();
output.allocate(sz);
- output << rlHeader << hsHeader << fin;
+ output << rlHeader << iv << hsHeader << fin;
- hashHandShake(ssl, output);
+ hashHandShake(ssl, output, ssl.isTLSv1_1() ? true : false);
opaque digest[SHA_LEN]; // max size
if (ssl.isTLS())
- TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER,
- output.get_size() - RECORD_HEADER, handshake);
+ TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER + ivSz,
+ output.get_size() - RECORD_HEADER - ivSz, handshake);
else
hmac(ssl, digest, output.get_buffer() + RECORD_HEADER,
output.get_size() - RECORD_HEADER, handshake);
@@ -282,9 +311,12 @@ void buildMessage(SSL& ssl, output_buffer& output, const Message& msg)
uint digestSz = ssl.getCrypto().get_digest().get_digestSize();
uint sz = RECORD_HEADER + msg.get_length() + digestSz;
uint pad = 0;
+ uint blockSz = ssl.getCrypto().get_cipher().get_blockSize();
+
if (ssl.getSecurity().get_parms().cipher_type_ == block) {
+ if (ssl.isTLSv1_1()) // IV
+ sz += blockSz;
sz += 1; // pad byte
- uint blockSz = ssl.getCrypto().get_cipher().get_blockSize();
pad = (sz - RECORD_HEADER) % blockSz;
pad = blockSz - pad;
sz += pad;
@@ -294,13 +326,21 @@ void buildMessage(SSL& ssl, output_buffer& output, const Message& msg)
buildHeader(ssl, rlHeader, msg);
rlHeader.length_ = sz - RECORD_HEADER; // record header includes mac
// and pad, hanshake doesn't
+ input_buffer iv;
+ if (ssl.isTLSv1_1() && ssl.getSecurity().get_parms().cipher_type_== block){
+ iv.allocate(blockSz);
+ ssl.getCrypto().get_random().Fill(iv.get_buffer(), blockSz);
+ iv.add_size(blockSz);
+ }
+
+ uint ivSz = iv.get_size();
output.allocate(sz);
- output << rlHeader << msg;
+ output << rlHeader << iv << msg;
opaque digest[SHA_LEN]; // max size
if (ssl.isTLS())
- TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER,
- output.get_size() - RECORD_HEADER, msg.get_type());
+ TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER + ivSz,
+ output.get_size() - RECORD_HEADER - ivSz, msg.get_type());
else
hmac(ssl, digest, output.get_buffer() + RECORD_HEADER,
output.get_size() - RECORD_HEADER, msg.get_type());
@@ -456,6 +496,10 @@ void buildSHA_CertVerify(SSL& ssl, byte* digest)
// some clients still send sslv2 client hello
void ProcessOldClientHello(input_buffer& input, SSL& ssl)
{
+ if (input.get_remaining() < 2) {
+ ssl.SetError(bad_input);
+ return;
+ }
byte b0 = input[AUTO];
byte b1 = input[AUTO];
@@ -721,6 +765,7 @@ int DoProcessReply(SSL& ssl)
// each message in record, can be more than 1 if not encrypted
if (ssl.getSecurity().get_parms().pending_ == false) // cipher on
decrypt_message(ssl, buffer, hdr.length_);
+
mySTL::auto_ptr<Message> msg(mf.CreateObject(hdr.type_));
if (!msg.get()) {
ssl.SetError(factory_error);
@@ -744,13 +789,13 @@ void processReply(SSL& ssl)
if (DoProcessReply(ssl))
// didn't complete process
- if (!ssl.getSocket().IsBlocking()) {
- // keep trying now
+ if (!ssl.getSocket().IsNonBlocking()) {
+ // keep trying now, blocking ok
while (!ssl.GetError())
if (DoProcessReply(ssl) == 0) break;
}
else
- // user will have try again later
+ // user will have try again later, non blocking
ssl.SetError(YasslError(SSL_ERROR_WANT_READ));
}
@@ -761,7 +806,8 @@ void sendClientHello(SSL& ssl)
ssl.verifyState(serverNull);
if (ssl.GetError()) return;
- ClientHello ch(ssl.getSecurity().get_connection().version_);
+ ClientHello ch(ssl.getSecurity().get_connection().version_,
+ ssl.getSecurity().get_connection().compression_);
RecordLayerHeader rlHeader;
HandShakeHeader hsHeader;
output_buffer out;
@@ -859,6 +905,7 @@ void sendFinished(SSL& ssl, ConnectionEnd side, BufferOutput buffer)
buildFinished(ssl, ssl.useHashes().use_verify(), client); // client
}
else {
+ if (!ssl.getSecurity().GetContext()->GetSessionCacheOff())
GetSessions().add(ssl); // store session
if (side == client_end)
buildFinished(ssl, ssl.useHashes().use_verify(), server); // server
@@ -885,7 +932,20 @@ int sendData(SSL& ssl, const void* buffer, int sz)
for (;;) {
int len = min(sz - sent, MAX_RECORD_SIZE);
output_buffer out;
- const Data data(len, static_cast<const opaque*>(buffer) + sent);
+ input_buffer tmp;
+
+ Data data;
+
+ if (ssl.CompressionOn()) {
+ if (Compress(static_cast<const opaque*>(buffer) + sent, len,
+ tmp) == -1) {
+ ssl.SetError(compress_error);
+ return -1;
+ }
+ data.SetData(tmp.get_size(), tmp.get_buffer());
+ }
+ else
+ data.SetData(len, static_cast<const opaque*>(buffer) + sent);
buildMessage(ssl, out, data);
ssl.Send(out.get_buffer(), out.get_size());
@@ -947,7 +1007,8 @@ void sendServerHello(SSL& ssl, BufferOutput buffer)
ssl.verifyState(clientHelloComplete);
if (ssl.GetError()) return;
- ServerHello sh(ssl.getSecurity().get_connection().version_);
+ ServerHello sh(ssl.getSecurity().get_connection().version_,
+ ssl.getSecurity().get_connection().compression_);
RecordLayerHeader rlHeader;
HandShakeHeader hsHeader;
mySTL::auto_ptr<output_buffer> out(NEW_YS output_buffer);