diff options
Diffstat (limited to 'tests/twisted/servicetest.py')
-rw-r--r-- | tests/twisted/servicetest.py | 397 |
1 files changed, 194 insertions, 203 deletions
diff --git a/tests/twisted/servicetest.py b/tests/twisted/servicetest.py index b8485e1..bb00d7a 100644 --- a/tests/twisted/servicetest.py +++ b/tests/twisted/servicetest.py @@ -6,50 +6,20 @@ Infrastructure code for testing connection managers. from twisted.internet import glib2reactor from twisted.internet.protocol import Protocol, Factory, ClientFactory glib2reactor.install() +import sys import pprint -import traceback import unittest import dbus.glib from twisted.internet import reactor +import constants as cs + tp_name_prefix = 'org.freedesktop.Telepathy' tp_path_prefix = '/org/freedesktop/Telepathy' -class TryNextHandler(Exception): - pass - -def lazy(func): - def handler(event, data): - if func(event, data): - return True - else: - raise TryNextHandler() - handler.__name__ = func.__name__ - return handler - -def match(type, **kw): - def decorate(func): - def handler(event, data, *extra, **extra_kw): - if event.type != type: - return False - - for key, value in kw.iteritems(): - if not hasattr(event, key): - return False - - if getattr(event, key) != value: - return False - - return func(event, data, *extra, **extra_kw) - - handler.__name__ = func.__name__ - return handler - - return decorate - class Event: def __init__(self, type, **kw): self.__dict__.update(kw) @@ -68,118 +38,24 @@ def format_event(event): return ret -class EventTest: - """Somewhat odd event dispatcher for asynchronous tests. - - Callbacks are kept in a queue. Incoming events are passed to the first - callback. If the callback returns True, the callback is removed. If the - callback raises AssertionError, the test fails. If there are no more - callbacks, the test passes. The reactor is stopped when the test passes. - """ - - def __init__(self): - self.queue = [] - self.data = {'test': self} - self.timeout_delayed_call = reactor.callLater(5, self.timeout_cb) - #self.verbose = True - self.verbose = False - # ugh - self.stopping = False - - def timeout_cb(self): - print 'timed out waiting for events' - print self.queue[0] - self.fail() - - def fail(self): - # ugh; better way to stop the reactor and exit(1)? - import os - os._exit(1) - - def expect(self, f): - self.queue.append(f) - - def log(self, s): - if self.verbose: - print s - - def try_stop(self): - if self.stopping: - return True - - if not self.queue: - self.log('no handlers left; stopping') - self.stopping = True - reactor.stop() - return True - - return False - - def call_handlers(self, event): - self.log('trying %r' % self.queue[0]) - handler = self.queue.pop(0) - - try: - ret = handler(event, self.data) - if not ret: - self.queue.insert(0, handler) - except TryNextHandler, e: - if self.queue: - ret = self.call_handlers(event) - else: - ret = False - self.queue.insert(0, handler) - - return ret - - def handle_event(self, event): - if self.try_stop(): - return - - self.log('got event:') - self.log('- type: %s' % event.type) - map(self.log, format_event(event)) - - try: - ret = self.call_handlers(event) - except SystemExit, e: - if e.code: - print "Unsuccessful exit:", e - self.fail() - else: - self.queue[:] = [] - ret = True - except AssertionError, e: - print 'test failed:' - traceback.print_exc() - self.fail() - except (Exception, KeyboardInterrupt), e: - print 'error in handler:' - traceback.print_exc() - self.fail() - - if ret not in (True, False): - print ("warning: %s() returned something other than True or False" - % self.queue[0].__name__) - - if ret: - self.timeout_delayed_call.reset(5) - self.log('event handled') - else: - self.log('event not handled') - - self.log('') - self.try_stop() - class EventPattern: def __init__(self, type, **properties): self.type = type - self.predicate = lambda x: True + self.predicate = None if 'predicate' in properties: self.predicate = properties['predicate'] del properties['predicate'] self.properties = properties + def __repr__(self): + properties = dict(self.properties) + + if self.predicate is not None: + properties['predicate'] = self.predicate + + return '%s(%r, **%r)' % ( + self.__class__.__name__, self.type, properties) + def match(self, event): if event.type != self.type: return False @@ -191,7 +67,7 @@ class EventPattern: except AttributeError: return False - if self.predicate(event): + if self.predicate is None or self.predicate(event): return True return False @@ -208,7 +84,7 @@ class BaseEventQueue: def __init__(self, timeout=None): self.verbose = False - self.past_events = [] + self.forbidden_events = set() if timeout is None: self.timeout = 5 @@ -219,36 +95,50 @@ class BaseEventQueue: if self.verbose: print s - def flush_past_events(self): - self.past_events = [] - - def expect_racy(self, type, **kw): - pattern = EventPattern(type, **kw) + def log_event(self, event): + if self.verbose: + self.log('got event:') - for event in self.past_events: - if pattern.match(event): - self.log('past event handled') + if self.verbose: map(self.log, format_event(event)) - self.log('') - self.past_events.remove(event) - return event - return self.expect(type, **kw) + def forbid_events(self, patterns): + """ + Add patterns (an iterable of EventPattern) to the set of forbidden + events. If a forbidden event occurs during an expect or expect_many, + the test will fail. + """ + self.forbidden_events.update(set(patterns)) + + def unforbid_events(self, patterns): + """ + Remove 'patterns' (an iterable of EventPattern) from the set of + forbidden events. These must be the same EventPattern pointers that + were passed to forbid_events. + """ + self.forbidden_events.difference_update(set(patterns)) + + def _check_forbidden(self, event): + for e in self.forbidden_events: + if e.match(event): + print "forbidden event occurred:" + for x in format_event(event): + print x + assert False def expect(self, type, **kw): pattern = EventPattern(type, **kw) while True: event = self.wait() - self.log('got event:') - map(self.log, format_event(event)) + self.log_event(event) + self._check_forbidden(event) if pattern.match(event): self.log('handled') self.log('') return event - self.past_events.append(event) self.log('not handled') self.log('') @@ -256,18 +146,25 @@ class BaseEventQueue: ret = [None] * len(patterns) while None in ret: - event = self.wait() - self.log('got event:') - map(self.log, format_event(event)) + try: + event = self.wait() + except TimeoutError: + self.log('timeout') + self.log('still expecting:') + for i, pattern in enumerate(patterns): + if ret[i] is None: + self.log(' - %r' % pattern) + raise + self.log_event(event) + self._check_forbidden(event) for i, pattern in enumerate(patterns): - if pattern.match(event): + if ret[i] is None and pattern.match(event): self.log('handled') self.log('') ret[i] = event break else: - self.past_events.append(event) self.log('not handled') self.log('') @@ -277,8 +174,7 @@ class BaseEventQueue: pattern = EventPattern(type, **kw) event = self.wait() - self.log('got event:') - map(self.log, format_event(event)) + self.log_event(event) if pattern.match(event): self.log('handled') @@ -343,6 +239,16 @@ class EventQueueTest(unittest.TestCase): assert bar.type == 'bar' assert foo.type == 'foo' + def test_expect_many2(self): + # Test that events are only matched against patterns that haven't yet + # been matched. This tests a regression. + queue = TestEventQueue([Event('foo', x=1), Event('foo', x=2)]) + foo1, foo2 = queue.expect_many( + EventPattern('foo'), + EventPattern('foo')) + assert foo1.type == 'foo' and foo1.x == 1 + assert foo2.type == 'foo' and foo2.x == 2 + def test_timeout(self): queue = TestEventQueue([]) self.assertRaises(TimeoutError, queue.expect, 'foo') @@ -369,7 +275,10 @@ def unwrap(x): if isinstance(x, dict): return dict([(unwrap(k), unwrap(v)) for k, v in x.iteritems()]) - for t in [unicode, str, long, int, float, bool]: + if isinstance(x, dbus.Boolean): + return bool(x) + + for t in [unicode, str, long, int, float]: if isinstance(x, t): return t(x) @@ -384,7 +293,8 @@ def call_async(test, proxy, method, *args, **kw): value=unwrap(ret))) def error_func(err): - test.handle_event(Event('dbus-error', method=method, error=err)) + test.handle_event(Event('dbus-error', method=method, error=err, + name=err.get_dbus_name(), message=str(err))) method_proxy = getattr(proxy, method) kw.update({'reply_handler': reply_func, 'error_handler': error_func}) @@ -392,14 +302,20 @@ def call_async(test, proxy, method, *args, **kw): def sync_dbus(bus, q, conn): # Dummy D-Bus method call - call_async(q, conn, "InspectHandles", 1, []) - - event = q.expect('dbus-return', method='InspectHandles') + # This won't do the right thing unless the proxy has a unique name. + assert conn.object.bus_name.startswith(':') + root_object = bus.get_object(conn.object.bus_name, '/') + call_async( + q, dbus.Interface(root_object, 'org.freedesktop.DBus.Peer'), 'Ping') + q.expect('dbus-return', method='Ping') class ProxyWrapper: def __init__(self, object, default, others): self.object = object self.default_interface = dbus.Interface(object, default) + self.Properties = dbus.Interface(object, dbus.PROPERTIES_IFACE) + self.TpProperties = \ + dbus.Interface(object, tp_name_prefix + '.Properties') self.interfaces = dict([ (name, dbus.Interface(object, iface)) for name, iface in others.iteritems()]) @@ -413,6 +329,33 @@ class ProxyWrapper: return getattr(self.default_interface, name) +def wrap_connection(conn): + return ProxyWrapper(conn, tp_name_prefix + '.Connection', + dict([ + (name, tp_name_prefix + '.Connection.Interface.' + name) + for name in ['Aliasing', 'Avatars', 'Capabilities', 'Contacts', + 'Presence', 'SimplePresence', 'Requests']] + + [('Peer', 'org.freedesktop.DBus.Peer'), + ('ContactCapabilities', cs.CONN_IFACE_CONTACT_CAPS), + ('ContactInfo', cs.CONN_IFACE_CONTACT_INFO), + ('Location', cs.CONN_IFACE_LOCATION), + ('Future', tp_name_prefix + '.Connection.FUTURE'), + ('MailNotification', cs.CONN_IFACE_MAIL_NOTIFICATION), + ])) + +def wrap_channel(chan, type_, extra=None): + interfaces = { + type_: tp_name_prefix + '.Channel.Type.' + type_, + 'Group': tp_name_prefix + '.Channel.Interface.Group', + } + + if extra: + interfaces.update(dict([ + (name, tp_name_prefix + '.Channel.Interface.' + name) + for name in extra])) + + return ProxyWrapper(chan, tp_name_prefix + '.Channel', interfaces) + def make_connection(bus, event_func, name, proto, params): cm = bus.get_object( tp_name_prefix + '.ConnectionManager.%s' % name, @@ -421,29 +364,7 @@ def make_connection(bus, event_func, name, proto, params): connection_name, connection_path = cm_iface.RequestConnection( proto, params) - conn = bus.get_object(connection_name, connection_path) - conn = ProxyWrapper(conn, tp_name_prefix + '.Connection', - dict([ - (name, tp_name_prefix + '.Connection.Interface.' + name) - for name in ['Aliasing', 'Avatars', 'Capabilities', 'Contacts', - 'Presence', 'SimplePresence', 'Requests']] + - [('Peer', 'org.freedesktop.DBus.Peer')])) - - bus.add_signal_receiver( - lambda *args, **kw: - event_func( - Event('dbus-signal', - path=unwrap(kw['path'])[len(tp_path_prefix):], - signal=kw['member'], args=map(unwrap, args), - interface=kw['interface'])), - None, # signal name - None, # interface - cm._named_service, - path_keyword='path', - member_keyword='member', - interface_keyword='interface', - byte_arrays=True - ) + conn = wrap_connection(bus.get_object(connection_name, connection_path)) return conn @@ -453,20 +374,12 @@ def make_channel_proxy(conn, path, iface): chan = dbus.Interface(chan, tp_name_prefix + '.' + iface) return chan -def load_event_handlers(): - path, _, _, _ = traceback.extract_stack()[0] - import compiler - import __main__ - ast = compiler.parseFile(path) - return [ - getattr(__main__, node.name) - for node in ast.node.asList() - if node.__class__ == compiler.ast.Function and - node.name.startswith('expect_')] - +# block_reading can be used if the test want to choose when we start to read +# data from the socket. class EventProtocol(Protocol): - def __init__(self, queue=None): + def __init__(self, queue=None, block_reading=False): self.queue = queue + self.block_reading = block_reading def dataReceived(self, data): if self.queue is not None: @@ -476,12 +389,24 @@ class EventProtocol(Protocol): def sendData(self, data): self.transport.write(data) + def connectionMade(self): + if self.block_reading: + self.transport.stopReading() + + def connectionLost(self, reason=None): + if self.queue is not None: + self.queue.handle_event(Event('socket-disconnected', protocol=self)) + class EventProtocolFactory(Factory): - def __init__(self, queue): + def __init__(self, queue, block_reading=False): self.queue = queue + self.block_reading = block_reading + + def _create_protocol(self): + return EventProtocol(self.queue, self.block_reading) def buildProtocol(self, addr): - proto = EventProtocol(self.queue) + proto = self._create_protocol() self.queue.handle_event(Event('socket-connected', protocol=proto)) return proto @@ -500,6 +425,72 @@ def watch_tube_signals(q, tube): path_keyword='path', member_keyword='member', byte_arrays=True) +def pretty(x): + return pprint.pformat(unwrap(x)) + +def assertEquals(expected, value): + if expected != value: + raise AssertionError( + "expected:\n%s\ngot:\n%s" % (pretty(expected), pretty(value))) + +def assertNotEquals(expected, value): + if expected == value: + raise AssertionError( + "expected something other than:\n%s" % pretty(value)) + +def assertContains(element, value): + if element not in value: + raise AssertionError( + "expected:\n%s\nin:\n%s" % (pretty(element), pretty(value))) + +def assertDoesNotContain(element, value): + if element in value: + raise AssertionError( + "expected:\n%s\nnot in:\n%s" % (pretty(element), pretty(value))) + +def assertLength(length, value): + if len(value) != length: + raise AssertionError("expected: length %d, got length %d:\n%s" % ( + length, len(value), pretty(value))) + +def assertFlagsSet(flags, value): + masked = value & flags + if masked != flags: + raise AssertionError( + "expected flags %u, of which only %u are set in %u" % ( + flags, masked, value)) + +def assertFlagsUnset(flags, value): + masked = value & flags + if masked != 0: + raise AssertionError( + "expected none of flags %u, but %u are set in %u" % ( + flags, masked, value)) + +def install_colourer(): + def red(s): + return '\x1b[31m%s\x1b[0m' % s + + def green(s): + return '\x1b[32m%s\x1b[0m' % s + + patterns = { + 'handled': green, + 'not handled': red, + } + + class Colourer: + def __init__(self, fh, patterns): + self.fh = fh + self.patterns = patterns + + def write(self, s): + f = self.patterns.get(s, lambda x: x) + self.fh.write(f(s)) + + sys.stdout = Colourer(sys.stdout, patterns) + return sys.stdout + if __name__ == '__main__': unittest.main() |