summaryrefslogtreecommitdiff
path: root/distbuild
diff options
context:
space:
mode:
authorRichard Ipsum <richard.ipsum@codethink.co.uk>2015-05-11 16:31:47 +0100
committerRichard Ipsum <richard.ipsum@codethink.co.uk>2015-05-19 09:19:56 +0000
commit5dd1f23b77a4d1937fc309efa73d208278ab8de4 (patch)
tree6b07ef61f25ee317d6605ff9c3c8fe5294587fa0 /distbuild
parentddef6ab1ae5c4d54e651c9be6653a50e2a44c04b (diff)
downloadmorph-5dd1f23b77a4d1937fc309efa73d208278ab8de4.tar.gz
Use protocol to validate incoming requests
Change-Id: I16680439b131e63d30eeff91814a1af643af6246
Diffstat (limited to 'distbuild')
-rw-r--r--distbuild/initiator_connection.py46
-rw-r--r--distbuild/protocol.py26
2 files changed, 49 insertions, 23 deletions
diff --git a/distbuild/initiator_connection.py b/distbuild/initiator_connection.py
index b3e17e98..d48ad214 100644
--- a/distbuild/initiator_connection.py
+++ b/distbuild/initiator_connection.py
@@ -19,6 +19,11 @@ import logging
import distbuild
+PROTOCOL_VERSION_MISMATCH_RESPONSE = (
+ 'Protocol version mismatch between server and initiator: '
+ 'distbuild network uses distbuild protocol version %s, '
+ 'but client uses version %s.'
+)
class InitiatorDisconnect(object):
@@ -50,7 +55,7 @@ class InitiatorConnection(distbuild.StateMachine):
state machines, and vice versa.
'''
-
+
_idgen = distbuild.IdentifierGenerator('InitiatorConnection')
_route_map = distbuild.RouteMap()
@@ -122,25 +127,28 @@ class InitiatorConnection(distbuild.StateMachine):
'build-cancel': self._handle_build_cancel,
'build-status': self._handle_build_status,
}
- try:
- if event.msg.get('protocol_version') == distbuild.protocol.VERSION:
- msg_handler[event.msg['type']](event)
- else:
- response = (
- 'Protocol version mismatch between server & initiator: '
- 'distbuild network uses distbuild protocol version %s, '
- 'but client uses version %s.' %
- (distbuild.protocol.VERSION,
- event.msg.get('protocol_version')))
- self._refuse_build_request(event.msg, response)
- except (KeyError, ValueError) as ex:
- response = (
- 'Invalid build-request message. Check you are using a '
- 'supported version of Morph. This distbuild network uses '
- 'protocol version %i.' % distbuild.protocol.VERSION)
+
+ protocol_version = event.msg.get('protocol_version')
+ msg_type = event.msg.get('type')
+
+ if (protocol_version == distbuild.protocol.VERSION
+ and msg_type in msg_handler
+ and distbuild.protocol.is_valid_message(event.msg)):
+ try:
+ msg_handler[msg_type](event)
+ except Exception:
+ logging.exception('Error handling msg: %s', event.msg)
+ else:
+ response = 'Bad request'
+
+ if (protocol_version is not None
+ and protocol_version != distbuild.protocol.VERSION):
+ # Provide hint to possible cause of bad request
+ response += ('\n' + PROTOCOL_VERSION_MISMATCH_RESPONSE %
+ (distbuild.protocol.VERSION, protocol_version))
+
+ logging.info('Invalid message from initiator: %s', event.msg)
self._refuse_build_request(event.msg, response)
- logging.info('Invalid message from initiator: %s: exception %r',
- event.msg, ex)
def _refuse_build_request(self, build_request_message, reason):
'''Send an error message back to the initiator.
diff --git a/distbuild/protocol.py b/distbuild/protocol.py
index 9aab6a6d..44552ae1 100644
--- a/distbuild/protocol.py
+++ b/distbuild/protocol.py
@@ -129,13 +129,13 @@ _optional_fields = {
}
-def message(message_type, **kwargs):
- known_types = _required_fields.keys()
- assert message_type in known_types
-
+def _validate(message_type, **kwargs):
required_fields = _required_fields[message_type]
optional_fields = _optional_fields.get(message_type, [])
+ known_types = _required_fields.keys()
+ assert message_type in known_types
+
for name in required_fields:
assert name in kwargs, 'field %s is required' % name
@@ -143,7 +143,25 @@ def message(message_type, **kwargs):
assert (name in required_fields or name in optional_fields), \
'field %s is not allowed' % name
+def message(message_type, **kwargs):
+ _validate(message_type, **kwargs)
+
msg = dict(kwargs)
msg['type'] = message_type
return msg
+def is_valid_message(msg):
+
+ if 'type' not in msg:
+ return False
+
+ msg_type = msg['type']
+ del msg['type']
+
+ try:
+ _validate(msg_type, **msg)
+ return True
+ except AssertionError:
+ return False
+ finally:
+ msg['type'] = msg_type