From c39bb5becd1e37530c4f8fea54b23d563620f3ef Mon Sep 17 00:00:00 2001 From: jacobkeeler Date: Tue, 25 Jul 2017 11:20:38 -0400 Subject: Implement reading and writing constructed payloads for StartService and EndService --- .../connection_handler/connection_handler_impl.h | 10 +- .../src/connection_handler_impl.cc | 18 +- .../test/connection_handler_impl_test.cc | 18 +- .../include/protocol_handler/session_observer.h | 9 +- .../test/protocol_handler/mock_session_observer.h | 6 + .../protocol_handler/protocol_handler_impl.h | 29 ++ .../include/protocol_handler/protocol_packet.h | 44 +++ .../protocol_handler/src/protocol_handler_impl.cc | 298 +++++++++++++++++++-- .../protocol_handler/src/protocol_packet.cc | 42 +++ .../test/protocol_handler_tm_test.cc | 8 +- 10 files changed, 440 insertions(+), 42 deletions(-) diff --git a/src/components/connection_handler/include/connection_handler/connection_handler_impl.h b/src/components/connection_handler/include/connection_handler/connection_handler_impl.h index cd8aec0ff3..a6651a6466 100644 --- a/src/components/connection_handler/include/connection_handler/connection_handler_impl.h +++ b/src/components/connection_handler/include/connection_handler/connection_handler_impl.h @@ -189,20 +189,26 @@ class ConnectionHandlerImpl const bool is_protected, uint32_t* hash_id); + // DEPRECATED + uint32_t OnSessionEndedCallback( + const transport_manager::ConnectionUID connection_handle, + const uint8_t session_id, + const uint32_t& hashCode, + const protocol_handler::ServiceType& service_type) OVERRIDE; /** * \brief Callback function used by ProtocolHandler * when Mobile Application initiates session ending. * \param connection_handle Connection identifier within which session exists * \param sessionId Identifier of the session to be ended * \param hashCode Hash used only in second version of SmartDeviceLink - * protocol. + * protocol. (Set to HASH_ID_WRONG if the hash is incorrect) * If not equal to hash assigned to session on start then operation fails. * \return uint32_t 0 if operation fails, session key otherwise */ uint32_t OnSessionEndedCallback( const transport_manager::ConnectionUID connection_handle, const uint8_t session_id, - const uint32_t& hashCode, + uint32_t* hashCode, const protocol_handler::ServiceType& service_type) OVERRIDE; /** diff --git a/src/components/connection_handler/src/connection_handler_impl.cc b/src/components/connection_handler/src/connection_handler_impl.cc index 0f6720f6a1..125a6cd769 100644 --- a/src/components/connection_handler/src/connection_handler_impl.cc +++ b/src/components/connection_handler/src/connection_handler_impl.cc @@ -378,11 +378,22 @@ void ConnectionHandlerImpl::OnMalformedMessageCallback( CloseConnection(connection_handle); } +// DEPRECATED uint32_t ConnectionHandlerImpl::OnSessionEndedCallback( const transport_manager::ConnectionUID connection_handle, const uint8_t session_id, const uint32_t& hashCode, const protocol_handler::ServiceType& service_type) { + uint32_t hashValue = hashCode; + return OnSessionEndedCallback( + connection_handle, session_id, &hashValue, service_type); +} + +uint32_t ConnectionHandlerImpl::OnSessionEndedCallback( + const transport_manager::ConnectionUID connection_handle, + const uint8_t session_id, + uint32_t* hashCode, + const protocol_handler::ServiceType& service_type) { LOG4CXX_AUTO_TRACE(logger_); connection_list_lock_.AcquireForReading(); @@ -402,12 +413,13 @@ uint32_t ConnectionHandlerImpl::OnSessionEndedCallback( "Session " << static_cast(session_id) << " to be removed"); // old version of protocol doesn't support hash - if (protocol_handler::HASH_ID_NOT_SUPPORTED != hashCode) { - if (protocol_handler::HASH_ID_WRONG == hashCode || - session_key != hashCode) { + if (protocol_handler::HASH_ID_NOT_SUPPORTED != *hashCode) { + if (protocol_handler::HASH_ID_WRONG == *hashCode || + session_key != *hashCode) { LOG4CXX_WARN(logger_, "Wrong hash_id for session " << static_cast(session_id)); + *hashCode = protocol_handler::HASH_ID_WRONG; return 0; } } diff --git a/src/components/connection_handler/test/connection_handler_impl_test.cc b/src/components/connection_handler/test/connection_handler_impl_test.cc index 6b5c2c89ad..7f4429500d 100644 --- a/src/components/connection_handler/test/connection_handler_impl_test.cc +++ b/src/components/connection_handler/test/connection_handler_impl_test.cc @@ -1030,9 +1030,10 @@ TEST_F(ConnectionHandlerTest, StartService_withServices) { TEST_F(ConnectionHandlerTest, ServiceStop_UnExistSession) { AddTestDeviceConnection(); - + uint32_t dummy_hash = 0u; const uint32_t end_session_result = - connection_handler_->OnSessionEndedCallback(uid_, 0u, 0u, kAudio); + connection_handler_->OnSessionEndedCallback( + uid_, 0u, &dummy_hash, kAudio); EXPECT_EQ(0u, end_session_result); CheckSessionExists(uid_, 0); } @@ -1040,9 +1041,10 @@ TEST_F(ConnectionHandlerTest, ServiceStop_UnExistSession) { TEST_F(ConnectionHandlerTest, ServiceStop_UnExistService) { AddTestDeviceConnection(); AddTestSession(); + uint32_t dummy_hash = 0u; const uint32_t end_session_result = connection_handler_->OnSessionEndedCallback( - uid_, start_session_id_, 0u, kAudio); + uid_, start_session_id_, &dummy_hash, kAudio); EXPECT_EQ(0u, end_session_result); CheckServiceExists(uid_, start_session_id_, kAudio, false); } @@ -1060,7 +1062,7 @@ TEST_F(ConnectionHandlerTest, ServiceStop) { const uint32_t end_session_result = connection_handler_->OnSessionEndedCallback( - uid_, start_session_id_, some_hash_id, kAudio); + uid_, start_session_id_, &some_hash_id, kAudio); EXPECT_EQ(connection_key_, end_session_result); CheckServiceExists(uid_, start_session_id_, kAudio, false); } @@ -1072,12 +1074,13 @@ TEST_F(ConnectionHandlerTest, SessionStop_CheckHash) { AddTestSession(); const uint32_t hash = connection_key_; - const uint32_t wrong_hash = hash + 1; + uint32_t wrong_hash = hash + 1; const uint32_t end_audio_wrong_hash = connection_handler_->OnSessionEndedCallback( - uid_, start_session_id_, wrong_hash, kRpc); + uid_, start_session_id_, &wrong_hash, kRpc); EXPECT_EQ(0u, end_audio_wrong_hash); + EXPECT_EQ(protocol_handler::HASH_ID_WRONG, wrong_hash); CheckSessionExists(uid_, start_session_id_); const uint32_t end_audio = connection_handler_->OnSessionEndedCallback( @@ -1092,13 +1095,14 @@ TEST_F(ConnectionHandlerTest, SessionStop_CheckSpecificHash) { for (uint32_t session = 0; session < 0xFF; ++session) { AddTestSession(); - const uint32_t wrong_hash = protocol_handler::HASH_ID_WRONG; + uint32_t wrong_hash = protocol_handler::HASH_ID_WRONG; const uint32_t hash = protocol_handler::HASH_ID_NOT_SUPPORTED; const uint32_t end_audio_wrong_hash = connection_handler_->OnSessionEndedCallback( uid_, start_session_id_, wrong_hash, kRpc); EXPECT_EQ(0u, end_audio_wrong_hash); + EXPECT_EQ(protocol_handler::HASH_ID_WRONG, wrong_hash); CheckSessionExists(uid_, start_session_id_); const uint32_t end_audio = connection_handler_->OnSessionEndedCallback( diff --git a/src/components/include/protocol_handler/session_observer.h b/src/components/include/protocol_handler/session_observer.h index 5e630c6c74..a5901baf0b 100644 --- a/src/components/include/protocol_handler/session_observer.h +++ b/src/components/include/protocol_handler/session_observer.h @@ -80,6 +80,13 @@ class SessionObserver { const bool is_protected, uint32_t* hash_id) = 0; + // DEPRECATED + virtual uint32_t OnSessionEndedCallback( + const transport_manager::ConnectionUID connection_handle, + const uint8_t sessionId, + const uint32_t& hashCode, + const protocol_handler::ServiceType& service_type) = 0; + /** * \brief Callback function used by ProtocolHandler * when Mobile Application initiates session ending. @@ -94,7 +101,7 @@ class SessionObserver { virtual uint32_t OnSessionEndedCallback( const transport_manager::ConnectionUID connection_handle, const uint8_t sessionId, - const uint32_t& hashCode, + uint32_t* hashCode, const protocol_handler::ServiceType& service_type) = 0; /** diff --git a/src/components/include/test/protocol_handler/mock_session_observer.h b/src/components/include/test/protocol_handler/mock_session_observer.h index c376cb85f5..0a86a29db5 100644 --- a/src/components/include/test/protocol_handler/mock_session_observer.h +++ b/src/components/include/test/protocol_handler/mock_session_observer.h @@ -59,6 +59,12 @@ class MockSessionObserver : public ::protocol_handler::SessionObserver { const uint8_t sessionId, const uint32_t& hashCode, const protocol_handler::ServiceType& service_type)); + MOCK_METHOD4( + OnSessionEndedCallback, + uint32_t(const transport_manager::ConnectionUID connection_handle, + const uint8_t sessionId, + uint32_t* hashCode, + const protocol_handler::ServiceType& service_type)); MOCK_METHOD1(OnApplicationFloodCallBack, void(const uint32_t& connection_key)); MOCK_METHOD1(OnMalformedMessageCallback, diff --git a/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h b/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h index b18ee07d4d..562540ff3f 100644 --- a/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h +++ b/src/components/protocol_handler/include/protocol_handler/protocol_handler_impl.h @@ -250,6 +250,14 @@ class ProtocolHandlerImpl uint8_t service_type, bool protection); + void SendStartSessionAck(ConnectionID connection_id, + uint8_t session_id, + uint8_t protocol_version, + uint32_t hash_code, + uint8_t service_type, + bool protection, + ProtocolPacket::ProtocolVersion& full_version); + const ProtocolHandlerSettings& get_settings() const OVERRIDE { return settings_; } @@ -266,6 +274,12 @@ class ProtocolHandlerImpl uint8_t protocol_version, uint8_t service_type); + void SendStartSessionNAck(ConnectionID connection_id, + uint8_t session_id, + uint8_t protocol_version, + uint8_t service_type, + std::vector& rejectedParams); + /** * \brief Sends acknowledgement of end session/service to mobile application * with session number for second version of protocol. @@ -294,6 +308,21 @@ class ProtocolHandlerImpl uint32_t session_id, uint8_t protocol_version, uint8_t service_type); + /** + * \brief Sends fail of ending session to mobile application (variant for + * Protocol v5) + * \param connection_id Identifier of connection within which + * session exists + * \param session_id ID of session ment to be ended + * \param protocol_version Version of protocol used for communication + * \param service_type Type of session: RPC or BULK Data. RPC by default + * \param rejected_params List of rejected params to send in payload + */ + void SendEndSessionNAck(ConnectionID connection_id, + uint32_t session_id, + uint8_t protocol_version, + uint8_t service_type, + std::vector& rejected_params); SessionObserver& get_session_observer() OVERRIDE; diff --git a/src/components/protocol_handler/include/protocol_handler/protocol_packet.h b/src/components/protocol_handler/include/protocol_handler/protocol_packet.h index 276c416d59..eae4a74025 100644 --- a/src/components/protocol_handler/include/protocol_handler/protocol_packet.h +++ b/src/components/protocol_handler/include/protocol_handler/protocol_packet.h @@ -63,6 +63,50 @@ class ProtocolPacket { uint32_t totalDataBytes; }; + class ProtocolVersion { + public: + ProtocolVersion(); + ProtocolVersion(uint8_t majorVersion, + uint8_t minorVersion, + uint8_t patchVersion); + ProtocolVersion(ProtocolVersion& other); + ProtocolVersion(std::string versionString); + uint8_t majorVersion; + uint8_t minorVersion; + uint8_t patchVersion; + static inline uint8_t cmp(const ProtocolVersion& version1, + const ProtocolVersion& version2) { + uint8_t diff = version1.majorVersion - version2.majorVersion; + if (diff == 0) { + diff = version1.minorVersion - version2.minorVersion; + if (diff == 0) { + diff = version1.minorVersion - version2.minorVersion; + } + } + return diff; + } + inline bool operator==(const ProtocolVersion& other) { + return ProtocolVersion::cmp(*this, other) == 0; + } + inline bool operator<(const ProtocolVersion& other) { + return ProtocolVersion::cmp(*this, other) < 0; + } + bool operator>(const ProtocolVersion& other) { + return ProtocolVersion::cmp(*this, other) > 0; + } + inline bool operator<=(const ProtocolVersion& other) { + return ProtocolVersion::cmp(*this, other) <= 0; + } + bool operator>=(const ProtocolVersion& other) { + return ProtocolVersion::cmp(*this, other) >= 0; + } + static inline ProtocolVersion* min(ProtocolVersion& version1, + ProtocolVersion& version2) { + return (version1 < version2) ? &version1 : &version2; + } + std::string to_string(); + }; + /** * \class ProtocolHeader * \brief Used for storing protocol header of a message. diff --git a/src/components/protocol_handler/src/protocol_handler_impl.cc b/src/components/protocol_handler/src/protocol_handler_impl.cc index 1fd49222f0..fce8cbfcea 100644 --- a/src/components/protocol_handler/src/protocol_handler_impl.cc +++ b/src/components/protocol_handler/src/protocol_handler_impl.cc @@ -33,6 +33,7 @@ #include "protocol_handler/protocol_handler_impl.h" #include #include // std::find +#include #include "connection_handler/connection_handler_impl.h" #include "protocol_handler/session_observer.h" @@ -57,6 +58,8 @@ std::string ConvertPacketDataToString(const uint8_t* data, const size_t kStackSize = 32768; +ProtocolPacket::ProtocolVersion defaultProtocolVersion(5, 0, 0); + ProtocolHandlerImpl::ProtocolHandlerImpl( const ProtocolHandlerSettings& settings, protocol_handler::SessionObserver& session_observer, @@ -184,17 +187,38 @@ void set_hash_id(uint32_t hash_id, protocol_handler::ProtocolPacket& packet) { void ProtocolHandlerImpl::SendStartSessionAck(ConnectionID connection_id, uint8_t session_id, - uint8_t, + uint8_t protocol_version, uint32_t hash_id, uint8_t service_type, bool protection) { LOG4CXX_AUTO_TRACE(logger_); + ProtocolPacket::ProtocolVersion* fullVersion = + new ProtocolPacket::ProtocolVersion(); + SendStartSessionAck(connection_id, + session_id, + protocol_version, + hash_id, + service_type, + protection, + *fullVersion); + delete fullVersion; +} + +void ProtocolHandlerImpl::SendStartSessionAck( + ConnectionID connection_id, + uint8_t session_id, + uint8_t protocol_version, + uint32_t hash_id, + uint8_t service_type, + bool protection, + ProtocolPacket::ProtocolVersion& full_version) { + LOG4CXX_AUTO_TRACE(logger_); - uint8_t protocolVersion = SupportedSDLProtocolVersion(); + uint8_t maxProtocolVersion = SupportedSDLProtocolVersion(); ProtocolFramePtr ptr( new protocol_handler::ProtocolPacket(connection_id, - protocolVersion, + maxProtocolVersion, protection, FRAME_TYPE_CONTROL, service_type, @@ -203,7 +227,38 @@ void ProtocolHandlerImpl::SendStartSessionAck(ConnectionID connection_id, 0u, message_counters_[session_id]++)); - set_hash_id(hash_id, *ptr); + // Cannot include a constructed payload if either side doesn't support it + if (maxProtocolVersion >= PROTOCOL_VERSION_5) { + ServiceType serviceTypeValue = ServiceTypeFromByte(service_type); + + BsonObject payloadObj; + bson_object_initialize_default(&payloadObj); + bson_object_put_int32(&payloadObj, "hashId", static_cast(hash_id)); + bson_object_put_int64( + &payloadObj, + "mtu", + static_cast( + protocol_header_validator_.max_payload_size_by_service_type( + serviceTypeValue))); + if (serviceTypeValue == kRpc) { + // Minimum protocol version supported by both + ProtocolPacket::ProtocolVersion* minVersion = + (full_version.majorVersion < PROTOCOL_VERSION_5) + ? &defaultProtocolVersion + : ProtocolPacket::ProtocolVersion::min(full_version, + defaultProtocolVersion); + char protocolVersionString[255]; + strncpy(protocolVersionString, (*minVersion).to_string().c_str(), 255); + bson_object_put_string( + &payloadObj, "protocolVersion", protocolVersionString); + } + uint8_t* payloadBytes = bson_object_to_bytes(&payloadObj); + ptr->set_data(payloadBytes, bson_object_size(&payloadObj)); + free(payloadBytes); + bson_object_deinitialize(&payloadObj); + } else { + set_hash_id(hash_id, *ptr); + } raw_ford_messages_to_mobile_.PostMessage( impl::RawFordMessageToMobile(ptr, false)); @@ -243,6 +298,52 @@ void ProtocolHandlerImpl::SendStartSessionNAck(ConnectionID connection_id, << static_cast(session_id)); } +void ProtocolHandlerImpl::SendStartSessionNAck( + ConnectionID connection_id, + uint8_t session_id, + uint8_t protocol_version, + uint8_t service_type, + std::vector& rejectedParams) { + LOG4CXX_AUTO_TRACE(logger_); + + ProtocolFramePtr ptr( + new protocol_handler::ProtocolPacket(connection_id, + protocol_version, + PROTECTION_OFF, + FRAME_TYPE_CONTROL, + service_type, + FRAME_DATA_START_SERVICE_NACK, + session_id, + 0u, + message_counters_[session_id]++)); + + if (rejectedParams.size() > 0) { + BsonObject payloadObj; + bson_object_initialize_default(&payloadObj); + BsonArray rejectedParamsArr; + bson_array_initialize(&rejectedParamsArr, rejectedParams.size()); + for (std::string param : rejectedParams) { + char paramPtr[255]; + strncpy(paramPtr, param.c_str(), 255); + bson_array_add_string(&rejectedParamsArr, paramPtr); + } + bson_object_put_array(&payloadObj, "rejectedParams", &rejectedParamsArr); + uint8_t* payloadBytes = bson_object_to_bytes(&payloadObj); + ptr->set_data(payloadBytes, bson_object_size(&payloadObj)); + free(payloadBytes); + bson_object_deinitialize(&payloadObj); + } + + raw_ford_messages_to_mobile_.PostMessage( + impl::RawFordMessageToMobile(ptr, false)); + + LOG4CXX_DEBUG(logger_, + "SendStartSessionNAck() for connection " + << connection_id << " for service_type " + << static_cast(service_type) << " session_id " + << static_cast(session_id)); +} + void ProtocolHandlerImpl::SendEndSessionNAck(ConnectionID connection_id, uint32_t session_id, uint8_t protocol_version, @@ -270,6 +371,51 @@ void ProtocolHandlerImpl::SendEndSessionNAck(ConnectionID connection_id, << static_cast(session_id)); } +void ProtocolHandlerImpl::SendEndSessionNAck( + ConnectionID connection_id, + uint32_t session_id, + uint8_t protocol_version, + uint8_t service_type, + std::vector& rejectedParams) { + LOG4CXX_AUTO_TRACE(logger_); + + ProtocolFramePtr ptr( + new protocol_handler::ProtocolPacket(connection_id, + protocol_version, + PROTECTION_OFF, + FRAME_TYPE_CONTROL, + service_type, + FRAME_DATA_END_SERVICE_NACK, + session_id, + 0u, + message_counters_[session_id]++)); + if (rejectedParams.size() > 0) { + BsonObject payloadObj; + bson_object_initialize_default(&payloadObj); + BsonArray rejectedParamsArr; + bson_array_initialize(&rejectedParamsArr, rejectedParams.size()); + for (std::string param : rejectedParams) { + char paramPtr[255]; + strncpy(paramPtr, param.c_str(), 255); + bson_array_add_string(&rejectedParamsArr, paramPtr); + } + bson_object_put_array(&payloadObj, "rejectedParams", &rejectedParamsArr); + uint8_t* payloadBytes = bson_object_to_bytes(&payloadObj); + ptr->set_data(payloadBytes, bson_object_size(&payloadObj)); + free(payloadBytes); + bson_object_deinitialize(&payloadObj); + } + + raw_ford_messages_to_mobile_.PostMessage( + impl::RawFordMessageToMobile(ptr, false)); + + LOG4CXX_DEBUG(logger_, + "SendEndSessionNAck() for connection " + << connection_id << " for service_type " + << static_cast(service_type) << " session_id " + << static_cast(session_id)); +} + SessionObserver& ProtocolHandlerImpl::get_session_observer() { return session_observer_; } @@ -421,7 +567,7 @@ void ProtocolHandlerImpl::SendMessageToMobileApp(const RawMessagePtr message, metric_observer_->StartMessageProcess(message_id, start_time); } #endif // TELEMETRY_MONITOR - const size_t max_frame_size = get_settings().maximum_payload_size(); + size_t max_frame_size = get_settings().maximum_payload_size(); size_t frame_size = MAXIMUM_FRAME_DATA_V2_SIZE; switch (message->protocol_version()) { case PROTOCOL_VERSION_3: @@ -430,6 +576,13 @@ void ProtocolHandlerImpl::SendMessageToMobileApp(const RawMessagePtr message, ? max_frame_size : MAXIMUM_FRAME_DATA_V2_SIZE; break; + case PROTOCOL_VERSION_5: + max_frame_size = + protocol_header_validator_.max_payload_size_by_service_type( + ServiceTypeFromByte(message->service_type())); + frame_size = max_frame_size > MAXIMUM_FRAME_DATA_V2_SIZE + ? max_frame_size + : MAXIMUM_FRAME_DATA_V2_SIZE; default: break; } @@ -932,10 +1085,18 @@ uint32_t get_hash_id(const ProtocolPacket& packet) { LOG4CXX_WARN(logger_, "Packet without hash data (data size less 4)"); return HASH_ID_WRONG; } - const uint32_t hash_be = *(reinterpret_cast(packet.data())); - const uint32_t hash_le = BE_TO_LE32(hash_be); - // null hash is wrong hash value - return hash_le == HASH_ID_NOT_SUPPORTED ? HASH_ID_WRONG : hash_le; + if (packet.protocol_version() >= PROTOCOL_VERSION_5) { + BsonObject obj = bson_object_from_bytes(packet.data()); + const uint32_t hash_id = (uint32_t)bson_object_get_int32(&obj, "hashId"); + bson_object_deinitialize(&obj); + return hash_id; + } else { + const uint32_t hash_be = *(reinterpret_cast(packet.data())); + const uint32_t hash_le = BE_TO_LE32(hash_be); + + // null hash is wrong hash value + return hash_le == HASH_ID_NOT_SUPPORTED ? HASH_ID_WRONG : hash_le; + } } RESULT_CODE ProtocolHandlerImpl::HandleControlMessageEndSession( @@ -943,12 +1104,14 @@ RESULT_CODE ProtocolHandlerImpl::HandleControlMessageEndSession( LOG4CXX_AUTO_TRACE(logger_); const uint8_t current_session_id = packet.session_id(); - const uint32_t hash_id = get_hash_id(packet); + uint32_t hash_id; + + hash_id = get_hash_id(packet); const ServiceType service_type = ServiceTypeFromByte(packet.service_type()); const ConnectionID connection_id = packet.connection_id(); const uint32_t session_key = session_observer_.OnSessionEndedCallback( - connection_id, current_session_id, hash_id, service_type); + connection_id, current_session_id, &hash_id, service_type); // TODO(EZamakhov): add clean up output queue (for removed service) if (session_key != 0) { @@ -961,10 +1124,22 @@ RESULT_CODE ProtocolHandlerImpl::HandleControlMessageEndSession( LOG4CXX_WARN(logger_, "Refused to end session " << static_cast(service_type) << " type."); - SendEndSessionNAck(connection_id, - current_session_id, - packet.protocol_version(), - service_type); + if (packet.protocol_version() >= PROTOCOL_VERSION_5) { + std::vector rejectedParams(0, std::string("")); + if (hash_id == protocol_handler::HASH_ID_WRONG) { + rejectedParams.push_back(std::string("hashId")); + } + SendEndSessionNAck(connection_id, + current_session_id, + packet.protocol_version(), + service_type, + rejectedParams); + } else { + SendEndSessionNAck(connection_id, + current_session_id, + packet.protocol_version(), + service_type); + } } return RESULT_OK; } @@ -1014,7 +1189,28 @@ class StartSessionHandler : public security_manager::SecurityManagerListener { , protocol_version_(protocol_version) , hash_id_(hash_id) , service_type_(service_type) - , force_protected_service_(force_protected_service) {} + , force_protected_service_(force_protected_service) + , full_version_() {} + StartSessionHandler(uint32_t connection_key, + ProtocolHandlerImpl* protocol_handler, + SessionObserver& session_observer, + ConnectionID connection_id, + int32_t session_id, + uint8_t protocol_version, + uint32_t hash_id, + ServiceType service_type, + const std::vector& force_protected_service, + ProtocolPacket::ProtocolVersion& full_version) + : connection_key_(connection_key) + , protocol_handler_(protocol_handler) + , session_observer_(session_observer) + , connection_id_(connection_id) + , session_id_(session_id) + , protocol_version_(protocol_version) + , hash_id_(hash_id) + , service_type_(service_type) + , force_protected_service_(force_protected_service) + , full_version_(full_version) {} bool OnHandshakeDone( const uint32_t connection_key, @@ -1067,6 +1263,7 @@ class StartSessionHandler : public security_manager::SecurityManagerListener { const uint32_t hash_id_; const ServiceType service_type_; const std::vector force_protected_service_; + const ProtocolPacket::ProtocolVersion full_version_; }; } // namespace #endif // ENABLE_SECURITY @@ -1129,7 +1326,26 @@ RESULT_CODE ProtocolHandlerImpl::HandleControlMessageStartSession( PROTECTION_OFF); return RESULT_OK; } - if (ssl_context->IsInitCompleted()) { + ProtocolPacket::ProtocolVersion* fullVersion; + std::vector rejectedParams(0, std::string("")); + if (packet.data_size() != 0) { + BsonObject obj = bson_object_from_bytes(packet.data()); + fullVersion = new ProtocolPacket::ProtocolVersion( + std::string(bson_object_get_string(&obj, "protocolVersion"))); + bson_object_deinitialize(&obj); + if (fullVersion->majorVersion < PROTOCOL_VERSION_5) { + rejectedParams.push_back(std::string("protocolVersion")); + } + } else { + fullVersion = new ProtocolPacket::ProtocolVersion(); + } + if (!rejectedParams.empty()) { + SendStartSessionNAck(connection_id, + packet.session_id(), + protocol_version, + packet.service_type(), + rejectedParams); + } else if (ssl_context->IsInitCompleted()) { // mark service as protected session_observer_.SetProtectionFlag(connection_key, service_type); // Start service as protected with current SSLContext @@ -1138,7 +1354,8 @@ RESULT_CODE ProtocolHandlerImpl::HandleControlMessageStartSession( packet.protocol_version(), hash_id, packet.service_type(), - PROTECTION_ON); + PROTECTION_ON, + *fullVersion); } else { security_manager_->AddListener( new StartSessionHandler(connection_key, @@ -1149,25 +1366,54 @@ RESULT_CODE ProtocolHandlerImpl::HandleControlMessageStartSession( packet.protocol_version(), hash_id, service_type, - get_settings().force_protected_service())); + get_settings().force_protected_service(), + *fullVersion)); if (!ssl_context->IsHandshakePending()) { // Start handshake process security_manager_->StartHandshake(connection_key); } } + delete fullVersion; LOG4CXX_DEBUG(logger_, "Protection establishing for connection " << connection_key << " is in progress"); return RESULT_OK; } #endif // ENABLE_SECURITY - // Start service without protection - SendStartSessionAck(connection_id, - session_id, - packet.protocol_version(), - hash_id, - packet.service_type(), - PROTECTION_OFF); + if (packet.data_size() != 0) { + BsonObject obj = bson_object_from_bytes(packet.data()); + ProtocolPacket::ProtocolVersion fullVersion( + bson_object_get_string(&obj, "protocolVersion")); + bson_object_deinitialize(&obj); + + if (fullVersion.majorVersion >= PROTOCOL_VERSION_5) { + // Start service without protection + SendStartSessionAck(connection_id, + session_id, + packet.protocol_version(), + hash_id, + packet.service_type(), + PROTECTION_OFF, + fullVersion); + } else { + std::vector rejectedParams(1, + std::string("protocolVersion")); + SendStartSessionNAck(connection_id, + packet.session_id(), + protocol_version, + packet.service_type(), + rejectedParams); + } + + } else { + // Start service without protection + SendStartSessionAck(connection_id, + session_id, + packet.protocol_version(), + hash_id, + packet.service_type(), + PROTECTION_OFF); + } return RESULT_OK; } diff --git a/src/components/protocol_handler/src/protocol_packet.cc b/src/components/protocol_handler/src/protocol_packet.cc index d4daffe1d7..0cf884f736 100644 --- a/src/components/protocol_handler/src/protocol_packet.cc +++ b/src/components/protocol_handler/src/protocol_packet.cc @@ -52,6 +52,48 @@ ProtocolPacket::ProtocolData::~ProtocolData() { delete[] data; } +ProtocolPacket::ProtocolVersion::ProtocolVersion() + : majorVersion(0), minorVersion(0), patchVersion(0) {} + +ProtocolPacket::ProtocolVersion::ProtocolVersion(uint8_t majorVersion, + uint8_t minorVersion, + uint8_t patchVersion) + : majorVersion(majorVersion) + , minorVersion(minorVersion) + , patchVersion(patchVersion) {} + +ProtocolPacket::ProtocolVersion::ProtocolVersion(ProtocolVersion& other) { + this->majorVersion = other.majorVersion; + this->minorVersion = other.minorVersion; + this->patchVersion = other.patchVersion; +} + +ProtocolPacket::ProtocolVersion::ProtocolVersion(std::string versionString) + : majorVersion(0), minorVersion(0), patchVersion(0) { + unsigned int majorInt, minorInt, patchInt; + int readElements = sscanf( + versionString.c_str(), "%u.%u.%u", &majorInt, &minorInt, &patchInt); + if (readElements != 3) { + LOG4CXX_WARN(logger_, + "Error while parsing version string: " << versionString); + } else { + majorVersion = static_cast(majorInt); + minorVersion = static_cast(minorInt); + patchVersion = static_cast(patchInt); + } +} + +std::string ProtocolPacket::ProtocolVersion::to_string() { + char versionString[255]; + snprintf(versionString, + 255, + "%u.%u.%u", + static_cast(majorVersion), + static_cast(minorVersion), + static_cast(patchVersion)); + return std::string(versionString); +} + ProtocolPacket::ProtocolHeader::ProtocolHeader() : version(0x00) , protection_flag(PROTECTION_OFF) diff --git a/src/components/protocol_handler/test/protocol_handler_tm_test.cc b/src/components/protocol_handler/test/protocol_handler_tm_test.cc index 308901e013..66226825ba 100644 --- a/src/components/protocol_handler/test/protocol_handler_tm_test.cc +++ b/src/components/protocol_handler/test/protocol_handler_tm_test.cc @@ -96,6 +96,7 @@ using connection_handler::DeviceHandle; using ::testing::Return; using ::testing::ReturnRefOfCopy; using ::testing::ReturnNull; +using ::testing::An; using ::testing::AnyOf; using ::testing::DoAll; using ::testing::_; @@ -500,12 +501,12 @@ TEST_F(ProtocolHandlerImplTest, EndSession_SessionObserverReject) { uint32_t times = 0; AddSession(waiter, times); - const ServiceType service = kRpc; // Expect ConnectionHandler check EXPECT_CALL(session_observer_mock, - OnSessionEndedCallback(connection_id, session_id, _, service)) + OnSessionEndedCallback( + connection_id, session_id, An(), service)) . // reject session start WillOnce( @@ -539,7 +540,8 @@ TEST_F(ProtocolHandlerImplTest, EndSession_Success) { // Expect ConnectionHandler check EXPECT_CALL(session_observer_mock, - OnSessionEndedCallback(connection_id, session_id, _, service)) + OnSessionEndedCallback( + connection_id, session_id, An(), service)) . // return sessions start success WillOnce(DoAll(NotifyTestAsyncWaiter(waiter), Return(connection_key))); -- cgit v1.2.1