summaryrefslogtreecommitdiff
path: root/openid
diff options
context:
space:
mode:
authorVlastimil Zíma <vlastimil.zima@nic.cz>2018-02-02 18:38:00 +0100
committerVlastimil Zíma <vlastimil.zima@nic.cz>2018-02-07 14:29:16 +0100
commit9f9050b012a71a9eb5ed4e2996672ff15d07cda1 (patch)
tree25b5e589d415e3a56b54ee586c6e26c99c50c2b9 /openid
parent546f09f740b74d8aa474149c33b639bd2c9bc626 (diff)
downloadopenid-9f9050b012a71a9eb5ed4e2996672ff15d07cda1.tar.gz
Refactor server request message
* Move `message` instance variable into base `OpenIDRequest`. * Deprecate `namespace` property for all requests. * Fix up request constructors.
Diffstat (limited to 'openid')
-rw-r--r--openid/server/server.py177
-rw-r--r--openid/test/test_server.py171
-rw-r--r--openid/test/test_sreg.py4
3 files changed, 147 insertions, 205 deletions
diff --git a/openid/server/server.py b/openid/server/server.py
index 8d45bc8..dfe4444 100644
--- a/openid/server/server.py
+++ b/openid/server/server.py
@@ -150,9 +150,26 @@ class OpenIDRequest(object):
@cvar mode: the C{X{openid.mode}} of this request.
@type mode: str
+
+ @ivar message: Original request message.
+ @type message: Message
"""
mode = None
+ def __init__(self, message=None):
+ if message is not None:
+ self.message = message
+ else:
+ # If no message is defined, create an empty one.
+ self.message = Message(OPENID2_NS)
+
+ @property
+ def namespace(self):
+ """Return request namespace."""
+ msg = 'The "namespace" attribute of {} objects is deprecated. Use "message.getOpenIDNamespace()" instead'
+ warnings.warn(msg.format(type(self).__name__), DeprecationWarning, stacklevel=2)
+ return self.message.getOpenIDNamespace()
+
class CheckAuthRequest(OpenIDRequest):
"""A request to verify the validity of a previous response.
@@ -176,7 +193,7 @@ class CheckAuthRequest(OpenIDRequest):
required_fields = ["identity", "return_to", "response_nonce"]
- def __init__(self, assoc_handle, signed, invalidate_handle=None):
+ def __init__(self, assoc_handle, signed, invalidate_handle=None, message=None):
"""Construct me.
These parameters are assigned directly as class attributes, see
@@ -186,10 +203,10 @@ class CheckAuthRequest(OpenIDRequest):
@type signed: L{Message}
@type invalidate_handle: str
"""
+ super(CheckAuthRequest, self).__init__(message=message)
self.assoc_handle = assoc_handle
self.signed = signed
self.invalidate_handle = invalidate_handle
- self.namespace = OPENID2_NS
@classmethod
def fromMessage(klass, message, op_endpoint=UNUSED):
@@ -200,28 +217,22 @@ class CheckAuthRequest(OpenIDRequest):
@returntype: L{CheckAuthRequest}
"""
- self = klass.__new__(klass)
- self.message = message
- self.namespace = message.getOpenIDNamespace()
- self.assoc_handle = message.getArg(OPENID_NS, 'assoc_handle')
- self.sig = message.getArg(OPENID_NS, 'sig')
-
- if (self.assoc_handle is None or
- self.sig is None):
+ assoc_handle = message.getArg(OPENID_NS, 'assoc_handle')
+ sig = message.getArg(OPENID_NS, 'sig')
+ invalidate_handle = message.getArg(OPENID_NS, 'invalidate_handle')
+ if (assoc_handle is None or sig is None):
fmt = "%s request missing required parameter from message %s"
- raise ProtocolError(
- message, text=fmt % (self.mode, message))
-
- self.invalidate_handle = message.getArg(OPENID_NS, 'invalidate_handle')
+ raise ProtocolError(message, text=fmt % (klass.mode, message))
- self.signed = message.copy()
+ signed = message.copy()
# openid.mode is currently check_authentication because
# that's the mode of this request. But the signature
# was made on something with a different openid.mode.
# http://article.gmane.org/gmane.comp.web.openid.general/537
- if self.signed.hasKey(OPENID_NS, "mode"):
- self.signed.setArg(OPENID_NS, "mode", "id_res")
+ if signed.hasKey(OPENID_NS, "mode"):
+ signed.setArg(OPENID_NS, "mode", "id_res")
+ self = klass(assoc_handle, signed, invalidate_handle, message)
return self
def answer(self, signatory):
@@ -257,9 +268,8 @@ class CheckAuthRequest(OpenIDRequest):
ih = " invalidate? %r" % (self.invalidate_handle,)
else:
ih = ""
- s = "<%s handle: %r sig: %r: signed: %r%s>" % (
- self.__class__.__name__, self.assoc_handle,
- self.sig, self.signed, ih)
+ sig = self.message.getArg(OPENID_NS, 'sig')
+ s = "<%s handle: %r sig: %r: signed: %r%s>" % (self.__class__.__name__, self.assoc_handle, sig, self.signed, ih)
return s
@@ -397,16 +407,15 @@ class AssociateRequest(OpenIDRequest):
'DH-SHA256': DiffieHellmanSHA256ServerSession,
}
- def __init__(self, session, assoc_type):
+ def __init__(self, session, assoc_type, message=None):
"""Construct me.
The session is assigned directly as a class attribute. See my
L{class documentation<AssociateRequest>} for its description.
"""
- super(AssociateRequest, self).__init__()
+ super(AssociateRequest, self).__init__(message=message)
self.session = session
self.assoc_type = assoc_type
- self.namespace = OPENID2_NS
@classmethod
def fromMessage(klass, message, op_endpoint=UNUSED):
@@ -455,9 +464,7 @@ class AssociateRequest(OpenIDRequest):
fmt = 'Session type %s does not support association type %s'
raise ProtocolError(message, fmt % (session_type, assoc_type))
- self = klass(session, assoc_type)
- self.message = message
- self.namespace = message.getOpenIDNamespace()
+ self = klass(session, assoc_type, message=message)
return self
def answer(self, assoc):
@@ -527,7 +534,7 @@ class CheckIDRequest(OpenIDRequest):
@ivar claimed_id: The claimed identifier. Not present in OpenID 1.x
messages.
- @type claimed_id: str
+ @type claimed_id: str or None
@ivar trust_root: "Are you Frank?" asks the checkid request. "Who wants
to know?" C{trust_root}, that's who. This URL identifies the party
@@ -546,7 +553,7 @@ class CheckIDRequest(OpenIDRequest):
"""
def __init__(self, identity, return_to, trust_root=None, immediate=False,
- assoc_handle=None, op_endpoint=None, claimed_id=None):
+ assoc_handle=None, op_endpoint=None, claimed_id=None, message=None):
"""Construct me.
These parameters are assigned directly as class attributes, see
@@ -554,13 +561,33 @@ class CheckIDRequest(OpenIDRequest):
@raises MalformedReturnURL: When the C{return_to} URL is not a URL.
"""
+ super(CheckIDRequest, self).__init__(message=message)
self.assoc_handle = assoc_handle
+
+ # Check the identifier validity. In case of error, create protocol error from the message in the argument.
+ if self.message.isOpenID1():
+ if identity is None:
+ s = "OpenID 1 message did not contain openid.identity"
+ raise ProtocolError(message, text=s)
+ else:
+ if identity and not claimed_id:
+ s = ("OpenID 2.0 message contained openid.identity but not "
+ "claimed_id")
+ raise ProtocolError(message, text=s)
+ elif claimed_id and not identity:
+ s = ("OpenID 2.0 message contained openid.claimed_id but not "
+ "identity")
+ raise ProtocolError(message, text=s)
+
self.identity = identity
- self.claimed_id = claimed_id or identity
+ self.claimed_id = claimed_id
self.return_to = return_to
self.trust_root = trust_root or return_to
+
+ if self.message.isOpenID2() and op_endpoint is None:
+ raise ValueError("CheckIDRequest requires op_endpoint argument for OpenID 2.0 requests.")
self.op_endpoint = op_endpoint
- assert self.op_endpoint is not None
+
if immediate:
self.immediate = True
self.mode = "checkid_immediate"
@@ -568,18 +595,22 @@ class CheckIDRequest(OpenIDRequest):
self.immediate = False
self.mode = "checkid_setup"
+ # Using TrustRoot.parse here is a bit misleading, as we're not
+ # parsing return_to as a trust root at all. However, valid URLs
+ # are valid trust roots, so we can use this to get an idea if it
+ # is a valid URL. Not all trust roots are valid return_to URLs,
+ # however (particularly ones with wildcards), so this is still a
+ # little sketchy.
if self.return_to is not None and not TrustRoot.parse(self.return_to):
- raise MalformedReturnURL(None, self.return_to)
- if not self.trustRootValid():
- raise UntrustedReturnURL(None, self.return_to, self.trust_root)
- self.message = None
+ raise MalformedReturnURL(message, self.return_to)
- @property
- def namespace(self):
- warnings.warn('The "namespace" attribute of CheckIDRequest objects '
- 'is deprecated. Use "message.getOpenIDNamespace()" '
- 'instead', DeprecationWarning, stacklevel=2)
- return self.message.getOpenIDNamespace()
+ # I first thought that checking to see if the return_to is within
+ # the trust_root is premature here, a logic-not-decoding thing. But
+ # it was argued that this is really part of data validation. A
+ # request with an invalid trust_root/return_to is broken regardless of
+ # application, right?
+ if not self.trustRootValid():
+ raise UntrustedReturnURL(message, self.return_to, self.trust_root)
@classmethod
def fromMessage(klass, message, op_endpoint):
@@ -602,38 +633,17 @@ class CheckIDRequest(OpenIDRequest):
@returntype: L{CheckIDRequest}
"""
- self = klass.__new__(klass)
- self.message = message
- self.op_endpoint = op_endpoint
mode = message.getArg(OPENID_NS, 'mode')
- if mode == "checkid_immediate":
- self.immediate = True
- self.mode = "checkid_immediate"
- else:
- self.immediate = False
- self.mode = "checkid_setup"
+ assert mode in ('checkid_immediate', 'checkid_setup')
+ immediate = bool(mode == 'checkid_immediate')
- self.return_to = message.getArg(OPENID_NS, 'return_to')
- if message.isOpenID1() and not self.return_to:
+ return_to = message.getArg(OPENID_NS, 'return_to')
+ if message.isOpenID1() and not return_to:
fmt = "Missing required field 'return_to' from %r"
raise ProtocolError(message, text=fmt % (message,))
- self.identity = message.getArg(OPENID_NS, 'identity')
- self.claimed_id = message.getArg(OPENID_NS, 'claimed_id')
- if message.isOpenID1():
- if self.identity is None:
- s = "OpenID 1 message did not contain openid.identity"
- raise ProtocolError(message, text=s)
- else:
- if self.identity and not self.claimed_id:
- s = ("OpenID 2.0 message contained openid.identity but not "
- "claimed_id")
- raise ProtocolError(message, text=s)
- elif self.claimed_id and not self.identity:
- s = ("OpenID 2.0 message contained openid.claimed_id but not "
- "identity")
- raise ProtocolError(message, text=s)
-
+ identity = message.getArg(OPENID_NS, 'identity')
+ claimed_id = message.getArg(OPENID_NS, 'claimed_id')
# There's a case for making self.trust_root be a TrustRoot
# here. But if TrustRoot isn't currently part of the "public" API,
# I'm not sure it's worth doing.
@@ -646,32 +656,15 @@ class CheckIDRequest(OpenIDRequest):
# Using 'or' here is slightly different than sending a default
# argument to getArg, as it will treat no value and an empty
# string as equivalent.
- self.trust_root = (message.getArg(OPENID_NS, trust_root_param) or self.return_to)
+ trust_root = (message.getArg(OPENID_NS, trust_root_param) or return_to)
- if not message.isOpenID1():
- if self.return_to is self.trust_root is None:
- raise ProtocolError(message, "openid.realm required when " +
- "openid.return_to absent")
+ if not message.isOpenID1() and (return_to is trust_root is None):
+ raise ProtocolError(message, "openid.realm required when openid.return_to absent")
- self.assoc_handle = message.getArg(OPENID_NS, 'assoc_handle')
-
- # Using TrustRoot.parse here is a bit misleading, as we're not
- # parsing return_to as a trust root at all. However, valid URLs
- # are valid trust roots, so we can use this to get an idea if it
- # is a valid URL. Not all trust roots are valid return_to URLs,
- # however (particularly ones with wildcards), so this is still a
- # little sketchy.
- if self.return_to is not None and not TrustRoot.parse(self.return_to):
- raise MalformedReturnURL(message, self.return_to)
-
- # I first thought that checking to see if the return_to is within
- # the trust_root is premature here, a logic-not-decoding thing. But
- # it was argued that this is really part of data validation. A
- # request with an invalid trust_root/return_to is broken regardless of
- # application, right?
- if not self.trustRootValid():
- raise UntrustedReturnURL(message, self.return_to, self.trust_root)
+ assoc_handle = message.getArg(OPENID_NS, 'assoc_handle')
+ self = klass(identity, return_to, trust_root=trust_root, immediate=immediate, assoc_handle=assoc_handle,
+ op_endpoint=op_endpoint, claimed_id=claimed_id, message=message)
return self
def idSelect(self):
@@ -773,8 +766,6 @@ class CheckIDRequest(OpenIDRequest):
@raises NoReturnError: when I do not have a return_to.
"""
- assert self.message is not None
-
if not self.return_to:
raise NoReturnToError
@@ -974,7 +965,7 @@ class OpenIDResponse(object):
@type request: L{OpenIDRequest}
"""
self.request = request
- self.fields = Message(request.namespace)
+ self.fields = Message(request.message.getOpenIDNamespace())
def __str__(self):
return "%s for %s: %s" % (
diff --git a/openid/test/test_server.py b/openid/test/test_server.py
index a3296fc..d4a1f14 100644
--- a/openid/test/test_server.py
+++ b/openid/test/test_server.py
@@ -1,10 +1,12 @@
"""Tests for openid.server.
"""
import unittest
+import warnings
from functools import partial
from urlparse import parse_qs, parse_qsl, urlparse
-from testfixtures import LogCapture, StringComparison
+from mock import sentinel
+from testfixtures import LogCapture, ShouldWarn, StringComparison
from openid import association, cryptutil, oidutil
from openid.consumer.consumer import DiffieHellmanSHA256ConsumerSession
@@ -25,6 +27,39 @@ ALT_MODULUS = int('1423261515703355186607439952816216983770573549498844689430217
ALT_GEN = 5
+# Example values to be used in tests
+EXAMPLE_IDENTITY = 'http://id.example.cz/'
+EXAMPLE_CLAIMED_ID = 'http://claimed.example.cz/'
+
+
+def make_checkid_request(identity=EXAMPLE_IDENTITY, claimed_id=EXAMPLE_CLAIMED_ID,
+ trust_root='http://realm.example.cz/', return_to='http://realm.example.cz/return_to/',
+ op_endpoint='http://op.example.cz/', immediate=False, message=None):
+ """Create a simple CheckIDRequest."""
+ message = message or Message(OPENID2_NS)
+ return server.CheckIDRequest(identity=identity, claimed_id=claimed_id, trust_root=trust_root, return_to=return_to,
+ op_endpoint=op_endpoint, immediate=immediate, message=message)
+
+
+class TestOpenIDRequest(unittest.TestCase):
+ """Test OpenID request base class."""
+
+ def test_init_default_message(self):
+ # Test empty OpenID 2.0 message is create if not provided.
+ request = server.OpenIDRequest()
+ self.assertTrue(request.message)
+ self.assertEqual(request.message.getOpenIDNamespace(), OPENID2_NS)
+
+ def test_namespace(self):
+ # Test deprecated namespace property
+ request = server.OpenIDRequest()
+ warning_msg = ('The "namespace" attribute of OpenIDRequest objects is deprecated. Use '
+ '"message.getOpenIDNamespace()" instead')
+ with ShouldWarn(DeprecationWarning(warning_msg)):
+ warnings.simplefilter('always')
+ self.assertEqual(request.namespace, OPENID2_NS)
+
+
class TestProtocolError(unittest.TestCase):
def test_browserWithReturnTo(self):
return_to = "http://rp.unittest/consumer"
@@ -328,7 +363,7 @@ class TestDecode(unittest.TestCase):
r = self.decode(args)
self.assertIsInstance(r, server.CheckAuthRequest)
self.assertEqual(r.mode, 'check_authentication')
- self.assertEqual(r.sig, 'sigblob')
+ self.assertEqual(r.message.getArg(OPENID_NS, 'sig'), 'sigblob')
def test_checkAuthMissingSignature(self):
args = {
@@ -488,14 +523,7 @@ class TestEncode(unittest.TestCase):
OpenID 1 message size, a GET response (i.e., redirect) is
issued.
"""
- request = server.CheckIDRequest(
- identity='http://bombom.unittest/',
- trust_root='http://burr.unittest/',
- return_to='http://burr.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- request.message = Message(OPENID2_NS)
+ request = make_checkid_request()
response = server.OpenIDResponse(request)
response.fields = Message.fromOpenIDArgs({
'ns': OPENID2_NS,
@@ -516,14 +544,7 @@ class TestEncode(unittest.TestCase):
message size, a POST response (i.e., an HTML form) is
returned.
"""
- request = server.CheckIDRequest(
- identity='http://bombom.unittest/',
- trust_root='http://burr.unittest/',
- return_to='http://burr.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- request.message = Message(OPENID2_NS)
+ request = make_checkid_request()
response = server.OpenIDResponse(request)
response.fields = Message.fromOpenIDArgs({
'ns': OPENID2_NS,
@@ -540,14 +561,7 @@ class TestEncode(unittest.TestCase):
self.assertIn(response.toFormMarkup(), webresponse.body)
def test_toFormMarkup(self):
- request = server.CheckIDRequest(
- identity='http://bombom.unittest/',
- trust_root='http://burr.unittest/',
- return_to='http://burr.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- request.message = Message(OPENID2_NS)
+ request = make_checkid_request()
response = server.OpenIDResponse(request)
response.fields = Message.fromOpenIDArgs({
'ns': OPENID2_NS,
@@ -561,14 +575,7 @@ class TestEncode(unittest.TestCase):
self.assertIn(' foo="bar"', form_markup)
def test_toHTML(self):
- request = server.CheckIDRequest(
- identity='http://bombom.unittest/',
- trust_root='http://burr.unittest/',
- return_to='http://burr.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- request.message = Message(OPENID2_NS)
+ request = make_checkid_request()
response = server.OpenIDResponse(request)
response.fields = Message.fromOpenIDArgs({
'ns': OPENID2_NS,
@@ -582,7 +589,7 @@ class TestEncode(unittest.TestCase):
self.assertIn('</html>', html)
self.assertIn('<body onload=', html)
self.assertIn('<form', html)
- self.assertIn('http://bombom.unittest/', html)
+ self.assertIn(EXAMPLE_IDENTITY, html)
def test_id_res_OpenID1_exceeds_limit(self):
"""
@@ -591,14 +598,7 @@ class TestEncode(unittest.TestCase):
shouldn't be permitted by the library, but this test is in
place to preserve the status quo for OpenID 1.
"""
- request = server.CheckIDRequest(
- identity='http://bombom.unittest/',
- trust_root='http://burr.unittest/',
- return_to='http://burr.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- request.message = Message(OPENID2_NS)
+ request = make_checkid_request()
response = server.OpenIDResponse(request)
response.fields = Message.fromOpenIDArgs({
'mode': 'id_res',
@@ -613,14 +613,7 @@ class TestEncode(unittest.TestCase):
self.assertEqual(webresponse.headers['location'], response.encodeToURL())
def test_id_res(self):
- request = server.CheckIDRequest(
- identity='http://bombom.unittest/',
- trust_root='http://burr.unittest/',
- return_to='http://burr.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- request.message = Message(OPENID2_NS)
+ request = make_checkid_request()
response = server.OpenIDResponse(request)
response.fields = Message.fromOpenIDArgs({
'mode': 'id_res',
@@ -640,14 +633,7 @@ class TestEncode(unittest.TestCase):
self.assertEqual(q2, expected)
def test_cancel(self):
- request = server.CheckIDRequest(
- identity='http://bombom.unittest/',
- trust_root='http://burr.unittest/',
- return_to='http://burr.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- request.message = Message(OPENID2_NS)
+ request = make_checkid_request()
response = server.OpenIDResponse(request)
response.fields = Message.fromOpenIDArgs({
'mode': 'cancel',
@@ -657,14 +643,7 @@ class TestEncode(unittest.TestCase):
self.assertIn('location', webresponse.headers)
def test_cancelToForm(self):
- request = server.CheckIDRequest(
- identity='http://bombom.unittest/',
- trust_root='http://burr.unittest/',
- return_to='http://burr.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- request.message = Message(OPENID2_NS)
+ request = make_checkid_request()
response = server.OpenIDResponse(request)
response.fields = Message.fromOpenIDArgs({
'mode': 'cancel',
@@ -729,14 +708,7 @@ class TestSigningEncode(unittest.TestCase):
self._normal_key = server.Signatory._normal_key
self.store = memstore.MemoryStore()
self.server = server.Server(self.store, "http://signing.unittest/enc")
- self.request = server.CheckIDRequest(
- identity='http://bombom.unittest/',
- trust_root='http://burr.unittest/',
- return_to='http://burr.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- self.request.message = Message(OPENID2_NS)
+ self.request = make_checkid_request()
self.response = server.OpenIDResponse(self.request)
self.response.fields = Message.fromOpenIDArgs({
'mode': 'id_res',
@@ -780,14 +752,7 @@ class TestSigningEncode(unittest.TestCase):
self.assertRaises(ValueError, self.encode, self.response)
def test_cancel(self):
- request = server.CheckIDRequest(
- identity='http://bombom.unittest/',
- trust_root='http://burr.unittest/',
- return_to='http://burr.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- request.message = Message(OPENID2_NS)
+ request = make_checkid_request()
response = server.OpenIDResponse(request)
response.fields.setArg(OPENID_NS, 'mode', 'cancel')
webresponse = self.encode(response)
@@ -821,14 +786,12 @@ class TestCheckID(unittest.TestCase):
self.op_endpoint = 'http://endpoint.unittest/'
self.store = memstore.MemoryStore()
self.server = server.Server(self.store, self.op_endpoint)
- self.request = server.CheckIDRequest(
- identity='http://bambam.unittest/',
- trust_root='http://bar.unittest/',
- return_to='http://bar.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- self.request.message = Message(OPENID2_NS)
+ self.request = make_checkid_request(op_endpoint=self.op_endpoint)
+
+ def test_openid2_requires_provider(self):
+ with self.assertRaisesRegexp(ValueError, 'CheckIDRequest requires op_endpoint'):
+ server.CheckIDRequest(sentinel.identity, sentinel.return_to, claimed_id=sentinel.claimed_id,
+ message=Message(OPENID2_NS))
def test_trustRootInvalid(self):
self.request.trust_root = "http://foo.unittest/17"
@@ -852,6 +815,7 @@ class TestCheckID(unittest.TestCase):
def test_trustRootValidNoReturnTo(self):
request = server.CheckIDRequest(
identity='http://bambam.unittest/',
+ claimed_id=EXAMPLE_CLAIMED_ID,
trust_root='http://bar.unittest/',
return_to=None,
immediate=False,
@@ -925,7 +889,7 @@ class TestCheckID(unittest.TestCase):
"""
answer = self.request.answer(True)
self.assertEqual(answer.request, self.request)
- self._expectAnswer(answer, self.request.identity)
+ self._expectAnswer(answer, self.request.identity, EXAMPLE_CLAIMED_ID)
def test_answerAllowDelegatedIdentity(self):
self.request.claimed_id = 'http://delegating.unittest/'
@@ -937,7 +901,7 @@ class TestCheckID(unittest.TestCase):
# This time with the identity argument explicitly passed in to
# answer()
self.request.claimed_id = 'http://delegating.unittest/'
- answer = self.request.answer(True, identity='http://bambam.unittest/')
+ answer = self.request.answer(True, identity=EXAMPLE_IDENTITY)
self._expectAnswer(answer, self.request.identity,
self.request.claimed_id)
@@ -1071,7 +1035,7 @@ class TestCheckID(unittest.TestCase):
self.request.trust_root = None
answer = self.request.answer(True)
self.assertEqual(answer.request, self.request)
- self._expectAnswer(answer, self.request.identity)
+ self._expectAnswer(answer, self.request.identity, EXAMPLE_CLAIMED_ID)
def test_fromMessageWithoutTrustRoot(self):
msg = Message(OPENID2_NS)
@@ -1115,6 +1079,7 @@ class TestCheckID(unittest.TestCase):
"""
identity = 'http://bambam.unittest/'
reqmessage = Message.fromOpenIDArgs({
+ 'mode': 'checkid_setup',
'identity': identity,
'trust_root': 'http://bar.unittest/',
'return_to': 'http://bar.unittest/999',
@@ -1214,14 +1179,7 @@ class TestCheckIDExtension(unittest.TestCase):
self.op_endpoint = 'http://endpoint.unittest/ext'
self.store = memstore.MemoryStore()
self.server = server.Server(self.store, self.op_endpoint)
- self.request = server.CheckIDRequest(
- identity='http://bambam.unittest/',
- trust_root='http://bar.unittest/',
- return_to='http://bar.unittest/999',
- immediate=False,
- op_endpoint=self.server.op_endpoint,
- )
- self.request.message = Message(OPENID2_NS)
+ self.request = make_checkid_request(op_endpoint=self.op_endpoint)
self.response = server.OpenIDResponse(self.request)
self.response.fields.setArg(OPENID_NS, 'mode', 'id_res')
self.response.fields.setArg(OPENID_NS, 'blue', 'star')
@@ -1586,9 +1544,8 @@ class TestServer(unittest.TestCase):
r = server.OpenIDResponse(request)
return r
self.server.openid_monkeymode = monkeyDo
- request = server.OpenIDRequest()
+ request = server.OpenIDRequest(Message(OPENID1_NS))
request.mode = "monkeymode"
- request.namespace = OPENID1_NS
self.server.handleRequest(request)
self.assertEqual(monkeycalled.count, 1)
@@ -1693,14 +1650,13 @@ class TestSignatory(unittest.TestCase):
self._normal_key = self.signatory._normal_key
def test_sign(self):
- request = server.OpenIDRequest()
+ request = server.OpenIDRequest(Message(OPENID1_NS))
assoc_handle = '{assoc}{lookatme}'
self.store.storeAssociation(
self._normal_key,
association.Association.fromExpiresIn(60, assoc_handle,
'sekrit', 'HMAC-SHA1'))
request.assoc_handle = assoc_handle
- request.namespace = OPENID1_NS
response = server.OpenIDResponse(request)
response.fields = Message.fromOpenIDArgs({
'foo': 'amsigned',
@@ -1717,7 +1673,6 @@ class TestSignatory(unittest.TestCase):
def test_signDumb(self):
request = server.OpenIDRequest()
request.assoc_handle = None
- request.namespace = OPENID2_NS
response = server.OpenIDResponse(request)
response.fields = Message.fromOpenIDArgs({
'foo': 'amsigned',
@@ -1750,7 +1705,6 @@ class TestSignatory(unittest.TestCase):
Relying Party included with the original request.
"""
request = server.OpenIDRequest()
- request.namespace = OPENID2_NS
assoc_handle = '{assoc}{lookatme}'
self.store.storeAssociation(
self._normal_key,
@@ -1789,7 +1743,6 @@ class TestSignatory(unittest.TestCase):
def test_signInvalidHandle(self):
request = server.OpenIDRequest()
- request.namespace = OPENID2_NS
assoc_handle = '{bogus-assoc}{notvalid}'
request.assoc_handle = assoc_handle
diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py
index f358976..80fd442 100644
--- a/openid/test/test_sreg.py
+++ b/openid/test/test_sreg.py
@@ -439,9 +439,7 @@ class SendFieldsTest(unittest.TestCase):
req_msg = Message()
req_msg.updateArgs(sreg.ns_uri, sreg_req.getExtensionArgs())
- req = OpenIDRequest()
- req.message = req_msg
- req.namespace = req_msg.getOpenIDNamespace()
+ req = OpenIDRequest(req_msg)
# -> send checkid_* request