diff options
-rw-r--r-- | oslo_middleware/cors.py | 50 | ||||
-rw-r--r-- | oslo_middleware/tests/test_cors.py | 155 |
2 files changed, 149 insertions, 56 deletions
diff --git a/oslo_middleware/cors.py b/oslo_middleware/cors.py index 513bdbc..a0d9322 100644 --- a/oslo_middleware/cors.py +++ b/oslo_middleware/cors.py @@ -117,15 +117,29 @@ class CORS(base.Middleware): self.allowed_origins[allowed_origin] = conf[section] - def process_request(self, req): - '''If we detect an OPTIONS request, handle it immediately.''' - if req.method == 'OPTIONS': - resp = webob.response.Response(status=webob.exc.HTTPOk.code) - self._apply_cors_preflight_headers(request=req, response=resp) - return resp - def process_response(self, response, request=None): - '''Detect CORS headers on the request, and decorate the response.''' + '''Check for CORS headers, and decorate if necessary. + + Perform two checks. First, if an OPTIONS request was issued, let the + application handle it, and (if necessary) decorate the response with + preflight headers. In this case, if a 404 is thrown by the underlying + application (i.e. if the underlying application does not handle + OPTIONS requests, the response code is overridden. + + In the case of all other requests, regular request headers are applied. + ''' + + # Sanity precheck: If we detect CORS headers provided by something in + # in the middleware chain, assume that it knows better. + if 'Access-Control-Allow-Origin' in response.headers: + return response + + # Doublecheck for an OPTIONS request. + if request.method == 'OPTIONS': + return self._apply_cors_preflight_headers(request=request, + response=response) + + # Apply regular CORS headers. self._apply_cors_request_headers(request=request, response=response) # Finally, return the response. @@ -148,9 +162,15 @@ class CORS(base.Middleware): appropriate for the request. """ + # If the response contains a 2XX code, we have to assume that the + # underlying middleware's response content needs to be persisted. + # Otherwise, create a new response. + if 200 > response.status_code or response.status_code >= 300: + response = webob.response.Response(status=webob.exc.HTTPOk.code) + # Does the request have an origin header? (Section 6.2.1) if 'Origin' not in request.headers: - return + return response # Is this origin registered? (Section 6.2.2) origin = request.headers['Origin'] @@ -160,12 +180,12 @@ class CORS(base.Middleware): else: LOG.debug('CORS request from origin \'%s\' not permitted.' % (origin,)) - return + return response cors_config = self.allowed_origins[origin] # If there's no request method, exit. (Section 6.2.3) if 'Access-Control-Request-Method' not in request.headers: - return + return response request_method = request.headers['Access-Control-Request-Method'] # Extract Request headers. If parsing fails, exit. (Section 6.2.4) @@ -175,11 +195,11 @@ class CORS(base.Middleware): 'Access-Control-Request-Headers') except Exception: LOG.debug('Cannot parse request headers.') - return + return response # Compare request method to permitted methods (Section 6.2.5) if request_method not in cors_config.allow_methods: - return + return response # Compare request headers to permitted headers, case-insensitively. # (Section 6.2.6) @@ -188,7 +208,7 @@ class CORS(base.Middleware): permitted_headers = cors_config.allow_headers + self.simple_headers if upper_header not in (header.upper() for header in permitted_headers): - return + return response # Set the default origin permission headers. (Sections 6.2.7, 6.4) response.headers['Vary'] = 'Origin' @@ -211,6 +231,8 @@ class CORS(base.Middleware): response.headers['Access-Control-Allow-Headers'] = \ ','.join(request_headers) + return response + def _apply_cors_request_headers(self, request, response): """Handle Basic CORS Request (Section 6.1) diff --git a/oslo_middleware/tests/test_cors.py b/oslo_middleware/tests/test_cors.py index 34001e0..229bb80 100644 --- a/oslo_middleware/tests/test_cors.py +++ b/oslo_middleware/tests/test_cors.py @@ -17,10 +17,32 @@ from oslo_config import fixture from oslotest import base as test_base import webob import webob.dec +import webob.exc as exc from oslo_middleware import cors +@webob.dec.wsgify +def test_application(req): + if req.path_info == '/server_cors': + # Mirror back the origin in the request. + response = webob.Response(status=200) + response.headers['Access-Control-Allow-Origin'] = \ + req.headers['Origin'] + response.headers['X-Server-Generated-Response'] = '1' + return response + + if req.path_info == '/server_no_cors': + # Send a response with no CORS headers. + response = webob.Response(status=200) + return response + + if req.method == 'OPTIONS': + raise exc.HTTPNotFound() + + return 'Hello World' + + class CORSTestBase(test_base.BaseTestCase): """Base class for all CORS tests. @@ -102,10 +124,6 @@ class CORSRegularRequestTest(CORSTestBase): """Setup the tests.""" super(CORSRegularRequestTest, self).setUp() - @webob.dec.wsgify - def application(req): - return 'Hello, World!!!' - # Set up the config fixture. config = self.useFixture(fixture.Config(cfg.CONF)) @@ -138,7 +156,7 @@ class CORSRegularRequestTest(CORSTestBase): allow_methods='GET,PUT,POST,DELETE,HEAD') # Now that the config is set up, create our application. - self.application = cors.CORS(application, cfg.CONF) + self.application = cors.CORS(test_application, cfg.CONF) def test_config_overrides(self): """Assert that the configuration options are properly registered.""" @@ -205,8 +223,7 @@ class CORSRegularRequestTest(CORSTestBase): request is outside the scope of this specification. """ for method in self.methods: - request = webob.Request({}) - request.method = method + request = webob.Request.blank('/') response = request.get_response(self.application) self.assertCORSResponse(response, status='200 OK', @@ -227,7 +244,7 @@ class CORSRegularRequestTest(CORSTestBase): # Test valid origin header. for method in self.methods: - request = webob.Request({}) + request = webob.Request.blank('/') request.method = method request.headers['Origin'] = 'http://valid.example.com' response = request.get_response(self.application) @@ -242,7 +259,7 @@ class CORSRegularRequestTest(CORSTestBase): # Test origin header not present in configuration. for method in self.methods: - request = webob.Request({}) + request = webob.Request.blank('/') request.method = method request.headers['Origin'] = 'http://invalid.example.com' response = request.get_response(self.application) @@ -257,7 +274,7 @@ class CORSRegularRequestTest(CORSTestBase): # Test valid, but case-mismatched origin header. for method in self.methods: - request = webob.Request({}) + request = webob.Request.blank('/') request.method = method request.headers['Origin'] = 'http://VALID.EXAMPLE.COM' response = request.get_response(self.application) @@ -285,7 +302,7 @@ class CORSRegularRequestTest(CORSTestBase): """ # Test valid origin header without credentials. for method in self.methods: - request = webob.Request({}) + request = webob.Request.blank('/') request.method = method request.headers['Origin'] = 'http://valid.example.com' response = request.get_response(self.application) @@ -300,7 +317,7 @@ class CORSRegularRequestTest(CORSTestBase): # Test valid origin header with credentials for method in self.methods: - request = webob.Request({}) + request = webob.Request.blank('/') request.method = method request.headers['Origin'] = 'http://creds.example.com' response = request.get_response(self.application) @@ -321,7 +338,7 @@ class CORSRegularRequestTest(CORSTestBase): names given in the list of exposed headers. """ for method in self.methods: - request = webob.Request({}) + request = webob.Request.blank('/') request.method = method request.headers['Origin'] = 'http://headers.example.com' response = request.get_response(self.application) @@ -334,6 +351,29 @@ class CORSRegularRequestTest(CORSTestBase): allow_credentials=None, expose_headers='X-Header-1,X-Header-2') + def test_application_options_response(self): + """Assert that an application provided OPTIONS response is honored. + + If the underlying application, via middleware or other, provides a + CORS response, its response should be honored. + """ + test_origin = 'http://creds.example.com' + + request = webob.Request.blank('/server_cors') + request.method = "GET" + request.headers['Origin'] = test_origin + request.headers['Access-Control-Request-Method'] = 'GET' + + response = request.get_response(self.application) + + # If the regular CORS handling catches this request, it should set + # the allow credentials header. This makes sure that it doesn't. + self.assertNotIn('Access-Control-Allow-Credentials', response.headers) + self.assertEqual(response.headers['Access-Control-Allow-Origin'], + test_origin) + self.assertEqual(response.headers['X-Server-Generated-Response'], + '1') + class CORSPreflightRequestTest(CORSTestBase): """CORS Specification Section 6.2 @@ -344,10 +384,6 @@ class CORSPreflightRequestTest(CORSTestBase): def setUp(self): super(CORSPreflightRequestTest, self).setUp() - @webob.dec.wsgify - def application(req): - return 'Hello, World!!!' - # Set up the config fixture. config = self.useFixture(fixture.Config(cfg.CONF)) @@ -380,7 +416,7 @@ class CORSPreflightRequestTest(CORSTestBase): allow_methods='GET,PUT,POST,DELETE,HEAD') # Now that the config is set up, create our application. - self.application = cors.CORS(application, cfg.CONF) + self.application = cors.CORS(test_application, cfg.CONF) def test_config_overrides(self): """Assert that the configuration options are properly registered.""" @@ -446,7 +482,7 @@ class CORSPreflightRequestTest(CORSTestBase): If the Origin header is not present terminate this set of steps. The request is outside the scope of this specification. """ - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" response = request.get_response(self.application) self.assertCORSResponse(response, @@ -467,7 +503,7 @@ class CORSPreflightRequestTest(CORSTestBase): """ # Test valid domain - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://valid.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -482,7 +518,7 @@ class CORSPreflightRequestTest(CORSTestBase): expose_headers=None) # Test invalid domain - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://invalid.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -497,7 +533,7 @@ class CORSPreflightRequestTest(CORSTestBase): expose_headers=None) # Test case-sensitive mismatch domain - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://VALID.EXAMPLE.COM' request.headers['Access-Control-Request-Method'] = 'GET' @@ -520,7 +556,7 @@ class CORSPreflightRequestTest(CORSTestBase): """ # Test valid domain, valid method. - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://get.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -535,7 +571,7 @@ class CORSPreflightRequestTest(CORSTestBase): expose_headers=None) # Test valid domain, invalid HTTP method. - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://valid.example.com' request.headers['Access-Control-Request-Method'] = 'TEAPOT' @@ -550,7 +586,7 @@ class CORSPreflightRequestTest(CORSTestBase): expose_headers=None) # Test valid domain, no HTTP method. - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://valid.example.com' response = request.get_response(self.application) @@ -570,7 +606,7 @@ class CORSPreflightRequestTest(CORSTestBase): list of methods do not set any additional headers and terminate this set of steps. """ - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://get.example.com' request.headers['Access-Control-Request-Method'] = 'get' @@ -594,7 +630,7 @@ class CORSPreflightRequestTest(CORSTestBase): this set of steps. The request is outside the scope of this specification. """ - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://headers.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -615,7 +651,7 @@ class CORSPreflightRequestTest(CORSTestBase): If there are no Access-Control-Request-Headers headers let header field-names be the empty list. """ - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://headers.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -639,7 +675,7 @@ class CORSPreflightRequestTest(CORSTestBase): If there are no Access-Control-Request-Headers headers let header field-names be the empty list. """ - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://headers.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -665,7 +701,7 @@ class CORSPreflightRequestTest(CORSTestBase): match for any of the values in list of headers do not set any additional headers and terminate this set of steps. """ - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://headers.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -694,7 +730,7 @@ class CORSPreflightRequestTest(CORSTestBase): NOTE: We never use the "*" as origin. """ - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://creds.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -715,7 +751,7 @@ class CORSPreflightRequestTest(CORSTestBase): the amount of seconds the user agent is allowed to cache the result of the request. """ - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://cached.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -740,7 +776,7 @@ class CORSPreflightRequestTest(CORSTestBase): enough. """ for method in ['GET', 'PUT', 'POST', 'DELETE']: - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://all.example.com' request.headers['Access-Control-Request-Method'] = method @@ -755,7 +791,7 @@ class CORSPreflightRequestTest(CORSTestBase): expose_headers=None) for method in ['PUT', 'POST', 'DELETE']: - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://get.example.com' request.headers['Access-Control-Request-Method'] = method @@ -786,7 +822,7 @@ class CORSPreflightRequestTest(CORSTestBase): requested_headers = 'Content-Type,X-Header-1,Cache-Control,Expires,' \ 'Last-Modified,Pragma' - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://headers.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -801,6 +837,45 @@ class CORSPreflightRequestTest(CORSTestBase): allow_credentials=None, expose_headers=None) + def test_application_options_response(self): + """Assert that an application provided OPTIONS response is honored. + + If the underlying application, via middleware or other, provides a + CORS response, its response should be honored. + """ + test_origin = 'http://creds.example.com' + + request = webob.Request.blank('/server_cors') + request.method = "OPTIONS" + request.headers['Origin'] = test_origin + request.headers['Access-Control-Request-Method'] = 'GET' + + response = request.get_response(self.application) + + # If the regular CORS handling catches this request, it should set + # the allow credentials header. This makes sure that it doesn't. + self.assertNotIn('Access-Control-Allow-Credentials', response.headers) + self.assertEqual(response.headers['Access-Control-Allow-Origin'], + test_origin) + self.assertEqual(response.headers['X-Server-Generated-Response'], + '1') + + # If the application returns an OPTIONS response without CORS + # headers, assert that we apply headers. + request = webob.Request.blank('/server_no_cors') + request.method = "OPTIONS" + request.headers['Origin'] = 'http://get.example.com' + request.headers['Access-Control-Request-Method'] = 'GET' + response = request.get_response(self.application) + self.assertCORSResponse(response, + status='200 OK', + allow_origin='http://get.example.com', + max_age=None, + allow_methods='GET', + allow_headers=None, + allow_credentials=None, + expose_headers=None) + class CORSTestWildcard(CORSTestBase): """Test the CORS wildcard specification.""" @@ -808,10 +883,6 @@ class CORSTestWildcard(CORSTestBase): def setUp(self): super(CORSTestWildcard, self).setUp() - @webob.dec.wsgify - def application(req): - return 'Hello, World!!!' - # Set up the config fixture. config = self.useFixture(fixture.Config(cfg.CONF)) @@ -828,7 +899,7 @@ class CORSTestWildcard(CORSTestBase): allow_methods='GET') # Now that the config is set up, create our application. - self.application = cors.CORS(application, cfg.CONF) + self.application = cors.CORS(test_application, cfg.CONF) def test_config_overrides(self): """Assert that the configuration options are properly registered.""" @@ -861,7 +932,7 @@ class CORSTestWildcard(CORSTestBase): """ # Test valid domain - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://default.example.com' request.headers['Access-Control-Request-Method'] = 'GET' @@ -876,7 +947,7 @@ class CORSTestWildcard(CORSTestBase): expose_headers=None) # Test invalid domain - request = webob.Request({}) + request = webob.Request.blank('/') request.method = "OPTIONS" request.headers['Origin'] = 'http://invalid.example.com' request.headers['Access-Control-Request-Method'] = 'GET' |