summaryrefslogtreecommitdiff
path: root/src/saml2test/check.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/saml2test/check.py')
-rw-r--r--src/saml2test/check.py262
1 files changed, 262 insertions, 0 deletions
diff --git a/src/saml2test/check.py b/src/saml2test/check.py
new file mode 100644
index 00000000..1f2f062a
--- /dev/null
+++ b/src/saml2test/check.py
@@ -0,0 +1,262 @@
+import inspect
+import json
+
+__author__ = 'rolandh'
+
+import traceback
+import sys
+
+INFORMATION = 0
+OK = 1
+WARNING = 2
+ERROR = 3
+CRITICAL = 4
+INTERACTION = 5
+
+STATUSCODE = ["INFORMATION", "OK", "WARNING", "ERROR", "CRITICAL",
+ "INTERACTION"]
+
+CONT_JSON = "application/json"
+CONT_JWT = "application/jwt"
+
+
+class Check(object):
+ """ General test
+ """
+
+ cid = "check"
+ msg = "OK"
+
+ def __init__(self, **kwargs):
+ self._status = OK
+ self._message = ""
+ self.content = None
+ self.url = ""
+ self._kwargs = kwargs
+
+ def _func(self, conv):
+ return {}
+
+ def __call__(self, conv=None, output=None):
+ _stat = self.response(**self._func(conv))
+ if output is not None:
+ output.append(_stat)
+ return _stat
+
+ def response(self, **kwargs):
+ try:
+ name = " ".join(
+ [s.strip() for s in self.__doc__.strip().split("\n")])
+ except AttributeError:
+ name = ""
+
+ res = {
+ "id": self.cid,
+ "status": self._status,
+ "name": name
+ }
+
+ if self._message:
+ res["message"] = self._message
+
+ if kwargs:
+ res.update(kwargs)
+
+ return res
+
+
+class ExpectedError(Check):
+ pass
+
+
+class CriticalError(Check):
+ status = CRITICAL
+
+
+class Information(Check):
+ status = INFORMATION
+
+
+class Error(Check):
+ status = ERROR
+
+
+class ResponseInfo(Information):
+ """Response information"""
+
+ def _func(self, conv=None):
+ self._status = self.status
+ _msg = conv.last_content
+
+ if isinstance(_msg, basestring):
+ self._message = _msg
+ else:
+ self._message = _msg.to_dict()
+
+ return {}
+
+
+class CheckErrorResponse(ExpectedError):
+ """
+ Checks that the HTTP response status is outside the 200 or 300 range
+ or that an JSON encoded error message has been received
+ """
+ cid = "check-error-response"
+ msg = "OP error"
+
+ def _func(self, conv):
+ _response = conv.last_response
+ _content = conv.last_content
+
+ res = {}
+ if _response.status_code >= 400:
+ content_type = _response.headers["content-type"]
+ if content_type is None:
+ res["content"] = _content
+ else:
+ res["content"] = _content
+
+ return res
+
+
+class VerifyBadRequestResponse(ExpectedError):
+ """
+ Verifies that the OP returned a 400 Bad Request response containing a
+ Error message.
+ """
+ cid = "verify-bad-request-response"
+ msg = "OP error"
+
+ def _func(self, conv):
+ _response = conv.last_response
+ _content = conv.last_content
+ res = {}
+ if _response.status_code == 400:
+ pass
+ else:
+ self._message = "Expected a 400 error message"
+ self._status = CRITICAL
+
+ return res
+
+
+class VerifyError(Error):
+ cid = "verify-error"
+
+ def _func(self, conv):
+ response = conv.last_response
+ if response.status_code == 400:
+ try:
+ resp = json.loads(response.text)
+ if "error" in resp:
+ return {}
+ except Exception:
+ pass
+
+ item, msg = conv.protocol_response[-1]
+ try:
+ assert item.type().endswith("ErrorResponse")
+ except AssertionError:
+ self._message = "Expected an error response"
+ self._status = self.status
+ return {}
+
+ try:
+ assert item["error"] in self._kwargs["error"]
+ except AssertionError:
+ self._message = "Wrong type of error, got %s" % item["error"]
+ self._status = self.status
+
+ return {}
+
+
+class WrapException(CriticalError):
+ """
+ A runtime exception
+ """
+ cid = "exception"
+ msg = "Test tool exception"
+
+ def _func(self, conv=None):
+ self._status = self.status
+ self._message = traceback.format_exception(*sys.exc_info())
+ return {}
+
+
+class Other(CriticalError):
+ """ Other error """
+ msg = "Other error"
+
+
+class CheckHTTPResponse(CriticalError):
+ """
+ Checks that the HTTP response status is within the 200 or 300 range
+ """
+ cid = "check-http-response"
+ msg = "OP error"
+
+ def _func(self, conv):
+ _response = conv.last_response
+ _content = conv.last_content
+
+ res = {}
+ if _response.status_code >= 400:
+ self._status = self.status
+ self._message = self.msg
+ res["content"] = _content
+ res["url"] = conv.position
+ res["http_status"] = _response.status_code
+
+ return res
+
+
+class MissingRedirect(CriticalError):
+ """ At this point in the flow a redirect back to the client was expected.
+ """
+ cid = "missing-redirect"
+ msg = "Expected redirect to the RP, got something else"
+
+ def _func(self, conv=None):
+ self._status = self.status
+ return {"url": conv.position}
+
+
+class Parse(CriticalError):
+ """ Parsing the response """
+ cid = "response-parse"
+ errmsg = "Parse error"
+
+ def _func(self, conv=None):
+ if conv.exception:
+ self._status = self.status
+ err = conv.exception
+ self._message = "%s: %s" % (err.__class__.__name__, err)
+ else:
+ _rmsg = conv.response_message
+ cname = _rmsg.type()
+ if conv.response_type != cname:
+ self._status = self.status
+ self._message = (
+ "Didn't get a response of the type I expected:",
+ " '%s' instead of '%s', content:'%s'" % (
+ cname, conv.response_type, _rmsg))
+ return {
+ "response_type": conv.response_type,
+ "url": conv.position
+ }
+
+ return {}
+
+def factory(cid, classes):
+ if len(classes) == 0:
+ for name, obj in inspect.getmembers(sys.modules[__name__]):
+ if inspect.isclass(obj):
+ try:
+ classes[obj.cid] = obj
+ except AttributeError:
+ pass
+
+ if cid in classes:
+ return classes[cid]
+ else:
+ return None