summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authordaniel-a-nguyen <daniel.a.nguyen@hp.com>2013-02-14 11:22:06 -0800
committerdaniel-a-nguyen <daniel.a.nguyen@hp.com>2013-03-05 14:17:36 -0800
commit0a71ef9e880996597f6f94d06fa2d0934051d420 (patch)
tree784fb4660548d3ee3434234fb3e7fa849ecc3647
parent8de8507493f66c2e90a1d2fd4a332c8610be60ac (diff)
downloadtrove-0a71ef9e880996597f6f94d06fa2d0934051d420.tar.gz
Rate limits implementation
added unittest for limits reverted changes to openstack/common removed commented code cleaned up unittest added int-tests updated reference to XMLNS removed 1.1 XMLMS in wsgi Implements: blueprint rate-limits Change-Id: I842de3a6cae1859cc246264a5836abfd97fb8074
-rw-r--r--etc/reddwarf/api-paste.ini5
-rw-r--r--etc/reddwarf/reddwarf.conf.sample6
-rw-r--r--etc/reddwarf/reddwarf.conf.test5
-rw-r--r--etc/tests/localhost.test.conf18
-rw-r--r--reddwarf/common/api.py7
-rw-r--r--reddwarf/common/cfg.py4
-rw-r--r--reddwarf/common/limits.py464
-rw-r--r--reddwarf/common/schemas/atom-link.rng141
-rw-r--r--reddwarf/common/schemas/atom.rng597
-rw-r--r--reddwarf/common/schemas/v1.1/limits.rng28
-rw-r--r--reddwarf/common/wsgi.py237
-rw-r--r--reddwarf/common/xmlutil.py910
-rw-r--r--reddwarf/limits/__init__.py0
-rw-r--r--reddwarf/limits/service.py49
-rw-r--r--reddwarf/limits/views.py98
-rw-r--r--reddwarf/tests/api/limits.py109
-rw-r--r--reddwarf/tests/unittests/api/common/__init__.py0
-rw-r--r--reddwarf/tests/unittests/api/common/test_limits.py741
-rw-r--r--reddwarf/tests/unittests/util/matchers.py454
-rw-r--r--run_tests.py1
-rw-r--r--tools/test-requires1
21 files changed, 3872 insertions, 3 deletions
diff --git a/etc/reddwarf/api-paste.ini b/etc/reddwarf/api-paste.ini
index 896c6c18..7f197ca5 100644
--- a/etc/reddwarf/api-paste.ini
+++ b/etc/reddwarf/api-paste.ini
@@ -7,7 +7,7 @@ use = call:reddwarf.common.wsgi:versioned_urlmap
paste.app_factory = reddwarf.versions:app_factory
[pipeline:reddwarfapi]
-pipeline = faultwrapper tokenauth authorization contextwrapper extensions reddwarfapp
+pipeline = faultwrapper tokenauth authorization contextwrapper ratelimit extensions reddwarfapp
#pipeline = debug extensions reddwarfapp
[filter:extensions]
@@ -34,6 +34,9 @@ paste.filter_factory = reddwarf.common.wsgi:ContextMiddleware.factory
[filter:faultwrapper]
paste.filter_factory = reddwarf.common.wsgi:FaultWrapper.factory
+[filter:ratelimit]
+paste.filter_factory = reddwarf.common.limits:RateLimitingMiddleware.factory
+
[app:reddwarfapp]
paste.app_factory = reddwarf.common.api:app_factory
diff --git a/etc/reddwarf/reddwarf.conf.sample b/etc/reddwarf/reddwarf.conf.sample
index 7399fbc6..89b4cf67 100644
--- a/etc/reddwarf/reddwarf.conf.sample
+++ b/etc/reddwarf/reddwarf.conf.sample
@@ -54,6 +54,12 @@ max_instances_per_user = 5
max_volumes_per_user = 100
volume_time_out=30
+# Config options for rate limits
+http_get_rate = 200
+http_post_rate = 200
+http_put_rate = 200
+http_delete_rate = 200
+
# Reddwarf DNS
reddwarf_dns_support = False
diff --git a/etc/reddwarf/reddwarf.conf.test b/etc/reddwarf/reddwarf.conf.test
index 03a0569d..96867724 100644
--- a/etc/reddwarf/reddwarf.conf.test
+++ b/etc/reddwarf/reddwarf.conf.test
@@ -106,7 +106,7 @@ use = call:reddwarf.common.wsgi:versioned_urlmap
paste.app_factory = reddwarf.versions:app_factory
[pipeline:reddwarfapi]
-pipeline = faultwrapper tokenauth authorization contextwrapper extensions reddwarfapp
+pipeline = faultwrapper tokenauth authorization contextwrapper ratelimit extensions reddwarfapp
# pipeline = debug reddwarfapp
[filter:extensions]
@@ -132,6 +132,9 @@ paste.filter_factory = reddwarf.common.wsgi:ContextMiddleware.factory
[filter:faultwrapper]
paste.filter_factory = reddwarf.common.wsgi:FaultWrapper.factory
+[filter:ratelimit]
+paste.filter_factory = reddwarf.common.limits:RateLimitingMiddleware.factory
+
[app:reddwarfapp]
paste.app_factory = reddwarf.common.api:app_factory
diff --git a/etc/tests/localhost.test.conf b/etc/tests/localhost.test.conf
index cae44a17..d9efcd27 100644
--- a/etc/tests/localhost.test.conf
+++ b/etc/tests/localhost.test.conf
@@ -45,6 +45,24 @@
"is_admin":false,
"services": ["reddwarf"]
}
+ },
+ {
+ "auth_user":"rate_limit",
+ "auth_key":"password",
+ "tenant":"4000",
+ "requirements": {
+ "is_admin":false,
+ "services": ["reddwarf"]
+ }
+ },
+ {
+ "auth_user":"rate_limit_exceeded",
+ "auth_key":"password",
+ "tenant":"4050",
+ "requirements": {
+ "is_admin":false,
+ "services": ["reddwarf"]
+ }
}
],
diff --git a/reddwarf/common/api.py b/reddwarf/common/api.py
index cb80ace4..3bc2ab50 100644
--- a/reddwarf/common/api.py
+++ b/reddwarf/common/api.py
@@ -19,6 +19,7 @@ from reddwarf.common import wsgi
from reddwarf.extensions.mgmt.host.instance import service as hostservice
from reddwarf.flavor.service import FlavorController
from reddwarf.instance.service import InstanceController
+from reddwarf.limits.service import LimitsController
from reddwarf.openstack.common import log as logging
from reddwarf.openstack.common import rpc
from reddwarf.versions import VersionsController
@@ -32,6 +33,7 @@ class API(wsgi.Router):
self._instance_router(mapper)
self._flavor_router(mapper)
self._versions_router(mapper)
+ self._limits_router(mapper)
def _versions_router(self, mapper):
versions_resource = VersionsController().create_resource()
@@ -48,6 +50,11 @@ class API(wsgi.Router):
path = "/{tenant_id}/flavors"
mapper.resource("flavor", path, controller=flavor_resource)
+ def _limits_router(self, mapper):
+ limits_resource = LimitsController().create_resource()
+ path = "/{tenant_id}/limits"
+ mapper.resource("limits", path, controller=limits_resource)
+
def app_factory(global_conf, **local_conf):
return API()
diff --git a/reddwarf/common/cfg.py b/reddwarf/common/cfg.py
index 937ba750..f6523a1d 100644
--- a/reddwarf/common/cfg.py
+++ b/reddwarf/common/cfg.py
@@ -100,6 +100,10 @@ common_opts = [
cfg.IntOpt('revert_time_out', default=60 * 10),
cfg.ListOpt('root_grant', default=['ALL']),
cfg.BoolOpt('root_grant_option', default=True),
+ cfg.IntOpt('http_get_rate', default=200),
+ cfg.IntOpt('http_post_rate', default=200),
+ cfg.IntOpt('http_delete_rate', default=200),
+ cfg.IntOpt('http_put_rate', default=200),
]
diff --git a/reddwarf/common/limits.py b/reddwarf/common/limits.py
new file mode 100644
index 00000000..bfe233f6
--- /dev/null
+++ b/reddwarf/common/limits.py
@@ -0,0 +1,464 @@
+# Copyright 2011 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Module dedicated functions/classes dealing with rate limiting requests.
+"""
+
+import collections
+import copy
+import httplib
+import math
+import re
+import time
+import webob.dec
+import webob.exc
+import xmlutil
+
+from reddwarf.common import cfg
+from reddwarf.common import wsgi as base_wsgi
+from reddwarf.openstack.common import importutils
+from reddwarf.openstack.common import jsonutils
+from reddwarf.openstack.common import wsgi
+from reddwarf.openstack.common.gettextutils import _
+
+#
+# TODO: come back to this later
+# Dan Nguyen
+#
+#from nova import quota
+#QUOTAS = quota.QUOTAS
+
+
+CONF = cfg.CONF
+
+# Convenience constants for the limits dictionary passed to Limiter().
+PER_SECOND = 1
+PER_MINUTE = 60
+PER_HOUR = 60 * 60
+PER_DAY = 60 * 60 * 24
+
+
+limits_nsmap = {None: xmlutil.XMLNS_COMMON_V10, 'atom': xmlutil.XMLNS_ATOM}
+
+
+class LimitsTemplate(xmlutil.TemplateBuilder):
+ def construct(self):
+ root = xmlutil.TemplateElement('limits', selector='limits')
+
+ rates = xmlutil.SubTemplateElement(root, 'rates')
+ rate = xmlutil.SubTemplateElement(rates, 'rate', selector='rate')
+ rate.set('uri', 'uri')
+ rate.set('regex', 'regex')
+ limit = xmlutil.SubTemplateElement(rate, 'limit', selector='limit')
+ limit.set('value', 'value')
+ limit.set('verb', 'verb')
+ limit.set('remaining', 'remaining')
+ limit.set('unit', 'unit')
+ limit.set('next-available', 'next-available')
+
+ absolute = xmlutil.SubTemplateElement(root, 'absolute',
+ selector='absolute')
+ limit = xmlutil.SubTemplateElement(absolute, 'limit',
+ selector=xmlutil.get_items)
+ limit.set('name', 0)
+ limit.set('value', 1)
+
+ return xmlutil.MasterTemplate(root, 1, nsmap=limits_nsmap)
+
+
+class Limit(object):
+ """
+ Stores information about a limit for HTTP requests.
+ """
+
+ UNITS = {
+ 1: "SECOND",
+ 60: "MINUTE",
+ 60 * 60: "HOUR",
+ 60 * 60 * 24: "DAY",
+ }
+
+ UNIT_MAP = dict([(v, k) for k, v in UNITS.items()])
+
+ def __init__(self, verb, uri, regex, value, unit):
+ """
+ Initialize a new `Limit`.
+
+ @param verb: HTTP verb (POST, PUT, etc.)
+ @param uri: Human-readable URI
+ @param regex: Regular expression format for this limit
+ @param value: Integer number of requests which can be made
+ @param unit: Unit of measure for the value parameter
+ """
+ self.verb = verb
+ self.uri = uri
+ self.regex = regex
+ self.value = int(value)
+ self.unit = unit
+ self.unit_string = self.display_unit().lower()
+ self.remaining = int(value)
+
+ if value <= 0:
+ raise ValueError("Limit value must be > 0")
+
+ self.last_request = None
+ self.next_request = None
+
+ self.water_level = 0
+ self.capacity = self.unit
+ self.request_value = float(self.capacity) / float(self.value)
+ msg = _("Only %(value)s %(verb)s request(s) can be "
+ "made to %(uri)s every %(unit_string)s.")
+ self.error_message = msg % self.__dict__
+
+ def __call__(self, verb, url):
+ """
+ Represents a call to this limit from a relevant request.
+
+ @param verb: string http verb (POST, GET, etc.)
+ @param url: string URL
+ """
+ if self.verb != verb or not re.match(self.regex, url):
+ return
+
+ now = self._get_time()
+
+ if self.last_request is None:
+ self.last_request = now
+
+ leak_value = now - self.last_request
+
+ self.water_level -= leak_value
+ self.water_level = max(self.water_level, 0)
+ self.water_level += self.request_value
+
+ difference = self.water_level - self.capacity
+
+ self.last_request = now
+
+ if difference > 0:
+ self.water_level -= self.request_value
+ self.next_request = now + difference
+ return difference
+
+ cap = self.capacity
+ water = self.water_level
+ val = self.value
+
+ self.remaining = math.floor(((cap - water) / cap) * val)
+ self.next_request = now
+
+ def _get_time(self):
+ """Retrieve the current time. Broken out for testability."""
+ return time.time()
+
+ def display_unit(self):
+ """Display the string name of the unit."""
+ return self.UNITS.get(self.unit, "UNKNOWN")
+
+ def display(self):
+ """Return a useful representation of this class."""
+ return {
+ "verb": self.verb,
+ "URI": self.uri,
+ "regex": self.regex,
+ "value": self.value,
+ "remaining": int(self.remaining),
+ "unit": self.display_unit(),
+ "resetTime": int(self.next_request or self._get_time()),
+ }
+
+# "Limit" format is a dictionary with the HTTP verb, human-readable URI,
+# a regular-expression to match, value and unit of measure (PER_DAY, etc.)
+DEFAULT_LIMITS = [
+ Limit("POST", "*", ".*", CONF.http_post_rate, PER_MINUTE),
+ Limit("PUT", "*", ".*", CONF.http_put_rate, PER_MINUTE),
+ Limit("DELETE", "*", ".*", CONF.http_delete_rate, PER_MINUTE),
+ Limit("GET", "*", ".*", CONF.http_get_rate, PER_MINUTE),
+]
+
+
+class RateLimitingMiddleware(base_wsgi.ReddwarfMiddleware):
+ """
+ Rate-limits requests passing through this middleware. All limit information
+ is stored in memory for this implementation.
+ """
+
+ def __init__(self, application, limits=None, limiter=None, **kwargs):
+ """
+ Initialize new `RateLimitingMiddleware`, which wraps the given WSGI
+ application and sets up the given limits.
+
+ @param application: WSGI application to wrap
+ @param limits: String describing limits
+ @param limiter: String identifying class for representing limits
+
+ Other parameters are passed to the constructor for the limiter.
+ """
+ base_wsgi.Middleware.__init__(self, application)
+
+ # Select the limiter class
+ if limiter is None:
+ limiter = Limiter
+ else:
+ limiter = importutils.import_class(limiter)
+
+ # Parse the limits, if any are provided
+ if limits is not None:
+ limits = limiter.parse_limits(limits)
+
+ self._limiter = limiter(limits or DEFAULT_LIMITS, **kwargs)
+
+ @webob.dec.wsgify(RequestClass=wsgi.Request)
+ def __call__(self, req):
+ """
+ Represents a single call through this middleware. We should record the
+ request if we have a limit relevant to it. If no limit is relevant to
+ the request, ignore it.
+
+ If the request should be rate limited, return a fault telling the user
+ they are over the limit and need to retry later.
+ """
+ verb = req.method
+ url = req.url
+ context = req.environ.get(base_wsgi.CONTEXT_KEY)
+
+ tenant_id = None
+ if context:
+ tenant_id = context.tenant
+
+ delay, error = self._limiter.check_for_delay(verb, url, tenant_id)
+
+ if delay:
+ msg = _("This request was rate-limited.")
+ retry = time.time() + delay
+ return base_wsgi.OverLimitFault(msg, error, retry)
+
+ req.environ["reddwarf.limits"] = self._limiter.get_limits(tenant_id)
+
+ return self.application
+
+
+class Limiter(object):
+ """
+ Rate-limit checking class which handles limits in memory.
+ """
+
+ def __init__(self, limits, **kwargs):
+ """
+ Initialize the new `Limiter`.
+
+ @param limits: List of `Limit` objects
+ """
+ self.limits = copy.deepcopy(limits)
+ self.levels = collections.defaultdict(lambda: copy.deepcopy(limits))
+
+ # Pick up any per-user limit information
+ for key, value in kwargs.items():
+ if key.startswith('user:'):
+ username = key[5:]
+ self.levels[username] = self.parse_limits(value)
+
+ def get_limits(self, username=None):
+ """
+ Return the limits for a given user.
+ """
+ return [limit.display() for limit in self.levels[username]]
+
+ def check_for_delay(self, verb, url, username=None):
+ """
+ Check the given verb/user/user triplet for limit.
+
+ @return: Tuple of delay (in seconds) and error message (or None, None)
+ """
+ delays = []
+
+ for limit in self.levels[username]:
+ delay = limit(verb, url)
+ if delay:
+ delays.append((delay, limit.error_message))
+
+ if delays:
+ delays.sort()
+ return delays[0]
+
+ return None, None
+
+ # This was ported from nova.
+ # Keeping it as a static method for the sake of consistency
+ #
+ # Note: This method gets called before the class is instantiated,
+ # so this must be either a static method or a class method. It is
+ # used to develop a list of limits to feed to the constructor. We
+ # put this in the class so that subclasses can override the
+ # default limit parsing.
+ @staticmethod
+ def parse_limits(limits):
+ """
+ Convert a string into a list of Limit instances. This
+ implementation expects a semicolon-separated sequence of
+ parenthesized groups, where each group contains a
+ comma-separated sequence consisting of HTTP method,
+ user-readable URI, a URI reg-exp, an integer number of
+ requests which can be made, and a unit of measure. Valid
+ values for the latter are "SECOND", "MINUTE", "HOUR", and
+ "DAY".
+
+ @return: List of Limit instances.
+ """
+
+ # Handle empty limit strings
+ limits = limits.strip()
+ if not limits:
+ return []
+
+ # Split up the limits by semicolon
+ result = []
+ for group in limits.split(';'):
+ group = group.strip()
+ if group[:1] != '(' or group[-1:] != ')':
+ raise ValueError("Limit rules must be surrounded by "
+ "parentheses")
+ group = group[1:-1]
+
+ # Extract the Limit arguments
+ args = [a.strip() for a in group.split(',')]
+ if len(args) != 5:
+ raise ValueError("Limit rules must contain the following "
+ "arguments: verb, uri, regex, value, unit")
+
+ # Pull out the arguments
+ verb, uri, regex, value, unit = args
+
+ # Upper-case the verb
+ verb = verb.upper()
+
+ # Convert value--raises ValueError if it's not integer
+ value = int(value)
+
+ # Convert unit
+ unit = unit.upper()
+ if unit not in Limit.UNIT_MAP:
+ raise ValueError("Invalid units specified")
+ unit = Limit.UNIT_MAP[unit]
+
+ # Build a limit
+ result.append(Limit(verb, uri, regex, value, unit))
+
+ return result
+
+
+class WsgiLimiter(object):
+ """
+ Rate-limit checking from a WSGI application. Uses an in-memory `Limiter`.
+
+ To use, POST ``/<username>`` with JSON data such as::
+
+ {
+ "verb" : GET,
+ "path" : "/servers"
+ }
+
+ and receive a 204 No Content, or a 403 Forbidden with an X-Wait-Seconds
+ header containing the number of seconds to wait before the action would
+ succeed.
+ """
+
+ def __init__(self, limits=None):
+ """
+ Initialize the new `WsgiLimiter`.
+
+ @param limits: List of `Limit` objects
+ """
+ self._limiter = Limiter(limits or DEFAULT_LIMITS)
+
+ @webob.dec.wsgify(RequestClass=wsgi.Request)
+ def __call__(self, request):
+ """
+ Handles a call to this application. Returns 204 if the request is
+ acceptable to the limiter, else a 403 is returned with a relevant
+ header indicating when the request *will* succeed.
+ """
+ if request.method != "POST":
+ raise webob.exc.HTTPMethodNotAllowed()
+
+ try:
+ info = dict(jsonutils.loads(request.body))
+ except ValueError:
+ raise webob.exc.HTTPBadRequest()
+
+ username = request.path_info_pop()
+ verb = info.get("verb")
+ path = info.get("path")
+
+ delay, error = self._limiter.check_for_delay(verb, path, username)
+
+ if delay:
+ headers = {"X-Wait-Seconds": "%.2f" % delay}
+ return webob.exc.HTTPForbidden(headers=headers, explanation=error)
+ else:
+ return webob.exc.HTTPNoContent()
+
+
+class WsgiLimiterProxy(object):
+ """
+ Rate-limit requests based on answers from a remote source.
+ """
+
+ def __init__(self, limiter_address):
+ """
+ Initialize the new `WsgiLimiterProxy`.
+
+ @param limiter_address: IP/port combination of where to request limit
+ """
+ self.limiter_address = limiter_address
+
+ def check_for_delay(self, verb, path, username=None):
+ body = jsonutils.dumps({"verb": verb, "path": path})
+ headers = {"Content-Type": "application/json"}
+
+ conn = httplib.HTTPConnection(self.limiter_address)
+
+ if username:
+ conn.request("POST", "/%s" % (username), body, headers)
+ else:
+ conn.request("POST", "/", body, headers)
+
+ resp = conn.getresponse()
+
+ if 200 >= resp.status < 300:
+ return None, None
+
+ return resp.getheader("X-Wait-Seconds"), resp.read() or None
+
+ # This was ported from nova.
+ # Keeping it as a static method for the sake of consistency
+ #
+ # Note: This method gets called before the class is instantiated,
+ # so this must be either a static method or a class method. It is
+ # used to develop a list of limits to feed to the constructor.
+ # This implementation returns an empty list, since all limit
+ # decisions are made by a remote server.
+ @staticmethod
+ def parse_limits(limits):
+ """
+ Ignore a limits string--simply doesn't apply for the limit
+ proxy.
+
+ @return: Empty list.
+ """
+
+ return []
diff --git a/reddwarf/common/schemas/atom-link.rng b/reddwarf/common/schemas/atom-link.rng
new file mode 100644
index 00000000..edba5eee
--- /dev/null
+++ b/reddwarf/common/schemas/atom-link.rng
@@ -0,0 +1,141 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ -*- rnc -*-
+ RELAX NG Compact Syntax Grammar for the
+ Atom Format Specification Version 11
+-->
+<grammar xmlns:xhtml="http://www.w3.org/1999/xhtml" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:s="http://www.ascc.net/xml/schematron" xmlns="http://relaxng.org/ns/structure/1.0" datatypeLibrary="http://www.w3.org/2001/XMLSchema-datatypes">
+ <start>
+ <choice>
+ <ref name="atomLink"/>
+ </choice>
+ </start>
+ <!-- Common attributes -->
+ <define name="atomCommonAttributes">
+ <optional>
+ <attribute name="xml:base">
+ <ref name="atomUri"/>
+ </attribute>
+ </optional>
+ <optional>
+ <attribute name="xml:lang">
+ <ref name="atomLanguageTag"/>
+ </attribute>
+ </optional>
+ <zeroOrMore>
+ <ref name="undefinedAttribute"/>
+ </zeroOrMore>
+ </define>
+ <!-- atom:link -->
+ <define name="atomLink">
+ <element name="atom:link">
+ <ref name="atomCommonAttributes"/>
+ <attribute name="href">
+ <ref name="atomUri"/>
+ </attribute>
+ <optional>
+ <attribute name="rel">
+ <choice>
+ <ref name="atomNCName"/>
+ <ref name="atomUri"/>
+ </choice>
+ </attribute>
+ </optional>
+ <optional>
+ <attribute name="type">
+ <ref name="atomMediaType"/>
+ </attribute>
+ </optional>
+ <optional>
+ <attribute name="hreflang">
+ <ref name="atomLanguageTag"/>
+ </attribute>
+ </optional>
+ <optional>
+ <attribute name="title"/>
+ </optional>
+ <optional>
+ <attribute name="length"/>
+ </optional>
+ <ref name="undefinedContent"/>
+ </element>
+ </define>
+ <!-- Low-level simple types -->
+ <define name="atomNCName">
+ <data type="string">
+ <param name="minLength">1</param>
+ <param name="pattern">[^:]*</param>
+ </data>
+ </define>
+ <!-- Whatever a media type is, it contains at least one slash -->
+ <define name="atomMediaType">
+ <data type="string">
+ <param name="pattern">.+/.+</param>
+ </data>
+ </define>
+ <!-- As defined in RFC 3066 -->
+ <define name="atomLanguageTag">
+ <data type="string">
+ <param name="pattern">[A-Za-z]{1,8}(-[A-Za-z0-9]{1,8})*</param>
+ </data>
+ </define>
+ <!--
+ Unconstrained; it's not entirely clear how IRI fit into
+ xsd:anyURI so let's not try to constrain it here
+ -->
+ <define name="atomUri">
+ <text/>
+ </define>
+ <!-- Other Extensibility -->
+ <define name="undefinedAttribute">
+ <attribute>
+ <anyName>
+ <except>
+ <name>xml:base</name>
+ <name>xml:lang</name>
+ <nsName ns=""/>
+ </except>
+ </anyName>
+ </attribute>
+ </define>
+ <define name="undefinedContent">
+ <zeroOrMore>
+ <choice>
+ <text/>
+ <ref name="anyForeignElement"/>
+ </choice>
+ </zeroOrMore>
+ </define>
+ <define name="anyElement">
+ <element>
+ <anyName/>
+ <zeroOrMore>
+ <choice>
+ <attribute>
+ <anyName/>
+ </attribute>
+ <text/>
+ <ref name="anyElement"/>
+ </choice>
+ </zeroOrMore>
+ </element>
+ </define>
+ <define name="anyForeignElement">
+ <element>
+ <anyName>
+ <except>
+ <nsName ns="http://www.w3.org/2005/Atom"/>
+ </except>
+ </anyName>
+ <zeroOrMore>
+ <choice>
+ <attribute>
+ <anyName/>
+ </attribute>
+ <text/>
+ <ref name="anyElement"/>
+ </choice>
+ </zeroOrMore>
+ </element>
+ </define>
+</grammar>
diff --git a/reddwarf/common/schemas/atom.rng b/reddwarf/common/schemas/atom.rng
new file mode 100644
index 00000000..c2df4e41
--- /dev/null
+++ b/reddwarf/common/schemas/atom.rng
@@ -0,0 +1,597 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ -*- rnc -*-
+ RELAX NG Compact Syntax Grammar for the
+ Atom Format Specification Version 11
+-->
+<grammar xmlns:xhtml="http://www.w3.org/1999/xhtml" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:s="http://www.ascc.net/xml/schematron" xmlns="http://relaxng.org/ns/structure/1.0" datatypeLibrary="http://www.w3.org/2001/XMLSchema-datatypes">
+ <start>
+ <choice>
+ <ref name="atomFeed"/>
+ <ref name="atomEntry"/>
+ </choice>
+ </start>
+ <!-- Common attributes -->
+ <define name="atomCommonAttributes">
+ <optional>
+ <attribute name="xml:base">
+ <ref name="atomUri"/>
+ </attribute>
+ </optional>
+ <optional>
+ <attribute name="xml:lang">
+ <ref name="atomLanguageTag"/>
+ </attribute>
+ </optional>
+ <zeroOrMore>
+ <ref name="undefinedAttribute"/>
+ </zeroOrMore>
+ </define>
+ <!-- Text Constructs -->
+ <define name="atomPlainTextConstruct">
+ <ref name="atomCommonAttributes"/>
+ <optional>
+ <attribute name="type">
+ <choice>
+ <value>text</value>
+ <value>html</value>
+ </choice>
+ </attribute>
+ </optional>
+ <text/>
+ </define>
+ <define name="atomXHTMLTextConstruct">
+ <ref name="atomCommonAttributes"/>
+ <attribute name="type">
+ <value>xhtml</value>
+ </attribute>
+ <ref name="xhtmlDiv"/>
+ </define>
+ <define name="atomTextConstruct">
+ <choice>
+ <ref name="atomPlainTextConstruct"/>
+ <ref name="atomXHTMLTextConstruct"/>
+ </choice>
+ </define>
+ <!-- Person Construct -->
+ <define name="atomPersonConstruct">
+ <ref name="atomCommonAttributes"/>
+ <interleave>
+ <element name="atom:name">
+ <text/>
+ </element>
+ <optional>
+ <element name="atom:uri">
+ <ref name="atomUri"/>
+ </element>
+ </optional>
+ <optional>
+ <element name="atom:email">
+ <ref name="atomEmailAddress"/>
+ </element>
+ </optional>
+ <zeroOrMore>
+ <ref name="extensionElement"/>
+ </zeroOrMore>
+ </interleave>
+ </define>
+ <!-- Date Construct -->
+ <define name="atomDateConstruct">
+ <ref name="atomCommonAttributes"/>
+ <data type="dateTime"/>
+ </define>
+ <!-- atom:feed -->
+ <define name="atomFeed">
+ <element name="atom:feed">
+ <s:rule context="atom:feed">
+ <s:assert test="atom:author or not(atom:entry[not(atom:author)])">An atom:feed must have an atom:author unless all of its atom:entry children have an atom:author.</s:assert>
+ </s:rule>
+ <ref name="atomCommonAttributes"/>
+ <interleave>
+ <zeroOrMore>
+ <ref name="atomAuthor"/>
+ </zeroOrMore>
+ <zeroOrMore>
+ <ref name="atomCategory"/>
+ </zeroOrMore>
+ <zeroOrMore>
+ <ref name="atomContributor"/>
+ </zeroOrMore>
+ <optional>
+ <ref name="atomGenerator"/>
+ </optional>
+ <optional>
+ <ref name="atomIcon"/>
+ </optional>
+ <ref name="atomId"/>
+ <zeroOrMore>
+ <ref name="atomLink"/>
+ </zeroOrMore>
+ <optional>
+ <ref name="atomLogo"/>
+ </optional>
+ <optional>
+ <ref name="atomRights"/>
+ </optional>
+ <optional>
+ <ref name="atomSubtitle"/>
+ </optional>
+ <ref name="atomTitle"/>
+ <ref name="atomUpdated"/>
+ <zeroOrMore>
+ <ref name="extensionElement"/>
+ </zeroOrMore>
+ </interleave>
+ <zeroOrMore>
+ <ref name="atomEntry"/>
+ </zeroOrMore>
+ </element>
+ </define>
+ <!-- atom:entry -->
+ <define name="atomEntry">
+ <element name="atom:entry">
+ <s:rule context="atom:entry">
+ <s:assert test="atom:link[@rel='alternate'] or atom:link[not(@rel)] or atom:content">An atom:entry must have at least one atom:link element with a rel attribute of 'alternate' or an atom:content.</s:assert>
+ </s:rule>
+ <s:rule context="atom:entry">
+ <s:assert test="atom:author or ../atom:author or atom:source/atom:author">An atom:entry must have an atom:author if its feed does not.</s:assert>
+ </s:rule>
+ <ref name="atomCommonAttributes"/>
+ <interleave>
+ <zeroOrMore>
+ <ref name="atomAuthor"/>
+ </zeroOrMore>
+ <zeroOrMore>
+ <ref name="atomCategory"/>
+ </zeroOrMore>
+ <optional>
+ <ref name="atomContent"/>
+ </optional>
+ <zeroOrMore>
+ <ref name="atomContributor"/>
+ </zeroOrMore>
+ <ref name="atomId"/>
+ <zeroOrMore>
+ <ref name="atomLink"/>
+ </zeroOrMore>
+ <optional>
+ <ref name="atomPublished"/>
+ </optional>
+ <optional>
+ <ref name="atomRights"/>
+ </optional>
+ <optional>
+ <ref name="atomSource"/>
+ </optional>
+ <optional>
+ <ref name="atomSummary"/>
+ </optional>
+ <ref name="atomTitle"/>
+ <ref name="atomUpdated"/>
+ <zeroOrMore>
+ <ref name="extensionElement"/>
+ </zeroOrMore>
+ </interleave>
+ </element>
+ </define>
+ <!-- atom:content -->
+ <define name="atomInlineTextContent">
+ <element name="atom:content">
+ <ref name="atomCommonAttributes"/>
+ <optional>
+ <attribute name="type">
+ <choice>
+ <value>text</value>
+ <value>html</value>
+ </choice>
+ </attribute>
+ </optional>
+ <zeroOrMore>
+ <text/>
+ </zeroOrMore>
+ </element>
+ </define>
+ <define name="atomInlineXHTMLContent">
+ <element name="atom:content">
+ <ref name="atomCommonAttributes"/>
+ <attribute name="type">
+ <value>xhtml</value>
+ </attribute>
+ <ref name="xhtmlDiv"/>
+ </element>
+ </define>
+ <define name="atomInlineOtherContent">
+ <element name="atom:content">
+ <ref name="atomCommonAttributes"/>
+ <optional>
+ <attribute name="type">
+ <ref name="atomMediaType"/>
+ </attribute>
+ </optional>
+ <zeroOrMore>
+ <choice>
+ <text/>
+ <ref name="anyElement"/>
+ </choice>
+ </zeroOrMore>
+ </element>
+ </define>
+ <define name="atomOutOfLineContent">
+ <element name="atom:content">
+ <ref name="atomCommonAttributes"/>
+ <optional>
+ <attribute name="type">
+ <ref name="atomMediaType"/>
+ </attribute>
+ </optional>
+ <attribute name="src">
+ <ref name="atomUri"/>
+ </attribute>
+ <empty/>
+ </element>
+ </define>
+ <define name="atomContent">
+ <choice>
+ <ref name="atomInlineTextContent"/>
+ <ref name="atomInlineXHTMLContent"/>
+ <ref name="atomInlineOtherContent"/>
+ <ref name="atomOutOfLineContent"/>
+ </choice>
+ </define>
+ <!-- atom:author -->
+ <define name="atomAuthor">
+ <element name="atom:author">
+ <ref name="atomPersonConstruct"/>
+ </element>
+ </define>
+ <!-- atom:category -->
+ <define name="atomCategory">
+ <element name="atom:category">
+ <ref name="atomCommonAttributes"/>
+ <attribute name="term"/>
+ <optional>
+ <attribute name="scheme">
+ <ref name="atomUri"/>
+ </attribute>
+ </optional>
+ <optional>
+ <attribute name="label"/>
+ </optional>
+ <ref name="undefinedContent"/>
+ </element>
+ </define>
+ <!-- atom:contributor -->
+ <define name="atomContributor">
+ <element name="atom:contributor">
+ <ref name="atomPersonConstruct"/>
+ </element>
+ </define>
+ <!-- atom:generator -->
+ <define name="atomGenerator">
+ <element name="atom:generator">
+ <ref name="atomCommonAttributes"/>
+ <optional>
+ <attribute name="uri">
+ <ref name="atomUri"/>
+ </attribute>
+ </optional>
+ <optional>
+ <attribute name="version"/>
+ </optional>
+ <text/>
+ </element>
+ </define>
+ <!-- atom:icon -->
+ <define name="atomIcon">
+ <element name="atom:icon">
+ <ref name="atomCommonAttributes"/>
+ <ref name="atomUri"/>
+ </element>
+ </define>
+ <!-- atom:id -->
+ <define name="atomId">
+ <element name="atom:id">
+ <ref name="atomCommonAttributes"/>
+ <ref name="atomUri"/>
+ </element>
+ </define>
+ <!-- atom:logo -->
+ <define name="atomLogo">
+ <element name="atom:logo">
+ <ref name="atomCommonAttributes"/>
+ <ref name="atomUri"/>
+ </element>
+ </define>
+ <!-- atom:link -->
+ <define name="atomLink">
+ <element name="atom:link">
+ <ref name="atomCommonAttributes"/>
+ <attribute name="href">
+ <ref name="atomUri"/>
+ </attribute>
+ <optional>
+ <attribute name="rel">
+ <choice>
+ <ref name="atomNCName"/>
+ <ref name="atomUri"/>
+ </choice>
+ </attribute>
+ </optional>
+ <optional>
+ <attribute name="type">
+ <ref name="atomMediaType"/>
+ </attribute>
+ </optional>
+ <optional>
+ <attribute name="hreflang">
+ <ref name="atomLanguageTag"/>
+ </attribute>
+ </optional>
+ <optional>
+ <attribute name="title"/>
+ </optional>
+ <optional>
+ <attribute name="length"/>
+ </optional>
+ <ref name="undefinedContent"/>
+ </element>
+ </define>
+ <!-- atom:published -->
+ <define name="atomPublished">
+ <element name="atom:published">
+ <ref name="atomDateConstruct"/>
+ </element>
+ </define>
+ <!-- atom:rights -->
+ <define name="atomRights">
+ <element name="atom:rights">
+ <ref name="atomTextConstruct"/>
+ </element>
+ </define>
+ <!-- atom:source -->
+ <define name="atomSource">
+ <element name="atom:source">
+ <ref name="atomCommonAttributes"/>
+ <interleave>
+ <zeroOrMore>
+ <ref name="atomAuthor"/>
+ </zeroOrMore>
+ <zeroOrMore>
+ <ref name="atomCategory"/>
+ </zeroOrMore>
+ <zeroOrMore>
+ <ref name="atomContributor"/>
+ </zeroOrMore>
+ <optional>
+ <ref name="atomGenerator"/>
+ </optional>
+ <optional>
+ <ref name="atomIcon"/>
+ </optional>
+ <optional>
+ <ref name="atomId"/>
+ </optional>
+ <zeroOrMore>
+ <ref name="atomLink"/>
+ </zeroOrMore>
+ <optional>
+ <ref name="atomLogo"/>
+ </optional>
+ <optional>
+ <ref name="atomRights"/>
+ </optional>
+ <optional>
+ <ref name="atomSubtitle"/>
+ </optional>
+ <optional>
+ <ref name="atomTitle"/>
+ </optional>
+ <optional>
+ <ref name="atomUpdated"/>
+ </optional>
+ <zeroOrMore>
+ <ref name="extensionElement"/>
+ </zeroOrMore>
+ </interleave>
+ </element>
+ </define>
+ <!-- atom:subtitle -->
+ <define name="atomSubtitle">
+ <element name="atom:subtitle">
+ <ref name="atomTextConstruct"/>
+ </element>
+ </define>
+ <!-- atom:summary -->
+ <define name="atomSummary">
+ <element name="atom:summary">
+ <ref name="atomTextConstruct"/>
+ </element>
+ </define>
+ <!-- atom:title -->
+ <define name="atomTitle">
+ <element name="atom:title">
+ <ref name="atomTextConstruct"/>
+ </element>
+ </define>
+ <!-- atom:updated -->
+ <define name="atomUpdated">
+ <element name="atom:updated">
+ <ref name="atomDateConstruct"/>
+ </element>
+ </define>
+ <!-- Low-level simple types -->
+ <define name="atomNCName">
+ <data type="string">
+ <param name="minLength">1</param>
+ <param name="pattern">[^:]*</param>
+ </data>
+ </define>
+ <!-- Whatever a media type is, it contains at least one slash -->
+ <define name="atomMediaType">
+ <data type="string">
+ <param name="pattern">.+/.+</param>
+ </data>
+ </define>
+ <!-- As defined in RFC 3066 -->
+ <define name="atomLanguageTag">
+ <data type="string">
+ <param name="pattern">[A-Za-z]{1,8}(-[A-Za-z0-9]{1,8})*</param>
+ </data>
+ </define>
+ <!--
+ Unconstrained; it's not entirely clear how IRI fit into
+ xsd:anyURI so let's not try to constrain it here
+ -->
+ <define name="atomUri">
+ <text/>
+ </define>
+ <!-- Whatever an email address is, it contains at least one @ -->
+ <define name="atomEmailAddress">
+ <data type="string">
+ <param name="pattern">.+@.+</param>
+ </data>
+ </define>
+ <!-- Simple Extension -->
+ <define name="simpleExtensionElement">
+ <element>
+ <anyName>
+ <except>
+ <nsName ns="http://www.w3.org/2005/Atom"/>
+ </except>
+ </anyName>
+ <text/>
+ </element>
+ </define>
+ <!-- Structured Extension -->
+ <define name="structuredExtensionElement">
+ <element>
+ <anyName>
+ <except>
+ <nsName ns="http://www.w3.org/2005/Atom"/>
+ </except>
+ </anyName>
+ <choice>
+ <group>
+ <oneOrMore>
+ <attribute>
+ <anyName/>
+ </attribute>
+ </oneOrMore>
+ <zeroOrMore>
+ <choice>
+ <text/>
+ <ref name="anyElement"/>
+ </choice>
+ </zeroOrMore>
+ </group>
+ <group>
+ <zeroOrMore>
+ <attribute>
+ <anyName/>
+ </attribute>
+ </zeroOrMore>
+ <group>
+ <optional>
+ <text/>
+ </optional>
+ <oneOrMore>
+ <ref name="anyElement"/>
+ </oneOrMore>
+ <zeroOrMore>
+ <choice>
+ <text/>
+ <ref name="anyElement"/>
+ </choice>
+ </zeroOrMore>
+ </group>
+ </group>
+ </choice>
+ </element>
+ </define>
+ <!-- Other Extensibility -->
+ <define name="extensionElement">
+ <choice>
+ <ref name="simpleExtensionElement"/>
+ <ref name="structuredExtensionElement"/>
+ </choice>
+ </define>
+ <define name="undefinedAttribute">
+ <attribute>
+ <anyName>
+ <except>
+ <name>xml:base</name>
+ <name>xml:lang</name>
+ <nsName ns=""/>
+ </except>
+ </anyName>
+ </attribute>
+ </define>
+ <define name="undefinedContent">
+ <zeroOrMore>
+ <choice>
+ <text/>
+ <ref name="anyForeignElement"/>
+ </choice>
+ </zeroOrMore>
+ </define>
+ <define name="anyElement">
+ <element>
+ <anyName/>
+ <zeroOrMore>
+ <choice>
+ <attribute>
+ <anyName/>
+ </attribute>
+ <text/>
+ <ref name="anyElement"/>
+ </choice>
+ </zeroOrMore>
+ </element>
+ </define>
+ <define name="anyForeignElement">
+ <element>
+ <anyName>
+ <except>
+ <nsName ns="http://www.w3.org/2005/Atom"/>
+ </except>
+ </anyName>
+ <zeroOrMore>
+ <choice>
+ <attribute>
+ <anyName/>
+ </attribute>
+ <text/>
+ <ref name="anyElement"/>
+ </choice>
+ </zeroOrMore>
+ </element>
+ </define>
+ <!-- XHTML -->
+ <define name="anyXHTML">
+ <element>
+ <nsName ns="http://www.w3.org/1999/xhtml"/>
+ <zeroOrMore>
+ <choice>
+ <attribute>
+ <anyName/>
+ </attribute>
+ <text/>
+ <ref name="anyXHTML"/>
+ </choice>
+ </zeroOrMore>
+ </element>
+ </define>
+ <define name="xhtmlDiv">
+ <element name="xhtml:div">
+ <zeroOrMore>
+ <choice>
+ <attribute>
+ <anyName/>
+ </attribute>
+ <text/>
+ <ref name="anyXHTML"/>
+ </choice>
+ </zeroOrMore>
+ </element>
+ </define>
+</grammar>
diff --git a/reddwarf/common/schemas/v1.1/limits.rng b/reddwarf/common/schemas/v1.1/limits.rng
new file mode 100644
index 00000000..a66af4b9
--- /dev/null
+++ b/reddwarf/common/schemas/v1.1/limits.rng
@@ -0,0 +1,28 @@
+<element name="limits" ns="http://docs.openstack.org/common/api/v1.0"
+ xmlns="http://relaxng.org/ns/structure/1.0">
+ <element name="rates">
+ <zeroOrMore>
+ <element name="rate">
+ <attribute name="uri"> <text/> </attribute>
+ <attribute name="regex"> <text/> </attribute>
+ <zeroOrMore>
+ <element name="limit">
+ <attribute name="value"> <text/> </attribute>
+ <attribute name="verb"> <text/> </attribute>
+ <attribute name="remaining"> <text/> </attribute>
+ <attribute name="unit"> <text/> </attribute>
+ <attribute name="next-available"> <text/> </attribute>
+ </element>
+ </zeroOrMore>
+ </element>
+ </zeroOrMore>
+ </element>
+ <element name="absolute">
+ <zeroOrMore>
+ <element name="limit">
+ <attribute name="name"> <text/> </attribute>
+ <attribute name="value"> <text/> </attribute>
+ </element>
+ </zeroOrMore>
+ </element>
+</element>
diff --git a/reddwarf/common/wsgi.py b/reddwarf/common/wsgi.py
index 673d6f38..d68562c7 100644
--- a/reddwarf/common/wsgi.py
+++ b/reddwarf/common/wsgi.py
@@ -17,12 +17,15 @@
"""Wsgi helper utilities for reddwarf"""
import eventlet.wsgi
+import math
import paste.urlmap
import re
+import time
import traceback
import webob
import webob.dec
import webob.exc
+from lxml import etree
from paste import deploy
from xml.dom import minidom
@@ -30,13 +33,14 @@ from reddwarf.common import context as rd_context
from reddwarf.common import exception
from reddwarf.common import utils
from reddwarf.openstack.common.gettextutils import _
+from reddwarf.openstack.common import jsonutils
+
from reddwarf.openstack.common import pastedeploy
from reddwarf.openstack.common import service
from reddwarf.openstack.common import wsgi as openstack_wsgi
from reddwarf.openstack.common import log as logging
from reddwarf.common import cfg
-
CONTEXT_KEY = 'reddwarf.context'
Router = openstack_wsgi.Router
Debug = openstack_wsgi.Debug
@@ -130,6 +134,54 @@ def launch(app_name, port, paste_config_file, data={},
return service.launch(server)
+# Note: taken from Nova
+def serializers(**serializers):
+ """Attaches serializers to a method.
+
+ This decorator associates a dictionary of serializers with a
+ method. Note that the function attributes are directly
+ manipulated; the method is not wrapped.
+ """
+
+ def decorator(func):
+ if not hasattr(func, 'wsgi_serializers'):
+ func.wsgi_serializers = {}
+ func.wsgi_serializers.update(serializers)
+ return func
+ return decorator
+
+
+class ReddwarfMiddleware(Middleware):
+
+ # Note: taken from nova
+ @classmethod
+ def factory(cls, global_config, **local_config):
+ """Used for paste app factories in paste.deploy config files.
+
+ Any local configuration (that is, values under the [filter:APPNAME]
+ section of the paste config) will be passed into the `__init__` method
+ as kwargs.
+
+ A hypothetical configuration would look like:
+
+ [filter:analytics]
+ redis_host = 127.0.0.1
+ paste.filter_factory = nova.api.analytics:Analytics.factory
+
+ which would result in a call to the `Analytics` class as
+
+ import nova.api.analytics
+ analytics.Analytics(app_from_paste, redis_host='127.0.0.1')
+
+ You could of course re-implement the `factory` method in subclasses,
+ but using the kwarg passing it shouldn't be necessary.
+
+ """
+ def _factory(app):
+ return cls(app, **local_config)
+ return _factory
+
+
class VersionedURLMap(object):
def __init__(self, urlmap):
@@ -591,3 +643,186 @@ class FaultWrapper(openstack_wsgi.Middleware):
def _factory(app):
return cls(app)
return _factory
+
+
+# ported from Nova
+class OverLimitFault(webob.exc.HTTPException):
+ """
+ Rate-limited request response.
+ """
+
+ def __init__(self, message, details, retry_time):
+ """
+ Initialize new `OverLimitFault` with relevant information.
+ """
+ hdrs = OverLimitFault._retry_after(retry_time)
+ self.wrapped_exc = webob.exc.HTTPRequestEntityTooLarge(headers=hdrs)
+ self.content = {"overLimit": {"code": self.wrapped_exc.status_int,
+ "message": message,
+ "details": details,
+ "retryAfter": hdrs['Retry-After'],
+ },
+ }
+
+ @staticmethod
+ def _retry_after(retry_time):
+ delay = int(math.ceil(retry_time - time.time()))
+ retry_after = delay if delay > 0 else 0
+ headers = {'Retry-After': '%d' % retry_after}
+ return headers
+
+ @webob.dec.wsgify(RequestClass=Request)
+ def __call__(self, request):
+ """
+ Return the wrapped exception with a serialized body conforming to our
+ error format.
+ """
+ content_type = request.best_match_content_type()
+ metadata = {"attributes": {"overLimit": ["code", "retryAfter"]}}
+
+ xml_serializer = XMLDictSerializer(metadata, XMLNS)
+ serializer = {'application/xml': xml_serializer,
+ 'application/json': JSONDictSerializer(),
+ }[content_type]
+
+ content = serializer.serialize(self.content)
+ self.wrapped_exc.body = content
+ self.wrapped_exc.content_type = content_type
+
+ return self.wrapped_exc
+
+
+class ActionDispatcher(object):
+ """Maps method name to local methods through action name."""
+
+ def dispatch(self, *args, **kwargs):
+ """Find and call local method."""
+ action = kwargs.pop('action', 'default')
+ action_method = getattr(self, str(action), self.default)
+ return action_method(*args, **kwargs)
+
+ def default(self, data):
+ raise NotImplementedError()
+
+
+class DictSerializer(ActionDispatcher):
+ """Default request body serialization."""
+
+ def serialize(self, data, action='default'):
+ return self.dispatch(data, action=action)
+
+ def default(self, data):
+ return ""
+
+
+class JSONDictSerializer(DictSerializer):
+ """Default JSON request body serialization."""
+
+ def default(self, data):
+ return jsonutils.dumps(data)
+
+
+class XMLDictSerializer(DictSerializer):
+
+ def __init__(self, metadata=None, xmlns=None):
+ """
+ :param metadata: information needed to deserialize xml into
+ a dictionary.
+ :param xmlns: XML namespace to include with serialized xml
+ """
+ super(XMLDictSerializer, self).__init__()
+ self.metadata = metadata or {}
+ self.xmlns = xmlns
+
+ def default(self, data):
+ # We expect data to contain a single key which is the XML root.
+ root_key = data.keys()[0]
+ doc = minidom.Document()
+ node = self._to_xml_node(doc, self.metadata, root_key, data[root_key])
+
+ return self.to_xml_string(node)
+
+ def to_xml_string(self, node, has_atom=False):
+ self._add_xmlns(node, has_atom)
+ return node.toxml('UTF-8')
+
+ #NOTE (ameade): the has_atom should be removed after all of the
+ # xml serializers and view builders have been updated to the current
+ # spec that required all responses include the xmlns:atom, the has_atom
+ # flag is to prevent current tests from breaking
+ def _add_xmlns(self, node, has_atom=False):
+ if self.xmlns is not None:
+ node.setAttribute('xmlns', self.xmlns)
+ if has_atom:
+ node.setAttribute('xmlns:atom', "http://www.w3.org/2005/Atom")
+
+ def _to_xml_node(self, doc, metadata, nodename, data):
+ """Recursive method to convert data members to XML nodes."""
+ result = doc.createElement(nodename)
+
+ # Set the xml namespace if one is specified
+ # TODO(justinsb): We could also use prefixes on the keys
+ xmlns = metadata.get('xmlns', None)
+ if xmlns:
+ result.setAttribute('xmlns', xmlns)
+
+ #TODO(bcwaldon): accomplish this without a type-check
+ if isinstance(data, list):
+ collections = metadata.get('list_collections', {})
+ if nodename in collections:
+ metadata = collections[nodename]
+ for item in data:
+ node = doc.createElement(metadata['item_name'])
+ node.setAttribute(metadata['item_key'], str(item))
+ result.appendChild(node)
+ return result
+ singular = metadata.get('plurals', {}).get(nodename, None)
+ if singular is None:
+ if nodename.endswith('s'):
+ singular = nodename[:-1]
+ else:
+ singular = 'item'
+ for item in data:
+ node = self._to_xml_node(doc, metadata, singular, item)
+ result.appendChild(node)
+ #TODO(bcwaldon): accomplish this without a type-check
+ elif isinstance(data, dict):
+ collections = metadata.get('dict_collections', {})
+ if nodename in collections:
+ metadata = collections[nodename]
+ for k, v in data.items():
+ node = doc.createElement(metadata['item_name'])
+ node.setAttribute(metadata['item_key'], str(k))
+ text = doc.createTextNode(str(v))
+ node.appendChild(text)
+ result.appendChild(node)
+ return result
+ attrs = metadata.get('attributes', {}).get(nodename, {})
+ for k, v in data.items():
+ if k in attrs:
+ result.setAttribute(k, str(v))
+ else:
+ if k == "deleted":
+ v = str(bool(v))
+ node = self._to_xml_node(doc, metadata, k, v)
+ result.appendChild(node)
+ else:
+ # Type is atom
+ node = doc.createTextNode(str(data))
+ result.appendChild(node)
+ return result
+
+ def _create_link_nodes(self, xml_doc, links):
+ link_nodes = []
+ for link in links:
+ link_node = xml_doc.createElement('atom:link')
+ link_node.setAttribute('rel', link['rel'])
+ link_node.setAttribute('href', link['href'])
+ if 'type' in link:
+ link_node.setAttribute('type', link['type'])
+ link_nodes.append(link_node)
+ return link_nodes
+
+ def _to_xml(self, root):
+ """Convert the xml object to an xml string."""
+ return etree.tostring(root, encoding='UTF-8', xml_declaration=True)
diff --git a/reddwarf/common/xmlutil.py b/reddwarf/common/xmlutil.py
new file mode 100644
index 00000000..934da12c
--- /dev/null
+++ b/reddwarf/common/xmlutil.py
@@ -0,0 +1,910 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2011 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import os.path
+
+from lxml import etree
+
+
+XMLNS_V10 = 'http://docs.rackspacecloud.com/servers/api/v1.0'
+XMLNS_V11 = 'http://docs.openstack.org/database/api/v1.1'
+XMLNS_COMMON_V10 = 'http://docs.openstack.org/common/api/v1.0'
+XMLNS_ATOM = 'http://www.w3.org/2005/Atom'
+
+
+def validate_schema(xml, schema_name):
+ if isinstance(xml, str):
+ xml = etree.fromstring(xml)
+ base_path = 'reddwarf/common/schemas/v1.1/'
+ if schema_name in ('atom', 'atom-link'):
+ base_path = 'reddwarf/common/schemas/'
+
+ # TODO: need to figure out our schema paths later
+ import reddwarf
+ schema_path = os.path.join(os.path.abspath(reddwarf.__file__)
+ .split('reddwarf/__init__.py')[0],
+ '%s%s.rng' % (base_path, schema_name))
+
+ schema_doc = etree.parse(schema_path)
+ relaxng = etree.RelaxNG(schema_doc)
+ relaxng.assertValid(xml)
+
+
+class Selector(object):
+ """Selects datum to operate on from an object."""
+
+ def __init__(self, *chain):
+ """Initialize the selector.
+
+ Each argument is a subsequent index into the object.
+ """
+
+ self.chain = chain
+
+ def __repr__(self):
+ """Return a representation of the selector."""
+
+ return "Selector" + repr(self.chain)
+
+ def __call__(self, obj, do_raise=False):
+ """Select a datum to operate on.
+
+ Selects the relevant datum within the object.
+
+ :param obj: The object from which to select the object.
+ :param do_raise: If False (the default), return None if the
+ indexed datum does not exist. Otherwise,
+ raise a KeyError.
+ """
+
+ # Walk the selector list
+ for elem in self.chain:
+ # If it's callable, call it
+ if callable(elem):
+ obj = elem(obj)
+ else:
+ # Use indexing
+ try:
+ obj = obj[elem]
+ except (KeyError, IndexError):
+ # No sense going any further
+ if do_raise:
+ # Convert to a KeyError, for consistency
+ raise KeyError(elem)
+ return None
+
+ # Return the finally-selected object
+ return obj
+
+
+def get_items(obj):
+ """Get items in obj."""
+
+ return list(obj.items())
+
+
+class EmptyStringSelector(Selector):
+ """Returns the empty string if Selector would return None."""
+ def __call__(self, obj, do_raise=False):
+ """Returns empty string if the selected value does not exist."""
+
+ try:
+ return super(EmptyStringSelector, self).__call__(obj, True)
+ except KeyError:
+ return ""
+
+
+class ConstantSelector(object):
+ """Returns a constant."""
+
+ def __init__(self, value):
+ """Initialize the selector.
+
+ :param value: The value to return.
+ """
+
+ self.value = value
+
+ def __repr__(self):
+ """Return a representation of the selector."""
+
+ return repr(self.value)
+
+ def __call__(self, _obj, _do_raise=False):
+ """Select a datum to operate on.
+
+ Returns a constant value. Compatible with
+ Selector.__call__().
+ """
+
+ return self.value
+
+
+class TemplateElement(object):
+ """Represent an element in the template."""
+
+ def __init__(self, tag, attrib=None, selector=None, subselector=None,
+ **extra):
+ """Initialize an element.
+
+ Initializes an element in the template. Keyword arguments
+ specify attributes to be set on the element; values must be
+ callables. See TemplateElement.set() for more information.
+
+ :param tag: The name of the tag to create.
+ :param attrib: An optional dictionary of element attributes.
+ :param selector: An optional callable taking an object and
+ optional boolean do_raise indicator and
+ returning the object bound to the element.
+ :param subselector: An optional callable taking an object and
+ optional boolean do_raise indicator and
+ returning the object bound to the element.
+ This is used to further refine the datum
+ object returned by selector in the event
+ that it is a list of objects.
+ """
+
+ # Convert selector into a Selector
+ if selector is None:
+ selector = Selector()
+ elif not callable(selector):
+ selector = Selector(selector)
+
+ # Convert subselector into a Selector
+ if subselector is not None and not callable(subselector):
+ subselector = Selector(subselector)
+
+ self.tag = tag
+ self.selector = selector
+ self.subselector = subselector
+ self.attrib = {}
+ self._text = None
+ self._children = []
+ self._childmap = {}
+
+ # Run the incoming attributes through set() so that they
+ # become selectorized
+ if not attrib:
+ attrib = {}
+ attrib.update(extra)
+ for k, v in attrib.items():
+ self.set(k, v)
+
+ def __repr__(self):
+ """Return a representation of the template element."""
+
+ return ('<%s.%s %r at %#x>' %
+ (self.__class__.__module__, self.__class__.__name__,
+ self.tag, id(self)))
+
+ def __len__(self):
+ """Return the number of child elements."""
+
+ return len(self._children)
+
+ def __contains__(self, key):
+ """Determine whether a child node named by key exists."""
+
+ return key in self._childmap
+
+ def __getitem__(self, idx):
+ """Retrieve a child node by index or name."""
+
+ if isinstance(idx, basestring):
+ # Allow access by node name
+ return self._childmap[idx]
+ else:
+ return self._children[idx]
+
+ def append(self, elem):
+ """Append a child to the element."""
+
+ # Unwrap templates...
+ elem = elem.unwrap()
+
+ # Avoid duplications
+ if elem.tag in self._childmap:
+ raise KeyError(elem.tag)
+
+ self._children.append(elem)
+ self._childmap[elem.tag] = elem
+
+ def extend(self, elems):
+ """Append children to the element."""
+
+ # Pre-evaluate the elements
+ elemmap = {}
+ elemlist = []
+ for elem in elems:
+ # Unwrap templates...
+ elem = elem.unwrap()
+
+ # Avoid duplications
+ if elem.tag in self._childmap or elem.tag in elemmap:
+ raise KeyError(elem.tag)
+
+ elemmap[elem.tag] = elem
+ elemlist.append(elem)
+
+ # Update the children
+ self._children.extend(elemlist)
+ self._childmap.update(elemmap)
+
+ def insert(self, idx, elem):
+ """Insert a child element at the given index."""
+
+ # Unwrap templates...
+ elem = elem.unwrap()
+
+ # Avoid duplications
+ if elem.tag in self._childmap:
+ raise KeyError(elem.tag)
+
+ self._children.insert(idx, elem)
+ self._childmap[elem.tag] = elem
+
+ def remove(self, elem):
+ """Remove a child element."""
+
+ # Unwrap templates...
+ elem = elem.unwrap()
+
+ # Check if element exists
+ if elem.tag not in self._childmap or self._childmap[elem.tag] != elem:
+ raise ValueError(_('element is not a child'))
+
+ self._children.remove(elem)
+ del self._childmap[elem.tag]
+
+ def get(self, key):
+ """Get an attribute.
+
+ Returns a callable which performs datum selection.
+
+ :param key: The name of the attribute to get.
+ """
+
+ return self.attrib[key]
+
+ def set(self, key, value=None):
+ """Set an attribute.
+
+ :param key: The name of the attribute to set.
+
+ :param value: A callable taking an object and optional boolean
+ do_raise indicator and returning the datum bound
+ to the attribute. If None, a Selector() will be
+ constructed from the key. If a string, a
+ Selector() will be constructed from the string.
+ """
+
+ # Convert value to a selector
+ if value is None:
+ value = Selector(key)
+ elif not callable(value):
+ value = Selector(value)
+
+ self.attrib[key] = value
+
+ def keys(self):
+ """Return the attribute names."""
+
+ return self.attrib.keys()
+
+ def items(self):
+ """Return the attribute names and values."""
+
+ return self.attrib.items()
+
+ def unwrap(self):
+ """Unwraps a template to return a template element."""
+
+ # We are a template element
+ return self
+
+ def wrap(self):
+ """Wraps a template element to return a template."""
+
+ # Wrap in a basic Template
+ return Template(self)
+
+ def apply(self, elem, obj):
+ """Apply text and attributes to an etree.Element.
+
+ Applies the text and attribute instructions in the template
+ element to an etree.Element instance.
+
+ :param elem: An etree.Element instance.
+ :param obj: The base object associated with this template
+ element.
+ """
+
+ # Start with the text...
+ if self.text is not None:
+ elem.text = unicode(self.text(obj))
+
+ # Now set up all the attributes...
+ for key, value in self.attrib.items():
+ try:
+ elem.set(key, unicode(value(obj, True)))
+ except KeyError:
+ # Attribute has no value, so don't include it
+ pass
+
+ def _render(self, parent, datum, patches, nsmap):
+ """Internal rendering.
+
+ Renders the template node into an etree.Element object.
+ Returns the etree.Element object.
+
+ :param parent: The parent etree.Element instance.
+ :param datum: The datum associated with this template element.
+ :param patches: A list of other template elements that must
+ also be applied.
+ :param nsmap: An optional namespace dictionary to be
+ associated with the etree.Element instance.
+ """
+
+ # Allocate a node
+ if callable(self.tag):
+ tagname = self.tag(datum)
+ else:
+ tagname = self.tag
+ elem = etree.Element(tagname, nsmap=nsmap)
+
+ # If we have a parent, append the node to the parent
+ if parent is not None:
+ parent.append(elem)
+
+ # If the datum is None, do nothing else
+ if datum is None:
+ return elem
+
+ # Apply this template element to the element
+ self.apply(elem, datum)
+
+ # Additionally, apply the patches
+ for patch in patches:
+ patch.apply(elem, datum)
+
+ # We have fully rendered the element; return it
+ return elem
+
+ def render(self, parent, obj, patches=[], nsmap=None):
+ """Render an object.
+
+ Renders an object against this template node. Returns a list
+ of two-item tuples, where the first item is an etree.Element
+ instance and the second item is the datum associated with that
+ instance.
+
+ :param parent: The parent for the etree.Element instances.
+ :param obj: The object to render this template element
+ against.
+ :param patches: A list of other template elements to apply
+ when rendering this template element.
+ :param nsmap: An optional namespace dictionary to attach to
+ the etree.Element instances.
+ """
+
+ # First, get the datum we're rendering
+ data = None if obj is None else self.selector(obj)
+
+ # Check if we should render at all
+ if not self.will_render(data):
+ return []
+ elif data is None:
+ return [(self._render(parent, None, patches, nsmap), None)]
+
+ # Make the data into a list if it isn't already
+ if not isinstance(data, list):
+ data = [data]
+ elif parent is None:
+ raise ValueError(_('root element selecting a list'))
+
+ # Render all the elements
+ elems = []
+ for datum in data:
+ if self.subselector is not None:
+ datum = self.subselector(datum)
+ elems.append((self._render(parent, datum, patches, nsmap), datum))
+
+ # Return all the elements rendered, as well as the
+ # corresponding datum for the next step down the tree
+ return elems
+
+ def will_render(self, datum):
+ """Hook method.
+
+ An overridable hook method to determine whether this template
+ element will be rendered at all. By default, returns False
+ (inhibiting rendering) if the datum is None.
+
+ :param datum: The datum associated with this template element.
+ """
+
+ # Don't render if datum is None
+ return datum is not None
+
+ def _text_get(self):
+ """Template element text.
+
+ Either None or a callable taking an object and optional
+ boolean do_raise indicator and returning the datum bound to
+ the text of the template element.
+ """
+
+ return self._text
+
+ def _text_set(self, value):
+ # Convert value to a selector
+ if value is not None and not callable(value):
+ value = Selector(value)
+
+ self._text = value
+
+ def _text_del(self):
+ self._text = None
+
+ text = property(_text_get, _text_set, _text_del)
+
+ def tree(self):
+ """Return string representation of the template tree.
+
+ Returns a representation of the template rooted at this
+ element as a string, suitable for inclusion in debug logs.
+ """
+
+ # Build the inner contents of the tag...
+ contents = [self.tag, '!selector=%r' % self.selector]
+
+ # Add the text...
+ if self.text is not None:
+ contents.append('!text=%r' % self.text)
+
+ # Add all the other attributes
+ for key, value in self.attrib.items():
+ contents.append('%s=%r' % (key, value))
+
+ # If there are no children, return it as a closed tag
+ if len(self) == 0:
+ return '<%s/>' % ' '.join([str(i) for i in contents])
+
+ # OK, recurse to our children
+ children = [c.tree() for c in self]
+
+ # Return the result
+ return ('<%s>%s</%s>' %
+ (' '.join(contents), ''.join(children), self.tag))
+
+
+def SubTemplateElement(parent, tag, attrib=None, selector=None,
+ subselector=None, **extra):
+ """Create a template element as a child of another.
+
+ Corresponds to the etree.SubElement interface. Parameters are as
+ for TemplateElement, with the addition of the parent.
+ """
+
+ # Convert attributes
+ attrib = attrib or {}
+ attrib.update(extra)
+
+ # Get a TemplateElement
+ elem = TemplateElement(tag, attrib=attrib, selector=selector,
+ subselector=subselector)
+
+ # Append the parent safely
+ if parent is not None:
+ parent.append(elem)
+
+ return elem
+
+
+class Template(object):
+ """Represent a template."""
+
+ def __init__(self, root, nsmap=None):
+ """Initialize a template.
+
+ :param root: The root element of the template.
+ :param nsmap: An optional namespace dictionary to be
+ associated with the root element of the
+ template.
+ """
+
+ self.root = root.unwrap() if root is not None else None
+ self.nsmap = nsmap or {}
+ self.serialize_options = dict(encoding='UTF-8', xml_declaration=True)
+
+ def _serialize(self, parent, obj, siblings, nsmap=None):
+ """Internal serialization.
+
+ Recursive routine to build a tree of etree.Element instances
+ from an object based on the template. Returns the first
+ etree.Element instance rendered, or None.
+
+ :param parent: The parent etree.Element instance. Can be
+ None.
+ :param obj: The object to render.
+ :param siblings: The TemplateElement instances against which
+ to render the object.
+ :param nsmap: An optional namespace dictionary to be
+ associated with the etree.Element instance
+ rendered.
+ """
+
+ # First step, render the element
+ elems = siblings[0].render(parent, obj, siblings[1:], nsmap)
+
+ # Now, recurse to all child elements
+ seen = set()
+ for idx, sibling in enumerate(siblings):
+ for child in sibling:
+ # Have we handled this child already?
+ if child.tag in seen:
+ continue
+ seen.add(child.tag)
+
+ # Determine the child's siblings
+ nieces = [child]
+ for sib in siblings[idx + 1:]:
+ if child.tag in sib:
+ nieces.append(sib[child.tag])
+
+ # Now we recurse for every data element
+ for elem, datum in elems:
+ self._serialize(elem, datum, nieces)
+
+ # Return the first element; at the top level, this will be the
+ # root element
+ if elems:
+ return elems[0][0]
+
+ def serialize(self, obj, *args, **kwargs):
+ """Serialize an object.
+
+ Serializes an object against the template. Returns a string
+ with the serialized XML. Positional and keyword arguments are
+ passed to etree.tostring().
+
+ :param obj: The object to serialize.
+ """
+
+ elem = self.make_tree(obj)
+ if elem is None:
+ return ''
+
+ for k, v in self.serialize_options.items():
+ kwargs.setdefault(k, v)
+
+ # Serialize it into XML
+ return etree.tostring(elem, *args, **kwargs)
+
+ def make_tree(self, obj):
+ """Create a tree.
+
+ Serializes an object against the template. Returns an Element
+ node with appropriate children.
+
+ :param obj: The object to serialize.
+ """
+
+ # If the template is empty, return the empty string
+ if self.root is None:
+ return None
+
+ # Get the siblings and nsmap of the root element
+ siblings = self._siblings()
+ nsmap = self._nsmap()
+
+ # Form the element tree
+ return self._serialize(None, obj, siblings, nsmap)
+
+ def _siblings(self):
+ """Hook method for computing root siblings.
+
+ An overridable hook method to return the siblings of the root
+ element. By default, this is the root element itself.
+ """
+
+ return [self.root]
+
+ def _nsmap(self):
+ """Hook method for computing the namespace dictionary.
+
+ An overridable hook method to return the namespace dictionary.
+ """
+
+ return self.nsmap.copy()
+
+ def unwrap(self):
+ """Unwraps a template to return a template element."""
+
+ # Return the root element
+ return self.root
+
+ def wrap(self):
+ """Wraps a template element to return a template."""
+
+ # We are a template
+ return self
+
+ def apply(self, master):
+ """Hook method for determining slave applicability.
+
+ An overridable hook method used to determine if this template
+ is applicable as a slave to a given master template.
+
+ :param master: The master template to test.
+ """
+
+ return True
+
+ def tree(self):
+ """Return string representation of the template tree.
+
+ Returns a representation of the template as a string, suitable
+ for inclusion in debug logs.
+ """
+
+ return "%r: %s" % (self, self.root.tree())
+
+
+class MasterTemplate(Template):
+ """Represent a master template.
+
+ Master templates are versioned derivatives of templates that
+ additionally allow slave templates to be attached. Slave
+ templates allow modification of the serialized result without
+ directly changing the master.
+ """
+
+ def __init__(self, root, version, nsmap=None):
+ """Initialize a master template.
+
+ :param root: The root element of the template.
+ :param version: The version number of the template.
+ :param nsmap: An optional namespace dictionary to be
+ associated with the root element of the
+ template.
+ """
+
+ super(MasterTemplate, self).__init__(root, nsmap)
+ self.version = version
+ self.slaves = []
+
+ def __repr__(self):
+ """Return string representation of the template."""
+
+ return ("<%s.%s object version %s at %#x>" %
+ (self.__class__.__module__, self.__class__.__name__,
+ self.version, id(self)))
+
+ def _siblings(self):
+ """Hook method for computing root siblings.
+
+ An overridable hook method to return the siblings of the root
+ element. This is the root element plus the root elements of
+ all the slave templates.
+ """
+
+ return [self.root] + [slave.root for slave in self.slaves]
+
+ def _nsmap(self):
+ """Hook method for computing the namespace dictionary.
+
+ An overridable hook method to return the namespace dictionary.
+ The namespace dictionary is computed by taking the master
+ template's namespace dictionary and updating it from all the
+ slave templates.
+ """
+
+ nsmap = self.nsmap.copy()
+ for slave in self.slaves:
+ nsmap.update(slave._nsmap())
+ return nsmap
+
+ def attach(self, *slaves):
+ """Attach one or more slave templates.
+
+ Attaches one or more slave templates to the master template.
+ Slave templates must have a root element with the same tag as
+ the master template. The slave template's apply() method will
+ be called to determine if the slave should be applied to this
+ master; if it returns False, that slave will be skipped.
+ (This allows filtering of slaves based on the version of the
+ master template.)
+ """
+
+ slave_list = []
+ for slave in slaves:
+ slave = slave.wrap()
+
+ # Make sure we have a tree match
+ if slave.root.tag != self.root.tag:
+ slavetag = slave.root.tag
+ mastertag = self.root.tag
+ msg = _("Template tree mismatch; adding slave %(slavetag)s "
+ "to master %(mastertag)s") % locals()
+ raise ValueError(msg)
+
+ # Make sure slave applies to this template
+ if not slave.apply(self):
+ continue
+
+ slave_list.append(slave)
+
+ # Add the slaves
+ self.slaves.extend(slave_list)
+
+ def copy(self):
+ """Return a copy of this master template."""
+
+ # Return a copy of the MasterTemplate
+ tmp = self.__class__(self.root, self.version, self.nsmap)
+ tmp.slaves = self.slaves[:]
+ return tmp
+
+
+class SlaveTemplate(Template):
+ """Represent a slave template.
+
+ Slave templates are versioned derivatives of templates. Each
+ slave has a minimum version and optional maximum version of the
+ master template to which they can be attached.
+ """
+
+ def __init__(self, root, min_vers, max_vers=None, nsmap=None):
+ """Initialize a slave template.
+
+ :param root: The root element of the template.
+ :param min_vers: The minimum permissible version of the master
+ template for this slave template to apply.
+ :param max_vers: An optional upper bound for the master
+ template version.
+ :param nsmap: An optional namespace dictionary to be
+ associated with the root element of the
+ template.
+ """
+
+ super(SlaveTemplate, self).__init__(root, nsmap)
+ self.min_vers = min_vers
+ self.max_vers = max_vers
+
+ def __repr__(self):
+ """Return string representation of the template."""
+
+ return ("<%s.%s object versions %s-%s at %#x>" %
+ (self.__class__.__module__, self.__class__.__name__,
+ self.min_vers, self.max_vers, id(self)))
+
+ def apply(self, master):
+ """Hook method for determining slave applicability.
+
+ An overridable hook method used to determine if this template
+ is applicable as a slave to a given master template. This
+ version requires the master template to have a version number
+ between min_vers and max_vers.
+
+ :param master: The master template to test.
+ """
+
+ # Does the master meet our minimum version requirement?
+ if master.version < self.min_vers:
+ return False
+
+ # How about our maximum version requirement?
+ if self.max_vers is not None and master.version > self.max_vers:
+ return False
+
+ return True
+
+
+class TemplateBuilder(object):
+ """Template builder.
+
+ This class exists to allow templates to be lazily built without
+ having to build them each time they are needed. It must be
+ subclassed, and the subclass must implement the construct()
+ method, which must return a Template (or subclass) instance. The
+ constructor will always return the template returned by
+ construct(), or, if it has a copy() method, a copy of that
+ template.
+ """
+
+ _tmpl = None
+
+ def __new__(cls, copy=True):
+ """Construct and return a template.
+
+ :param copy: If True (the default), a copy of the template
+ will be constructed and returned, if possible.
+ """
+
+ # Do we need to construct the template?
+ if cls._tmpl is None:
+ tmp = super(TemplateBuilder, cls).__new__(cls)
+
+ # Construct the template
+ cls._tmpl = tmp.construct()
+
+ # If the template has a copy attribute, return the result of
+ # calling it
+ if copy and hasattr(cls._tmpl, 'copy'):
+ return cls._tmpl.copy()
+
+ # Return the template
+ return cls._tmpl
+
+ def construct(self):
+ """Construct a template.
+
+ Called to construct a template instance, which it must return.
+ Only called once.
+ """
+
+ raise NotImplementedError(_("subclasses must implement construct()!"))
+
+
+def make_links(parent, selector=None):
+ """
+ Attach an Atom <links> element to the parent.
+ """
+
+ elem = SubTemplateElement(parent, '{%s}link' % XMLNS_ATOM,
+ selector=selector)
+ elem.set('rel')
+ elem.set('type')
+ elem.set('href')
+
+ # Just for completeness...
+ return elem
+
+
+def make_flat_dict(name, selector=None, subselector=None, ns=None):
+ """
+ Utility for simple XML templates that traditionally used
+ XMLDictSerializer with no metadata. Returns a template element
+ where the top-level element has the given tag name, and where
+ sub-elements have tag names derived from the object's keys and
+ text derived from the object's values. This only works for flat
+ dictionary objects, not dictionaries containing nested lists or
+ dictionaries.
+ """
+
+ # Set up the names we need...
+ if ns is None:
+ elemname = name
+ tagname = Selector(0)
+ else:
+ elemname = '{%s}%s' % (ns, name)
+ tagname = lambda obj, do_raise=False: '{%s}%s' % (ns, obj[0])
+
+ if selector is None:
+ selector = name
+
+ # Build the root element
+ root = TemplateElement(elemname, selector=selector,
+ subselector=subselector)
+
+ # Build an element to represent all the keys and values
+ elem = SubTemplateElement(root, tagname, selector=get_items)
+ elem.text = 1
+
+ # Return the template
+ return root
diff --git a/reddwarf/limits/__init__.py b/reddwarf/limits/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/reddwarf/limits/__init__.py
diff --git a/reddwarf/limits/service.py b/reddwarf/limits/service.py
new file mode 100644
index 00000000..3bb8a7b3
--- /dev/null
+++ b/reddwarf/limits/service.py
@@ -0,0 +1,49 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2013 OpenStack LLC
+# Copyright 2013 Hewlett-Packard Development Company, L.P.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+from reddwarf.common import wsgi as base_wsgi
+from reddwarf.common.limits import LimitsTemplate
+from reddwarf.limits import views
+from reddwarf.openstack.common import wsgi
+
+
+class LimitsController(base_wsgi.Controller):
+ """
+ Controller for accessing limits in the OpenStack API.
+ Note: this is a little different than how other controllers are implemented
+ """
+
+ @base_wsgi.serializers(xml=LimitsTemplate)
+ def index(self, req, tenant_id):
+ """
+ Return all global and rate limit information.
+ """
+ context = req.environ[base_wsgi.CONTEXT_KEY]
+
+ #
+ # TODO: hook this in later
+ #quotas = QUOTAS.get_project_quotas(context, context.project_id,
+ # usages=False)
+ #abs_limits = dict((k, v['limit']) for k, v in quotas.items())
+ abs_limits = {}
+ rate_limits = req.environ.get("reddwarf.limits", [])
+
+ builder = self._get_view_builder(req)
+ return builder.build(rate_limits, abs_limits)
+
+ def _get_view_builder(self, req):
+ return views.ViewBuilder()
diff --git a/reddwarf/limits/views.py b/reddwarf/limits/views.py
new file mode 100644
index 00000000..6158cc2b
--- /dev/null
+++ b/reddwarf/limits/views.py
@@ -0,0 +1,98 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010-2011 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+import datetime
+
+from reddwarf.openstack.common import timeutils
+
+
+class ViewBuilder(object):
+ """OpenStack API base limits view builder."""
+
+ def build(self, rate_limits, absolute_limits):
+ rate_limits = self._build_rate_limits(rate_limits)
+ absolute_limits = self._build_absolute_limits(absolute_limits)
+
+ output = {
+ "limits": {
+ "rate": rate_limits,
+ "absolute": absolute_limits,
+ },
+ }
+
+ return output
+
+ def _build_absolute_limits(self, absolute_limits):
+ """Builder for absolute limits
+
+ absolute_limits should be given as a dict of limits.
+ For example: {"ram": 512, "gigabytes": 1024}.
+
+ """
+ limit_names = {
+ "ram": ["maxTotalRAMSize"],
+ "instances": ["maxTotalInstances"],
+ "cores": ["maxTotalCores"],
+ "metadata_items": ["maxServerMeta", "maxImageMeta"],
+ "injected_files": ["maxPersonality"],
+ "injected_file_content_bytes": ["maxPersonalitySize"],
+ "security_groups": ["maxSecurityGroups"],
+ "security_group_rules": ["maxSecurityGroupRules"],
+ }
+ limits = {}
+ for name, value in absolute_limits.iteritems():
+ if name in limit_names and value is not None:
+ for name in limit_names[name]:
+ limits[name] = value
+ return limits
+
+ def _build_rate_limits(self, rate_limits):
+ limits = []
+ for rate_limit in rate_limits:
+ _rate_limit_key = None
+ _rate_limit = self._build_rate_limit(rate_limit)
+
+ # check for existing key
+ for limit in limits:
+ if (limit["uri"] == rate_limit["URI"] and
+ limit["regex"] == rate_limit["regex"]):
+ _rate_limit_key = limit
+ break
+
+ # ensure we have a key if we didn't find one
+ if not _rate_limit_key:
+ _rate_limit_key = {
+ "uri": rate_limit["URI"],
+ "regex": rate_limit["regex"],
+ "limit": [],
+ }
+ limits.append(_rate_limit_key)
+
+ _rate_limit_key["limit"].append(_rate_limit)
+
+ return limits
+
+ def _build_rate_limit(self, rate_limit):
+ _get_utc = datetime.datetime.utcfromtimestamp
+ next_avail = _get_utc(rate_limit["resetTime"])
+ return {
+ "verb": rate_limit["verb"],
+ "value": rate_limit["value"],
+ "remaining": int(rate_limit["remaining"]),
+ "unit": rate_limit["unit"],
+ "next-available": timeutils.isotime(at=next_avail),
+ }
diff --git a/reddwarf/tests/api/limits.py b/reddwarf/tests/api/limits.py
new file mode 100644
index 00000000..c22fa26f
--- /dev/null
+++ b/reddwarf/tests/api/limits.py
@@ -0,0 +1,109 @@
+from nose.tools import assert_equal
+from nose.tools import assert_false
+from nose.tools import assert_true
+
+from proboscis import before_class
+from proboscis import test
+
+from reddwarf.openstack.common import timeutils
+from reddwarf.tests.util import create_dbaas_client
+from reddwarf.tests.util import test_config
+from reddwarfclient import exceptions
+
+from datetime import datetime
+
+GROUP = "dbaas.api.limits"
+DEFAULT_RATE = 200
+# Note: This should not be enabled until rd-client merges
+RD_CLIENT_OK = False
+
+
+@test(groups=[GROUP])
+class Limits(object):
+
+ @before_class
+ def setUp(self):
+ rate_user = self._get_user('rate_limit')
+ self.rd_client = create_dbaas_client(rate_user)
+
+ def _get_user(self, name):
+ return test_config.users.find_user_by_name(name)
+
+ def _get_next_available(self, resource):
+ return resource.__dict__['next-available']
+
+ def __is_available(self, next_available):
+ dt_next = timeutils.parse_isotime(next_available)
+ dt_now = datetime.now()
+ return dt_next.time() < dt_now.time()
+
+ @test(enabled=RD_CLIENT_OK)
+ def test_limits_index(self):
+ """test_limits_index"""
+ r1, r2, r3, r4 = self.rd_client.limits.index()
+
+ assert_equal(r1.verb, "POST")
+ assert_equal(r1.unit, "MINUTE")
+ assert_true(r1.remaining <= DEFAULT_RATE)
+
+ next_available = self._get_next_available(r1)
+ assert_true(next_available is not None)
+
+ assert_equal(r2.verb, "PUT")
+ assert_equal(r2.unit, "MINUTE")
+ assert_true(r2.remaining <= DEFAULT_RATE)
+
+ next_available = self._get_next_available(r2)
+ assert_true(next_available is not None)
+
+ assert_equal(r3.verb, "DELETE")
+ assert_equal(r3.unit, "MINUTE")
+ assert_true(r3.remaining <= DEFAULT_RATE)
+
+ next_available = self._get_next_available(r3)
+ assert_true(next_available is not None)
+
+ assert_equal(r4.verb, "GET")
+ assert_equal(r4.unit, "MINUTE")
+ assert_true(r4.remaining <= DEFAULT_RATE)
+
+ next_available = self._get_next_available(r4)
+ assert_true(next_available is not None)
+
+ @test(enabled=RD_CLIENT_OK)
+ def test_limits_get_remaining(self):
+ """test_limits_get_remaining"""
+ gets = None
+ for i in xrange(5):
+ r1, r2, r3, r4 = self.rd_client.limits.index()
+ gets = r4
+
+ assert_equal(gets.verb, "GET")
+ assert_equal(gets.unit, "MINUTE")
+ assert_true(gets.remaining <= DEFAULT_RATE - 5)
+
+ next_available = self._get_next_available(gets)
+ assert_true(next_available is not None)
+
+ @test(enabled=RD_CLIENT_OK)
+ def test_limits_exception(self):
+ """test_limits_exception"""
+
+ # use a different user to avoid throttling tests run out of order
+ rate_user_exceeded = self._get_user('rate_limit_exceeded')
+ rd_client = create_dbaas_client(rate_user_exceeded)
+
+ gets = None
+ encountered = False
+ for i in xrange(DEFAULT_RATE + 50):
+ try:
+ r1, r2, r3, r4 = rd_client.limits.index()
+ gets = r4
+ assert_equal(gets.verb, "GET")
+ assert_equal(gets.unit, "MINUTE")
+
+ except exceptions.OverLimit:
+ encountered = True
+
+ assert_true(encountered)
+ assert_true(gets.remaining <= 50)
diff --git a/reddwarf/tests/unittests/api/common/__init__.py b/reddwarf/tests/unittests/api/common/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/reddwarf/tests/unittests/api/common/__init__.py
diff --git a/reddwarf/tests/unittests/api/common/test_limits.py b/reddwarf/tests/unittests/api/common/test_limits.py
new file mode 100644
index 00000000..a7b19c74
--- /dev/null
+++ b/reddwarf/tests/unittests/api/common/test_limits.py
@@ -0,0 +1,741 @@
+# Copyright 2011 OpenStack LLC.
+# All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""
+Tests dealing with HTTP rate-limiting.
+"""
+
+import httplib
+import StringIO
+from xml.dom import minidom
+from lxml import etree
+import testtools
+import webob
+
+from mockito import when
+
+from reddwarf.common import limits
+from reddwarf.common import xmlutil
+from reddwarf.common.limits import Limit
+from reddwarf.limits import views
+from reddwarf.openstack.common import jsonutils
+
+from reddwarf.tests.unittests.util.matchers import DictMatches
+
+TEST_LIMITS = [
+ Limit("GET", "/delayed", "^/delayed", 1, limits.PER_MINUTE),
+ Limit("POST", "*", ".*", 7, limits.PER_MINUTE),
+ Limit("POST", "/servers", "^/servers", 3, limits.PER_MINUTE),
+ Limit("PUT", "*", "", 10, limits.PER_MINUTE),
+ Limit("PUT", "/servers", "^/servers", 5, limits.PER_MINUTE),
+]
+NS = {
+ 'atom': 'http://www.w3.org/2005/Atom',
+ 'ns': 'http://docs.openstack.org/common/api/v1.0'
+}
+
+
+class BaseLimitTestSuite(testtools.TestCase):
+ """Base test suite which provides relevant stubs and time abstraction."""
+
+ def setUp(self):
+ super(BaseLimitTestSuite, self).setUp()
+
+ self.absolute_limits = {}
+
+
+class LimitsControllerTest(BaseLimitTestSuite):
+ """
+ Tests for `limits.LimitsController` class.
+ TODO: add test cases once absolute limits are integrated
+ """
+ pass
+
+
+class TestLimiter(limits.Limiter):
+ pass
+
+
+class LimitMiddlewareTest(BaseLimitTestSuite):
+ """
+ Tests for the `limits.RateLimitingMiddleware` class.
+ """
+
+ @webob.dec.wsgify
+ def _empty_app(self, request):
+ """Do-nothing WSGI app."""
+ pass
+
+ def setUp(self):
+ """Prepare middleware for use through fake WSGI app."""
+ super(LimitMiddlewareTest, self).setUp()
+ _limits = '(GET, *, .*, 1, MINUTE)'
+ self.app = limits.RateLimitingMiddleware(self._empty_app, _limits,
+ "%s.TestLimiter" %
+ self.__class__.__module__)
+
+ def test_limit_class(self):
+ # Test that middleware selected correct limiter class.
+ assert isinstance(self.app._limiter, TestLimiter)
+
+ def test_good_request(self):
+ # Test successful GET request through middleware.
+ request = webob.Request.blank("/")
+ response = request.get_response(self.app)
+ self.assertEqual(200, response.status_int)
+
+ def test_limited_request_json(self):
+ # Test a rate-limited (413) GET request through middleware.
+ request = webob.Request.blank("/")
+ response = request.get_response(self.app)
+ self.assertEqual(200, response.status_int)
+
+ request = webob.Request.blank("/")
+ response = request.get_response(self.app)
+ self.assertEqual(response.status_int, 413)
+
+ self.assertTrue('Retry-After' in response.headers)
+ retry_after = int(response.headers['Retry-After'])
+ self.assertAlmostEqual(retry_after, 60, 1)
+
+ body = jsonutils.loads(response.body)
+ expected = "Only 1 GET request(s) can be made to * every minute."
+ value = body["overLimit"]["details"].strip()
+ self.assertEqual(value, expected)
+
+ self.assertTrue("retryAfter" in body["overLimit"])
+ retryAfter = body["overLimit"]["retryAfter"]
+ self.assertEqual(retryAfter, "60")
+
+ def test_limited_request_xml(self):
+ # Test a rate-limited (413) response as XML.
+ request = webob.Request.blank("/")
+ response = request.get_response(self.app)
+ self.assertEqual(200, response.status_int)
+
+ request = webob.Request.blank("/")
+ request.accept = "application/xml"
+ response = request.get_response(self.app)
+ self.assertEqual(response.status_int, 413)
+
+ root = minidom.parseString(response.body).childNodes[0]
+ expected = "Only 1 GET request(s) can be made to * every minute."
+
+ self.assertNotEqual(root.attributes.getNamedItem("retryAfter"), None)
+ retryAfter = root.attributes.getNamedItem("retryAfter").value
+ self.assertEqual(retryAfter, "60")
+
+ details = root.getElementsByTagName("details")
+ self.assertEqual(details.length, 1)
+
+ value = details.item(0).firstChild.data.strip()
+ self.assertEqual(value, expected)
+
+
+class LimitTest(BaseLimitTestSuite):
+ """
+ Tests for the `limits.Limit` class.
+ """
+
+ def test_GET_no_delay(self):
+ # Test a limit handles 1 GET per second.
+ limit = Limit("GET", "*", ".*", 1, 1)
+ when(limit)._get_time().thenReturn(0.0)
+ delay = limit("GET", "/anything")
+ self.assertEqual(None, delay)
+ self.assertEqual(0, limit.next_request)
+ self.assertEqual(0, limit.last_request)
+
+ def test_GET_delay(self):
+ # Test two calls to 1 GET per second limit.
+ limit = Limit("GET", "*", ".*", 1, 1)
+ when(limit)._get_time().thenReturn(0.0)
+
+ delay = limit("GET", "/anything")
+ self.assertEqual(None, delay)
+
+ delay = limit("GET", "/anything")
+ self.assertEqual(1, delay)
+ self.assertEqual(1, limit.next_request)
+ self.assertEqual(0, limit.last_request)
+
+ when(limit)._get_time().thenReturn(4.0)
+
+ delay = limit("GET", "/anything")
+ self.assertEqual(None, delay)
+ self.assertEqual(4, limit.next_request)
+ self.assertEqual(4, limit.last_request)
+
+
+class ParseLimitsTest(BaseLimitTestSuite):
+ """
+ Tests for the default limits parser in the in-memory
+ `limits.Limiter` class.
+ """
+
+ def test_invalid(self):
+ # Test that parse_limits() handles invalid input correctly.
+ self.assertRaises(ValueError, limits.Limiter.parse_limits,
+ ';;;;;')
+
+ def test_bad_rule(self):
+ # Test that parse_limits() handles bad rules correctly.
+ self.assertRaises(ValueError, limits.Limiter.parse_limits,
+ 'GET, *, .*, 20, minute')
+
+ def test_missing_arg(self):
+ # Test that parse_limits() handles missing args correctly.
+ self.assertRaises(ValueError, limits.Limiter.parse_limits,
+ '(GET, *, .*, 20)')
+
+ def test_bad_value(self):
+ # Test that parse_limits() handles bad values correctly.
+ self.assertRaises(ValueError, limits.Limiter.parse_limits,
+ '(GET, *, .*, foo, minute)')
+
+ def test_bad_unit(self):
+ # Test that parse_limits() handles bad units correctly.
+ self.assertRaises(ValueError, limits.Limiter.parse_limits,
+ '(GET, *, .*, 20, lightyears)')
+
+ def test_multiple_rules(self):
+ # Test that parse_limits() handles multiple rules correctly.
+ try:
+ l = limits.Limiter.parse_limits('(get, *, .*, 20, minute);'
+ '(PUT, /foo*, /foo.*, 10, hour);'
+ '(POST, /bar*, /bar.*, 5, second);'
+ '(Say, /derp*, /derp.*, 1, day)')
+ except ValueError, e:
+ assert False, str(e)
+
+ # Make sure the number of returned limits are correct
+ self.assertEqual(len(l), 4)
+
+ # Check all the verbs...
+ expected = ['GET', 'PUT', 'POST', 'SAY']
+ self.assertEqual([t.verb for t in l], expected)
+
+ # ...the URIs...
+ expected = ['*', '/foo*', '/bar*', '/derp*']
+ self.assertEqual([t.uri for t in l], expected)
+
+ # ...the regexes...
+ expected = ['.*', '/foo.*', '/bar.*', '/derp.*']
+ self.assertEqual([t.regex for t in l], expected)
+
+ # ...the values...
+ expected = [20, 10, 5, 1]
+ self.assertEqual([t.value for t in l], expected)
+
+ # ...and the units...
+ expected = [limits.PER_MINUTE, limits.PER_HOUR,
+ limits.PER_SECOND, limits.PER_DAY]
+ self.assertEqual([t.unit for t in l], expected)
+
+
+class LimiterTest(BaseLimitTestSuite):
+ """
+ Tests for the in-memory `limits.Limiter` class.
+ """
+
+ def update_limits(self, delay):
+ for l in TEST_LIMITS:
+ when(l)._get_time().thenReturn(delay)
+
+ def setUp(self):
+ """Run before each test."""
+ super(LimiterTest, self).setUp()
+ userlimits = {'user:user3': ''}
+
+ self.update_limits(0.0)
+
+ self.limiter = limits.Limiter(TEST_LIMITS, **userlimits)
+
+ def _check(self, num, verb, url, username=None):
+ """Check and yield results from checks."""
+ for x in xrange(num):
+ yield self.limiter.check_for_delay(verb, url, username)[0]
+
+ def _check_sum(self, num, verb, url, username=None):
+ """Check and sum results from checks."""
+ results = self._check(num, verb, url, username)
+ return sum(item for item in results if item)
+
+ def test_no_delay_GET(self):
+ """
+ Simple test to ensure no delay on a single call for a limit verb we
+ didn"t set.
+ """
+ delay = self.limiter.check_for_delay("GET", "/anything")
+ self.assertEqual(delay, (None, None))
+
+ def test_no_delay_PUT(self):
+ # Simple test to ensure no delay on a single call for a known limit.
+ delay = self.limiter.check_for_delay("PUT", "/anything")
+ self.assertEqual(delay, (None, None))
+
+ def test_delay_PUT(self):
+ """
+ Ensure the 11th PUT will result in a delay of 6.0 seconds until
+ the next request will be granced.
+ """
+ expected = [None] * 10 + [6.0]
+ results = list(self._check(11, "PUT", "/anything"))
+
+ self.assertEqual(expected, results)
+
+ def test_delay_POST(self):
+ """
+ Ensure the 8th POST will result in a delay of 6.0 seconds until
+ the next request will be granced.
+ """
+ expected = [None] * 7
+ results = list(self._check(7, "POST", "/anything"))
+ self.assertEqual(expected, results)
+
+ expected = 60.0 / 7.0
+ results = self._check_sum(1, "POST", "/anything")
+ self.failUnlessAlmostEqual(expected, results, 8)
+
+ def test_delay_GET(self):
+ # Ensure the 11th GET will result in NO delay.
+ expected = [None] * 11
+ results = list(self._check(11, "GET", "/anything"))
+
+ self.assertEqual(expected, results)
+
+ def test_delay_PUT_servers(self):
+ """
+ Ensure PUT on /servers limits at 5 requests, and PUT elsewhere is still
+ OK after 5 requests...but then after 11 total requests, PUT limiting
+ kicks in.
+ """
+ # First 6 requests on PUT /servers
+ expected = [None] * 5 + [12.0]
+ results = list(self._check(6, "PUT", "/servers"))
+ self.assertEqual(expected, results)
+
+ # Next 5 request on PUT /anything
+ expected = [None] * 4 + [6.0]
+ results = list(self._check(5, "PUT", "/anything"))
+ self.assertEqual(expected, results)
+
+ def test_delay_PUT_wait(self):
+ """
+ Ensure after hitting the limit and then waiting for the correct
+ amount of time, the limit will be lifted.
+ """
+ expected = [None] * 10 + [6.0]
+ results = list(self._check(11, "PUT", "/anything"))
+ self.assertEqual(expected, results)
+
+ # Advance time
+ self.update_limits(6.0)
+
+ expected = [None, 6.0]
+ results = list(self._check(2, "PUT", "/anything"))
+ self.assertEqual(expected, results)
+
+ def test_multiple_delays(self):
+ # Ensure multiple requests still get a delay.
+ expected = [None] * 10 + [6.0] * 10
+ results = list(self._check(20, "PUT", "/anything"))
+ self.assertEqual(expected, results)
+
+ self.update_limits(1.0)
+
+ expected = [5.0] * 10
+ results = list(self._check(10, "PUT", "/anything"))
+ self.assertEqual(expected, results)
+
+ def test_user_limit(self):
+ # Test user-specific limits.
+ self.assertEqual(self.limiter.levels['user3'], [])
+
+ def test_multiple_users(self):
+ # Tests involving multiple users.
+ # User1
+ self.update_limits(0.0)
+ expected = [None] * 10 + [6.0] * 10
+ results = list(self._check(20, "PUT", "/anything", "user1"))
+ self.assertEqual(expected, results)
+
+ # User2
+ expected = [None] * 10 + [6.0] * 5
+ results = list(self._check(15, "PUT", "/anything", "user2"))
+ self.assertEqual(expected, results)
+
+ # User3
+ expected = [None] * 20
+ results = list(self._check(20, "PUT", "/anything", "user3"))
+ self.assertEqual(expected, results)
+
+ self.update_limits(1.0)
+ # User1 again
+ expected = [5.0] * 10
+ results = list(self._check(10, "PUT", "/anything", "user1"))
+ self.assertEqual(expected, results)
+
+ self.update_limits(2.0)
+
+ # User1 again
+ expected = [4.0] * 5
+ results = list(self._check(5, "PUT", "/anything", "user2"))
+ self.assertEqual(expected, results)
+
+
+class WsgiLimiterTest(BaseLimitTestSuite):
+ """
+ Tests for `limits.WsgiLimiter` class.
+ """
+
+ def setUp(self):
+ """Run before each test."""
+ super(WsgiLimiterTest, self).setUp()
+ self.app = limits.WsgiLimiter(TEST_LIMITS)
+
+ def _request_data(self, verb, path):
+ """Get data describing a limit request verb/path."""
+ return jsonutils.dumps({"verb": verb, "path": path})
+
+ def _request(self, verb, url, username=None):
+ """Make sure that POSTing to the given url causes the given username
+ to perform the given action. Make the internal rate limiter return
+ delay and make sure that the WSGI app returns the correct response.
+ """
+ if username:
+ request = webob.Request.blank("/%s" % username)
+ else:
+ request = webob.Request.blank("/")
+
+ request.method = "POST"
+ request.body = self._request_data(verb, url)
+ response = request.get_response(self.app)
+
+ if "X-Wait-Seconds" in response.headers:
+ self.assertEqual(response.status_int, 403)
+ return response.headers["X-Wait-Seconds"]
+
+ self.assertEqual(response.status_int, 204)
+
+ def test_invalid_methods(self):
+ # Only POSTs should work.
+ requests = []
+ for method in ["GET", "PUT", "DELETE", "HEAD", "OPTIONS"]:
+ request = webob.Request.blank("/", method=method)
+ response = request.get_response(self.app)
+ self.assertEqual(response.status_int, 405)
+
+ def test_good_url(self):
+ delay = self._request("GET", "/something")
+ self.assertEqual(delay, None)
+
+ def test_escaping(self):
+ delay = self._request("GET", "/something/jump%20up")
+ self.assertEqual(delay, None)
+
+ def test_response_to_delays(self):
+ delay = self._request("GET", "/delayed")
+ self.assertEqual(delay, None)
+
+ delay = self._request("GET", "/delayed")
+ self.assertEqual(delay, '60.00')
+
+ def test_response_to_delays_usernames(self):
+ delay = self._request("GET", "/delayed", "user1")
+ self.assertEqual(delay, None)
+
+ delay = self._request("GET", "/delayed", "user2")
+ self.assertEqual(delay, None)
+
+ delay = self._request("GET", "/delayed", "user1")
+ self.assertEqual(delay, '60.00')
+
+ delay = self._request("GET", "/delayed", "user2")
+ self.assertEqual(delay, '60.00')
+
+
+class FakeHttplibSocket(object):
+ """
+ Fake `httplib.HTTPResponse` replacement.
+ """
+
+ def __init__(self, response_string):
+ """Initialize new `FakeHttplibSocket`."""
+ self._buffer = StringIO.StringIO(response_string)
+
+ def makefile(self, _mode, _other):
+ """Returns the socket's internal buffer."""
+ return self._buffer
+
+
+class FakeHttplibConnection(object):
+ """
+ Fake `httplib.HTTPConnection`.
+ """
+
+ def __init__(self, app, host):
+ """
+ Initialize `FakeHttplibConnection`.
+ """
+ self.app = app
+ self.host = host
+
+ def request(self, method, path, body="", headers=None):
+ """
+ Requests made via this connection actually get translated and routed
+ into our WSGI app, we then wait for the response and turn it back into
+ an `httplib.HTTPResponse`.
+ """
+ if not headers:
+ headers = {}
+
+ req = webob.Request.blank(path)
+ req.method = method
+ req.headers = headers
+ req.host = self.host
+ req.body = body
+
+ resp = str(req.get_response(self.app))
+ resp = "HTTP/1.0 %s" % resp
+ sock = FakeHttplibSocket(resp)
+ self.http_response = httplib.HTTPResponse(sock)
+ self.http_response.begin()
+
+ def getresponse(self):
+ """Return our generated response from the request."""
+ return self.http_response
+
+
+def wire_HTTPConnection_to_WSGI(host, app):
+ """Monkeypatches HTTPConnection so that if you try to connect to host, you
+ are instead routed straight to the given WSGI app.
+
+ After calling this method, when any code calls
+
+ httplib.HTTPConnection(host)
+
+ the connection object will be a fake. Its requests will be sent directly
+ to the given WSGI app rather than through a socket.
+
+ Code connecting to hosts other than host will not be affected.
+
+ This method may be called multiple times to map different hosts to
+ different apps.
+
+ This method returns the original HTTPConnection object, so that the caller
+ can restore the default HTTPConnection interface (for all hosts).
+ """
+
+ class HTTPConnectionDecorator(object):
+ """Wraps the real HTTPConnection class so that when you instantiate
+ the class you might instead get a fake instance."""
+
+ def __init__(self, wrapped):
+ self.wrapped = wrapped
+
+ def __call__(self, connection_host, *args, **kwargs):
+ if connection_host == host:
+ return FakeHttplibConnection(app, host)
+ else:
+ return self.wrapped(connection_host, *args, **kwargs)
+
+ oldHTTPConnection = httplib.HTTPConnection
+ httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection)
+ return oldHTTPConnection
+
+
+class WsgiLimiterProxyTest(BaseLimitTestSuite):
+ """
+ Tests for the `limits.WsgiLimiterProxy` class.
+ """
+
+ def setUp(self):
+ """
+ Do some nifty HTTP/WSGI magic which allows for WSGI to be called
+ directly by something like the `httplib` library.
+ """
+ super(WsgiLimiterProxyTest, self).setUp()
+ self.app = limits.WsgiLimiter(TEST_LIMITS)
+ self.oldHTTPConnection = (
+ wire_HTTPConnection_to_WSGI("169.254.0.1:80", self.app))
+ self.proxy = limits.WsgiLimiterProxy("169.254.0.1:80")
+
+ def test_200(self):
+ # Successful request test.
+ delay = self.proxy.check_for_delay("GET", "/anything")
+ self.assertEqual(delay, (None, None))
+
+ def test_403(self):
+ # Forbidden request test.
+ delay = self.proxy.check_for_delay("GET", "/delayed")
+ self.assertEqual(delay, (None, None))
+
+ delay, error = self.proxy.check_for_delay("GET", "/delayed")
+ error = error.strip()
+
+ expected = ("60.00", "403 Forbidden\n\nOnly 1 GET request(s) can be "
+ "made to /delayed every minute.")
+
+ self.assertEqual((delay, error), expected)
+
+ def tearDown(self):
+ # restore original HTTPConnection object
+ httplib.HTTPConnection = self.oldHTTPConnection
+ super(WsgiLimiterProxyTest, self).tearDown()
+
+
+class LimitsViewBuilderTest(testtools.TestCase):
+ def setUp(self):
+ super(LimitsViewBuilderTest, self).setUp()
+ self.view_builder = views.ViewBuilder()
+ self.rate_limits = [{"URI": "*",
+ "regex": ".*",
+ "value": 10,
+ "verb": "POST",
+ "remaining": 2,
+ "unit": "MINUTE",
+ "resetTime": 1311272226},
+ {"URI": "*/servers",
+ "regex": "^/servers",
+ "value": 50,
+ "verb": "POST",
+ "remaining": 10,
+ "unit": "DAY",
+ "resetTime": 1311272226}]
+ self.absolute_limits = {"metadata_items": 1,
+ "injected_files": 5,
+ "injected_file_content_bytes": 5}
+
+ def test_build_limits(self):
+ expected_limits = {"limits": {
+ "rate": [{"uri": "*",
+ "regex": ".*",
+ "limit": [{"value": 10,
+ "verb": "POST",
+ "remaining": 2,
+ "unit": "MINUTE",
+ "next-available": "2011-07-21T18:17:06Z"}]},
+ {"uri": "*/servers",
+ "regex": "^/servers",
+ "limit": [{"value": 50,
+ "verb": "POST",
+ "remaining": 10,
+ "unit": "DAY",
+ "next-available": "2011-07-21T18:17:06Z"}]}],
+ "absolute": {
+ "maxServerMeta": 1,
+ "maxImageMeta": 1,
+ "maxPersonality": 5,
+ "maxPersonalitySize": 5}}}
+
+ output = self.view_builder.build(self.rate_limits,
+ self.absolute_limits)
+ self.assertThat(output, DictMatches(expected_limits))
+
+ def test_build_limits_empty_limits(self):
+ expected_limits = {"limits": {"rate": [],
+ "absolute": {}}}
+
+ abs_limits = {}
+ rate_limits = []
+ output = self.view_builder.build(rate_limits, abs_limits)
+ self.assertThat(output, DictMatches(expected_limits))
+
+
+class LimitsXMLSerializationTest(testtools.TestCase):
+ def test_xml_declaration(self):
+ serializer = limits.LimitsTemplate()
+
+ fixture = {"limits": {
+ "rate": [],
+ "absolute": {}}}
+
+ output = serializer.serialize(fixture)
+ has_dec = output.startswith("<?xml version='1.0' encoding='UTF-8'?>")
+ self.assertTrue(has_dec)
+
+ def test_index(self):
+ serializer = limits.LimitsTemplate()
+ fixture = {
+ "limits": {
+ "rate": [{"uri": "*",
+ "regex": ".*",
+ "limit": [
+ {"value": 10,
+ "verb": "POST",
+ "remaining": 2,
+ "unit": "MINUTE",
+ "next-available": "2011-12-15T22:42:45Z"}]},
+ {"uri": "*/servers",
+ "regex": "^/servers",
+ "limit": [
+ {"value": 50,
+ "verb": "POST",
+ "remaining": 10,
+ "unit": "DAY",
+ "next-available": "2011-12-15T22:42:45Z"}]}],
+ "absolute": {
+ "maxServerMeta": 1,
+ "maxImageMeta": 1,
+ "maxPersonality": 5,
+ "maxPersonalitySize": 10240}}}
+
+ output = serializer.serialize(fixture)
+ root = etree.XML(output)
+ xmlutil.validate_schema(root, 'limits')
+
+ #verify absolute limits
+ absolutes = root.xpath('ns:absolute/ns:limit', namespaces=NS)
+ self.assertEqual(len(absolutes), 4)
+ for limit in absolutes:
+ name = limit.get('name')
+ value = limit.get('value')
+ self.assertEqual(value, str(fixture['limits']['absolute'][name]))
+
+ #verify rate limits
+ rates = root.xpath('ns:rates/ns:rate', namespaces=NS)
+ self.assertEqual(len(rates), 2)
+ for i, rate in enumerate(rates):
+ for key in ['uri', 'regex']:
+ self.assertEqual(rate.get(key),
+ str(fixture['limits']['rate'][i][key]))
+ rate_limits = rate.xpath('ns:limit', namespaces=NS)
+ self.assertEqual(len(rate_limits), 1)
+ for j, limit in enumerate(rate_limits):
+ for key in ['verb', 'value', 'remaining', 'unit',
+ 'next-available']:
+ self.assertEqual(limit.get(key),
+ str(fixture['limits']['rate'][i]['limit']
+ [j][key]))
+
+ def test_index_no_limits(self):
+ serializer = limits.LimitsTemplate()
+
+ fixture = {"limits": {
+ "rate": [],
+ "absolute": {}}}
+
+ output = serializer.serialize(fixture)
+ root = etree.XML(output)
+ xmlutil.validate_schema(root, 'limits')
+
+ #verify absolute limits
+ absolutes = root.xpath('ns:absolute/ns:limit', namespaces=NS)
+ self.assertEqual(len(absolutes), 0)
+
+ #verify rate limits
+ rates = root.xpath('ns:rates/ns:rate', namespaces=NS)
+ self.assertEqual(len(rates), 0)
diff --git a/reddwarf/tests/unittests/util/matchers.py b/reddwarf/tests/unittests/util/matchers.py
new file mode 100644
index 00000000..be65da82
--- /dev/null
+++ b/reddwarf/tests/unittests/util/matchers.py
@@ -0,0 +1,454 @@
+# vim: tabstop=4 shiftwidth=4 softtabstop=4
+
+# Copyright 2010 United States Government as represented by the
+# Administrator of the National Aeronautics and Space Administration.
+# Copyright 2012 Hewlett-Packard Development Company, L.P.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may
+# not use this file except in compliance with the License. You may obtain
+# a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+
+"""Matcher classes to be used inside of the testtools assertThat framework."""
+
+import pprint
+
+from lxml import etree
+
+
+class DictKeysMismatch(object):
+ def __init__(self, d1only, d2only):
+ self.d1only = d1only
+ self.d2only = d2only
+
+ def describe(self):
+ return ('Keys in d1 and not d2: %(d1only)s.'
+ ' Keys in d2 and not d1: %(d2only)s' % self.__dict__)
+
+ def get_details(self):
+ return {}
+
+
+class DictMismatch(object):
+ def __init__(self, key, d1_value, d2_value):
+ self.key = key
+ self.d1_value = d1_value
+ self.d2_value = d2_value
+
+ def describe(self):
+ return ("Dictionaries do not match at %(key)s."
+ " d1: %(d1_value)s d2: %(d2_value)s" % self.__dict__)
+
+ def get_details(self):
+ return {}
+
+
+class DictMatches(object):
+
+ def __init__(self, d1, approx_equal=False, tolerance=0.001):
+ self.d1 = d1
+ self.approx_equal = approx_equal
+ self.tolerance = tolerance
+
+ def __str__(self):
+ return 'DictMatches(%s)' % (pprint.pformat(self.d1))
+
+ # Useful assertions
+ def match(self, d2):
+ """Assert two dicts are equivalent.
+
+ This is a 'deep' match in the sense that it handles nested
+ dictionaries appropriately.
+
+ NOTE:
+
+ If you don't care (or don't know) a given value, you can specify
+ the string DONTCARE as the value. This will cause that dict-item
+ to be skipped.
+
+ """
+
+ d1keys = set(self.d1.keys())
+ d2keys = set(d2.keys())
+ if d1keys != d2keys:
+ d1only = d1keys - d2keys
+ d2only = d2keys - d1keys
+ return DictKeysMismatch(d1only, d2only)
+
+ for key in d1keys:
+ d1value = self.d1[key]
+ d2value = d2[key]
+ try:
+ error = abs(float(d1value) - float(d2value))
+ within_tolerance = error <= self.tolerance
+ except (ValueError, TypeError):
+ # If both values aren't convertible to float, just ignore
+ # ValueError if arg is a str, TypeError if it's something else
+ # (like None)
+ within_tolerance = False
+
+ if hasattr(d1value, 'keys') and hasattr(d2value, 'keys'):
+ matcher = DictMatches(d1value)
+ did_match = matcher.match(d2value)
+ if did_match is not None:
+ return did_match
+ elif 'DONTCARE' in (d1value, d2value):
+ continue
+ elif self.approx_equal and within_tolerance:
+ continue
+ elif d1value != d2value:
+ return DictMismatch(key, d1value, d2value)
+
+
+class ListLengthMismatch(object):
+ def __init__(self, len1, len2):
+ self.len1 = len1
+ self.len2 = len2
+
+ def describe(self):
+ return ('Length mismatch: len(L1)=%(len1)d != '
+ 'len(L2)=%(len2)d' % self.__dict__)
+
+ def get_details(self):
+ return {}
+
+
+class DictListMatches(object):
+
+ def __init__(self, l1, approx_equal=False, tolerance=0.001):
+ self.l1 = l1
+ self.approx_equal = approx_equal
+ self.tolerance = tolerance
+
+ def __str__(self):
+ return 'DictListMatches(%s)' % (pprint.pformat(self.l1))
+
+ # Useful assertions
+ def match(self, l2):
+ """Assert a list of dicts are equivalent."""
+
+ l1count = len(self.l1)
+ l2count = len(l2)
+ if l1count != l2count:
+ return ListLengthMismatch(l1count, l2count)
+
+ for d1, d2 in zip(self.l1, l2):
+ matcher = DictMatches(d2,
+ approx_equal=self.approx_equal,
+ tolerance=self.tolerance)
+ did_match = matcher.match(d1)
+ if did_match:
+ return did_match
+
+
+class SubDictMismatch(object):
+ def __init__(self,
+ key=None,
+ sub_value=None,
+ super_value=None,
+ keys=False):
+ self.key = key
+ self.sub_value = sub_value
+ self.super_value = super_value
+ self.keys = keys
+
+ def describe(self):
+ if self.keys:
+ return "Keys between dictionaries did not match"
+ else:
+ return("Dictionaries do not match at %s. d1: %s d2: %s"
+ % (self.key,
+ self.super_value,
+ self.sub_value))
+
+ def get_details(self):
+ return {}
+
+
+class IsSubDictOf(object):
+
+ def __init__(self, super_dict):
+ self.super_dict = super_dict
+
+ def __str__(self):
+ return 'IsSubDictOf(%s)' % (self.super_dict)
+
+ def match(self, sub_dict):
+ """Assert a sub_dict is subset of super_dict."""
+ if not set(sub_dict.keys()).issubset(set(self.super_dict.keys())):
+ return SubDictMismatch(keys=True)
+ for k, sub_value in sub_dict.items():
+ super_value = self.super_dict[k]
+ if isinstance(sub_value, dict):
+ matcher = IsSubDictOf(super_value)
+ did_match = matcher.match(sub_value)
+ if did_match is not None:
+ return did_match
+ elif 'DONTCARE' in (sub_value, super_value):
+ continue
+ else:
+ if sub_value != super_value:
+ return SubDictMismatch(k, sub_value, super_value)
+
+
+class FunctionCallMatcher(object):
+
+ def __init__(self, expected_func_calls):
+ self.expected_func_calls = expected_func_calls
+ self.actual_func_calls = []
+
+ def call(self, *args, **kwargs):
+ func_call = {'args': args, 'kwargs': kwargs}
+ self.actual_func_calls.append(func_call)
+
+ def match(self):
+ dict_list_matcher = DictListMatches(self.expected_func_calls)
+ return dict_list_matcher.match(self.actual_func_calls)
+
+
+class XMLMismatch(object):
+ """Superclass for XML mismatch."""
+
+ def __init__(self, state):
+ self.path = str(state)
+ self.expected = state.expected
+ self.actual = state.actual
+
+ def describe(self):
+ return "%(path)s: XML does not match" % self.__dict__
+
+ def get_details(self):
+ return {
+ 'expected': self.expected,
+ 'actual': self.actual,
+ }
+
+
+class XMLTagMismatch(XMLMismatch):
+ """XML tags don't match."""
+
+ def __init__(self, state, idx, expected_tag, actual_tag):
+ super(XMLTagMismatch, self).__init__(state)
+ self.idx = idx
+ self.expected_tag = expected_tag
+ self.actual_tag = actual_tag
+
+ def describe(self):
+ return ("%(path)s: XML tag mismatch at index %(idx)d: "
+ "expected tag <%(expected_tag)s>; "
+ "actual tag <%(actual_tag)s>" % self.__dict__)
+
+
+class XMLAttrKeysMismatch(XMLMismatch):
+ """XML attribute keys don't match."""
+
+ def __init__(self, state, expected_only, actual_only):
+ super(XMLAttrKeysMismatch, self).__init__(state)
+ self.expected_only = ', '.join(sorted(expected_only))
+ self.actual_only = ', '.join(sorted(actual_only))
+
+ def describe(self):
+ return ("%(path)s: XML attributes mismatch: "
+ "keys only in expected: %(expected_only)s; "
+ "keys only in actual: %(actual_only)s" % self.__dict__)
+
+
+class XMLAttrValueMismatch(XMLMismatch):
+ """XML attribute values don't match."""
+
+ def __init__(self, state, key, expected_value, actual_value):
+ super(XMLAttrValueMismatch, self).__init__(state)
+ self.key = key
+ self.expected_value = expected_value
+ self.actual_value = actual_value
+
+ def describe(self):
+ return ("%(path)s: XML attribute value mismatch: "
+ "expected value of attribute %(key)s: %(expected_value)r; "
+ "actual value: %(actual_value)r" % self.__dict__)
+
+
+class XMLTextValueMismatch(XMLMismatch):
+ """XML text values don't match."""
+
+ def __init__(self, state, expected_text, actual_text):
+ super(XMLTextValueMismatch, self).__init__(state)
+ self.expected_text = expected_text
+ self.actual_text = actual_text
+
+ def describe(self):
+ return ("%(path)s: XML text value mismatch: "
+ "expected text value: %(expected_text)r; "
+ "actual value: %(actual_text)r" % self.__dict__)
+
+
+class XMLUnexpectedChild(XMLMismatch):
+ """Unexpected child present in XML."""
+
+ def __init__(self, state, tag, idx):
+ super(XMLUnexpectedChild, self).__init__(state)
+ self.tag = tag
+ self.idx = idx
+
+ def describe(self):
+ return ("%(path)s: XML unexpected child element <%(tag)s> "
+ "present at index %(idx)d" % self.__dict__)
+
+
+class XMLExpectedChild(XMLMismatch):
+ """Expected child not present in XML."""
+
+ def __init__(self, state, tag, idx):
+ super(XMLExpectedChild, self).__init__(state)
+ self.tag = tag
+ self.idx = idx
+
+ def describe(self):
+ return ("%(path)s: XML expected child element <%(tag)s> "
+ "not present at index %(idx)d" % self.__dict__)
+
+
+class XMLMatchState(object):
+ """
+ Maintain some state for matching.
+
+ Tracks the XML node path and saves the expected and actual full
+ XML text, for use by the XMLMismatch subclasses.
+ """
+
+ def __init__(self, expected, actual):
+ self.path = []
+ self.expected = expected
+ self.actual = actual
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ self.path.pop()
+ return False
+
+ def __str__(self):
+ return '/' + '/'.join(self.path)
+
+ def node(self, tag, idx):
+ """
+ Adds tag and index to the path; they will be popped off when
+ the corresponding 'with' statement exits.
+
+ :param tag: The element tag
+ :param idx: If not None, the integer index of the element
+ within its parent. Not included in the path
+ element if None.
+ """
+
+ if idx is not None:
+ self.path.append("%s[%d]" % (tag, idx))
+ else:
+ self.path.append(tag)
+ return self
+
+
+class XMLMatches(object):
+ """Compare XML strings. More complete than string comparison."""
+
+ def __init__(self, expected):
+ self.expected_xml = expected
+ self.expected = etree.fromstring(expected)
+
+ def __str__(self):
+ return 'XMLMatches(%r)' % self.expected_xml
+
+ def match(self, actual_xml):
+ actual = etree.fromstring(actual_xml)
+
+ state = XMLMatchState(self.expected_xml, actual_xml)
+ result = self._compare_node(self.expected, actual, state, None)
+
+ if result is False:
+ return XMLMismatch(state)
+ elif result is not True:
+ return result
+
+ def _compare_node(self, expected, actual, state, idx):
+ """Recursively compares nodes within the XML tree."""
+
+ # Start by comparing the tags
+ if expected.tag != actual.tag:
+ return XMLTagMismatch(state, idx, expected.tag, actual.tag)
+
+ with state.node(expected.tag, idx):
+ # Compare the attribute keys
+ expected_attrs = set(expected.attrib.keys())
+ actual_attrs = set(actual.attrib.keys())
+ if expected_attrs != actual_attrs:
+ expected_only = expected_attrs - actual_attrs
+ actual_only = actual_attrs - expected_attrs
+ return XMLAttrKeysMismatch(state, expected_only, actual_only)
+
+ # Compare the attribute values
+ for key in expected_attrs:
+ expected_value = expected.attrib[key]
+ actual_value = actual.attrib[key]
+
+ if 'DONTCARE' in (expected_value, actual_value):
+ continue
+ elif expected_value != actual_value:
+ return XMLAttrValueMismatch(state, key, expected_value,
+ actual_value)
+
+ # Compare the contents of the node
+ if len(expected) == 0 and len(actual) == 0:
+ # No children, compare text values
+ if ('DONTCARE' not in (expected.text, actual.text) and
+ expected.text != actual.text):
+ return XMLTextValueMismatch(state, expected.text,
+ actual.text)
+ else:
+ expected_idx = 0
+ actual_idx = 0
+ while (expected_idx < len(expected) and
+ actual_idx < len(actual)):
+ # Ignore comments and processing instructions
+ # TODO(Vek): may interpret PIs in the future, to
+ # allow for, say, arbitrary ordering of some
+ # elements
+ if (expected[expected_idx].tag in
+ (etree.Comment, etree.ProcessingInstruction)):
+ expected_idx += 1
+ continue
+
+ # Compare the nodes
+ result = self._compare_node(expected[expected_idx],
+ actual[actual_idx], state,
+ actual_idx)
+ if result is not True:
+ return result
+
+ # Step on to comparing the next nodes...
+ expected_idx += 1
+ actual_idx += 1
+
+ # Make sure we consumed all nodes in actual
+ if actual_idx < len(actual):
+ return XMLUnexpectedChild(state, actual[actual_idx].tag,
+ actual_idx)
+
+ # Make sure we consumed all nodes in expected
+ if expected_idx < len(expected):
+ for node in expected[expected_idx:]:
+ if (node.tag in
+ (etree.Comment, etree.ProcessingInstruction)):
+ continue
+
+ return XMLExpectedChild(state, node.tag, actual_idx)
+
+ # The nodes match
+ return True
diff --git a/run_tests.py b/run_tests.py
index 71701f83..7e7fdbe7 100644
--- a/run_tests.py
+++ b/run_tests.py
@@ -111,6 +111,7 @@ if __name__ == "__main__":
# Initialize the test configuration.
CONFIG.load_from_file('etc/tests/localhost.test.conf')
+ from reddwarf.tests.api import limits
from reddwarf.tests.api import flavors
from reddwarf.tests.api import versions
from reddwarf.tests.api import instances
diff --git a/tools/test-requires b/tools/test-requires
index a095ae4f..2096d766 100644
--- a/tools/test-requires
+++ b/tools/test-requires
@@ -17,3 +17,4 @@ testtools>=0.9.22
pexpect
discover
testrepository>=0.0.8
+mockito