summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--oslo_middleware/cors.py41
-rw-r--r--oslo_middleware/tests/test_cors.py13
2 files changed, 40 insertions, 14 deletions
diff --git a/oslo_middleware/cors.py b/oslo_middleware/cors.py
index 314d776..65d7be0 100644
--- a/oslo_middleware/cors.py
+++ b/oslo_middleware/cors.py
@@ -53,6 +53,14 @@ CORS_OPTS = [
]
+class InvalidOriginError(Exception):
+ """Exception raised when Origin is invalid."""
+ def __init__(self, origin):
+ self.origin = origin
+ super(InvalidOriginError, self).__init__(
+ 'CORS request from origin \'%s\' not permitted.' % origin)
+
+
class CORS(base.ConfigurableMiddleware):
"""CORS Middleware.
@@ -263,15 +271,11 @@ class CORS(base.ConfigurableMiddleware):
return response
# Is this origin registered? (Section 6.2.2)
- origin = request.headers['Origin']
- if origin not in self.allowed_origins:
- if '*' in self.allowed_origins:
- origin = '*'
- else:
- LOG.debug('CORS request from origin \'%s\' not permitted.'
- % (origin,))
- return response
- cors_config = self.allowed_origins[origin]
+ try:
+ origin, cors_config = self._get_cors_config_by_origin(
+ request.headers['Origin'])
+ except InvalidOriginError:
+ return response
# If there's no request method, exit. (Section 6.2.3)
if 'Access-Control-Request-Method' not in request.headers:
@@ -335,6 +339,16 @@ class CORS(base.ConfigurableMiddleware):
return response
+ def _get_cors_config_by_origin(self, origin):
+ if origin not in self.allowed_origins:
+ if '*' in self.allowed_origins:
+ origin = '*'
+ else:
+ LOG.debug('CORS request from origin \'%s\' not permitted.'
+ % origin)
+ raise InvalidOriginError(origin)
+ return origin, self.allowed_origins[origin]
+
def _apply_cors_request_headers(self, request, response):
"""Handle Basic CORS Request (Section 6.1)
@@ -347,12 +361,11 @@ class CORS(base.ConfigurableMiddleware):
return
# Is this origin registered? (Section 6.1.2)
- origin = request.headers['Origin']
- if origin not in self.allowed_origins:
- LOG.debug('CORS request from origin \'%s\' not permitted.'
- % (origin,))
+ try:
+ origin, cors_config = self._get_cors_config_by_origin(
+ request.headers['Origin'])
+ except InvalidOriginError:
return
- cors_config = self.allowed_origins[origin]
# Set the default origin permission headers. (Sections 6.1.3 & 6.4)
response.headers['Vary'] = 'Origin'
diff --git a/oslo_middleware/tests/test_cors.py b/oslo_middleware/tests/test_cors.py
index 1499fc6..d1c1885 100644
--- a/oslo_middleware/tests/test_cors.py
+++ b/oslo_middleware/tests/test_cors.py
@@ -1023,6 +1023,19 @@ class CORSTestWildcard(CORSTestBase):
allow_credentials='true',
expose_headers=None)
+ # Test valid domain
+ request = webob.Request.blank('/')
+ request.method = "GET"
+ request.headers['Origin'] = 'http://default.example.com'
+ response = request.get_response(self.application)
+ self.assertCORSResponse(response,
+ status='200 OK',
+ allow_origin='http://default.example.com',
+ max_age=None,
+ allow_headers='',
+ allow_credentials='true',
+ expose_headers=None)
+
# Test invalid domain
request = webob.Request.blank('/')
request.method = "OPTIONS"