diff options
Diffstat (limited to 'tests/twisted/servicetest.py')
-rw-r--r-- | tests/twisted/servicetest.py | 364 |
1 files changed, 267 insertions, 97 deletions
diff --git a/tests/twisted/servicetest.py b/tests/twisted/servicetest.py index 4dc604f3..821240b0 100644 --- a/tests/twisted/servicetest.py +++ b/tests/twisted/servicetest.py @@ -1,5 +1,5 @@ # Copyright (C) 2009 Nokia Corporation -# Copyright (C) 2009 Collabora Ltd. +# Copyright (C) 2009-2013 Collabora Ltd. # # This library is free software; you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public @@ -17,30 +17,55 @@ # 02110-1301 USA """ -Infrastructure code for testing Mission Control +Infrastructure code for testing Telepathy services. """ from twisted.internet import glib2reactor from twisted.internet.protocol import Protocol, Factory, ClientFactory glib2reactor.install() import sys +import time +import os import pprint import unittest import dbus import dbus.lowlevel -import dbus.glib +from dbus.mainloop.glib import DBusGMainLoop +DBusGMainLoop(set_as_default=True) from twisted.internet import reactor -tp_name_prefix = 'org.freedesktop.Telepathy' -tp_path_prefix = '/org/freedesktop/Telepathy' +import constants as cs -class Event: +tp_name_prefix = cs.PREFIX +tp_path_prefix = cs.PATH_PREFIX + +class DictionarySupersetOf (object): + """Utility class for expecting "a dictionary with at least these keys".""" + def __init__(self, dictionary): + self._dictionary = dictionary + def __repr__(self): + return "DictionarySupersetOf(%s)" % self._dictionary + def __eq__(self, other): + """would like to just do: + return set(other.items()).issuperset(self._dictionary.items()) + but it turns out that this doesn't work if you have another dict + nested in the values of your dicts""" + try: + for k,v in self._dictionary.items(): + if k not in other or other[k] != v: + return False + return True + except TypeError: # other is not iterable + return False + +class Event(object): def __init__(self, type, **kw): self.__dict__.update(kw) self.type = type + (self.subqueue, self.subtype) = type.split ("-", 1) def __str__(self): return '\n'.join([ str(type(self)) ] + format_event(self)) @@ -48,7 +73,7 @@ class Event: def format_event(event): ret = ['- type %s' % event.type] - for key in dir(event): + for key in sorted(dir(event)): if key != 'type' and not key.startswith('_'): ret.append('- %s: %s' % ( key, pprint.pformat(getattr(event, key)))) @@ -61,16 +86,17 @@ def format_event(event): class EventPattern: def __init__(self, type, **properties): self.type = type - self.predicate = lambda x: True + self.predicate = None if 'predicate' in properties: self.predicate = properties['predicate'] del properties['predicate'] self.properties = properties + (self.subqueue, self.subtype) = type.split ("-", 1) def __repr__(self): properties = dict(self.properties) - if self.predicate: + if self.predicate is not None: properties['predicate'] = self.predicate return '%s(%r, **%r)' % ( @@ -87,7 +113,7 @@ class EventPattern: except AttributeError: return False - if self.predicate(event): + if self.predicate is None or self.predicate(event): return True return False @@ -112,8 +138,8 @@ class BaseEventQueue: def __init__(self, timeout=None): self.verbose = False - self.past_events = [] self.forbidden_events = set() + self.event_queues = {} if timeout is None: self.timeout = 5 @@ -124,28 +150,14 @@ class BaseEventQueue: if self.verbose: print s - def log_event(self, event): - if self.verbose: - self.log('got event:') - - if self.verbose: - map(self.log, format_event(event)) - - def flush_past_events(self): - self.past_events = [] + def log_queues(self, queues): + self.log ("Waiting for event on: %s" % ", ".join(queues)) - def expect_racy(self, type, **kw): - pattern = EventPattern(type, **kw) - - for event in self.past_events: - if pattern.match(event): - self.log('past event handled') - map(self.log, format_event(event)) - self.log('') - self.past_events.remove(event) - return event + def log_event(self, event): + self.log('got event:') - return self.expect(type, **kw) + if self.verbose: + map(self.log, format_event(event)) def forbid_events(self, patterns): """ @@ -163,34 +175,86 @@ class BaseEventQueue: """ self.forbidden_events.difference_update(set(patterns)) + def unforbid_all(self): + """ + Remove all patterns from the set of forbidden events. + """ + self.forbidden_events.clear() + def _check_forbidden(self, event): for e in self.forbidden_events: if e.match(event): raise ForbiddenEventOccurred(event) def expect(self, type, **kw): + """ + Waits for an event matching the supplied pattern to occur, and returns + it. For example, to await a D-Bus signal with particular arguments: + + e = q.expect('dbus-signal', signal='Badgers', args=["foo", 42]) + """ pattern = EventPattern(type, **kw) + t = time.time() while True: - event = self.wait() - self.log_event(event) + try: + event = self.wait([pattern.subqueue]) + except TimeoutError: + self.log('timeout') + self.log('still expecting:') + self.log(' - %r' % pattern) + raise + self._check_forbidden(event) if pattern.match(event): - self.log('handled') + self.log('handled, took %0.3f ms' + % ((time.time() - t) * 1000.0) ) self.log('') return event - self.past_events.append(event) self.log('not handled') self.log('') def expect_many(self, *patterns): + """ + Waits for events matching all of the supplied EventPattern instances to + return, and returns a list of events in the same order as the patterns + they matched. After a pattern is successfully matched, it is not + considered for future events; if more than one unsatisfied pattern + matches an event, the first "wins". + + Note that the expected events may occur in any order. If you're + expecting a series of events in a particular order, use repeated calls + to expect() instead. + + This method is useful when you're awaiting a number of events which may + happen in any order. For instance, in telepathy-gabble, calling a D-Bus + method often causes a value to be returned immediately, as well as a + query to be sent to the server. Since these events may reach the test + in either order, the following is incorrect and will fail if the IQ + happens to reach the test first: + + ret = q.expect('dbus-return', method='Foo') + query = q.expect('stream-iq', query_ns=ns.FOO) + + The following would be correct: + + ret, query = q.expect_many( + EventPattern('dbus-return', method='Foo'), + EventPattern('stream-iq', query_ns=ns.FOO), + ) + """ ret = [None] * len(patterns) + t = time.time() while None in ret: try: - event = self.wait() + queues = set() + for i, pattern in enumerate(patterns): + if ret[i] is None: + queues.add(pattern.subqueue) + event = self.wait(queues) except TimeoutError: self.log('timeout') self.log('still expecting:') @@ -198,17 +262,16 @@ class BaseEventQueue: if ret[i] is None: self.log(' - %r' % pattern) raise - self.log_event(event) self._check_forbidden(event) for i, pattern in enumerate(patterns): if ret[i] is None and pattern.match(event): - self.log('handled') + self.log('handled, took %0.3f ms' + % ((time.time() - t) * 1000.0) ) self.log('') ret[i] = event break else: - self.past_events.append(event) self.log('not handled') self.log('') @@ -217,8 +280,7 @@ class BaseEventQueue: def demand(self, type, **kw): pattern = EventPattern(type, **kw) - event = self.wait() - self.log_event(event) + event = self.wait([pattern.subqueue]) if pattern.match(event): self.log('handled') @@ -228,19 +290,39 @@ class BaseEventQueue: self.log('not handled') raise RuntimeError('expected %r, got %r' % (pattern, event)) + def queues_available(self, queues): + if queues == None: + return self.event_queues.keys() + else: + available = self.event_queues.keys() + return filter(lambda x: x in available, queues) + + + def pop_next(self, queue): + events = self.event_queues[queue] + e = events.pop(0) + if not events: + self.event_queues.pop (queue) + return e + + def append(self, event): + self.log ("Adding to queue") + self.log_event (event) + self.event_queues[event.subqueue] = \ + self.event_queues.get(event.subqueue, []) + [event] + class IteratingEventQueue(BaseEventQueue): """Event queue that works by iterating the Twisted reactor.""" def __init__(self, timeout=None): BaseEventQueue.__init__(self, timeout) - self.events = [] self._dbus_method_impls = [] self._buses = [] # a message filter which will claim we handled everything self._dbus_dev_null = \ lambda bus, message: dbus.lowlevel.HANDLER_RESULT_HANDLED - def wait(self): + def wait(self, queues=None): stop = [False] def later(): @@ -248,21 +330,21 @@ class IteratingEventQueue(BaseEventQueue): delayed_call = reactor.callLater(self.timeout, later) - while (not self.events) and (not stop[0]): - reactor.iterate(0.1) + self.log_queues(queues) + + qa = self.queues_available(queues) + while not qa and (not stop[0]): + reactor.iterate(0.01) + qa = self.queues_available(queues) - if self.events: + if qa: delayed_call.cancel() - return self.events.pop(0) + e = self.pop_next (qa[0]) + self.log_event (e) + return e else: raise TimeoutError - def append(self, event): - self.events.append(event) - - # compatibility - handle_event = append - def add_dbus_method_impl(self, cb, bus=None, **kwargs): if bus is None: bus = self._buses[0] @@ -387,50 +469,74 @@ class IteratingEventQueue(BaseEventQueue): class TestEventQueue(BaseEventQueue): def __init__(self, events): BaseEventQueue.__init__(self) - self.events = events + for e in events: + self.append (e) - def wait(self): - if self.events: - return self.events.pop(0) + def wait(self, queues = None): + qa = self.queues_available(queues) + + if qa: + return self.pop_next (qa[0]) else: raise TimeoutError class EventQueueTest(unittest.TestCase): def test_expect(self): - queue = TestEventQueue([Event('foo'), Event('bar')]) - assert queue.expect('foo').type == 'foo' - assert queue.expect('bar').type == 'bar' + queue = TestEventQueue([Event('test-foo'), Event('test-bar')]) + assert queue.expect('test-foo').type == 'test-foo' + assert queue.expect('test-bar').type == 'test-bar' def test_expect_many(self): - queue = TestEventQueue([Event('foo'), Event('bar')]) + queue = TestEventQueue([Event('test-foo'), + Event('test-bar')]) bar, foo = queue.expect_many( - EventPattern('bar'), - EventPattern('foo')) - assert bar.type == 'bar' - assert foo.type == 'foo' + EventPattern('test-bar'), + EventPattern('test-foo')) + assert bar.type == 'test-bar' + assert foo.type == 'test-foo' def test_expect_many2(self): # Test that events are only matched against patterns that haven't yet # been matched. This tests a regression. - queue = TestEventQueue([Event('foo', x=1), Event('foo', x=2)]) + queue = TestEventQueue([Event('test-foo', x=1), Event('test-foo', x=2)]) foo1, foo2 = queue.expect_many( - EventPattern('foo'), - EventPattern('foo')) - assert foo1.type == 'foo' and foo1.x == 1 - assert foo2.type == 'foo' and foo2.x == 2 + EventPattern('test-foo'), + EventPattern('test-foo')) + assert foo1.type == 'test-foo' and foo1.x == 1 + assert foo2.type == 'test-foo' and foo2.x == 2 + + def test_expect_queueing(self): + queue = TestEventQueue([Event('foo-test', x=1), + Event('foo-test', x=2)]) + + queue.append(Event('bar-test', x=1)) + queue.append(Event('bar-test', x=2)) + + queue.append(Event('baz-test', x=1)) + queue.append(Event('baz-test', x=2)) + + for x in xrange(1,2): + e = queue.expect ('baz-test') + assertEquals (x, e.x) + + e = queue.expect ('bar-test') + assertEquals (x, e.x) + + e = queue.expect ('foo-test') + assertEquals (x, e.x) def test_timeout(self): queue = TestEventQueue([]) - self.assertRaises(TimeoutError, queue.expect, 'foo') + self.assertRaises(TimeoutError, queue.expect, 'test-foo') def test_demand(self): - queue = TestEventQueue([Event('foo'), Event('bar')]) - foo = queue.demand('foo') - assert foo.type == 'foo' + queue = TestEventQueue([Event('test-foo'), Event('test-bar')]) + foo = queue.demand('test-foo') + assert foo.type == 'test-foo' def test_demand_fail(self): - queue = TestEventQueue([Event('foo'), Event('bar')]) - self.assertRaises(RuntimeError, queue.demand, 'bar') + queue = TestEventQueue([Event('test-foo'), Event('test-bar')]) + self.assertRaises(RuntimeError, queue.demand, 'test-bar') def unwrap(x): """Hack to unwrap D-Bus values, so that they're easier to read when @@ -459,11 +565,11 @@ def call_async(test, proxy, method, *args, **kw): resulting method return/error.""" def reply_func(*ret): - test.handle_event(Event('dbus-return', method=method, + test.append(Event('dbus-return', method=method, value=unwrap(ret))) def error_func(err): - test.handle_event(Event('dbus-error', method=method, error=err, + test.append(Event('dbus-error', method=method, error=err, name=err.get_dbus_name(), message=str(err))) method_proxy = getattr(proxy, method) @@ -481,7 +587,7 @@ def sync_dbus(bus, q, proxy): q.expect('dbus-error', method='DummySyncDBus') class ProxyWrapper: - def __init__(self, object, default, others): + def __init__(self, object, default, others={}): self.object = object self.default_interface = dbus.Interface(object, default) self.Properties = dbus.Interface(object, dbus.PROPERTIES_IFACE) @@ -500,6 +606,41 @@ class ProxyWrapper: return getattr(self.default_interface, name) +class ConnWrapper(ProxyWrapper): + def inspect_contact_sync(self, handle): + return self.inspect_contacts_sync([handle])[0] + + def inspect_contacts_sync(self, handles): + h2asv = self.Contacts.GetContactAttributes(handles, [], True) + ret = [] + for h in handles: + ret.append(h2asv[h][cs.ATTR_CONTACT_ID]) + return ret + + def get_contact_handle_sync(self, identifier): + return self.Contacts.GetContactByID(identifier, [])[0] + + def get_contact_handles_sync(self, ids): + return [self.get_contact_handle_sync(i) for i in ids] + +def wrap_connection(conn): + return ConnWrapper(conn, tp_name_prefix + '.Connection', + dict([ + (name, tp_name_prefix + '.Connection.Interface.' + name) + for name in ['Aliasing', 'Avatars', 'Capabilities', 'Contacts', + 'SimplePresence', 'Requests']] + + [('Peer', 'org.freedesktop.DBus.Peer'), + ('ContactCapabilities', cs.CONN_IFACE_CONTACT_CAPS), + ('ContactInfo', cs.CONN_IFACE_CONTACT_INFO), + ('Location', cs.CONN_IFACE_LOCATION), + ('Future', tp_name_prefix + '.Connection.FUTURE'), + ('MailNotification', cs.CONN_IFACE_MAIL_NOTIFICATION), + ('ContactList', cs.CONN_IFACE_CONTACT_LIST), + ('ContactGroups', cs.CONN_IFACE_CONTACT_GROUPS), + ('PowerSaving', cs.CONN_IFACE_POWER_SAVING), + ('Addressing', cs.CONN_IFACE_ADDRESSING), + ])) + def wrap_channel(chan, type_, extra=None): interfaces = { type_: tp_name_prefix + '.Channel.Type.' + type_, @@ -513,14 +654,26 @@ def wrap_channel(chan, type_, extra=None): return ProxyWrapper(chan, tp_name_prefix + '.Channel', interfaces) + +def wrap_content(chan, extra=None): + interfaces = { } + + if extra: + interfaces.update(dict([ + (name, tp_name_prefix + '.Call1.Content.Interface.' + name) + for name in extra])) + + return ProxyWrapper(chan, tp_name_prefix + '.Call1.Content', interfaces) + def make_connection(bus, event_func, name, proto, params): cm = bus.get_object( tp_name_prefix + '.ConnectionManager.%s' % name, - tp_path_prefix + '/ConnectionManager/%s' % name) + tp_path_prefix + '/ConnectionManager/%s' % name, + introspect=False) cm_iface = dbus.Interface(cm, tp_name_prefix + '.ConnectionManager') connection_name, connection_path = cm_iface.RequestConnection( - proto, params) + proto, dbus.Dictionary(params, signature='sv')) conn = wrap_connection(bus.get_object(connection_name, connection_path)) return conn @@ -540,7 +693,7 @@ class EventProtocol(Protocol): def dataReceived(self, data): if self.queue is not None: - self.queue.handle_event(Event('socket-data', protocol=self, + self.queue.append(Event('socket-data', protocol=self, data=data)) def sendData(self, data): @@ -552,7 +705,7 @@ class EventProtocol(Protocol): def connectionLost(self, reason=None): if self.queue is not None: - self.queue.handle_event(Event('socket-disconnected', protocol=self)) + self.queue.append(Event('socket-disconnected', protocol=self)) class EventProtocolFactory(Factory): def __init__(self, queue, block_reading=False): @@ -564,7 +717,7 @@ class EventProtocolFactory(Factory): def buildProtocol(self, addr): proto = self._create_protocol() - self.queue.handle_event(Event('socket-connected', protocol=proto)) + self.queue.append(Event('socket-connected', protocol=proto)) return proto class EventProtocolClientFactory(EventProtocolFactory, ClientFactory): @@ -572,7 +725,7 @@ class EventProtocolClientFactory(EventProtocolFactory, ClientFactory): def watch_tube_signals(q, tube): def got_signal_cb(*args, **kwargs): - q.handle_event(Event('tube-signal', + q.append(Event('tube-signal', path=kwargs['path'], signal=kwargs['member'], args=map(unwrap, args), @@ -590,6 +743,15 @@ def assertEquals(expected, value): raise AssertionError( "expected:\n%s\ngot:\n%s" % (pretty(expected), pretty(value))) +def assertSameSets(expected, value): + exp_set = set(expected) + val_set = set(value) + + if exp_set != val_set: + raise AssertionError( + "expected contents:\n%s\ngot:\n%s" % ( + pretty(exp_set), pretty(val_set))) + def assertNotEquals(expected, value): if expected == value: raise AssertionError( @@ -624,15 +786,11 @@ def assertFlagsUnset(flags, value): "expected none of flags %u, but %u are set in %u" % ( flags, masked, value)) -def assertSameSets(expected, value): - exp_set = set(expected) - val_set = set(value) - - if exp_set != val_set: +def assertDBusError(name, error): + if error.get_dbus_name() != name: raise AssertionError( - "expected contents:\n%s\ngot:\n%s" % ( - pretty(exp_set), pretty(val_set))) - + "expected DBus error named:\n %s\ngot:\n %s\n(with message: %s)" + % (name, error.get_dbus_name(), error.message)) def install_colourer(): def red(s): @@ -652,14 +810,26 @@ def install_colourer(): self.patterns = patterns def write(self, s): - f = self.patterns.get(s, lambda x: x) - self.fh.write(f(s)) + for p, f in self.patterns.items(): + if s.startswith(p): + self.fh.write(f(p) + s[len(p):]) + return + + self.fh.write(s) sys.stdout = Colourer(sys.stdout, patterns) return sys.stdout +# this is just to shut up unittest. +class DummyStream(object): + def write(self, s): + if 'CHECK_TWISTED_VERBOSE' in os.environ: + print s, + def flush(self): + pass if __name__ == '__main__': - unittest.main() - + stream = DummyStream() + runner = unittest.TextTestRunner(stream=stream) + unittest.main(testRunner=runner) |