diff options
author | Vlastimil Zíma <vlastimil.zima@nic.cz> | 2018-01-26 13:02:02 +0100 |
---|---|---|
committer | Vlastimil Zíma <vlastimil.zima@nic.cz> | 2018-01-26 14:39:13 +0100 |
commit | 7e2c019bb4d4e3593398314f84c267787a8a4d77 (patch) | |
tree | c1aa82d2f3751ad84134ad91b0fd234db301b908 /openid | |
parent | bf03ab4f440d76b59330ad423a8213a3b25fbceb (diff) | |
download | openid-7e2c019bb4d4e3593398314f84c267787a8a4d77.tar.gz |
Use testfixtures to replace CatchLogs
Diffstat (limited to 'openid')
-rw-r--r-- | openid/test/support.py | 49 | ||||
-rw-r--r-- | openid/test/test_association_response.py | 22 | ||||
-rw-r--r-- | openid/test/test_consumer.py | 103 | ||||
-rw-r--r-- | openid/test/test_kvform.py | 52 | ||||
-rw-r--r-- | openid/test/test_negotiation.py | 111 | ||||
-rw-r--r-- | openid/test/test_rpverify.py | 26 | ||||
-rw-r--r-- | openid/test/test_server.py | 84 | ||||
-rw-r--r-- | openid/test/test_verifydisco.py | 91 |
8 files changed, 246 insertions, 292 deletions
diff --git a/openid/test/support.py b/openid/test/support.py index c2d45ea..16f54c7 100644 --- a/openid/test/support.py +++ b/openid/test/support.py @@ -1,21 +1,6 @@ -import logging -from logging.handlers import BufferingHandler - from openid import message -class TestHandler(BufferingHandler): - def __init__(self, messages): - BufferingHandler.__init__(self, 0) - self.messages = messages - - def shouldFlush(self): - return False - - def emit(self, record): - self.messages.append(record) - - class OpenIDTestMixin(object): def failUnlessOpenIDValueEquals(self, msg, key, expected, ns=None): if ns is None: @@ -33,37 +18,3 @@ class OpenIDTestMixin(object): actual = msg.getArg(ns, key) error_message = 'openid.%s unexpectedly present: %s' % (key, actual) self.assertIsNone(actual, error_message) - - -class CatchLogs(object): - def setUp(self): - self.messages = [] - root_logger = logging.getLogger() - self.old_log_level = root_logger.getEffectiveLevel() - root_logger.setLevel(logging.DEBUG) - - self.handler = TestHandler(self.messages) - formatter = logging.Formatter("%(message)s [%(asctime)s - %(name)s - %(levelname)s]") - self.handler.setFormatter(formatter) - root_logger.addHandler(self.handler) - - def tearDown(self): - root_logger = logging.getLogger() - root_logger.removeHandler(self.handler) - root_logger.setLevel(self.old_log_level) - - def failUnlessLogMatches(self, *prefixes): - """ - Check that the log messages contained in self.messages have - prefixes in *prefixes. Raise AssertionError if not, or if the - number of prefixes is different than the number of log - messages. - """ - messages = [r.getMessage() for r in self.messages] - assert len(prefixes) == len(messages), "Expected log prefixes %r, got %r" % (prefixes, messages) - - for prefix, msg in zip(prefixes, messages): - assert msg.startswith(prefix), "Expected log prefixes %r, got %r" % (prefixes, messages) - - def failUnlessLogEmpty(self): - self.failUnlessLogMatches() diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 9bac3e2..62b3175 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -5,12 +5,13 @@ this works for now. """ import unittest +from testfixtures import LogCapture + from openid.consumer.consumer import GenericConsumer, ProtocolError from openid.consumer.discover import OPENID_1_1_TYPE, OPENID_2_0_TYPE, OpenIDServiceEndpoint from openid.message import OPENID2_NS, OPENID_NS, Message from openid.server.server import DiffieHellmanSHA1ServerSession from openid.store import memstore -from openid.test.test_consumer import CatchLogs # Some values we can use for convenience (see mkAssocResponse) association_response_values = { @@ -33,9 +34,8 @@ def mkAssocResponse(*keys): return Message.fromOpenIDArgs(args) -class BaseAssocTest(CatchLogs, unittest.TestCase): +class BaseAssocTest(unittest.TestCase): def setUp(self): - CatchLogs.setUp(self) self.store = memstore.MemoryStore() self.consumer = GenericConsumer(self.store) self.endpoint = OpenIDServiceEndpoint() @@ -175,8 +175,9 @@ class TestOpenID1AssociationResponseSessionType(BaseAssocTest): """ def test(self): - self._doTest(expected_session_type, session_type_value) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self._doTest(expected_session_type, session_type_value) + self.assertEqual(logbook.records, []) return test @@ -209,11 +210,12 @@ class TestOpenID1AssociationResponseSessionType(BaseAssocTest): # This one's different because it expects log messages def test_explicitNoEncryption(self): - self._doTest( - session_type_value='no-encryption', - expected_session_type='no-encryption', - ) - self.failUnlessLogMatches('OpenID server sent "no-encryption"') + with LogCapture() as logbook: + self._doTest( + session_type_value='no-encryption', + expected_session_type='no-encryption', + ) + logbook.check(('openid.consumer.consumer', 'WARNING', 'OpenID server sent "no-encryption" for OpenID 1.X')) test_dhSHA1 = mkTest( session_type_value='DH-SHA1', diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 169a423..a427dbf 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -3,6 +3,8 @@ import time import unittest import urlparse +from testfixtures import LogCapture, StringComparison + from openid import association, cryptutil, fetchers, kvform, oidutil from openid.consumer.consumer import (CANCEL, FAILURE, SETUP_NEEDED, SUCCESS, AuthRequest, CancelResponse, Consumer, DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA256ConsumerSession, @@ -20,8 +22,6 @@ from openid.store.nonce import mkNonce, split as splitNonce from openid.yadis.discover import DiscoveryFailure from openid.yadis.manager import Discovery -from .support import CatchLogs - assocs = [ ('another 20-byte key.', 'Snarky'), ('\x00' * 20, 'Zeros'), @@ -212,13 +212,12 @@ consumer_url = 'http://consumer.example.com/' https_server_url = 'https://server.example.com/' -class TestSuccess(unittest.TestCase, CatchLogs): +class TestSuccess(unittest.TestCase): server_url = http_server_url user_url = 'http://www.example.com/user.html' delegate_url = 'http://consumer.example.com/user' def setUp(self): - CatchLogs.setUp(self) self.links = '<link rel="openid.server" href="%s" />' % ( self.server_url,) @@ -226,9 +225,6 @@ class TestSuccess(unittest.TestCase, CatchLogs): '<link rel="openid.delegate" href="%s" />') % ( self.server_url, self.delegate_url) - def tearDown(self): - CatchLogs.tearDown(self) - def test_nodelegate(self): _test_success(self.server_url, self.user_url, self.user_url, self.links) @@ -262,12 +258,10 @@ class TestConstruct(unittest.TestCase): self.assertRaises(TypeError, GenericConsumer) -class TestIdRes(unittest.TestCase, CatchLogs): +class TestIdRes(unittest.TestCase): consumer_class = GenericConsumer def setUp(self): - CatchLogs.setUp(self) - self.store = memstore.MemoryStore() self.consumer = self.consumer_class(self.store) self.return_to = "nonny" @@ -464,19 +458,18 @@ class TestComplete(TestIdRes): }) self.consumer.store = GoodAssocStore() - self.assertRaises(VerifiedError, self.consumer.complete, message, self.endpoint) - - self.failUnlessLogMatches('Error attempting to use stored', - 'Attempting discovery') + with LogCapture() as logbook: + self.assertRaises(VerifiedError, self.consumer.complete, message, self.endpoint) + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) -class TestCompleteMissingSig(unittest.TestCase, CatchLogs): +class TestCompleteMissingSig(unittest.TestCase): def setUp(self): self.store = GoodAssocStore() self.consumer = GenericConsumer(self.store) self.server_url = "http://idp.unittest/" - CatchLogs.setUp(self) claimed_id = 'bogus.claimed' @@ -498,9 +491,6 @@ class TestCompleteMissingSig(unittest.TestCase, CatchLogs): self.endpoint.claimed_id = claimed_id self.consumer._checkReturnTo = lambda unused1, unused2: True - def tearDown(self): - CatchLogs.tearDown(self) - def test_idResMissingNoSigs(self): def _vrfy(resp_msg, endpoint=None): return endpoint @@ -542,14 +532,10 @@ class TestCompleteMissingSig(unittest.TestCase, CatchLogs): self.fail("Non-successful response: %s" % (response,)) -class TestCheckAuthResponse(TestIdRes, CatchLogs): +class TestCheckAuthResponse(TestIdRes): def setUp(self): - CatchLogs.setUp(self) TestIdRes.setUp(self) - def tearDown(self): - CatchLogs.tearDown(self) - def _createAssoc(self): issued = time.time() lifetime = 1000 @@ -602,11 +588,10 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): 'is_valid': 'true', 'invalidate_handle': 'missing', }) - r = self.consumer._processCheckAuthResponse(response, self.server_url) + with LogCapture() as logbook: + r = self.consumer._processCheckAuthResponse(response, self.server_url) self.assertTrue(r) - self.failUnlessLogMatches( - 'Received "invalidate_handle"' - ) + logbook.check(('openid.consumer.consumer', 'INFO', StringComparison('Received "invalidate_handle" from .*'))) def test_invalidateMissing_noStore(self): """invalidate_handle with a handle that is not present""" @@ -615,11 +600,11 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): 'invalidate_handle': 'missing', }) self.consumer.store = None - r = self.consumer._processCheckAuthResponse(response, self.server_url) + with LogCapture() as logbook: + r = self.consumer._processCheckAuthResponse(response, self.server_url) self.assertTrue(r) - self.failUnlessLogMatches( - 'Received "invalidate_handle"', - 'Unexpectedly got invalidate_handle without a store') + logbook.check(('openid.consumer.consumer', 'INFO', StringComparison('Received "invalidate_handle" from .*')), + ('openid.consumer.consumer', 'ERROR', 'Unexpectedly got invalidate_handle without a store!')) def test_invalidatePresent(self): """invalidate_handle with a handle that exists @@ -813,51 +798,52 @@ class CheckAuthHappened(Exception): pass -class CheckNonceVerifyTest(TestIdRes, CatchLogs): +class CheckNonceVerifyTest(TestIdRes): def setUp(self): - CatchLogs.setUp(self) TestIdRes.setUp(self) self.consumer.openid1_nonce_query_arg_name = 'nonce' - def tearDown(self): - CatchLogs.tearDown(self) - def test_openid1Success(self): """use consumer-generated nonce""" nonce_value = mkNonce() self.return_to = 'http://rt.unittest/?nonce=%s' % (nonce_value,) self.response = Message.fromOpenIDArgs({'return_to': self.return_to}) self.response.setArg(BARE_NS, 'nonce', nonce_value) - self.consumer._idResCheckNonce(self.response, self.endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.consumer._idResCheckNonce(self.response, self.endpoint) + self.assertEqual(logbook.records, []) def test_openid1Missing(self): """use consumer-generated nonce""" self.response = Message.fromOpenIDArgs({}) - n = self.consumer._idResGetNonceOpenID1(self.response, self.endpoint) + with LogCapture() as logbook: + n = self.consumer._idResGetNonceOpenID1(self.response, self.endpoint) self.assertIsNone(n) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) def test_consumerNonceOpenID2(self): """OpenID 2 does not use consumer-generated nonce""" self.return_to = 'http://rt.unittest/?nonce=%s' % (mkNonce(),) self.response = Message.fromOpenIDArgs( {'return_to': self.return_to, 'ns': OPENID2_NS}) - self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) + self.assertEqual(logbook.records, []) def test_serverNonce(self): """use server-generated nonce""" self.response = Message.fromOpenIDArgs({'ns': OPENID2_NS, 'response_nonce': mkNonce()}) - self.consumer._idResCheckNonce(self.response, self.endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.consumer._idResCheckNonce(self.response, self.endpoint) + self.assertEqual(logbook.records, []) def test_serverNonceOpenID1(self): """OpenID 1 does not use server-generated nonce""" self.response = Message.fromOpenIDArgs( {'ns': OPENID1_NS, 'return_to': 'http://return.to/', 'response_nonce': mkNonce()}) - self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) + self.assertEqual(logbook.records, []) def test_badNonce(self): """remove the nonce from the store @@ -880,8 +866,9 @@ class CheckNonceVerifyTest(TestIdRes, CatchLogs): """When there is no store, checking the nonce succeeds""" self.consumer.store = None self.response = Message.fromOpenIDArgs({'response_nonce': mkNonce(), 'ns': OPENID2_NS}) - self.consumer._idResCheckNonce(self.response, self.endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.consumer._idResCheckNonce(self.response, self.endpoint) + self.assertEqual(logbook.records, []) def test_tamperedNonce(self): """Malformed nonce""" @@ -905,12 +892,11 @@ class CheckAuthDetectingConsumer(GenericConsumer): return True -class TestCheckAuthTriggered(TestIdRes, CatchLogs): +class TestCheckAuthTriggered(TestIdRes): consumer_class = CheckAuthDetectingConsumer def setUp(self): TestIdRes.setUp(self) - CatchLogs.setUp(self) self.disableDiscoveryVerification() def test_checkAuthTriggered(self): @@ -1156,11 +1142,10 @@ class BadArgCheckingConsumer(GenericConsumer): return None -class TestCheckAuth(unittest.TestCase, CatchLogs): +class TestCheckAuth(unittest.TestCase): consumer_class = GenericConsumer def setUp(self): - CatchLogs.setUp(self) self.store = memstore.MemoryStore() self.consumer = self.consumer_class(self.store) @@ -1170,7 +1155,6 @@ class TestCheckAuth(unittest.TestCase, CatchLogs): fetchers.setDefaultFetcher(self.fetcher) def tearDown(self): - CatchLogs.tearDown(self) fetchers.setDefaultFetcher(self._orig_fetcher, wrap_exceptions=False) def test_error(self): @@ -1178,10 +1162,12 @@ class TestCheckAuth(unittest.TestCase, CatchLogs): "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n') query = {'openid.signed': 'stuff', 'openid.stuff': 'a value'} - r = self.consumer._checkAuth(Message.fromPostArgs(query), - http_server_url) + with LogCapture() as logbook: + r = self.consumer._checkAuth(Message.fromPostArgs(query), http_server_url) self.assertFalse(r) - self.assertTrue(self.messages) + logbook.check(('openid.consumer.consumer', 'INFO', 'Using OpenID check_authentication'), + ('openid.consumer.consumer', 'INFO', 'stuff'), + ('openid.consumer.consumer', 'ERROR', StringComparison('check_authentication failed: .*: 404'))) def test_bad_args(self): query = { @@ -1236,11 +1222,10 @@ class TestCheckAuth(unittest.TestCase, CatchLogs): self.assertEqual(car.toPostArgs(), expected_args) -class TestFetchAssoc(unittest.TestCase, CatchLogs): +class TestFetchAssoc(unittest.TestCase): consumer_class = GenericConsumer def setUp(self): - CatchLogs.setUp(self) self.store = memstore.MemoryStore() self.fetcher = MockFetcher() fetchers.setDefaultFetcher(self.fetcher) diff --git a/openid/test/test_kvform.py b/openid/test/test_kvform.py index 09532db..187629a 100644 --- a/openid/test/test_kvform.py +++ b/openid/test/test_kvform.py @@ -1,39 +1,23 @@ import unittest -from openid import kvform -from openid.test.support import CatchLogs - - -class KVBaseTest(unittest.TestCase, CatchLogs): - - def checkWarnings(self, num_warnings, msg=None): - full_msg = 'Invalid number of warnings {} != {}'.format(num_warnings, len(self.messages)) - if msg is not None: - full_msg = full_msg + ' ' + msg - self.assertEqual(num_warnings, len(self.messages), full_msg) +from testfixtures import LogCapture - def setUp(self): - CatchLogs.setUp(self) - - def tearDown(self): - CatchLogs.tearDown(self) +from openid import kvform -class KVDictTest(KVBaseTest): +class KVDictTest(unittest.TestCase): def runTest(self): for kv_data, result, expected_warnings in kvdict_cases: - # Clean captrured messages - del self.messages[:] - # Convert KVForm to dict - d = kvform.kvToDict(kv_data) + with LogCapture() as logbook: + d = kvform.kvToDict(kv_data) # make sure it parses to expected dict self.assertEqual(d, result) # Check to make sure we got the expected number of warnings - self.checkWarnings(expected_warnings, msg='kvToDict({!r})'.format(kv_data)) + self.assertEqual(len(logbook.records), expected_warnings) # Convert back to KVForm and round-trip back to dict to make # sure that *** dict -> kv -> dict is identity. *** @@ -42,7 +26,7 @@ class KVDictTest(KVBaseTest): self.assertEqual(d, d2) -class KVSeqTest(KVBaseTest): +class KVSeqTest(unittest.TestCase): def cleanSeq(self, seq): """Create a new sequence by stripping whitespace from start @@ -58,11 +42,9 @@ class KVSeqTest(KVBaseTest): def runTest(self): for kv_data, result, expected_warnings in kvseq_cases: - # Clean captrured messages - del self.messages[:] - # seq serializes to expected kvform - actual = kvform.seqToKV(kv_data) + with LogCapture() as logbook: + actual = kvform.seqToKV(kv_data) self.assertEqual(actual, result) self.assertIsInstance(actual, str) @@ -73,7 +55,8 @@ class KVSeqTest(KVBaseTest): clean_seq = self.cleanSeq(seq) self.assertEqual(seq, clean_seq) - self.checkWarnings(expected_warnings) + self.assertEqual(len(logbook.records), expected_warnings, + "Invalid warnings for {}: {}".format(kv_data, [r.getMessage() for r in logbook.records])) kvdict_cases = [ @@ -119,16 +102,16 @@ kvseq_cases = [ ([('openid', 'useful'), ('a', 'b')], 'openid:useful\na:b\n', 0), # Warnings about leading whitespace - ([(' openid', 'useful'), ('a', 'b')], ' openid:useful\na:b\n', 2), + ([(' openid', 'useful'), ('a', 'b')], ' openid:useful\na:b\n', 1), # Warnings about leading and trailing whitespace ([(' openid ', ' useful '), - (' a ', ' b ')], ' openid : useful \n a : b \n', 8), + (' a ', ' b ')], ' openid : useful \n a : b \n', 4), # warnings about leading and trailing whitespace, but not about # internal whitespace. ([(' open id ', ' use ful '), - (' a ', ' b ')], ' open id : use ful \n a : b \n', 8), + (' a ', ' b ')], ' open id : use ful \n a : b \n', 4), ([(u'foo', 'bar')], 'foo:bar\n', 0), ] @@ -150,10 +133,11 @@ class KVExcTest(unittest.TestCase): self.assertRaises(ValueError, kvform.seqToKV, kv_data) -class GeneralTest(KVBaseTest): +class GeneralTest(unittest.TestCase): kvform = '<None>' def test_convert(self): - result = kvform.seqToKV([(1, 1)]) + with LogCapture() as logbook: + result = kvform.seqToKV([(1, 1)]) self.assertEqual(result, '1:1\n') - self.checkWarnings(2) + self.assertEqual(len(logbook.records), 2) diff --git a/openid/test/test_negotiation.py b/openid/test/test_negotiation.py index 71ff200..6c4cf1f 100644 --- a/openid/test/test_negotiation.py +++ b/openid/test/test_negotiation.py @@ -1,12 +1,12 @@ import unittest +from testfixtures import LogCapture, StringComparison + from openid import association from openid.consumer.consumer import GenericConsumer, ServerError from openid.consumer.discover import OPENID_2_0_TYPE, OpenIDServiceEndpoint from openid.message import OPENID1_NS, OPENID_NS, Message -from .support import CatchLogs - class ErrorRaisingConsumer(GenericConsumer): """ @@ -29,14 +29,13 @@ class ErrorRaisingConsumer(GenericConsumer): return m -class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): +class TestOpenID2SessionNegotiation(unittest.TestCase): """ Test the session type negotiation behavior of an OpenID 2 consumer. """ def setUp(self): - CatchLogs.setUp(self) self.consumer = ErrorRaisingConsumer(store=None) self.endpoint = OpenIDServiceEndpoint() @@ -49,8 +48,10 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): server error or is otherwise undecipherable. """ self.consumer.return_messages = [Message(self.endpoint.preferredNamespace())] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testEmptyAssocType(self): """ @@ -64,11 +65,11 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'session_type', 'new-session-type') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Unsupported association type', - 'Server responded with unsupported association ' + - 'session but did not supply a fallback.') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + no_fallback_msg = 'Server responded with unsupported association session but did not supply a fallback.' + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'ERROR', no_fallback_msg)) def testEmptySessionType(self): """ @@ -82,11 +83,11 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): # not set: msg.setArg(OPENID_NS, 'session_type', None) self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Unsupported association type', - 'Server responded with unsupported association ' + - 'session but did not supply a fallback.') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + no_fallback_msg = 'Server responded with unsupported association session but did not supply a fallback.' + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'ERROR', no_fallback_msg)) def testNotAllowed(self): """ @@ -106,10 +107,11 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'session_type', 'not-allowed') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Unsupported association type', - 'Server sent unsupported session/association type:') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + unsupported_msg = StringComparison('Server sent unsupported session/association type: .*') + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'ERROR', unsupported_msg)) def testUnsupportedWithRetry(self): """ @@ -126,9 +128,9 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [msg, assoc] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) - - self.failUnlessLogMatches('Unsupported association type') + with LogCapture() as logbook: + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*'))) def testUnsupportedWithRetryAndFail(self): """ @@ -144,10 +146,11 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): self.consumer.return_messages = [msg, Message(self.endpoint.preferredNamespace())] - self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) - - self.failUnlessLogMatches('Unsupported association type', - 'Server %s refused' % (self.endpoint.server_url)) + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + refused_msg = StringComparison('Server %s refused its .*' % self.endpoint.server_url) + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'ERROR', refused_msg)) def testValid(self): """ @@ -158,23 +161,22 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [assoc] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) + self.assertEqual(logbook.records, []) -class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): +class TestOpenID1SessionNegotiation(unittest.TestCase): """ Tests for the OpenID 1 consumer association session behavior. See the docs for TestOpenID2SessionNegotiation. Notice that this class is not a subclass of the OpenID 2 tests. Instead, it uses - many of the same inputs but inspects the log messages. - See the calls to self.failUnlessLogMatches. Some of - these tests pass openid2-style messages to the openid 1 + many of the same inputs but inspects the log messages, see the LogCapture. + Some of these tests pass openid2-style messages to the openid 1 association processing logic to be sure it ignores the extra data. """ def setUp(self): - CatchLogs.setUp(self) self.consumer = ErrorRaisingConsumer(store=None) self.endpoint = OpenIDServiceEndpoint() @@ -183,8 +185,10 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): def testBadResponse(self): self.consumer.return_messages = [Message(self.endpoint.preferredNamespace())] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testEmptyAssocType(self): msg = Message(self.endpoint.preferredNamespace()) @@ -194,9 +198,10 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'session_type', 'new-session-type') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testEmptySessionType(self): msg = Message(self.endpoint.preferredNamespace()) @@ -206,9 +211,10 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): # not set: msg.setArg(OPENID_NS, 'session_type', None) self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testNotAllowed(self): allowed_types = [] @@ -223,9 +229,10 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'session_type', 'not-allowed') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testUnsupportedWithRetry(self): msg = Message(self.endpoint.preferredNamespace()) @@ -238,20 +245,22 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [msg, assoc] - self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) - - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testValid(self): assoc = association.Association( 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [assoc] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) + self.assertEqual(logbook.records, []) -class TestNegotiatorBehaviors(unittest.TestCase, CatchLogs): +class TestNegotiatorBehaviors(unittest.TestCase): def setUp(self): self.allowed_types = [ ('HMAC-SHA1', 'no-encryption'), diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index d12cc5b..04b693e 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -5,8 +5,9 @@ __all__ = ['TestBuildDiscoveryURL'] import unittest +from testfixtures import LogCapture, StringComparison + from openid.server import trustroot -from openid.test.support import CatchLogs from openid.yadis import services from openid.yadis.discover import DiscoveryFailure, DiscoveryResult @@ -190,13 +191,7 @@ class TestReturnToMatches(unittest.TestCase): self.assertFalse(trustroot.returnToMatches([r], 'http://example.com/xss_exploit')) -class TestVerifyReturnTo(unittest.TestCase, CatchLogs): - - def setUp(self): - CatchLogs.setUp(self) - - def tearDown(self): - CatchLogs.tearDown(self) +class TestVerifyReturnTo(unittest.TestCase): def test_bogusRealm(self): self.assertFalse(trustroot.verifyReturnTo('', 'http://example.com/')) @@ -209,8 +204,9 @@ class TestVerifyReturnTo(unittest.TestCase, CatchLogs): self.assertEqual(disco_url, 'http://www.example.com/') return [return_to] - self.assertTrue(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertTrue(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) + self.assertEqual(logbook.records, []) def test_verifyFailWithDiscoveryCalled(self): realm = 'http://*.example.com/' @@ -220,8 +216,9 @@ class TestVerifyReturnTo(unittest.TestCase, CatchLogs): self.assertEqual(disco_url, 'http://www.example.com/') return ['http://something-else.invalid/'] - self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) - self.failUnlessLogMatches("Failed to validate return_to") + with LogCapture() as logbook: + self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) + logbook.check(('openid.server.trustroot', 'ERROR', StringComparison('Failed to validate return_to .*'))) def test_verifyFailIfDiscoveryRedirects(self): realm = 'http://*.example.com/' @@ -231,8 +228,9 @@ class TestVerifyReturnTo(unittest.TestCase, CatchLogs): raise trustroot.RealmVerificationRedirected( disco_url, "http://redirected.invalid") - self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) - self.failUnlessLogMatches("Attempting to verify") + with LogCapture() as logbook: + self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) + logbook.check(('openid.server.trustroot', 'ERROR', StringComparison('Attempting to verify .*'))) if __name__ == '__main__': diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 8fd8ac8..c61878b 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -5,12 +5,13 @@ import unittest from functools import partial from urlparse import urlparse +from testfixtures import LogCapture, StringComparison + from openid import association, cryptutil, oidutil from openid.consumer.consumer import DiffieHellmanSHA256ConsumerSession from openid.message import IDENTIFIER_SELECT, OPENID1_NS, OPENID1_URL_LIMIT, OPENID2_NS, OPENID_NS, Message, no_default from openid.server import server from openid.store import memstore -from openid.test.support import CatchLogs # In general, if you edit or add tests here, try to move in the direction # of testing smaller units. For testing the external interfaces, we'll be @@ -1576,11 +1577,10 @@ class Counter(object): self.count += 1 -class TestServer(unittest.TestCase, CatchLogs): +class TestServer(unittest.TestCase): def setUp(self): self.store = memstore.MemoryStore() self.server = server.Server(self.store, "http://server.unittest/endpt") - CatchLogs.setUp(self) def test_dispatch(self): monkeycalled = Counter() @@ -1689,13 +1689,12 @@ class TestServer(unittest.TestCase, CatchLogs): self.assertTrue(response.fields.hasKey(OPENID_NS, "is_valid")) -class TestSignatory(unittest.TestCase, CatchLogs): +class TestSignatory(unittest.TestCase): def setUp(self): self.store = memstore.MemoryStore() self.signatory = server.Signatory(self.store) self._dumb_key = self.signatory._dumb_key self._normal_key = self.signatory._normal_key - CatchLogs.setUp(self) def test_sign(self): request = server.OpenIDRequest() @@ -1712,11 +1711,12 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'bar': 'notsigned', 'azu': 'alsosigned', }) - sresponse = self.signatory.sign(response) + with LogCapture() as logbook: + sresponse = self.signatory.sign(response) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'assoc_handle'), assoc_handle) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,signed') self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) def test_signDumb(self): request = server.OpenIDRequest() @@ -1729,14 +1729,15 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'azu': 'alsosigned', 'ns': OPENID2_NS, }) - sresponse = self.signatory.sign(response) + with LogCapture() as logbook: + sresponse = self.signatory.sign(response) assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.assertTrue(assoc_handle) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) self.assertTrue(assoc) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,ns,signed') self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) def test_signExpired(self): """Sign a response to a message with an expired handle (using invalidate_handle). @@ -1768,7 +1769,8 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'bar': 'notsigned', 'azu': 'alsosigned', }) - sresponse = self.signatory.sign(response) + with LogCapture() as logbook: + sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.assertTrue(new_assoc_handle) @@ -1787,7 +1789,7 @@ class TestSignatory(unittest.TestCase, CatchLogs): # make sure the new key is a dumb mode association self.assertTrue(self.store.getAssociation(self._dumb_key, new_assoc_handle)) self.assertFalse(self.store.getAssociation(self._normal_key, new_assoc_handle)) - self.assertTrue(self.messages) + logbook.check(('openid.server.server', 'INFO', StringComparison('requested .* key .* is expired .*'))) def test_signInvalidHandle(self): request = server.OpenIDRequest() @@ -1801,7 +1803,8 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'bar': 'notsigned', 'azu': 'alsosigned', }) - sresponse = self.signatory.sign(response) + with LogCapture() as logbook: + sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.assertTrue(new_assoc_handle) @@ -1816,7 +1819,7 @@ class TestSignatory(unittest.TestCase, CatchLogs): # make sure the new key is a dumb mode association self.assertTrue(self.store.getAssociation(self._dumb_key, new_assoc_handle)) self.assertFalse(self.store.getAssociation(self._normal_key, new_assoc_handle)) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) def test_verify(self): assoc_handle = '{vroom}{zoom}' @@ -1833,8 +1836,9 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco=', }) - verified = self.signatory.verify(assoc_handle, signed) - self.assertFalse(self.messages) + with LogCapture() as logbook: + verified = self.signatory.verify(assoc_handle, signed) + self.assertEqual(logbook.records, []) self.assertTrue(verified) def test_verifyBadSig(self): @@ -1852,8 +1856,9 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco='.encode('rot13'), }) - verified = self.signatory.verify(assoc_handle, signed) - self.assertFalse(self.messages) + with LogCapture() as logbook: + verified = self.signatory.verify(assoc_handle, signed) + self.assertEqual(logbook.records, []) self.assertFalse(verified) def test_verifyBadHandle(self): @@ -1864,9 +1869,10 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'openid.sig': "Ylu0KcIR7PvNegB/K41KpnRgJl0=", }) - verified = self.signatory.verify(assoc_handle, signed) + with LogCapture() as logbook: + verified = self.signatory.verify(assoc_handle, signed) self.assertFalse(verified) - self.assertTrue(self.messages) + logbook.check(('openid.server.server', 'ERROR', StringComparison('failed to get assoc with handle .*'))) def test_verifyAssocMismatch(self): """Attempt to validate sign-all message with a signed-list assoc.""" @@ -1882,33 +1888,38 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'openid.sig': "d71xlHtqnq98DonoSgoK/nD+QRM=", }) - verified = self.signatory.verify(assoc_handle, signed) + with LogCapture() as logbook: + verified = self.signatory.verify(assoc_handle, signed) self.assertFalse(verified) - self.assertTrue(self.messages) + logbook.check(('openid.server.server', 'ERROR', StringComparison('Error in verifying .*'))) def test_getAssoc(self): assoc_handle = self.makeAssoc(dumb=True) - assoc = self.signatory.getAssociation(assoc_handle, True) + with LogCapture() as logbook: + assoc = self.signatory.getAssociation(assoc_handle, True) self.assertTrue(assoc) self.assertEqual(assoc.handle, assoc_handle) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) def test_getAssocExpired(self): assoc_handle = self.makeAssoc(dumb=True, lifetime=-10) - assoc = self.signatory.getAssociation(assoc_handle, True) + with LogCapture() as logbook: + assoc = self.signatory.getAssociation(assoc_handle, True) self.assertFalse(assoc) - self.assertTrue(self.messages) + logbook.check(('openid.server.server', 'INFO', StringComparison('requested .* key .* is expired .*'))) def test_getAssocInvalid(self): ah = 'no-such-handle' - self.assertIsNone(self.signatory.getAssociation(ah, dumb=False)) - self.assertFalse(self.messages) + with LogCapture() as logbook: + self.assertIsNone(self.signatory.getAssociation(ah, dumb=False)) + self.assertEqual(logbook.records, []) def test_getAssocDumbVsNormal(self): """getAssociation(dumb=False) cannot get a dumb assoc""" assoc_handle = self.makeAssoc(dumb=True) - self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=False)) - self.assertFalse(self.messages) + with LogCapture() as logbook: + self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=False)) + self.assertEqual(logbook.records, []) def test_getAssocNormalVsDumb(self): """getAssociation(dumb=True) cannot get a shared assoc @@ -1919,13 +1930,15 @@ class TestSignatory(unittest.TestCase, CatchLogs): MAC keys. """ assoc_handle = self.makeAssoc(dumb=False) - self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=True)) - self.assertFalse(self.messages) + with LogCapture() as logbook: + self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=True)) + self.assertEqual(logbook.records, []) def test_createAssociation(self): - assoc = self.signatory.createAssociation(dumb=False) + with LogCapture() as logbook: + assoc = self.signatory.createAssociation(dumb=False) self.assertTrue(self.signatory.getAssociation(assoc.handle, dumb=False)) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) def makeAssoc(self, dumb, lifetime=60): assoc_handle = '{bling}' @@ -1945,10 +1958,11 @@ class TestSignatory(unittest.TestCase, CatchLogs): self.assertTrue(assoc) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) self.assertTrue(assoc) - self.signatory.invalidate(assoc_handle, dumb=True) + with LogCapture() as logbook: + self.signatory.invalidate(assoc_handle, dumb=True) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) self.assertFalse(assoc) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) if __name__ == '__main__': diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index c0055ef..ec69a62 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -1,5 +1,7 @@ import unittest +from testfixtures import LogCapture, StringComparison + from openid import message from openid.consumer import consumer, discover from openid.test.support import OpenIDTestMixin @@ -25,48 +27,51 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint.claimed_id = 'bogus' msg = message.Message.fromOpenIDArgs({}) - self.failUnlessProtocolError( - 'Missing required field openid.identity', - self.consumer._verifyDiscoveryResults, msg, endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.failUnlessProtocolError('Missing required field openid.identity', + self.consumer._verifyDiscoveryResults, msg, endpoint) + self.assertEqual(logbook.records, []) def test_openID1NoEndpoint(self): msg = message.Message.fromOpenIDArgs({'identity': 'snakes on a plane'}) - self.assertRaises(RuntimeError, self.consumer._verifyDiscoveryResults, msg) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertRaises(RuntimeError, self.consumer._verifyDiscoveryResults, msg) + self.assertEqual(logbook.records, []) def test_openID2NoOPEndpointArg(self): msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS}) - self.assertRaises(KeyError, self.consumer._verifyDiscoveryResults, msg) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertRaises(KeyError, self.consumer._verifyDiscoveryResults, msg) + self.assertEqual(logbook.records, []) def test_openID2LocalIDNoClaimed(self): msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, 'op_endpoint': 'Phone Home', 'identity': 'Jose Lius Borges'}) - self.failUnlessProtocolError( - 'openid.identity is present without', - self.consumer._verifyDiscoveryResults, msg) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.failUnlessProtocolError('openid.identity is present without', + self.consumer._verifyDiscoveryResults, msg) + self.assertEqual(logbook.records, []) def test_openID2NoLocalIDClaimed(self): msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, 'op_endpoint': 'Phone Home', 'claimed_id': 'Manuel Noriega'}) - self.failUnlessProtocolError( - 'openid.claimed_id is present without', - self.consumer._verifyDiscoveryResults, msg) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.failUnlessProtocolError('openid.claimed_id is present without', + self.consumer._verifyDiscoveryResults, msg) + self.assertEqual(logbook.records, []) def test_openID2NoIdentifiers(self): op_endpoint = 'Phone Home' msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, 'op_endpoint': op_endpoint}) - result_endpoint = self.consumer._verifyDiscoveryResults(msg) + with LogCapture() as logbook: + result_endpoint = self.consumer._verifyDiscoveryResults(msg) self.assertTrue(result_endpoint.isOPIdentifier()) self.assertEqual(result_endpoint.server_url, op_endpoint) self.assertIsNone(result_endpoint.claimed_id) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) def test_openID2NoEndpointDoesDisco(self): op_endpoint = 'Phone Home' @@ -78,9 +83,10 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): 'identity': 'sour grapes', 'claimed_id': 'monkeysoft', 'op_endpoint': op_endpoint}) - result = self.consumer._verifyDiscoveryResults(msg) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoveryResults(msg) self.assertEqual(result, sentinel) - self.failUnlessLogMatches('No pre-discovered') + logbook.check(('openid.consumer.consumer', 'INFO', 'No pre-discovered information supplied.')) def test_openID2MismatchedDoesDisco(self): mismatched = discover.OpenIDServiceEndpoint() @@ -96,10 +102,11 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): 'identity': 'sour grapes', 'claimed_id': 'monkeysoft', 'op_endpoint': op_endpoint}) - result = self.consumer._verifyDiscoveryResults(msg, mismatched) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoveryResults(msg, mismatched) self.assertEqual(result, sentinel) - self.failUnlessLogMatches('Error attempting to use stored', - 'Attempting discovery') + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) def test_openid2UsePreDiscovered(self): endpoint = discover.OpenIDServiceEndpoint() @@ -113,9 +120,10 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): 'identity': endpoint.local_id, 'claimed_id': endpoint.claimed_id, 'op_endpoint': endpoint.server_url}) - result = self.consumer._verifyDiscoveryResults(msg, endpoint) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.assertEqual(result, endpoint) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) def test_openid2UsePreDiscoveredWrongType(self): text = "verify failed" @@ -140,11 +148,12 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): 'claimed_id': endpoint.claimed_id, 'op_endpoint': endpoint.server_url}) - with self.assertRaisesRegexp(consumer.ProtocolError, text): - self.consumer._verifyDiscoveryResults(msg, endpoint) + with LogCapture() as logbook: + with self.assertRaisesRegexp(consumer.ProtocolError, text): + self.consumer._verifyDiscoveryResults(msg, endpoint) - self.failUnlessLogMatches('Error attempting to use stored', - 'Attempting discovery') + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) def test_openid1UsePreDiscovered(self): endpoint = discover.OpenIDServiceEndpoint() @@ -156,9 +165,10 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): msg = message.Message.fromOpenIDArgs( {'ns': message.OPENID1_NS, 'identity': endpoint.local_id}) - result = self.consumer._verifyDiscoveryResults(msg, endpoint) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.assertEqual(result, endpoint) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) def test_openid1UsePreDiscoveredWrongType(self): class VerifiedError(Exception): @@ -179,10 +189,10 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): {'ns': message.OPENID1_NS, 'identity': endpoint.local_id}) - self.assertRaises(VerifiedError, self.consumer._verifyDiscoveryResults, msg, endpoint) - - self.failUnlessLogMatches('Error attempting to use stored', - 'Attempting discovery') + with LogCapture() as logbook: + self.assertRaises(VerifiedError, self.consumer._verifyDiscoveryResults, msg, endpoint) + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) def test_openid2Fragment(self): claimed_id = "http://unittest.invalid/" @@ -198,15 +208,15 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): 'identity': endpoint.local_id, 'claimed_id': claimed_id_frag, 'op_endpoint': endpoint.server_url}) - result = self.consumer._verifyDiscoveryResults(msg, endpoint) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.assertEqual(result.local_id, endpoint.local_id) self.assertEqual(result.server_url, endpoint.server_url) self.assertEqual(result.type_uris, endpoint.type_uris) - self.assertEqual(result.claimed_id, claimed_id_frag) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) def test_openid1Fallback1_0(self): claimed_id = 'http://claimed.id/' @@ -248,10 +258,11 @@ class TestVerifyDiscoverySingle(TestIdRes): to_match.server_url = "http://localhost:8000/openidserver" to_match.claimed_id = "http://localhost:8000/id/id-jo" to_match.local_id = "http://localhost:8000/id/id-jo" - result = self.consumer._verifyDiscoverySingle(endpoint, to_match) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoverySingle(endpoint, to_match) # result should always be None, raises exception on failure. self.assertIsNone(result) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) if __name__ == '__main__': |