"""Tests for openid.server. """ from __future__ import unicode_literals import unittest import warnings from functools import partial import six from cryptography.hazmat.primitives import hashes from mock import sentinel from six.moves.urllib.parse import parse_qs, parse_qsl, urlparse from testfixtures import LogCapture, ShouldWarn, StringComparison from openid import association, cryptutil, oidutil from openid.consumer.consumer import DiffieHellmanSHA256ConsumerSession from openid.dh import DiffieHellman from openid.message import IDENTIFIER_SELECT, OPENID1_NS, OPENID1_URL_LIMIT, OPENID2_NS, OPENID_NS, Message, no_default from openid.server import server from openid.server.server import DiffieHellmanSHA1ServerSession from openid.store import memstore # In general, if you edit or add tests here, try to move in the direction # of testing smaller units. For testing the external interfaces, we'll be # developing an implementation-agnostic testing suite. # for more, see /etc/ssh/moduli ALT_MODULUS = ('AMqt3ewWZ/xotfoV1TxOFTLdJFYaGi1HoSwBq+oeAHMfaSGqxAdCMR/fnmNLtxMb7hryQCYVVDiakQQl4ETojINZsBD1rSuA4pyxpbA' 'nsZ2eAab2Om9F5dftL/aioAhQUKfQzzB8PbUdJJA1WQe0QnwjqY3x64q+8rogm7ev/oan') ALT_GEN = 'BQ==' # Example values to be used in tests EXAMPLE_IDENTITY = 'http://id.example.cz/' EXAMPLE_CLAIMED_ID = 'http://claimed.example.cz/' def make_checkid_request(identity=EXAMPLE_IDENTITY, claimed_id=EXAMPLE_CLAIMED_ID, trust_root='http://realm.example.cz/', return_to='http://realm.example.cz/return_to/', op_endpoint='http://op.example.cz/', immediate=False, message=None): """Create a simple CheckIDRequest.""" message = message or Message(OPENID2_NS) return server.CheckIDRequest(identity=identity, claimed_id=claimed_id, trust_root=trust_root, return_to=return_to, op_endpoint=op_endpoint, immediate=immediate, message=message) class TestOpenIDRequest(unittest.TestCase): """Test OpenID request base class.""" def test_init_default_message(self): # Test empty OpenID 2.0 message is create if not provided. request = server.OpenIDRequest() self.assertTrue(request.message) self.assertEqual(request.message.getOpenIDNamespace(), OPENID2_NS) def test_namespace(self): # Test deprecated namespace property request = server.OpenIDRequest() warning_msg = ('The "namespace" attribute of OpenIDRequest objects is deprecated. Use ' '"message.getOpenIDNamespace()" instead') with ShouldWarn(DeprecationWarning(warning_msg)): warnings.simplefilter('always') self.assertEqual(request.namespace, OPENID2_NS) class TestProtocolError(unittest.TestCase): def test_browserWithReturnTo(self): return_to = "http://rp.unittest/consumer" # will be a ProtocolError raised by Decode or CheckIDRequest.answer args = Message.fromPostArgs({ 'openid.mode': 'monkeydance', 'openid.identity': 'http://wagu.unittest/', 'openid.return_to': return_to, }) e = server.ProtocolError(args, "plucky") self.assertTrue(e.hasReturnTo()) expected_args = { 'openid.mode': ['error'], 'openid.error': ['plucky'], } rt_base, result_args = e.encodeToURL().split('?', 1) result_args = parse_qs(result_args) self.assertEqual(result_args, expected_args) def test_browserWithReturnTo_OpenID2_GET(self): return_to = "http://rp.unittest/consumer" # will be a ProtocolError raised by Decode or CheckIDRequest.answer args = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, 'openid.mode': 'monkeydance', 'openid.identity': 'http://wagu.unittest/', 'openid.claimed_id': 'http://wagu.unittest/', 'openid.return_to': return_to, }) e = server.ProtocolError(args, "plucky") self.assertTrue(e.hasReturnTo()) expected_args = { 'openid.ns': [OPENID2_NS], 'openid.mode': ['error'], 'openid.error': ['plucky'], } rt_base, result_args = e.encodeToURL().split('?', 1) result_args = parse_qs(result_args) self.assertEqual(result_args, expected_args) def test_browserWithReturnTo_OpenID2_POST(self): return_to = "http://rp.unittest/consumer" + ('x' * OPENID1_URL_LIMIT) # will be a ProtocolError raised by Decode or CheckIDRequest.answer args = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, 'openid.mode': 'monkeydance', 'openid.identity': 'http://wagu.unittest/', 'openid.claimed_id': 'http://wagu.unittest/', 'openid.return_to': return_to, }) e = server.ProtocolError(args, "plucky") self.assertTrue(e.hasReturnTo()) self.assertEqual(e.whichEncoding(), server.ENCODE_HTML_FORM) self.assertEqual(e.toFormMarkup(), e.toMessage().toFormMarkup(args.getArg(OPENID_NS, 'return_to'))) def test_browserWithReturnTo_OpenID1_exceeds_limit(self): return_to = "http://rp.unittest/consumer" + ('x' * OPENID1_URL_LIMIT) # will be a ProtocolError raised by Decode or CheckIDRequest.answer args = Message.fromPostArgs({ 'openid.mode': 'monkeydance', 'openid.identity': 'http://wagu.unittest/', 'openid.return_to': return_to, }) e = server.ProtocolError(args, "plucky") self.assertTrue(e.hasReturnTo()) expected_args = { 'openid.mode': ['error'], 'openid.error': ['plucky'], } self.assertEqual(e.whichEncoding(), server.ENCODE_URL) rt_base, result_args = e.encodeToURL().split('?', 1) result_args = parse_qs(result_args) self.assertEqual(result_args, expected_args) def test_noReturnTo(self): # will be a ProtocolError raised by Decode or CheckIDRequest.answer args = Message.fromPostArgs({ 'openid.mode': 'zebradance', 'openid.identity': 'http://wagu.unittest/', }) e = server.ProtocolError(args, "waffles") self.assertFalse(e.hasReturnTo()) expected = """error:waffles mode:error """ self.assertEqual(e.encodeToKVForm(), expected) def test_noMessage(self): e = server.ProtocolError(None, "no moar pancakes") self.assertFalse(e.hasReturnTo()) self.assertIsNone(e.whichEncoding()) class TestDecode(unittest.TestCase): def setUp(self): self.claimed_id = 'http://de.legating.de.coder.unittest/' self.id_url = "http://decoder.am.unittest/" self.rt_url = "http://rp.unittest/foobot/?qux=zam" self.tr_url = "http://rp.unittest/" self.assoc_handle = "{assoc}{handle}" self.op_endpoint = 'http://endpoint.unittest/encode' self.store = memstore.MemoryStore() self.server = server.Server(self.store, self.op_endpoint) self.decode = self.server.decoder.decode self.decode = server.Decoder(self.server).decode def test_none(self): args = {} r = self.decode(args) self.assertIsNone(r) def test_irrelevant(self): args = { 'pony': 'spotted', 'sreg.mutant_power': 'decaffinator', } self.assertRaises(server.ProtocolError, self.decode, args) def test_bad(self): args = { 'openid.mode': 'twos-compliment', 'openid.pants': 'zippered', } self.assertRaises(server.ProtocolError, self.decode, args) def test_dictOfLists(self): args = { 'openid.mode': ['checkid_setup'], 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, } with six.assertRaisesRegex(self, TypeError, 'values'): self.decode(args) def test_checkidImmediate(self): args = { 'openid.mode': 'checkid_immediate', 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, # should be ignored 'openid.some.extension': 'junk', } r = self.decode(args) self.assertIsInstance(r, server.CheckIDRequest) self.assertEqual(r.mode, "checkid_immediate") self.assertTrue(r.immediate) self.assertEqual(r.identity, self.id_url) self.assertEqual(r.trust_root, self.tr_url) self.assertEqual(r.return_to, self.rt_url) self.assertEqual(r.assoc_handle, self.assoc_handle) def test_checkidSetup(self): args = { 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, } r = self.decode(args) self.assertIsInstance(r, server.CheckIDRequest) self.assertEqual(r.mode, "checkid_setup") self.assertFalse(r.immediate) self.assertEqual(r.identity, self.id_url) self.assertEqual(r.trust_root, self.tr_url) self.assertEqual(r.return_to, self.rt_url) def test_checkidSetupOpenID2(self): args = { 'openid.ns': OPENID2_NS, 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.claimed_id': self.claimed_id, 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, } r = self.decode(args) self.assertIsInstance(r, server.CheckIDRequest) self.assertEqual(r.mode, "checkid_setup") self.assertFalse(r.immediate) self.assertEqual(r.identity, self.id_url) self.assertEqual(r.claimed_id, self.claimed_id) self.assertEqual(r.trust_root, self.tr_url) self.assertEqual(r.return_to, self.rt_url) def test_checkidSetupNoClaimedIDOpenID2(self): args = { 'openid.ns': OPENID2_NS, 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, } self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoIdentityOpenID2(self): args = { 'openid.ns': OPENID2_NS, 'openid.mode': 'checkid_setup', 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, } r = self.decode(args) self.assertIsInstance(r, server.CheckIDRequest) self.assertEqual(r.mode, "checkid_setup") self.assertFalse(r.immediate) self.assertIsNone(r.identity) self.assertEqual(r.trust_root, self.tr_url) self.assertEqual(r.return_to, self.rt_url) def test_checkidSetupEmptyIdentityOpenID2(self): args = { 'openid.ns': OPENID2_NS, 'openid.mode': 'checkid_setup', 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, 'openid.identity': '', } self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupEmptyClaimedIDOpenID2(self): args = { 'openid.ns': OPENID2_NS, 'openid.mode': 'checkid_setup', 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, 'openid.claimed_id': '', } self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoReturnOpenID1(self): """Make sure an OpenID 1 request cannot be decoded if it lacks a return_to. """ args = { 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.trust_root': self.tr_url, } self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoReturnOpenID2(self): """Make sure an OpenID 2 request with no return_to can be decoded, and make sure a response to such a request raises NoReturnToError. """ args = { 'openid.ns': OPENID2_NS, 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.claimed_id': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.realm': self.tr_url, } self.assertIsInstance(self.decode(args), server.CheckIDRequest) req = self.decode(args) self.assertRaises(server.NoReturnToError, req.answer, False) self.assertRaises(server.NoReturnToError, req.encodeToURL, 'bogus') self.assertRaises(server.NoReturnToError, req.getCancelURL) def test_checkidSetupRealmRequiredOpenID2(self): """Make sure that an OpenID 2 request which lacks return_to cannot be decoded if it lacks a realm. Spec: This value (openid.realm) MUST be sent if openid.return_to is omitted. """ args = { 'openid.ns': OPENID2_NS, 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, } self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupBadReturn(self): args = { 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': 'not a url', } with self.assertRaises(server.ProtocolError) as catch: self.decode(args) self.assertTrue(catch.exception.openid_message) def test_checkidSetupUntrustedReturn(self): args = { 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': 'http://not-the-return-place.unittest/', } with self.assertRaises(server.UntrustedReturnURL) as catch: self.decode(args) self.assertTrue(catch.exception.openid_message) def test_checkAuth(self): args = { 'openid.mode': 'check_authentication', 'openid.assoc_handle': '{dumb}{handle}', 'openid.sig': 'sigblob', 'openid.signed': 'identity,return_to,response_nonce,mode', 'openid.identity': 'signedval1', 'openid.return_to': 'signedval2', 'openid.response_nonce': 'signedval3', 'openid.baz': 'unsigned', } r = self.decode(args) self.assertIsInstance(r, server.CheckAuthRequest) self.assertEqual(r.mode, 'check_authentication') self.assertEqual(r.message.getArg(OPENID_NS, 'sig'), 'sigblob') def test_checkAuthMissingSignature(self): args = { 'openid.mode': 'check_authentication', 'openid.assoc_handle': '{dumb}{handle}', 'openid.signed': 'foo,bar,mode', 'openid.foo': 'signedval1', 'openid.bar': 'signedval2', 'openid.baz': 'unsigned', } self.assertRaises(server.ProtocolError, self.decode, args) def test_checkAuthAndInvalidate(self): args = { 'openid.mode': 'check_authentication', 'openid.assoc_handle': '{dumb}{handle}', 'openid.invalidate_handle': '[[SMART_handle]]', 'openid.sig': 'sigblob', 'openid.signed': 'identity,return_to,response_nonce,mode', 'openid.identity': 'signedval1', 'openid.return_to': 'signedval2', 'openid.response_nonce': 'signedval3', 'openid.baz': 'unsigned', } r = self.decode(args) self.assertIsInstance(r, server.CheckAuthRequest) self.assertEqual(r.invalidate_handle, '[[SMART_handle]]') def test_associateDH(self): args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", } r = self.decode(args) self.assertIsInstance(r, server.AssociateRequest) self.assertEqual(r.mode, "associate") self.assertEqual(r.session.session_type, "DH-SHA1") self.assertEqual(r.assoc_type, "HMAC-SHA1") self.assertTrue(r.session.consumer_pubkey) def test_associateDHMissingKey(self): """Trying DH assoc w/o public key""" args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', } # Using DH-SHA1 without supplying dh_consumer_public is an error. self.assertRaises(server.ProtocolError, self.decode, args) def test_associateDHpubKeyNotB64(self): args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "donkeydonkeydonkey", } self.assertRaises(server.ProtocolError, self.decode, args) def test_associateDHModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': ALT_MODULUS, 'openid.dh_gen': ALT_GEN, } r = self.decode(args) self.assertIsInstance(r, server.AssociateRequest) self.assertEqual(r.mode, "associate") self.assertEqual(r.session.session_type, "DH-SHA1") self.assertEqual(r.assoc_type, "HMAC-SHA1") self.assertEqual(r.session.dh.parameters, (ALT_MODULUS, ALT_GEN)) self.assertTrue(r.session.consumer_pubkey) def test_associateDHCorruptModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': 'pizza', 'openid.dh_gen': 'gnocchi', } self.assertRaises(server.ProtocolError, self.decode, args) def test_associateDHMissingModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': 'pizza', } self.assertRaises(server.ProtocolError, self.decode, args) def test_associateWeirdSession(self): args = { 'openid.mode': 'associate', 'openid.session_type': 'FLCL6', 'openid.dh_consumer_public': "YQ==\n", } self.assertRaises(server.ProtocolError, self.decode, args) def test_associatePlain(self): args = { 'openid.mode': 'associate', } r = self.decode(args) self.assertIsInstance(r, server.AssociateRequest) self.assertEqual(r.mode, "associate") self.assertEqual(r.session.session_type, "no-encryption") self.assertEqual(r.assoc_type, "HMAC-SHA1") def test_nomode(self): args = { 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "my public keeey", } self.assertRaises(server.ProtocolError, self.decode, args) def test_invalidns(self): args = {'openid.ns': 'Tuesday', 'openid.mode': 'associate'} with six.assertRaisesRegex(self, server.ProtocolError, 'Tuesday') as catch: self.decode(args) self.assertTrue(catch.exception.openid_message) class TestEncode(unittest.TestCase): def setUp(self): self.encoder = server.Encoder() self.encode = self.encoder.encode self.op_endpoint = 'http://endpoint.unittest/encode' self.store = memstore.MemoryStore() self.server = server.Server(self.store, self.op_endpoint) def test_id_res_OpenID2_GET(self): """ Check that when an OpenID 2 response does not exceed the OpenID 1 message size, a GET response (i.e., redirect) is issued. """ request = make_checkid_request() response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'ns': OPENID2_NS, 'mode': 'id_res', 'identity': request.identity, 'claimed_id': request.identity, 'return_to': request.return_to, }) self.assertFalse(response.renderAsForm()) self.assertEqual(response.whichEncoding(), server.ENCODE_URL) webresponse = self.encode(response) self.assertIn('location', webresponse.headers) def test_id_res_OpenID2_POST(self): """ Check that when an OpenID 2 response exceeds the OpenID 1 message size, a POST response (i.e., an HTML form) is returned. """ request = make_checkid_request() response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'ns': OPENID2_NS, 'mode': 'id_res', 'identity': request.identity, 'claimed_id': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, }) self.assertTrue(response.renderAsForm()) self.assertGreater(len(response.encodeToURL()), OPENID1_URL_LIMIT) self.assertEqual(response.whichEncoding(), server.ENCODE_HTML_FORM) webresponse = self.encode(response) self.assertIn(response.toFormMarkup(), webresponse.body) def test_toFormMarkup(self): request = make_checkid_request() response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'ns': OPENID2_NS, 'mode': 'id_res', 'identity': request.identity, 'claimed_id': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, }) form_markup = response.toFormMarkup({'foo': 'bar'}) self.assertIn(' foo="bar"', form_markup) def test_toHTML(self): request = make_checkid_request() response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'ns': OPENID2_NS, 'mode': 'id_res', 'identity': request.identity, 'claimed_id': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, }) html = response.toHTML() self.assertIn('', html) self.assertIn('', html) self.assertIn('