diff options
-rw-r--r-- | oslo_middleware/cors.py | 41 | ||||
-rw-r--r-- | oslo_middleware/tests/test_cors.py | 13 |
2 files changed, 40 insertions, 14 deletions
diff --git a/oslo_middleware/cors.py b/oslo_middleware/cors.py index 314d776..65d7be0 100644 --- a/oslo_middleware/cors.py +++ b/oslo_middleware/cors.py @@ -53,6 +53,14 @@ CORS_OPTS = [ ] +class InvalidOriginError(Exception): + """Exception raised when Origin is invalid.""" + def __init__(self, origin): + self.origin = origin + super(InvalidOriginError, self).__init__( + 'CORS request from origin \'%s\' not permitted.' % origin) + + class CORS(base.ConfigurableMiddleware): """CORS Middleware. @@ -263,15 +271,11 @@ class CORS(base.ConfigurableMiddleware): return response # Is this origin registered? (Section 6.2.2) - origin = request.headers['Origin'] - if origin not in self.allowed_origins: - if '*' in self.allowed_origins: - origin = '*' - else: - LOG.debug('CORS request from origin \'%s\' not permitted.' - % (origin,)) - return response - cors_config = self.allowed_origins[origin] + try: + origin, cors_config = self._get_cors_config_by_origin( + request.headers['Origin']) + except InvalidOriginError: + return response # If there's no request method, exit. (Section 6.2.3) if 'Access-Control-Request-Method' not in request.headers: @@ -335,6 +339,16 @@ class CORS(base.ConfigurableMiddleware): return response + def _get_cors_config_by_origin(self, origin): + if origin not in self.allowed_origins: + if '*' in self.allowed_origins: + origin = '*' + else: + LOG.debug('CORS request from origin \'%s\' not permitted.' + % origin) + raise InvalidOriginError(origin) + return origin, self.allowed_origins[origin] + def _apply_cors_request_headers(self, request, response): """Handle Basic CORS Request (Section 6.1) @@ -347,12 +361,11 @@ class CORS(base.ConfigurableMiddleware): return # Is this origin registered? (Section 6.1.2) - origin = request.headers['Origin'] - if origin not in self.allowed_origins: - LOG.debug('CORS request from origin \'%s\' not permitted.' - % (origin,)) + try: + origin, cors_config = self._get_cors_config_by_origin( + request.headers['Origin']) + except InvalidOriginError: return - cors_config = self.allowed_origins[origin] # Set the default origin permission headers. (Sections 6.1.3 & 6.4) response.headers['Vary'] = 'Origin' diff --git a/oslo_middleware/tests/test_cors.py b/oslo_middleware/tests/test_cors.py index 1499fc6..d1c1885 100644 --- a/oslo_middleware/tests/test_cors.py +++ b/oslo_middleware/tests/test_cors.py @@ -1023,6 +1023,19 @@ class CORSTestWildcard(CORSTestBase): allow_credentials='true', expose_headers=None) + # Test valid domain + request = webob.Request.blank('/') + request.method = "GET" + request.headers['Origin'] = 'http://default.example.com' + response = request.get_response(self.application) + self.assertCORSResponse(response, + status='200 OK', + allow_origin='http://default.example.com', + max_age=None, + allow_headers='', + allow_credentials='true', + expose_headers=None) + # Test invalid domain request = webob.Request.blank('/') request.method = "OPTIONS" |