summaryrefslogtreecommitdiff
path: root/tests/twisted/servicetest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/twisted/servicetest.py')
-rw-r--r--tests/twisted/servicetest.py397
1 files changed, 194 insertions, 203 deletions
diff --git a/tests/twisted/servicetest.py b/tests/twisted/servicetest.py
index b8485e1..bb00d7a 100644
--- a/tests/twisted/servicetest.py
+++ b/tests/twisted/servicetest.py
@@ -6,50 +6,20 @@ Infrastructure code for testing connection managers.
from twisted.internet import glib2reactor
from twisted.internet.protocol import Protocol, Factory, ClientFactory
glib2reactor.install()
+import sys
import pprint
-import traceback
import unittest
import dbus.glib
from twisted.internet import reactor
+import constants as cs
+
tp_name_prefix = 'org.freedesktop.Telepathy'
tp_path_prefix = '/org/freedesktop/Telepathy'
-class TryNextHandler(Exception):
- pass
-
-def lazy(func):
- def handler(event, data):
- if func(event, data):
- return True
- else:
- raise TryNextHandler()
- handler.__name__ = func.__name__
- return handler
-
-def match(type, **kw):
- def decorate(func):
- def handler(event, data, *extra, **extra_kw):
- if event.type != type:
- return False
-
- for key, value in kw.iteritems():
- if not hasattr(event, key):
- return False
-
- if getattr(event, key) != value:
- return False
-
- return func(event, data, *extra, **extra_kw)
-
- handler.__name__ = func.__name__
- return handler
-
- return decorate
-
class Event:
def __init__(self, type, **kw):
self.__dict__.update(kw)
@@ -68,118 +38,24 @@ def format_event(event):
return ret
-class EventTest:
- """Somewhat odd event dispatcher for asynchronous tests.
-
- Callbacks are kept in a queue. Incoming events are passed to the first
- callback. If the callback returns True, the callback is removed. If the
- callback raises AssertionError, the test fails. If there are no more
- callbacks, the test passes. The reactor is stopped when the test passes.
- """
-
- def __init__(self):
- self.queue = []
- self.data = {'test': self}
- self.timeout_delayed_call = reactor.callLater(5, self.timeout_cb)
- #self.verbose = True
- self.verbose = False
- # ugh
- self.stopping = False
-
- def timeout_cb(self):
- print 'timed out waiting for events'
- print self.queue[0]
- self.fail()
-
- def fail(self):
- # ugh; better way to stop the reactor and exit(1)?
- import os
- os._exit(1)
-
- def expect(self, f):
- self.queue.append(f)
-
- def log(self, s):
- if self.verbose:
- print s
-
- def try_stop(self):
- if self.stopping:
- return True
-
- if not self.queue:
- self.log('no handlers left; stopping')
- self.stopping = True
- reactor.stop()
- return True
-
- return False
-
- def call_handlers(self, event):
- self.log('trying %r' % self.queue[0])
- handler = self.queue.pop(0)
-
- try:
- ret = handler(event, self.data)
- if not ret:
- self.queue.insert(0, handler)
- except TryNextHandler, e:
- if self.queue:
- ret = self.call_handlers(event)
- else:
- ret = False
- self.queue.insert(0, handler)
-
- return ret
-
- def handle_event(self, event):
- if self.try_stop():
- return
-
- self.log('got event:')
- self.log('- type: %s' % event.type)
- map(self.log, format_event(event))
-
- try:
- ret = self.call_handlers(event)
- except SystemExit, e:
- if e.code:
- print "Unsuccessful exit:", e
- self.fail()
- else:
- self.queue[:] = []
- ret = True
- except AssertionError, e:
- print 'test failed:'
- traceback.print_exc()
- self.fail()
- except (Exception, KeyboardInterrupt), e:
- print 'error in handler:'
- traceback.print_exc()
- self.fail()
-
- if ret not in (True, False):
- print ("warning: %s() returned something other than True or False"
- % self.queue[0].__name__)
-
- if ret:
- self.timeout_delayed_call.reset(5)
- self.log('event handled')
- else:
- self.log('event not handled')
-
- self.log('')
- self.try_stop()
-
class EventPattern:
def __init__(self, type, **properties):
self.type = type
- self.predicate = lambda x: True
+ self.predicate = None
if 'predicate' in properties:
self.predicate = properties['predicate']
del properties['predicate']
self.properties = properties
+ def __repr__(self):
+ properties = dict(self.properties)
+
+ if self.predicate is not None:
+ properties['predicate'] = self.predicate
+
+ return '%s(%r, **%r)' % (
+ self.__class__.__name__, self.type, properties)
+
def match(self, event):
if event.type != self.type:
return False
@@ -191,7 +67,7 @@ class EventPattern:
except AttributeError:
return False
- if self.predicate(event):
+ if self.predicate is None or self.predicate(event):
return True
return False
@@ -208,7 +84,7 @@ class BaseEventQueue:
def __init__(self, timeout=None):
self.verbose = False
- self.past_events = []
+ self.forbidden_events = set()
if timeout is None:
self.timeout = 5
@@ -219,36 +95,50 @@ class BaseEventQueue:
if self.verbose:
print s
- def flush_past_events(self):
- self.past_events = []
-
- def expect_racy(self, type, **kw):
- pattern = EventPattern(type, **kw)
+ def log_event(self, event):
+ if self.verbose:
+ self.log('got event:')
- for event in self.past_events:
- if pattern.match(event):
- self.log('past event handled')
+ if self.verbose:
map(self.log, format_event(event))
- self.log('')
- self.past_events.remove(event)
- return event
- return self.expect(type, **kw)
+ def forbid_events(self, patterns):
+ """
+ Add patterns (an iterable of EventPattern) to the set of forbidden
+ events. If a forbidden event occurs during an expect or expect_many,
+ the test will fail.
+ """
+ self.forbidden_events.update(set(patterns))
+
+ def unforbid_events(self, patterns):
+ """
+ Remove 'patterns' (an iterable of EventPattern) from the set of
+ forbidden events. These must be the same EventPattern pointers that
+ were passed to forbid_events.
+ """
+ self.forbidden_events.difference_update(set(patterns))
+
+ def _check_forbidden(self, event):
+ for e in self.forbidden_events:
+ if e.match(event):
+ print "forbidden event occurred:"
+ for x in format_event(event):
+ print x
+ assert False
def expect(self, type, **kw):
pattern = EventPattern(type, **kw)
while True:
event = self.wait()
- self.log('got event:')
- map(self.log, format_event(event))
+ self.log_event(event)
+ self._check_forbidden(event)
if pattern.match(event):
self.log('handled')
self.log('')
return event
- self.past_events.append(event)
self.log('not handled')
self.log('')
@@ -256,18 +146,25 @@ class BaseEventQueue:
ret = [None] * len(patterns)
while None in ret:
- event = self.wait()
- self.log('got event:')
- map(self.log, format_event(event))
+ try:
+ event = self.wait()
+ except TimeoutError:
+ self.log('timeout')
+ self.log('still expecting:')
+ for i, pattern in enumerate(patterns):
+ if ret[i] is None:
+ self.log(' - %r' % pattern)
+ raise
+ self.log_event(event)
+ self._check_forbidden(event)
for i, pattern in enumerate(patterns):
- if pattern.match(event):
+ if ret[i] is None and pattern.match(event):
self.log('handled')
self.log('')
ret[i] = event
break
else:
- self.past_events.append(event)
self.log('not handled')
self.log('')
@@ -277,8 +174,7 @@ class BaseEventQueue:
pattern = EventPattern(type, **kw)
event = self.wait()
- self.log('got event:')
- map(self.log, format_event(event))
+ self.log_event(event)
if pattern.match(event):
self.log('handled')
@@ -343,6 +239,16 @@ class EventQueueTest(unittest.TestCase):
assert bar.type == 'bar'
assert foo.type == 'foo'
+ def test_expect_many2(self):
+ # Test that events are only matched against patterns that haven't yet
+ # been matched. This tests a regression.
+ queue = TestEventQueue([Event('foo', x=1), Event('foo', x=2)])
+ foo1, foo2 = queue.expect_many(
+ EventPattern('foo'),
+ EventPattern('foo'))
+ assert foo1.type == 'foo' and foo1.x == 1
+ assert foo2.type == 'foo' and foo2.x == 2
+
def test_timeout(self):
queue = TestEventQueue([])
self.assertRaises(TimeoutError, queue.expect, 'foo')
@@ -369,7 +275,10 @@ def unwrap(x):
if isinstance(x, dict):
return dict([(unwrap(k), unwrap(v)) for k, v in x.iteritems()])
- for t in [unicode, str, long, int, float, bool]:
+ if isinstance(x, dbus.Boolean):
+ return bool(x)
+
+ for t in [unicode, str, long, int, float]:
if isinstance(x, t):
return t(x)
@@ -384,7 +293,8 @@ def call_async(test, proxy, method, *args, **kw):
value=unwrap(ret)))
def error_func(err):
- test.handle_event(Event('dbus-error', method=method, error=err))
+ test.handle_event(Event('dbus-error', method=method, error=err,
+ name=err.get_dbus_name(), message=str(err)))
method_proxy = getattr(proxy, method)
kw.update({'reply_handler': reply_func, 'error_handler': error_func})
@@ -392,14 +302,20 @@ def call_async(test, proxy, method, *args, **kw):
def sync_dbus(bus, q, conn):
# Dummy D-Bus method call
- call_async(q, conn, "InspectHandles", 1, [])
-
- event = q.expect('dbus-return', method='InspectHandles')
+ # This won't do the right thing unless the proxy has a unique name.
+ assert conn.object.bus_name.startswith(':')
+ root_object = bus.get_object(conn.object.bus_name, '/')
+ call_async(
+ q, dbus.Interface(root_object, 'org.freedesktop.DBus.Peer'), 'Ping')
+ q.expect('dbus-return', method='Ping')
class ProxyWrapper:
def __init__(self, object, default, others):
self.object = object
self.default_interface = dbus.Interface(object, default)
+ self.Properties = dbus.Interface(object, dbus.PROPERTIES_IFACE)
+ self.TpProperties = \
+ dbus.Interface(object, tp_name_prefix + '.Properties')
self.interfaces = dict([
(name, dbus.Interface(object, iface))
for name, iface in others.iteritems()])
@@ -413,6 +329,33 @@ class ProxyWrapper:
return getattr(self.default_interface, name)
+def wrap_connection(conn):
+ return ProxyWrapper(conn, tp_name_prefix + '.Connection',
+ dict([
+ (name, tp_name_prefix + '.Connection.Interface.' + name)
+ for name in ['Aliasing', 'Avatars', 'Capabilities', 'Contacts',
+ 'Presence', 'SimplePresence', 'Requests']] +
+ [('Peer', 'org.freedesktop.DBus.Peer'),
+ ('ContactCapabilities', cs.CONN_IFACE_CONTACT_CAPS),
+ ('ContactInfo', cs.CONN_IFACE_CONTACT_INFO),
+ ('Location', cs.CONN_IFACE_LOCATION),
+ ('Future', tp_name_prefix + '.Connection.FUTURE'),
+ ('MailNotification', cs.CONN_IFACE_MAIL_NOTIFICATION),
+ ]))
+
+def wrap_channel(chan, type_, extra=None):
+ interfaces = {
+ type_: tp_name_prefix + '.Channel.Type.' + type_,
+ 'Group': tp_name_prefix + '.Channel.Interface.Group',
+ }
+
+ if extra:
+ interfaces.update(dict([
+ (name, tp_name_prefix + '.Channel.Interface.' + name)
+ for name in extra]))
+
+ return ProxyWrapper(chan, tp_name_prefix + '.Channel', interfaces)
+
def make_connection(bus, event_func, name, proto, params):
cm = bus.get_object(
tp_name_prefix + '.ConnectionManager.%s' % name,
@@ -421,29 +364,7 @@ def make_connection(bus, event_func, name, proto, params):
connection_name, connection_path = cm_iface.RequestConnection(
proto, params)
- conn = bus.get_object(connection_name, connection_path)
- conn = ProxyWrapper(conn, tp_name_prefix + '.Connection',
- dict([
- (name, tp_name_prefix + '.Connection.Interface.' + name)
- for name in ['Aliasing', 'Avatars', 'Capabilities', 'Contacts',
- 'Presence', 'SimplePresence', 'Requests']] +
- [('Peer', 'org.freedesktop.DBus.Peer')]))
-
- bus.add_signal_receiver(
- lambda *args, **kw:
- event_func(
- Event('dbus-signal',
- path=unwrap(kw['path'])[len(tp_path_prefix):],
- signal=kw['member'], args=map(unwrap, args),
- interface=kw['interface'])),
- None, # signal name
- None, # interface
- cm._named_service,
- path_keyword='path',
- member_keyword='member',
- interface_keyword='interface',
- byte_arrays=True
- )
+ conn = wrap_connection(bus.get_object(connection_name, connection_path))
return conn
@@ -453,20 +374,12 @@ def make_channel_proxy(conn, path, iface):
chan = dbus.Interface(chan, tp_name_prefix + '.' + iface)
return chan
-def load_event_handlers():
- path, _, _, _ = traceback.extract_stack()[0]
- import compiler
- import __main__
- ast = compiler.parseFile(path)
- return [
- getattr(__main__, node.name)
- for node in ast.node.asList()
- if node.__class__ == compiler.ast.Function and
- node.name.startswith('expect_')]
-
+# block_reading can be used if the test want to choose when we start to read
+# data from the socket.
class EventProtocol(Protocol):
- def __init__(self, queue=None):
+ def __init__(self, queue=None, block_reading=False):
self.queue = queue
+ self.block_reading = block_reading
def dataReceived(self, data):
if self.queue is not None:
@@ -476,12 +389,24 @@ class EventProtocol(Protocol):
def sendData(self, data):
self.transport.write(data)
+ def connectionMade(self):
+ if self.block_reading:
+ self.transport.stopReading()
+
+ def connectionLost(self, reason=None):
+ if self.queue is not None:
+ self.queue.handle_event(Event('socket-disconnected', protocol=self))
+
class EventProtocolFactory(Factory):
- def __init__(self, queue):
+ def __init__(self, queue, block_reading=False):
self.queue = queue
+ self.block_reading = block_reading
+
+ def _create_protocol(self):
+ return EventProtocol(self.queue, self.block_reading)
def buildProtocol(self, addr):
- proto = EventProtocol(self.queue)
+ proto = self._create_protocol()
self.queue.handle_event(Event('socket-connected', protocol=proto))
return proto
@@ -500,6 +425,72 @@ def watch_tube_signals(q, tube):
path_keyword='path', member_keyword='member',
byte_arrays=True)
+def pretty(x):
+ return pprint.pformat(unwrap(x))
+
+def assertEquals(expected, value):
+ if expected != value:
+ raise AssertionError(
+ "expected:\n%s\ngot:\n%s" % (pretty(expected), pretty(value)))
+
+def assertNotEquals(expected, value):
+ if expected == value:
+ raise AssertionError(
+ "expected something other than:\n%s" % pretty(value))
+
+def assertContains(element, value):
+ if element not in value:
+ raise AssertionError(
+ "expected:\n%s\nin:\n%s" % (pretty(element), pretty(value)))
+
+def assertDoesNotContain(element, value):
+ if element in value:
+ raise AssertionError(
+ "expected:\n%s\nnot in:\n%s" % (pretty(element), pretty(value)))
+
+def assertLength(length, value):
+ if len(value) != length:
+ raise AssertionError("expected: length %d, got length %d:\n%s" % (
+ length, len(value), pretty(value)))
+
+def assertFlagsSet(flags, value):
+ masked = value & flags
+ if masked != flags:
+ raise AssertionError(
+ "expected flags %u, of which only %u are set in %u" % (
+ flags, masked, value))
+
+def assertFlagsUnset(flags, value):
+ masked = value & flags
+ if masked != 0:
+ raise AssertionError(
+ "expected none of flags %u, but %u are set in %u" % (
+ flags, masked, value))
+
+def install_colourer():
+ def red(s):
+ return '\x1b[31m%s\x1b[0m' % s
+
+ def green(s):
+ return '\x1b[32m%s\x1b[0m' % s
+
+ patterns = {
+ 'handled': green,
+ 'not handled': red,
+ }
+
+ class Colourer:
+ def __init__(self, fh, patterns):
+ self.fh = fh
+ self.patterns = patterns
+
+ def write(self, s):
+ f = self.patterns.get(s, lambda x: x)
+ self.fh.write(f(s))
+
+ sys.stdout = Colourer(sys.stdout, patterns)
+ return sys.stdout
+
if __name__ == '__main__':
unittest.main()