diff options
author | Vlastimil Zíma <vlastimil.zima@nic.cz> | 2017-11-24 15:07:29 +0100 |
---|---|---|
committer | Vlastimil Zíma <vlastimil.zima@nic.cz> | 2017-11-29 08:38:30 +0100 |
commit | 8f0ff0d27771514d16a415b8ac76d18ea0809f38 (patch) | |
tree | 06b56cd9666f85aec459e9264e4c724f86c70f0d | |
parent | f58d7cee3e9f4bff9854dc10ffcd105fb3bc6619 (diff) | |
download | openid-8f0ff0d27771514d16a415b8ac76d18ea0809f38.tar.gz |
Pepify and add flake8
100 files changed, 1954 insertions, 1853 deletions
@@ -0,0 +1,5 @@ +[flake8] +max-line-length = 120 +# Ignore E123 - enforce hang-closing instead +ignore = E123,W503 +max-complexity = 22 diff --git a/.travis.yml b/.travis.yml index 14636d8..fe11957 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,8 +3,9 @@ language: python python: - 2.7 -before_install: pip install Django pycrypto lxml isort +before_install: pip install Django pycrypto lxml isort flake8 install: python setup.py install script: - make check-isort + - make check-flake8 - make test @@ -1,4 +1,4 @@ -.PHONY: test coverage isort check-isort +.PHONY: test coverage isort check-all check-isort check-flake8 test: python admin/runtests @@ -12,5 +12,10 @@ coverage: isort: isort --recursive . +check-all: check-isort check-flake8 + check-isort: isort --check-only --diff --recursive . + +check-flake8: + flake8 --format=pylint . diff --git a/admin/builddiscover.py b/admin/builddiscover.py index 011ab88..ef4ede9 100755 --- a/admin/builddiscover.py +++ b/admin/builddiscover.py @@ -29,6 +29,7 @@ manifest_header = """\ """ + def buildDiscover(base_url, out_dir): """Convert all files in a directory to apache mod_asis files in another directory.""" @@ -63,6 +64,7 @@ def buildDiscover(base_url, out_dir): manifest_file.write(chunk) manifest_file.close() + if __name__ == '__main__': import sys buildDiscover(*sys.argv[1:]) diff --git a/admin/gettlds.py b/admin/gettlds.py index f473224..b2a7c92 100644 --- a/admin/gettlds.py +++ b/admin/gettlds.py @@ -21,7 +21,7 @@ langs = { 'ruby': ("%w'", "", " ", "", "'"), - } +} lang = sys.argv[1] prefix, line_prefix, separator, line_suffix, suffix = langs[lang] diff --git a/admin/runtests b/admin/runtests index b2a3a79..db7a647 100755 --- a/admin/runtests +++ b/admin/runtests @@ -1,11 +1,14 @@ #!/usr/bin/env python -import os.path, sys, warnings +import os.path +import sys +import warnings test_modules = [ 'cryptutil', 'oidutil', 'dh', - ] +] + def fixpath(): try: @@ -17,10 +20,11 @@ def fixpath(): print "putting %s in sys.path" % (parent,) sys.path.insert(0, parent) + def otherTests(): failed = [] for module_name in test_modules: - print 'Testing %s...' % (module_name,) , + print 'Testing %s...' % (module_name,), sys.stdout.flush() module_name = 'openid.test.' + module_name try: @@ -31,17 +35,15 @@ def otherTests(): else: try: test_mod.test() - except (SystemExit, KeyboardInterrupt): - raise - except: + except Exception: sys.excepthook(*sys.exc_info()) failed.append(module_name) else: print 'Succeeded.' - return failed + def pyunitTests(): import unittest pyunit_module_names = [ @@ -63,16 +65,16 @@ def pyunitTests(): 'pape_draft5', 'rpverify', 'extension', - ] + ] pyunit_modules = [ __import__('openid.test.test_%s' % (name,), {}, {}, ['unused']) for name in pyunit_module_names - ] + ] try: from openid.test import test_examples - except ImportError, e: + except ImportError as e: if 'twill' in str(e): warnings.warn("Could not import twill; skipping test_examples.") else: @@ -98,7 +100,7 @@ def pyunitTests(): 'test_urinorm', 'test_yadis_discover', 'trustroot', - ] + ] loader = unittest.TestLoader() s = unittest.TestSuite() @@ -110,18 +112,17 @@ def pyunitTests(): m = __import__('openid.test.%s' % (name,), {}, {}, ['unused']) try: s.addTest(m.pyUnitTests()) - except AttributeError, ex: + except AttributeError as ex: # because the AttributeError doesn't actually say which # object it was. print "Error loading tests from %s:" % (name,) raise - runner = unittest.TextTestRunner() # verbosity=2) + runner = unittest.TextTestRunner() # verbosity=2) return runner.run(s) - def splitDir(d, count): # in python2.4 and above, it's easier to spell this as # d.rsplit(os.sep, count) @@ -130,7 +131,6 @@ def splitDir(d, count): return d - def _import_djopenid(): """Import djopenid from examples/ @@ -153,7 +153,6 @@ def _import_djopenid(): sys.modules['djopenid'] = djopenid - def django_tests(): """Runs tests from examples/djopenid. @@ -167,11 +166,12 @@ def django_tests(): try: import django.test.simple - except ImportError, e: + except ImportError as e: warnings.warn("django.test.simple not found; " "django examples not tested.") return 0 - import djopenid.server.models, djopenid.consumer.models + import djopenid.server.models + import djopenid.consumer.models print "Testing Django examples:" # These tests do get put in to a pyunit test suite, so we could run them @@ -180,12 +180,14 @@ def django_tests(): return django.test.simple.run_tests([djopenid.server.models, djopenid.consumer.models]) + try: bool except NameError: def bool(x): return not not x + def main(): fixpath() other_failed = otherTests() @@ -200,5 +202,6 @@ def main(): (django_failures > 0)) return failed + if __name__ == '__main__': sys.exit(main() and 1 or 0) diff --git a/contrib/associate b/contrib/associate index 4cb05c3..76fe5b0 100755 --- a/contrib/associate +++ b/contrib/associate @@ -10,6 +10,7 @@ from openid.consumer.discover import OpenIDServiceEndpoint from datetime import datetime + def verboseAssociation(assoc): """A more verbose representation of an Association. """ @@ -24,6 +25,7 @@ def verboseAssociation(assoc): """ return fmt % d + def main(): if not sys.argv[1:]: print "Usage: %s ENDPOINT_URL..." % (sys.argv[0],) @@ -43,5 +45,6 @@ def main(): else: print " ...no association." + if __name__ == '__main__': main() diff --git a/contrib/openid-parse b/contrib/openid-parse index 21ab18d..ac2c5df 100644 --- a/contrib/openid-parse +++ b/contrib/openid-parse @@ -9,12 +9,16 @@ Requires the 'xsel' program to get the contents of the clipboard. from pprint import pformat from urlparse import urlsplit, urlunsplit -import cgi, re, subprocess, sys +import cgi +import re +import subprocess +import sys from openid import message OPENID_SORT_ORDER = ['mode', 'identity', 'claimed_id'] + class NoQuery(Exception): def __init__(self, url): self.url = url @@ -42,7 +46,7 @@ def main(): for url in urls: try: queries.append(queryFromURL(url)) - except NoQuery, err: + except NoQuery as err: errors.append(err) queries.extend(queriesFromLogs(source)) @@ -73,7 +77,7 @@ def openidFromQuery(query): try: msg = message.Message.fromPostArgs(unlistify(query)) s = formatOpenIDMessage(msg) - except Exception, err: + except Exception as err: # XXX - side effect. sys.stderr.write(str(err)) s = pformat(query) @@ -103,8 +107,7 @@ def formatOpenIDMessage(msg): except KeyError: pass - values = values.items() - values.sort() + values = sorted(values.items()) for k, v in values: ns_output.append(" %s = %s" % (k, v)) @@ -124,6 +127,7 @@ def queriesFromLogs(s): return [(match.group(1), cgi.parse_qs(match.group(2))) for match in qre.finditer(s)] + def queriesFromPostdata(s): # This looks for query data in a line that starts POSTDATA=. # Tamperdata outputs such lines. If there's a 'Host=' in that block, @@ -133,16 +137,18 @@ def queriesFromPostdata(s): return [(match.group('host') or 'POSTDATA', cgi.parse_qs(match.group('query'))) for match in qre.finditer(s)] + def find_urls(s): # Regular expression borrowed from urlscan # by Daniel Burrows <dburrows@debian.org>, GPL. - urlinternalpattern=r'[{}a-zA-Z/\-_0-9%?&.=:;+,#~]' - urltrailingpattern=r'[{}a-zA-Z/\-_0-9%&=+#]' + urlinternalpattern = r'[{}a-zA-Z/\-_0-9%?&.=:;+,#~]' + urltrailingpattern = r'[{}a-zA-Z/\-_0-9%&=+#]' httpurlpattern = r'(?:https?://' + urlinternalpattern + r'*' + urltrailingpattern + r')' # Used to guess that blah.blah.blah.TLD is a URL. - tlds=['biz', 'com', 'edu', 'info', 'org'] - guessedurlpattern=r'(?:[a-zA-Z0-9_\-%]+(?:\.[a-zA-Z0-9_\-%]+)*\.(?:' + '|'.join(tlds) + '))' - urlre = re.compile(r'(?:<(?:URL:)?)?(' + httpurlpattern + '|' + guessedurlpattern + '|(?:mailto:[a-zA-Z0-9\-_]*@[0-9a-zA-Z_\-.]*[0-9a-zA-Z_\-]))>?') + tlds = ['biz', 'com', 'edu', 'info', 'org'] + guessedurlpattern = r'(?:[a-zA-Z0-9_\-%]+(?:\.[a-zA-Z0-9_\-%]+)*\.(?:' + '|'.join(tlds) + '))' + urlre = re.compile(r'(?:<(?:URL:)?)?(' + httpurlpattern + '|' + guessedurlpattern + + '|(?:mailto:[a-zA-Z0-9\-_]*@[0-9a-zA-Z_\-.]*[0-9a-zA-Z_\-]))>?') return [match.group(1) for match in urlre.finditer(s)] diff --git a/contrib/upgrade-store-1.1-to-2.0 b/contrib/upgrade-store-1.1-to-2.0 index 1f587c3..1907ce3 100644 --- a/contrib/upgrade-store-1.1-to-2.0 +++ b/contrib/upgrade-store-1.1-to-2.0 @@ -23,15 +23,16 @@ from optparse import OptionParser def askForPassword(): return getpass.getpass("DB Password: ") -def askForConfirmation(dbname,tablename): + +def askForConfirmation(dbname, tablename): print """The table %s from the database %s will be dropped, and - an empty table with the new nonce table schema will replace it."""%( - tablename, dbname) + an empty table with the new nonce table schema will replace it.""" % (tablename, dbname) return raw_input("Continue? ").lower().strip().startswith('y') + def doSQLiteUpgrade(db_conn, nonce_table_name='oid_nonces'): cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) + cur.execute('DROP TABLE %s' % nonce_table_name) sql = """ CREATE TABLE %s ( server_url VARCHAR, @@ -39,13 +40,14 @@ def doSQLiteUpgrade(db_conn, nonce_table_name='oid_nonces'): salt CHAR(40), UNIQUE(server_url, timestamp, salt) ); - """%nonce_table_name + """ % nonce_table_name cur.execute(sql) cur.close() + def doMySQLUpgrade(db_conn, nonce_table_name='oid_nonces'): cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) + cur.execute('DROP TABLE %s' % nonce_table_name) sql = """ CREATE TABLE %s ( server_url BLOB, @@ -54,13 +56,14 @@ def doMySQLUpgrade(db_conn, nonce_table_name='oid_nonces'): PRIMARY KEY (server_url(255), timestamp, salt) ) TYPE=InnoDB; - """%nonce_table_name + """ % nonce_table_name cur.execute(sql) cur.close() + def doPostgreSQLUpgrade(db_conn, nonce_table_name='oid_nonces'): cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) + cur.execute('DROP TABLE %s' % nonce_table_name) sql = """ CREATE TABLE %s ( server_url VARCHAR(2047), @@ -68,11 +71,12 @@ def doPostgreSQLUpgrade(db_conn, nonce_table_name='oid_nonces'): salt CHAR(40), PRIMARY KEY (server_url, timestamp, salt) ); - """%nonce_table_name + """ % nonce_table_name cur.execute(sql) cur.close() db_conn.commit() + def main(argv=None): parser = OptionParser() parser.add_option("-u", "--user", dest="username", @@ -106,7 +110,7 @@ def main(argv=None): return 1 try: db_conn = sqlite.connect(options.sqlite_db_name) - except Exception, e: + except Exception as e: print "Could not connect to SQLite database:", str(e) return 1 @@ -125,11 +129,11 @@ def main(argv=None): return 1 try: - db_conn = psycopg.connect(database = options.postgres_db_name, - user = options.username, - host = options.db_host, - password = password) - except Exception, e: + db_conn = psycopg.connect(database=options.postgres_db_name, + user=options.username, + host=options.db_host, + password=password) + except Exception as e: print "Could not connect to PostgreSQL database:", str(e) return 1 @@ -150,7 +154,7 @@ def main(argv=None): try: db_conn = MySQLdb.connect(options.db_host, options.username, password, options.mysql_db_name) - except Exception, e: + except Exception as e: print "Could not connect to MySQL database:", str(e) return 1 diff --git a/examples/consumer.py b/examples/consumer.py index 1d448de..908130a 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -35,6 +35,7 @@ For more information, see the README in the root of the library distribution.""") sys.exit(1) else: + del openid from openid.consumer import consumer from openid.cryptutil import randomString from openid.extensions import pape, sreg @@ -45,13 +46,14 @@ else: # Used with an OpenID provider affiliate program. OPENID_PROVIDER_NAME = 'MyOpenID' -OPENID_PROVIDER_URL ='https://www.myopenid.com/affiliate_signup?affiliate_id=39' +OPENID_PROVIDER_URL = 'https://www.myopenid.com/affiliate_signup?affiliate_id=39' class OpenIDHTTPServer(HTTPServer): """http server that contains a reference to an OpenID consumer and knows its base URL. """ + def __init__(self, store, *args, **kwargs): HTTPServer.__init__(self, *args, **kwargs) self.sessions = {} @@ -63,6 +65,7 @@ class OpenIDHTTPServer(HTTPServer): else: self.base_url = 'http://%s/' % (self.server_name,) + class OpenIDRequestHandler(BaseHTTPRequestHandler): """Request handler that knows how to verify an OpenID identity.""" SESSION_COOKIE_NAME = 'pyoidconsexsid' @@ -145,9 +148,7 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): else: self.notFound() - except (KeyboardInterrupt, SystemExit): - raise - except: + except Exception: self.send_response(500) self.send_header('Content-type', 'text/html') self.setSessionCookie() @@ -170,10 +171,10 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): use_pape = 'use_pape' in self.query use_stateless = 'use_stateless' in self.query - oidconsumer = self.getConsumer(stateless = use_stateless) + oidconsumer = self.getConsumer(stateless=use_stateless) try: request = oidconsumer.begin(openid_url) - except consumer.DiscoveryFailure, exc: + except consumer.DiscoveryFailure as exc: fetch_error_string = 'Error in discovery: %s' % ( cgi.escape(str(exc[0]))) self.render(fetch_error_string, @@ -207,7 +208,7 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): else: form_html = request.htmlMarkup( trust_root, return_to, - form_tag_attrs={'id':'openid_message'}, + form_tag_attrs={'id': 'openid_message'}, immediate=immediate) self.wfile.write(form_html) @@ -230,7 +231,7 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): # us. Status is a code indicating the response type. info is # either None or a string containing more information about # the return type. - url = 'http://'+self.headers.get('Host')+self.path + url = 'http://' + self.headers.get('Host') + self.path info = oidconsumer.complete(self.query, url) sreg_resp = None @@ -300,8 +301,7 @@ class OpenIDRequestHandler(BaseHTTPRequestHandler): self.wfile.write( '<div class="alert">No registration data was returned</div>') else: - sreg_list = sreg_data.items() - sreg_list.sort() + sreg_list = sorted(sreg_data.items()) self.wfile.write( '<h2>Registration Data</h2>' '<table class="sreg">' @@ -443,14 +443,17 @@ Content-type: text/html; charset=UTF-8 <input type="submit" value="Verify" /><br /> <input type="checkbox" name="immediate" id="immediate" /><label for="immediate">Use immediate mode</label> <input type="checkbox" name="use_sreg" id="use_sreg" /><label for="use_sreg">Request registration data</label> - <input type="checkbox" name="use_pape" id="use_pape" /><label for="use_pape">Request phishing-resistent auth policy (PAPE)</label> - <input type="checkbox" name="use_stateless" id="use_stateless" /><label for="use_stateless">Use stateless mode</label> + <input type="checkbox" name="use_pape" id="use_pape" /> + <label for="use_pape">Request phishing-resistent auth policy (PAPE)</label> + <input type="checkbox" name="use_stateless" id="use_stateless" /> + <label for="use_stateless">Use stateless mode</label> </form> </div> </body> </html> ''' % (quoteattr(self.buildURL('verify')), quoteattr(form_contents))) + def main(host, port, data_path, weak_ssl=False): # Instantiate OpenID consumer store and OpenID consumer. If you # were connecting to a database, you would create the database @@ -470,6 +473,7 @@ def main(host, port, data_path, weak_ssl=False): print server.base_url server.serve_forever() + if __name__ == '__main__': parser = optparse.OptionParser('Usage:\n %prog [options]') parser.add_option( diff --git a/examples/discover b/examples/discover index 9b74e8a..e2ede67 100644 --- a/examples/discover +++ b/examples/discover @@ -2,10 +2,11 @@ from openid.consumer.discover import discover, DiscoveryFailure from openid.fetchers import HTTPFetchingError -names = [["server_url", "Server URL "], - ["local_id", "Local ID "], +names = [["server_url", "Server URL "], + ["local_id", "Local ID "], ["canonicalID", "Canonical ID"], - ] + ] + def show_services(user_input, normalized, services): print " Claimed identifier:", normalized @@ -28,6 +29,7 @@ def show_services(user_input, normalized, services): print " No OpenID services found" print + if __name__ == "__main__": import sys @@ -36,10 +38,10 @@ if __name__ == "__main__": print "Running discovery on", user_input try: normalized, services = discover(user_input) - except DiscoveryFailure, why: + except DiscoveryFailure as why: print "Discovery failed:", why print - except HTTPFetchingError, why: + except HTTPFetchingError as why: print "HTTP request failed:", why print else: diff --git a/examples/djopenid/consumer/models.py b/examples/djopenid/consumer/models.py index 71a8362..b194906 100644 --- a/examples/djopenid/consumer/models.py +++ b/examples/djopenid/consumer/models.py @@ -1,3 +1 @@ -from django.db import models - -# Create your models here. +"""Required module for Django application.""" diff --git a/examples/djopenid/consumer/urls.py b/examples/djopenid/consumer/urls.py index d55e056..7190093 100644 --- a/examples/djopenid/consumer/urls.py +++ b/examples/djopenid/consumer/urls.py @@ -1,5 +1,4 @@ - -from django.conf.urls.defaults import * +from django.conf.urls.defaults import patterns urlpatterns = patterns( 'djopenid.consumer.views', diff --git a/examples/djopenid/consumer/views.py b/examples/djopenid/consumer/views.py index 1f4dd94..bbc0ff8 100644 --- a/examples/djopenid/consumer/views.py +++ b/examples/djopenid/consumer/views.py @@ -1,5 +1,3 @@ - -from django import http from django.http import HttpResponseRedirect from django.views.generic.simple import direct_to_template @@ -7,7 +5,7 @@ from openid.consumer import consumer from openid.consumer.discover import DiscoveryFailure from openid.extensions import ax, pape, sreg from openid.server.trustroot import RP_RETURN_TO_URL_TYPE -from openid.yadis.constants import YADIS_CONTENT_TYPE, YADIS_HEADER_NAME +from openid.yadis.constants import YADIS_HEADER_NAME from .. import util @@ -15,12 +13,13 @@ PAPE_POLICIES = [ 'AUTH_PHISHING_RESISTANT', 'AUTH_MULTI_FACTOR', 'AUTH_MULTI_FACTOR_PHYSICAL', - ] +] # List of (name, uri) for use in generating the request form. POLICY_PAIRS = [(p, getattr(pape, p)) for p in PAPE_POLICIES] + def getOpenIDStore(): """ Return an OpenID store object fit for the currently-chosen @@ -28,21 +27,24 @@ def getOpenIDStore(): """ return util.getOpenIDStore('/tmp/djopenid_c_store', 'c_') + def getConsumer(request): """ Get a Consumer object to perform OpenID authentication. """ return consumer.Consumer(request.session, getOpenIDStore()) + def renderIndexPage(request, **template_args): template_args['consumer_url'] = util.getViewURL(request, startOpenID) template_args['pape_policies'] = POLICY_PAIRS - response = direct_to_template( + response = direct_to_template( request, 'consumer/index.html', template_args) response[YADIS_HEADER_NAME] = util.getViewURL(request, rpXRDS) return response + def startOpenID(request): """ Start the OpenID authentication process. Renders an @@ -67,7 +69,7 @@ def startOpenID(request): try: auth_request = c.begin(openid_url) - except DiscoveryFailure, e: + except DiscoveryFailure as e: # Some other protocol-level failure occurred. error = "OpenID discovery error: %s" % (str(e),) @@ -133,6 +135,7 @@ def startOpenID(request): return renderIndexPage(request) + def finishOpenID(request): """ Finish the OpenID authentication process. Invoke the OpenID @@ -173,7 +176,7 @@ def finishOpenID(request): 'http://schema.openid.net/namePerson'), 'web': ax_response.get( 'http://schema.openid.net/contact/web/default'), - } + } # Get a PAPE response object if response information was # included in the OpenID response. @@ -197,7 +200,7 @@ def finishOpenID(request): 'sreg': sreg_response and sreg_response.items(), 'ax': ax_items.items(), 'pape': pape_response} - } + } result = results[response.status] @@ -210,6 +213,7 @@ def finishOpenID(request): return renderIndexPage(request, **result) + def rpXRDS(request): """ Return a relying party verification XRDS document diff --git a/examples/djopenid/manage.py b/examples/djopenid/manage.py index ae94958..45a1ee6 100644 --- a/examples/djopenid/manage.py +++ b/examples/djopenid/manage.py @@ -2,10 +2,12 @@ from django.core.management import execute_manager try: - import settings # Assumed to be in the same directory. + import settings # Assumed to be in the same directory. except ImportError: import sys - sys.stderr.write("Error: Can't find the file 'settings.py' in the directory containing %r. It appears you've customized things.\nYou'll have to run django-admin.py, passing it your settings module.\n(If the file settings.py does indeed exist, it's causing an ImportError somehow.)\n" % __file__) + sys.stderr.write("Error: Can't find the file 'settings.py' in the directory containing %r. It appears you've " + "customized things.\nYou'll have to run django-admin.py, passing it your settings module.\n(If " + "the file settings.py does indeed exist, it's causing an ImportError somehow.)\n" % __file__) sys.exit(1) if __name__ == "__main__": diff --git a/examples/djopenid/server/models.py b/examples/djopenid/server/models.py index 71a8362..b194906 100644 --- a/examples/djopenid/server/models.py +++ b/examples/djopenid/server/models.py @@ -1,3 +1 @@ -from django.db import models - -# Create your models here. +"""Required module for Django application.""" diff --git a/examples/djopenid/server/tests.py b/examples/djopenid/server/tests.py index d86151b..6cae547 100644 --- a/examples/djopenid/server/tests.py +++ b/examples/djopenid/server/tests.py @@ -19,6 +19,7 @@ def dummyRequest(): request.META['SERVER_PROTOCOL'] = 'HTTP' return request + class TestProcessTrustResult(TestCase): def setUp(self): self.request = dummyRequest() @@ -32,12 +33,11 @@ class TestProcessTrustResult(TestCase): 'openid.identity': id_url, 'openid.return_to': 'http://127.0.0.1/%s' % (self.id(),), 'openid.sreg.required': 'postcode', - }) + }) self.openid_request = CheckIDRequest.fromMessage(message, op_endpoint) views.setRequest(self.request, self.openid_request) - def test_allow(self): self.request.POST['allow'] = 'Yes' @@ -61,7 +61,6 @@ class TestProcessTrustResult(TestCase): self.failIf('openid.sreg.postcode=12345' in finalURL, finalURL) - class TestShowDecidePage(TestCase): def test_unreachableRealm(self): self.request = dummyRequest() @@ -75,7 +74,7 @@ class TestShowDecidePage(TestCase): 'openid.identity': id_url, 'openid.return_to': 'http://unreachable.invalid/%s' % (self.id(),), 'openid.sreg.required': 'postcode', - }) + }) self.openid_request = CheckIDRequest.fromMessage(message, op_endpoint) views.setRequest(self.request, self.openid_request) @@ -85,7 +84,6 @@ class TestShowDecidePage(TestCase): response) - class TestGenericXRDS(TestCase): def test_genericRender(self): """Render an XRDS document with a single type URI and a single endpoint URL diff --git a/examples/djopenid/server/urls.py b/examples/djopenid/server/urls.py index d6931a4..6763d85 100644 --- a/examples/djopenid/server/urls.py +++ b/examples/djopenid/server/urls.py @@ -1,5 +1,4 @@ - -from django.conf.urls.defaults import * +from django.conf.urls.defaults import patterns urlpatterns = patterns( 'djopenid.server.views', diff --git a/examples/djopenid/server/views.py b/examples/djopenid/server/views.py index bb6d660..bbb9468 100644 --- a/examples/djopenid/server/views.py +++ b/examples/djopenid/server/views.py @@ -23,7 +23,7 @@ from django.views.generic.simple import direct_to_template from openid.consumer.discover import OPENID_IDP_2_0_TYPE from openid.extensions import pape, sreg from openid.fetchers import HTTPFetchingError -from openid.server.server import CheckIDRequest, EncodingError, ProtocolError, Server +from openid.server.server import EncodingError, ProtocolError, Server from openid.server.trustroot import verifyReturnTo from openid.yadis.discover import DiscoveryFailure @@ -38,12 +38,14 @@ def getOpenIDStore(): """ return util.getOpenIDStore('/tmp/djopenid_s_store', 's_') + def getServer(request): """ Get a Server object to perform OpenID authentication. """ return Server(getOpenIDStore(), getViewURL(request, endpoint)) + def setRequest(request, openid_request): """ Store the openid request information in the session. @@ -53,12 +55,14 @@ def setRequest(request, openid_request): else: request.session['openid_request'] = None + def getRequest(request): """ Get an openid request from the session, if any. """ return request.session.get('openid_request') + def server(request): """ Respond to requests for the server's primary web page. @@ -70,6 +74,7 @@ def server(request): 'server_xrds_url': getViewURL(request, idpXrds), }) + def idpXrds(request): """ Respond to requests for the IDP's XRDS document, which is used in @@ -78,6 +83,7 @@ def idpXrds(request): return util.renderXRDS( request, [OPENID_IDP_2_0_TYPE], [getViewURL(request, endpoint)]) + def idPage(request): """ Serve the identity page for OpenID URLs. @@ -87,6 +93,7 @@ def idPage(request): 'server/idPage.html', {'server_url': getViewURL(request, endpoint)}) + def trustPage(request): """ Display the trust page template, which allows the user to decide @@ -95,7 +102,8 @@ def trustPage(request): return direct_to_template( request, 'server/trust.html', - {'trust_handler_url':getViewURL(request, processTrustResult)}) + {'trust_handler_url': getViewURL(request, processTrustResult)}) + def endpoint(request): """ @@ -109,7 +117,7 @@ def endpoint(request): # library can use. try: openid_request = s.decodeRequest(query) - except ProtocolError, why: + except ProtocolError as why: # This means the incoming request was invalid. return direct_to_template( request, @@ -134,6 +142,7 @@ def endpoint(request): openid_response = s.handleRequest(openid_request) return displayResponse(request, openid_response) + def handleCheckIDRequest(request, openid_request): """ Handle checkid_* requests. Get input from the user to find out @@ -175,6 +184,7 @@ def handleCheckIDRequest(request, openid_request): setRequest(request, openid_request) return showDecidePage(request, openid_request) + def showDecidePage(request, openid_request): """ Render a page to the user so a trust decision can be made. @@ -186,11 +196,10 @@ def showDecidePage(request, openid_request): try: # Stringify because template's ifequal can only compare to strings. - trust_root_valid = verifyReturnTo(trust_root, return_to) \ - and "Valid" or "Invalid" - except DiscoveryFailure, err: + trust_root_valid = verifyReturnTo(trust_root, return_to) and "Valid" or "Invalid" + except DiscoveryFailure: trust_root_valid = "DISCOVERY_FAILED" - except HTTPFetchingError, err: + except HTTPFetchingError: trust_root_valid = "Unreachable" pape_request = pape.Request.fromOpenIDRequest(openid_request) @@ -199,11 +208,12 @@ def showDecidePage(request, openid_request): request, 'server/trust.html', {'trust_root': trust_root, - 'trust_handler_url':getViewURL(request, processTrustResult), + 'trust_handler_url': getViewURL(request, processTrustResult), 'trust_root_valid': trust_root_valid, 'pape_request': pape_request, }) + def processTrustResult(request): """ Handle the result of a trust decision and respond to the RP @@ -236,7 +246,7 @@ def processTrustResult(request): 'country': 'ES', 'language': 'eu', 'timezone': 'America/New_York', - } + } sreg_req = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_resp = sreg.SRegResponse.extractResponse(sreg_req, sreg_data) @@ -248,6 +258,7 @@ def processTrustResult(request): return displayResponse(request, openid_response) + def displayResponse(request, openid_response): """ Display an OpenID response. Errors will be displayed directly to @@ -260,7 +271,7 @@ def displayResponse(request, openid_response): # Encode the response into something that is renderable. try: webresponse = s.encodeResponse(openid_response) - except EncodingError, why: + except EncodingError as why: # If it couldn't be encoded, display an error. text = why.response.encodeToKVForm() return direct_to_template( diff --git a/examples/djopenid/settings.py b/examples/djopenid/settings.py index f2a7c87..1ba3ff4 100644 --- a/examples/djopenid/settings.py +++ b/examples/djopenid/settings.py @@ -6,9 +6,11 @@ import warnings try: import openid -except ImportError, e: +except ImportError as e: warnings.warn("Could not import OpenID library. Please consult the djopenid README.") sys.exit(1) +else: + del openid DEBUG = True TEMPLATE_DEBUG = DEBUG @@ -21,7 +23,7 @@ MANAGERS = ADMINS DATABASES = { 'default': { - 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'mysql', 'sqlite3' or 'oracle'. + 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'mysql', 'sqlite3' or 'oracle'. 'NAME': '/tmp/test.db', # Or path to database file if using sqlite3. 'USER': '', # Not used with sqlite3. 'PASSWORD': '', # Not used with sqlite3. @@ -61,7 +63,7 @@ SECRET_KEY = 'u^bw6lmsa6fah0$^lz-ct$)y7x7#ag92-z+y45-8!(jk0lkavy' TEMPLATE_LOADERS = ( 'django.template.loaders.filesystem.Loader', 'django.template.loaders.app_directories.Loader', -# 'django.template.loaders.eggs.load_template_source', + # 'django.template.loaders.eggs.load_template_source', ) MIDDLEWARE_CLASSES = ( diff --git a/examples/djopenid/urls.py b/examples/djopenid/urls.py index d91ee1f..3783317 100644 --- a/examples/djopenid/urls.py +++ b/examples/djopenid/urls.py @@ -1,4 +1,4 @@ -from django.conf.urls.defaults import * +from django.conf.urls.defaults import include, patterns urlpatterns = patterns( '', diff --git a/examples/djopenid/util.py b/examples/djopenid/util.py index f06e11f..2847d8e 100644 --- a/examples/djopenid/util.py +++ b/examples/djopenid/util.py @@ -1,17 +1,12 @@ - """ Utility code for the Django example consumer and server. """ - from urlparse import urljoin -from django import http from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.urlresolvers import reverse as reverseURL from django.db import connection -from django.template import loader -from django.template.context import RequestContext from django.views.generic.simple import direct_to_template from openid.store import sqlstore @@ -41,7 +36,7 @@ def getOpenIDStore(filestore_path, table_prefix): The result of this function should be passed to the Consumer constructor as the store parameter. """ - if not settings.DATABASES.get('default', {'ENGINE':None}).get('ENGINE'): + if not settings.DATABASES.get('default', {'ENGINE': None}).get('ENGINE'): return FileOpenIDStore(filestore_path) # Possible side-effect: create a database connection if one isn't @@ -52,27 +47,23 @@ def getOpenIDStore(filestore_path, table_prefix): tablenames = { 'associations_table': table_prefix + 'openid_associations', 'nonces_table': table_prefix + 'openid_nonces', - } + } types = { 'django.db.backends.postgresql': sqlstore.PostgreSQLStore, 'django.db.backends.mysql': sqlstore.MySQLStore, 'django.db.backends.sqlite3': sqlstore.SQLiteStore, - } + } + engine = settings.DATABASES.get('default', {'ENGINE': None}).get('ENGINE') try: - s = types[settings.DATABASES.get('default', {'ENGINE':None}).get('ENGINE')](connection.connection, - **tablenames) + s = types[engine](connection.connection, **tablenames) except KeyError: - raise ImproperlyConfigured, \ - "Database engine %s not supported by OpenID library" % \ - (settings.DATABASES.get('default', {'ENGINE':None}).get('ENGINE'),) + raise ImproperlyConfigured("Database engine %s not supported by OpenID library" % engine) try: s.createTables() - except (SystemExit, KeyboardInterrupt, MemoryError), e: - raise - except: + except Exception: # XXX This is not the Right Way to do this, but because the # underlying database implementation might differ in behavior # at this point, we can't reliably catch the right @@ -85,11 +76,13 @@ def getOpenIDStore(filestore_path, table_prefix): return s + def getViewURL(req, view_name_or_obj, args=None, kwargs=None): relative_url = reverseURL(view_name_or_obj, args=args, kwargs=kwargs) full_path = req.META.get('SCRIPT_NAME', '') + relative_url return urljoin(getBaseURL(req), full_path) + def getBaseURL(req): """ Given a Django web request object, returns the OpenID 'trust root' @@ -101,12 +94,12 @@ def getBaseURL(req): name = req.META['HTTP_HOST'] try: name = name[:name.index(':')] - except: + except Exception: pass try: port = int(req.META['SERVER_PORT']) - except: + except Exception: port = 80 proto = req.META['SERVER_PROTOCOL'] @@ -124,6 +117,7 @@ def getBaseURL(req): url = "%s://%s%s/" % (proto, name, port) return url + def normalDict(request_data): """ Converts a django request MutliValueDict (e.g., request.GET, @@ -135,6 +129,7 @@ def normalDict(request_data): """ return dict((k, v) for k, v in request_data.iteritems()) + def renderXRDS(request, type_uris, endpoint_urls): """Render an XRDS page with the specified type URIs and endpoint URLs in one service block, and return a response with the @@ -142,6 +137,6 @@ def renderXRDS(request, type_uris, endpoint_urls): """ response = direct_to_template( request, 'xrds.xml', - {'type_uris':type_uris, 'endpoint_urls':endpoint_urls,}) + {'type_uris': type_uris, 'endpoint_urls': endpoint_urls}) response['Content-Type'] = YADIS_CONTENT_TYPE return response diff --git a/examples/djopenid/views.py b/examples/djopenid/views.py index 3f08324..5d7a4e2 100644 --- a/examples/djopenid/views.py +++ b/examples/djopenid/views.py @@ -12,4 +12,4 @@ def index(request): return direct_to_template( request, 'index.html', - {'consumer_url':consumer_url, 'server_url':server_url}) + {'consumer_url': consumer_url, 'server_url': server_url}) diff --git a/examples/server.py b/examples/server.py index 0b12597..2da8835 100644 --- a/examples/server.py +++ b/examples/server.py @@ -30,6 +30,7 @@ For more information, see the README in the root of the library distribution.""") sys.exit(1) else: + del openid from openid.consumer import discover from openid.extensions import sreg from openid.server import server @@ -41,6 +42,7 @@ class OpenIDHTTPServer(HTTPServer): http server that contains a reference to an OpenID Server and knows its base URL. """ + def __init__(self, *args, **kwargs): HTTPServer.__init__(self, *args, **kwargs) @@ -63,7 +65,6 @@ class ServerHandler(BaseHTTPRequestHandler): self.user = None BaseHTTPRequestHandler.__init__(self, *args, **kwargs) - def do_GET(self): try: self.parsed_uri = urlparse(self.path) @@ -94,9 +95,7 @@ class ServerHandler(BaseHTTPRequestHandler): self.send_response(404) self.end_headers() - except (KeyboardInterrupt, SystemExit): - raise - except: + except Exception: self.send_response(500) self.send_header('Content-type', 'text/html') self.end_headers() @@ -124,9 +123,7 @@ class ServerHandler(BaseHTTPRequestHandler): self.send_response(404) self.end_headers() - except (KeyboardInterrupt, SystemExit): - raise - except: + except Exception: self.send_response(500) self.send_header('Content-type', 'text/html') self.end_headers() @@ -160,7 +157,6 @@ class ServerHandler(BaseHTTPRequestHandler): self.displayResponse(response) - def setUser(self): cookies = self.headers.get('Cookie') if cookies: @@ -181,7 +177,7 @@ class ServerHandler(BaseHTTPRequestHandler): def serverEndPoint(self, query): try: request = self.server.openid.decodeRequest(query) - except server.ProtocolError, why: + except server.ProtocolError as why: self.displayResponse(why) return @@ -203,8 +199,8 @@ class ServerHandler(BaseHTTPRequestHandler): # and the user should be asked for permission to release # it. sreg_data = { - 'nickname':self.user - } + 'nickname': self.user + } sreg_resp = sreg.SRegResponse.extractResponse(sreg_req, sreg_data) response.addExtension(sreg_resp) @@ -229,7 +225,7 @@ class ServerHandler(BaseHTTPRequestHandler): def displayResponse(self, response): try: webresponse = self.server.openid.encodeResponse(response) - except server.EncodingError, why: + except server.EncodingError as why: text = why.response.encodeToKVForm() self.showErrorPage('<pre>%s</pre>' % cgi.escape(text)) return @@ -287,7 +283,7 @@ class ServerHandler(BaseHTTPRequestHandler): ('http://www.openidenabled.com/', 'An OpenID community Web site, home of this library'), ('http://www.openid.net/', 'the official OpenID Web site'), - ] + ] resource_markup = ''.join([term(url, text) for url, text in resources]) @@ -336,14 +332,14 @@ class ServerHandler(BaseHTTPRequestHandler): ''' % error_message) def showDecidePage(self, request): - id_url_base = self.server.base_url+'id/' + id_url_base = self.server.base_url + 'id/' # XXX: This may break if there are any synonyms for id_url_base, # such as referring to it by IP address or a CNAME. - assert (request.identity.startswith(id_url_base) or + assert (request.identity.startswith(id_url_base) or request.idSelect()), repr((request.identity, id_url_base)) expected_user = request.identity[len(id_url_base):] - if request.idSelect(): # We are being asked to select an ID + if request.idSelect(): # We are being asked to select an ID msg = '''\ <p>A site has asked for your identity. You may select an identifier by which you would like this site to know you. @@ -355,7 +351,7 @@ class ServerHandler(BaseHTTPRequestHandler): fdata = { 'id_url_base': id_url_base, 'trust_root': request.trust_root, - } + } form = '''\ <form method="POST" action="/allow"> <table> @@ -370,7 +366,7 @@ class ServerHandler(BaseHTTPRequestHandler): <input type="submit" name="yes" value="yes" /> <input type="submit" name="no" value="no" /> </form> - '''%fdata + ''' % fdata elif expected_user == self.user: msg = '''\ <p>A new site has asked to confirm your identity. If you @@ -382,7 +378,7 @@ class ServerHandler(BaseHTTPRequestHandler): fdata = { 'identity': request.identity, 'trust_root': request.trust_root, - } + } form = '''\ <table> <tr><td>Identity:</td><td>%(identity)s</td></tr> @@ -400,7 +396,7 @@ class ServerHandler(BaseHTTPRequestHandler): mdata = { 'expected_user': expected_user, 'user': self.user, - } + } msg = '''\ <p>A site has asked for an identity belonging to %(expected_user)s, but you are logged in as %(user)s. To @@ -412,7 +408,7 @@ class ServerHandler(BaseHTTPRequestHandler): 'identity': request.identity, 'trust_root': request.trust_root, 'expected_user': expected_user, - } + } form = '''\ <table> <tr><td>Identity:</td><td>%(identity)s</td></tr> @@ -432,9 +428,9 @@ class ServerHandler(BaseHTTPRequestHandler): def showIdPage(self, path): link_tag = '<link rel="openid.server" href="%sopenidserver">' %\ - self.server.base_url - yadis_loc_tag = '<meta http-equiv="x-xrds-location" content="%s">'%\ - (self.server.base_url+'yadis/'+path[4:]) + self.server.base_url + yadis_loc_tag = '<meta http-equiv="x-xrds-location" content="%s">' %\ + (self.server.base_url + 'yadis/' + path[4:]) disco_tags = link_tag + yadis_loc_tag ident = self.server.base_url + path[1:] @@ -480,8 +476,8 @@ class ServerHandler(BaseHTTPRequestHandler): </XRD> </xrds:XRDS> -"""%(discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE, - endpoint_url, user_url)) +""" % (discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE, + endpoint_url, user_url)) def showServerYadis(self): self.send_response(200) @@ -503,10 +499,10 @@ class ServerHandler(BaseHTTPRequestHandler): </XRD> </xrds:XRDS> -"""%(discover.OPENID_IDP_2_0_TYPE, endpoint_url,)) +""" % (discover.OPENID_IDP_2_0_TYPE, endpoint_url,)) def showMainPage(self): - yadis_tag = '<meta http-equiv="x-xrds-location" content="%s">'%\ + yadis_tag = '<meta http-equiv="x-xrds-location" content="%s">' %\ (self.server.base_url + 'serveryadis') if self.user: openid_url = self.server.base_url + 'id/' + self.user @@ -521,7 +517,7 @@ class ServerHandler(BaseHTTPRequestHandler): order to simulate a standard Web user experience. You are not <a href='/login'>logged in</a>.</p>""" - self.showPage(200, 'Main Page', head_extras = yadis_tag, msg='''\ + self.showPage(200, 'Main Page', head_extras=yadis_tag, msg='''\ <p>This is a simple OpenID server implemented using the <a href="http://openid.schtuff.com/">Python OpenID library</a>.</p> @@ -557,13 +553,14 @@ class ServerHandler(BaseHTTPRequestHandler): if self.user is None: user_link = '<a href="/login">not logged in</a>.' else: - user_link = 'logged in as <a href="/id/%s">%s</a>.<br /><a href="/loginsubmit?submit=true&success_to=/login">Log out</a>' % \ + user_link = 'logged in as <a href="/id/%s">%s</a>.<br />' \ + '<a href="/loginsubmit?submit=true&success_to=/login">Log out</a>' % \ (self.user, self.user) body = '' if err is not None: - body += '''\ + body += '''\ <div class="error"> %s </div> @@ -588,7 +585,7 @@ class ServerHandler(BaseHTTPRequestHandler): 'head_extras': head_extras, 'body': body, 'user_link': user_link, - } + } self.send_response(response_code) self.writeUserHeader() @@ -689,6 +686,7 @@ def main(host, port, data_path): print httpserver.base_url httpserver.serve_forever() + if __name__ == '__main__': parser = optparse.OptionParser('Usage:\n %prog [options]') parser.add_option( diff --git a/openid/__init__.py b/openid/__init__.py index 8ecb033..b172b30 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -41,7 +41,7 @@ __all__ = [ 'store', 'urinorm', 'yadis', - ] +] # Parse the version info try: diff --git a/openid/association.py b/openid/association.py index f9cc91e..8a52b78 100644 --- a/openid/association.py +++ b/openid/association.py @@ -30,7 +30,7 @@ __all__ = [ 'encrypted_negotiator', 'SessionNegotiator', 'Association', - ] +] import time @@ -40,7 +40,7 @@ from openid.message import OPENID_NS all_association_types = [ 'HMAC-SHA1', 'HMAC-SHA256', - ] +] if hasattr(cryptutil, 'hmacSha256'): supported_association_types = list(all_association_types) @@ -50,32 +50,34 @@ if hasattr(cryptutil, 'hmacSha256'): ('HMAC-SHA1', 'no-encryption'), ('HMAC-SHA256', 'DH-SHA256'), ('HMAC-SHA256', 'no-encryption'), - ] + ] only_encrypted_association_order = [ ('HMAC-SHA1', 'DH-SHA1'), ('HMAC-SHA256', 'DH-SHA256'), - ] + ] else: supported_association_types = ['HMAC-SHA1'] default_association_order = [ ('HMAC-SHA1', 'DH-SHA1'), ('HMAC-SHA1', 'no-encryption'), - ] + ] only_encrypted_association_order = [ ('HMAC-SHA1', 'DH-SHA1'), - ] + ] + def getSessionTypes(assoc_type): """Return the allowed session types for a given association type""" assoc_to_session = { 'HMAC-SHA1': ['DH-SHA1', 'no-encryption'], 'HMAC-SHA256': ['DH-SHA256', 'no-encryption'], - } + } return assoc_to_session.get(assoc_type, []) + def checkSessionType(assoc_type, session_type): """Check to make sure that this pair of assoc type and session type are allowed""" @@ -84,6 +86,7 @@ def checkSessionType(assoc_type, session_type): 'Session type %r not valid for assocation type %r' % (session_type, assoc_type)) + class SessionNegotiator(object): """A session negotiator controls the allowed and preferred association types and association session types. Both the @@ -166,7 +169,6 @@ class SessionNegotiator(object): checkSessionType(assoc_type, session_type) self.allowed_types.append((assoc_type, session_type)) - def isAllowed(self, assoc_type, session_type): """Is this combination of association type and session type allowed?""" assoc_good = (assoc_type, session_type) in self.allowed_types @@ -181,9 +183,11 @@ class SessionNegotiator(object): except IndexError: return (None, None) + default_negotiator = SessionNegotiator(default_association_order) encrypted_negotiator = SessionNegotiator(only_encrypted_association_order) + def getSecretSize(assoc_type): if assoc_type == 'HMAC-SHA1': return 20 @@ -192,6 +196,7 @@ def getSecretSize(assoc_type): else: raise ValueError('Unsupported association type: %r' % (assoc_type,)) + class Association(object): """ This class represents an association between a server and a @@ -247,14 +252,12 @@ class Association(object): 'issued', 'lifetime', 'assoc_type', - ] - + ] _macs = { 'HMAC-SHA1': cryptutil.hmacSha1, 'HMAC-SHA256': cryptutil.hmacSha256, - } - + } def fromExpiresIn(cls, expires_in, handle, secret, assoc_type): """ @@ -378,7 +381,7 @@ class Association(object): @rtype: C{bool} """ - return type(self) is type(other) and self.__dict__ == other.__dict__ + return type(self) == type(other) and self.__dict__ == other.__dict__ def __ne__(self, other): """ @@ -403,13 +406,13 @@ class Association(object): @rtype: str """ data = { - 'version':'2', - 'handle':self.handle, - 'secret':oidutil.toBase64(self.secret), - 'issued':str(int(self.issued)), - 'lifetime':str(int(self.lifetime)), - 'assoc_type':self.assoc_type - } + 'version': '2', + 'handle': self.handle, + 'secret': oidutil.toBase64(self.secret), + 'issued': str(int(self.issued)), + 'lifetime': str(int(self.lifetime)), + 'assoc_type': self.assoc_type + } assert len(data) == len(self.assoc_keys) pairs = [] @@ -476,7 +479,6 @@ class Association(object): return mac(self.secret, kv) - def getMessageSignature(self, message): """Return the signature of a message. @@ -499,8 +501,7 @@ class Association(object): @return: a new Message object with a signature @rtype: L{openid.message.Message} """ - if (message.hasKey(OPENID_NS, 'sig') or - message.hasKey(OPENID_NS, 'signed')): + if (message.hasKey(OPENID_NS, 'sig') or message.hasKey(OPENID_NS, 'signed')): raise ValueError('Message already has signed list or signature') extant_handle = message.getArg(OPENID_NS, 'assoc_handle') @@ -532,7 +533,6 @@ class Association(object): calculated_sig = self.getMessageSignature(message) return cryptutil.const_eq(calculated_sig, message_sig) - def _makePairs(self, message): signed = message.getArg(OPENID_NS, 'signed') if not signed: diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index eaa4847..c811ce0 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -251,7 +251,6 @@ def _httpResponseToMessage(response, server_url): return response_message - class Consumer(object): """An OpenID consumer implementation that performs discovery and does session management. @@ -338,7 +337,7 @@ class Consumer(object): disco = Discovery(self.session, user_url, self.session_key_prefix) try: service = disco.getNextService(self._discover) - except fetchers.HTTPFetchingError, why: + except fetchers.HTTPFetchingError as why: raise DiscoveryFailure( 'Error fetching XRDS document: %s' % (why[0],), None) @@ -374,7 +373,7 @@ class Consumer(object): try: auth_req.setAnonymous(anonymous) - except ValueError, why: + except ValueError as why: raise ProtocolError(str(why)) return auth_req @@ -414,8 +413,7 @@ class Consumer(object): except KeyError: pass - if (response.status in ['success', 'cancel'] and - response.identity_url is not None): + if (response.status in ['success', 'cancel'] and response.identity_url is not None): disco = Discovery(self.session, response.identity_url, @@ -448,6 +446,7 @@ class Consumer(object): """ self.consumer.negotiator = SessionNegotiator(association_preferences) + class DiffieHellmanSHA1ConsumerSession(object): session_type = 'DH-SHA1' hash_func = staticmethod(cryptutil.sha1) @@ -469,7 +468,7 @@ class DiffieHellmanSHA1ConsumerSession(object): args.update({ 'dh_modulus': cryptutil.longToBase64(self.dh.modulus), 'dh_gen': cryptutil.longToBase64(self.dh.generator), - }) + }) return args @@ -481,12 +480,14 @@ class DiffieHellmanSHA1ConsumerSession(object): enc_mac_key = oidutil.fromBase64(enc_mac_key64) return self.dh.xorSecret(dh_server_public, enc_mac_key, self.hash_func) + class DiffieHellmanSHA256ConsumerSession(DiffieHellmanSHA1ConsumerSession): session_type = 'DH-SHA256' hash_func = staticmethod(cryptutil.sha256) secret_size = 32 allowed_assoc_types = ['HMAC-SHA256'] + class PlainTextConsumerSession(object): session_type = 'no-encryption' allowed_assoc_types = ['HMAC-SHA1', 'HMAC-SHA256'] @@ -498,17 +499,21 @@ class PlainTextConsumerSession(object): mac_key64 = response.getArg(OPENID_NS, 'mac_key', no_default) return oidutil.fromBase64(mac_key64) + class SetupNeededError(Exception): """Internally-used exception that indicates that an immediate-mode request cancelled.""" + def __init__(self, user_setup_url=None): Exception.__init__(self, user_setup_url) self.user_setup_url = user_setup_url + class ProtocolError(ValueError): """Exception that indicates that a message violated the protocol. It is raised and caught internally to this file.""" + class TypeURIMismatch(ProtocolError): """A protocol error arising from type URIs mismatching """ @@ -525,7 +530,6 @@ class TypeURIMismatch(ProtocolError): return s - class ServerError(Exception): """Exception that is raised when the server returns a 400 response code to a direct request.""" @@ -546,6 +550,7 @@ class ServerError(Exception): fromMessage = classmethod(fromMessage) + class GenericConsumer(object): """This is the implementation of the common logic for OpenID consumers. It is unaware of the application in which it is @@ -573,10 +578,10 @@ class GenericConsumer(object): openid1_return_to_identifier_name = 'openid1_claimed_id' session_types = { - 'DH-SHA1':DiffieHellmanSHA1ConsumerSession, - 'DH-SHA256':DiffieHellmanSHA256ConsumerSession, - 'no-encryption':PlainTextConsumerSession, - } + 'DH-SHA1': DiffieHellmanSHA1ConsumerSession, + 'DH-SHA256': DiffieHellmanSHA256ConsumerSession, + 'no-encryption': PlainTextConsumerSession, + } _discover = staticmethod(discover) @@ -635,12 +640,12 @@ class GenericConsumer(object): def _complete_id_res(self, message, endpoint, return_to): try: self._checkSetupNeeded(message) - except SetupNeededError, why: + except SetupNeededError as why: return SetupNeededResponse(endpoint, why.user_setup_url) else: try: return self._doIdRes(message, endpoint, return_to) - except (ProtocolError, DiscoveryFailure), why: + except (ProtocolError, DiscoveryFailure) as why: return FailureResponse(endpoint, why[0]) def _completeInvalid(self, message, endpoint, _): @@ -657,7 +662,7 @@ class GenericConsumer(object): # message. try: self._verifyReturnToArgs(message.toPostArgs()) - except ProtocolError, why: + except ProtocolError as why: _LOGGER.exception("Verifying return_to arguments: %s", why) return False @@ -721,7 +726,6 @@ class GenericConsumer(object): "return_to does not match return URL. Expected %r, got %r" % (return_to, message.getArg(OPENID_NS, 'return_to'))) - # Verify discovery information: endpoint = self._verifyDiscoveryResults(message, endpoint) _LOGGER.info("Received id_res response from %s using association %s", @@ -763,11 +767,10 @@ class GenericConsumer(object): try: timestamp, salt = splitNonce(nonce) - except ValueError, why: + except ValueError as why: raise ProtocolError('Malformed nonce: %s' % (why[0],)) - if (self.store is not None and - not self.store.useNonce(server_url, timestamp, salt)): + if (self.store is not None and not self.store.useNonce(server_url, timestamp, salt)): raise ProtocolError('Nonce already used or out of range') def _idResCheckSignature(self, message, server_url): @@ -811,15 +814,12 @@ class GenericConsumer(object): require_fields = { OPENID2_NS: basic_fields + ['op_endpoint'], OPENID1_NS: basic_fields + ['identity'], - } + } require_sigs = { - OPENID2_NS: basic_sig_fields + ['response_nonce', - 'claimed_id', - 'assoc_handle', - 'op_endpoint',], + OPENID2_NS: basic_sig_fields + ['response_nonce', 'claimed_id', 'assoc_handle', 'op_endpoint'], OPENID1_NS: basic_sig_fields, - } + } for field in require_fields[message.getOpenIDNamespace()]: if not message.hasKey(OPENID_NS, field): @@ -833,7 +833,6 @@ class GenericConsumer(object): if message.hasKey(OPENID_NS, field) and field not in signed_list: raise ProtocolError('"%s" not signed' % (field,)) - def _verifyReturnToArgs(query): """Verify that the arguments in the return_to URL are present in this response. @@ -883,7 +882,6 @@ class GenericConsumer(object): else: return self._verifyDiscoveryResultsOpenID1(resp_msg, endpoint) - def _verifyDiscoveryResultsOpenID2(self, resp_msg, endpoint): to_match = OpenIDServiceEndpoint() to_match.type_uris = [OPENID_2_0_TYPE] @@ -896,8 +894,7 @@ class GenericConsumer(object): # claimed_id and identifier must both be present or both # be absent - if (to_match.claimed_id is None and - to_match.local_id is not None): + if (to_match.claimed_id is None and to_match.local_id is not None): raise ProtocolError( 'openid.identity is present without openid.claimed_id') @@ -925,7 +922,7 @@ class GenericConsumer(object): # case. try: self._verifyDiscoverySingle(endpoint, to_match) - except ProtocolError, e: + except ProtocolError as e: _LOGGER.exception("Error attempting to use stored discovery information: %s", e) _LOGGER.info("Attempting discovery to verify endpoint") endpoint = self._discoverAndVerify( @@ -968,7 +965,7 @@ class GenericConsumer(object): self._verifyDiscoverySingle(endpoint, to_match) except TypeURIMismatch: self._verifyDiscoverySingle(endpoint, to_match_1_0) - except ProtocolError, e: + except ProtocolError as e: _LOGGER.exception("Error attempting to use stored discovery information: %s", e) _LOGGER.info("Attempting discovery to verify endpoint") else: @@ -1048,7 +1045,6 @@ class GenericConsumer(object): return self._verifyDiscoveredServices(claimed_id, services, to_match_endpoints) - def _verifyDiscoveredServices(self, claimed_id, services, to_match_endpoints): """See @L{_discoverAndVerify}""" @@ -1060,7 +1056,7 @@ class GenericConsumer(object): try: self._verifyDiscoverySingle( endpoint, to_match_endpoint) - except ProtocolError, why: + except ProtocolError as why: failure_messages.append(str(why)) else: # It matches, so discover verification has @@ -1087,7 +1083,7 @@ class GenericConsumer(object): return False try: response = self._makeKVPost(request, server_url) - except (fetchers.HTTPFetchingError, ServerError), e: + except (fetchers.HTTPFetchingError, ServerError) as e: _LOGGER.exception('check_authentication failed: %s', e) return False else: @@ -1167,7 +1163,7 @@ class GenericConsumer(object): try: assoc = self._requestAssociation( endpoint, assoc_type, session_type) - except ServerError, why: + except ServerError as why: supportedTypes = self._extractSupportedAssociationType(why, endpoint, assoc_type) @@ -1179,7 +1175,7 @@ class GenericConsumer(object): try: assoc = self._requestAssociation( endpoint, assoc_type, session_type) - except ServerError, why: + except ServerError as why: # Do not keep trying, since it rejected the # association type that it told us to use. _LOGGER.error('Server %s refused its suggested association type: session_type=%s, assoc_type=%s', @@ -1201,8 +1197,7 @@ class GenericConsumer(object): """ # Any error message whose code is not 'unsupported-type' # should be considered a total failure. - if server_error.error_code != 'unsupported-type' or \ - server_error.message.isOpenID1(): + if server_error.error_code != 'unsupported-type' or server_error.message.isOpenID1(): _LOGGER.error('Server error when requesting an association from %r: %s', endpoint.server_url, server_error.error_text) return None @@ -1227,7 +1222,6 @@ class GenericConsumer(object): else: return assoc_type, session_type - def _requestAssociation(self, endpoint, assoc_type, session_type): """Make and process one association request to this endpoint's OP endpoint URL. @@ -1242,16 +1236,16 @@ class GenericConsumer(object): try: response = self._makeKVPost(args, endpoint.server_url) - except fetchers.HTTPFetchingError, why: + except fetchers.HTTPFetchingError as why: _LOGGER.exception('openid.associate request failed: %s', why) return None try: assoc = self._extractAssociation(response, assoc_session) - except KeyError, why: + except KeyError as why: _LOGGER.exception('Missing required parameter in response from %s: %s', endpoint.server_url, why) return None - except ProtocolError, why: + except ProtocolError as why: _LOGGER.exception('Protocol error parsing response from %s: %s', endpoint.server_url, why) return None else: @@ -1287,15 +1281,14 @@ class GenericConsumer(object): args = { 'mode': 'associate', 'assoc_type': assoc_type, - } + } if not endpoint.compatibilityMode(): args['ns'] = OPENID2_NS # Leave out the session type if we're in compatibility mode # *and* it's no-encryption. - if (not endpoint.compatibilityMode() or - assoc_session.session_type != 'no-encryption'): + if (not endpoint.compatibilityMode() or assoc_session.session_type != 'no-encryption'): args['session_type'] = assoc_session.session_type args.update(assoc_session.getRequest()) @@ -1372,7 +1365,7 @@ class GenericConsumer(object): OPENID_NS, 'expires_in', no_default) try: expires_in = int(expires_in_str) - except ValueError, why: + except ValueError as why: raise ProtocolError('Invalid expires_in field: %s' % (why[0],)) # OpenID 1 has funny association session behaviour. @@ -1384,8 +1377,7 @@ class GenericConsumer(object): # Session type mismatch if assoc_session.session_type != session_type: - if (assoc_response.isOpenID1() and - session_type == 'no-encryption'): + if (assoc_response.isOpenID1() and session_type == 'no-encryption'): # In OpenID 1, any association request can result in a # 'no-encryption' association response. Setting # assoc_session to a new no-encryption session should @@ -1410,13 +1402,14 @@ class GenericConsumer(object): # type. try: secret = assoc_session.extractSecret(assoc_response) - except ValueError, why: + except ValueError as why: fmt = 'Malformed response for %s session: %s' raise ProtocolError(fmt % (assoc_session.session_type, why[0])) return Association.fromExpiresIn( expires_in, assoc_handle, secret, assoc_type) + class AuthRequest(object): """An object that holds the state necessary for generating an OpenID authentication request. This object holds the association @@ -1550,11 +1543,7 @@ class AuthRequest(object): realm_key = 'realm' message.updateArgs(OPENID_NS, - { - realm_key:realm, - 'mode':mode, - 'return_to':return_to, - }) + {realm_key: realm, 'mode': mode, 'return_to': return_to}) if not self._anonymous: if self.endpoint.isOPIdentifier(): @@ -1623,8 +1612,7 @@ class AuthRequest(object): message = self.getMessage(realm, return_to, immediate) return message.toURL(self.endpoint.server_url) - def formMarkup(self, realm, return_to=None, immediate=False, - form_tag_attrs=None): + def formMarkup(self, realm, return_to=None, immediate=False, form_tag_attrs=None): """Get html for a form to submit this request to the IDP. @param form_tag_attrs: Dictionary of attributes to be added to @@ -1634,11 +1622,9 @@ class AuthRequest(object): @type form_tag_attrs: {unicode: unicode} """ message = self.getMessage(realm, return_to, immediate) - return message.toFormMarkup(self.endpoint.server_url, - form_tag_attrs) + return message.toFormMarkup(self.endpoint.server_url, form_tag_attrs) - def htmlMarkup(self, realm, return_to=None, immediate=False, - form_tag_attrs=None): + def htmlMarkup(self, realm, return_to=None, immediate=False, form_tag_attrs=None): """Get an autosubmitting HTML page that submits this request to the IDP. This is just a wrapper for formMarkup. @@ -1646,10 +1632,7 @@ class AuthRequest(object): @returns: str """ - return oidutil.autoSubmitHTML(self.formMarkup(realm, - return_to, - immediate, - form_tag_attrs)) + return oidutil.autoSubmitHTML(self.formMarkup(realm, return_to, immediate, form_tag_attrs)) def shouldSendRedirect(self): """Should this OpenID authentication request be sent as a HTTP @@ -1659,11 +1642,13 @@ class AuthRequest(object): """ return self.endpoint.compatibilityMode() + FAILURE = 'failure' SUCCESS = 'success' CANCEL = 'cancel' SETUP_NEEDED = 'setup_needed' + class Response(object): status = None @@ -1694,6 +1679,7 @@ class Response(object): return self.endpoint.getDisplayIdentifier() return None + class SuccessResponse(Response): """A response with a status of SUCCESS. Indicates that this request is a successful acknowledgement from the OpenID server that the @@ -1854,6 +1840,7 @@ class CancelResponse(Response): def __init__(self, endpoint): self.setEndpoint(endpoint) + class SetupNeededResponse(Response): """A response with a status of SETUP_NEEDED. Indicates that the request was in immediate mode, and the server is unable to diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index 5764dc5..f847c63 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -11,12 +11,12 @@ __all__ = [ 'OPENID_IDP_2_0_TYPE', 'OpenIDServiceEndpoint', 'discover', - ] +] import logging import urlparse -from openid import fetchers, urinorm, yadis +from openid import fetchers, urinorm from openid.consumer import html_parse from openid.message import OPENID1_NS as OPENID_1_0_MESSAGE_NS, OPENID2_NS as OPENID_2_0_MESSAGE_NS from openid.yadis import filters, xri, xrires @@ -48,7 +48,7 @@ class OpenIDServiceEndpoint(object): OPENID_2_0_TYPE, OPENID_1_1_TYPE, OPENID_1_0_TYPE, - ] + ] def __init__(self): self.claimed_id = None @@ -56,15 +56,14 @@ class OpenIDServiceEndpoint(object): self.type_uris = [] self.local_id = None self.canonicalID = None - self.used_yadis = False # whether this came from an XRDS + self.used_yadis = False # whether this came from an XRDS self.display_identifier = None def usesExtension(self, extension_uri): return extension_uri in self.type_uris def preferredNamespace(self): - if (OPENID_IDP_2_0_TYPE in self.type_uris or - OPENID_2_0_TYPE in self.type_uris): + if (OPENID_IDP_2_0_TYPE in self.type_uris or OPENID_2_0_TYPE in self.type_uris): return OPENID_2_0_MESSAGE_NS else: return OPENID_1_0_MESSAGE_NS @@ -74,10 +73,7 @@ class OpenIDServiceEndpoint(object): I consider C{/server} endpoints to implicitly support C{/signon}. """ - return ( - (type_uri in self.type_uris) or - (type_uri == OPENID_2_0_TYPE and self.isOPIdentifier()) - ) + return ((type_uri in self.type_uris) or (type_uri == OPENID_2_0_TYPE and self.isOPIdentifier())) def getDisplayIdentifier(self): """Return the display_identifier if set, else return the claimed_id. @@ -155,7 +151,7 @@ class OpenIDServiceEndpoint(object): discovery_types = [ (OPENID_2_0_TYPE, 'openid2.provider', 'openid2.local_id'), (OPENID_1_1_TYPE, 'openid.server', 'openid.delegate'), - ] + ] link_attrs = html_parse.parseLinkAttrs(html) services = [] @@ -178,7 +174,6 @@ class OpenIDServiceEndpoint(object): fromHTML = classmethod(fromHTML) - def fromXRDS(cls, uri, xrds): """Parse the given document as XRDS looking for OpenID services. @@ -192,7 +187,6 @@ class OpenIDServiceEndpoint(object): fromXRDS = classmethod(fromXRDS) - def fromDiscoveryResult(cls, discoveryResult): """Create endpoints from a DiscoveryResult. @@ -213,7 +207,6 @@ class OpenIDServiceEndpoint(object): fromDiscoveryResult = classmethod(fromDiscoveryResult) - def fromOPEndpointURL(cls, op_endpoint_url): """Construct an OP-Identifier OpenIDServiceEndpoint object for a given OP Endpoint URL @@ -228,7 +221,6 @@ class OpenIDServiceEndpoint(object): fromOPEndpointURL = classmethod(fromOPEndpointURL) - def __str__(self): return ("<%s.%s " "server_url=%r " @@ -237,7 +229,7 @@ class OpenIDServiceEndpoint(object): "canonicalID=%r " "used_yadis=%s " ">" - % (self.__class__.__module__, self.__class__.__name__, + % (self.__class__.__module__, self.__class__.__name__, self.server_url, self.claimed_id, self.local_id, @@ -245,7 +237,6 @@ class OpenIDServiceEndpoint(object): self.used_yadis)) - def findOPLocalIdentifier(service_element, type_uris): """Find the OP-Local Identifier for this xrd:Service element. @@ -275,8 +266,7 @@ def findOPLocalIdentifier(service_element, type_uris): # Build the list of tags that could contain the OP-Local Identifier local_id_tags = [] - if (OPENID_1_1_TYPE in type_uris or - OPENID_1_0_TYPE in type_uris): + if (OPENID_1_1_TYPE in type_uris or OPENID_1_0_TYPE in type_uris): local_id_tags.append(nsTag(OPENID_1_0_NS, 'Delegate')) if OPENID_2_0_TYPE in type_uris: @@ -296,22 +286,25 @@ def findOPLocalIdentifier(service_element, type_uris): return local_id + def normalizeURL(url): """Normalize a URL, converting normalization failures to DiscoveryFailure""" try: normalized = urinorm.urinorm(url) - except ValueError, why: + except ValueError as why: raise DiscoveryFailure('Normalizing identifier: %s' % (why[0],), None) else: return urlparse.urldefrag(normalized)[0] + def normalizeXRI(xri): """Normalize an XRI, stripping its scheme if present""" if xri.startswith("xri://"): xri = xri[6:] return xri + def arrangeByType(service_list, preferred_types): """Rearrange service_list in a new list so services are ordered by types listed in preferred_types. Return the new list.""" @@ -333,9 +326,7 @@ def arrangeByType(service_list, preferred_types): # Build a list with the service elements in tuples whose # comparison will prefer the one with the best matching service - prio_services = [(bestMatchingService(s), orig_index, s) - for (orig_index, s) in enumerate(service_list)] - prio_services.sort() + prio_services = sorted((bestMatchingService(s), orig_index, s) for (orig_index, s) in enumerate(service_list)) # Now that the services are sorted by priority, remove the sort # keys from the list. @@ -344,6 +335,7 @@ def arrangeByType(service_list, preferred_types): return prio_services + def getOPOrUserServices(openid_services): """Extract OP Identifier services. If none found, return the rest, sorted with most preferred first according to @@ -360,6 +352,7 @@ def getOPOrUserServices(openid_services): return op_services or openid_services + def discoverYadis(uri): """Discover OpenID services for a URI. Tries Yadis and falls back on old-style <link rel='...'> discovery if Yadis fails. @@ -401,6 +394,7 @@ def discoverYadis(uri): return (yadis_url, getOPOrUserServices(openid_services)) + def discoverXRI(iname): endpoints = [] iname = normalizeXRI(iname) @@ -440,6 +434,7 @@ def discoverNoYadis(uri): claimed_id, http_resp.body) return claimed_id, openid_services + def discoverURI(uri): parsed = urlparse.urlparse(uri) if parsed[0] and parsed[1]: @@ -453,6 +448,7 @@ def discoverURI(uri): claimed_id = normalizeURL(claimed_id) return claimed_id, openid_services + def discover(identifier): if xri.identifierScheme(identifier) == "XRI": return discoverXRI(identifier) diff --git a/openid/consumer/html_parse.py b/openid/consumer/html_parse.py index 880dfda..14ff8cc 100644 --- a/openid/consumer/html_parse.py +++ b/openid/consumer/html_parse.py @@ -70,12 +70,17 @@ The parser deals with invalid markup in these ways: __all__ = ['parseLinkAttrs'] import re - -flags = ( re.DOTALL # Match newlines with '.' - | re.IGNORECASE - | re.VERBOSE # Allow comments and whitespace in patterns - | re.UNICODE # Make \b respect Unicode word boundaries - ) +from functools import partial + +flags = ( + # Match newlines with '.' + re.DOTALL | + re.IGNORECASE | + # Allow comments and whitespace in patterns + re.VERBOSE | + # Make \b respect Unicode word boundaries + re.UNICODE +) # Stuff to remove before we start looking for tags removed_re = re.compile(r''' @@ -123,6 +128,7 @@ tag_expr = r''' ) ''' + def tagMatcher(tag_name, *close_tags): if close_tags: options = '|'.join((tag_name,) + close_tags) @@ -133,6 +139,7 @@ def tagMatcher(tag_name, *close_tags): expr = tag_expr % locals() return re.compile(expr, flags) + # Must contain at least an open html and an open head tag html_find = tagMatcher('html') head_find = tagMatcher('head', 'body') @@ -160,17 +167,20 @@ attr_find = re.compile(r''' # Entity replacement: replacements = { - 'amp':'&', - 'lt':'<', - 'gt':'>', - 'quot':'"', - } + 'amp': '&', + 'lt': '<', + 'gt': '>', + 'quot': '"', +} ent_replace = re.compile(r'&(%s);' % '|'.join(replacements.keys())) + + def replaceEnt(mo): "Replace the entities that are specified by OpenID" return replacements.get(mo.group(1), mo.group()) + def parseLinkAttrs(html): """Find all link tags in a string representing a HTML document and return a list of their attributes. @@ -214,6 +224,7 @@ def parseLinkAttrs(html): return matches + def relMatches(rel_attr, target_rel): """Does this target_rel appear in the rel_str?""" # XXX: TESTME @@ -225,19 +236,22 @@ def relMatches(rel_attr, target_rel): return 0 + def linkHasRel(link_attrs, target_rel): """Does this link have target_rel as a relationship?""" # XXX: TESTME rel_attr = link_attrs.get('rel') return rel_attr and relMatches(rel_attr, target_rel) + def findLinksRel(link_attrs_list, target_rel): """Filter the list of link attributes on whether it has target_rel as a relationship.""" # XXX: TESTME - matchesTarget = lambda attrs: linkHasRel(attrs, target_rel) + matchesTarget = partial(linkHasRel, target_rel=target_rel) return filter(matchesTarget, link_attrs_list) + def findFirstHref(link_attrs_list, target_rel): """Return the value of the href attribute for the first link tag in the list that has target_rel as a relationship.""" diff --git a/openid/cryptutil.py b/openid/cryptutil.py index 769aa6c..27ff796 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -21,7 +21,7 @@ __all__ = [ 'randrange', 'sha1', 'sha256', - ] +] import hashlib import hmac @@ -36,18 +36,23 @@ class HashContainer(object): self.new = hash_constructor self.digest_size = hash_constructor().digest_size + sha1_module = HashContainer(hashlib.sha1) sha256_module = HashContainer(hashlib.sha256) + def hmacSha1(key, text): return hmac.new(key, text, sha1_module).digest() + def sha1(s): return sha1_module.new(s).digest() + def hmacSha256(key, text): return hmac.new(key, text, sha256_module).digest() + def sha256(s): return sha256_module.new(s).digest() @@ -57,22 +62,22 @@ try: except ImportError: import pickle - def longToBinary(l): - if l == 0: + def longToBinary(value): + if value == 0: return '\x00' - return ''.join(reversed(pickle.encode_long(l))) + return ''.join(reversed(pickle.encode_long(value))) def binaryToLong(s): return pickle.decode_long(''.join(reversed(s))) else: # We have pycrypto - def longToBinary(l): - if l < 0: + def longToBinary(value): + if value < 0: raise ValueError('This function only supports positive integers') - bytes = long_to_bytes(l) + bytes = long_to_bytes(value) if ord(bytes[0]) > 127: return '\x00' + bytes else: @@ -112,6 +117,7 @@ except AttributeError: return ''.join(bytes) else: _pool = RandomPool() + def getBytes(n, pool=_pool): if pool.entropy < n: pool.randomize() @@ -125,9 +131,9 @@ except AttributeError: # numbers larger than sys.maxint for randrange. For simplicity, # use this implementation for any Python that does not have # random.SystemRandom - from math import log, ceil _duplicate_cache = {} + def randrange(start, stop=None, step=1): if stop is None: stop = start @@ -154,7 +160,7 @@ except AttributeError: _duplicate_cache[r] = (duplicate, nbytes) - while 1: + while True: bytes = '\x00' + getBytes(nbytes) n = binaryToLong(bytes) # Keep looping if this value is in the low duplicated range @@ -163,12 +169,15 @@ except AttributeError: return start + (n % r) * step + def longToBase64(l): return toBase64(longToBinary(l)) + def base64ToLong(s): return binaryToLong(fromBase64(s)) + def randomString(length, chrs=None): """Produce a string of length random bytes, chosen from chrs.""" if chrs is None: @@ -177,6 +186,7 @@ def randomString(length, chrs=None): n = len(chrs) return ''.join([chrs[randrange(n)] for _ in xrange(length)]) + def const_eq(s1, s2): if len(s1) != len(s2): return False diff --git a/openid/dh.py b/openid/dh.py index 3478240..b0400b9 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -1,15 +1,23 @@ -from openid import cryptutil, oidutil +from openid import cryptutil + + +def _xor(a_b): + a, b = a_b + return chr(ord(a) ^ ord(b)) def strxor(x, y): if len(x) != len(y): raise ValueError('Inputs to strxor must have the same length') - xor = lambda (a, b): chr(ord(a) ^ ord(b)) - return "".join(map(xor, zip(x, y))) + return "".join(map(_xor, zip(x, y))) + class DiffieHellman(object): - DEFAULT_MOD = 155172898181473697471232257763715539915724801966915404479707795314057629378541917580651227423698188993727816152646631438561595825688188889951272158842675419950341258706556549803580104870537681476726513255747040765857479291291572334510643245094715007229621094194349783925984760375594985848253359305585439638443L + DEFAULT_MOD = int('155172898181473697471232257763715539915724801966915404479707795314057629378541917580651227423698' + '188993727816152646631438561595825688188889951272158842675419950341258706556549803580104870537681' + '476726513255747040765857479291291572334510643245094715007229621094194349783925984760375594985848' + '253359305585439638443') DEFAULT_GEN = 2 diff --git a/openid/extension.py b/openid/extension.py index 6366f03..55e129b 100644 --- a/openid/extension.py +++ b/openid/extension.py @@ -1,3 +1,5 @@ +import warnings + from openid import message as message_module diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index 6b21812..c8fac3f 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -5,12 +5,12 @@ """ __all__ = [ - 'AttributeRequest', + 'AttrInfo', 'FetchRequest', 'FetchResponse', 'StoreRequest', 'StoreResponse', - ] +] from openid import extension from openid.message import OPENID_NS, NamespaceMap @@ -24,6 +24,7 @@ UNLIMITED_VALUES = "unlimited" # completeness. MINIMUM_SUPPORTED_ALIAS_LENGTH = 32 + def checkAlias(alias): """ Check an alias for invalid characters; raise AXError if any are @@ -60,11 +61,6 @@ class AXMessage(extension.Extension): be overridden in subclasses. """ - # This class is abstract, so it's OK that it doesn't override the - # abstract method in Extension: - # - #pylint:disable-msg=W0223 - ns_alias = 'ax' mode = None ns_uri = 'http://openid.net/srv/ax/1.0' @@ -90,7 +86,7 @@ class AXMessage(extension.Extension): basic information that must be in every attribute exchange message. """ - return {'mode':self.mode} + return {'mode': self.mode} class AttrInfo(object): @@ -122,11 +118,6 @@ class AttrInfo(object): @type alias: str or NoneType """ - # It's OK that this class doesn't have public methods (it's just a - # holder for a bunch of attributes): - # - #pylint:disable-msg=R0903 - def __init__(self, type_uri, count=1, required=False, alias=None): self.required = required self.count = count @@ -146,6 +137,7 @@ class AttrInfo(object): """ return self.count == UNLIMITED_VALUES + def toTypeURIs(namespace_map, alias_list_s): """Given a namespace mapping and a string containing a comma-separated list of namespace aliases, return a list of type @@ -304,7 +296,7 @@ class FetchRequest(AXMessage): self = cls() try: self.parseExtensionArgs(ax_args) - except NotAXMessage, err: + except NotAXMessage: return None if self.update_url: @@ -413,11 +405,6 @@ class AXKeyValueMessage(AXMessage): fetch_response and store_request. """ - # This class is abstract, so it's OK that it doesn't override the - # abstract method in Extension: - # - #pylint:disable-msg=W0223 - def __init__(self): AXMessage.__init__(self) self.data = {} @@ -652,8 +639,7 @@ class FetchResponse(AXKeyValueMessage): values = [] zero_value_types.append(attr_info) - if (attr_info.count != UNLIMITED_VALUES) and \ - (attr_info.count < len(values)): + if (attr_info.count != UNLIMITED_VALUES) and (attr_info.count < len(values)): raise AXError( 'More than the number of requested values were ' 'specified for %r' % (attr_info.type_uri,)) @@ -671,8 +657,7 @@ class FetchResponse(AXKeyValueMessage): kv_args['type.' + alias] = attr_info.type_uri kv_args['count.' + alias] = '0' - update_url = ((self.request and self.request.update_url) - or self.update_url) + update_url = ((self.request and self.request.update_url) or self.update_url) if update_url: ax_args['update_url'] = update_url @@ -709,7 +694,7 @@ class FetchResponse(AXKeyValueMessage): try: self.parseExtensionArgs(ax_args) - except NotAXMessage, err: + except NotAXMessage: return None else: return self @@ -762,7 +747,7 @@ class StoreRequest(AXKeyValueMessage): self = cls() try: self.parseExtensionArgs(ax_args) - except NotAXMessage, err: + except NotAXMessage: return None return self @@ -782,8 +767,7 @@ class StoreResponse(AXMessage): AXMessage.__init__(self) if succeeded and error_message is not None: - raise AXError('An error message may only be included in a ' - 'failing fetch response') + raise AXError('An error message may only be included in a failing fetch response') if succeeded: self.mode = self.SUCCESS_MODE else: @@ -826,7 +810,7 @@ class StoreResponse(AXMessage): try: self.parseExtensionArgs(ax_args) - except NotAXMessage, err: + except NotAXMessage: return None else: return self diff --git a/openid/extensions/draft/pape2.py b/openid/extensions/draft/pape2.py index b800ce2..f9b84c8 100644 --- a/openid/extensions/draft/pape2.py +++ b/openid/extensions/draft/pape2.py @@ -13,7 +13,7 @@ __all__ = [ 'AUTH_PHISHING_RESISTANT', 'AUTH_MULTI_FACTOR', 'AUTH_MULTI_FACTOR_PHYSICAL', - ] +] import re @@ -30,6 +30,7 @@ AUTH_PHISHING_RESISTANT = \ TIME_VALIDATOR = re.compile('^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$') + class Request(Extension): """A Provider Authentication Policy request, sent from a relying party to a provider @@ -75,8 +76,8 @@ class Request(Extension): """@see: C{L{Extension.getExtensionArgs}} """ ns_args = { - 'preferred_auth_policies':' '.join(self.preferred_auth_policies) - } + 'preferred_auth_policies': ' '.join(self.preferred_auth_policies) + } if self.max_auth_age is not None: ns_args['max_auth_age'] = str(self.max_auth_age) @@ -148,6 +149,7 @@ class Request(Extension): return filter(self.preferred_auth_policies.__contains__, supported_types) + Request.ns_uri = ns_uri @@ -254,12 +256,12 @@ class Response(Extension): """ if len(self.auth_policies) == 0: ns_args = { - 'auth_policies':'none', + 'auth_policies': 'none', } else: ns_args = { - 'auth_policies':' '.join(self.auth_policies), - } + 'auth_policies': ' '.join(self.auth_policies), + } if self.nist_auth_level is not None: if self.nist_auth_level not in range(0, 5): @@ -275,4 +277,5 @@ class Response(Extension): return ns_args + Response.ns_uri = ns_uri diff --git a/openid/extensions/draft/pape5.py b/openid/extensions/draft/pape5.py index e146873..6d0b1dd 100644 --- a/openid/extensions/draft/pape5.py +++ b/openid/extensions/draft/pape5.py @@ -15,7 +15,7 @@ __all__ = [ 'AUTH_MULTI_FACTOR_PHYSICAL', 'LEVELS_NIST', 'LEVELS_JISA', - ] +] import re import warnings @@ -38,11 +38,12 @@ TIME_VALIDATOR = re.compile('^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$') LEVELS_NIST = 'http://csrc.nist.gov/publications/nistpubs/800-63/SP800-63V1_0_2.pdf' LEVELS_JISA = 'http://www.jisa.or.jp/spec/auth_level.html' + class PAPEExtension(Extension): _default_auth_level_aliases = { 'nist': LEVELS_NIST, 'jisa': LEVELS_JISA, - } + } def __init__(self): self.auth_level_aliases = self._default_auth_level_aliases.copy() @@ -90,6 +91,7 @@ class PAPEExtension(Extension): raise KeyError(auth_level_uri) + class Request(PAPEExtension): """A Provider Authentication Policy request, sent from a relying party to a provider @@ -152,8 +154,8 @@ class Request(PAPEExtension): """@see: C{L{Extension.getExtensionArgs}} """ ns_args = { - 'preferred_auth_policies':' '.join(self.preferred_auth_policies), - } + 'preferred_auth_policies': ' '.join(self.preferred_auth_policies), + } if self.max_auth_age is not None: ns_args['max_auth_age'] = str(self.max_auth_age) @@ -266,6 +268,7 @@ class Request(PAPEExtension): return filter(self.preferred_auth_policies.__contains__, supported_types) + Request.ns_uri = ns_uri @@ -455,8 +458,8 @@ class Response(PAPEExtension): } else: ns_args = { - 'auth_policies':' '.join(self.auth_policies), - } + 'auth_policies': ' '.join(self.auth_policies), + } for level_type, level in self.auth_levels.iteritems(): alias = self._getAlias(level_type) @@ -471,4 +474,5 @@ class Response(PAPEExtension): return ns_args + Response.ns_uri = ns_uri diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index e147cf1..786aeea 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -48,22 +48,23 @@ __all__ = [ 'ns_uri_1_0', 'ns_uri_1_1', 'supportsSReg', - ] +] _LOGGER = logging.getLogger(__name__) # The data fields that are listed in the sreg spec data_fields = { - 'fullname':'Full Name', - 'nickname':'Nickname', - 'dob':'Date of Birth', - 'email':'E-mail Address', - 'gender':'Gender', - 'postcode':'Postal Code', - 'country':'Country', - 'language':'Language', - 'timezone':'Time Zone', - } + 'fullname': 'Full Name', + 'nickname': 'Nickname', + 'dob': 'Date of Birth', + 'email': 'E-mail Address', + 'gender': 'Gender', + 'postcode': 'Postal Code', + 'country': 'Country', + 'language': 'Language', + 'timezone': 'Time Zone', +} + def checkFieldName(field_name): """Check to see that the given value is a valid simple @@ -76,6 +77,7 @@ def checkFieldName(field_name): raise ValueError('%r is not a defined simple registration field' % (field_name,)) + # URI used in the wild for Yadis documents advertising simple # registration support ns_uri_1_0 = 'http://openid.net/sreg/1.0' @@ -90,9 +92,10 @@ ns_uri = ns_uri_1_1 try: registerNamespaceAlias(ns_uri_1_1, 'sreg') -except NamespaceAliasRegistrationError, e: +except NamespaceAliasRegistrationError as e: _LOGGER.exception('registerNamespaceAlias(%r, %r) failed: %s', ns_uri_1_1, 'sreg', e) + def supportsSReg(endpoint): """Does the given endpoint advertise support for simple registration? @@ -106,6 +109,7 @@ def supportsSReg(endpoint): return (endpoint.usesExtension(ns_uri_1_1) or endpoint.usesExtension(ns_uri_1_0)) + class SRegNamespaceError(ValueError): """The simple registration namespace was not found and could not be created using the expected name (there's another extension @@ -120,6 +124,7 @@ class SRegNamespaceError(ValueError): the message that is being processed. """ + def getSRegNS(message): """Extract the simple registration namespace URI from the given OpenID message. Handles OpenID 1 and 2, as well as both sreg @@ -151,14 +156,13 @@ def getSRegNS(message): sreg_ns_uri = ns_uri_1_1 try: message.namespaces.addAlias(ns_uri_1_1, 'sreg') - except KeyError, why: + except KeyError as why: # An alias for the string 'sreg' already exists, but it's # defined for something other than simple registration raise SRegNamespaceError(why[0]) - # we know that sreg_ns_uri defined, because it's defined in the - # else clause of the loop as well, so disable the warning - return sreg_ns_uri #pylint:disable-msg=W0631 + return sreg_ns_uri + class SRegRequest(Extension): """An object to hold the state of a simple registration request. @@ -368,6 +372,7 @@ class SRegRequest(Extension): return args + class SRegResponse(Extension): """Represents the data returned in a simple registration response inside of an OpenID C{id_res} response. This object will be diff --git a/openid/fetchers.py b/openid/fetchers.py index b30f895..750b5f5 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -32,6 +32,7 @@ except ImportError: USER_AGENT = "python-openid/%s (%s)" % (openid.__version__, sys.platform) MAX_RESPONSE_KB = 1024 + def fetch(url, body=None, headers=None): """Invoke the fetch method on the default fetcher. Most users should need only this method. @@ -41,6 +42,7 @@ def fetch(url, body=None, headers=None): fetcher = getDefaultFetcher() return fetcher.fetch(url, body, headers) + def createHTTPFetcher(): """Create a default HTTP fetcher instance @@ -52,11 +54,13 @@ def createHTTPFetcher(): return fetcher + # Contains the currently set HTTP fetcher. If it is set to None, the # library will call createHTTPFetcher() to set it. Do not access this # variable outside of this module. _default_fetcher = None + def getDefaultFetcher(): """Return the default fetcher instance if no fetcher has been set, it will create a default fetcher. @@ -71,6 +75,7 @@ def getDefaultFetcher(): return _default_fetcher + def setDefaultFetcher(fetcher, wrap_exceptions=True): """Set the default fetcher @@ -91,6 +96,7 @@ def setDefaultFetcher(fetcher, wrap_exceptions=True): else: _default_fetcher = ExceptionWrappingFetcher(fetcher) + def usingCurl(): """Whether the currently set HTTP fetcher is a Curl HTTP fetcher.""" fetcher = getDefaultFetcher() @@ -98,6 +104,7 @@ def usingCurl(): fetcher = fetcher.fetcher return isinstance(fetcher, CurlHTTPFetcher) + class HTTPResponse(object): """XXX document attributes""" headers = None @@ -116,6 +123,7 @@ class HTTPResponse(object): self.status, self.final_url) + class HTTPFetcher(object): """ This class is the interface for openid HTTP fetchers. This @@ -145,19 +153,23 @@ class HTTPFetcher(object): """ raise NotImplementedError + def _allowedURL(url): return url.startswith('http://') or url.startswith('https://') + class HTTPFetchingError(Exception): """Exception that is wrapped around all exceptions that are raised by the underlying fetcher when using the ExceptionWrappingFetcher @ivar why: The exception that caused this exception """ + def __init__(self, why=None): Exception.__init__(self, why) self.why = why + class ExceptionWrappingFetcher(HTTPFetcher): """Fetcher wrapper which wraps all exceptions to `HTTPFetchingError`.""" @@ -175,6 +187,7 @@ class ExceptionWrappingFetcher(HTTPFetcher): raise HTTPFetchingError(why=exc_inst) + class Urllib2Fetcher(HTTPFetcher): """An C{L{HTTPFetcher}} that uses urllib2. """ @@ -201,7 +214,7 @@ class Urllib2Fetcher(HTTPFetcher): return self._makeResponse(f) finally: f.close() - except urllib2.HTTPError, why: + except urllib2.HTTPError as why: try: return self._makeResponse(why) finally: @@ -220,6 +233,7 @@ class Urllib2Fetcher(HTTPFetcher): return resp + class HTTPError(HTTPFetchingError): """ This exception is raised by the C{L{CurlHTTPFetcher}} when it @@ -228,12 +242,14 @@ class HTTPError(HTTPFetchingError): pass # XXX: define what we mean by paranoid, and make sure it is. + + class CurlHTTPFetcher(HTTPFetcher): """ An C{L{HTTPFetcher}} that uses pycurl for fetching. See U{http://pycurl.sourceforge.net/}. """ - ALLOWED_TIME = 20 # seconds + ALLOWED_TIME = 20 # seconds def __init__(self): HTTPFetcher.__init__(self) @@ -244,7 +260,7 @@ class CurlHTTPFetcher(HTTPFetcher): header_file.seek(0) # Remove the status line from the beginning of the input - unused_http_status_line = header_file.readline().lower () + unused_http_status_line = header_file.readline().lower() if unused_http_status_line.startswith('http/1.1 100 '): unused_http_status_line = header_file.readline() unused_http_status_line = header_file.readline() @@ -309,8 +325,9 @@ class CurlHTTPFetcher(HTTPFetcher): raise HTTPError("Fetching URL not allowed: %r" % (url,)) data = cStringIO.StringIO() + def write_data(chunk): - if data.tell() > 1024*MAX_RESPONSE_KB: + if data.tell() > 1024 * MAX_RESPONSE_KB: return 0 else: return data.write(chunk) @@ -350,6 +367,7 @@ class CurlHTTPFetcher(HTTPFetcher): finally: c.close() + class HTTPLib2Fetcher(HTTPFetcher): """A fetcher that uses C{httplib2} for performing HTTP requests. This implementation supports HTTP caching. @@ -419,4 +437,4 @@ class HTTPLib2Fetcher(HTTPFetcher): final_url=final_url, headers=dict(httplib2_response.items()), status=httplib2_response.status, - ) + ) diff --git a/openid/kvform.py b/openid/kvform.py index 8252d91..e0e91a0 100644 --- a/openid/kvform.py +++ b/openid/kvform.py @@ -9,6 +9,7 @@ _LOGGER = logging.getLogger(__name__) class KVFormError(ValueError): pass + def seqToKV(seq, strict=False): """Represent a sequence of pairs of strings as newline-terminated key:value pairs. The pairs are generated in the order given. @@ -62,6 +63,7 @@ def seqToKV(seq, strict=False): return ''.join(lines).encode('UTF8') + def kvToSeq(data, strict=False): """ @@ -116,10 +118,11 @@ def kvToSeq(data, strict=False): return pairs + def dictToKV(d): - seq = d.items() - seq.sort() + seq = sorted(d.items()) return seqToKV(seq) + def kvToDict(s): return dict(kvToSeq(s)) diff --git a/openid/message.py b/openid/message.py index 92706d9..9c487d6 100644 --- a/openid/message.py +++ b/openid/message.py @@ -55,23 +55,27 @@ OPENID_PROTOCOL_FIELDS = [ 'dh_consumer_public', 'claimed_id', 'identity', 'realm', 'invalidate_handle', 'op_endpoint', 'response_nonce', 'sig', 'assoc_handle', 'trust_root', 'openid', - ] +] + class UndefinedOpenIDNamespace(ValueError): """Raised if the generic OpenID namespace is accessed when there is no OpenID namespace set for this message.""" + class InvalidOpenIDNamespace(ValueError): """Raised if openid.ns is not a recognized value. For recognized values, see L{Message.allowed_openid_namespaces} """ + def __str__(self): s = "Invalid OpenID Namespace" if self.args: s += " %r" % (self.args[0],) return s + class InvalidNamespace(KeyError): """ Raised if there is problem with other namespaces than OpenID namespace @@ -86,12 +90,14 @@ no_default = object() # registerNamespaceAlias. registered_aliases = {} + class NamespaceAliasRegistrationError(Exception): """ Raised when an alias or namespace URI has already been registered. """ pass + def registerNamespaceAlias(namespace_uri, alias): """ Registers a (namespace URI, alias) mapping in a global namespace @@ -106,15 +112,14 @@ def registerNamespaceAlias(namespace_uri, alias): return if namespace_uri in registered_aliases.values(): - raise NamespaceAliasRegistrationError, \ - 'Namespace uri %r already registered' % (namespace_uri,) + raise NamespaceAliasRegistrationError('Namespace uri %r already registered' % (namespace_uri,)) if alias in registered_aliases: - raise NamespaceAliasRegistrationError, \ - 'Alias %r already registered' % (alias,) + raise NamespaceAliasRegistrationError('Alias %r already registered' % (alias,)) registered_aliases[alias] = namespace_uri + class Message(object): """ In the implementation of this object, None represents the global @@ -158,7 +163,6 @@ class Message(object): raise TypeError("query dict must have one value for each key, " "not lists of values. Query is %r" % (args,)) - try: prefix, rest = key.split('.', 1) except ValueError: @@ -348,7 +352,7 @@ class Message(object): form.append(ElementTree.Element(u'input', attrs)) submit = ElementTree.Element(u'input', - {u'type':'submit', u'value':oidutil.toUnicode(submit_text)}) + {u'type': 'submit', u'value': oidutil.toUnicode(submit_text)}) form.append(submit) return ElementTree.tostring(form, encoding='utf-8') @@ -367,8 +371,7 @@ class Message(object): def toURLEncoded(self): """Generate an x-www-urlencoded string""" - args = self.toPostArgs().items() - args.sort() + args = sorted(self.toPostArgs().items()) return urllib.urlencode(args) def _fixNS(self, namespace): @@ -464,7 +467,7 @@ class Message(object): for ((pair_ns, ns_key), value) in self.args.iteritems() if pair_ns == namespace - ]) + ]) def updateArgs(self, namespace, updates): """Set multiple key/value pairs in one call @@ -497,11 +500,9 @@ class Message(object): def __eq__(self, other): return self.args == other.args - def __ne__(self, other): return not (self == other) - def getAliasedArg(self, aliased_key, default=None): if aliased_key == 'ns': return self.getOpenIDNamespace() @@ -530,9 +531,11 @@ class Message(object): return self.getArg(ns, key, default) + class NamespaceMap(object): """Maintains a bijective map between namespace uris and aliases. """ + def __init__(self): self.alias_to_namespace = {} self.namespace_to_alias = {} @@ -564,8 +567,7 @@ class NamespaceMap(object): """ # Check that desired_alias is not an openid protocol field as # per the spec. - assert desired_alias not in OPENID_PROTOCOL_FIELDS, \ - "%r is not an allowed namespace alias" % (desired_alias,) + assert desired_alias not in OPENID_PROTOCOL_FIELDS, "%r is not an allowed namespace alias" % (desired_alias,) # Check that desired_alias does not contain a period as per # the spec. @@ -576,8 +578,7 @@ class NamespaceMap(object): # Check that there is not a namespace already defined for # the desired alias current_namespace_uri = self.alias_to_namespace.get(desired_alias) - if (current_namespace_uri is not None - and current_namespace_uri != namespace_uri): + if (current_namespace_uri is not None and current_namespace_uri != namespace_uri): fmt = ('Cannot map %r to alias %r. ' '%r is already mapped to alias %r') diff --git a/openid/oidutil.py b/openid/oidutil.py index a92b453..13954b7 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -9,8 +9,6 @@ __all__ = ['log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML', 'toU import binascii import logging -import sys -import urlparse from urllib import urlencode _LOGGER = logging.getLogger(__name__) @@ -21,7 +19,8 @@ elementtree_modules = [ 'xml.etree.ElementTree', 'cElementTree', 'elementtree.ElementTree', - ] +] + def toUnicode(value): """Returns the given argument as a unicode object. @@ -35,6 +34,7 @@ def toUnicode(value): return value.decode('utf-8') return unicode(value) + def autoSubmitHTML(form, title='OpenID transaction in progress'): return """ <html> @@ -53,6 +53,7 @@ for (var i = 0; i < elements.length; i++) { </html> """ % (title, form) + def importElementTree(module_names=None): """Find a working ElementTree implementation, trying the standard places that such a thing might show up. @@ -76,9 +77,7 @@ def importElementTree(module_names=None): # Make sure it can actually parse XML try: ElementTree.XML('<unused/>') - except (SystemExit, MemoryError, AssertionError): - raise - except: + except Exception: logging.exception('Not using ElementTree library %r because it failed to parse a trivial document: %s', mod_name) else: @@ -89,6 +88,7 @@ def importElementTree(module_names=None): 'Tried importing %r' % (module_names,) ) + def log(message, level=0): """Handle a log message from the OpenID library. @@ -109,6 +109,7 @@ def log(message, level=0): logging.error("This is a legacy log message, please use the logging module. Message: %s", message) + def appendArgs(url, args): """Append query arguments to a HTTP(s) URL. If the URL already has query arguemtns, these arguments will be added, and the existing @@ -129,8 +130,7 @@ def appendArgs(url, args): @rtype: str """ if hasattr(args, 'items'): - args = args.items() - args.sort() + args = sorted(args.items()) else: args = list(args) @@ -146,10 +146,10 @@ def appendArgs(url, args): # about the encodings of plain bytes (str). i = 0 for k, v in args: - if type(k) is not str: + if not isinstance(k, str): k = k.encode('UTF-8') - if type(v) is not str: + if not isinstance(v, str): v = v.encode('UTF-8') args[i] = (k, v) @@ -157,17 +157,20 @@ def appendArgs(url, args): return '%s%s%s' % (url, sep, urlencode(args)) + def toBase64(s): """Represent string s as base64, omitting newlines""" return binascii.b2a_base64(s)[:-1] + def fromBase64(s): try: return binascii.a2b_base64(s) - except binascii.Error, why: + except binascii.Error as why: # Convert to a common exception type raise ValueError(why[0]) + class Symbol(object): """This class implements an object that compares equal to others of the same type that have the same name. These are distict from @@ -178,13 +181,13 @@ class Symbol(object): self.name = name def __eq__(self, other): - return type(self) is type(other) and self.name == other.name + return type(self) == type(other) and self.name == other.name def __ne__(self, other): return not (self == other) def __hash__(self): return hash((self.__class__, self.name)) - + def __repr__(self): return '<Symbol %s>' % (self.name,) diff --git a/openid/server/server.py b/openid/server/server.py index 1e456e0..436b8ad 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -144,6 +144,7 @@ ENCODE_HTML_FORM = ('HTML form',) UNUSED = None + class OpenIDRequest(object): """I represent an incoming OpenID request. @@ -190,7 +191,6 @@ class CheckAuthRequest(OpenIDRequest): self.invalidate_handle = invalidate_handle self.namespace = OPENID2_NS - def fromMessage(klass, message, op_endpoint=UNUSED): """Construct me from an OpenID Message. @@ -206,7 +206,7 @@ class CheckAuthRequest(OpenIDRequest): self.sig = message.getArg(OPENID_NS, 'sig') if (self.assoc_handle is None or - self.sig is None): + self.sig is None): fmt = "%s request missing required parameter from message %s" raise ProtocolError( message, text=fmt % (self.mode, message)) @@ -253,7 +253,6 @@ class CheckAuthRequest(OpenIDRequest): OPENID_NS, 'invalidate_handle', self.invalidate_handle) return response - def __str__(self): if self.invalidate_handle: ih = " invalidate? %r" % (self.invalidate_handle,) @@ -330,7 +329,7 @@ class DiffieHellmanSHA1ServerSession(object): dh_modulus = message.getArg(OPENID_NS, 'dh_modulus') dh_gen = message.getArg(OPENID_NS, 'dh_gen') if (dh_modulus is None and dh_gen is not None or - dh_gen is None and dh_modulus is not None): + dh_gen is None and dh_modulus is not None): if dh_modulus is None: missing = 'modulus' @@ -367,13 +366,15 @@ class DiffieHellmanSHA1ServerSession(object): return { 'dh_server_public': cryptutil.longToBase64(self.dh.public), 'enc_mac_key': oidutil.toBase64(mac_key), - } + } + class DiffieHellmanSHA256ServerSession(DiffieHellmanSHA1ServerSession): session_type = 'DH-SHA256' hash_func = staticmethod(cryptutil.sha256) allowed_assoc_types = ['HMAC-SHA256'] + class AssociateRequest(OpenIDRequest): """A request to establish an X{association}. @@ -397,7 +398,7 @@ class AssociateRequest(OpenIDRequest): 'no-encryption': PlainTextServerSession, 'DH-SHA1': DiffieHellmanSHA1ServerSession, 'DH-SHA256': DiffieHellmanSHA256ServerSession, - } + } def __init__(self, session, assoc_type): """Construct me. @@ -410,7 +411,6 @@ class AssociateRequest(OpenIDRequest): self.assoc_type = assoc_type self.namespace = OPENID2_NS - def fromMessage(klass, message, op_endpoint=UNUSED): """Construct me from an OpenID Message. @@ -423,7 +423,7 @@ class AssociateRequest(OpenIDRequest): session_type = message.getArg(OPENID_NS, 'session_type') if session_type == 'no-encryption': _LOGGER.warn('Received OpenID 1 request with a no-encryption ' - 'assocaition session type. Continuing anyway.') + 'assocaition session type. Continuing anyway.') elif not session_type: session_type = 'no-encryption' @@ -449,7 +449,7 @@ class AssociateRequest(OpenIDRequest): try: session = session_class.fromMessage(message) - except ValueError, why: + except ValueError as why: raise ProtocolError(message, 'Error parsing %s session: %s' % (session_class.session_type, why[0])) @@ -479,7 +479,7 @@ class AssociateRequest(OpenIDRequest): 'expires_in': '%d' % (assoc.getExpiresIn(),), 'assoc_type': self.assoc_type, 'assoc_handle': assoc.handle, - }) + }) response.fields.updateArgs(OPENID_NS, self.session.answer(assoc.secret)) @@ -513,6 +513,7 @@ class AssociateRequest(OpenIDRequest): return response + class CheckIDRequest(OpenIDRequest): """A request to confirm the identity of a user. @@ -571,8 +572,7 @@ class CheckIDRequest(OpenIDRequest): self.immediate = False self.mode = "checkid_setup" - if self.return_to is not None and \ - not TrustRoot.parse(self.return_to): + if self.return_to is not None and not TrustRoot.parse(self.return_to): raise MalformedReturnURL(None, self.return_to) if not self.trustRootValid(): raise UntrustedReturnURL(None, self.return_to, self.trust_root) @@ -650,8 +650,7 @@ class CheckIDRequest(OpenIDRequest): # Using 'or' here is slightly different than sending a default # argument to getArg, as it will treat no value and an empty # string as equivalent. - self.trust_root = (message.getArg(OPENID_NS, trust_root_param) - or self.return_to) + self.trust_root = (message.getArg(OPENID_NS, trust_root_param) or self.return_to) if not message.isOpenID1(): if self.return_to is self.trust_root is None: @@ -666,8 +665,7 @@ class CheckIDRequest(OpenIDRequest): # is a valid URL. Not all trust roots are valid return_to URLs, # however (particularly ones with wildcards), so this is still a # little sketchy. - if self.return_to is not None and \ - not TrustRoot.parse(self.return_to): + if self.return_to is not None and not TrustRoot.parse(self.return_to): raise MalformedReturnURL(message, self.return_to) # I first thought that checking to see if the return_to is within @@ -798,10 +796,10 @@ class CheckIDRequest(OpenIDRequest): if allow: mode = 'id_res' elif self.message.isOpenID1(): - if self.immediate: - mode = 'id_res' - else: - mode = 'cancel' + if self.immediate: + mode = 'id_res' + else: + mode = 'cancel' else: if self.immediate: mode = 'setup_needed' @@ -829,8 +827,7 @@ class CheckIDRequest(OpenIDRequest): normalized_request_identity = urinorm(self.identity) normalized_answer_identity = urinorm(identity) - if (normalized_request_identity != - normalized_answer_identity): + if normalized_request_identity != normalized_answer_identity: raise ValueError( "Request was for identity %r, cannot reply " "with identity %r" % (self.identity, identity)) @@ -851,13 +848,13 @@ class CheckIDRequest(OpenIDRequest): raise ValueError( "Request was an OpenID 1 request, so response must " "include an identifier." - ) + ) response.fields.updateArgs(OPENID_NS, { 'mode': mode, 'return_to': self.return_to, 'response_nonce': mkNonce(), - }) + }) if server_url: response.fields.setArg(OPENID_NS, 'op_endpoint', server_url) @@ -888,7 +885,6 @@ class CheckIDRequest(OpenIDRequest): return response - def encodeToURL(self, server_url): """Encode this request as a URL to GET. @@ -922,7 +918,6 @@ class CheckIDRequest(OpenIDRequest): response.updateArgs(OPENID_NS, q) return response.toURL(server_url) - def getCancelURL(self): """Get the URL to cancel this request. @@ -949,7 +944,6 @@ class CheckIDRequest(OpenIDRequest): response.setArg(OPENID_NS, 'mode', 'cancel') return response.toURL(self.return_to) - def __repr__(self): return '<%s id:%r im:%s tr:%r ah:%r>' % (self.__class__.__name__, self.identity, @@ -958,7 +952,6 @@ class CheckIDRequest(OpenIDRequest): self.assoc_handle) - class OpenIDResponse(object): """I am a response to an OpenID request. @@ -995,7 +988,6 @@ class OpenIDResponse(object): self.request.__class__.__name__, self.fields) - def toFormMarkup(self, form_tag_attrs=None): """Returns the form markup for this response. @@ -1033,7 +1025,6 @@ class OpenIDResponse(object): """ return self.whichEncoding() == ENCODE_HTML_FORM - def needsSigning(self): """Does this response require signing? @@ -1041,7 +1032,6 @@ class OpenIDResponse(object): """ return self.fields.getArg(OPENID_NS, 'mode') == 'id_res' - # implements IEncodable def whichEncoding(self): @@ -1061,7 +1051,6 @@ class OpenIDResponse(object): else: return ENCODE_KVFORM - def encodeToURL(self): """Encode a response as a URL for the user agent to GET. @@ -1072,7 +1061,6 @@ class OpenIDResponse(object): """ return self.fields.toURL(self.request.return_to) - def addExtension(self, extension_response): """ Add an extension response to this response message. @@ -1086,7 +1074,6 @@ class OpenIDResponse(object): """ extension_response.toMessage(self.fields) - def encodeToKVForm(self): """Encode a response in key-value colon/newline format. @@ -1101,7 +1088,6 @@ class OpenIDResponse(object): return self.fields.toKVForm() - class WebResponse(object): """I am a response to an OpenID request in terms a web server understands. @@ -1132,7 +1118,6 @@ class WebResponse(object): self.body = body - class Signatory(object): """I sign things. @@ -1146,7 +1131,7 @@ class Signatory(object): @type SECRET_LIFETIME: int """ - SECRET_LIFETIME = 14 * 24 * 60 * 60 # 14 days, in seconds + SECRET_LIFETIME = 14 * 24 * 60 * 60 # 14 days, in seconds # keys have a bogus server URL in them because the filestore # really does expect that key to be a URL. This seems a little @@ -1155,7 +1140,6 @@ class Signatory(object): _normal_key = 'http://localhost/|normal' _dumb_key = 'http://localhost/|dumb' - def __init__(self, store): """Create a new Signatory. @@ -1165,7 +1149,6 @@ class Signatory(object): assert store is not None self.store = store - def verify(self, assoc_handle, message): """Verify that the signature for some data is valid. @@ -1186,12 +1169,11 @@ class Signatory(object): try: valid = assoc.checkMessageSignature(message) - except ValueError, ex: + except ValueError as ex: _LOGGER.exception("Error in verifying %s with %s: %s", message, assoc, ex) return False return valid - def sign(self, response): """Sign a response. @@ -1232,11 +1214,10 @@ class Signatory(object): try: signed_response.fields = assoc.signMessage(signed_response.fields) - except kvform.KVFormError, err: + except kvform.KVFormError as err: raise EncodingError(response, explanation=str(err)) return signed_response - def createAssociation(self, dumb=True, assoc_type='HMAC-SHA1'): """Make a new association. @@ -1264,7 +1245,6 @@ class Signatory(object): self.store.storeAssociation(key, assoc) return assoc - def getAssociation(self, assoc_handle, dumb, checkExpiration=True): """Get the association with the specified handle. @@ -1299,7 +1279,6 @@ class Signatory(object): assoc = None return assoc - def invalidate(self, assoc_handle, dumb): """Invalidates the association with the given handle. @@ -1315,7 +1294,6 @@ class Signatory(object): self.store.removeAssociation(key, assoc_handle) - class Encoder(object): """I encode responses in to L{WebResponses<WebResponse>}. @@ -1327,7 +1305,6 @@ class Encoder(object): responseFactory = WebResponse - def encode(self, response): """Encode a response to a L{WebResponse}. @@ -1353,7 +1330,6 @@ class Encoder(object): return wr - class SigningEncoder(Encoder): """I encode responses in to L{WebResponses<WebResponse>}, signing them when required. """ @@ -1366,7 +1342,6 @@ class SigningEncoder(Encoder): """ self.signatory = signatory - def encode(self, response): """Encode a response to a L{WebResponse}, signing it first if appropriate. @@ -1390,7 +1365,6 @@ class SigningEncoder(Encoder): return super(SigningEncoder, self).encode(response) - class Decoder(object): """I decode an incoming web request in to a L{OpenIDRequest}. """ @@ -1400,7 +1374,7 @@ class Decoder(object): 'checkid_immediate': CheckIDRequest.fromMessage, 'check_authentication': CheckAuthRequest.fromMessage, 'associate': AssociateRequest.fromMessage, - } + } def __init__(self, server): """Construct a Decoder. @@ -1431,7 +1405,7 @@ class Decoder(object): try: message = Message.fromPostArgs(query) - except InvalidOpenIDNamespace, err: + except InvalidOpenIDNamespace as err: # It's useful to have a Message attached to a ProtocolError, so we # override the bad ns value to build a Message out of it. Kinda # kludgy, since it's made of lies, but the parts that aren't lies @@ -1440,7 +1414,7 @@ class Decoder(object): query['openid.ns'] = OPENID2_NS message = Message.fromPostArgs(query) raise ProtocolError(message, str(err)) - except InvalidNamespace, err: + except InvalidNamespace as err: # If openid.ns is OK, but there is problem with other namespaces # We keep only bare parts of query and we try to make a ProtocolError from it query = [(key, value) for key, value in query.items() if key.count('.') < 2] @@ -1455,7 +1429,6 @@ class Decoder(object): handler = self._handlers.get(mode, self.defaultDecoder) return handler(message, self.server.op_endpoint) - def defaultDecoder(self, message, server): """Called to decode queries when no handler for that mode is found. @@ -1467,7 +1440,6 @@ class Decoder(object): raise ProtocolError(message, text=fmt % (mode,)) - class Server(object): """I handle requests for an OpenID server. @@ -1521,13 +1493,7 @@ class Server(object): encoderClass = SigningEncoder decoderClass = Decoder - def __init__( - self, - store, - op_endpoint=None, - signatoryClass=None, - encoderClass=None, - decoderClass=None): + def __init__(self, store, op_endpoint=None, signatoryClass=None, encoderClass=None, decoderClass=None): """A new L{Server}. @param store: The back-end where my associations are stored. @@ -1570,7 +1536,6 @@ class Server(object): stacklevel=2) self.op_endpoint = op_endpoint - def handleRequest(self, request): """Handle a request. @@ -1592,7 +1557,6 @@ class Server(object): "%s has no handler for a request of mode %r." % (self, request.mode)) - def openid_check_authentication(self, request): """Handle and respond to C{check_authentication} requests. @@ -1600,7 +1564,6 @@ class Server(object): """ return request.answer(self.signatory) - def openid_associate(self, request): """Handle and respond to C{associate} requests. @@ -1616,14 +1579,12 @@ class Server(object): else: message = ('Association type %r is not supported with ' 'session type %r' % (assoc_type, session_type)) - (preferred_assoc_type, preferred_session_type) = \ - self.negotiator.getAllowedType() + (preferred_assoc_type, preferred_session_type) = self.negotiator.getAllowedType() return request.answerUnsupported( message, preferred_assoc_type, preferred_session_type) - def decodeRequest(self, query): """Transform query parameters into an L{OpenIDRequest}. @@ -1643,7 +1604,6 @@ class Server(object): """ return self.decoder.decode(query) - def encodeResponse(self, response): """Encode a response to a L{WebResponse}, signing it first if appropriate. @@ -1659,7 +1619,6 @@ class Server(object): return self.encoder.encode(response) - class ProtocolError(Exception): """A message did not conform to the OpenID protocol. @@ -1683,7 +1642,6 @@ class ProtocolError(Exception): assert type(message) not in [str, unicode] Exception.__init__(self, text) - def getReturnTo(self): """Get the return_to argument from the request, if any. @@ -1778,13 +1736,11 @@ class ProtocolError(Exception): return None - class VersionError(Exception): """Raised when an operation was attempted that is not compatible with the protocol version being used.""" - class NoReturnToError(Exception): """Raised when a response to a request cannot be generated because the request contains no return_to URL. @@ -1792,7 +1748,6 @@ class NoReturnToError(Exception): pass - class EncodingError(Exception): """Could not encode this as a protocol message. @@ -1821,7 +1776,6 @@ class AlreadySigned(EncodingError): """This response is already signed.""" - class UntrustedReturnURL(ProtocolError): """A return_to is outside the trust_root.""" @@ -1837,12 +1791,12 @@ class UntrustedReturnURL(ProtocolError): class MalformedReturnURL(ProtocolError): """The return_to URL doesn't look like a valid URL.""" + def __init__(self, openid_message, return_to): self.return_to = return_to ProtocolError.__init__(self, openid_message) - class MalformedTrustRoot(ProtocolError): """The trust root is not well-formed. @@ -1851,7 +1805,7 @@ class MalformedTrustRoot(ProtocolError): pass -#class IEncodable: # Interface +# class IEncodable: # Interface # def encodeToURL(return_to): # """Encode a response as a URL for redirection. # diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index 955a0d8..ec771b9 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -12,10 +12,10 @@ the realm. __all__ = [ 'TrustRoot', 'RP_RETURN_TO_URL_TYPE', - 'extractReturnToURLs', + 'getAllowedReturnURLs', 'returnToMatches', 'verifyReturnTo', - ] +] import logging import re @@ -66,11 +66,13 @@ _top_level_domains = [ host_segment_re = re.compile( r"(?:[-a-zA-Z0-9!$&'\(\)\*+,;=._~]|%[a-zA-Z0-9]{2})+$") + class RealmVerificationRedirected(Exception): """Attempting to verify this realm resulted in a redirect. @since: 2.1.0 """ + def __init__(self, relying_party_url, rp_url_after_redirects): self.relying_party_url = relying_party_url self.rp_url_after_redirects = rp_url_after_redirects @@ -111,6 +113,7 @@ def _parseURL(url): return proto, host, port, path + class TrustRoot(object): """ This class represents an OpenID trust root. The C{L{parse}} @@ -178,7 +181,7 @@ class TrustRoot(object): if self.wildcard: if len(tld) == 2 and len(host_parts[-2]) <= 3: # It's a 2-letter tld with a short second to last segment - # so there needs to be more than two segments specified + # so there needs to be more than two segments specified # (e.g. *.co.uk is insane) return len(host_parts) > 2 @@ -239,8 +242,7 @@ class TrustRoot(object): else: allowed = '?/' - return (self.path[-1] in allowed or - path[path_len] in allowed) + return (self.path[-1] in allowed or path[path_len] in allowed) return True @@ -352,12 +354,14 @@ class TrustRoot(object): def __str__(self): return repr(self) + # The URI for relying party discovery, used in realm verification. # # XXX: This should probably live somewhere else (like in # openid.consumer or openid.yadis somewhere) RP_RETURN_TO_URL_TYPE = 'http://specs.openid.net/auth/2.0/return_to' + def _extractReturnURL(endpoint): """If the endpoint is a relying party OpenID return_to endpoint, return the endpoint URL. Otherwise, return None. @@ -380,6 +384,7 @@ def _extractReturnURL(endpoint): else: return None + def returnToMatches(allowed_return_to_urls, return_to): """Is the return_to URL under one of the supplied allowed return_to URLs? @@ -394,7 +399,8 @@ def returnToMatches(allowed_return_to_urls, return_to): # a wildcard. return_realm = TrustRoot.parse(allowed_return_to) - if (# Parses as a trust root + if ( + # Parses as a trust root return_realm is not None and # Does not have a wildcard @@ -402,12 +408,13 @@ def returnToMatches(allowed_return_to_urls, return_to): # Matches the return_to that we passed in with it return_realm.validateURL(return_to) - ): + ): return True # No URL in the list matched return False + def getAllowedReturnURLs(relying_party_url): """Given a relying party discovery URL return a list of return_to URLs. @@ -424,6 +431,8 @@ def getAllowedReturnURLs(relying_party_url): return return_to_urls # _vrfy parameter is there to make testing easier + + def verifyReturnTo(realm_str, return_to, _vrfy=getAllowedReturnURLs): """Verify that a return_to URL is valid for the given realm. @@ -444,7 +453,7 @@ def verifyReturnTo(realm_str, return_to, _vrfy=getAllowedReturnURLs): try: allowable_urls = _vrfy(realm.buildDiscoveryURL()) - except RealmVerificationRedirected, err: + except RealmVerificationRedirected as err: _LOGGER.exception(str(err)) return False diff --git a/openid/sreg.py b/openid/sreg.py index bf454d7..bceb53f 100644 --- a/openid/sreg.py +++ b/openid/sreg.py @@ -2,7 +2,9 @@ import warnings -from openid.extensions.sreg import * +from openid.extensions.sreg import SRegRequest, SRegResponse, data_fields, ns_uri, ns_uri_1_0, ns_uri_1_1, supportsSReg warnings.warn("openid.sreg has moved to openid.extensions.sreg", DeprecationWarning) + +__all__ = ['SRegRequest', 'SRegResponse', 'data_fields', 'ns_uri', 'ns_uri_1_0', 'ns_uri_1_1', 'supportsSReg'] diff --git a/openid/store/filestore.py b/openid/store/filestore.py index 3ec4c59..0c5c044 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -21,6 +21,7 @@ _LOGGER = logging.getLogger(__name__) _filename_allowed = string.ascii_letters + string.digits + '.' _isFilenameSafe = set(_filename_allowed).__contains__ + def _safe64(s): h64 = oidutil.toBase64(cryptutil.sha1(s)) h64 = h64.replace('+', '_') @@ -28,6 +29,7 @@ def _safe64(s): h64 = h64.replace('=', '') return h64 + def _filenameEscape(s): filename_chunks = [] for c in s: @@ -37,6 +39,7 @@ def _filenameEscape(s): filename_chunks.append('_%02X' % ord(c)) return ''.join(filename_chunks) + def _removeIfPresent(filename): """Attempt to remove a file, returning whether the file existed at the time of the call. @@ -45,7 +48,7 @@ def _removeIfPresent(filename): """ try: os.unlink(filename) - except OSError, why: + except OSError as why: if why.errno == ENOENT: # Someone beat us to it, but it's gone, so that's OK return 0 @@ -55,6 +58,7 @@ def _removeIfPresent(filename): # File was present return 1 + def _ensureDir(dir_name): """Create dir_name as a directory if it does not exist. If it exists, make sure that it is, in fact, a directory. @@ -65,10 +69,11 @@ def _ensureDir(dir_name): """ try: os.makedirs(dir_name) - except OSError, why: + except OSError as why: if why.errno != EEXIST or not os.path.isdir(dir_name): raise + class FileOpenIDStore(OpenIDStore): """ This is a filesystem-based store for OpenID associations and @@ -108,7 +113,7 @@ class FileOpenIDStore(OpenIDStore): # directory self.temp_dir = os.path.join(directory, 'temp') - self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds + self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds self._setup() @@ -137,7 +142,9 @@ class FileOpenIDStore(OpenIDStore): try: file_obj = os.fdopen(fd, 'wb') return file_obj, name - except: + except Exception: + # If there was an error, don't leave the temporary file + # around. _removeIfPresent(name) raise @@ -183,7 +190,7 @@ class FileOpenIDStore(OpenIDStore): try: os.rename(tmp, filename) - except OSError, why: + except OSError as why: if why.errno != EEXIST: raise @@ -192,7 +199,7 @@ class FileOpenIDStore(OpenIDStore): # file, but not in putting the temporary file in place. try: os.unlink(filename) - except OSError, why: + except OSError as why: if why.errno == ENOENT: pass else: @@ -201,7 +208,7 @@ class FileOpenIDStore(OpenIDStore): # Now the target should not exist. Try renaming again, # giving up if it fails. os.rename(tmp, filename) - except: + except Exception: # If there was an error, don't leave the temporary file # around. _removeIfPresent(tmp) @@ -252,7 +259,7 @@ class FileOpenIDStore(OpenIDStore): def _getAssociation(self, filename): try: assoc_file = file(filename, 'rb') - except IOError, why: + except IOError as why: if why.errno == ENOENT: # No association exists for that URL and handle return None @@ -313,8 +320,8 @@ class FileOpenIDStore(OpenIDStore): filename = os.path.join(self.nonce_dir, filename) try: - fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0200) - except OSError, why: + fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o200) + except OSError as why: if why.errno == EEXIST: return False else: @@ -332,7 +339,7 @@ class FileOpenIDStore(OpenIDStore): for association_filename in association_filenames: try: association_file = file(association_filename, 'rb') - except IOError, why: + except IOError as why: if why.errno == ENOENT: _LOGGER.exception("%s disappeared during %s._allAssocs", association_filename, self.__class__.__name__) diff --git a/openid/store/interface.py b/openid/store/interface.py index bb90972..6377657 100644 --- a/openid/store/interface.py +++ b/openid/store/interface.py @@ -3,6 +3,7 @@ This module contains the definition of the C{L{OpenIDStore}} interface. """ + class OpenIDStore(object): """ This is the interface for the store objects the OpenID library diff --git a/openid/store/memstore.py b/openid/store/memstore.py index 89a16bd..366a596 100644 --- a/openid/store/memstore.py +++ b/openid/store/memstore.py @@ -49,12 +49,12 @@ class ServerAssocs(object): return len(remove), len(self.assocs) - class MemoryStore(object): """In-process memory store. Use for single long-running processes. No persistence supplied. """ + def __init__(self): self.server_assocs = {} self.nonces = {} diff --git a/openid/store/nonce.py b/openid/store/nonce.py index 89ef096..800dfec 100644 --- a/openid/store/nonce.py +++ b/openid/store/nonce.py @@ -2,7 +2,7 @@ __all__ = [ 'split', 'mkNonce', 'checkTimestamp', - ] +] import string from calendar import timegm @@ -20,6 +20,7 @@ SKEW = 60 * 60 * 5 time_fmt = '%Y-%m-%dT%H:%M:%SZ' time_str_len = len('0000-00-00T00:00:00Z') + def split(nonce_string): """Extract a timestamp from the given nonce string @@ -38,6 +39,7 @@ def split(nonce_string): raise ValueError('time out of range') return timestamp, nonce_string[time_str_len:] + def checkTimestamp(nonce_string, allowed_skew=SKEW, now=None): """Is the timestamp that is part of the specified nonce string within the allowed clock-skew of the current time? @@ -74,6 +76,7 @@ def checkTimestamp(nonce_string, allowed_skew=SKEW, now=None): # the past return past <= stamp <= future + def mkNonce(when=None): """Generate a nonce with the current timestamp diff --git a/openid/store/sqlstore.py b/openid/store/sqlstore.py index a629e72..c9e7b23 100644 --- a/openid/store/sqlstore.py +++ b/openid/store/sqlstore.py @@ -4,7 +4,8 @@ various SQL databases to back them. Example of how to initialize a store database:: - python -c 'from openid.store import sqlstore; import pysqlite2.dbapi2; sqlstore.SQLiteStore(pysqlite2.dbapi2.connect("cstore.db")).createTables()' + python -c 'from openid.store import sqlstore; import pysqlite2.dbapi2; + sqlstore.SQLiteStore(pysqlite2.dbapi2.connect("cstore.db")).createTables()' """ import re import time @@ -29,6 +30,7 @@ def _inTxn(func): return wrapped + class SQLStore(OpenIDStore): """ This is the parent class for the SQL stores, which contains the @@ -98,14 +100,13 @@ class SQLStore(OpenIDStore): self._table_names = { 'associations': associations_table or self.associations_table, 'nonces': nonces_table or self.nonces_table, - } - self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds + } + self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds # DB API extension: search for "Connection Attributes .Error, # .ProgrammingError, etc." in # http://www.python.org/dev/peps/pep-0249/ - if (hasattr(self.conn, 'IntegrityError') and - hasattr(self.conn, 'OperationalError')): + if hasattr(self.conn, 'IntegrityError') and hasattr(self.conn, 'OperationalError'): self.exceptions = self.conn if not (hasattr(self.exceptions, 'IntegrityError') and @@ -139,6 +140,7 @@ class SQLStore(OpenIDStore): # arguments if they are passed in as unicode instead of str. # Currently the strings in our tables just have ascii in them, # so this ought to be safe. + def unicode_to_str(arg): if isinstance(arg, unicode): return str(arg) @@ -153,6 +155,7 @@ class SQLStore(OpenIDStore): # as an attribute of this object and executes it. if attr[:3] == 'db_': sql_name = attr[3:] + '_sql' + def func(*args): return self._execSQL(sql_name, *args) setattr(self, attr, func) @@ -174,7 +177,7 @@ class SQLStore(OpenIDStore): finally: self.cur.close() self.cur = None - except: + except Exception: self.conn.rollback() raise else: @@ -248,7 +251,7 @@ class SQLStore(OpenIDStore): (str, str) -> bool """ self.db_remove_assoc(server_url, handle) - return self.cur.rowcount > 0 # -1 is undefined + return self.cur.rowcount > 0 # -1 is undefined removeAssociation = _inTxn(txn_removeAssociation) @@ -350,12 +353,13 @@ class SQLiteStore(SQLStore): # message from the OperationalError. try: return super(SQLiteStore, self).useNonce(*args, **kwargs) - except self.exceptions.OperationalError, why: + except self.exceptions.OperationalError as why: if re.match('^columns .* are not unique$', why[0]): return False else: raise + class MySQLStore(SQLStore): """ This is a MySQL-based specialization of C{L{SQLStore}}. @@ -417,13 +421,14 @@ class MySQLStore(SQLStore): clean_nonce_sql = 'DELETE FROM %(nonces)s WHERE timestamp < %%s;' def blobDecode(self, blob): - if type(blob) is str: + if isinstance(blob, str): # Versions of MySQLdb >= 1.2.2 return blob else: # Versions of MySQLdb prior to 1.2.2 (as far as we can tell) return blob.tostring() + class PostgreSQLStore(SQLStore): """ This is a PostgreSQL-based specialization of C{L{SQLStore}}. @@ -473,7 +478,7 @@ class PostgreSQLStore(SQLStore): REPLACE INTO is not supported by PostgreSQL (and is not standard SQL). """ - result = self.db_get_assoc(server_url, handle) + self.db_get_assoc(server_url, handle) rows = self.cur.fetchall() if len(rows): # Update the table since this associations already exists. diff --git a/openid/test/cryptutil.py b/openid/test/cryptutil.py index e52b6a3..cf6074c 100644 --- a/openid/test/cryptutil.py +++ b/openid/test/cryptutil.py @@ -7,6 +7,7 @@ from openid import cryptutil # Most of the purpose of this test is to make sure that cryptutil can # find a good source of randomness on this machine. + def test_cryptrand(): # It's possible, but HIGHLY unlikely that a correct implementation # will fail by returning the same number twice @@ -17,15 +18,16 @@ def test_cryptrand(): assert len(t) == 32 assert s != t - a = cryptutil.randrange(2L ** 128) - b = cryptutil.randrange(2L ** 128) - assert type(a) is long - assert type(b) is long + a = cryptutil.randrange(2 ** 128) + b = cryptutil.randrange(2 ** 128) + assert isinstance(a, long) + assert isinstance(b, long) assert b != a # Make sure that we can generate random numbers that are larger # than platform int size - cryptutil.randrange(long(sys.maxint) + 1L) + cryptutil.randrange(long(sys.maxsize) + 1) + def test_reversed(): if hasattr(cryptutil, 'reversed'): @@ -37,10 +39,10 @@ def test_reversed(): ('abcdefg', 'gfedcba'), ([], []), ([1], [1]), - ([1,2], [2,1]), - ([1,2,3], [3,2,1]), + ([1, 2], [2, 1]), + ([1, 2, 3], [3, 2, 1]), (range(1000), range(999, -1, -1)), - ] + ] for case, expected in cases: expected = list(expected) @@ -49,28 +51,29 @@ def test_reversed(): twice = list(cryptutil.reversed(actual)) assert twice == list(case), (actual, case, twice) + def test_binaryLongConvert(): - MAX = sys.maxint + MAX = sys.maxsize for iteration in xrange(500): - n = 0L + n = 0 for i in range(10): n += long(random.randrange(MAX)) s = cryptutil.longToBinary(n) - assert type(s) is str + assert isinstance(s, str) n_prime = cryptutil.binaryToLong(s) assert n == n_prime, (n, n_prime) cases = [ - ('\x00', 0L), - ('\x01', 1L), - ('\x7F', 127L), - ('\x00\xFF', 255L), - ('\x00\x80', 128L), - ('\x00\x81', 129L), - ('\x00\x80\x00', 32768L), - ('OpenID is cool', 1611215304203901150134421257416556L) - ] + ('\x00', 0), + ('\x01', 1), + ('\x7F', 127), + ('\x00\xFF', 255), + ('\x00\x80', 128), + ('\x00\x81', 129), + ('\x00\x80\x00', 32768), + ('OpenID is cool', 1611215304203901150134421257416556) + ] for s, n in cases: n_prime = cryptutil.binaryToLong(s) @@ -78,6 +81,7 @@ def test_binaryLongConvert(): assert n == n_prime, (s, n, n_prime) assert s == s_prime, (n, s, s_prime) + def test_longToBase64(): f = file(os.path.join(os.path.dirname(__file__), 'n2b64')) try: @@ -87,6 +91,7 @@ def test_longToBase64(): finally: f.close() + def test_base64ToLong(): f = file(os.path.join(os.path.dirname(__file__), 'n2b64')) try: @@ -104,5 +109,6 @@ def test(): test_longToBase64() test_base64ToLong() + if __name__ == '__main__': test() diff --git a/openid/test/datadriven.py b/openid/test/datadriven.py index c7dc4f7..aac6e9d 100644 --- a/openid/test/datadriven.py +++ b/openid/test/datadriven.py @@ -31,6 +31,7 @@ class DataDrivenTestCase(unittest.TestCase): def shortDescription(self): return '%s for %s' % (self.__class__.__name__, self.description) + def loadTests(module_name): loader = unittest.defaultTestLoader this_module = __import__(module_name, {}, {}, [None]) @@ -38,8 +39,7 @@ def loadTests(module_name): tests = [] for name in dir(this_module): obj = getattr(this_module, name) - if (isinstance(obj, (type, types.ClassType)) and - issubclass(obj, unittest.TestCase)): + if isinstance(obj, (type, types.ClassType)) and issubclass(obj, unittest.TestCase): if hasattr(obj, 'loadTests'): tests.extend(obj.loadTests()) else: diff --git a/openid/test/dh.py b/openid/test/dh.py index 299730b..01a6ab5 100644 --- a/openid/test/dh.py +++ b/openid/test/dh.py @@ -16,7 +16,7 @@ def test_strxor(): ('\x01', '\x02', '\x03'), ('\xf0', '\x0f', '\xff'), ('\xff', '\x0f', '\xf0'), - ] + ] for aa, bb, expected in cases: actual = strxor(aa, bb) @@ -28,7 +28,7 @@ def test_strxor(): (NUL * 3, NUL * 4), (''.join(map(chr, xrange(256))), ''.join(map(chr, xrange(128)))), - ] + ] for aa, bb in exc_cases: try: @@ -38,6 +38,7 @@ def test_strxor(): else: assert False, 'Expected ValueError, got %r' % (unexpected,) + def test1(): dh1 = DiffieHellman.fromDefaults() dh2 = DiffieHellman.fromDefaults() @@ -46,11 +47,13 @@ def test1(): assert secret1 == secret2 return secret1 + def test_exchange(): s1 = test1() s2 = test1() assert s1 != s2 + def test_public(): f = file(os.path.join(os.path.dirname(__file__), 'dhpriv')) dh = DiffieHellman.fromDefaults() @@ -63,10 +66,12 @@ def test_public(): finally: f.close() + def test(): test_exchange() test_public() test_strxor() + if __name__ == '__main__': test() diff --git a/openid/test/discoverdata.py b/openid/test/discoverdata.py index 1d906d8..32d9619 100644 --- a/openid/test/discoverdata.py +++ b/openid/test/discoverdata.py @@ -9,25 +9,26 @@ tests_dir = os.path.dirname(__file__) data_path = os.path.join(tests_dir, 'data') testlist = [ -# success, input_name, id_name, result_name - (True, "equiv", "equiv", "xrds"), - (True, "header", "header", "xrds"), - (True, "lowercase_header", "lowercase_header", "xrds"), - (True, "xrds", "xrds", "xrds"), - (True, "xrds_ctparam", "xrds_ctparam", "xrds_ctparam"), - (True, "xrds_ctcase", "xrds_ctcase", "xrds_ctcase"), - (False, "xrds_html", "xrds_html", "xrds_html"), - (True, "redir_equiv", "equiv", "xrds"), - (True, "redir_header", "header", "xrds"), - (True, "redir_xrds", "xrds", "xrds"), - (False, "redir_xrds_html", "xrds_html", "xrds_html"), - (True, "redir_redir_equiv", "equiv", "xrds"), - (False, "404_server_response", None, None), - (False, "404_with_header", None, None), - (False, "404_with_meta", None, None), - (False, "201_server_response", None, None), - (False, "500_server_response", None, None), - ] + # success, input_name, id_name, result_name + (True, "equiv", "equiv", "xrds"), + (True, "header", "header", "xrds"), + (True, "lowercase_header", "lowercase_header", "xrds"), + (True, "xrds", "xrds", "xrds"), + (True, "xrds_ctparam", "xrds_ctparam", "xrds_ctparam"), + (True, "xrds_ctcase", "xrds_ctcase", "xrds_ctcase"), + (False, "xrds_html", "xrds_html", "xrds_html"), + (True, "redir_equiv", "equiv", "xrds"), + (True, "redir_header", "header", "xrds"), + (True, "redir_xrds", "xrds", "xrds"), + (False, "redir_xrds_html", "xrds_html", "xrds_html"), + (True, "redir_redir_equiv", "equiv", "xrds"), + (False, "404_server_response", None, None), + (False, "404_with_header", None, None), + (False, "404_with_meta", None, None), + (False, "201_server_response", None, None), + (False, "500_server_response", None, None), +] + def getDataName(*components): sanitized = [] @@ -42,15 +43,18 @@ def getDataName(*components): return os.path.join(data_path, *sanitized) + def getExampleXRDS(): filename = getDataName('example-xrds.xml') return file(filename).read() + example_xrds = getExampleXRDS() default_test_file = getDataName('test1-discover.txt') discover_tests = {} + def readTests(filename): data = file(filename).read() tests = {} @@ -59,6 +63,7 @@ def readTests(filename): tests[name] = content return tests + def getData(filename, name): global discover_tests try: @@ -67,25 +72,27 @@ def getData(filename, name): file_tests = discover_tests[filename] = readTests(filename) return file_tests[name] + def fillTemplate(test_name, template, base_url, example_xrds): mapping = [ ('URL_BASE/', base_url), ('<XRDS Content>', example_xrds), ('YADIS_HEADER', YADIS_HEADER_NAME), ('NAME', test_name), - ] + ] for k, v in mapping: template = template.replace(k, v) return template + def generateSample(test_name, base_url, example_xrds=example_xrds, filename=default_test_file): try: template = getData(filename, test_name) - except IOError, why: + except IOError as why: import errno if why[0] == errno.ENOENT: raise KeyError(filename) @@ -94,6 +101,7 @@ def generateSample(test_name, base_url, return fillTemplate(test_name, template, base_url, example_xrds) + def generateResult(base_url, input_name, id_name, result_name, success): input_url = urlparse.urljoin(base_url, input_name) diff --git a/openid/test/kvform.py b/openid/test/kvform.py index b54a64b..7bbb5ce 100644 --- a/openid/test/kvform.py +++ b/openid/test/kvform.py @@ -17,6 +17,7 @@ class KVBaseTest(unittest.TestCase, CatchLogs): def tearDown(self): CatchLogs.tearDown(self) + class KVDictTest(KVBaseTest): def __init__(self, kv, dct, warnings): unittest.TestCase.__init__(self) @@ -40,6 +41,7 @@ class KVDictTest(KVBaseTest): d2 = kvform.kvToDict(kv) self.failUnlessEqual(d, d2) + class KVSeqTest(KVBaseTest): def __init__(self, seq, kv, expected_warnings): unittest.TestCase.__init__(self) @@ -52,9 +54,9 @@ class KVSeqTest(KVBaseTest): and end of each value of each pair""" clean = [] for k, v in self.seq: - if type(k) is str: + if isinstance(k, str): k = k.decode('utf8') - if type(v) is str: + if isinstance(v, str): v = v.decode('utf8') clean.append((k.strip(), v.strip())) return clean @@ -63,7 +65,7 @@ class KVSeqTest(KVBaseTest): # seq serializes to expected kvform actual = kvform.seqToKV(self.seq) self.failUnlessEqual(self.kvform, actual) - self.failUnless(type(actual) is str) + self.assertIsInstance(actual, str) # Parse back to sequence. Expected to be unchanged, except # stripping whitespace from start and end of values @@ -74,15 +76,14 @@ class KVSeqTest(KVBaseTest): self.failUnlessEqual(seq, clean_seq) self.checkWarnings(self.expected_warnings) + kvdict_cases = [ # (kvform, parsed dictionary, expected warnings) ('', {}, 0), - ('college:harvey mudd\n', {'college':'harvey mudd'}, 0), - ('city:claremont\nstate:CA\n', - {'city':'claremont', 'state':'CA'}, 0), + ('college:harvey mudd\n', {'college': 'harvey mudd'}, 0), + ('city:claremont\nstate:CA\n', {'city': 'claremont', 'state': 'CA'}, 0), ('is_valid:true\ninvalidate_handle:{HMAC-SHA1:2398410938412093}\n', - {'is_valid':'true', - 'invalidate_handle':'{HMAC-SHA1:2398410938412093}'}, 0), + {'is_valid': 'true', 'invalidate_handle': '{HMAC-SHA1:2398410938412093}'}, 0), # Warnings from lines with no colon: ('x\n', {}, 1), @@ -93,18 +94,18 @@ kvdict_cases = [ ('x\n\n', {}, 1), # Warning from empty key - (':\n', {'':''}, 1), - (':missing key\n', {'':'missing key'}, 1), + (':\n', {'': ''}, 1), + (':missing key\n', {'': 'missing key'}, 1), # Warnings from leading or trailing whitespace in key or value - (' street:foothill blvd\n', {'street':'foothill blvd'}, 1), - ('major: computer science\n', {'major':'computer science'}, 1), - (' dorm : east \n', {'dorm':'east'}, 2), + (' street:foothill blvd\n', {'street': 'foothill blvd'}, 1), + ('major: computer science\n', {'major': 'computer science'}, 1), + (' dorm : east \n', {'dorm': 'east'}, 2), # Warnings from missing trailing newline - ('e^(i*pi)+1:0', {'e^(i*pi)+1':'0'}, 1), - ('east:west\nnorth:south', {'east':'west', 'north':'south'}, 1), - ] + ('e^(i*pi)+1:0', {'e^(i*pi)+1': '0'}, 1), + ('east:west\nnorth:south', {'east': 'west', 'north': 'south'}, 1), +] kvseq_cases = [ ([], '', 0), @@ -131,7 +132,7 @@ kvseq_cases = [ (' a ', ' b ')], ' open id : use ful \n a : b \n', 8), ([(u'foo', 'bar')], 'foo:bar\n', 0), - ] +] kvexc_cases = [ [('openid', 'use\nful')], @@ -140,7 +141,8 @@ kvexc_cases = [ [('open:id', 'useful')], [('foo', 'bar'), ('ba\n d', 'seed')], [('foo', 'bar'), ('bad:', 'seed')], - ] +] + class KVExcTest(unittest.TestCase): def __init__(self, seq): @@ -153,14 +155,16 @@ class KVExcTest(unittest.TestCase): def runTest(self): self.failUnlessRaises(ValueError, kvform.seqToKV, self.seq) + class GeneralTest(KVBaseTest): kvform = '<None>' def test_convert(self): - result = kvform.seqToKV([(1,1)]) + result = kvform.seqToKV([(1, 1)]) self.failUnlessEqual(result, '1:1\n') self.checkWarnings(2) + def pyUnitTests(): tests = [KVDictTest(*case) for case in kvdict_cases] tests.extend([KVSeqTest(*case) for case in kvseq_cases]) @@ -168,6 +172,7 @@ def pyUnitTests(): tests.append(unittest.defaultTestLoader.loadTestsFromTestCase(GeneralTest)) return unittest.TestSuite(tests) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/linkparse.py b/openid/test/linkparse.py index adcdfb3..f31f9ef 100644 --- a/openid/test/linkparse.py +++ b/openid/test/linkparse.py @@ -23,6 +23,7 @@ def parseLink(line): return (optional, attrs) + def parseCase(s): header, markup = s.split('\n\n', 1) lines = header.split('\n') @@ -31,6 +32,7 @@ def parseCase(s): desc = name[6:] return desc, markup, map(parseLink, lines) + def parseTests(s): tests = [] @@ -47,6 +49,7 @@ def parseTests(s): return num_tests, tests + class _LinkTest(unittest.TestCase): def __init__(self, desc, case, expected, raw): unittest.TestCase.__init__(self) @@ -84,6 +87,7 @@ class _LinkTest(unittest.TestCase): assert i == len(actual) + def pyUnitTests(): here = os.path.dirname(os.path.abspath(__file__)) test_data_file_name = os.path.join(here, 'linkparse.txt') @@ -105,6 +109,7 @@ def pyUnitTests(): return unittest.TestSuite(tests) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/oidutil.py b/openid/test/oidutil.py index 568f16a..c7a002f 100644 --- a/openid/test/oidutil.py +++ b/openid/test/oidutil.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -import codecs import random import string import unittest @@ -25,7 +24,7 @@ def test_base64(): '\x01', '\x00' * 100, ''.join(map(chr, range(256))), - ] + ] for s in cases: b64 = oidutil.toBase64(s) @@ -42,6 +41,7 @@ def test_base64(): s_prime = oidutil.fromBase64(b64) assert s_prime == s, (s, b64, s_prime) + class AppendArgsTest(unittest.TestCase): def __init__(self, desc, args, expected): unittest.TestCase.__init__(self) @@ -56,6 +56,7 @@ class AppendArgsTest(unittest.TestCase): def shortDescription(self): return self.desc + class TestUnicodeConversion(unittest.TestCase): def test_toUnicode(self): @@ -68,6 +69,7 @@ class TestUnicodeConversion(unittest.TestCase): # Other encodings raise exceptions self.assertRaises(UnicodeDecodeError, lambda: oidutil.toUnicode(u'fööbär'.encode('latin-1'))) + class TestSymbol(unittest.TestCase): def testCopyHash(self): import copy @@ -96,7 +98,7 @@ def buildAppendTests(): simple + '?a=b'), ('one dict', - (simple, {'a':'b'}), + (simple, {'a': 'b'}), simple + '?a=b'), ('two list (same)', @@ -112,7 +114,7 @@ def buildAppendTests(): simple + '?b=c&a=b'), ('two dict (order)', - (simple, {'b':'c', 'a':'b'}), + (simple, {'b': 'c', 'a': 'b'}), simple + '?a=b&b=c'), ('escape', @@ -144,17 +146,17 @@ def buildAppendTests(): simple + '?stuff=bother&ack=ack'), ('args exist (dict 2)', - (simple + '?stuff=bother', {'ack': 'ack', 'zebra':'lion'}), + (simple + '?stuff=bother', {'ack': 'ack', 'zebra': 'lion'}), simple + '?stuff=bother&ack=ack&zebra=lion'), ('three args (dict)', - (simple, {'stuff': 'bother', 'ack': 'ack', 'zebra':'lion'}), + (simple, {'stuff': 'bother', 'ack': 'ack', 'zebra': 'lion'}), simple + '?ack=ack&stuff=bother&zebra=lion'), ('three args (list)', (simple, [('stuff', 'bother'), ('ack', 'ack'), ('zebra', 'lion')]), simple + '?stuff=bother&ack=ack&zebra=lion'), - ] + ] tests = [] @@ -164,12 +166,14 @@ def buildAppendTests(): return unittest.TestSuite(tests) + def pyUnitTests(): some = buildAppendTests() some.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestSymbol)) some.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestUnicodeConversion)) return some + def test_appendArgs(): suite = buildAppendTests() suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestSymbol)) @@ -181,10 +185,12 @@ def test_appendArgs(): # specified and tested in oidutil.py These include, but are not # limited to appendArgs + def test(skipPyUnit=True): test_base64() if not skipPyUnit: test_appendArgs() + if __name__ == '__main__': test(skipPyUnit=False) diff --git a/openid/test/storetest.py b/openid/test/storetest.py index 6d876fc..a3885b5 100644 --- a/openid/test/storetest.py +++ b/openid/test/storetest.py @@ -17,18 +17,20 @@ for c in string.printable: allowed_handle.append(c) allowed_handle = ''.join(allowed_handle) + def generateHandle(n): return randomString(n, allowed_handle) + generateSecret = randomString + def getTmpDbName(): hostname = socket.gethostname() hostname = hostname.replace('.', '_') hostname = hostname.replace('-', '_') - return "%s_%d_%s_openid_test" % \ - (hostname, os.getpid(), \ - random.randrange(1, int(time.time()))) + return "%s_%d_%s_openid_test" % (hostname, os.getpid(), random.randrange(1, int(time.time()))) + def testStore(store): """Make sure a given store has a minimum of API compliance. Call @@ -38,10 +40,11 @@ def testStore(store): OpenIDStore -> NoneType """ - ### Association functions + # Association functions now = int(time.time()) server_url = 'http://www.myopenid.com/openid' + def genAssoc(issued, lifetime=600): sec = generateSecret(20) hdl = generateHandle(128) @@ -146,15 +149,15 @@ def testStore(store): checkRemove(server_url, assoc.handle, False) checkRemove(server_url, assoc3.handle, False) - ### test expired associations + # test expired associations # assoc 1: server 1, valid # assoc 2: server 1, expired # assoc 3: server 2, expired # assoc 4: server 3, valid - assocValid1 = genAssoc(issued=-3600,lifetime=7200) + assocValid1 = genAssoc(issued=-3600, lifetime=7200) assocValid2 = genAssoc(issued=-5) - assocExpired1 = genAssoc(issued=-7200,lifetime=3600) - assocExpired2 = genAssoc(issued=-7200,lifetime=3600) + assocExpired1 = genAssoc(issued=-7200, lifetime=3600) + assocExpired2 = genAssoc(issued=-7200, lifetime=3600) store.cleanupAssociations() store.storeAssociation(server_url + '1', assocValid1) @@ -165,7 +168,7 @@ def testStore(store): cleaned = store.cleanupAssociations() assert cleaned == 2, cleaned - ### Nonce functions + # Nonce functions def checkUseNonce(nonce, expected, server_url, msg=''): stamp, salt = split(nonce) @@ -189,7 +192,6 @@ def testStore(store): old_nonce = mkNonce(3600) checkUseNonce(old_nonce, False, url, "Old nonce (%r) passed." % (old_nonce,)) - old_nonce1 = mkNonce(now - 20000) old_nonce2 = mkNonce(now - 10000) recent_nonce = mkNonce(now - 600) @@ -235,11 +237,12 @@ def test_filestore(): try: testStore(store) store.cleanup() - except: + except Exception: raise else: shutil.rmtree(temp_dir) + def test_sqlite(): from openid.store import sqlstore try: @@ -252,6 +255,7 @@ def test_sqlite(): store.createTables() testStore(store) + def test_mysql(): from openid.store import sqlstore try: @@ -263,12 +267,10 @@ def test_mysql(): db_passwd = '' db_name = getTmpDbName() - from MySQLdb.constants import ER - # Change this connect line to use the right user and password try: - conn = MySQLdb.connect(user=db_user, passwd=db_passwd, host = db_host) - except MySQLdb.OperationalError, why: + conn = MySQLdb.connect(user=db_user, passwd=db_passwd, host=db_host) + except MySQLdb.OperationalError as why: if why[0] == 2005: print ('Skipping MySQL store test (cannot connect ' 'to test server on host %r)' % (db_host,)) @@ -292,6 +294,7 @@ def test_mysql(): # failing test, comment out this line. conn.query('DROP DATABASE %s;' % db_name) + def test_postgresql(): """ Tests the PostgreSQLStore on a locally-hosted PostgreSQL database @@ -329,8 +332,7 @@ def test_postgresql(): # Connect once to create the database; reconnect to access the # new database. - conn_create = psycopg.connect(database = 'template1', user = db_user, - host = db_host) + conn_create = psycopg.connect(database='template1', user=db_user, host=db_host) conn_create.autocommit() # Create the test database. @@ -339,8 +341,7 @@ def test_postgresql(): conn_create.close() # Connect to the test database. - conn_test = psycopg.connect(database = db_name, user = db_user, - host = db_host) + conn_test = psycopg.connect(database=db_name, user=db_user, host=db_host) # OK, we're in the right environment. Create the store # instance and create the tables. @@ -361,31 +362,33 @@ def test_postgresql(): time.sleep(1) # Remove the database now that the test is over. - conn_remove = psycopg.connect(database = 'template1', user = db_user, - host = db_host) + conn_remove = psycopg.connect(database='template1', user=db_user, host=db_host) conn_remove.autocommit() cursor = conn_remove.cursor() cursor.execute('DROP DATABASE %s;' % (db_name,)) conn_remove.close() + def test_memstore(): from openid.store import memstore testStore(memstore.MemoryStore()) + test_functions = [ test_filestore, test_sqlite, test_mysql, test_postgresql, test_memstore, - ] +] + def pyUnitTests(): tests = map(unittest.FunctionTestCase, test_functions) - load = unittest.defaultTestLoader.loadTestsFromTestCase return unittest.TestSuite(tests) + if __name__ == '__main__': import sys suite = pyUnitTests() diff --git a/openid/test/support.py b/openid/test/support.py index d61973c..e864c89 100644 --- a/openid/test/support.py +++ b/openid/test/support.py @@ -7,7 +7,7 @@ from openid import message class TestHandler(BufferingHandler): def __init__(self, messages): BufferingHandler.__init__(self, 0) - self.messages = messages + self.messages = messages def shouldFlush(self): return False @@ -15,6 +15,7 @@ class TestHandler(BufferingHandler): def emit(self, record): self.messages.append(record) + class OpenIDTestMixin(object): def failUnlessOpenIDValueEquals(self, msg, key, expected, ns=None): if ns is None: @@ -33,22 +34,23 @@ class OpenIDTestMixin(object): error_message = 'openid.%s unexpectedly present: %s' % (key, actual) self.failIf(actual is not None, error_message) + class CatchLogs(object): def setUp(self): - self.messages = [] - root_logger = logging.getLogger() - self.old_log_level = root_logger.getEffectiveLevel() - root_logger.setLevel(logging.DEBUG) + self.messages = [] + root_logger = logging.getLogger() + self.old_log_level = root_logger.getEffectiveLevel() + root_logger.setLevel(logging.DEBUG) - self.handler = TestHandler(self.messages) - formatter = logging.Formatter("%(message)s [%(asctime)s - %(name)s - %(levelname)s]") - self.handler.setFormatter(formatter) - root_logger.addHandler(self.handler) + self.handler = TestHandler(self.messages) + formatter = logging.Formatter("%(message)s [%(asctime)s - %(name)s - %(levelname)s]") + self.handler.setFormatter(formatter) + root_logger.addHandler(self.handler) def tearDown(self): root_logger = logging.getLogger() - root_logger.removeHandler(self.handler) - root_logger.setLevel(self.old_log_level) + root_logger.removeHandler(self.handler) + root_logger.setLevel(self.old_log_level) def failUnlessLogMatches(self, *prefixes): """ @@ -58,14 +60,10 @@ class CatchLogs(object): messages. """ messages = [r.getMessage() for r in self.messages] - assert len(prefixes) == len(messages), \ - "Expected log prefixes %r, got %r" % (prefixes, - messages) - - for prefix, message in zip(prefixes, messages): - assert message.startswith(prefix), \ - "Expected log prefixes %r, got %r" % (prefixes, - messages) + assert len(prefixes) == len(messages), "Expected log prefixes %r, got %r" % (prefixes, messages) + + for prefix, msg in zip(prefixes, messages): + assert msg.startswith(prefix), "Expected log prefixes %r, got %r" % (prefixes, messages) def failUnlessLogEmpty(self): self.failUnlessLogMatches() diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index 547e42a..b8af670 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -17,6 +17,7 @@ def getTestData(): i += 1 return lines + def chunk(lines): """Return groups of lines separated by whitespace or comments @@ -38,6 +39,7 @@ def chunk(lines): return chunks + def parseLines(chunk): """Take the given chunk of lines and turn it into a test data dictionary @@ -51,6 +53,7 @@ def parseLines(chunk): return items + def parseAvailable(available_text): """Parse an Available: line's data @@ -58,6 +61,7 @@ def parseAvailable(available_text): """ return [s.strip() for s in available_text.split(',')] + def parseExpected(expected_text): """Parse an Expected: line's data @@ -78,6 +82,7 @@ def parseExpected(expected_text): return expected + class MatchAcceptTest(unittest.TestCase): def __init__(self, descr, accept_header, available, expected): unittest.TestCase.__init__(self) @@ -94,6 +99,7 @@ class MatchAcceptTest(unittest.TestCase): actual = accept.matchTypes(accepted, self.available) self.failUnlessEqual(self.expected, actual) + def pyUnitTests(): lines = getTestData() chunks = chunk(lines) @@ -107,7 +113,7 @@ def pyUnitTests(): lnos.append(lno) try: available = parseAvailable(avail_data) - except: + except Exception: print 'On line', lno raise @@ -115,7 +121,7 @@ def pyUnitTests(): lnos.append(lno) try: expected = parseExpected(exp_data) - except: + except Exception: print 'On line', lno raise @@ -124,6 +130,7 @@ def pyUnitTests(): cases.append(case) return unittest.TestSuite(cases) + if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(pyUnitTests()) diff --git a/openid/test/test_association.py b/openid/test/test_association.py index 86c2883..929d5b6 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -1,14 +1,11 @@ import time import unittest -import warnings -from openid import association, cryptutil -from openid.consumer.consumer import (DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA256ConsumerSession, - PlainTextConsumerSession) +from openid import association +from openid.consumer.consumer import DiffieHellmanSHA1ConsumerSession, PlainTextConsumerSession from openid.dh import DiffieHellman from openid.message import BARE_NS, OPENID2_NS, OPENID_NS, Message -from openid.server.server import (DiffieHellmanSHA1ServerSession, DiffieHellmanSHA256ServerSession, - PlainTextServerSession) +from openid.server.server import DiffieHellmanSHA1ServerSession, PlainTextServerSession from openid.test import datadriven @@ -27,25 +24,24 @@ class AssociationSerializationTest(unittest.TestCase): self.failUnlessEqual(assoc.assoc_type, assoc2.assoc_type) - - def createNonstandardConsumerDH(): nonstandard_dh = DiffieHellman(1315291, 2) return DiffieHellmanSHA1ConsumerSession(nonstandard_dh) + class DiffieHellmanSessionTest(datadriven.DataDrivenTestCase): secrets = [ '\x00' * 20, '\xff' * 20, ' ' * 20, 'This is a secret....', - ] + ] session_factories = [ (DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA1ServerSession), (createNonstandardConsumerDH, DiffieHellmanSHA1ServerSession), (PlainTextConsumerSession, PlainTextServerSession), - ] + ] def generateCases(cls): return [(c, s, sec) @@ -69,7 +65,6 @@ class DiffieHellmanSessionTest(datadriven.DataDrivenTestCase): self.failUnlessEqual(self.secret, check_secret) - class TestMakePairs(unittest.TestCase): """Check the key-value formatting methods of associations. """ @@ -81,29 +76,26 @@ class TestMakePairs(unittest.TestCase): 'identifier': '=example', 'signed': 'identifier,mode', 'sig': 'cephalopod', - }) + }) m.updateArgs(BARE_NS, {'xey': 'value'}) self.assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") - def testMakePairs(self): """Make pairs using the OpenID 1.x type signed list.""" pairs = self.assoc._makePairs(self.message) expected = [ ('identifier', '=example'), ('mode', 'id_res'), - ] + ] self.failUnlessEqual(pairs, expected) - class TestMac(unittest.TestCase): def setUp(self): self.pairs = [('key1', 'value1'), ('key2', 'value2')] - def test_sha1(self): assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") @@ -121,7 +113,6 @@ class TestMac(unittest.TestCase): self.failUnlessEqual(sig, expected) - class TestMessageSigning(unittest.TestCase): def setUp(self): self.message = m = Message(OPENID2_NS) @@ -132,7 +123,6 @@ class TestMessageSigning(unittest.TestCase): 'openid.identifier': '=example', 'xey': 'value'} - def test_signSHA1(self): assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") @@ -170,6 +160,7 @@ class TestCheckMessageSignature(unittest.TestCase): def pyUnitTests(): return datadriven.loadTests(__name__) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 11161fb..79be68c 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -5,10 +5,9 @@ this works for now. """ import unittest -from openid import oidutil -from openid.consumer.consumer import DiffieHellmanSHA1ConsumerSession, GenericConsumer, ProtocolError +from openid.consumer.consumer import GenericConsumer, ProtocolError from openid.consumer.discover import OPENID_1_1_TYPE, OPENID_2_0_TYPE, OpenIDServiceEndpoint -from openid.message import OPENID2_NS, OPENID_NS, Message, no_default +from openid.message import OPENID2_NS, OPENID_NS, Message from openid.server.server import DiffieHellmanSHA1ServerSession from openid.store import memstore from openid.test.test_consumer import CatchLogs @@ -16,11 +15,12 @@ from openid.test.test_consumer import CatchLogs # Some values we can use for convenience (see mkAssocResponse) association_response_values = { 'expires_in': '1000', - 'assoc_handle':'a handle', - 'assoc_type':'a type', - 'session_type':'a session type', - 'ns':OPENID2_NS, - } + 'assoc_handle': 'a handle', + 'assoc_type': 'a type', + 'session_type': 'a session type', + 'ns': OPENID2_NS, +} + def mkAssocResponse(*keys): """Build an association response message that contains the @@ -32,6 +32,7 @@ def mkAssocResponse(*keys): args = dict([(key, association_response_values[key]) for key in keys]) return Message.fromOpenIDArgs(args) + class BaseAssocTest(CatchLogs, unittest.TestCase): def setUp(self): CatchLogs.setUp(self) @@ -42,12 +43,13 @@ class BaseAssocTest(CatchLogs, unittest.TestCase): def failUnlessProtocolError(self, str_prefix, func, *args, **kwargs): try: result = func(*args, **kwargs) - except ProtocolError, e: + except ProtocolError as e: message = 'Expected prefix %r, got %r' % (str_prefix, e[0]) self.failUnless(e[0].startswith(str_prefix), message) else: self.fail('Expected ProtocolError, got %r' % (result,)) + def mkExtractAssocMissingTest(keys): """Factory function for creating test methods for generating missing field tests. @@ -77,6 +79,7 @@ def mkExtractAssocMissingTest(keys): return test + class TestExtractAssociationMissingFieldsOpenID2(BaseAssocTest): """Test for returning an error upon missing fields in association responses for OpenID 2""" @@ -95,6 +98,7 @@ class TestExtractAssociationMissingFieldsOpenID2(BaseAssocTest): test_missingSessionType_openid2 = mkExtractAssocMissingTest( ['expires_in', 'assoc_handle', 'assoc_type', 'ns']) + class TestExtractAssociationMissingFieldsOpenID1(BaseAssocTest): """Test for returning an error upon missing fields in association responses for OpenID 2""" @@ -110,11 +114,13 @@ class TestExtractAssociationMissingFieldsOpenID1(BaseAssocTest): test_missingAssocType_openid1 = mkExtractAssocMissingTest( ['expires_in', 'assoc_handle']) + class DummyAssocationSession(object): def __init__(self, session_type, allowed_assoc_types=()): self.session_type = session_type self.allowed_assoc_types = allowed_assoc_types + class ExtractAssociationSessionTypeMismatch(BaseAssocTest): def mkTest(requested_session_type, response_session_type, openid1=False): def test(self): @@ -124,48 +130,47 @@ class ExtractAssociationSessionTypeMismatch(BaseAssocTest): keys.remove('ns') msg = mkAssocResponse(*keys) msg.setArg(OPENID_NS, 'session_type', response_session_type) - self.failUnlessProtocolError('Session type mismatch', - self.consumer._extractAssociation, msg, assoc_session) + self.failUnlessProtocolError('Session type mismatch', self.consumer._extractAssociation, msg, assoc_session) return test test_typeMismatchNoEncBlank_openid2 = mkTest( requested_session_type='no-encryption', response_session_type='', - ) + ) test_typeMismatchDHSHA1NoEnc_openid2 = mkTest( requested_session_type='DH-SHA1', response_session_type='no-encryption', - ) + ) test_typeMismatchDHSHA256NoEnc_openid2 = mkTest( requested_session_type='DH-SHA256', response_session_type='no-encryption', - ) + ) test_typeMismatchNoEncDHSHA1_openid2 = mkTest( requested_session_type='no-encryption', response_session_type='DH-SHA1', - ) + ) test_typeMismatchDHSHA1NoEnc_openid1 = mkTest( requested_session_type='DH-SHA1', response_session_type='DH-SHA256', openid1=True, - ) + ) test_typeMismatchDHSHA256NoEnc_openid1 = mkTest( requested_session_type='DH-SHA256', response_session_type='DH-SHA1', openid1=True, - ) + ) test_typeMismatchNoEncDHSHA1_openid1 = mkTest( requested_session_type='no-encryption', response_session_type='DH-SHA1', openid1=True, - ) + ) class TestOpenID1AssociationResponseSessionType(BaseAssocTest): @@ -174,6 +179,7 @@ class TestOpenID1AssociationResponseSessionType(BaseAssocTest): be used if the OpenID 1 response to an associate call sets the 'session_type' field to `session_type_value` """ + def test(self): self._doTest(expected_session_type, session_type_value) self.failUnlessLogEmpty() @@ -201,25 +207,25 @@ class TestOpenID1AssociationResponseSessionType(BaseAssocTest): test_none = mkTest( session_type_value=None, expected_session_type='no-encryption', - ) + ) test_empty = mkTest( session_type_value='', expected_session_type='no-encryption', - ) + ) # This one's different because it expects log messages def test_explicitNoEncryption(self): self._doTest( session_type_value='no-encryption', expected_session_type='no-encryption', - ) + ) self.failUnlessLogMatches('OpenID server sent "no-encryption"') test_dhSHA1 = mkTest( session_type_value='DH-SHA1', expected_session_type='DH-SHA1', - ) + ) # DH-SHA256 is not a valid session type for OpenID1, but this # function does not test that. This is mostly just to make sure @@ -229,7 +235,8 @@ class TestOpenID1AssociationResponseSessionType(BaseAssocTest): test_dhSHA256 = mkTest( session_type_value='DH-SHA256', expected_session_type='DH-SHA256', - ) + ) + class DummyAssociationSession(object): secret = "shh! don't tell!" @@ -243,6 +250,7 @@ class DummyAssociationSession(object): self.extract_secret_called = True return self.secret + class TestInvalidFields(BaseAssocTest): def setUp(self): BaseAssocTest.setUp(self) @@ -256,11 +264,11 @@ class TestInvalidFields(BaseAssocTest): # These arguments should all be valid self.assoc_response = Message.fromOpenIDArgs({ 'expires_in': '1000', - 'assoc_handle':self.assoc_handle, - 'assoc_type':self.assoc_type, - 'session_type':self.session_type, - 'ns':OPENID2_NS, - }) + 'assoc_handle': self.assoc_handle, + 'assoc_type': self.assoc_type, + 'session_type': self.session_type, + 'ns': OPENID2_NS, + }) self.assoc_session = DummyAssociationSession() @@ -283,15 +291,13 @@ class TestInvalidFields(BaseAssocTest): # for the given session. self.assoc_session.allowed_assoc_types = [] self.failUnlessProtocolError('Unsupported assoc_type for session', - self.consumer._extractAssociation, - self.assoc_response, self.assoc_session) + self.consumer._extractAssociation, self.assoc_response, self.assoc_session) def test_badExpiresIn(self): # Invalid value for expires_in should cause failure self.assoc_response.setArg(OPENID_NS, 'expires_in', 'forever') self.failUnlessProtocolError('Invalid expires_in', - self.consumer._extractAssociation, - self.assoc_response, self.assoc_session) + self.consumer._extractAssociation, self.assoc_response, self.assoc_session) # XXX: This is what causes most of the imports in this file. It is @@ -334,5 +340,4 @@ class TestExtractAssociationDiffieHellman(BaseAssocTest): def test_badDHValues(self): sess, server_resp = self._setUpDH() server_resp.setArg(OPENID_NS, 'enc_mac_key', '\x00\x00\x00') - self.failUnlessProtocolError('Malformed response for', - self.consumer._extractAssociation, server_resp, sess) + self.failUnlessProtocolError('Malformed response for', self.consumer._extractAssociation, server_resp, sess) diff --git a/openid/test/test_auth_request.py b/openid/test/test_auth_request.py index 1419ab5..c92ccb3 100644 --- a/openid/test/test_auth_request.py +++ b/openid/test/test_auth_request.py @@ -1,4 +1,3 @@ -import cgi import unittest from openid import message @@ -21,9 +20,11 @@ class DummyEndpoint(object): def isOPIdentifier(self): return self.is_op_identifier + class DummyAssoc(object): handle = "assoc-handle" + class AuthRequestTestMixin(support.OpenIDTestMixin): """Mixin for AuthRequest tests for OpenID 1 and 2; DON'T add unittest.TestCase as a base class here.""" @@ -102,6 +103,7 @@ class AuthRequestTestMixin(support.OpenIDTestMixin): self.failUnlessHasIdentifiers( msg, self.endpoint.local_id, self.endpoint.claimed_id) + class TestAuthRequestOpenID2(AuthRequestTestMixin, unittest.TestCase): preferred_namespace = message.OPENID2_NS @@ -152,13 +154,10 @@ class TestAuthRequestOpenID2(AuthRequestTestMixin, unittest.TestCase): self.failUnlessHasIdentifiers( msg, message.IDENTIFIER_SELECT, message.IDENTIFIER_SELECT) + class TestAuthRequestOpenID1(AuthRequestTestMixin, unittest.TestCase): preferred_namespace = message.OPENID1_NS - def setUpEndpoint(self): - TestAuthRequestBase.setUpEndpoint(self) - self.endpoint.preferred_namespace = message.OPENID1_NS - def failUnlessHasIdentifiers(self, msg, op_specific_id, claimed_id): """Make sure claimed_is is *absent* in request.""" self.failUnlessOpenIDValueEquals(msg, 'identity', op_specific_id) @@ -195,13 +194,16 @@ class TestAuthRequestOpenID1(AuthRequestTestMixin, unittest.TestCase): self.failUnlessEqual(message.IDENTIFIER_SELECT, msg.getArg(message.OPENID1_NS, 'identity')) + class TestAuthRequestOpenID1Immediate(TestAuthRequestOpenID1): immediate = True expected_mode = 'checkid_immediate' + class TestAuthRequestOpenID2Immediate(TestAuthRequestOpenID2): immediate = True expected_mode = 'checkid_immediate' + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 28f90fa..80be5c8 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -13,10 +13,12 @@ class BogusAXMessage(ax.AXMessage): getExtensionArgs = ax.AXMessage._newArgs + class DummyRequest(object): def __init__(self, message): self.message = message + class AXMessageTest(unittest.TestCase): def setUp(self): self.bax = BogusAXMessage() @@ -24,10 +26,10 @@ class AXMessageTest(unittest.TestCase): def test_checkMode(self): check = self.bax._checkMode self.failUnlessRaises(ax.NotAXMessage, check, {}) - self.failUnlessRaises(ax.AXError, check, {'mode':'fetch_request'}) + self.failUnlessRaises(ax.AXError, check, {'mode': 'fetch_request'}) # does not raise an exception when the mode is right - check({'mode':self.bax.mode}) + check({'mode': self.bax.mode}) def test_checkMode_newArgs(self): """_newArgs generates something that has the correct mode""" @@ -80,6 +82,7 @@ class ToTypeURIsTest(unittest.TestCase): uris = ax.toTypeURIs(self.aliases, ','.join([alias1, alias2])) self.failUnlessEqual([uri1, uri2], uris) + class ParseAXValuesTest(unittest.TestCase): """Testing AXKeyValueMessage.parseExtensionArgs.""" @@ -97,27 +100,27 @@ class ParseAXValuesTest(unittest.TestCase): self.failUnlessAXValues({}, {}) def test_missingValueForAliasExplodes(self): - self.failUnlessAXKeyError({'type.foo':'urn:foo'}) + self.failUnlessAXKeyError({'type.foo': 'urn:foo'}) def test_countPresentButNotValue(self): - self.failUnlessAXKeyError({'type.foo':'urn:foo', - 'count.foo':'1'}) + self.failUnlessAXKeyError({'type.foo': 'urn:foo', + 'count.foo': '1'}) def test_invalidCountValue(self): msg = ax.FetchRequest() self.failUnlessRaises(ax.AXError, msg.parseExtensionArgs, - {'type.foo':'urn:foo', - 'count.foo':'bogus'}) + {'type.foo': 'urn:foo', + 'count.foo': 'bogus'}) def test_requestUnlimitedValues(self): msg = ax.FetchRequest() msg.parseExtensionArgs( - {'mode':'fetch_request', - 'required':'foo', - 'type.foo':'urn:foo', - 'count.foo':ax.UNLIMITED_VALUES}) + {'mode': 'fetch_request', + 'required': 'foo', + 'type.foo': 'urn:foo', + 'count.foo': ax.UNLIMITED_VALUES}) attrs = list(msg.iterAttrs()) foo = attrs[0] @@ -135,20 +138,20 @@ class ParseAXValuesTest(unittest.TestCase): {'type.%s' % (alias,): 'urn:foo', 'count.%s' % (alias,): '1', 'value.%s.1' % (alias,): 'first'} - ) + ) def test_invalidAlias(self): types = [ ax.AXKeyValueMessage, ax.FetchRequest - ] + ] inputs = [ - {'type.a.b':'urn:foo', - 'count.a.b':'1'}, - {'type.a,b':'urn:foo', - 'count.a,b':'1'}, - ] + {'type.a.b': 'urn:foo', + 'count.a.b': '1'}, + {'type.a,b': 'urn:foo', + 'count.a,b': '1'}, + ] for typ in types: for input in inputs: @@ -158,37 +161,37 @@ class ParseAXValuesTest(unittest.TestCase): def test_countPresentAndIsZero(self): self.failUnlessAXValues( - {'type.foo':'urn:foo', - 'count.foo':'0', - }, {'urn:foo':[]}) + {'type.foo': 'urn:foo', + 'count.foo': '0', + }, {'urn:foo': []}) def test_singletonEmpty(self): self.failUnlessAXValues( - {'type.foo':'urn:foo', - 'value.foo':'', - }, {'urn:foo':[]}) + {'type.foo': 'urn:foo', + 'value.foo': '', + }, {'urn:foo': []}) def test_doubleAlias(self): self.failUnlessAXKeyError( - {'type.foo':'urn:foo', - 'value.foo':'', - 'type.bar':'urn:foo', - 'value.bar':'', + {'type.foo': 'urn:foo', + 'value.foo': '', + 'type.bar': 'urn:foo', + 'value.bar': '', }) def test_doubleSingleton(self): self.failUnlessAXValues( - {'type.foo':'urn:foo', - 'value.foo':'', - 'type.bar':'urn:bar', - 'value.bar':'', - }, {'urn:foo':[], 'urn:bar':[]}) + {'type.foo': 'urn:foo', + 'value.foo': '', + 'type.bar': 'urn:bar', + 'value.bar': '', + }, {'urn:foo': [], 'urn:bar': []}) def test_singletonValue(self): self.failUnlessAXValues( - {'type.foo':'urn:foo', - 'value.foo':'Westfall', - }, {'urn:foo':['Westfall']}) + {'type.foo': 'urn:foo', + 'value.foo': 'Westfall', + }, {'urn:foo': ['Westfall']}) class FetchRequestTest(unittest.TestCase): @@ -197,7 +200,6 @@ class FetchRequestTest(unittest.TestCase): self.type_a = 'http://janrain.example.com/a' self.alias_a = 'a' - def test_mode(self): self.failUnlessEqual(self.msg.mode, 'fetch_request') @@ -230,14 +232,14 @@ class FetchRequestTest(unittest.TestCase): def test_getExtensionArgs_empty(self): expected_args = { - 'mode':'fetch_request', - } + 'mode': 'fetch_request', + } self.failUnlessEqual(expected_args, self.msg.getExtensionArgs()) def test_getExtensionArgs_noAlias(self): attr = ax.AttrInfo( - type_uri = 'type://of.transportation', - ) + type_uri='type://of.transportation', + ) self.msg.add(attr) ax_args = self.msg.getExtensionArgs() for k, v in ax_args.iteritems(): @@ -248,32 +250,32 @@ class FetchRequestTest(unittest.TestCase): self.fail("Didn't find the type definition") self.failUnlessExtensionArgs({ - 'type.' + alias:attr.type_uri, - 'if_available':alias, - }) + 'type.' + alias: attr.type_uri, + 'if_available': alias, + }) def test_getExtensionArgs_alias_if_available(self): attr = ax.AttrInfo( - type_uri = 'type://of.transportation', - alias = 'transport', - ) + type_uri='type://of.transportation', + alias='transport', + ) self.msg.add(attr) self.failUnlessExtensionArgs({ - 'type.' + attr.alias:attr.type_uri, - 'if_available':attr.alias, - }) + 'type.' + attr.alias: attr.type_uri, + 'if_available': attr.alias, + }) def test_getExtensionArgs_alias_req(self): attr = ax.AttrInfo( - type_uri = 'type://of.transportation', - alias = 'transport', - required = True, - ) + type_uri='type://of.transportation', + alias='transport', + required=True, + ) self.msg.add(attr) self.failUnlessExtensionArgs({ - 'type.' + attr.alias:attr.type_uri, - 'required':attr.alias, - }) + 'type.' + attr.alias: attr.type_uri, + 'required': attr.alias, + }) def failUnlessExtensionArgs(self, expected_args): """Make sure that getExtensionArgs has the expected result @@ -293,18 +295,18 @@ class FetchRequestTest(unittest.TestCase): def test_parseExtensionArgs_extraType(self): extension_args = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + } self.failUnlessRaises(ValueError, self.msg.parseExtensionArgs, extension_args) def test_parseExtensionArgs(self): extension_args = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - 'if_available':self.alias_a - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + 'if_available': self.alias_a + } self.msg.parseExtensionArgs(extension_args) self.failUnless(self.type_a in self.msg) self.failUnlessEqual([self.type_a], list(self.msg)) @@ -317,37 +319,37 @@ class FetchRequestTest(unittest.TestCase): def test_extensionArgs_idempotent(self): extension_args = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - 'if_available':self.alias_a - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + 'if_available': self.alias_a + } self.msg.parseExtensionArgs(extension_args) self.failUnlessEqual(extension_args, self.msg.getExtensionArgs()) self.failIf(self.msg.requested_attributes[self.type_a].required) def test_extensionArgs_idempotent_count_required(self): extension_args = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - 'count.' + self.alias_a:'2', - 'required':self.alias_a - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + 'count.' + self.alias_a: '2', + 'required': self.alias_a + } self.msg.parseExtensionArgs(extension_args) self.failUnlessEqual(extension_args, self.msg.getExtensionArgs()) self.failUnless(self.msg.requested_attributes[self.type_a].required) def test_extensionArgs_count1(self): extension_args = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - 'count.' + self.alias_a:'1', - 'if_available':self.alias_a, - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + 'count.' + self.alias_a: '1', + 'if_available': self.alias_a, + } extension_args_norm = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - 'if_available':self.alias_a, - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + 'if_available': self.alias_a, + } self.msg.parseExtensionArgs(extension_args) self.failUnlessEqual(extension_args_norm, self.msg.getExtensionArgs()) @@ -358,7 +360,7 @@ class FetchRequestTest(unittest.TestCase): 'ns.ax': ax.AXMessage.ns_uri, 'ax.update_url': 'http://different.site/path', 'ax.mode': 'fetch_request', - }) + }) self.failUnlessRaises(ax.AXError, ax.FetchRequest.fromOpenIDRequest, DummyRequest(openid_req_msg)) @@ -371,7 +373,7 @@ class FetchRequestTest(unittest.TestCase): 'ns.ax': ax.AXMessage.ns_uri, 'ax.update_url': 'http://different.site/path', 'ax.mode': 'fetch_request', - }) + }) self.failUnlessRaises(ax.AXError, ax.FetchRequest.fromOpenIDRequest, @@ -385,9 +387,9 @@ class FetchRequestTest(unittest.TestCase): 'ns.ax': ax.AXMessage.ns_uri, 'ax.update_url': 'http://example.com/realm/update_path', 'ax.mode': 'fetch_request', - }) + }) - fr = ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) + ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) def test_openidUpdateURLVerificationSuccessReturnTo(self): openid_req_msg = Message.fromOpenIDArgs({ @@ -397,16 +399,16 @@ class FetchRequestTest(unittest.TestCase): 'ns.ax': ax.AXMessage.ns_uri, 'ax.update_url': 'http://example.com/realm/update_path', 'ax.mode': 'fetch_request', - }) + }) - fr = ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) + ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) def test_fromOpenIDRequestWithoutExtension(self): """return None for an OpenIDRequest without AX paramaters.""" openid_req_msg = Message.fromOpenIDArgs({ 'mode': 'checkid_setup', 'ns': OPENID2_NS, - }) + }) oreq = DummyRequest(openid_req_msg) r = ax.FetchRequest.fromOpenIDRequest(oreq) self.failUnless(r is None, "%s is not None" % (r,)) @@ -420,7 +422,7 @@ class FetchRequestTest(unittest.TestCase): 'ns': OPENID2_NS, 'ns.ax': ax.AXMessage.ns_uri, 'ax.mode': 'fetch_request', - }) + }) oreq = DummyRequest(openid_req_msg) r = ax.FetchRequest.fromOpenIDRequest(oreq) self.failUnless(r is not None) @@ -440,14 +442,14 @@ class FetchResponseTest(unittest.TestCase): def test_getExtensionArgs_empty(self): expected_args = { - 'mode':'fetch_response', - } + 'mode': 'fetch_response', + } self.failUnlessEqual(expected_args, self.msg.getExtensionArgs()) def test_getExtensionArgs_empty_request(self): expected_args = { - 'mode':'fetch_response', - } + 'mode': 'fetch_response', + } req = ax.FetchRequest() msg = ax.FetchResponse(request=req) self.failUnlessEqual(expected_args, msg.getExtensionArgs()) @@ -457,10 +459,10 @@ class FetchResponseTest(unittest.TestCase): alias = 'ext0' expected_args = { - 'mode':'fetch_response', + 'mode': 'fetch_response', 'type.%s' % (alias,): uri, 'count.%s' % (alias,): '0' - } + } req = ax.FetchRequest() req.add(ax.AttrInfo(uri)) msg = ax.FetchResponse(request=req) @@ -471,11 +473,11 @@ class FetchResponseTest(unittest.TestCase): alias = 'ext0' expected_args = { - 'mode':'fetch_response', + 'mode': 'fetch_response', 'update_url': self.request_update_url, 'type.%s' % (alias,): uri, 'count.%s' % (alias,): '0' - } + } req = ax.FetchRequest(update_url=self.request_update_url) req.add(ax.AttrInfo(uri)) msg = ax.FetchResponse(request=req) @@ -483,11 +485,11 @@ class FetchResponseTest(unittest.TestCase): def test_getExtensionArgs_some_request(self): expected_args = { - 'mode':'fetch_response', - 'type.' + self.alias_a:self.type_a, - 'value.' + self.alias_a + '.1':self.value_a, + 'mode': 'fetch_response', + 'type.' + self.alias_a: self.type_a, + 'value.' + self.alias_a + '.1': self.value_a, 'count.' + self.alias_a: '1' - } + } req = ax.FetchRequest() req.add(ax.AttrInfo(self.type_a, alias=self.alias_a)) msg = ax.FetchResponse(request=req) @@ -501,7 +503,6 @@ class FetchResponseTest(unittest.TestCase): self.failUnlessRaises(KeyError, msg.getExtensionArgs) def test_getSingle_success(self): - req = ax.FetchRequest() self.msg.addValue(self.type_a, self.value_a) self.failUnlessEqual(self.value_a, self.msg.getSingle(self.type_a)) @@ -520,9 +521,10 @@ class FetchResponseTest(unittest.TestCase): args = { 'mode': 'id_res', 'ns': OPENID2_NS, - } + } sf = ['openid.' + i for i in args.keys()] msg = Message.fromOpenIDArgs(args) + class Endpoint: claimed_id = 'http://invalid.' @@ -538,9 +540,10 @@ class FetchResponseTest(unittest.TestCase): 'ns': OPENID2_NS, 'ns.ax': ax.AXMessage.ns_uri, 'ax.mode': 'fetch_response', - } + } sf = ['openid.' + i for i in args.keys()] msg = Message.fromOpenIDArgs(args) + class Endpoint: claimed_id = 'http://invalid.' @@ -558,12 +561,13 @@ class FetchResponseTest(unittest.TestCase): 'ns.ax': ax.AXMessage.ns_uri, 'ax.update_url': 'http://example.com/realm/update_path', 'ax.mode': 'fetch_response', - 'ax.type.'+name: uri, - 'ax.count.'+name: '1', - 'ax.value.%s.1'%name: value, - } + 'ax.type.' + name: uri, + 'ax.count.' + name: '1', + 'ax.value.%s.1' % name: value, + } sf = ['openid.' + i for i in args.keys()] msg = Message.fromOpenIDArgs(args) + class Endpoint: claimed_id = 'http://invalid.' @@ -585,8 +589,8 @@ class StoreRequestTest(unittest.TestCase): def test_getExtensionArgs_empty(self): args = self.msg.getExtensionArgs() expected_args = { - 'mode':'store_request', - } + 'mode': 'store_request', + } self.failUnlessEqual(expected_args, args) def test_getExtensionArgs_nonempty(self): @@ -596,27 +600,28 @@ class StoreRequestTest(unittest.TestCase): msg.setValues(self.type_a, ['foo', 'bar']) args = msg.getExtensionArgs() expected_args = { - 'mode':'store_request', + 'mode': 'store_request', 'type.' + self.alias_a: self.type_a, 'count.' + self.alias_a: '2', - 'value.%s.1' % (self.alias_a,):'foo', - 'value.%s.2' % (self.alias_a,):'bar', - } + 'value.%s.1' % (self.alias_a,): 'foo', + 'value.%s.2' % (self.alias_a,): 'bar', + } self.failUnlessEqual(expected_args, args) + class StoreResponseTest(unittest.TestCase): def test_success(self): msg = ax.StoreResponse() self.failUnless(msg.succeeded()) self.failIf(msg.error_message) - self.failUnlessEqual({'mode':'store_response_success'}, + self.failUnlessEqual({'mode': 'store_response_success'}, msg.getExtensionArgs()) def test_fail_nomsg(self): msg = ax.StoreResponse(False) self.failIf(msg.succeeded()) self.failIf(msg.error_message) - self.failUnlessEqual({'mode':'store_response_failure'}, + self.failUnlessEqual({'mode': 'store_response_failure'}, msg.getExtensionArgs()) def test_fail_msg(self): @@ -624,5 +629,5 @@ class StoreResponseTest(unittest.TestCase): msg = ax.StoreResponse(False, reason) self.failIf(msg.succeeded()) self.failUnlessEqual(reason, msg.error_message) - self.failUnlessEqual({'mode':'store_response_failure', - 'error':reason}, msg.getExtensionArgs()) + self.failUnlessEqual({'mode': 'store_response_failure', + 'error': reason}, msg.getExtensionArgs()) diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index acab7c0..0549663 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -2,9 +2,8 @@ import cgi import time import unittest import urlparse -import warnings -from openid import association, cryptutil, dh, fetchers, kvform, oidutil +from openid import association, cryptutil, fetchers, kvform, oidutil from openid.consumer.consumer import (CANCEL, FAILURE, SETUP_NEEDED, SUCCESS, AuthRequest, CancelResponse, Consumer, DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA256ConsumerSession, FailureResponse, GenericConsumer, PlainTextConsumerSession, ProtocolError, @@ -26,7 +25,8 @@ from .support import CatchLogs assocs = [ ('another 20-byte key.', 'Snarky'), ('\x00' * 20, 'Zeros'), - ] +] + def mkSuccess(endpoint, q): """Convenience function to create a SuccessResponse with the given @@ -34,13 +34,15 @@ def mkSuccess(endpoint, q): signed_list = ['openid.' + k for k in q.keys()] return SuccessResponse(endpoint, Message.fromOpenIDArgs(q), signed_list) + def parseQuery(qs): q = {} for (k, v) in cgi.parse_qsl(qs): - assert not q.has_key(k) + assert k not in q q[k] = v return q + def associate(qs, assoc_secret, assoc_handle): """Do the server's half of the associate call, using the given secret and handle.""" @@ -48,10 +50,10 @@ def associate(qs, assoc_secret, assoc_handle): assert q['openid.mode'] == 'associate' assert q['openid.assoc_type'] == 'HMAC-SHA1' reply_dict = { - 'assoc_type':'HMAC-SHA1', - 'assoc_handle':assoc_handle, - 'expires_in':'600', - } + 'assoc_type': 'HMAC-SHA1', + 'assoc_handle': assoc_handle, + 'expires_in': '600', + } if q.get('openid.session_type') == 'DH-SHA1': assert len(q) == 6 or len(q) == 4 @@ -86,8 +88,9 @@ class GoodAssocStore(memstore.MemoryStore): class TestFetcher(object): - def __init__(self, user_url, user_page, (assoc_secret, assoc_handle)): - self.get_responses = {user_url:self.response(user_url, 200, user_page)} + def __init__(self, user_url, user_page, xxx_todo_changeme): + (assoc_secret, assoc_handle) = xxx_todo_changeme + self.get_responses = {user_url: self.response(user_url, 200, user_page)} self.assoc_secret = assoc_secret self.assoc_handle = assoc_handle self.num_assocs = 0 @@ -104,7 +107,7 @@ class TestFetcher(object): try: body.index('openid.mode=associate') except ValueError: - pass # fall through + pass # fall through else: assert body.find('DH-SHA1') != -1 response = associate( @@ -114,6 +117,7 @@ class TestFetcher(object): return self.response(url, 404, 'Not found') + def makeFastConsumerSession(): """ Create custom DH object so tests run quickly. @@ -121,9 +125,11 @@ def makeFastConsumerSession(): dh = DiffieHellman(100389557, 2) return DiffieHellmanSHA1ConsumerSession(dh) + def setConsumerSession(con): con.session_types = {'DH-SHA1': makeFastConsumerSession} + def _test_success(server_url, user_url, delegate_url, links, immediate=False): store = memstore.MemoryStore() if immediate: @@ -149,8 +155,6 @@ def _test_success(server_url, user_url, delegate_url, links, immediate=False): request = consumer.begin(endpoint) return_to = consumer_url - m = request.getMessage(trust_root, return_to, immediate) - redirect_url = request.redirectURL(trust_root, return_to, immediate) parsed = urlparse.urlparse(redirect_url) @@ -159,11 +163,11 @@ def _test_success(server_url, user_url, delegate_url, links, immediate=False): new_return_to = q['openid.return_to'] del q['openid.return_to'] assert q == { - 'openid.mode':mode, - 'openid.identity':delegate_url, - 'openid.trust_root':trust_root, - 'openid.assoc_handle':fetcher.assoc_handle, - }, (q, user_url, delegate_url, mode) + 'openid.mode': mode, + 'openid.identity': delegate_url, + 'openid.trust_root': trust_root, + 'openid.assoc_handle': fetcher.assoc_handle, + }, (q, user_url, delegate_url, mode) assert new_return_to.startswith(return_to) assert redirect_url.startswith(server_url) @@ -171,11 +175,11 @@ def _test_success(server_url, user_url, delegate_url, links, immediate=False): parsed = urlparse.urlparse(new_return_to) query = parseQuery(parsed[4]) query.update({ - 'openid.mode':'id_res', - 'openid.return_to':new_return_to, - 'openid.identity':delegate_url, - 'openid.assoc_handle':fetcher.assoc_handle, - }) + 'openid.mode': 'id_res', + 'openid.return_to': new_return_to, + 'openid.identity': delegate_url, + 'openid.assoc_handle': fetcher.assoc_handle, + }) assoc = store.getAssociation(server_url, fetcher.assoc_handle) @@ -207,6 +211,7 @@ http_server_url = 'http://server.example.com/' consumer_url = 'http://consumer.example.com/' https_server_url = 'https://server.example.com/' + class TestSuccess(unittest.TestCase, CatchLogs): server_url = http_server_url user_url = 'http://www.example.com/user.html' @@ -284,10 +289,12 @@ class TestIdRes(unittest.TestCase, CatchLogs): return True self.consumer._checkReturnTo = checkReturnTo complete = self.consumer.complete + def callCompleteWithoutReturnTo(message, endpoint): return complete(message, endpoint, None) self.consumer.complete = callCompleteWithoutReturnTo + class TestIdResCheckSignature(TestIdRes): def setUp(self): TestIdRes.setUp(self) @@ -302,22 +309,19 @@ class TestIdResCheckSignature(TestIdRes): 'openid.assoc_handle': self.assoc.handle, 'openid.signed': 'mode,identity,assoc_handle,signed', 'frobboz': 'banzit', - }) - + }) def test_sign(self): # assoc_handle to assoc with good sig self.consumer._idResCheckSignature(self.message, self.endpoint.server_url) - def test_signFailsWithBadSig(self): self.message.setArg(OPENID_NS, 'sig', 'BAD SIGNATURE') self.failUnlessRaises( ProtocolError, self.consumer._idResCheckSignature, self.message, self.endpoint.server_url) - def test_stateless(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle") @@ -364,11 +368,12 @@ class TestQueryFormat(TestIdRes): query = {'openid.mode': ['cancel']} try: r = Message.fromPostArgs(query) - except TypeError, err: + except TypeError as err: self.failUnless(str(err).find('values') != -1, err) else: self.fail("expected TypeError, got this instead: %s" % (r,)) + class TestComplete(TestIdRes): """Testing GenericConsumer.complete. @@ -404,9 +409,7 @@ class TestComplete(TestIdRes): def test_error(self): msg = 'an error message' - message = Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': msg, - }) + message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg}) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.failUnlessEqual(r.status, FAILURE) @@ -416,10 +419,7 @@ class TestComplete(TestIdRes): def test_errorWithNoOptionalKeys(self): msg = 'an error message' contact = 'some contact info here' - message = Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': msg, - 'openid.contact': contact, - }) + message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg, 'openid.contact': contact}) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.failUnlessEqual(r.status, FAILURE) @@ -432,10 +432,8 @@ class TestComplete(TestIdRes): msg = 'an error message' contact = 'me' reference = 'support ticket' - message = Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': msg, 'openid.reference': reference, - 'openid.contact': contact, 'openid.ns': OPENID2_NS, - }) + message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg, 'openid.reference': reference, + 'openid.contact': contact, 'openid.ns': OPENID2_NS}) r = self.consumer.complete(message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) self.failUnless(r.identity_url == self.endpoint.claimed_id) @@ -458,7 +456,8 @@ class TestComplete(TestIdRes): message, self.endpoint, None) def test_idResURLMismatch(self): - class VerifiedError(Exception): pass + class VerifiedError(Exception): + pass def discoverAndVerify(claimed_id, _to_match_endpoints): raise VerifiedError @@ -483,6 +482,7 @@ class TestComplete(TestIdRes): self.failUnlessLogMatches('Error attempting to use stored', 'Attempting discovery') + class TestCompleteMissingSig(unittest.TestCase, CatchLogs): def setUp(self): @@ -503,18 +503,17 @@ class TestCompleteMissingSig(unittest.TestCase, CatchLogs): 'signed': 'identity,return_to,response_nonce,assoc_handle,claimed_id,op_endpoint', 'claimed_id': claimed_id, 'op_endpoint': self.server_url, - 'ns':OPENID2_NS, + 'ns': OPENID2_NS, }) self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = self.server_url self.endpoint.claimed_id = claimed_id - self.consumer._checkReturnTo = lambda unused1, unused2 : True + self.consumer._checkReturnTo = lambda unused1, unused2: True def tearDown(self): CatchLogs.tearDown(self) - def test_idResMissingNoSigs(self): def _vrfy(resp_msg, endpoint=None): return endpoint @@ -523,7 +522,6 @@ class TestCompleteMissingSig(unittest.TestCase, CatchLogs): r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessSuccess(r) - def test_idResNoIdentity(self): self.message.delArg(OPENID_NS, 'identity') self.message.delArg(OPENID_NS, 'claimed_id') @@ -532,37 +530,31 @@ class TestCompleteMissingSig(unittest.TestCase, CatchLogs): r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessSuccess(r) - def test_idResMissingIdentitySig(self): self.message.setArg(OPENID_NS, 'signed', 'return_to,response_nonce,assoc_handle,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) - def test_idResMissingReturnToSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,assoc_handle,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) - def test_idResMissingAssocHandleSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,return_to,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) - def test_idResMissingClaimedIDSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,return_to,assoc_handle') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) - def failUnlessSuccess(self, response): if response.status != SUCCESS: self.fail("Non-successful response: %s" % (response,)) - class TestCheckAuthResponse(TestIdRes, CatchLogs): def setUp(self): CatchLogs.setUp(self) @@ -583,7 +575,7 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): def test_goodResponse(self): """successful response to check_authentication""" - response = Message.fromOpenIDArgs({'is_valid':'true',}) + response = Message.fromOpenIDArgs({'is_valid': 'true'}) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) @@ -595,7 +587,7 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): def test_badResponse(self): """check_authentication returns false when is_valid is false""" - response = Message.fromOpenIDArgs({'is_valid':'false',}) + response = Message.fromOpenIDArgs({'is_valid': 'false'}) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failIf(r) @@ -610,9 +602,9 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): """ self._createAssoc() response = Message.fromOpenIDArgs({ - 'is_valid':'false', - 'invalidate_handle':'handle', - }) + 'is_valid': 'false', + 'invalidate_handle': 'handle', + }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failIf(r) self.failUnless( @@ -621,21 +613,21 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): def test_invalidateMissing(self): """invalidate_handle with a handle that is not present""" response = Message.fromOpenIDArgs({ - 'is_valid':'true', - 'invalidate_handle':'missing', - }) + 'is_valid': 'true', + 'invalidate_handle': 'missing', + }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) self.failUnlessLogMatches( 'Received "invalidate_handle"' - ) + ) def test_invalidateMissing_noStore(self): """invalidate_handle with a handle that is not present""" response = Message.fromOpenIDArgs({ - 'is_valid':'true', - 'invalidate_handle':'missing', - }) + 'is_valid': 'true', + 'invalidate_handle': 'missing', + }) self.consumer.store = None r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) @@ -654,19 +646,20 @@ class TestCheckAuthResponse(TestIdRes, CatchLogs): """ self._createAssoc() response = Message.fromOpenIDArgs({ - 'is_valid':'true', - 'invalidate_handle':'handle', - }) + 'is_valid': 'true', + 'invalidate_handle': 'handle', + }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) self.failUnless( self.consumer.store.getAssociation(self.server_url) is None) + class TestSetupNeeded(TestIdRes): def failUnlessSetupNeeded(self, expected_setup_url, message): try: self.consumer._checkSetupNeeded(message) - except SetupNeededError, why: + except SetupNeededError as why: self.failUnlessEqual(expected_setup_url, why.user_setup_url) else: self.fail("Expected to find an immediate-mode response") @@ -677,7 +670,7 @@ class TestSetupNeeded(TestIdRes): message = Message.fromPostArgs({ 'openid.mode': 'id_res', 'openid.user_setup_url': setup_url, - }) + }) self.failUnless(message.isOpenID1()) self.failUnlessSetupNeeded(setup_url, message) @@ -688,7 +681,7 @@ class TestSetupNeeded(TestIdRes): 'openid.mode': 'id_res', 'openid.user_setup_url': setup_url, 'openid.identity': 'bogus', - }) + }) self.failUnless(message.isOpenID1()) self.failUnlessSetupNeeded(setup_url, message) @@ -703,9 +696,9 @@ class TestSetupNeeded(TestIdRes): def test_setupNeededOpenID2(self): message = Message.fromOpenIDArgs({ - 'mode':'setup_needed', - 'ns':OPENID2_NS, - }) + 'mode': 'setup_needed', + 'ns': OPENID2_NS, + }) self.failUnless(message.isOpenID2()) response = self.consumer.complete(message, None, None) self.failUnlessEqual('setup_needed', response.status) @@ -713,8 +706,8 @@ class TestSetupNeeded(TestIdRes): def test_setupNeededDoesntWorkForOpenID1(self): message = Message.fromOpenIDArgs({ - 'mode':'setup_needed', - }) + 'mode': 'setup_needed', + }) # No SetupNeededError raised self.consumer._checkSetupNeeded(message) @@ -725,15 +718,16 @@ class TestSetupNeeded(TestIdRes): def test_noSetupNeededOpenID2(self): message = Message.fromOpenIDArgs({ - 'mode':'id_res', - 'game':'puerto_rico', - 'ns':OPENID2_NS, - }) + 'mode': 'id_res', + 'game': 'puerto_rico', + 'ns': OPENID2_NS, + }) self.failUnless(message.isOpenID2()) # No SetupNeededError raised self.consumer._checkSetupNeeded(message) + class IdResCheckForFieldsTest(TestIdRes): def setUp(self): self.consumer = GenericConsumer(None) @@ -746,32 +740,32 @@ class IdResCheckForFieldsTest(TestIdRes): return test test_openid1Success = mkSuccessTest( - {'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'identity':'someone', + {'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', }, ['return_to', 'identity']) test_openid2Success = mkSuccessTest( - {'ns':OPENID2_NS, - 'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'op_endpoint':'my favourite server', - 'response_nonce':'use only once', + {'ns': OPENID2_NS, + 'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'op_endpoint': 'my favourite server', + 'response_nonce': 'use only once', }, ['return_to', 'response_nonce', 'assoc_handle', 'op_endpoint']) test_openid2Success_identifiers = mkSuccessTest( - {'ns':OPENID2_NS, - 'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'claimed_id':'i claim to be me', - 'identity':'my server knows me as me', - 'op_endpoint':'my favourite server', - 'response_nonce':'use only once', + {'ns': OPENID2_NS, + 'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'claimed_id': 'i claim to be me', + 'identity': 'my server knows me as me', + 'op_endpoint': 'my favourite server', + 'response_nonce': 'use only once', }, ['return_to', 'response_nonce', 'identity', 'claimed_id', 'assoc_handle', 'op_endpoint']) @@ -781,7 +775,7 @@ class IdResCheckForFieldsTest(TestIdRes): message = Message.fromOpenIDArgs(openid_args) try: self.consumer._idResCheckForFields(message) - except ProtocolError, why: + except ProtocolError as why: self.failUnless(why[0].startswith('Missing required')) else: self.fail('Expected an error, but none occurred') @@ -792,53 +786,56 @@ class IdResCheckForFieldsTest(TestIdRes): message = Message.fromOpenIDArgs(openid_args) try: self.consumer._idResCheckForFields(message) - except ProtocolError, why: + except ProtocolError as why: self.failUnless(why[0].endswith('not signed')) else: self.fail('Expected an error, but none occurred') return test test_openid1Missing_returnToSig = mkMissingSignedTest( - {'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'identity':'someone', - 'signed':'identity', + {'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', + 'signed': 'identity', }) test_openid1Missing_identitySig = mkMissingSignedTest( - {'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'identity':'someone', - 'signed':'return_to' + {'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', + 'signed': 'return_to' }) test_openid2Missing_opEndpointSig = mkMissingSignedTest( - {'ns':OPENID2_NS, - 'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'identity':'someone', - 'op_endpoint':'the endpoint', - 'signed':'return_to,identity,assoc_handle' + {'ns': OPENID2_NS, + 'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', + 'op_endpoint': 'the endpoint', + 'signed': 'return_to,identity,assoc_handle' }) test_openid1MissingReturnTo = mkMissingFieldTest( - {'assoc_handle':'assoc handle', - 'sig':'a signature', - 'identity':'someone', + {'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', }) test_openid1MissingAssocHandle = mkMissingFieldTest( - {'return_to':'return', - 'sig':'a signature', - 'identity':'someone', + {'return_to': 'return', + 'sig': 'a signature', + 'identity': 'someone', }) # XXX: I could go on... -class CheckAuthHappened(Exception): pass + +class CheckAuthHappened(Exception): + pass + class CheckNonceVerifyTest(TestIdRes, CatchLogs): def setUp(self): @@ -869,24 +866,21 @@ class CheckNonceVerifyTest(TestIdRes, CatchLogs): """OpenID 2 does not use consumer-generated nonce""" self.return_to = 'http://rt.unittest/?nonce=%s' % (mkNonce(),) self.response = Message.fromOpenIDArgs( - {'return_to': self.return_to, 'ns':OPENID2_NS}) + {'return_to': self.return_to, 'ns': OPENID2_NS}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) self.failUnlessLogEmpty() def test_serverNonce(self): """use server-generated nonce""" - self.response = Message.fromOpenIDArgs( - {'ns':OPENID2_NS, 'response_nonce': mkNonce(),}) + self.response = Message.fromOpenIDArgs({'ns': OPENID2_NS, 'response_nonce': mkNonce()}) self.consumer._idResCheckNonce(self.response, self.endpoint) self.failUnlessLogEmpty() def test_serverNonceOpenID1(self): """OpenID 1 does not use server-generated nonce""" self.response = Message.fromOpenIDArgs( - {'ns':OPENID1_NS, - 'return_to': 'http://return.to/', - 'response_nonce': mkNonce(),}) + {'ns': OPENID1_NS, 'return_to': 'http://return.to/', 'response_nonce': mkNonce()}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) self.failUnlessLogEmpty() @@ -905,38 +899,31 @@ class CheckNonceVerifyTest(TestIdRes, CatchLogs): nonce = mkNonce() stamp, salt = splitNonce(nonce) self.store.useNonce(self.server_url, stamp, salt) - self.response = Message.fromOpenIDArgs( - {'response_nonce': nonce, - 'ns':OPENID2_NS, - }) + self.response = Message.fromOpenIDArgs({'response_nonce': nonce, 'ns': OPENID2_NS}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) def test_successWithNoStore(self): """When there is no store, checking the nonce succeeds""" self.consumer.store = None - self.response = Message.fromOpenIDArgs( - {'response_nonce': mkNonce(), - 'ns':OPENID2_NS, - }) + self.response = Message.fromOpenIDArgs({'response_nonce': mkNonce(), 'ns': OPENID2_NS}) self.consumer._idResCheckNonce(self.response, self.endpoint) self.failUnlessLogEmpty() def test_tamperedNonce(self): """Malformed nonce""" - self.response = Message.fromOpenIDArgs( - {'ns':OPENID2_NS, - 'response_nonce':'malformed'}) + self.response = Message.fromOpenIDArgs({'ns': OPENID2_NS, 'response_nonce': 'malformed'}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) def test_missingNonce(self): """no nonce parameter on the return_to""" self.response = Message.fromOpenIDArgs( - {'return_to': self.return_to}) + {'return_to': self.return_to}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) + class CheckAuthDetectingConsumer(GenericConsumer): def _checkAuth(self, *args): raise CheckAuthHappened(args) @@ -946,6 +933,7 @@ class CheckAuthDetectingConsumer(GenericConsumer): when it asks.""" return True + class TestCheckAuthTriggered(TestIdRes, CatchLogs): consumer_class = CheckAuthDetectingConsumer @@ -956,12 +944,12 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): def test_checkAuthTriggered(self): message = Message.fromPostArgs({ - 'openid.return_to':self.return_to, - 'openid.identity':self.server_id, - 'openid.assoc_handle':'not_found', + 'openid.return_to': self.return_to, + 'openid.identity': self.server_id, + 'openid.assoc_handle': 'not_found', 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', - }) + }) self.disableReturnToChecking() try: result = self.consumer._doIdRes(message, self.endpoint, None) @@ -981,12 +969,12 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): self.store.storeAssociation(self.server_url, assoc) self.disableReturnToChecking() message = Message.fromPostArgs({ - 'openid.return_to':self.return_to, - 'openid.identity':self.server_id, - 'openid.assoc_handle':'not_found', + 'openid.return_to': self.return_to, + 'openid.identity': self.server_id, + 'openid.assoc_handle': 'not_found', 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', - }) + }) try: result = self.consumer._doIdRes(message, self.endpoint, None) except CheckAuthHappened: @@ -1006,12 +994,12 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): self.store.storeAssociation(self.server_url, assoc) message = Message.fromPostArgs({ - 'openid.return_to':self.return_to, - 'openid.identity':self.server_id, - 'openid.assoc_handle':handle, + 'openid.return_to': self.return_to, + 'openid.identity': self.server_id, + 'openid.assoc_handle': handle, 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', - }) + }) self.disableReturnToChecking() self.failUnlessRaises(ProtocolError, self.consumer._doIdRes, message, self.endpoint, None) @@ -1032,10 +1020,10 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): self.store.storeAssociation(self.server_url, bad_assoc) query = { - 'return_to':self.return_to, - 'identity':self.server_id, - 'assoc_handle':good_handle, - } + 'return_to': self.return_to, + 'identity': self.server_id, + 'assoc_handle': good_handle, + } message = Message.fromOpenIDArgs(query) message = good_assoc.signMessage(message) @@ -1045,7 +1033,6 @@ class TestCheckAuthTriggered(TestIdRes, CatchLogs): self.failUnlessEqual(self.consumer_id, info.identity_url) - class TestReturnToArgs(unittest.TestCase): """Verifying the Return URL paramaters. From the specification "Verifying the Return URL":: @@ -1073,7 +1060,7 @@ class TestReturnToArgs(unittest.TestCase): 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/?foo=bar', 'foo': 'bar', - } + } # no return value, success is assumed if there are no exceptions. self.consumer._verifyReturnToArgs(query) @@ -1082,7 +1069,7 @@ class TestReturnToArgs(unittest.TestCase): 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/?foo=', 'foo': '', - } + } # no return value, success is assumed if there are no exceptions. self.consumer._verifyReturnToArgs(query) @@ -1091,7 +1078,7 @@ class TestReturnToArgs(unittest.TestCase): 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/', 'foo': 'bar', - } + } # no return value, success is assumed if there are no exceptions. self.failUnlessRaises(ProtocolError, self.consumer._verifyReturnToArgs, query) @@ -1100,7 +1087,7 @@ class TestReturnToArgs(unittest.TestCase): query = { 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/?foo=bar', - } + } # fail, query has no key 'foo'. self.failUnlessRaises(ValueError, self.consumer._verifyReturnToArgs, query) @@ -1110,7 +1097,6 @@ class TestReturnToArgs(unittest.TestCase): self.failUnlessRaises(ValueError, self.consumer._verifyReturnToArgs, query) - def test_noReturnTo(self): query = {'openid.mode': 'id_res'} self.failUnlessRaises(ValueError, @@ -1135,12 +1121,11 @@ class TestReturnToArgs(unittest.TestCase): # Query args differ "http://some.url/path?foo=bar2", "http://some.url/path?foo2=bar", - ] + ] m = Message(OPENID1_NS) m.setArg(OPENID_NS, 'mode', 'cancel') m.setArg(BARE_NS, 'foo', 'bar') - endpoint = None for bad in bad_return_tos: m.setArg(OPENID_NS, 'return_to', bad) @@ -1156,12 +1141,12 @@ class TestReturnToArgs(unittest.TestCase): (return_to, {}), (return_to + "?another=arg", {(BARE_NS, 'another'): 'arg'}), (return_to + "?another=arg#fragment", {(BARE_NS, 'another'): 'arg'}), - ("HTTP"+return_to[4:], {}), - (return_to.replace('url','URL'), {}), + ("HTTP" + return_to[4:], {}), + (return_to.replace('url', 'URL'), {}), ("http://some.url:80/path", {}), ("http://some.url/p%61th", {}), ("http://some.url/./path", {}), - ] + ] endpoint = None @@ -1174,9 +1159,10 @@ class TestReturnToArgs(unittest.TestCase): m.setArg(OPENID_NS, 'return_to', good) result = self.consumer.complete(m, endpoint, return_to) - self.failUnless(isinstance(result, CancelResponse), \ + self.failUnless(isinstance(result, CancelResponse), "Expected CancelResponse, got %r for %s" % (result, good,)) + class MockFetcher(object): def __init__(self, response=None): self.response = response or HTTPResponse() @@ -1186,6 +1172,7 @@ class MockFetcher(object): self.fetches.append((url, body, headers)) return self.response + class ExceptionRaisingMockFetcher(object): class MyException(Exception): pass @@ -1193,15 +1180,17 @@ class ExceptionRaisingMockFetcher(object): def fetch(self, url, body=None, headers=None): raise self.MyException('mock fetcher exception') + class BadArgCheckingConsumer(GenericConsumer): def _makeKVPost(self, args, _): assert args == { - 'openid.mode':'check_authentication', - 'openid.signed':'foo', - 'openid.ns':OPENID1_NS - }, args + 'openid.mode': 'check_authentication', + 'openid.signed': 'foo', + 'openid.ns': OPENID1_NS + }, args return None + class TestCheckAuth(unittest.TestCase, CatchLogs): consumer_class = GenericConsumer @@ -1223,7 +1212,7 @@ class TestCheckAuth(unittest.TestCase, CatchLogs): self.fetcher.response = HTTPResponse( "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n') query = {'openid.signed': 'stuff', - 'openid.stuff':'a value'} + 'openid.stuff': 'a value'} r = self.consumer._checkAuth(Message.fromPostArgs(query), http_server_url) self.failIf(r) @@ -1231,13 +1220,12 @@ class TestCheckAuth(unittest.TestCase, CatchLogs): def test_bad_args(self): query = { - 'openid.signed':'foo', - 'closid.foo':'something', - } + 'openid.signed': 'foo', + 'closid.foo': 'something', + } consumer = BadArgCheckingConsumer(self.store) consumer._checkAuth(Message.fromPostArgs(query), 'does://not.matter') - def test_signedList(self): query = Message.fromOpenIDArgs({ 'mode': 'id_res', @@ -1248,41 +1236,41 @@ class TestCheckAuth(unittest.TestCase, CatchLogs): 'sreg.email': 'bogus@example.com', 'signed': 'identity,mode,ns.sreg,sreg.email', 'foo': 'bar', - }) + }) args = self.consumer._createCheckAuthRequest(query) self.failUnless(args.isOpenID1()) for signed_arg in query.getArg(OPENID_NS, 'signed').split(','): - self.failUnless(args.getAliasedArg(signed_arg), signed_arg) + self.failUnless(args.getAliasedArg(signed_arg), signed_arg) def test_112(self): - args = {'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', - 'openid.claimed_id': 'http://binkley.lan/user/test01', - 'openid.identity': 'http://test01.binkley.lan/', - 'openid.mode': 'id_res', - 'openid.ns': 'http://specs.openid.net/auth/2.0', - 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', - 'openid.op_endpoint': 'http://binkley.lan/server', - 'openid.pape.auth_policies': 'none', - 'openid.pape.auth_time': '2008-01-28T20:42:36Z', - 'openid.pape.nist_auth_level': '0', - 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', - 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', - 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', - 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies' - } + args = { + 'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', + 'openid.claimed_id': 'http://binkley.lan/user/test01', + 'openid.identity': 'http://test01.binkley.lan/', + 'openid.mode': 'id_res', + 'openid.ns': 'http://specs.openid.net/auth/2.0', + 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', + 'openid.op_endpoint': 'http://binkley.lan/server', + 'openid.pape.auth_policies': 'none', + 'openid.pape.auth_time': '2008-01-28T20:42:36Z', + 'openid.pape.nist_auth_level': '0', + 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', + 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', + 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', + 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,' + 'ns.pape,pape.nist_auth_level,pape.auth_policies'} self.failUnlessEqual(OPENID2_NS, args['openid.ns']) incoming = Message.fromPostArgs(args) self.failUnless(incoming.isOpenID2()) car = self.consumer._createCheckAuthRequest(incoming) expected_args = args.copy() expected_args['openid.mode'] = 'check_authentication' - expected =Message.fromPostArgs(expected_args) + expected = Message.fromPostArgs(expected_args) self.failUnless(expected.isOpenID2()) self.failUnlessEqual(expected, car) self.failUnlessEqual(expected_args, car.toPostArgs()) - class TestFetchAssoc(unittest.TestCase, CatchLogs): consumer_class = GenericConsumer @@ -1300,7 +1288,7 @@ class TestFetchAssoc(unittest.TestCase, CatchLogs): self.failUnlessRaises( fetchers.HTTPFetchingError, self.consumer._makeKVPost, - Message.fromPostArgs({'mode':'associate'}), + Message.fromPostArgs({'mode': 'associate'}), "http://server_url") def test_error_exception_unwrapped(self): @@ -1311,7 +1299,7 @@ class TestFetchAssoc(unittest.TestCase, CatchLogs): fetchers.setDefaultFetcher(self.fetcher, wrap_exceptions=False) self.failUnlessRaises(self.fetcher.MyException, self.consumer._makeKVPost, - Message.fromPostArgs({'mode':'associate'}), + Message.fromPostArgs({'mode': 'associate'}), "http://server_url") # exception fetching returns no association @@ -1322,7 +1310,7 @@ class TestFetchAssoc(unittest.TestCase, CatchLogs): self.failUnlessRaises(self.fetcher.MyException, self.consumer._checkAuth, - Message.fromPostArgs({'openid.signed':''}), + Message.fromPostArgs({'openid.signed': ''}), 'some://url') def test_error_exception_wrapped(self): @@ -1334,7 +1322,7 @@ class TestFetchAssoc(unittest.TestCase, CatchLogs): fetchers.setDefaultFetcher(self.fetcher) self.failUnlessRaises(fetchers.HTTPFetchingError, self.consumer._makeKVPost, - Message.fromOpenIDArgs({'mode':'associate'}), + Message.fromOpenIDArgs({'mode': 'associate'}), "http://server_url") # exception fetching returns no association @@ -1342,7 +1330,7 @@ class TestFetchAssoc(unittest.TestCase, CatchLogs): e.server_url = 'some://url' self.failUnless(self.consumer._getAssociation(e) is None) - msg = Message.fromPostArgs({'openid.signed':''}) + msg = Message.fromPostArgs({'openid.signed': ''}) self.failIf(self.consumer._checkAuth(msg, 'some://url')) @@ -1353,33 +1341,33 @@ class TestSuccessResponse(unittest.TestCase): def test_extensionResponse(self): resp = mkSuccess(self.endpoint, { - 'ns.sreg':'urn:sreg', - 'ns.unittest':'urn:unittest', - 'unittest.one':'1', - 'unittest.two':'2', - 'sreg.nickname':'j3h', - 'return_to':'return_to', - }) + 'ns.sreg': 'urn:sreg', + 'ns.unittest': 'urn:unittest', + 'unittest.one': '1', + 'unittest.two': '2', + 'sreg.nickname': 'j3h', + 'return_to': 'return_to', + }) utargs = resp.extensionResponse('urn:unittest', False) - self.failUnlessEqual(utargs, {'one':'1', 'two':'2'}) + self.failUnlessEqual(utargs, {'one': '1', 'two': '2'}) sregargs = resp.extensionResponse('urn:sreg', False) - self.failUnlessEqual(sregargs, {'nickname':'j3h'}) + self.failUnlessEqual(sregargs, {'nickname': 'j3h'}) def test_extensionResponseSigned(self): args = { - 'ns.sreg':'urn:sreg', - 'ns.unittest':'urn:unittest', - 'unittest.one':'1', - 'unittest.two':'2', - 'sreg.nickname':'j3h', - 'sreg.dob':'yesterday', - 'return_to':'return_to', + 'ns.sreg': 'urn:sreg', + 'ns.unittest': 'urn:unittest', + 'unittest.one': '1', + 'unittest.two': '2', + 'sreg.nickname': 'j3h', + 'sreg.dob': 'yesterday', + 'return_to': 'return_to', 'signed': 'sreg.nickname,unittest.one,sreg.dob', - } + } signed_list = ['openid.sreg.nickname', 'openid.unittest.one', - 'openid.sreg.dob',] + 'openid.sreg.dob'] # Don't use mkSuccess because it creates an all-inclusive # signed list. @@ -1388,7 +1376,7 @@ class TestSuccessResponse(unittest.TestCase): # All args in this NS are signed, so expect all. sregargs = resp.extensionResponse('urn:sreg', True) - self.failUnlessEqual(sregargs, {'nickname':'j3h', 'dob': 'yesterday'}) + self.failUnlessEqual(sregargs, {'nickname': 'j3h', 'dob': 'yesterday'}) # Not all args in this NS are signed, so expect None when # asking for them. @@ -1400,7 +1388,7 @@ class TestSuccessResponse(unittest.TestCase): self.failUnless(resp.getReturnTo() is None) def test_returnTo(self): - resp = mkSuccess(self.endpoint, {'return_to':'return_to'}) + resp = mkSuccess(self.endpoint, {'return_to': 'return_to'}) self.failUnlessEqual(resp.getReturnTo(), 'return_to') def test_displayIdentifierClaimedId(self): @@ -1414,6 +1402,7 @@ class TestSuccessResponse(unittest.TestCase): self.failUnlessEqual(resp.getDisplayIdentifier(), "http://input.url/") + class StubConsumer(object): def __init__(self): self.assoc = object() @@ -1429,11 +1418,13 @@ class StubConsumer(object): assert endpoint is self.endpoint return self.response + class ConsumerTest(unittest.TestCase): """Tests for high-level consumer.Consumer functions. Its GenericConsumer component is stubbed out with StubConsumer. """ + def setUp(self): self.endpoint = OpenIDServiceEndpoint() self.endpoint.claimed_id = self.identity_url = 'http://identity.url/' @@ -1473,13 +1464,14 @@ class ConsumerTest(unittest.TestCase): def test_beginHTTPError(self): """Make sure that the discovery HTTP failure case behaves properly """ + def getNextService(self, ignored): raise HTTPFetchingError("Unit test") def test(): try: self.consumer.begin('unused in this test') - except DiscoveryFailure, why: + except DiscoveryFailure as why: self.failUnless(why[0].startswith('Error fetching')) self.failIf(why[0].find('Unit test') == -1) else: @@ -1492,10 +1484,11 @@ class ConsumerTest(unittest.TestCase): return None url = 'http://a.user.url/' + def test(): try: self.consumer.begin(url) - except DiscoveryFailure, why: + except DiscoveryFailure as why: self.failUnless(why[0].startswith('No usable OpenID')) self.failIf(why[0].find(url) == -1) else: @@ -1503,7 +1496,6 @@ class ConsumerTest(unittest.TestCase): self.withDummyDiscovery(test, getNextService) - def test_beginWithoutDiscovery(self): # Does this really test anything non-trivial? result = self.consumer.beginWithoutDiscovery(self.endpoint) @@ -1631,9 +1623,7 @@ class ConsumerTest(unittest.TestCase): resp_endpoint = OpenIDServiceEndpoint() resp_endpoint.claimed_id = "http://user.url/" - resp = self._doRespDisco( - True, - mkSuccess(resp_endpoint, {})) + self._doRespDisco(True, mkSuccess(resp_endpoint, {})) self.failUnless(self.discovery.getManager(force=True) is None) def test_begin(self): @@ -1646,7 +1636,6 @@ class ConsumerTest(unittest.TestCase): self.failUnless(auth_req.assoc is self.consumer.consumer.assoc) - class IDPDrivenTest(unittest.TestCase): def setUp(self): @@ -1655,12 +1644,10 @@ class IDPDrivenTest(unittest.TestCase): self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = "http://idp.unittest/" - def test_idpDrivenBegin(self): # Testing here that the token-handling doesn't explode... self.consumer.begin(self.endpoint) - def test_idpDrivenComplete(self): identifier = '=directed_identifier' message = Message.fromPostArgs({ @@ -1669,20 +1656,21 @@ class IDPDrivenTest(unittest.TestCase): 'openid.assoc_handle': 'z', 'openid.signed': 'identity,return_to', 'openid.sig': GOODSIG, - }) + }) discovered_endpoint = OpenIDServiceEndpoint() discovered_endpoint.claimed_id = identifier discovered_endpoint.server_url = self.endpoint.server_url discovered_endpoint.local_id = identifier iverified = [] + def verifyDiscoveryResults(identifier, endpoint): self.failUnless(endpoint is self.endpoint) iverified.append(discovered_endpoint) return discovered_endpoint self.consumer._verifyDiscoveryResults = verifyDiscoveryResults self.consumer._idResCheckNonce = lambda *args: True - self.consumer._checkReturnTo = lambda unused1, unused2 : True + self.consumer._checkReturnTo = lambda unused1, unused2: True response = self.consumer._doIdRes(message, self.endpoint, None) self.failUnlessSuccess(response) @@ -1691,7 +1679,6 @@ class IDPDrivenTest(unittest.TestCase): # assert that discovery attempt happens and returns good self.failUnlessEqual(iverified, [discovered_endpoint]) - def test_idpDrivenCompleteFraud(self): # crap with an identifier that doesn't match discovery info message = Message.fromPostArgs({ @@ -1700,21 +1687,20 @@ class IDPDrivenTest(unittest.TestCase): 'openid.assoc_handle': 'z', 'openid.signed': 'identity,return_to', 'openid.sig': GOODSIG, - }) + }) + def verifyDiscoveryResults(identifier, endpoint): raise DiscoveryFailure("PHREAK!", None) self.consumer._verifyDiscoveryResults = verifyDiscoveryResults - self.consumer._checkReturnTo = lambda unused1, unused2 : True + self.consumer._checkReturnTo = lambda unused1, unused2: True self.failUnlessRaises(DiscoveryFailure, self.consumer._doIdRes, message, self.endpoint, None) - def failUnlessSuccess(self, response): if response.status != SUCCESS: self.fail("Non-successful response: %s" % (response,)) - class TestDiscoveryVerification(unittest.TestCase): services = [] @@ -1732,7 +1718,7 @@ class TestDiscoveryVerification(unittest.TestCase): 'openid.identity': self.identifier, 'openid.claimed_id': self.identifier, 'openid.op_endpoint': self.server_url, - }) + }) self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = self.server_url @@ -1748,7 +1734,6 @@ class TestDiscoveryVerification(unittest.TestCase): self.failUnlessEqual(r, endpoint) - def test_otherServer(self): text = "verify failed" @@ -1769,12 +1754,11 @@ class TestDiscoveryVerification(unittest.TestCase): self.services = [endpoint] try: r = self.consumer._verifyDiscoveryResults(self.message, endpoint) - except ProtocolError, e: + except ProtocolError as e: # Should we make more ProtocolError subclasses? self.failUnless(str(e), text) else: self.fail("expected ProtocolError, %r returned." % (r,)) - def test_foreignDelegate(self): text = "verify failed" @@ -1796,7 +1780,7 @@ class TestDiscoveryVerification(unittest.TestCase): try: r = self.consumer._verifyDiscoveryResults(self.message, endpoint) - except ProtocolError, e: + except ProtocolError as e: self.failUnlessEqual(str(e), text) else: self.fail("Exepected ProtocolError, %r returned" % (r,)) @@ -1808,7 +1792,6 @@ class TestDiscoveryVerification(unittest.TestCase): self.consumer._verifyDiscoveryResults, self.message, self.endpoint) - def discoveryFunc(self, identifier): return identifier, self.services @@ -1832,10 +1815,10 @@ class TestCreateAssociationRequest(unittest.TestCase): self.failUnless(isinstance(session, PlainTextConsumerSession)) expected = Message.fromOpenIDArgs( - {'ns':OPENID2_NS, - 'session_type':session_type, - 'mode':'associate', - 'assoc_type':self.assoc_type, + {'ns': OPENID2_NS, + 'session_type': session_type, + 'mode': 'associate', + 'assoc_type': self.assoc_type, }) self.failUnlessEqual(expected, args) @@ -1847,9 +1830,9 @@ class TestCreateAssociationRequest(unittest.TestCase): self.endpoint, self.assoc_type, session_type) self.failUnless(isinstance(session, PlainTextConsumerSession)) - self.failUnlessEqual(Message.fromOpenIDArgs({'mode':'associate', - 'assoc_type':self.assoc_type, - }), args) + self.failUnlessEqual( + Message.fromOpenIDArgs({'mode': 'associate', 'assoc_type': self.assoc_type}), + args) def test_dhSHA1Compatibility(self): # Set the consumer's session type to a fast session since we @@ -1870,9 +1853,9 @@ class TestCreateAssociationRequest(unittest.TestCase): # OK, session_type is set here and not for no-encryption # compatibility - expected = Message.fromOpenIDArgs({'mode':'associate', - 'session_type':'DH-SHA1', - 'assoc_type':self.assoc_type, + expected = Message.fromOpenIDArgs({'mode': 'associate', + 'session_type': 'DH-SHA1', + 'assoc_type': self.assoc_type, 'dh_modulus': 'BfvStQ==', 'dh_gen': 'Ag==', }) @@ -1881,6 +1864,7 @@ class TestCreateAssociationRequest(unittest.TestCase): # XXX: test the other types + class TestDiffieHellmanResponseParameters(object): session_cls = None message_namespace = None @@ -1933,10 +1917,12 @@ class TestDiffieHellmanResponseParameters(object): self.failUnlessRaises(ValueError, self.consumer_session.extractSecret, self.msg) + class TestOpenID1SHA1(TestDiffieHellmanResponseParameters, unittest.TestCase): session_cls = DiffieHellmanSHA1ConsumerSession message_namespace = OPENID1_NS + class TestOpenID2SHA1(TestDiffieHellmanResponseParameters, unittest.TestCase): session_cls = DiffieHellmanSHA1ConsumerSession message_namespace = OPENID2_NS @@ -1960,18 +1946,17 @@ class TestNoStore(unittest.TestCase): endpoint.claimed_id = 'identity_url' self.consumer._getAssociation = notCalled - auth_request = self.consumer.begin(endpoint) + self.consumer.begin(endpoint) # _getAssociation was not called - - class NonAnonymousAuthRequest(object): endpoint = 'unused' def setAnonymous(self, unused): raise ValueError('Should trigger ProtocolError') + class TestConsumerAnonymous(unittest.TestCase): def test_beginWithoutDiscoveryAnonymousFail(self): """Make sure that ValueError for setting an auth request @@ -1979,6 +1964,7 @@ class TestConsumerAnonymous(unittest.TestCase): """ sess = {} consumer = Consumer(sess, None) + def bogusBegin(unused): return NonAnonymousAuthRequest() consumer.consumer.begin = bogusBegin @@ -1991,6 +1977,7 @@ class TestDiscoverAndVerify(unittest.TestCase): def setUp(self): self.consumer = GenericConsumer(None) self.discovery_result = None + def dummyDiscover(unused_identifier): return self.discovery_result self.consumer._discover = dummyDiscover @@ -2014,6 +2001,7 @@ class TestDiscoverAndVerify(unittest.TestCase): assertion, then we end up raising a ProtocolError """ self.discovery_result = (None, ['unused']) + def raiseProtocolError(unused1, unused2): raise ProtocolError('unit test') self.consumer._verifyDiscoverySingle = raiseProtocolError @@ -2037,12 +2025,14 @@ class TestDiscoverAndVerify(unittest.TestCase): 'http://claimed.id/', [self.to_match]) self.failUnlessEqual(matching_endpoint, result) + class SillyExtension(Extension): ns_uri = 'http://silly.example.com/' ns_alias = 'silly' def getExtensionArgs(self): - return {'i_am':'silly'} + return {'i_am': 'silly'} + class TestAddExtension(unittest.TestCase): @@ -2054,7 +2044,6 @@ class TestAddExtension(unittest.TestCase): self.failUnlessEqual(ext.getExtensionArgs(), ext_args) - class TestKVPost(unittest.TestCase): def setUp(self): self.server_url = 'http://unittest/%s' % (self.id(),) @@ -2065,23 +2054,21 @@ class TestKVPost(unittest.TestCase): response.status = 200 response.body = "foo:bar\nbaz:quux\n" r = _httpResponseToMessage(response, self.server_url) - expected_msg = Message.fromOpenIDArgs({'foo':'bar','baz':'quux'}) + expected_msg = Message.fromOpenIDArgs({'foo': 'bar', 'baz': 'quux'}) self.failUnlessEqual(expected_msg, r) - def test_400(self): response = HTTPResponse() response.status = 400 response.body = "error:bonk\nerror_code:7\n" try: r = _httpResponseToMessage(response, self.server_url) - except ServerError, e: + except ServerError as e: self.failUnlessEqual(e.error_text, 'bonk') self.failUnlessEqual(e.error_code, '7') else: self.fail("Expected ServerError, got return %r" % (r,)) - def test_500(self): # 500 as an example of any non-200, non-400 code. response = HTTPResponse() @@ -2092,7 +2079,5 @@ class TestKVPost(unittest.TestCase): self.server_url) - - if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index a09a1f2..29a73ff 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -2,7 +2,6 @@ import os.path import sys import unittest -import warnings from urlparse import urlsplit from openid import fetchers, message @@ -14,7 +13,8 @@ from openid.yadis.xri import XRI from . import datadriven -### Tests for conditions that trigger DiscoveryFailure +# Tests for conditions that trigger DiscoveryFailure + class SimpleMockFetcher(object): def __init__(self, responses): @@ -26,6 +26,7 @@ class SimpleMockFetcher(object): assert response.final_url == url return response + class TestDiscoveryFailure(datadriven.DataDrivenTestCase): cases = [ [HTTPResponse('http://network.error/', None)], @@ -33,9 +34,9 @@ class TestDiscoveryFailure(datadriven.DataDrivenTestCase): [HTTPResponse('http://bad.request/', 400)], [HTTPResponse('http://server.error/', 500)], [HTTPResponse('http://header.found/', 200, - headers={'x-xrds-location':'http://xrds.missing/'}), + headers={'x-xrds-location': 'http://xrds.missing/'}), HTTPResponse('http://xrds.missing/', 404)], - ] + ] def __init__(self, responses): self.url = responses[0].final_url @@ -53,14 +54,14 @@ class TestDiscoveryFailure(datadriven.DataDrivenTestCase): expected_status = self.responses[-1].status try: discover.discover(self.url) - except DiscoveryFailure, why: + except DiscoveryFailure as why: self.failUnlessEqual(why.http_response.status, expected_status) else: self.fail('Did not raise DiscoveryFailure') -### Tests for raising/catching exceptions from the fetcher through the -### discover function +# Tests for raising/catching exceptions from the fetcher through the +# discover function class ErrorRaisingFetcher(object): """Just raise an exception when fetch is called""" @@ -71,9 +72,11 @@ class ErrorRaisingFetcher(object): def fetch(self, url, body=None, headers=None): raise self.thing_to_raise + class DidFetch(Exception): """Custom exception just to make sure it's not handled differently""" + class TestFetchException(datadriven.DataDrivenTestCase): """Make sure exceptions get passed through discover function from fetcher.""" @@ -83,7 +86,7 @@ class TestFetchException(datadriven.DataDrivenTestCase): DidFetch(), ValueError(), RuntimeError(), - ] + ] def __init__(self, exc): datadriven.DataDrivenTestCase.__init__(self, repr(exc)) @@ -99,7 +102,7 @@ class TestFetchException(datadriven.DataDrivenTestCase): def runOneTest(self): try: discover.discover('http://doesnt.matter/') - except: + except Exception: exc = sys.exc_info()[1] if exc is None: # str exception @@ -110,7 +113,7 @@ class TestFetchException(datadriven.DataDrivenTestCase): self.fail('Expected %r', self.exc) -### Tests for openid.consumer.discover.discover +# Tests for openid.consumer.discover.discover class TestNormalization(unittest.TestCase): def testAddingProtocol(self): @@ -119,10 +122,10 @@ class TestNormalization(unittest.TestCase): try: discover.discover('users.stompy.janrain.com:8000/x') - except DiscoveryFailure, why: + except DiscoveryFailure: self.fail('failed to parse url with port correctly') except RuntimeError: - pass #expected + pass # expected fetchers.setDefaultFetcher(None) @@ -154,6 +157,7 @@ class DiscoveryMockFetcher(object): # from twisted.trial import unittest as trialtest + class BaseTestDiscovery(unittest.TestCase): id_url = "http://someuser.unittest/" @@ -195,7 +199,7 @@ class BaseTestDiscovery(unittest.TestCase): '1.0': discover.OPENID_1_0_TYPE, '2.0': discover.OPENID_2_0_TYPE, '2.0 OP': discover.OPENID_IDP_2_0_TYPE, - } + } type_uris = [openid_types[t] for t in types] self.failUnlessEqual(type_uris, s.type_uris) @@ -217,12 +221,14 @@ class BaseTestDiscovery(unittest.TestCase): def tearDown(self): fetchers.setDefaultFetcher(None) + def readDataFile(filename): module_directory = os.path.dirname(os.path.abspath(__file__)) filename = os.path.join( module_directory, 'data', 'test_discover', filename) return file(filename).read() + class TestDiscovery(BaseTestDiscovery): def _discover(self, content_type, data, expected_services, expected_id=None): @@ -254,8 +260,7 @@ class TestDiscovery(BaseTestDiscovery): """ data = readDataFile('unicode2.html') self.failUnlessRaises(UnicodeDecodeError, data.decode, 'utf-8') - self._discover(content_type='text/html;charset=utf-8', - data=data, expected_services=0) + self._discover(content_type='text/html;charset=utf-8', data=data, expected_services=0) def test_unicode_undecodable_html2(self): """ @@ -267,8 +272,7 @@ class TestDiscovery(BaseTestDiscovery): data = readDataFile('unicode3.html') self.failUnlessRaises(UnicodeDecodeError, data.decode, 'utf-8') - self._discover(content_type='text/html;charset=utf-8', - data=data, expected_services=1) + self._discover(content_type='text/html;charset=utf-8', data=data, expected_services=1) def test_noOpenID(self): services = self._discover(content_type='text/plain', @@ -279,7 +283,7 @@ class TestDiscovery(BaseTestDiscovery): content_type='text/html', data=readDataFile('openid_no_delegate.html'), expected_services=1, - ) + ) self._checkService( services[0], @@ -288,7 +292,7 @@ class TestDiscovery(BaseTestDiscovery): server_url="http://www.myopenid.com/server", claimed_id=self.id_url, local_id=self.id_url, - ) + ) def test_html1(self): services = self._discover( @@ -296,7 +300,6 @@ class TestDiscovery(BaseTestDiscovery): data=readDataFile('openid.html'), expected_services=1) - self._checkService( services[0], used_yadis=False, @@ -305,7 +308,7 @@ class TestDiscovery(BaseTestDiscovery): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_html1Fragment(self): """Ensure that the Claimed Identifier does not have a fragment @@ -329,14 +332,14 @@ class TestDiscovery(BaseTestDiscovery): claimed_id=expected_id, local_id='http://smoker.myopenid.com/', display_identifier=expected_id, - ) + ) def test_html2(self): services = self._discover( content_type='text/html', data=readDataFile('openid2.html'), expected_services=1, - ) + ) self._checkService( services[0], @@ -346,14 +349,14 @@ class TestDiscovery(BaseTestDiscovery): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_html1And2(self): services = self._discover( content_type='text/html', data=readDataFile('openid_1_and_2.html'), expected_services=2, - ) + ) for t, s in zip(['2.0', '1.1'], services): self._checkService( @@ -364,12 +367,11 @@ class TestDiscovery(BaseTestDiscovery): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_yadisEmpty(self): - services = self._discover(content_type='application/xrds+xml', - data=readDataFile('yadis_0entries.xml'), - expected_services=0) + self._discover(content_type='application/xrds+xml', data=readDataFile('yadis_0entries.xml'), + expected_services=0) def test_htmlEmptyYadis(self): """HTML document has discovery information, but points to an @@ -390,7 +392,7 @@ class TestDiscovery(BaseTestDiscovery): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_yadis1NoDelegate(self): services = self._discover(content_type='application/xrds+xml', @@ -405,14 +407,14 @@ class TestDiscovery(BaseTestDiscovery): claimed_id=self.id_url, local_id=self.id_url, display_identifier=self.id_url, - ) + ) def test_yadis2NoLocalID(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('openid2_xrds_no_local_id.xml'), expected_services=1, - ) + ) self._checkService( services[0], @@ -422,14 +424,14 @@ class TestDiscovery(BaseTestDiscovery): claimed_id=self.id_url, local_id=self.id_url, display_identifier=self.id_url, - ) + ) def test_yadis2(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('openid2_xrds.xml'), expected_services=1, - ) + ) self._checkService( services[0], @@ -439,14 +441,14 @@ class TestDiscovery(BaseTestDiscovery): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_yadis2OP(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('yadis_idp.xml'), expected_services=1, - ) + ) self._checkService( services[0], @@ -454,7 +456,7 @@ class TestDiscovery(BaseTestDiscovery): types=['2.0 OP'], server_url="http://www.myopenid.com/server", display_identifier=self.id_url, - ) + ) def test_yadis2OPDelegate(self): """The delegate tag isn't meaningful for OP entries.""" @@ -462,7 +464,7 @@ class TestDiscovery(BaseTestDiscovery): content_type='application/xrds+xml', data=readDataFile('yadis_idp_delegate.xml'), expected_services=1, - ) + ) self._checkService( services[0], @@ -470,21 +472,20 @@ class TestDiscovery(BaseTestDiscovery): types=['2.0 OP'], server_url="http://www.myopenid.com/server", display_identifier=self.id_url, - ) + ) def test_yadis2BadLocalID(self): self.failUnlessRaises(DiscoveryFailure, self._discover, - content_type='application/xrds+xml', - data=readDataFile('yadis_2_bad_local_id.xml'), - expected_services=1, - ) + content_type='application/xrds+xml', + data=readDataFile('yadis_2_bad_local_id.xml'), + expected_services=1) def test_yadis1And2(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('openid_1_and_2_xrds.xml'), expected_services=1, - ) + ) self._checkService( services[0], @@ -494,14 +495,14 @@ class TestDiscovery(BaseTestDiscovery): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_yadis1And2BadLocalID(self): self.failUnlessRaises(DiscoveryFailure, self._discover, - content_type='application/xrds+xml', - data=readDataFile('openid_1_and_2_xrds_bad_delegate.xml'), - expected_services=1, - ) + content_type='application/xrds+xml', + data=readDataFile('openid_1_and_2_xrds_bad_delegate.xml'), + expected_services=1) + class MockFetcherForXRIProxy(object): @@ -510,12 +511,10 @@ class MockFetcherForXRIProxy(object): self.fetchlog = [] self.proxy_url = None - def fetch(self, url, body=None, headers=None): self.fetchlog.append((url, body, headers)) u = urlsplit(url) - proxy_host = u[1] xri = u[2] query = u[3] @@ -544,7 +543,7 @@ class TestXRIDiscovery(BaseTestDiscovery): documents = {'=smoker': ('application/xrds+xml', readDataFile('yadis_2entries_delegate.xml')), '=smoker*bad': ('application/xrds+xml', - readDataFile('yadis_another_delegate.xml')) } + readDataFile('yadis_another_delegate.xml'))} def test_xri(self): user_xri, services = discover.discoverXRI('=smoker') @@ -558,7 +557,7 @@ class TestXRIDiscovery(BaseTestDiscovery): canonical_id=XRI("=!1000"), local_id='http://smoker.myopenid.com/', display_identifier='=smoker' - ) + ) self._checkService( services[1], @@ -569,7 +568,7 @@ class TestXRIDiscovery(BaseTestDiscovery): canonical_id=XRI("=!1000"), local_id='http://frank.livejournal.com/', display_identifier='=smoker' - ) + ) def test_xri_normalize(self): user_xri, services = discover.discoverXRI('xri://=smoker') @@ -583,7 +582,7 @@ class TestXRIDiscovery(BaseTestDiscovery): canonical_id=XRI("=!1000"), local_id='http://smoker.myopenid.com/', display_identifier='=smoker' - ) + ) self._checkService( services[1], @@ -594,7 +593,7 @@ class TestXRIDiscovery(BaseTestDiscovery): canonical_id=XRI("=!1000"), local_id='http://frank.livejournal.com/', display_identifier='=smoker' - ) + ) def test_xriNoCanonicalID(self): user_xri, services = discover.discoverXRI('=smoker*bad') @@ -613,7 +612,7 @@ class TestXRIDiscoveryIDP(BaseTestDiscovery): fetcherClass = MockFetcherForXRIProxy documents = {'=smoker': ('application/xrds+xml', - readDataFile('yadis_2entries_idp.xml')) } + readDataFile('yadis_2entries_idp.xml'))} def test_xri(self): user_xri, services = discover.discoverXRI('=smoker') @@ -646,7 +645,8 @@ class TestPreferredNamespace(datadriven.DataDrivenTestCase): discover.OPENID_1_0_TYPE]), (message.OPENID2_NS, [discover.OPENID_1_0_TYPE, discover.OPENID_2_0_TYPE]), - ] + ] + class TestIsOPIdentifier(unittest.TestCase): def setUp(self): @@ -682,6 +682,7 @@ class TestIsOPIdentifier(unittest.TestCase): discover.OPENID_IDP_2_0_TYPE] self.failUnless(self.endpoint.isOPIdentifier()) + class TestFromOPEndpointURL(unittest.TestCase): def setUp(self): self.op_endpoint_url = 'http://example.com/op/endpoint' @@ -704,6 +705,7 @@ class TestFromOPEndpointURL(unittest.TestCase): def test_serverURL(self): self.failUnlessEqual(self.endpoint.server_url, self.op_endpoint_url) + class TestDiscoverFunction(unittest.TestCase): def setUp(self): self._old_discoverURI = discover.discoverURI @@ -734,6 +736,7 @@ class TestDiscoverFunction(unittest.TestCase): def test_xriChar(self): self.failUnlessEqual('XRI', discover.discover('=something')) + class TestEndpointSupportsType(unittest.TestCase): def setUp(self): self.endpoint = discover.OpenIDServiceEndpoint() @@ -745,7 +748,7 @@ class TestEndpointSupportsType(unittest.TestCase): discover.OPENID_1_0_TYPE, discover.OPENID_2_0_TYPE, discover.OPENID_IDP_2_0_TYPE, - ]: + ]: if t in types: self.failUnless(self.endpoint.supportsType(t), "Must support %r" % (t,)) @@ -799,6 +802,7 @@ class TestEndpointDisplayIdentifier(unittest.TestCase): def pyUnitTests(): return datadriven.loadTests(__name__) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index c3ff68a..cae2712 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -8,7 +8,8 @@ def datapath(filename): module_directory = os.path.dirname(os.path.abspath(__file__)) return os.path.join(module_directory, 'data', 'test_etxrd', filename) -XRD_FILE = datapath('valid-populated-xrds.xml') + +XRD_FILE = datapath('valid-populated-xrds.xml') NOXRDS_FILE = datapath('not-xrds.xml') NOXRD_FILE = datapath('no-xrd.xml') @@ -18,6 +19,7 @@ NOXRD_FILE = datapath('no-xrd.xml') LID_2_0 = "http://lid.netmesh.org/sso/2.0b5" TYPEKEY_1_0 = "http://typekey.com/services/1.0" + def simpleOpenIDTransformer(endpoint): """Function to extract information from an OpenID service element""" if 'http://openid.net/signon/1.0' not in endpoint.type_uris: @@ -29,6 +31,7 @@ def simpleOpenIDTransformer(endpoint): delegate = delegates[0].text return (endpoint.uri, delegate) + class TestServiceParser(unittest.TestCase): def setUp(self): self.xmldoc = file(XRD_FILE).read() @@ -39,7 +42,7 @@ class TestServiceParser(unittest.TestCase): def testParse(self): """Make sure that parsing succeeds at all""" - services = self._getServices() + self._getServices() def testParseOpenID(self): """Parse for OpenID services with a transformer function""" @@ -50,7 +53,7 @@ class TestServiceParser(unittest.TestCase): ("http://www.schtuff.com/openid", "http://users.schtuff.com/josh"), ("http://www.livejournal.com/openid/server.bml", "http://www.livejournal.com/users/nedthealpaca/"), - ] + ] it = iter(services) for (server_url, delegate) in expectedServices: @@ -79,16 +82,13 @@ class TestServiceParser(unittest.TestCase): # type, URL (TYPEKEY_1_0, None), (LID_2_0, "http://mylid.net/josh"), - ] + ] self._checkServices(expectedServices) def testGetSeveralForOne(self): """Getting services for one Service with several Type elements.""" - types = [ 'http://lid.netmesh.org/sso/2.0b5' - , 'http://lid.netmesh.org/2.0b5' - ] - + types = ['http://lid.netmesh.org/sso/2.0b5', 'http://lid.netmesh.org/2.0b5'] uri = "http://mylid.net/josh" for service in self._getServices(): @@ -131,6 +131,7 @@ class TestCanonicalID(unittest.TestCase): test for the given set of inputs""" filename = datapath(filename) + def test(self): xrds = etxrd.parseXRDS(file(filename).read()) self._getCanonicalID(iname, xrds, expectedID) diff --git a/openid/test/test_examples.py b/openid/test/test_examples.py index ca83d83..e1c5797 100644 --- a/openid/test/test_examples.py +++ b/openid/test/test_examples.py @@ -52,6 +52,7 @@ def splitDir(d, count): d = os.path.dirname(d) return d + def runExampleServer(host, port, data_path): thisfile = os.path.abspath(sys.modules[__name__].__file__) topDir = splitDir(thisfile, 3) @@ -64,7 +65,6 @@ def runExampleServer(host, port, data_path): serverMain(host, port, data_path) - class TestServer(unittest.TestCase): """Acceptance tests for examples/server.py. @@ -88,13 +88,11 @@ class TestServer(unittest.TestCase): twill.commands.reset_browser() - def runExampleServer(self): """Zero-arg run-the-server function to be passed to TestInfo.""" # FIXME - make sure sstore starts clean. runExampleServer('127.0.0.1', self.server_port, 'sstore') - def v1endpoint(self, port): """Return an OpenID 1.1 OpenIDServiceEndpoint for the server.""" base = "http://%s:%s" % (socket.getfqdn('127.0.0.1'), port) @@ -104,7 +102,6 @@ class TestServer(unittest.TestCase): ep.type_uris = [OPENID_1_1_TYPE] return ep - # TODO: test discovery def test_checkidv1(self): @@ -116,7 +113,6 @@ class TestServer(unittest.TestCase): if self.twillErr.getvalue(): self.fail(self.twillErr.getvalue()) - def test_allowed(self): """OpenID 1.1 checkid_setup request.""" ti = TwillTest(self.twill_allowed, self.runExampleServer, @@ -126,7 +122,6 @@ class TestServer(unittest.TestCase): if self.twillErr.getvalue(): self.fail(self.twillErr.getvalue()) - def twill_checkidv1(self, twillInfo): endpoint = self.v1endpoint(self.server_port) authreq = AuthRequest(endpoint, assoc=None) @@ -143,12 +138,11 @@ class TestServer(unittest.TestCase): finalURL = headers['Location'] self.failUnless('openid.mode=id_res' in finalURL, finalURL) self.failUnless('openid.identity=' in finalURL, finalURL) - except twill.commands.TwillAssertionError, e: + except twill.commands.TwillAssertionError as e: msg = '%s\nFinal page:\n%s' % ( str(e), c.get_browser().get_html()) self.fail(msg) - def twill_allowed(self, twillInfo): endpoint = self.v1endpoint(self.server_port) authreq = AuthRequest(endpoint, assoc=None) @@ -171,7 +165,7 @@ class TestServer(unittest.TestCase): headers = c.get_browser()._browser.response().info() finalURL = headers['Location'] self.failUnless(finalURL.startswith(self.return_to)) - except twill.commands.TwillAssertionError, e: + except twill.commands.TwillAssertionError: from traceback import format_exc msg = '%s\nTwill output:%s\nTwill errors:%s\nFinal page:\n%s' % ( format_exc(), @@ -180,7 +174,6 @@ class TestServer(unittest.TestCase): c.get_browser().get_html()) self.fail(msg) - def tearDown(self): twill.set_output(None) twill.set_errout(None) diff --git a/openid/test/test_extension.py b/openid/test/test_extension.py index 11ba1b2..0f714c6 100644 --- a/openid/test/test_extension.py +++ b/openid/test/test_extension.py @@ -10,6 +10,7 @@ class DummyExtension(extension.Extension): def getExtensionArgs(self): return {} + class ToMessageTest(unittest.TestCase): def test_OpenID1(self): oid1_msg = message.Message(message.OPENID1_NS) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 1ec5641..4cf5a22 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -13,6 +13,7 @@ from openid import fetchers # XXX: make these separate test cases + def failUnlessResponseExpected(expected, actual): assert expected.final_url == actual.final_url, ( "%r != %r" % (expected.final_url, actual.final_url)) @@ -32,7 +33,7 @@ def test_fetcher(fetcher, exc, server): server.socket.getsockname()[1], path) - expected_headers = {'content-type':'text/plain'} + expected_headers = {'content-type': 'text/plain'} def plain(path, code): path = '/' + path @@ -53,15 +54,13 @@ def test_fetcher(fetcher, exc, server): plain('forbidden', 403), plain('error', 500), plain('server_error', 503), - ] + ] for path, expected in cases: fetch_url = geturl(path) try: actual = fetcher.fetch(fetch_url) - except (SystemExit, KeyboardInterrupt): - pass - except: + except Exception: print fetcher, fetch_url raise else: @@ -73,29 +72,28 @@ def test_fetcher(fetcher, exc, server): 'ftp://janrain.com/pub/']: try: result = fetcher.fetch(err_url) - except (KeyboardInterrupt, SystemExit): - raise - except fetchers.HTTPError, why: + except fetchers.HTTPError: # This is raised by the Curl fetcher for bad cases # detected by the fetchers module, but it's a subclass of # HTTPFetchingError, so we have to catch it explicitly. assert exc - except fetchers.HTTPFetchingError, why: + except fetchers.HTTPFetchingError: assert not exc, (fetcher, exc, server) - except: + except Exception: assert exc else: assert False, 'An exception was expected for %r (%r)' % (fetcher, result) + def run_fetcher_tests(server): exc_fetchers = [] for klass, library_name in [ (fetchers.CurlHTTPFetcher, 'pycurl'), (fetchers.HTTPLib2Fetcher, 'httplib2'), - ]: + ]: try: exc_fetchers.append(klass()) - except RuntimeError, why: + except RuntimeError as why: if why[0].startswith('Cannot find %s library' % (library_name,)): try: __import__(library_name) @@ -122,17 +120,17 @@ def run_fetcher_tests(server): class FetcherTestHandler(BaseHTTPRequestHandler): cases = { - '/success':(200, None), - '/301redirect':(301, '/success'), - '/302redirect':(302, '/success'), - '/303redirect':(303, '/success'), - '/307redirect':(307, '/success'), - '/notfound':(404, None), - '/badreq':(400, None), - '/forbidden':(403, None), - '/error':(500, None), - '/server_error':(503, None), - } + '/success': (200, None), + '/301redirect': (301, '/success'), + '/302redirect': (302, '/success'), + '/303redirect': (303, '/success'), + '/307redirect': (307, '/success'), + '/notfound': (404, None), + '/badreq': (400, None), + '/forbidden': (403, None), + '/error': (500, None), + '/server_error': (503, None), + } def log_request(self, *args): pass @@ -173,7 +171,7 @@ class FetcherTestHandler(BaseHTTPRequestHandler): req = [ ('HTTP method', self.command), ('path', self.path), - ] + ] if message: req.append(('message', message)) @@ -197,6 +195,7 @@ class FetcherTestHandler(BaseHTTPRequestHandler): self.wfile.close() self.rfile.close() + def test(): import socket host = socket.getfqdn('127.0.0.1') @@ -215,12 +214,14 @@ def test(): run_fetcher_tests(server) + class FakeFetcher(object): sentinel = object() def fetch(self, *args, **kwargs): return self.sentinel + class DefaultFetcherTest(unittest.TestCase): def setUp(self): """reset the default fetcher to None""" @@ -276,7 +277,7 @@ class DefaultFetcherTest(unittest.TestCase): fetchers.fetch('http://invalid.janrain.com/') except fetchers.HTTPFetchingError: self.fail('Should not be wrapping exception') - except: + except Exception: exc = sys.exc_info()[1] self.failUnless(isinstance(exc, urllib2.URLError), exc) pass diff --git a/openid/test/test_htmldiscover.py b/openid/test/test_htmldiscover.py index e310435..188565b 100644 --- a/openid/test/test_htmldiscover.py +++ b/openid/test/test_htmldiscover.py @@ -8,7 +8,7 @@ class BadLinksTestCase(datadriven.DataDrivenTestCase): '', "http://not.in.a.link.tag/", '<link rel="openid.server" href="not.in.html.or.head" />', - ] + ] def __init__(self, data): datadriven.DataDrivenTestCase.__init__(self, data) @@ -19,5 +19,6 @@ class BadLinksTestCase(datadriven.DataDrivenTestCase): expected = [] self.failUnlessEqual(expected, actual) + def pyUnitTests(): return datadriven.loadTests(__name__) diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 571eec9..be6fc21 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -24,6 +24,7 @@ def mkGetArgTest(ns, key, expected=None): return test + class EmptyMessageTest(unittest.TestCase): def setUp(self): self.msg = message.Message() @@ -94,7 +95,7 @@ class EmptyMessageTest(unittest.TestCase): 'openid.test.flub': 'bogus'}) actual_uri = msg.getAliasedArg('ns.test', message.no_default) self.assertEquals("urn://foo", actual_uri) - + def test_getAliasedArgFailure(self): msg = message.Message.fromPostArgs({'openid.test.flub': 'bogus'}) self.assertRaises(KeyError, @@ -136,13 +137,13 @@ class EmptyMessageTest(unittest.TestCase): def test_updateArgs(self): self.failUnlessRaises(message.UndefinedOpenIDNamespace, self.msg.updateArgs, message.OPENID_NS, - {'does not':'matter'}) + {'does not': 'matter'}) def _test_updateArgsNS(self, ns): update_args = { - 'Camper van Beethoven':'David Lowery', - 'Magnolia Electric Co.':'Jason Molina', - } + 'Camper van Beethoven': 'David Lowery', + 'Magnolia Electric Co.': 'Jason Molina', + } self.failUnlessEqual(self.msg.getArgs(ns), {}) self.msg.updateArgs(ns, update_args) @@ -219,19 +220,20 @@ class EmptyMessageTest(unittest.TestCase): def test_isOpenID2(self): self.failIf(self.msg.isOpenID2()) + class OpenID1MessageTest(unittest.TestCase): def setUp(self): - self.msg = message.Message.fromPostArgs({'openid.mode':'error', - 'openid.error':'unit test'}) + self.msg = message.Message.fromPostArgs({'openid.mode': 'error', + 'openid.error': 'unit test'}) def test_toPostArgs(self): self.failUnlessEqual(self.msg.toPostArgs(), - {'openid.mode':'error', - 'openid.error':'unit test'}) + {'openid.mode': 'error', + 'openid.error': 'unit test'}) def test_toArgs(self): - self.failUnlessEqual(self.msg.toArgs(), {'mode':'error', - 'error':'unit test'}) + self.failUnlessEqual(self.msg.toArgs(), {'mode': 'error', + 'error': 'unit test'}) def test_toKVForm(self): self.failUnlessEqual(self.msg.toKVForm(), @@ -249,8 +251,8 @@ class OpenID1MessageTest(unittest.TestCase): self.failUnlessEqual(actual[len(base_url)], '?') query = actual[len(base_url) + 1:] parsed = cgi.parse_qs(query) - self.failUnlessEqual(parsed, {'openid.mode':['error'], - 'openid.error':['unit test']}) + self.failUnlessEqual(parsed, {'openid.mode': ['error'], + 'openid.error': ['unit test']}) def test_getOpenID(self): self.failUnlessEqual(self.msg.getOpenIDNamespace(), message.OPENID1_NS) @@ -298,18 +300,14 @@ class OpenID1MessageTest(unittest.TestCase): def test_getArgs(self): self.failUnlessEqual(self.msg.getArgs(message.OPENID_NS), - {'mode':'error', - 'error':'unit test', - }) + {'mode': 'error', 'error': 'unit test'}) def test_getArgsBARE(self): self.failUnlessEqual(self.msg.getArgs(message.BARE_NS), {}) def test_getArgsNS1(self): self.failUnlessEqual(self.msg.getArgs(message.OPENID1_NS), - {'mode':'error', - 'error':'unit test', - }) + {'mode': 'error', 'error': 'unit test'}) def test_getArgsNS2(self): self.failUnlessEqual(self.msg.getArgs(message.OPENID2_NS), {}) @@ -321,9 +319,9 @@ class OpenID1MessageTest(unittest.TestCase): if before is None: before = {} update_args = { - 'Camper van Beethoven':'David Lowery', - 'Magnolia Electric Co.':'Jason Molina', - } + 'Camper van Beethoven': 'David Lowery', + 'Magnolia Electric Co.': 'Jason Molina', + } self.failUnlessEqual(self.msg.getArgs(ns), before) self.msg.updateArgs(ns, update_args) @@ -333,14 +331,14 @@ class OpenID1MessageTest(unittest.TestCase): def test_updateArgs(self): self._test_updateArgsNS(message.OPENID_NS, - before={'mode':'error', 'error':'unit test'}) + before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsBARE(self): self._test_updateArgsNS(message.BARE_NS) def test_updateArgsNS1(self): self._test_updateArgsNS(message.OPENID1_NS, - before={'mode':'error', 'error':'unit test'}) + before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsNS2(self): self._test_updateArgsNS(message.OPENID2_NS) @@ -395,40 +393,40 @@ class OpenID1MessageTest(unittest.TestCase): def test_delArgNS3(self): self._test_delArgNS('urn:nothing-significant') - def test_isOpenID1(self): self.failUnless(self.msg.isOpenID1()) def test_isOpenID2(self): self.failIf(self.msg.isOpenID2()) + class OpenID1ExplicitMessageTest(unittest.TestCase): def setUp(self): - self.msg = message.Message.fromPostArgs({'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID1_NS + self.msg = message.Message.fromPostArgs({'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID1_NS }) def test_toPostArgs(self): self.failUnlessEqual(self.msg.toPostArgs(), - {'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID1_NS + {'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID1_NS }) def test_toArgs(self): - self.failUnlessEqual(self.msg.toArgs(), {'mode':'error', - 'error':'unit test', - 'ns':message.OPENID1_NS}) + self.failUnlessEqual(self.msg.toArgs(), {'mode': 'error', + 'error': 'unit test', + 'ns': message.OPENID1_NS}) def test_toKVForm(self): self.failUnlessEqual(self.msg.toKVForm(), - 'error:unit test\nmode:error\nns:%s\n' - %message.OPENID1_NS) + 'error:unit test\nmode:error\nns:%s\n' % message.OPENID1_NS) def test_toURLEncoded(self): - self.failUnlessEqual(self.msg.toURLEncoded(), - 'openid.error=unit+test&openid.mode=error&openid.ns=http%3A%2F%2Fopenid.net%2Fsignon%2F1.0') + self.failUnlessEqual( + self.msg.toURLEncoded(), + 'openid.error=unit+test&openid.mode=error&openid.ns=http%3A%2F%2Fopenid.net%2Fsignon%2F1.0') def test_toURL(self): base_url = 'http://base.url/' @@ -438,50 +436,48 @@ class OpenID1ExplicitMessageTest(unittest.TestCase): self.failUnlessEqual(actual[len(base_url)], '?') query = actual[len(base_url) + 1:] parsed = cgi.parse_qs(query) - self.failUnlessEqual(parsed, {'openid.mode':['error'], - 'openid.error':['unit test'], - 'openid.ns':[message.OPENID1_NS] - }) + self.failUnlessEqual( + parsed, + {'openid.mode': ['error'], 'openid.error': ['unit test'], 'openid.ns': [message.OPENID1_NS]}) def test_isOpenID1(self): self.failUnless(self.msg.isOpenID1()) + class OpenID2MessageTest(unittest.TestCase): def setUp(self): - self.msg = message.Message.fromPostArgs({'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID2_NS - }) + self.msg = message.Message.fromPostArgs({'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS}) self.msg.setArg(message.BARE_NS, "xey", "value") def test_toPostArgs(self): self.failUnlessEqual(self.msg.toPostArgs(), - {'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID2_NS, + {'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS, 'xey': 'value', }) def test_toPostArgs_bug_with_utf8_encoded_values(self): - msg = message.Message.fromPostArgs({'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID2_NS - }) + msg = message.Message.fromPostArgs({'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS + }) msg.setArg(message.BARE_NS, 'ünicöde_key', 'ünicöde_välüe') self.failUnlessEqual(msg.toPostArgs(), - {'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID2_NS, + {'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS, 'ünicöde_key': 'ünicöde_välüe', }) - def test_toArgs(self): # This method can't tolerate BARE_NS. self.msg.delArg(message.BARE_NS, "xey") - self.failUnlessEqual(self.msg.toArgs(), {'mode':'error', - 'error':'unit test', - 'ns':message.OPENID2_NS, + self.failUnlessEqual(self.msg.toArgs(), {'mode': 'error', + 'error': 'unit test', + 'ns': message.OPENID2_NS, }) def test_toKVForm(self): @@ -492,12 +488,10 @@ class OpenID2MessageTest(unittest.TestCase): (message.OPENID2_NS,)) def _test_urlencoded(self, s): - expected = ('openid.error=unit+test&openid.mode=error&' - 'openid.ns=%s&xey=value' % ( - urllib.quote(message.OPENID2_NS, ''),)) + expected = ('openid.error=unit+test&openid.mode=error&openid.ns=%s&xey=value' % + urllib.quote(message.OPENID2_NS, '')) self.failUnlessEqual(s, expected) - def test_toURLEncoded(self): self._test_urlencoded(self.msg.toURLEncoded()) @@ -558,9 +552,7 @@ class OpenID2MessageTest(unittest.TestCase): def test_getArgsOpenID(self): self.failUnlessEqual(self.msg.getArgs(message.OPENID_NS), - {'mode':'error', - 'error':'unit test', - }) + {'mode': 'error', 'error': 'unit test'}) def test_getArgsBARE(self): self.failUnlessEqual(self.msg.getArgs(message.BARE_NS), @@ -571,9 +563,7 @@ class OpenID2MessageTest(unittest.TestCase): def test_getArgsNS2(self): self.failUnlessEqual(self.msg.getArgs(message.OPENID2_NS), - {'mode':'error', - 'error':'unit test', - }) + {'mode': 'error', 'error': 'unit test'}) def test_getArgsNS3(self): self.failUnlessEqual(self.msg.getArgs('urn:nothing-significant'), {}) @@ -582,9 +572,9 @@ class OpenID2MessageTest(unittest.TestCase): if before is None: before = {} update_args = { - 'Camper van Beethoven':'David Lowery', - 'Magnolia Electric Co.':'Jason Molina', - } + 'Camper van Beethoven': 'David Lowery', + 'Magnolia Electric Co.': 'Jason Molina', + } self.failUnlessEqual(self.msg.getArgs(ns), before) self.msg.updateArgs(ns, update_args) @@ -594,18 +584,18 @@ class OpenID2MessageTest(unittest.TestCase): def test_updateArgsOpenID(self): self._test_updateArgsNS(message.OPENID_NS, - before={'mode':'error', 'error':'unit test'}) + before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsBARE(self): self._test_updateArgsNS(message.BARE_NS, - before={'xey':'value'}) + before={'xey': 'value'}) def test_updateArgsNS1(self): self._test_updateArgsNS(message.OPENID1_NS) def test_updateArgsNS2(self): self._test_updateArgsNS(message.OPENID2_NS, - before={'mode':'error', 'error':'unit test'}) + before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsNS3(self): self._test_updateArgsNS('urn:nothing-significant') @@ -649,52 +639,53 @@ class OpenID2MessageTest(unittest.TestCase): def test_mysterious_missing_namespace_bug(self): """A failing test for bug #112""" openid_args = { - 'assoc_handle': '{{HMAC-SHA256}{1211477242.29743}{v5cadg==}', - 'claimed_id': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', - 'ns.sreg': 'http://openid.net/extensions/sreg/1.1', - 'response_nonce': '2008-05-22T17:27:22ZUoW5.\\NV', - 'signed': 'return_to,identity,claimed_id,op_endpoint,response_nonce,ns.sreg,sreg.email,sreg.nickname,assoc_handle', - 'sig': 'e3eGZ10+TNRZitgq5kQlk5KmTKzFaCRI8OrRoXyoFa4=', - 'mode': 'check_authentication', - 'op_endpoint': 'http://nerdbank.org/OPAffirmative/ProviderNoAssoc.aspx', - 'sreg.nickname': 'Andy', - 'return_to': 'http://localhost.localdomain:8001/process?janrain_nonce=2008-05-22T17%3A27%3A21ZnxHULd', - 'invalidate_handle': '{{HMAC-SHA1}{1211477241.92242}{H0akXw==}', - 'identity': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', - 'sreg.email': 'a@b.com' - } + 'assoc_handle': '{{HMAC-SHA256}{1211477242.29743}{v5cadg==}', + 'claimed_id': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', + 'ns.sreg': 'http://openid.net/extensions/sreg/1.1', + 'response_nonce': '2008-05-22T17:27:22ZUoW5.\\NV', + 'signed': 'return_to,identity,claimed_id,op_endpoint,response_nonce,ns.sreg,sreg.email,sreg.nickname,' + 'assoc_handle', + 'sig': 'e3eGZ10+TNRZitgq5kQlk5KmTKzFaCRI8OrRoXyoFa4=', + 'mode': 'check_authentication', + 'op_endpoint': 'http://nerdbank.org/OPAffirmative/ProviderNoAssoc.aspx', + 'sreg.nickname': 'Andy', + 'return_to': 'http://localhost.localdomain:8001/process?janrain_nonce=2008-05-22T17%3A27%3A21ZnxHULd', + 'invalidate_handle': '{{HMAC-SHA1}{1211477241.92242}{H0akXw==}', + 'identity': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', + 'sreg.email': 'a@b.com'} m = message.Message.fromOpenIDArgs(openid_args) self.failUnless(('http://openid.net/extensions/sreg/1.1', 'sreg') in list(m.namespaces.iteritems())) missing = [] for k in openid_args['signed'].split(','): - if not ("openid."+k) in m.toPostArgs().keys(): + if not ("openid." + k) in m.toPostArgs().keys(): missing.append(k) self.assertEqual([], missing, missing) self.assertEqual(openid_args, m.toArgs()) self.failUnless(m.isOpenID1()) def test_112B(self): - args = {'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', - 'openid.claimed_id': 'http://binkley.lan/user/test01', - 'openid.identity': 'http://test01.binkley.lan/', - 'openid.mode': 'id_res', - 'openid.ns': 'http://specs.openid.net/auth/2.0', - 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', - 'openid.op_endpoint': 'http://binkley.lan/server', - 'openid.pape.auth_policies': 'none', - 'openid.pape.auth_time': '2008-01-28T20:42:36Z', - 'openid.pape.nist_auth_level': '0', - 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', - 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', - 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', - 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies' - } + args = { + 'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', + 'openid.claimed_id': 'http://binkley.lan/user/test01', + 'openid.identity': 'http://test01.binkley.lan/', + 'openid.mode': 'id_res', + 'openid.ns': 'http://specs.openid.net/auth/2.0', + 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', + 'openid.op_endpoint': 'http://binkley.lan/server', + 'openid.pape.auth_policies': 'none', + 'openid.pape.auth_time': '2008-01-28T20:42:36Z', + 'openid.pape.nist_auth_level': '0', + 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', + 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', + 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', + 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,' + 'ns.pape,pape.nist_auth_level,pape.auth_policies'} m = message.Message.fromPostArgs(args) missing = [] for k in args['openid.signed'].split(','): - if not ("openid."+k) in m.toPostArgs().keys(): + if not ("openid." + k) in m.toPostArgs().keys(): missing.append(k) self.assertEqual([], missing, missing) self.assertEqual(args, m.toPostArgs()) @@ -704,27 +695,27 @@ class OpenID2MessageTest(unittest.TestCase): """ Message that raises KeyError during encoding, because openid namespace is used in attributes """ - args = {'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', - 'openid.claimed_id': 'http://binkley.lan/user/test01', - 'openid.identity': 'http://test01.binkley.lan/', - 'openid.mode': 'id_res', - 'openid.ns': 'http://specs.openid.net/auth/2.0', - 'openid.op_endpoint': 'http://binkley.lan/server', - 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', - 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', - 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', - 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies', - 'openid.ns.pape': 'http://specs.openid.net/auth/2.0', - 'openid.pape.auth_policies': 'none', - 'openid.pape.auth_time': '2008-01-28T20:42:36Z', - 'openid.pape.nist_auth_level': '0', - } + args = { + 'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', + 'openid.claimed_id': 'http://binkley.lan/user/test01', + 'openid.identity': 'http://test01.binkley.lan/', + 'openid.mode': 'id_res', + 'openid.ns': 'http://specs.openid.net/auth/2.0', + 'openid.op_endpoint': 'http://binkley.lan/server', + 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', + 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', + 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', + 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,' + 'ns.pape,pape.nist_auth_level,pape.auth_policies', + 'openid.ns.pape': 'http://specs.openid.net/auth/2.0', + 'openid.pape.auth_policies': 'none', + 'openid.pape.auth_time': '2008-01-28T20:42:36Z', + 'openid.pape.nist_auth_level': '0', + } self.failUnlessRaises(message.InvalidNamespace, message.Message.fromPostArgs, args) def test_implicit_sreg_ns(self): - openid_args = { - 'sreg.email': 'a@b.com' - } + openid_args = {'sreg.email': 'a@b.com'} m = message.Message.fromOpenIDArgs(openid_args) self.failUnless((sreg.ns_uri, 'sreg') in list(m.namespaces.iteritems())) @@ -778,6 +769,7 @@ class OpenID2MessageTest(unittest.TestCase): def test_isOpenID2(self): self.failUnless(self.msg.isOpenID2()) + class MessageTest(unittest.TestCase): def setUp(self): self.postargs = { @@ -786,24 +778,24 @@ class MessageTest(unittest.TestCase): 'openid.identity': 'http://bogus.example.invalid:port/', 'openid.assoc_handle': 'FLUB', 'openid.return_to': 'Neverland', - } + } self.action_url = 'scheme://host:port/path?query' self.form_tag_attrs = { 'company': 'janrain', 'class': 'fancyCSS', - } + } self.submit_text = 'GO!' - ### Expected data regardless of input + # Expected data regardless of input self.required_form_attrs = { - 'accept-charset':'UTF-8', - 'enctype':'application/x-www-form-urlencoded', + 'accept-charset': 'UTF-8', + 'enctype': 'application/x-www-form-urlencoded', 'method': 'post', - } + } def _checkForm(self, html, message_, action_url, form_tag_attrs, submit_text): @@ -818,8 +810,7 @@ class MessageTest(unittest.TestCase): # Check required form attributes for k, v in self.required_form_attrs.iteritems(): assert form.attrib[k] == v, \ - "Expected '%s' for required form attribute '%s', got '%s'" % \ - (v, k, form.attrib[k]) + "Expected '%s' for required form attribute '%s', got '%s'" % (v, k, form.attrib[k]) # Check extra form attributes for k, v in form_tag_attrs.iteritems(): @@ -831,13 +822,11 @@ class MessageTest(unittest.TestCase): continue assert form.attrib[k] == v, \ - "Form attribute '%s' should be '%s', found '%s'" % \ - (k, v, form.attrib[k]) + "Form attribute '%s' should be '%s', found '%s'" % (k, v, form.attrib[k]) # Check hidden fields against post args - hiddens = [e for e in form \ - if e.tag.upper() == 'INPUT' and \ - e.attrib['type'].upper() == 'HIDDEN'] + hiddens = [e for e in form + if e.tag.upper() == 'INPUT' and e.attrib['type'].upper() == 'HIDDEN'] # For each post arg, make sure there is a hidden with that # value. Make sure there are no other hiddens. @@ -845,34 +834,29 @@ class MessageTest(unittest.TestCase): for e in hiddens: if e.attrib['name'] == name: assert e.attrib['value'] == value, \ - "Expected value of hidden input '%s' to be '%s', got '%s'" % \ - (e.attrib['name'], value, e.attrib['value']) + "Expected value of hidden input '%s' to be '%s', got '%s'" % \ + (e.attrib['name'], value, e.attrib['value']) break else: self.fail("Post arg '%s' not found in form" % (name,)) for e in hiddens: assert e.attrib['name'] in message_.toPostArgs().keys(), \ - "Form element for '%s' not in " + \ - "original message" % (e.attrib['name']) + "Form element for '%s' not in original message" % (e.attrib['name']) # Check action URL assert form.attrib['action'] == action_url, \ - "Expected form 'action' to be '%s', got '%s'" % \ - (action_url, form.attrib['action']) + "Expected form 'action' to be '%s', got '%s'" % (action_url, form.attrib['action']) # Check submit text - submits = [e for e in form \ - if e.tag.upper() == 'INPUT' and \ - e.attrib['type'].upper() == 'SUBMIT'] + submits = [e for e in form + if e.tag.upper() == 'INPUT' and e.attrib['type'].upper() == 'SUBMIT'] assert len(submits) == 1, \ - "Expected only one 'input' with type = 'submit', got %d" % \ - (len(submits),) + "Expected only one 'input' with type = 'submit', got %d" % (len(submits),) assert submits[0].attrib['value'] == submit_text, \ - "Expected submit value to be '%s', got '%s'" % \ - (submit_text, submits[0].attrib['value']) + "Expected submit value to be '%s', got '%s'" % (submit_text, submits[0].attrib['value']) def test_toFormMarkup(self): m = message.Message.fromPostArgs(self.postargs) @@ -888,8 +872,8 @@ class MessageTest(unittest.TestCase): 'openid.identity': 'http://bogus.example.invalid:port/', 'openid.assoc_handle': 'FLUB', 'openid.return_to': 'Neverland', - 'ünicöde_key' : 'ünicöde_välüe', - } + 'ünicöde_key': 'ünicöde_välüe', + } m = message.Message.fromPostArgs(postargs) # Calling m.toFormMarkup with lxml used for ElementTree will throw # a ValueError. @@ -930,7 +914,6 @@ class MessageTest(unittest.TestCase): self._checkForm(html, m, self.action_url, tag_attrs, self.submit_text) - def test_setOpenIDNamespace_invalid(self): m = message.Message() invalid_things = [ @@ -944,19 +927,18 @@ class MessageTest(unittest.TestCase): 'http%3A%2F%2Fspecs.openid.net%2Fauth%2F2.0', # This is a Type URI, not a openid.ns value. 'http://specs.openid.net/auth/2.0/signon', - ] + ] for x in invalid_things: self.failUnlessRaises(message.InvalidOpenIDNamespace, m.setOpenIDNamespace, x, False) - def test_isOpenID1(self): v1_namespaces = [ # Yes, there are two of them. 'http://openid.net/signon/1.1', 'http://openid.net/signon/1.0', - ] + ] for ns in v1_namespaces: m = message.Message(ns) @@ -983,14 +965,13 @@ class MessageTest(unittest.TestCase): m.setOpenIDNamespace(message.THE_OTHER_OPENID1_NS, True) self.failUnless(m.namespaces.isImplicit(message.THE_OTHER_OPENID1_NS)) - def test_explicitOpenID11NSSerialzation(self): m = message.Message() m.setOpenIDNamespace(message.THE_OTHER_OPENID1_NS, implicit=False) post_args = m.toPostArgs() self.failUnlessEqual(post_args, - {'openid.ns':message.THE_OTHER_OPENID1_NS}) + {'openid.ns': message.THE_OTHER_OPENID1_NS}) def test_fromPostArgs_ns11(self): # An example of the stuff that some Drupal installations send us, @@ -1005,12 +986,11 @@ class MessageTest(unittest.TestCase): u'openid.return_to': u'http://drupal.invalid/return_to', u'openid.sreg.required': u'nickname,email', u'openid.trust_root': u'http://drupal.invalid', - } + } m = message.Message.fromPostArgs(query) self.failUnless(m.isOpenID1()) - class NamespaceMapTest(unittest.TestCase): def test_onealias(self): nsm = message.NamespaceMap() @@ -1024,16 +1004,16 @@ class NamespaceMapTest(unittest.TestCase): nsm = message.NamespaceMap() uripat = 'http://example.com/foo%r' - nsm.add(uripat%0) - for n in range(1,23): - self.failUnless(uripat%(n-1) in nsm) - self.failUnless(nsm.isDefined(uripat%(n-1))) - nsm.add(uripat%n) + nsm.add(uripat % 0) + for n in range(1, 23): + self.failUnless(uripat % (n - 1) in nsm) + self.failUnless(nsm.isDefined(uripat % (n - 1))) + nsm.add(uripat % n) for (uri, alias) in nsm.iteritems(): - self.failUnless(uri[22:]==alias[3:]) + self.failUnless(uri[22:] == alias[3:]) - i=0 + i = 0 it = nsm.iterAliases() try: while True: @@ -1042,7 +1022,7 @@ class NamespaceMapTest(unittest.TestCase): except StopIteration: self.failUnless(i == 23) - i=0 + i = 0 it = nsm.iterNamespaceURIs() try: while True: diff --git a/openid/test/test_negotiation.py b/openid/test/test_negotiation.py index c23ef96..8936ecd 100644 --- a/openid/test/test_negotiation.py +++ b/openid/test/test_negotiation.py @@ -1,10 +1,9 @@ - import unittest from openid import association from openid.consumer.consumer import GenericConsumer, ServerError from openid.consumer.discover import OPENID_2_0_TYPE, OpenIDServiceEndpoint -from openid.message import OPENID1_NS, OPENID2_NS, OPENID_NS, Message +from openid.message import OPENID1_NS, OPENID_NS, Message from .support import CatchLogs @@ -29,11 +28,13 @@ class ErrorRaisingConsumer(GenericConsumer): else: return m + class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): """ Test the session type negotiation behavior of an OpenID 2 consumer. """ + def setUp(self): CatchLogs.setUp(self) self.consumer = ErrorRaisingConsumer(store=None) @@ -141,7 +142,7 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): msg.setArg(OPENID_NS, 'session_type', 'DH-SHA1') self.consumer.return_messages = [msg, - Message(self.endpoint.preferredNamespace())] + Message(self.endpoint.preferredNamespace())] self.failUnlessEqual(self.consumer._negotiateAssociation(self.endpoint), None) @@ -160,6 +161,7 @@ class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): self.failUnless(self.consumer._negotiateAssociation(self.endpoint) is assoc) self.failUnlessLogEmpty() + class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): """ Tests for the OpenID 1 consumer association session behavior. See @@ -170,6 +172,7 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): these tests pass openid2-style messages to the openid 1 association processing logic to be sure it ignores the extra data. """ + def setUp(self): CatchLogs.setUp(self) self.consumer = ErrorRaisingConsumer(store=None) @@ -247,12 +250,13 @@ class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): self.failUnless(self.consumer._negotiateAssociation(self.endpoint) is assoc) self.failUnlessLogEmpty() + class TestNegotiatorBehaviors(unittest.TestCase, CatchLogs): def setUp(self): self.allowed_types = [ ('HMAC-SHA1', 'no-encryption'), ('HMAC-SHA256', 'no-encryption'), - ] + ] self.n = association.SessionNegotiator(self.allowed_types) @@ -269,5 +273,6 @@ class TestNegotiatorBehaviors(unittest.TestCase, CatchLogs): for typ in association.getSessionTypes(assoc_type): self.failUnless((assoc_type, typ) in self.n.allowed_types) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_nonce.py b/openid/test/test_nonce.py index fe17151..7b27134 100644 --- a/openid/test/test_nonce.py +++ b/openid/test/test_nonce.py @@ -1,5 +1,4 @@ import re -import time import unittest from openid.store.nonce import checkTimestamp, mkNonce, split as splitNonce @@ -7,6 +6,7 @@ from openid.test import datadriven nonce_re = re.compile(r'\A\d{4}-\d\d-\d\dT\d\d:\d\d:\d\dZ') + class NonceTest(unittest.TestCase): def test_mkNonce(self): nonce = mkNonce() @@ -35,6 +35,7 @@ class NonceTest(unittest.TestCase): self.failUnlessEqual(len(salt), 6) self.failUnlessEqual(et, t) + class BadSplitTest(datadriven.DataDrivenTestCase): cases = [ '', @@ -44,7 +45,7 @@ class BadSplitTest(datadriven.DataDrivenTestCase): '1970.01-01T00:00:00Z', 'Thu Sep 7 13:29:31 PDT 2006', 'monkeys', - ] + ] def __init__(self, nonce_str): datadriven.DataDrivenTestCase.__init__(self, nonce_str) @@ -53,6 +54,7 @@ class BadSplitTest(datadriven.DataDrivenTestCase): def runOneTest(self): self.failUnlessRaises(ValueError, splitNonce, self.nonce_str) + class CheckTimestampTest(datadriven.DataDrivenTestCase): cases = [ # exact, no allowed skew @@ -78,7 +80,7 @@ class CheckTimestampTest(datadriven.DataDrivenTestCase): # malformed nonce string ('monkeys', 0, 0, False), - ] + ] def __init__(self, nonce_string, allowed_skew, now, expected): datadriven.DataDrivenTestCase.__init__( @@ -92,9 +94,11 @@ class CheckTimestampTest(datadriven.DataDrivenTestCase): actual = checkTimestamp(self.nonce_string, self.allowed_skew, self.now) self.failUnlessEqual(bool(self.expected), bool(actual)) + def pyUnitTests(): return datadriven.loadTests(__name__) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/test_openidyadis.py b/openid/test/test_openidyadis.py index 4b7cca4..16aebea 100644 --- a/openid/test/test_openidyadis.py +++ b/openid/test/test_openidyadis.py @@ -15,9 +15,11 @@ XRDS_BOILERPLATE = '''\ </xrds:XRDS> ''' + def mkXRDS(services): return XRDS_BOILERPLATE % (services,) + def mkService(uris=None, type_uris=None, local_id=None, dent=' '): chunks = [dent, '<Service>\n'] dent2 = dent + ' ' @@ -27,7 +29,7 @@ def mkService(uris=None, type_uris=None, local_id=None, dent=' '): if uris: for uri in uris: - if type(uri) is tuple: + if isinstance(uri, tuple): uri, prio = uri else: prio = None @@ -45,18 +47,21 @@ def mkService(uris=None, type_uris=None, local_id=None, dent=' '): return ''.join(chunks) + # Different sets of server URLs for use in the URI tag server_url_options = [ - [], # This case should not generate an endpoint object + [], # This case should not generate an endpoint object ['http://server.url/'], ['https://server.url/'], ['https://server.url/', 'http://server.url/'], ['https://server.url/', 'http://server.url/', 'http://example.server.url/'], - ] +] # Used for generating test data + + def subsets(l): """Generate all non-empty sublists of a list""" subsets_list = [[]] @@ -64,12 +69,13 @@ def subsets(l): subsets_list += [[x] + t for t in subsets_list] return subsets_list + # A couple of example extension type URIs. These are not at all # official, but are just here for testing. ext_types = [ 'http://janrain.com/extension/blah', 'http://openid.net/sreg/1.0', - ] +] # All valid combinations of Type tags that should produce an OpenID endpoint type_uri_options = [ @@ -81,14 +87,14 @@ type_uri_options = [ # All combinations of extension types (including empty extenstion list) for exts in subsets(ext_types) - ] +] # Range of valid Delegate tag values for generating test data local_id_options = [ None, 'http://vanity.domain/', 'https://somewhere/yadis/', - ] +] # All combinations of valid URIs, Type URIs and Delegate tags data = [ @@ -96,7 +102,8 @@ data = [ for uris in server_url_options for type_uris in type_uri_options for local_id in local_id_options - ] +] + class OpenIDYadisTest(unittest.TestCase): def __init__(self, uris, type_uris, local_id): @@ -129,8 +136,7 @@ class OpenIDYadisTest(unittest.TestCase): self.failUnlessEqual(len(self.uris), len(endpoints)) # So that we can check equality on the endpoint types - type_uris = list(self.type_uris) - type_uris.sort() + type_uris = sorted(self.type_uris) seen_uris = [] for endpoint in endpoints: @@ -143,19 +149,18 @@ class OpenIDYadisTest(unittest.TestCase): self.failUnlessEqual(self.local_id, endpoint.local_id) # and types - actual_types = list(endpoint.type_uris) - actual_types.sort() + actual_types = sorted(endpoint.type_uris) self.failUnlessEqual(actual_types, type_uris) # So that they will compare equal, because we don't care what # order they are in seen_uris.sort() - uris = list(self.uris) - uris.sort() + uris = sorted(self.uris) # Make sure we saw all URIs, and saw each one once self.failUnlessEqual(uris, seen_uris) + def pyUnitTests(): cases = [] for args in data: diff --git a/openid/test/test_pape_draft2.py b/openid/test/test_pape_draft2.py index f468015..be76550 100644 --- a/openid/test/test_pape_draft2.py +++ b/openid/test/test_pape_draft2.py @@ -1,8 +1,7 @@ - import unittest from openid.extensions.draft import pape2 as pape -from openid.message import * +from openid.message import OPENID2_NS, Message from openid.server import server @@ -39,14 +38,15 @@ class PapeRequestTestCase(unittest.TestCase): self.req.addPolicyURI('http://zig') self.failUnlessEqual({'preferred_auth_policies': 'http://uri http://zig'}, self.req.getExtensionArgs()) self.req.max_auth_age = 789 - self.failUnlessEqual({'preferred_auth_policies': 'http://uri http://zig', 'max_auth_age': '789'}, self.req.getExtensionArgs()) + self.failUnlessEqual({'preferred_auth_policies': 'http://uri http://zig', 'max_auth_age': '789'}, + self.req.getExtensionArgs()) def test_parseExtensionArgs(self): args = {'preferred_auth_policies': 'http://foo http://bar', 'max_auth_age': '9'} self.req.parseExtensionArgs(args) self.failUnlessEqual(9, self.req.max_auth_age) - self.failUnlessEqual(['http://foo','http://bar'], self.req.preferred_auth_policies) + self.failUnlessEqual(['http://foo', 'http://bar'], self.req.preferred_auth_policies) def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}) @@ -55,12 +55,12 @@ class PapeRequestTestCase(unittest.TestCase): def test_fromOpenIDRequest(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.preferred_auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'pape.max_auth_age': '5476' - }) + 'mode': 'checkid_setup', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.preferred_auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'pape.max_auth_age': '5476' + }) oid_req = server.OpenIDRequest() oid_req.message = openid_req_msg req = pape.Request.fromOpenIDRequest(oid_req) @@ -81,6 +81,7 @@ class PapeRequestTestCase(unittest.TestCase): pape.AUTH_MULTI_FACTOR_PHYSICAL]) self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], pt) + class DummySuccessResponse: def __init__(self, message, signed_stuff): self.message = message @@ -89,6 +90,7 @@ class DummySuccessResponse: def getSignedNS(self, ns_uri): return self.signed_stuff + class PapeResponseTestCase(unittest.TestCase): def setUp(self): self.req = pape.Response() @@ -122,9 +124,13 @@ class PapeResponseTestCase(unittest.TestCase): self.req.addPolicyURI('http://zig') self.failUnlessEqual({'auth_policies': 'http://uri http://zig'}, self.req.getExtensionArgs()) self.req.auth_time = "1776-07-04T14:43:12Z" - self.failUnlessEqual({'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z"}, self.req.getExtensionArgs()) + self.failUnlessEqual({'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z"}, + self.req.getExtensionArgs()) self.req.nist_auth_level = 3 - self.failUnlessEqual({'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z", 'nist_auth_level': '3'}, self.req.getExtensionArgs()) + self.failUnlessEqual({'auth_policies': 'http://uri http://zig', + 'auth_time': "1776-07-04T14:43:12Z", + 'nist_auth_level': '3'}, + self.req.getExtensionArgs()) def test_getExtensionArgs_error_auth_age(self): self.req.auth_time = "long ago" @@ -143,13 +149,13 @@ class PapeResponseTestCase(unittest.TestCase): 'auth_time': '1970-01-01T00:00:00Z'} self.req.parseExtensionArgs(args) self.failUnlessEqual('1970-01-01T00:00:00Z', self.req.auth_time) - self.failUnlessEqual(['http://foo','http://bar'], self.req.auth_policies) + self.failUnlessEqual(['http://foo', 'http://bar'], self.req.auth_policies) def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}) self.failUnlessEqual(None, self.req.auth_time) self.failUnlessEqual([], self.req.auth_policies) - + def test_parseExtensionArgs_strict_bogus1(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': 'yesterday'} @@ -162,13 +168,13 @@ class PapeResponseTestCase(unittest.TestCase): 'nist_auth_level': 'some'} self.failUnlessRaises(ValueError, self.req.parseExtensionArgs, args, True) - + def test_parseExtensionArgs_strict_good(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': '1970-01-01T00:00:00Z', 'nist_auth_level': '0'} self.req.parseExtensionArgs(args, True) - self.failUnlessEqual(['http://foo','http://bar'], self.req.auth_policies) + self.failUnlessEqual(['http://foo', 'http://bar'], self.req.auth_policies) self.failUnlessEqual('1970-01-01T00:00:00Z', self.req.auth_time) self.failUnlessEqual(0, self.req.nist_auth_level) @@ -177,21 +183,21 @@ class PapeResponseTestCase(unittest.TestCase): 'auth_time': 'when the cows come home', 'nist_auth_level': 'some'} self.req.parseExtensionArgs(args) - self.failUnlessEqual(['http://foo','http://bar'], self.req.auth_policies) + self.failUnlessEqual(['http://foo', 'http://bar'], self.req.auth_policies) self.failUnlessEqual(None, self.req.auth_time) self.failUnlessEqual(None, self.req.nist_auth_level) def test_fromSuccessResponse(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': 'id_res', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'pape.auth_time': '1970-01-01T00:00:00Z' + }) signed_stuff = { - 'auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'auth_time': '1970-01-01T00:00:00Z' + 'auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'auth_time': '1970-01-01T00:00:00Z' } oid_req = DummySuccessResponse(openid_req_msg, signed_stuff) req = pape.Response.fromSuccessResponse(oid_req) @@ -200,12 +206,12 @@ class PapeResponseTestCase(unittest.TestCase): def test_fromSuccessResponseNoSignedArgs(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': 'id_res', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'pape.auth_time': '1970-01-01T00:00:00Z' + }) signed_stuff = {} diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index 9693fad..243eae5 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -1,9 +1,8 @@ - import unittest import warnings from openid.extensions.draft import pape5 as pape -from openid.message import * +from openid.message import OPENID2_NS, Message from openid.server import server warnings.filterwarnings('ignore', module=__name__, @@ -111,7 +110,7 @@ class PapeRequestTestCase(unittest.TestCase): ('auth_level.ns.%s' % alias2): uri2, 'preferred_auth_level_types': ' '.join([alias, alias2]), 'preferred_auth_policies': '', - } + } self.failUnlessEqual(expected_args, self.req.getExtensionArgs()) @@ -127,7 +126,7 @@ class PapeRequestTestCase(unittest.TestCase): ('auth_level.ns.%s' % alias2): uri2, 'preferred_auth_level_types': ' '.join([alias, alias2]), 'preferred_auth_policies': '', - } + } # Check request object state self.req.parseExtensionArgs(request_args, is_openid1=False, strict=False) @@ -141,8 +140,8 @@ class PapeRequestTestCase(unittest.TestCase): def test_parseExtensionArgsWithAuthLevels_openID1(self): request_args = { - 'preferred_auth_level_types':'nist jisa', - } + 'preferred_auth_level_types': 'nist jisa', + } expected_auth_levels = [pape.LEVELS_NIST, pape.LEVELS_JISA] self.req.parseExtensionArgs(request_args, is_openid1=True) self.assertEqual(expected_auth_levels, @@ -159,12 +158,12 @@ class PapeRequestTestCase(unittest.TestCase): request_args, is_openid1=False, strict=True) def test_parseExtensionArgs_ignoreBadAuthLevels(self): - request_args = {'preferred_auth_level_types':'monkeys'} + request_args = {'preferred_auth_level_types': 'monkeys'} self.req.parseExtensionArgs(request_args, False) self.assertEqual([], self.req.preferred_auth_level_types) def test_parseExtensionArgs_strictBadAuthLevels(self): - request_args = {'preferred_auth_level_types':'monkeys'} + request_args = {'preferred_auth_level_types': 'monkeys'} self.failUnlessRaises(ValueError, self.req.parseExtensionArgs, request_args, is_openid1=False, strict=True) @@ -173,7 +172,7 @@ class PapeRequestTestCase(unittest.TestCase): 'max_auth_age': '9'} self.req.parseExtensionArgs(args, False) self.failUnlessEqual(9, self.req.max_auth_age) - self.failUnlessEqual(['http://foo','http://bar'], + self.failUnlessEqual(['http://foo', 'http://bar'], self.req.preferred_auth_policies) self.failUnlessEqual([], self.req.preferred_auth_level_types) @@ -191,12 +190,12 @@ class PapeRequestTestCase(unittest.TestCase): def test_fromOpenIDRequest(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.preferred_auth_policies': ' '.join(policy_uris), - 'pape.max_auth_age': '5476' - }) + 'mode': 'checkid_setup', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.preferred_auth_policies': ' '.join(policy_uris), + 'pape.max_auth_age': '5476' + }) oid_req = server.OpenIDRequest() oid_req.message = openid_req_msg req = pape.Request.fromOpenIDRequest(oid_req) @@ -217,6 +216,7 @@ class PapeRequestTestCase(unittest.TestCase): pape.AUTH_MULTI_FACTOR_PHYSICAL]) self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], pt) + class DummySuccessResponse: def __init__(self, message, signed_stuff): self.message = message @@ -228,6 +228,7 @@ class DummySuccessResponse: def getSignedNS(self, ns_uri): return self.signed_stuff + class PapeResponseTestCase(unittest.TestCase): def setUp(self): self.resp = pape.Response() @@ -293,7 +294,7 @@ class PapeResponseTestCase(unittest.TestCase): 'auth_time': '1970-01-01T00:00:00Z'} self.resp.parseExtensionArgs(args, is_openid1=False) self.failUnlessEqual('1970-01-01T00:00:00Z', self.resp.auth_time) - self.failUnlessEqual(['http://foo','http://bar'], + self.failUnlessEqual(['http://foo', 'http://bar'], self.resp.auth_policies) def test_parseExtensionArgs_valid_none(self): @@ -327,7 +328,7 @@ class PapeResponseTestCase(unittest.TestCase): args = { 'auth_policies': ' '.join(policies), - } + } self.resp.parseExtensionArgs(args, is_openid1=False, strict=False) @@ -339,7 +340,7 @@ class PapeResponseTestCase(unittest.TestCase): args = { 'auth_policies': ' '.join(policies), - } + } self.failUnlessRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) @@ -385,7 +386,7 @@ class PapeResponseTestCase(unittest.TestCase): 'auth_level.nist': '0', 'auth_level.ns.nist': pape.LEVELS_NIST} self.resp.parseExtensionArgs(args, is_openid1=False, strict=True) - self.failUnlessEqual(['http://foo','http://bar'], + self.failUnlessEqual(['http://foo', 'http://bar'], self.resp.auth_policies) self.failUnlessEqual('1970-01-01T00:00:00Z', self.resp.auth_time) self.failUnlessEqual(0, self.resp.nist_auth_level) @@ -395,7 +396,7 @@ class PapeResponseTestCase(unittest.TestCase): 'auth_time': 'when the cows come home', 'nist_auth_level': 'some'} self.resp.parseExtensionArgs(args, is_openid1=False) - self.failUnlessEqual(['http://foo','http://bar'], + self.failUnlessEqual(['http://foo', 'http://bar'], self.resp.auth_policies) self.failUnlessEqual(None, self.resp.auth_time) self.failUnlessEqual(None, self.resp.nist_auth_level) @@ -403,15 +404,15 @@ class PapeResponseTestCase(unittest.TestCase): def test_fromSuccessResponse(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join(policy_uris), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': 'id_res', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.auth_policies': ' '.join(policy_uris), + 'pape.auth_time': '1970-01-01T00:00:00Z' + }) signed_stuff = { - 'auth_policies': ' '.join(policy_uris), - 'auth_time': '1970-01-01T00:00:00Z' + 'auth_policies': ' '.join(policy_uris), + 'auth_time': '1970-01-01T00:00:00Z' } oid_req = DummySuccessResponse(openid_req_msg, signed_stuff) req = pape.Response.fromSuccessResponse(oid_req) @@ -421,12 +422,12 @@ class PapeResponseTestCase(unittest.TestCase): def test_fromSuccessResponseNoSignedArgs(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join(policy_uris), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': 'id_res', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.auth_policies': ' '.join(policy_uris), + 'pape.auth_time': '1970-01-01T00:00:00Z' + }) signed_stuff = {} @@ -438,5 +439,6 @@ class PapeResponseTestCase(unittest.TestCase): resp = pape.Response.fromSuccessResponse(oid_req) self.failUnless(resp is None) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_parsehtml.py b/openid/test/test_parsehtml.py index fe90ac7..a3c038d 100644 --- a/openid/test/test_parsehtml.py +++ b/openid/test/test_parsehtml.py @@ -18,9 +18,9 @@ class _TestCase(unittest.TestCase): def runTest(self): p = YadisHTMLParser() - try: + try: p.feed(self.case) - except ParseDone, why: + except ParseDone as why: found = why[0] # make sure we protect outselves against accidental bogus @@ -44,6 +44,7 @@ class _TestCase(unittest.TestCase): self.__class__.__module__, os.path.basename(self.filename)) + def parseCases(data): cases = [] for chunk in data.split('\f\n'): @@ -51,6 +52,7 @@ def parseCases(data): cases.append((expected, case)) return cases + def pyUnitTests(): """Make a pyunit TestSuite from a file defining test cases.""" s = unittest.TestSuite() @@ -58,10 +60,12 @@ def pyUnitTests(): s.addTest(_TestCase(filename, str(test_num), expected, case)) return s + def test(): runner = unittest.TextTestRunner() return runner.run(pyUnitTests()) + filenames = ['data/test1-parsehtml.txt'] default_test_files = [] @@ -70,6 +74,7 @@ for filename in filenames: full_name = os.path.join(base, filename) default_test_files.append(full_name) + def getCases(test_files=default_test_files): cases = [] for filename in test_files: diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index d37b594..c106981 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -11,8 +11,6 @@ from openid.yadis import services from openid.yadis.discover import DiscoveryFailure, DiscoveryResult -# Too many methods does not apply to unit test objects -#pylint:disable-msg=R0904 class TestBuildDiscoveryURL(unittest.TestCase): """Tests for building the discovery URL from a realm and a return_to URL @@ -44,6 +42,7 @@ class TestBuildDiscoveryURL(unittest.TestCase): self.failUnlessDiscoURL('http://*.example.com:8001/foo', 'http://www.example.com:8001/foo') + class TestExtractReturnToURLs(unittest.TestCase): disco_url = 'http://example.com/' @@ -141,8 +140,7 @@ class TestExtractReturnToURLs(unittest.TestCase): </Service> </XRD> </xrds:XRDS> -''', ['http://rp.example.com/return', - 'http://other.rp.example.com/return']) +''', ['http://rp.example.com/return', 'http://other.rp.example.com/return']) def test_twoEntries_withOther(self): self.failUnlessXRDSHasReturnURLs('''\ @@ -165,9 +163,7 @@ class TestExtractReturnToURLs(unittest.TestCase): </Service> </XRD> </xrds:XRDS> -''', ['http://rp.example.com/return', - 'http://other.rp.example.com/return']) - +''', ['http://rp.example.com/return', 'http://other.rp.example.com/return']) class TestReturnToMatches(unittest.TestCase): @@ -203,6 +199,7 @@ class TestReturnToMatches(unittest.TestCase): [r], 'http://example.com/xss_exploit')) + class TestVerifyReturnTo(unittest.TestCase, CatchLogs): def setUp(self): @@ -210,7 +207,7 @@ class TestVerifyReturnTo(unittest.TestCase, CatchLogs): def tearDown(self): CatchLogs.tearDown(self) - + def test_bogusRealm(self): self.failIf(trustroot.verifyReturnTo('', 'http://example.com/')) @@ -250,5 +247,6 @@ class TestVerifyReturnTo(unittest.TestCase, CatchLogs): trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) self.failUnlessLogMatches("Attempting to verify") + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_server.py b/openid/test/test_server.py index a19a734..171be4c 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -2,9 +2,11 @@ """ import cgi import unittest +from functools import partial from urlparse import urlparse from openid import association, cryptutil, oidutil +from openid.consumer.consumer import DiffieHellmanSHA256ConsumerSession from openid.message import IDENTIFIER_SELECT, OPENID1_NS, OPENID1_URL_LIMIT, OPENID2_NS, OPENID_NS, Message, no_default from openid.server import server from openid.store import memstore @@ -16,9 +18,13 @@ from openid.test.support import CatchLogs # for more, see /etc/ssh/moduli -ALT_MODULUS = 0xCAADDDEC1667FC68B5FA15D53C4E1532DD24561A1A2D47A12C01ABEA1E00731F6921AAC40742311FDF9E634BB7131BEE1AF240261554389A910425E044E88C8359B010F5AD2B80E29CB1A5B027B19D9E01A6F63A6F45E5D7ED2FF6A2A0085050A7D0CF307C3DB51D2490355907B4427C23A98DF1EB8ABEF2BA209BB7AFFE86A7 +ALT_MODULUS = int('1423261515703355186607439952816216983770573549498844689430217675736088990483613604225135575535147900' + '4551229946895343158530081254885941985717109436635815890343316791551733211386105974742540867014420109' + '9811846875730766487278261498262568348338476437200556998366087779709990807518291581860338635288400119' + '293970087') ALT_GEN = 5 + class TestProtocolError(unittest.TestCase): def test_browserWithReturnTo(self): return_to = "http://rp.unittest/consumer" @@ -27,13 +33,13 @@ class TestProtocolError(unittest.TestCase): 'openid.mode': 'monkeydance', 'openid.identity': 'http://wagu.unittest/', 'openid.return_to': return_to, - }) + }) e = server.ProtocolError(args, "plucky") self.failUnless(e.hasReturnTo()) expected_args = { 'openid.mode': ['error'], 'openid.error': ['plucky'], - } + } rt_base, result_args = e.encodeToURL().split('?', 1) result_args = cgi.parse_qs(result_args) @@ -48,14 +54,14 @@ class TestProtocolError(unittest.TestCase): 'openid.identity': 'http://wagu.unittest/', 'openid.claimed_id': 'http://wagu.unittest/', 'openid.return_to': return_to, - }) + }) e = server.ProtocolError(args, "plucky") self.failUnless(e.hasReturnTo()) expected_args = { 'openid.ns': [OPENID2_NS], 'openid.mode': ['error'], 'openid.error': ['plucky'], - } + } rt_base, result_args = e.encodeToURL().split('?', 1) result_args = cgi.parse_qs(result_args) @@ -70,15 +76,9 @@ class TestProtocolError(unittest.TestCase): 'openid.identity': 'http://wagu.unittest/', 'openid.claimed_id': 'http://wagu.unittest/', 'openid.return_to': return_to, - }) + }) e = server.ProtocolError(args, "plucky") self.failUnless(e.hasReturnTo()) - expected_args = { - 'openid.ns': [OPENID2_NS], - 'openid.mode': ['error'], - 'openid.error': ['plucky'], - } - self.failUnless(e.whichEncoding() == server.ENCODE_HTML_FORM) self.failUnless(e.toFormMarkup() == e.toMessage().toFormMarkup( args.getArg(OPENID_NS, 'return_to'))) @@ -90,13 +90,13 @@ class TestProtocolError(unittest.TestCase): 'openid.mode': 'monkeydance', 'openid.identity': 'http://wagu.unittest/', 'openid.return_to': return_to, - }) + }) e = server.ProtocolError(args, "plucky") self.failUnless(e.hasReturnTo()) expected_args = { 'openid.mode': ['error'], 'openid.error': ['plucky'], - } + } self.failUnless(e.whichEncoding() == server.ENCODE_URL) @@ -109,7 +109,7 @@ class TestProtocolError(unittest.TestCase): args = Message.fromPostArgs({ 'openid.mode': 'zebradance', 'openid.identity': 'http://wagu.unittest/', - }) + }) e = server.ProtocolError(args, "waffles") self.failIf(e.hasReturnTo()) expected = """error:waffles @@ -117,7 +117,6 @@ mode:error """ self.failUnlessEqual(e.encodeToKVForm(), expected) - def test_noMessage(self): e = server.ProtocolError(None, "no moar pancakes") self.failIf(e.hasReturnTo()) @@ -146,14 +145,14 @@ class TestDecode(unittest.TestCase): args = { 'pony': 'spotted', 'sreg.mutant_power': 'decaffinator', - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_bad(self): args = { 'openid.mode': 'twos-compliment', 'openid.pants': 'zippered', - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_dictOfLists(self): @@ -163,10 +162,10 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, - } + } try: result = self.decode(args) - except TypeError, err: + except TypeError as err: self.failUnless(str(err).find('values') != -1, err) else: self.fail("Expected TypeError, but got result %s" % (result,)) @@ -180,7 +179,7 @@ class TestDecode(unittest.TestCase): 'openid.trust_root': self.tr_url, # should be ignored 'openid.some.extension': 'junk', - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) self.failUnlessEqual(r.mode, "checkid_immediate") @@ -197,7 +196,7 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) self.failUnlessEqual(r.mode, "checkid_setup") @@ -215,7 +214,7 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) self.failUnlessEqual(r.mode, "checkid_setup") @@ -233,7 +232,7 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoIdentityOpenID2(self): @@ -243,7 +242,7 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) self.failUnlessEqual(r.mode, "checkid_setup") @@ -261,7 +260,7 @@ class TestDecode(unittest.TestCase): 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.trust_root': self.tr_url, - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoReturnOpenID2(self): @@ -276,7 +275,7 @@ class TestDecode(unittest.TestCase): 'openid.claimed_id': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.realm': self.tr_url, - } + } self.failUnless(isinstance(self.decode(args), server.CheckIDRequest)) req = self.decode(args) @@ -294,7 +293,7 @@ class TestDecode(unittest.TestCase): 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_checkidSetupBadReturn(self): @@ -303,10 +302,10 @@ class TestDecode(unittest.TestCase): 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': 'not a url', - } + } try: result = self.decode(args) - except server.ProtocolError, err: + except server.ProtocolError as err: self.failUnless(err.openid_message) else: self.fail("Expected ProtocolError, instead returned with %s" % @@ -319,10 +318,10 @@ class TestDecode(unittest.TestCase): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': 'http://not-the-return-place.unittest/', - } + } try: result = self.decode(args) - except server.UntrustedReturnURL, err: + except server.UntrustedReturnURL as err: self.failUnless(err.openid_message) else: self.fail("Expected UntrustedReturnURL, instead returned with %s" % @@ -338,13 +337,12 @@ class TestDecode(unittest.TestCase): 'openid.return_to': 'signedval2', 'openid.response_nonce': 'signedval3', 'openid.baz': 'unsigned', - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckAuthRequest)) self.failUnlessEqual(r.mode, 'check_authentication') self.failUnlessEqual(r.sig, 'sigblob') - def test_checkAuthMissingSignature(self): args = { 'openid.mode': 'check_authentication', @@ -353,10 +351,9 @@ class TestDecode(unittest.TestCase): 'openid.foo': 'signedval1', 'openid.bar': 'signedval2', 'openid.baz': 'unsigned', - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) - def test_checkAuthAndInvalidate(self): args = { 'openid.mode': 'check_authentication', @@ -368,18 +365,17 @@ class TestDecode(unittest.TestCase): 'openid.return_to': 'signedval2', 'openid.response_nonce': 'signedval3', 'openid.baz': 'unsigned', - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckAuthRequest)) self.failUnlessEqual(r.invalidate_handle, '[[SMART_handle]]') - def test_associateDH(self): args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", - } + } r = self.decode(args) self.failUnless(isinstance(r, server.AssociateRequest)) self.failUnlessEqual(r.mode, "associate") @@ -392,20 +388,18 @@ class TestDecode(unittest.TestCase): args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', - } + } # Using DH-SHA1 without supplying dh_consumer_public is an error. self.failUnlessRaises(server.ProtocolError, self.decode, args) - def test_associateDHpubKeyNotB64(self): args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "donkeydonkeydonkey", - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) - def test_associateDHModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen args = { @@ -413,8 +407,8 @@ class TestDecode(unittest.TestCase): 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': cryptutil.longToBase64(ALT_MODULUS), - 'openid.dh_gen': cryptutil.longToBase64(ALT_GEN) , - } + 'openid.dh_gen': cryptutil.longToBase64(ALT_GEN), + } r = self.decode(args) self.failUnless(isinstance(r, server.AssociateRequest)) self.failUnlessEqual(r.mode, "associate") @@ -424,7 +418,6 @@ class TestDecode(unittest.TestCase): self.failUnlessEqual(r.session.dh.generator, ALT_GEN) self.failUnless(r.session.consumer_pubkey) - def test_associateDHCorruptModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen args = { @@ -433,10 +426,9 @@ class TestDecode(unittest.TestCase): 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': 'pizza', 'openid.dh_gen': 'gnocchi', - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) - def test_associateDHMissingModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen args = { @@ -444,7 +436,7 @@ class TestDecode(unittest.TestCase): 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': 'pizza', - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) @@ -461,20 +453,18 @@ class TestDecode(unittest.TestCase): # self.failUnlessRaises(server.ProtocolError, self.decode, args) # test_associateDHInvalidModGen.todo = "low-priority feature" - def test_associateWeirdSession(self): args = { 'openid.mode': 'associate', 'openid.session_type': 'FLCL6', 'openid.dh_consumer_public': "YQ==\n", - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) - def test_associatePlain(self): args = { 'openid.mode': 'associate', - } + } r = self.decode(args) self.failUnless(isinstance(r, server.AssociateRequest)) self.failUnlessEqual(r.mode, "associate") @@ -485,16 +475,16 @@ class TestDecode(unittest.TestCase): args = { 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "my public keeey", - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_invalidns(self): - args = {'openid.ns': 'Tuesday', - 'openid.mode': 'associate'} + args = {'openid.ns': 'Tuesday', + 'openid.mode': 'associate'} try: r = self.decode(args) - except server.ProtocolError, err: + except server.ProtocolError as err: # Assert that the ProtocolError does have a Message attached # to it, even though the request wasn't a well-formed Message. self.failUnless(err.openid_message) @@ -519,12 +509,12 @@ class TestEncode(unittest.TestCase): issued. """ request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ @@ -533,12 +523,12 @@ class TestEncode(unittest.TestCase): 'identity': request.identity, 'claimed_id': request.identity, 'return_to': request.return_to, - }) + }) self.failIf(response.renderAsForm()) self.failUnless(response.whichEncoding() == server.ENCODE_URL) webresponse = self.encode(response) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) def test_id_res_OpenID2_POST(self): """ @@ -547,12 +537,12 @@ class TestEncode(unittest.TestCase): returned. """ request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ @@ -561,7 +551,7 @@ class TestEncode(unittest.TestCase): 'identity': request.identity, 'claimed_id': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + }) self.failUnless(response.renderAsForm()) self.failUnless(len(response.encodeToURL()) > OPENID1_URL_LIMIT) @@ -571,12 +561,12 @@ class TestEncode(unittest.TestCase): def test_toFormMarkup(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ @@ -585,19 +575,19 @@ class TestEncode(unittest.TestCase): 'identity': request.identity, 'claimed_id': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + }) - form_markup = response.toFormMarkup({'foo':'bar'}) + form_markup = response.toFormMarkup({'foo': 'bar'}) self.failUnless(' foo="bar"' in form_markup) def test_toHTML(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ @@ -606,7 +596,7 @@ class TestEncode(unittest.TestCase): 'identity': request.identity, 'claimed_id': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + }) html = response.toHTML() self.failUnless('<html>' in html) self.failUnless('</html>' in html) @@ -622,19 +612,19 @@ class TestEncode(unittest.TestCase): place to preserve the status quo for OpenID 1. """ request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'mode': 'id_res', 'identity': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + }) self.failIf(response.renderAsForm()) self.failUnless(len(response.encodeToURL()) > OPENID1_URL_LIMIT) @@ -644,22 +634,22 @@ class TestEncode(unittest.TestCase): def test_id_res(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'mode': 'id_res', 'identity': request.identity, 'return_to': request.return_to, - }) + }) webresponse = self.encode(response) self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] self.failUnless(location.startswith(request.return_to), @@ -672,34 +662,34 @@ class TestEncode(unittest.TestCase): def test_cancel(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'mode': 'cancel', - }) + }) webresponse = self.encode(response) self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) def test_cancelToForm(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'mode': 'cancel', - }) + }) form = response.toFormMarkup() self.failUnless(form) @@ -726,7 +716,7 @@ class TestEncode(unittest.TestCase): response.fields = Message.fromOpenIDArgs({ 'is_valid': 'true', 'invalidate_handle': 'xXxX:xXXx' - }) + }) body = """invalidate_handle:xXxX:xXXx is_valid:true """ @@ -738,7 +728,7 @@ is_valid:true def test_unencodableError(self): args = Message.fromPostArgs({ 'openid.identity': 'http://limu.unittest/', - }) + }) e = server.ProtocolError(args, "wet paint") self.failUnlessRaises(server.EncodingError, self.encode, e) @@ -746,15 +736,14 @@ is_valid:true args = Message.fromPostArgs({ 'openid.mode': 'associate', 'openid.identity': 'http://limu.unittest/', - }) - body="error:snoot\nmode:error\n" + }) + body = "error:snoot\nmode:error\n" webresponse = self.encode(server.ProtocolError(args, "snoot")) self.failUnlessEqual(webresponse.code, server.HTTP_ERROR) self.failUnlessEqual(webresponse.headers, {}) self.failUnlessEqual(webresponse.body, body) - class TestSigningEncode(unittest.TestCase): def setUp(self): self._dumb_key = server.Signatory._dumb_key @@ -762,19 +751,19 @@ class TestSigningEncode(unittest.TestCase): self.store = memstore.MemoryStore() self.server = server.Server(self.store, "http://signing.unittest/enc") self.request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) self.request.message = Message(OPENID2_NS) self.response = server.OpenIDResponse(self.request) self.response.fields = Message.fromOpenIDArgs({ 'mode': 'id_res', 'identity': self.request.identity, 'return_to': self.request.return_to, - }) + }) self.signatory = server.Signatory(self.store) self.encoder = server.SigningEncoder(self.signatory) self.encode = self.encoder.encode @@ -788,7 +777,7 @@ class TestSigningEncode(unittest.TestCase): self.request.assoc_handle = assoc_handle webresponse = self.encode(self.response) self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] query = cgi.parse_qs(urlparse(location)[4]) @@ -799,7 +788,7 @@ class TestSigningEncode(unittest.TestCase): def test_idresDumb(self): webresponse = self.encode(self.response) self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] query = cgi.parse_qs(urlparse(location)[4]) @@ -813,18 +802,18 @@ class TestSigningEncode(unittest.TestCase): def test_cancel(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields.setArg(OPENID_NS, 'mode', 'cancel') webresponse = self.encode(response) self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] query = cgi.parse_qs(urlparse(location)[4]) self.failIf('openid.sig' in query, response.fields.toPostArgs()) @@ -847,18 +836,19 @@ class TestSigningEncode(unittest.TestCase): self.response.fields.setArg(OPENID_NS, 'sig', 'priorSig==') self.failUnlessRaises(server.AlreadySigned, self.encode, self.response) + class TestCheckID(unittest.TestCase): def setUp(self): self.op_endpoint = 'http://endpoint.unittest/' self.store = memstore.MemoryStore() self.server = server.Server(self.store, self.op_endpoint) self.request = server.CheckIDRequest( - identity = 'http://bambam.unittest/', - trust_root = 'http://bar.unittest/', - return_to = 'http://bar.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bambam.unittest/', + trust_root='http://bar.unittest/', + return_to='http://bar.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) self.request.message = Message(OPENID2_NS) def test_trustRootInvalid(self): @@ -878,7 +868,7 @@ class TestCheckID(unittest.TestCase): self.request.message = sentinel try: result = self.request.trustRootValid() - except server.MalformedTrustRoot, why: + except server.MalformedTrustRoot as why: self.failUnless(sentinel is why.openid_message) else: self.fail('Expected MalformedTrustRoot exception. Got %r' @@ -886,12 +876,12 @@ class TestCheckID(unittest.TestCase): def test_trustRootValidNoReturnTo(self): request = server.CheckIDRequest( - identity = 'http://bambam.unittest/', - trust_root = 'http://bar.unittest/', - return_to = None, - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bambam.unittest/', + trust_root='http://bar.unittest/', + return_to=None, + immediate=False, + op_endpoint=self.server.op_endpoint, + ) self.failUnless(request.trustRootValid()) @@ -909,6 +899,7 @@ class TestCheckID(unittest.TestCase): # Ensure that exceptions are passed through sentinel = Exception() + def vrfyExc(trust_root, return_to): self.failUnlessEqual(self.request.trust_root, trust_root) self.failUnlessEqual(self.request.return_to, return_to) @@ -916,7 +907,7 @@ class TestCheckID(unittest.TestCase): try: withVerifyReturnTo(vrfyExc, self.request.returnToVerified) - except Exception, e: + except Exception as e: self.failUnless(e is sentinel, e) # Ensure that True and False are passed through unchanged @@ -938,7 +929,7 @@ class TestCheckID(unittest.TestCase): ('mode', 'id_res'), ('return_to', self.request.return_to), ('op_endpoint', self.op_endpoint), - ] + ] if identity: expected_list.append(('identity', identity)) if claimed_id: @@ -1140,13 +1131,13 @@ class TestCheckID(unittest.TestCase): def test_fromMessageWithEmptyTrustRoot(self): return_to = u'http://someplace.invalid/?go=thing' msg = Message.fromPostArgs({ - u'openid.assoc_handle': u'{blah}{blah}{OZivdQ==}', - u'openid.claimed_id': u'http://delegated.invalid/', - u'openid.identity': u'http://op-local.example.com/', - u'openid.mode': u'checkid_setup', - u'openid.ns': u'http://openid.net/signon/1.0', - u'openid.return_to': return_to, - u'openid.trust_root': u''}) + u'openid.assoc_handle': u'{blah}{blah}{OZivdQ==}', + u'openid.claimed_id': u'http://delegated.invalid/', + u'openid.identity': u'http://op-local.example.com/', + u'openid.mode': u'checkid_setup', + u'openid.ns': u'http://openid.net/signon/1.0', + u'openid.return_to': return_to, + u'openid.trust_root': u''}) result = server.CheckIDRequest.fromMessage(msg, self.server.op_endpoint) @@ -1172,7 +1163,7 @@ class TestCheckID(unittest.TestCase): 'identity': identity, 'trust_root': 'http://bar.unittest/', 'return_to': 'http://bar.unittest/999', - }) + }) self.request = server.CheckIDRequest.fromMessage(reqmessage, None) answer = self.request.answer(True) @@ -1180,7 +1171,7 @@ class TestCheckID(unittest.TestCase): ('mode', 'id_res'), ('return_to', self.request.return_to), ('identity', identity), - ] + ] for k, expected in expected_list: actual = answer.fields.getArg(OPENID_NS, k) @@ -1241,7 +1232,7 @@ class TestCheckID(unittest.TestCase): answer = self.request.answer(False) self.failUnlessEqual(answer.fields.getArgs(OPENID_NS), { 'mode': 'cancel', - }) + }) def test_encodeToURL(self): server_url = 'http://openid-server.unittest/' @@ -1262,8 +1253,8 @@ class TestCheckID(unittest.TestCase): rt, query_string = url.split('?') self.failUnlessEqual(self.request.return_to, rt) query = dict(cgi.parse_qsl(query_string)) - self.failUnlessEqual(query, {'openid.mode':'cancel', - 'openid.ns':OPENID2_NS}) + self.failUnlessEqual(query, {'openid.mode': 'cancel', + 'openid.ns': OPENID2_NS}) def test_getCancelURLimmed(self): self.request.mode = 'checkid_immediate' @@ -1271,7 +1262,6 @@ class TestCheckID(unittest.TestCase): self.failUnlessRaises(ValueError, self.request.getCancelURL) - class TestCheckIDExtension(unittest.TestCase): def setUp(self): @@ -1279,18 +1269,17 @@ class TestCheckIDExtension(unittest.TestCase): self.store = memstore.MemoryStore() self.server = server.Server(self.store, self.op_endpoint) self.request = server.CheckIDRequest( - identity = 'http://bambam.unittest/', - trust_root = 'http://bar.unittest/', - return_to = 'http://bar.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bambam.unittest/', + trust_root='http://bar.unittest/', + return_to='http://bar.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) self.request.message = Message(OPENID2_NS) self.response = server.OpenIDResponse(self.request) self.response.fields.setArg(OPENID_NS, 'mode', 'id_res') self.response.fields.setArg(OPENID_NS, 'blue', 'star') - def test_addField(self): namespace = 'something:' self.response.fields.setArg(namespace, 'bright', 'potato') @@ -1300,13 +1289,12 @@ class TestCheckIDExtension(unittest.TestCase): }) self.failUnlessEqual(self.response.fields.getArgs(namespace), - {'bright':'potato'}) - + {'bright': 'potato'}) def test_addFields(self): namespace = 'mi5:' - args = {'tangy': 'suspenders', - 'bravo': 'inclusion'} + args = {'tangy': 'suspenders', + 'bravo': 'inclusion'} self.response.fields.updateArgs(namespace, args) self.failUnlessEqual(self.response.fields.getArgs(OPENID_NS), {'blue': 'star', @@ -1315,7 +1303,6 @@ class TestCheckIDExtension(unittest.TestCase): self.failUnlessEqual(self.response.fields.getArgs(namespace), args) - class MockSignatory(object): isValid = True @@ -1349,7 +1336,7 @@ class TestCheckAuth(unittest.TestCase): 'openid.sig': 'signarture', 'one': 'alpha', 'two': 'beta', - }) + }) self.request = server.CheckAuthRequest( self.assoc_handle, self.message) @@ -1420,7 +1407,8 @@ class TestAssociate(unittest.TestCase): session = DiffieHellmanSHA1ServerSession(server_dh, cpub) self.request = server.AssociateRequest(session, 'HMAC-SHA1') response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) self.failIf(rfg("mac_key")) @@ -1444,7 +1432,8 @@ class TestAssociate(unittest.TestCase): session = DiffieHellmanSHA256ServerSession(server_dh, cpub) self.request = server.AssociateRequest(session, 'HMAC-SHA256') response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA256") self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) self.failIf(rfg("mac_key")) @@ -1458,23 +1447,18 @@ class TestAssociate(unittest.TestCase): self.failUnlessEqual(secret, self.assoc.secret) def test_protoError256(self): - from openid.consumer.consumer import \ - DiffieHellmanSHA256ConsumerSession - s256_session = DiffieHellmanSHA256ConsumerSession() - invalid_s256 = {'openid.assoc_type':'HMAC-SHA1', - 'openid.session_type':'DH-SHA256',} + invalid_s256 = {'openid.assoc_type': 'HMAC-SHA1', 'openid.session_type': 'DH-SHA256'} invalid_s256.update(s256_session.getRequest()) - invalid_s256_2 = {'openid.assoc_type':'MONKEY-PIRATE', - 'openid.session_type':'DH-SHA256',} + invalid_s256_2 = {'openid.assoc_type': 'MONKEY-PIRATE', 'openid.session_type': 'DH-SHA256'} invalid_s256_2.update(s256_session.getRequest()) bad_request_argss = [ invalid_s256, invalid_s256_2, - ] + ] for request_args in bad_request_argss: message = Message.fromPostArgs(request_args) @@ -1487,19 +1471,17 @@ class TestAssociate(unittest.TestCase): s1_session = DiffieHellmanSHA1ConsumerSession() - invalid_s1 = {'openid.assoc_type':'HMAC-SHA256', - 'openid.session_type':'DH-SHA1',} + invalid_s1 = {'openid.assoc_type': 'HMAC-SHA256', 'openid.session_type': 'DH-SHA1'} invalid_s1.update(s1_session.getRequest()) - invalid_s1_2 = {'openid.assoc_type':'ROBOT-NINJA', - 'openid.session_type':'DH-SHA1',} + invalid_s1_2 = {'openid.assoc_type': 'ROBOT-NINJA', 'openid.session_type': 'DH-SHA1'} invalid_s1_2.update(s1_session.getRequest()) bad_request_argss = [ - {'openid.assoc_type':'Wha?'}, + {'openid.assoc_type': 'Wha?'}, invalid_s1, invalid_s1_2, - ] + ] for request_args in bad_request_argss: message = Message.fromPostArgs(request_args) @@ -1516,7 +1498,7 @@ class TestAssociate(unittest.TestCase): openid1_args = { 'openid.identitiy': 'invalid', 'openid.mode': 'checkid_setup', - } + } openid2_args = dict(openid1_args) openid2_args.update({'openid.ns': OPENID2_NS}) @@ -1545,7 +1527,7 @@ class TestAssociate(unittest.TestCase): # Slop is necessary because the tests can sometimes get run # right on a second boundary - slop = 1 # second + slop = 1 # second difference = expected_expires_in - expires_in error_message = ('"expires_in" value not within %s of expected: ' @@ -1556,7 +1538,8 @@ class TestAssociate(unittest.TestCase): def test_plaintext(self): self.assoc = self.signatory.createAssociation(dumb=False, assoc_type='HMAC-SHA1') response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) @@ -1578,7 +1561,7 @@ class TestAssociate(unittest.TestCase): 'openid.mode': 'associate', 'openid.assoc_type': 'HMAC-SHA1', 'openid.session_type': 'no-encryption', - } + } self.request = server.AssociateRequest.fromMessage( Message.fromPostArgs(args)) @@ -1587,7 +1570,8 @@ class TestAssociate(unittest.TestCase): self.assoc = self.signatory.createAssociation( dumb=False, assoc_type='HMAC-SHA1') response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) @@ -1605,7 +1589,8 @@ class TestAssociate(unittest.TestCase): def test_plaintext256(self): self.assoc = self.signatory.createAssociation(dumb=False, assoc_type='HMAC-SHA256') response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) @@ -1632,8 +1617,9 @@ class TestAssociate(unittest.TestCase): message=message, preferred_session_type=allowed_sess, preferred_association_type=allowed_assoc, - ) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + ) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg('error_code'), 'unsupported-type') self.failUnlessEqual(rfg('assoc_type'), allowed_assoc) self.failUnlessEqual(rfg('error'), message) @@ -1647,12 +1633,14 @@ class TestAssociate(unittest.TestCase): self.request.message = Message(OPENID2_NS) response = self.request.answerUnsupported(message) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg('error_code'), 'unsupported-type') self.failUnlessEqual(rfg('assoc_type'), None) self.failUnlessEqual(rfg('error'), message) self.failUnlessEqual(rfg('session_type'), None) + class Counter(object): def __init__(self): self.count = 0 @@ -1660,6 +1648,7 @@ class Counter(object): def inc(self): self.count += 1 + class TestServer(unittest.TestCase, CatchLogs): def setUp(self): self.store = memstore.MemoryStore() @@ -1668,6 +1657,7 @@ class TestServer(unittest.TestCase, CatchLogs): def test_dispatch(self): monkeycalled = Counter() + def monkeyDo(request): monkeycalled.inc() r = server.OpenIDResponse(request) @@ -1676,7 +1666,7 @@ class TestServer(unittest.TestCase, CatchLogs): request = server.OpenIDRequest() request.mode = "monkeymode" request.namespace = OPENID1_NS - webresult = self.server.handleRequest(request) + self.server.handleRequest(request) self.failUnlessEqual(monkeycalled.count, 1) def test_associate(self): @@ -1698,7 +1688,7 @@ class TestServer(unittest.TestCase, CatchLogs): 'openid.ns': OPENID2_NS, 'openid.session_type': 'no-encryption', 'openid.assoc_type': 'HMAC-SHA1', - }) + }) request = server.AssociateRequest.fromMessage(msg) @@ -1721,7 +1711,7 @@ class TestServer(unittest.TestCase, CatchLogs): 'openid.ns': OPENID2_NS, 'openid.session_type': 'no-encryption', 'openid.assoc_type': 'HMAC-SHA1', - }) + }) request = server.AssociateRequest.fromMessage(msg) response = self.server.openid_associate(request) @@ -1745,7 +1735,7 @@ class TestServer(unittest.TestCase, CatchLogs): '1WxJY3jHd5k1/ZReyRZOxZTKdF/dnIqwF8ZXUwI6peV0TyS/K1fOfF/s', 'openid.assoc_type': 'HMAC-SHA256', 'openid.session_type': 'DH-SHA256', - } + } message = Message.fromPostArgs(query) request = server.AssociateRequest.fromMessage(message) response = self.server.openid_associate(request) @@ -1755,7 +1745,7 @@ class TestServer(unittest.TestCase, CatchLogs): """Make sure session_type is required in OpenID 2""" msg = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, - }) + }) self.assertRaises(server.ProtocolError, server.AssociateRequest.fromMessage, msg) @@ -1765,7 +1755,7 @@ class TestServer(unittest.TestCase, CatchLogs): msg = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, 'openid.session_type': 'no-encryption', - }) + }) self.assertRaises(server.ProtocolError, server.AssociateRequest.fromMessage, msg) @@ -1775,6 +1765,7 @@ class TestServer(unittest.TestCase, CatchLogs): response = self.server.openid_check_authentication(request) self.failUnless(response.fields.hasKey(OPENID_NS, "is_valid")) + class TestSignatory(unittest.TestCase, CatchLogs): def setUp(self): self.store = memstore.MemoryStore() @@ -1797,7 +1788,7 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - }) + }) sresponse = self.signatory.sign(response) self.failUnlessEqual( sresponse.fields.getArg(OPENID_NS, 'assoc_handle'), @@ -1816,8 +1807,8 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - 'ns':OPENID2_NS, - }) + 'ns': OPENID2_NS, + }) sresponse = self.signatory.sign(response) assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.failUnless(assoc_handle) @@ -1857,7 +1848,7 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - }) + }) sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') @@ -1881,7 +1872,6 @@ class TestSignatory(unittest.TestCase, CatchLogs): self.failIf(self.store.getAssociation(self._normal_key, new_assoc_handle)) self.failUnless(self.messages) - def test_signInvalidHandle(self): request = server.OpenIDRequest() request.namespace = OPENID2_NS @@ -1893,7 +1883,7 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - }) + }) sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') @@ -1913,7 +1903,6 @@ class TestSignatory(unittest.TestCase, CatchLogs): self.failIf(self.store.getAssociation(self._normal_key, new_assoc_handle)) self.failIf(self.messages, self.messages) - def test_verify(self): assoc_handle = '{vroom}{zoom}' assoc = association.Association.fromExpiresIn( @@ -1927,13 +1916,12 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'openid.assoc_handle': assoc_handle, 'openid.signed': 'apple,assoc_handle,foo,signed', 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco=', - }) + }) verified = self.signatory.verify(assoc_handle, signed) self.failIf(self.messages, self.messages) self.failUnless(verified) - def test_verifyBadSig(self): assoc_handle = '{vroom}{zoom}' assoc = association.Association.fromExpiresIn( @@ -1947,7 +1935,7 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'openid.assoc_handle': assoc_handle, 'openid.signed': 'apple,assoc_handle,foo,signed', 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco='.encode('rot13'), - }) + }) verified = self.signatory.verify(assoc_handle, signed) self.failIf(self.messages, self.messages) @@ -1959,13 +1947,12 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'foo': 'bar', 'apple': 'orange', 'openid.sig': "Ylu0KcIR7PvNegB/K41KpnRgJl0=", - }) + }) verified = self.signatory.verify(assoc_handle, signed) self.failIf(verified) self.failUnless(self.messages) - def test_verifyAssocMismatch(self): """Attempt to validate sign-all message with a signed-list assoc.""" assoc_handle = '{vroom}{zoom}' @@ -1978,7 +1965,7 @@ class TestSignatory(unittest.TestCase, CatchLogs): 'foo': 'bar', 'apple': 'orange', 'openid.sig': "d71xlHtqnq98DonoSgoK/nD+QRM=", - }) + }) verified = self.signatory.verify(assoc_handle, signed) self.failIf(verified) @@ -1992,10 +1979,10 @@ class TestSignatory(unittest.TestCase, CatchLogs): self.failIf(self.messages, self.messages) def test_getAssocExpired(self): - assoc_handle = self.makeAssoc(dumb=True, lifetime=-10) + assoc_handle = self.makeAssoc(dumb=True, lifetime=-10) assoc = self.signatory.getAssociation(assoc_handle, True) self.failIf(assoc, assoc) - self.failUnless(self.messages) + self.failUnless(self.messages) def test_getAssocInvalid(self): ah = 'no-such-handle' @@ -2052,6 +2039,5 @@ class TestSignatory(unittest.TestCase, CatchLogs): self.failIf(self.messages, self.messages) - if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index 0abbc5e..ddcf9dc 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -1,7 +1,7 @@ import unittest from openid.extensions import sreg -from openid.message import Message, NamespaceMap, registerNamespaceAlias +from openid.message import Message, NamespaceMap from openid.server.server import OpenIDRequest, OpenIDResponse @@ -9,6 +9,7 @@ class SRegURITest(unittest.TestCase): def test_is11(self): self.failUnlessEqual(sreg.ns_uri_1_1, sreg.ns_uri) + class CheckFieldNameTest(unittest.TestCase): def test_goodNamePasses(self): for field_name in sreg.data_fields: @@ -21,6 +22,8 @@ class CheckFieldNameTest(unittest.TestCase): self.failUnlessRaises(ValueError, sreg.checkFieldName, None) # For supportsSReg test + + class FakeEndpoint(object): def __init__(self, supported): self.supported = supported @@ -30,6 +33,7 @@ class FakeEndpoint(object): self.checked_uris.append(namespace_uri) return namespace_uri in self.supported + class SupportsSRegTest(unittest.TestCase): def test_unsupported(self): endpoint = FakeEndpoint([]) @@ -48,6 +52,7 @@ class SupportsSRegTest(unittest.TestCase): self.failUnlessEqual([sreg.ns_uri_1_1, sreg.ns_uri_1_0], endpoint.checked_uris) + class FakeMessage(object): def __init__(self): self.openid1 = False @@ -56,6 +61,7 @@ class FakeMessage(object): def isOpenID1(self): return self.openid1 + class GetNSTest(unittest.TestCase): def setUp(self): self.msg = FakeMessage() @@ -110,13 +116,14 @@ class GetNSTest(unittest.TestCase): args = { 'sreg.optional': 'nickname', 'sreg.required': 'dob', - } + } m = Message.fromOpenIDArgs(args) self.failUnless(m.getArg(sreg.ns_uri_1_1, 'optional') == 'nickname') self.failUnless(m.getArg(sreg.ns_uri_1_1, 'required') == 'dob') + class SRegRequestTest(unittest.TestCase): def test_constructEmpty(self): req = sreg.SRegRequest() @@ -142,7 +149,6 @@ class SRegRequestTest(unittest.TestCase): sreg.SRegRequest, ['elvis']) def test_fromOpenIDRequest(self): - args = {} ns_sentinel = object() args_sentinel = object() @@ -173,7 +179,7 @@ class SRegRequestTest(unittest.TestCase): openid_req.message = msg req = TestingReq.fromOpenIDRequest(openid_req) - self.failUnless(type(req) is TestingReq) + self.assertIsInstance(req, TestingReq) self.failUnless(msg.copied) def test_parseExtensionArgs_empty(self): @@ -183,60 +189,60 @@ class SRegRequestTest(unittest.TestCase): def test_parseExtensionArgs_extraIgnored(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'janrain':'inc'}) + req.parseExtensionArgs({'janrain': 'inc'}) def test_parseExtensionArgs_nonStrict(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'required':'beans'}) + req.parseExtensionArgs({'required': 'beans'}) self.failUnlessEqual([], req.required) def test_parseExtensionArgs_strict(self): req = sreg.SRegRequest() self.failUnlessRaises( ValueError, - req.parseExtensionArgs, {'required':'beans'}, strict=True) + req.parseExtensionArgs, {'required': 'beans'}, strict=True) def test_parseExtensionArgs_policy(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'policy_url':'http://policy'}, strict=True) + req.parseExtensionArgs({'policy_url': 'http://policy'}, strict=True) self.failUnlessEqual('http://policy', req.policy_url) def test_parseExtensionArgs_requiredEmpty(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'required':''}, strict=True) + req.parseExtensionArgs({'required': ''}, strict=True) self.failUnlessEqual([], req.required) def test_parseExtensionArgs_optionalEmpty(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':''}, strict=True) + req.parseExtensionArgs({'optional': ''}, strict=True) self.failUnlessEqual([], req.optional) def test_parseExtensionArgs_optionalSingle(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':'nickname'}, strict=True) + req.parseExtensionArgs({'optional': 'nickname'}, strict=True) self.failUnlessEqual(['nickname'], req.optional) def test_parseExtensionArgs_optionalList(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':'nickname,email'}, strict=True) - self.failUnlessEqual(['nickname','email'], req.optional) + req.parseExtensionArgs({'optional': 'nickname,email'}, strict=True) + self.failUnlessEqual(['nickname', 'email'], req.optional) def test_parseExtensionArgs_optionalListBadNonStrict(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':'nickname,email,beer'}) - self.failUnlessEqual(['nickname','email'], req.optional) + req.parseExtensionArgs({'optional': 'nickname,email,beer'}) + self.failUnlessEqual(['nickname', 'email'], req.optional) def test_parseExtensionArgs_optionalListBadStrict(self): req = sreg.SRegRequest() self.failUnlessRaises( ValueError, - req.parseExtensionArgs, {'optional':'nickname,email,beer'}, + req.parseExtensionArgs, {'optional': 'nickname,email,beer'}, strict=True) def test_parseExtensionArgs_bothNonStrict(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':'nickname', - 'required':'nickname'}) + req.parseExtensionArgs({'optional': 'nickname', + 'required': 'nickname'}) self.failUnlessEqual([], req.optional) self.failUnlessEqual(['nickname'], req.required) @@ -245,16 +251,16 @@ class SRegRequestTest(unittest.TestCase): self.failUnlessRaises( ValueError, req.parseExtensionArgs, - {'optional':'nickname', - 'required':'nickname'}, + {'optional': 'nickname', + 'required': 'nickname'}, strict=True) def test_parseExtensionArgs_bothList(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':'nickname,email', - 'required':'country,postcode'}, strict=True) - self.failUnlessEqual(['nickname','email'], req.optional) - self.failUnlessEqual(['country','postcode'], req.required) + req.parseExtensionArgs({'optional': 'nickname,email', + 'required': 'country,postcode'}, strict=True) + self.failUnlessEqual(['nickname', 'email'], req.optional) + self.failUnlessEqual(['country', 'postcode'], req.required) def test_allRequestedFields(self): req = sreg.SRegRequest() @@ -262,8 +268,7 @@ class SRegRequestTest(unittest.TestCase): req.requestField('nickname') self.failUnlessEqual(['nickname'], req.allRequestedFields()) req.requestField('gender', required=True) - requested = req.allRequestedFields() - requested.sort() + requested = sorted(req.allRequestedFields()) self.failUnlessEqual(['gender', 'nickname'], requested) def test_wereFieldsRequested(self): @@ -378,38 +383,40 @@ class SRegRequestTest(unittest.TestCase): self.failUnlessEqual({}, req.getExtensionArgs()) req.requestField('nickname') - self.failUnlessEqual({'optional':'nickname'}, req.getExtensionArgs()) + self.failUnlessEqual({'optional': 'nickname'}, req.getExtensionArgs()) req.requestField('email') - self.failUnlessEqual({'optional':'nickname,email'}, + self.failUnlessEqual({'optional': 'nickname,email'}, req.getExtensionArgs()) req.requestField('gender', required=True) - self.failUnlessEqual({'optional':'nickname,email', - 'required':'gender'}, + self.failUnlessEqual({'optional': 'nickname,email', + 'required': 'gender'}, req.getExtensionArgs()) req.requestField('postcode', required=True) - self.failUnlessEqual({'optional':'nickname,email', - 'required':'gender,postcode'}, + self.failUnlessEqual({'optional': 'nickname,email', + 'required': 'gender,postcode'}, req.getExtensionArgs()) req.policy_url = 'http://policy.invalid/' - self.failUnlessEqual({'optional':'nickname,email', - 'required':'gender,postcode', - 'policy_url':'http://policy.invalid/'}, + self.failUnlessEqual({'optional': 'nickname,email', + 'required': 'gender,postcode', + 'policy_url': 'http://policy.invalid/'}, req.getExtensionArgs()) + data = { - 'nickname':'linusaur', - 'postcode':'12345', - 'country':'US', - 'gender':'M', - 'fullname':'Leonhard Euler', - 'email':'president@whitehouse.gov', - 'dob':'0000-00-00', - 'language':'en-us', - } + 'nickname': 'linusaur', + 'postcode': '12345', + 'country': 'US', + 'gender': 'M', + 'fullname': 'Leonhard Euler', + 'email': 'president@whitehouse.gov', + 'dob': '0000-00-00', + 'language': 'en-us', +} + class DummySuccessResponse(object): def __init__(self, message, signed_stuff): @@ -419,6 +426,7 @@ class DummySuccessResponse(object): def getSignedNS(self, ns_uri): return self.signed_stuff + class SRegResponseTest(unittest.TestCase): def test_construct(self): resp = sreg.SRegResponse(data) @@ -432,22 +440,23 @@ class SRegResponseTest(unittest.TestCase): def test_fromSuccessResponse_signed(self): message = Message.fromOpenIDArgs({ - 'sreg.nickname':'The Mad Stork', - }) + 'sreg.nickname': 'The Mad Stork', + }) success_resp = DummySuccessResponse(message, {}) sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp) self.failIf(sreg_resp) def test_fromSuccessResponse_unsigned(self): message = Message.fromOpenIDArgs({ - 'sreg.nickname':'The Mad Stork', - }) + 'sreg.nickname': 'The Mad Stork', + }) success_resp = DummySuccessResponse(message, {}) sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp, signed_only=False) self.failUnlessEqual([('nickname', 'The Mad Stork')], sreg_resp.items()) + class SendFieldsTest(unittest.TestCase): def test(self): # Create a request message with simple registration fields @@ -476,10 +485,11 @@ class SendFieldsTest(unittest.TestCase): # Extract the fields that were sent sreg_data_resp = resp_msg.getArgs(sreg.ns_uri) self.failUnlessEqual( - {'nickname':'linusaur', - 'email':'president@whitehouse.gov', - 'fullname':'Leonhard Euler', + {'nickname': 'linusaur', + 'email': 'president@whitehouse.gov', + 'fullname': 'Leonhard Euler', }, sreg_data_resp) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_symbol.py b/openid/test/test_symbol.py index 7f9b79b..7425222 100644 --- a/openid/test/test_symbol.py +++ b/openid/test/test_symbol.py @@ -32,5 +32,6 @@ class SymbolTest(unittest.TestCase): y = oidutil.Symbol('yyy') self.failUnless(x != y) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index 154f751..98e9f49 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -17,7 +17,7 @@ class UrinormTest(unittest.TestCase): def runTest(self): try: actual = openid.urinorm.urinorm(self.case) - except ValueError, why: + except ValueError as why: self.assertEqual(self.expected, 'fail', why) else: self.assertEqual(actual, self.expected) @@ -43,6 +43,7 @@ def parseTests(test_data): return result + def pyUnitTests(): here = os.path.dirname(os.path.abspath(__file__)) test_data_file_name = os.path.join(here, 'urinorm.txt') diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index 43664bc..57ead86 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -14,11 +14,12 @@ def const(result): return constResult + class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): def failUnlessProtocolError(self, prefix, callable, *args, **kwargs): try: result = callable(*args, **kwargs) - except consumer.ProtocolError, e: + except consumer.ProtocolError as e: self.failUnless( e[0].startswith(prefix), 'Expected message prefix %r, got message %r' % (prefix, e[0])) @@ -37,30 +38,30 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): self.failUnlessLogEmpty() def test_openID1NoEndpoint(self): - msg = message.Message.fromOpenIDArgs({'identity':'snakes on a plane'}) + msg = message.Message.fromOpenIDArgs({'identity': 'snakes on a plane'}) self.failUnlessRaises(RuntimeError, self.consumer._verifyDiscoveryResults, msg) self.failUnlessLogEmpty() def test_openID2NoOPEndpointArg(self): - msg = message.Message.fromOpenIDArgs({'ns':message.OPENID2_NS}) + msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS}) self.failUnlessRaises(KeyError, self.consumer._verifyDiscoveryResults, msg) self.failUnlessLogEmpty() def test_openID2LocalIDNoClaimed(self): - msg = message.Message.fromOpenIDArgs({'ns':message.OPENID2_NS, - 'op_endpoint':'Phone Home', - 'identity':'Jose Lius Borges'}) + msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, + 'op_endpoint': 'Phone Home', + 'identity': 'Jose Lius Borges'}) self.failUnlessProtocolError( 'openid.identity is present without', self.consumer._verifyDiscoveryResults, msg) self.failUnlessLogEmpty() def test_openID2NoLocalIDClaimed(self): - msg = message.Message.fromOpenIDArgs({'ns':message.OPENID2_NS, - 'op_endpoint':'Phone Home', - 'claimed_id':'Manuel Noriega'}) + msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, + 'op_endpoint': 'Phone Home', + 'claimed_id': 'Manuel Noriega'}) self.failUnlessProtocolError( 'openid.claimed_id is present without', self.consumer._verifyDiscoveryResults, msg) @@ -68,8 +69,8 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): def test_openID2NoIdentifiers(self): op_endpoint = 'Phone Home' - msg = message.Message.fromOpenIDArgs({'ns':message.OPENID2_NS, - 'op_endpoint':op_endpoint}) + msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, + 'op_endpoint': op_endpoint}) result_endpoint = self.consumer._verifyDiscoveryResults(msg) self.failUnless(result_endpoint.isOPIdentifier()) self.failUnlessEqual(op_endpoint, result_endpoint.server_url) @@ -82,10 +83,10 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): sentinel.claimed_id = 'monkeysoft' self.consumer._discoverAndVerify = const(sentinel) msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID2_NS, - 'identity':'sour grapes', - 'claimed_id':'monkeysoft', - 'op_endpoint':op_endpoint}) + {'ns': message.OPENID2_NS, + 'identity': 'sour grapes', + 'claimed_id': 'monkeysoft', + 'op_endpoint': op_endpoint}) result = self.consumer._verifyDiscoveryResults(msg) self.failUnlessEqual(sentinel, result) self.failUnlessLogMatches('No pre-discovered') @@ -100,10 +101,10 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): sentinel.claimed_id = 'monkeysoft' self.consumer._discoverAndVerify = const(sentinel) msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID2_NS, - 'identity':'sour grapes', - 'claimed_id':'monkeysoft', - 'op_endpoint':op_endpoint}) + {'ns': message.OPENID2_NS, + 'identity': 'sour grapes', + 'claimed_id': 'monkeysoft', + 'op_endpoint': op_endpoint}) result = self.consumer._verifyDiscoveryResults(msg, mismatched) self.failUnlessEqual(sentinel, result) self.failUnlessLogMatches('Error attempting to use stored', @@ -117,10 +118,10 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint.type_uris = [discover.OPENID_2_0_TYPE] msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID2_NS, - 'identity':endpoint.local_id, - 'claimed_id':endpoint.claimed_id, - 'op_endpoint':endpoint.server_url}) + {'ns': message.OPENID2_NS, + 'identity': endpoint.local_id, + 'claimed_id': endpoint.claimed_id, + 'op_endpoint': endpoint.server_url}) result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.failUnless(result is endpoint) self.failUnlessLogEmpty() @@ -143,14 +144,14 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): self.consumer._discoverAndVerify = discoverAndVerify msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID2_NS, - 'identity':endpoint.local_id, - 'claimed_id':endpoint.claimed_id, - 'op_endpoint':endpoint.server_url}) + {'ns': message.OPENID2_NS, + 'identity': endpoint.local_id, + 'claimed_id': endpoint.claimed_id, + 'op_endpoint': endpoint.server_url}) try: r = self.consumer._verifyDiscoveryResults(msg, endpoint) - except consumer.ProtocolError, e: + except consumer.ProtocolError as e: # Should we make more ProtocolError subclasses? self.failUnless(str(e), text) else: @@ -167,14 +168,15 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint.type_uris = [discover.OPENID_1_1_TYPE] msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID1_NS, - 'identity':endpoint.local_id}) + {'ns': message.OPENID1_NS, + 'identity': endpoint.local_id}) result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.failUnless(result is endpoint) self.failUnlessLogEmpty() def test_openid1UsePreDiscoveredWrongType(self): - class VerifiedError(Exception): pass + class VerifiedError(Exception): + pass def discoverAndVerify(claimed_id, _to_match): raise VerifiedError @@ -188,8 +190,8 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint.type_uris = [discover.OPENID_2_0_TYPE] msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID1_NS, - 'identity':endpoint.local_id}) + {'ns': message.OPENID1_NS, + 'identity': endpoint.local_id}) self.failUnlessRaises( VerifiedError, @@ -208,18 +210,18 @@ class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): endpoint.type_uris = [discover.OPENID_2_0_TYPE] msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID2_NS, - 'identity':endpoint.local_id, + {'ns': message.OPENID2_NS, + 'identity': endpoint.local_id, 'claimed_id': claimed_id_frag, 'op_endpoint': endpoint.server_url}) result = self.consumer._verifyDiscoveryResults(msg, endpoint) - + self.failUnlessEqual(result.local_id, endpoint.local_id) self.failUnlessEqual(result.server_url, endpoint.server_url) self.failUnlessEqual(result.type_uris, endpoint.type_uris) self.failUnlessEqual(result.claimed_id, claimed_id_frag) - + self.failUnlessLogEmpty() def test_openid1Fallback1_0(self): @@ -267,5 +269,6 @@ class TestVerifyDiscoverySingle(TestIdRes): self.failUnlessEqual(result, None) self.failUnlessLogEmpty() + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_xri.py b/openid/test/test_xri.py index 33ea0e0..6e6ac8e 100644 --- a/openid/test/test_xri.py +++ b/openid/test/test_xri.py @@ -18,7 +18,6 @@ class XriEscapingTestCase(TestCase): self.failUnlessEqual(xri.escapeForIRI('@example/abc%2Fd/ef'), '@example/abc%252Fd/ef') - def test_escaping_xref(self): # no escapes esc = xri.escapeForIRI @@ -33,7 +32,6 @@ class XriEscapingTestCase(TestCase): esc('@example/foo/(@baz?p=q#r)?i=j#k')) - class XriTransformationTestCase(TestCase): def test_to_iri_normal(self): self.failUnlessEqual(xri.toIRINormal('@example'), 'xri://@example') @@ -53,7 +51,6 @@ class XriTransformationTestCase(TestCase): self.failUnlessEqual(xri.iriToURI(s), expected) - class CanonicalIDTest(TestCase): def mkTest(providerID, canonicalID, isAuthoritative): def test(self): @@ -73,6 +70,7 @@ class CanonicalIDTest(TestCase): test_atEqualsAndTooDeepFails = mkTest('@!1234!ABCD', '=!1234', False) test_differentBeginningFails = mkTest('=!BABE', '=!D00D', False) + class TestGetRootAuthority(TestCase): def mkTest(the_xri, expected_root): def test(self): @@ -96,8 +94,9 @@ class TestGetRootAuthority(TestCase): # Looking at the ABNF in XRI Syntax 2.0, I don't think you can # have example.com*bar. You can do (example.com)*bar, but that # would mean something else. - ##("example.com*bar/(=baz)", "example.com*bar"), - ##("baz.example.com!01/foo", "baz.example.com!01"), + # ("example.com*bar/(=baz)", "example.com*bar"), + # ("baz.example.com!01/foo", "baz.example.com!01"), + if __name__ == '__main__': import unittest diff --git a/openid/test/test_xrires.py b/openid/test/test_xrires.py index 873255c..b06a8b2 100644 --- a/openid/test/test_xrires.py +++ b/openid/test/test_xrires.py @@ -11,7 +11,6 @@ class ProxyQueryTestCase(TestCase): self.servicetype = 'xri://+i-service*(+forwarding)*($v*1.0)' self.servicetype_enc = 'xri%3A%2F%2F%2Bi-service%2A%28%2Bforwarding%29%2A%28%24v%2A1.0%29' - def test_proxy_url(self): st = self.servicetype ste = self.servicetype_enc @@ -30,7 +29,6 @@ class ProxyQueryTestCase(TestCase): args_esc = "_xrd_r=application%2Fxrds%2Bxml%3Bsep%3Dfalse" self.failUnlessEqual(h + '=foo?' + args_esc, pqu('=foo', None)) - def test_proxy_url_qmarks(self): st = self.servicetype ste = self.servicetype_enc diff --git a/openid/test/test_yadis_discover.py b/openid/test/test_yadis_discover.py index 8c222d0..c7ba05c 100644 --- a/openid/test/test_yadis_discover.py +++ b/openid/test/test_yadis_discover.py @@ -24,7 +24,10 @@ Content-Type: text/plain No such file %s """ -class QuitServer(Exception): pass + +class QuitServer(Exception): + pass + def mkResponse(data): status_mo = status_header_re.match(data) @@ -40,6 +43,7 @@ def mkResponse(data): headers=headers, body=body) + class TestFetcher(object): def __init__(self, base_url): self.base_url = base_url @@ -64,16 +68,18 @@ class TestFetcher(object): response.final_url = current_url return response + class TestSecondGet(unittest.TestCase): class MockFetcher(object): def __init__(self): self.count = 0 + def fetch(self, uri, headers=None, body=None): self.count += 1 if self.count == 1: headers = { 'X-XRDS-Location'.lower(): 'http://unittest/404', - } + } return fetchers.HTTPResponse(uri, 200, headers, '') else: return fetchers.HTTPResponse(uri, 404) @@ -137,10 +143,8 @@ class _TestCase(unittest.TestCase): self.failUnlessEqual( self.expected.response_text, result.response_text, msg) - expected_keys = dir(self.expected) - expected_keys.sort() - actual_keys = dir(result) - actual_keys.sort() + expected_keys = sorted(dir(self.expected)) + actual_keys = sorted(dir(result)) self.failUnlessEqual(actual_keys, expected_keys) for k in dir(self.expected): @@ -162,6 +166,7 @@ class _TestCase(unittest.TestCase): n, self.__class__.__module__) + def pyUnitTests(): s = unittest.TestSuite() for success, input_name, id_name, result_name in discoverdata.testlist: @@ -170,9 +175,11 @@ def pyUnitTests(): return s + def test(): runner = unittest.TextTestRunner() return runner.run(pyUnitTests()) + if __name__ == '__main__': test() diff --git a/openid/test/trustroot.py b/openid/test/trustroot.py index f934ce3..c9a0f72 100644 --- a/openid/test/trustroot.py +++ b/openid/test/trustroot.py @@ -23,6 +23,7 @@ class _ParseTest(unittest.TestCase): else: assert tr is None, tr + class _MatchTest(unittest.TestCase): def __init__(self, match, desc, line): unittest.TestCase.__init__(self) @@ -45,6 +46,7 @@ class _MatchTest(unittest.TestCase): else: assert not match + def getTests(t, grps, head, dat): tests = [] top = head.strip() @@ -61,6 +63,7 @@ def getTests(t, grps, head, dat): i += 2 return tests + def parseTests(data): parts = map(str.strip, data.split('=' * 40 + '\n')) assert not parts[0] @@ -71,6 +74,7 @@ def parseTests(data): tests.extend(getTests(_MatchTest, [1, 0], mh, mdat)) return tests + def pyUnitTests(): here = os.path.dirname(os.path.abspath(__file__)) test_data_file_name = os.path.join(here, 'data', 'trustroot.txt') @@ -81,6 +85,7 @@ def pyUnitTests(): tests = parseTests(test_data) return unittest.TestSuite(tests) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/urinorm.py b/openid/urinorm.py index 5bdbaef..21869c8 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -29,11 +29,11 @@ except ValueError: (0xA0, 0xD7FF), (0xF900, 0xFDCF), (0xFDF0, 0xFFEF), - ] + ] IPRIVATE = [ (0xE000, 0xF8FF), - ] + ] else: UCSCHAR = [ (0xA0, 0xD7FF), @@ -53,19 +53,22 @@ else: (0xC0000, 0xCFFFD), (0xD0000, 0xDFFFD), (0xE1000, 0xEFFFD), - ] + ] IPRIVATE = [ (0xE000, 0xF8FF), (0xF0000, 0xFFFFD), (0x100000, 0x10FFFD), - ] + ] _unreserved = [False] * 256 -for _ in range(ord('A'), ord('Z') + 1): _unreserved[_] = True -for _ in range(ord('0'), ord('9') + 1): _unreserved[_] = True -for _ in range(ord('a'), ord('z') + 1): _unreserved[_] = True +for _ in range(ord('A'), ord('Z') + 1): + _unreserved[_] = True +for _ in range(ord('0'), ord('9') + 1): + _unreserved[_] = True +for _ in range(ord('a'), ord('z') + 1): + _unreserved[_] = True _unreserved[ord('-')] = True _unreserved[ord('.')] = True _unreserved[ord('_')] = True @@ -73,7 +76,7 @@ _unreserved[ord('~')] = True _escapeme_re = re.compile('[%s]' % (''.join( - map(lambda (m, n): u'%s-%s' % (unichr(m), unichr(n)), + map(lambda m_n: u'%s-%s' % (unichr(m_n[0]), unichr(m_n[1])), UCSCHAR + IPRIVATE)),)) @@ -176,9 +179,7 @@ def urinorm(uri): host = host.lower() if port: - if (port == ':' or - (scheme == 'http' and port == ':80') or - (scheme == 'https' and port == ':443')): + if port == ':' or (scheme == 'http' and port == ':80') or (scheme == 'https' and port == ':443'): port = '' else: port = '' diff --git a/openid/yadis/__init__.py b/openid/yadis/__init__.py index cfa5f1e..68a0d44 100644 --- a/openid/yadis/__init__.py +++ b/openid/yadis/__init__.py @@ -10,7 +10,7 @@ __all__ = [ 'services', 'xri', 'xrires', - ] +] __version__ = '[library version:1.1.0-rc1]'[17:-1] diff --git a/openid/yadis/accept.py b/openid/yadis/accept.py index d750813..2353bfb 100644 --- a/openid/yadis/accept.py +++ b/openid/yadis/accept.py @@ -1,6 +1,8 @@ """Functions for generating and parsing HTTP Accept: headers for supporting server-directed content negotiation. """ +from operator import itemgetter + def generateAcceptHeader(*elements): """Generate an accept header value @@ -9,7 +11,7 @@ def generateAcceptHeader(*elements): """ parts = [] for element in elements: - if type(element) is str: + if isinstance(element, str): qs = "1.0" mtype = element else: @@ -32,6 +34,7 @@ def generateAcceptHeader(*elements): return ', '.join(chunks) + def parseAcceptHeader(value): """Parse an accept header, ignoring any accept-extensions @@ -65,11 +68,11 @@ def parseAcceptHeader(value): else: q = 1.0 - accept.append((q, main, sub)) + accept.append((main, sub, q)) + + # Sort in order q, main, sub + return sorted(accept, key=itemgetter(2, 0, 1), reverse=True) - accept.sort() - accept.reverse() - return [(main, sub, q) for (q, main, sub) in accept] def matchTypes(accept_types, have_types): """Given the result of parsing an Accept: header, and the @@ -93,31 +96,32 @@ def matchTypes(accept_types, have_types): match_main = {} match_sub = {} - for (main, sub, q) in accept_types: + for (main, sub, qvalue) in accept_types: if main == '*': - default = max(default, q) + default = max(default, qvalue) continue elif sub == '*': - match_main[main] = max(match_main.get(main, 0), q) + match_main[main] = max(match_main.get(main, 0), qvalue) else: - match_sub[(main, sub)] = max(match_sub.get((main, sub), 0), q) + match_sub[(main, sub)] = max(match_sub.get((main, sub), 0), qvalue) accepted_list = [] order_maintainer = 0 for mtype in have_types: main, sub = mtype.split('/') if (main, sub) in match_sub: - q = match_sub[(main, sub)] + quality = match_sub[(main, sub)] else: - q = match_main.get(main, default) + quality = match_main.get(main, default) - if q: - accepted_list.append((1 - q, order_maintainer, q, mtype)) + if quality: + accepted_list.append((1 - quality, order_maintainer, quality, mtype)) order_maintainer += 1 accepted_list.sort() return [(mtype, q) for (_, _, q, mtype) in accepted_list] + def getAcceptable(accept_header, have_types): """Parse the accept header and return a list of available types in preferred order. If a type is unacceptable, it will not be in the diff --git a/openid/yadis/constants.py b/openid/yadis/constants.py index 75ff96e..d160c66 100644 --- a/openid/yadis/constants.py +++ b/openid/yadis/constants.py @@ -10,4 +10,4 @@ YADIS_ACCEPT_HEADER = generateAcceptHeader( ('text/html', 0.3), ('application/xhtml+xml', 0.5), (YADIS_CONTENT_TYPE, 1.0), - ) +) diff --git a/openid/yadis/discover.py b/openid/yadis/discover.py index 27fcd01..83655a9 100644 --- a/openid/yadis/discover.py +++ b/openid/yadis/discover.py @@ -16,6 +16,7 @@ class DiscoveryFailure(Exception): Exception.__init__(self, message) self.http_response = http_response + class DiscoveryResult(object): """Contains the result of performing Yadis discovery on a URI""" @@ -53,6 +54,7 @@ class DiscoveryResult(object): return (self.usedYadisLocation() or self.content_type == YADIS_CONTENT_TYPE) + def discover(uri): """Discover services for a given URI. @@ -97,7 +99,6 @@ def discover(uri): return result - def whereIsYadis(resp): """Given a HTTPResponse, return the location of the Yadis document. @@ -114,8 +115,7 @@ def whereIsYadis(resp): # According to the spec, the content-type header must be an exact # match, or else we have to look for an indirection. - if (content_type and - content_type.split(';', 1)[0].lower() == YADIS_CONTENT_TYPE): + if content_type and content_type.split(';', 1)[0].lower() == YADIS_CONTENT_TYPE: return resp.final_url else: # Try the header diff --git a/openid/yadis/etxrd.py b/openid/yadis/etxrd.py index 52a8ab3..563a1f2 100644 --- a/openid/yadis/etxrd.py +++ b/openid/yadis/etxrd.py @@ -16,7 +16,7 @@ __all__ = [ 'iterServices', 'expandService', 'expandServices', - ] +] import random import sys @@ -36,9 +36,9 @@ try: # Make the parser raise an exception so we can sniff out the type # of exceptions ElementTree.XML('> purposely malformed XML <') -except (SystemExit, MemoryError, AssertionError, ImportError): +except (MemoryError, AssertionError, ImportError): raise -except: +except Exception: XMLError = sys.exc_info()[0] @@ -49,14 +49,12 @@ class XRDSError(Exception): reason = None - class XRDSFraud(XRDSError): """Raised when there's an assertion in the XRDS that it does not have the authority to make. """ - def parseXRDS(text): """Parse the given text as an XRDS document. @@ -67,7 +65,7 @@ def parseXRDS(text): """ try: element = ElementTree.XML(text) - except XMLError, why: + except XMLError as why: exc = XRDSError('Error parsing document as XML') exc.reason = why raise exc @@ -78,12 +76,15 @@ def parseXRDS(text): return tree + XRD_NS_2_0 = 'xri://$xrd*($v*2.0)' XRDS_NS = 'xri://$xrds' + def nsTag(ns, t): return '{%s}%s' % (ns, t) + def mkXRDTag(t): """basestring -> basestring @@ -92,6 +93,7 @@ def mkXRDTag(t): """ return nsTag(XRD_NS_2_0, t) + def mkXRDSTag(t): """basestring -> basestring @@ -100,6 +102,7 @@ def mkXRDSTag(t): """ return nsTag(XRDS_NS, t) + # Tags that are used in Yadis documents root_tag = mkXRDSTag('XRDS') service_tag = mkXRDTag('Service') @@ -111,11 +114,13 @@ expires_tag = mkXRDTag('Expires') # Other XRD tags canonicalID_tag = mkXRDTag('CanonicalID') + def isXRDS(xrd_tree): """Is this document an XRDS document?""" root = xrd_tree.getroot() return root.tag == root_tag + def getYadisXRD(xrd_tree): """Return the XRD element that should contain the Yadis services""" xrd = None @@ -132,6 +137,7 @@ def getYadisXRD(xrd_tree): return xrd + def getXRDExpiration(xrd_element, default=None): """Return the expiration date of this XRD element, or None if no expiration was specified. @@ -156,6 +162,7 @@ def getXRDExpiration(xrd_element, default=None): expires_time = strptime(expires_string, "%Y-%m-%dT%H:%M:%SZ") return datetime(*expires_time[0:6]) + def getCanonicalID(iname, xrd_tree): """Return the CanonicalID from this XRDS document. @@ -194,20 +201,22 @@ def getCanonicalID(iname, xrd_tree): return canonicalID - class _Max(object): """Value that compares greater than any other value. Should only be used as a singleton. Implemented for use as a priority value for when a priority is not specified.""" + def __cmp__(self, other): if other is self: return 0 return 1 + Max = _Max() + def getPriorityStrict(element): """Get the priority of this element. @@ -226,6 +235,7 @@ def getPriorityStrict(element): # Any errors in parsing the priority fall through to here return Max + def getPriority(element): """Get the priority of this element @@ -236,17 +246,18 @@ def getPriority(element): except ValueError: return Max + def prioSort(elements): """Sort a list of elements that have priority attributes""" # Randomize the services before sorting so that equal priority # elements are load-balanced. random.shuffle(elements) - prio_elems = [(getPriority(e), e) for e in elements] - prio_elems.sort() + prio_elems = sorted((getPriority(e), e) for e in elements) sorted_elems = [s for (_, s) in prio_elems] return sorted_elems + def iterServices(xrd_tree): """Return an iterable over the Service elements in the Yadis XRD @@ -254,18 +265,21 @@ def iterServices(xrd_tree): xrd = getYadisXRD(xrd_tree) return prioSort(xrd.findall(service_tag)) + def sortedURIs(service_element): """Given a Service element, return a list of the contents of all URI tags in priority order.""" return [uri_element.text for uri_element in prioSort(service_element.findall(uri_tag))] + def getTypeURIs(service_element): """Given a Service element, return a list of the contents of all Type tags""" return [type_element.text for type_element in service_element.findall(type_tag)] + def expandService(service_element): """Take a service element and expand it into an iterator of: ([type_uri], uri, service_element) @@ -281,6 +295,7 @@ def expandService(service_element): return expanded + def expandServices(service_elements): """Take a sorted iterator of service elements and expand it into a sorted iterator of: diff --git a/openid/yadis/filters.py b/openid/yadis/filters.py index 43e4f3f..1a9d3e7 100644 --- a/openid/yadis/filters.py +++ b/openid/yadis/filters.py @@ -9,7 +9,7 @@ __all__ = [ 'IFilter', 'TransformFilterMaker', 'CompoundFilter', - ] +] from openid.yadis.etxrd import expandService @@ -27,6 +27,7 @@ class BasicServiceEndpoint(object): The simplest kind of filter you can write implements fromBasicServiceEndpoint, which takes one of these objects. """ + def __init__(self, yadis_url, type_uris, uri, service_element): self.type_uris = type_uris self.yadis_url = yadis_url @@ -61,6 +62,7 @@ class BasicServiceEndpoint(object): fromBasicServiceEndpoint = staticmethod(fromBasicServiceEndpoint) + class IFilter(object): """Interface for Yadis filter objects. Other filter-like things are convertable to this class.""" @@ -69,6 +71,7 @@ class IFilter(object): """Returns an iterator of endpoint objects""" raise NotImplementedError + class TransformFilterMaker(object): """Take a list of basic filters and makes a filter that transforms the basic filter into a top-level filter. This is mostly useful @@ -124,10 +127,12 @@ class TransformFilterMaker(object): return None + class CompoundFilter(object): """Create a new filter that applies a set of filters to an endpoint and collects their results. """ + def __init__(self, subfilters): self.subfilters = subfilters @@ -140,10 +145,12 @@ class CompoundFilter(object): subfilter.getServiceEndpoints(yadis_url, service_element)) return endpoints + # Exception raised when something is not able to be turned into a filter filter_type_error = TypeError( 'Expected a filter, an endpoint, a callable or a list of any of these.') + def mkFilter(parts): """Convert a filter-convertable thing into a filter @@ -160,6 +167,7 @@ def mkFilter(parts): else: return mkCompoundFilter(parts) + def mkCompoundFilter(parts): """Create a filter out of a list of filter-like things diff --git a/openid/yadis/manager.py b/openid/yadis/manager.py index 709adb7..afd55ee 100644 --- a/openid/yadis/manager.py +++ b/openid/yadis/manager.py @@ -54,6 +54,7 @@ class YadisServiceManager(object): """Store this object in the session, by its session key.""" session[self.session_key] = self + class Discovery(object): """State management for discovery. @@ -133,7 +134,7 @@ class Discovery(object): return service - ### Lower-level methods + # Lower-level methods def getSessionKey(self): """Get the session key for this starting URL and suffix diff --git a/openid/yadis/parsehtml.py b/openid/yadis/parsehtml.py index c2f8029..4ecef3b 100644 --- a/openid/yadis/parsehtml.py +++ b/openid/yadis/parsehtml.py @@ -8,17 +8,20 @@ from openid.yadis.constants import YADIS_HEADER_NAME # Size of the chunks to search at a time (also the amount that gets # read at a time) -CHUNK_SIZE = 1024 * 16 # 16 KB +CHUNK_SIZE = 1024 * 16 # 16 KB + class ParseDone(Exception): """Exception to hold the URI that was located when the parse is finished. If the parse finishes without finding the URI, set it to None.""" + class MetaNotFound(Exception): """Exception to hold the content of the page if we did not find the appropriate <meta> tag""" + re_flags = re.IGNORECASE | re.UNICODE | re.VERBOSE ent_pat = r''' & @@ -32,6 +35,7 @@ ent_pat = r''' ent_re = re.compile(ent_pat, re_flags) + def substituteMO(mo): if mo.lastgroup == 'hex': codepoint = int(mo.group('hex'), 16) @@ -46,9 +50,11 @@ def substituteMO(mo): else: return unichr(codepoint) + def substituteEntities(s): return ent_re.sub(substituteMO, s) + class YadisHTMLParser(HTMLParser): """Parser that finds a meta http-equiv tag in the head of a html document. @@ -107,7 +113,7 @@ class YadisHTMLParser(HTMLParser): # if we ever see a start body tag, bail out right away, since # we want to prevent the meta tag from appearing in the body # [2] - if tag=='body': + if tag == 'body': self._terminate() if self.phase == self.TOP: @@ -155,6 +161,7 @@ class YadisHTMLParser(HTMLParser): return HTMLParser.feed(self, chars) + def findHTMLMeta(stream): """Look for a meta http-equiv tag with the YADIS header name. @@ -171,7 +178,7 @@ def findHTMLMeta(stream): parser = YadisHTMLParser() chunks = [] - while 1: + while True: chunk = stream.read(CHUNK_SIZE) if not chunk: # End of file @@ -180,11 +187,11 @@ def findHTMLMeta(stream): chunks.append(chunk) try: parser.feed(chunk) - except HTMLParseError, why: + except HTMLParseError as why: # HTML parse error, so bail chunks.append(stream.read()) break - except ParseDone, why: + except ParseDone as why: uri = why[0] if uri is None: # Parse finished, but we may need the rest of the file diff --git a/openid/yadis/services.py b/openid/yadis/services.py index 65d8834..740fec0 100644 --- a/openid/yadis/services.py +++ b/openid/yadis/services.py @@ -27,10 +27,11 @@ def getServiceEndpoints(input_url, flt=None): try: endpoints = applyFilter(result.normalized_uri, result.response_text, flt) - except XRDSError, err: + except XRDSError as err: raise DiscoveryFailure(str(err), None) return (result.normalized_uri, endpoints) + def applyFilter(normalized_uri, xrd_data, flt=None): """Generate an iterable of endpoint objects given this input data, presumably from the result of performing the Yadis protocol. diff --git a/openid/yadis/xri.py b/openid/yadis/xri.py index 3a39a6b..bd3b29e 100644 --- a/openid/yadis/xri.py +++ b/openid/yadis/xri.py @@ -1,10 +1,12 @@ # -*- test-case-name: openid.test.test_xri -*- """Utility functions for handling XRIs. -@see: XRI Syntax v2.0 at the U{OASIS XRI Technical Committee<http://www.oasis-open.org/committees/tc_home.php?wg_abbrev=xri>} +@see: XRI Syntax v2.0 at the + U{OASIS XRI Technical Committee<http://www.oasis-open.org/committees/tc_home.php?wg_abbrev=xri>} """ import re +from functools import reduce XRI_AUTHORITIES = ['!', '=', '@', '+', '$', '('] @@ -16,11 +18,11 @@ except ValueError: (0xA0, 0xD7FF), (0xF900, 0xFDCF), (0xFDF0, 0xFFEF), - ] + ] IPRIVATE = [ (0xE000, 0xF8FF), - ] + ] else: UCSCHAR = [ (0xA0, 0xD7FF), @@ -40,17 +42,17 @@ else: (0xC0000, 0xCFFFD), (0xD0000, 0xDFFFD), (0xE1000, 0xEFFFD), - ] + ] IPRIVATE = [ (0xE000, 0xF8FF), (0xF0000, 0xFFFFD), (0x100000, 0x10FFFD), - ] + ] _escapeme_re = re.compile('[%s]' % (''.join( - map(lambda (m, n): u'%s-%s' % (unichr(m), unichr(n)), + map(lambda m_n: u'%s-%s' % (unichr(m_n[0]), unichr(m_n[1])), UCSCHAR + IPRIVATE)),)) @@ -59,8 +61,7 @@ def identifierScheme(identifier): @returns: C{"XRI"} or C{"URI"} """ - if identifier.startswith('xri://') or ( - identifier and identifier[0] in XRI_AUTHORITIES): + if identifier.startswith('xri://') or (identifier and identifier[0] in XRI_AUTHORITIES): return "XRI" else: return "URI" @@ -146,8 +147,7 @@ def rootAuthority(xri): else: # IRI reference. XXX: Can IRI authorities have segments? segments = authority.split('!') - segments = reduce(list.__add__, - map(lambda s: s.split('*'), segments)) + segments = reduce(list.__add__, map(lambda s: s.split('*'), segments)) root = segments[0] return XRI(root) diff --git a/openid/yadis/xrires.py b/openid/yadis/xrires.py index e8fd7e4..4a36595 100644 --- a/openid/yadis/xrires.py +++ b/openid/yadis/xrires.py @@ -11,13 +11,14 @@ from openid.yadis.xri import toURINormal DEFAULT_PROXY = 'http://proxy.xri.net/' + class ProxyResolver(object): """Python interface to a remote XRI proxy resolver. """ + def __init__(self, proxy_url=DEFAULT_PROXY): self.proxy_url = proxy_url - def queryURL(self, xri, service_type=None): """Build a URL to query the proxy resolver. @@ -42,7 +43,7 @@ class ProxyResolver(object): # 11:13:42), then we could ask for application/xrd+xml instead, # which would give us a bit less to process. '_xrd_r': 'application/xrds+xml', - } + } if service_type: args['_xrd_t'] = service_type else: @@ -51,7 +52,6 @@ class ProxyResolver(object): query = _appendArgs(hxri, args) return query - def query(self, xri, service_types): """Resolve some services for an XRI. @@ -103,8 +103,7 @@ def _appendArgs(url, args): """ # to be merged with oidutil.appendArgs when we combine the projects. if hasattr(args, 'items'): - args = args.items() - args.sort() + args = sorted(args.items()) if len(args) == 0: return url diff --git a/pylintrc b/pylintrc deleted file mode 100644 index fb36e4c..0000000 --- a/pylintrc +++ /dev/null @@ -1,40 +0,0 @@ -[REPORTS] - -include-ids=y - -[BASIC] - -# Required attributes for module, separated by a comma -required-attributes=__all__ - -# Regular expression which should only match functions or classes name which do -# not require a docstring -no-docstring-rgx=__.*__ - -# Regular expression which should only match correct module names -module-rgx=[a-z_][a-z0-9_]*$ - -# Regular expression which should only match correct module level names -const-rgx=(([a-z_][a-z0-9_]{3,30})|(__.*__)|([A-Z_][A-Z0-9_]{3,30}))$ - -# Regular expression which should only match correct class names -class-rgx=[A-Z_][a-zA-Z0-9]+$ - -# Regular expression which should only match correct function names -function-rgx=[a-z_][A-Za-z0-9_]{2,30}$ - -# Regular expression which should only match correct method names -method-rgx=[a-z_][A-Za-z0-9_]{2,30}$ - -# Regular expression which should only match correct list comprehension / -# generator expression variable names -inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ - -# Good variable names which should always be accepted, separated by a comma -good-names=i,j,k,ex,Run,_ - -# Bad variable names which should always be refused, separated by a comma -bad-names=foo,bar,baz,toto,tutu,tata - -# List of builtins function names that should not be used, separated by a comma -bad-functions=input @@ -35,15 +35,15 @@ and support for a variety of storage back-ends.''', author_email='openid@janrain.com', download_url='http://github.com/openid/python-openid/tarball/%s' % (version,), classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Web Environment", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Operating System :: POSIX", - "Programming Language :: Python", - "Topic :: Internet :: WWW/HTTP", - "Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: System :: Systems Administration :: Authentication/Directory", + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX", + "Programming Language :: Python", + "Topic :: Internet :: WWW/HTTP", + "Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: System :: Systems Administration :: Authentication/Directory", ], - ) +) |