"""Unit tests for verification of return_to URLs for a realm."""
from __future__ import unicode_literals
import unittest
from mock import patch, sentinel
from testfixtures import LogCapture, StringComparison
from openid.server import trustroot
from openid.server.trustroot import getAllowedReturnURLs
from openid.yadis import services
from openid.yadis.discover import DiscoveryFailure, DiscoveryResult
__all__ = ['TestBuildDiscoveryURL']
class TestBuildDiscoveryURL(unittest.TestCase):
"""Tests for building the discovery URL from a realm and a
return_to URL
"""
def assertDiscoveryURL(self, realm, expected_discovery_url):
"""Build a discovery URL out of the realm and a return_to and
make sure that it matches the expected discovery URL
"""
realm_obj = trustroot.TrustRoot.parse(realm)
actual_discovery_url = realm_obj.buildDiscoveryURL()
self.assertEqual(actual_discovery_url, expected_discovery_url)
def test_trivial(self):
"""There is no wildcard and the realm is the same as the return_to URL
"""
self.assertDiscoveryURL('http://example.com/foo', 'http://example.com/foo')
def test_wildcard(self):
"""There is a wildcard
"""
self.assertDiscoveryURL('http://*.example.com/foo', 'http://www.example.com/foo')
def test_wildcard_port(self):
"""There is a wildcard
"""
self.assertDiscoveryURL('http://*.example.com:8001/foo', 'http://www.example.com:8001/foo')
class TestExtractReturnToURLs(unittest.TestCase):
disco_url = 'http://example.com/'
def setUp(self):
self.original_discover = services.discover
services.discover = self.mockDiscover
self.data = None
def tearDown(self):
services.discover = self.original_discover
def mockDiscover(self, uri):
result = DiscoveryResult(uri)
result.response_text = self.data
result.normalized_uri = uri
return result
def assertReturnURLs(self, data, expected_return_urls):
self.data = data
actual_return_urls = trustroot.getAllowedReturnURLs(self.disco_url)
self.assertEqual(actual_return_urls, expected_return_urls)
def assertDiscoveryFailure(self, text):
self.data = text
self.assertRaises(DiscoveryFailure, trustroot.getAllowedReturnURLs, self.disco_url)
def test_empty(self):
self.assertDiscoveryFailure('')
def test_badXML(self):
self.assertDiscoveryFailure('>')
def test_noEntries(self):
self.assertReturnURLs(b'''\
''', [])
def test_noReturnToEntries(self):
self.assertReturnURLs(b'''\
http://specs.openid.net/auth/2.0/server
http://www.myopenid.com/server
''', [])
def test_oneEntry(self):
self.assertReturnURLs(b'''\
http://specs.openid.net/auth/2.0/return_to
http://rp.example.com/return
''', ['http://rp.example.com/return'])
def test_twoEntries(self):
self.assertReturnURLs(b'''\
http://specs.openid.net/auth/2.0/return_to
http://rp.example.com/return
http://specs.openid.net/auth/2.0/return_to
http://other.rp.example.com/return
''', ['http://rp.example.com/return', 'http://other.rp.example.com/return'])
def test_twoEntries_withOther(self):
self.assertReturnURLs(b'''\
http://specs.openid.net/auth/2.0/return_to
http://rp.example.com/return
http://specs.openid.net/auth/2.0/return_to
http://other.rp.example.com/return
http://example.com/LOLCATS
http://example.com/invisible+uri
''', ['http://rp.example.com/return', 'http://other.rp.example.com/return'])
class TestReturnToMatches(unittest.TestCase):
def test_noEntries(self):
self.assertFalse(trustroot.returnToMatches([], 'anything'))
def test_exactMatch(self):
r = 'http://example.com/return.to'
self.assertTrue(trustroot.returnToMatches([r], r))
def test_garbageMatch(self):
r = 'http://example.com/return.to'
realm = 'This is not a URL at all. In fact, it has characters, like "<" that are not allowed in URLs'
self.assertTrue(trustroot.returnToMatches([realm, r], r))
def test_descendant(self):
r = 'http://example.com/return.to'
self.assertTrue(trustroot.returnToMatches([r], 'http://example.com/return.to/user:joe'))
def test_wildcard(self):
self.assertFalse(trustroot.returnToMatches(['http://*.example.com/return.to'], 'http://example.com/return.to'))
def test_noMatch(self):
r = 'http://example.com/return.to'
self.assertFalse(trustroot.returnToMatches([r], 'http://example.com/xss_exploit'))
class TestGetAllowedReturnURLs(unittest.TestCase):
def test_equal(self):
with patch('openid.yadis.services.getServiceEndpoints', autospec=True,
return_value=('http://example.com/', sentinel.endpoints)):
endpoints = getAllowedReturnURLs('http://example.com/')
self.assertEqual(endpoints, sentinel.endpoints)
def test_normalized(self):
# Test redirect is not reported when the returned URL is normalized.
with patch('openid.yadis.services.getServiceEndpoints', autospec=True,
return_value=('http://example.com/', sentinel.endpoints)):
endpoints = getAllowedReturnURLs('http://example.com:80')
self.assertEqual(endpoints, sentinel.endpoints)
class TestVerifyReturnTo(unittest.TestCase):
def test_bogusRealm(self):
self.assertFalse(trustroot.verifyReturnTo('', 'http://example.com/'))
def test_verifyWithDiscoveryCalled(self):
realm = 'http://*.example.com/'
return_to = 'http://www.example.com/foo'
def vrfy(disco_url):
self.assertEqual(disco_url, 'http://www.example.com/')
return [return_to]
with LogCapture() as logbook:
self.assertTrue(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy))
self.assertEqual(logbook.records, [])
def test_verifyFailWithDiscoveryCalled(self):
realm = 'http://*.example.com/'
return_to = 'http://www.example.com/foo'
def vrfy(disco_url):
self.assertEqual(disco_url, 'http://www.example.com/')
return ['http://something-else.invalid/']
with LogCapture() as logbook:
self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy))
logbook.check(('openid.server.trustroot', 'INFO', StringComparison('Failed to validate return_to .*')))
def test_verifyFailIfDiscoveryRedirects(self):
realm = 'http://*.example.com/'
return_to = 'http://www.example.com/foo'
def vrfy(disco_url):
raise trustroot.RealmVerificationRedirected(
disco_url, "http://redirected.invalid")
with LogCapture() as logbook:
self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy))
logbook.check(('openid.server.trustroot', 'INFO', StringComparison('Attempting to verify .*')))
if __name__ == '__main__':
unittest.main()