summaryrefslogtreecommitdiff
path: root/src/saml2test
diff options
context:
space:
mode:
authorRoland Hedberg <roland.hedberg@adm.umu.se>2013-08-26 09:45:39 +0200
committerRoland Hedberg <roland.hedberg@adm.umu.se>2013-08-26 09:45:39 +0200
commitedda40422c78443f40d4171d381c5f3ee731778b (patch)
tree0d60c036314b08ee2d96a738b0242c9cc21730e0 /src/saml2test
parent6e86b548397858938f13f5616cf0fa4e8d1dbf98 (diff)
downloadpysaml2-edda40422c78443f40d4171d381c5f3ee731778b.tar.gz
Changed directory name.
Diffstat (limited to 'src/saml2test')
-rw-r--r--src/saml2test/__init__.py126
-rw-r--r--src/saml2test/check.py262
-rw-r--r--src/saml2test/interaction.py394
-rw-r--r--src/saml2test/opfunc.py373
-rw-r--r--src/saml2test/status.py11
-rw-r--r--src/saml2test/tool.py286
6 files changed, 1452 insertions, 0 deletions
diff --git a/src/saml2test/__init__.py b/src/saml2test/__init__.py
new file mode 100644
index 00000000..15e096d8
--- /dev/null
+++ b/src/saml2test/__init__.py
@@ -0,0 +1,126 @@
+import logging
+import time
+import traceback
+import requests
+import sys
+
+from subprocess import Popen, PIPE
+from saml2test.check import CRITICAL
+
+logger = logging.getLogger(__name__)
+
+__author__ = 'rolandh'
+
+
+class FatalError(Exception):
+ pass
+
+
+class CheckError(Exception):
+ pass
+
+
+class HTTP_ERROR(Exception):
+ pass
+
+
+class Unknown(Exception):
+ pass
+
+
+# class Trace(object):
+# def __init__(self):
+# self.trace = []
+# self.start = time.time()
+#
+# def request(self, msg):
+# delta = time.time() - self.start
+# self.trace.append("%f --> %s" % (delta, msg))
+#
+# def reply(self, msg):
+# delta = time.time() - self.start
+# self.trace.append("%f <-- %s" % (delta, msg))
+#
+# def info(self, msg, who="saml2client"):
+# delta = time.time() - self.start
+# self.trace.append("%f - INFO - [%s] %s" % (delta, who, msg))
+#
+# def error(self, msg, who="saml2client"):
+# delta = time.time() - self.start
+# self.trace.append("%f - ERROR - [%s] %s" % (delta, who, msg))
+#
+# def warning(self, msg, who="saml2client"):
+# delta = time.time() - self.start
+# self.trace.append("%f - WARNING - [%s] %s" % (delta, who, msg))
+#
+# def __str__(self):
+# return "\n". join([t.encode("utf-8") for t in self.trace])
+#
+# def clear(self):
+# self.trace = []
+#
+# def __getitem__(self, item):
+# return self.trace[item]
+#
+# def next(self):
+# for line in self.trace:
+# yield line
+
+
+class ContextFilter(logging.Filter):
+ """
+ This is a filter which injects time laps information into the log.
+ """
+
+ def start(self):
+ self.start = time.time()
+
+ def filter(self, record):
+ record.delta = time.time() - self.start
+ return True
+
+
+def start_script(path, *args):
+ popen_args = [path]
+ popen_args.extend(args)
+ return Popen(popen_args, stdout=PIPE, stderr=PIPE)
+
+
+def stop_script_by_name(name):
+ import subprocess
+ import signal
+ import os
+
+ p = subprocess.Popen(['ps', '-A'], stdout=subprocess.PIPE)
+ out, err = p.communicate()
+
+ for line in out.splitlines():
+ if name in line:
+ pid = int(line.split(None, 1)[0])
+ os.kill(pid, signal.SIGKILL)
+
+
+def stop_script_by_pid(pid):
+ import signal
+ import os
+
+ os.kill(pid, signal.SIGKILL)
+
+
+def get_page(url):
+ resp = requests.get(url)
+ if resp.status_code == 200:
+ return resp.text
+ else:
+ raise HTTP_ERROR(resp.status)
+
+
+def exception_trace(tag, exc, log=None):
+ message = traceback.format_exception(*sys.exc_info())
+
+ try:
+ _exc = "Exception: %s" % exc
+ except UnicodeEncodeError:
+ _exc = "Exception: %s" % exc.message.encode("utf-8", "replace")
+
+ return {"status": CRITICAL, "message": _exc, "content": "".join(message)}
diff --git a/src/saml2test/check.py b/src/saml2test/check.py
new file mode 100644
index 00000000..1f2f062a
--- /dev/null
+++ b/src/saml2test/check.py
@@ -0,0 +1,262 @@
+import inspect
+import json
+
+__author__ = 'rolandh'
+
+import traceback
+import sys
+
+INFORMATION = 0
+OK = 1
+WARNING = 2
+ERROR = 3
+CRITICAL = 4
+INTERACTION = 5
+
+STATUSCODE = ["INFORMATION", "OK", "WARNING", "ERROR", "CRITICAL",
+ "INTERACTION"]
+
+CONT_JSON = "application/json"
+CONT_JWT = "application/jwt"
+
+
+class Check(object):
+ """ General test
+ """
+
+ cid = "check"
+ msg = "OK"
+
+ def __init__(self, **kwargs):
+ self._status = OK
+ self._message = ""
+ self.content = None
+ self.url = ""
+ self._kwargs = kwargs
+
+ def _func(self, conv):
+ return {}
+
+ def __call__(self, conv=None, output=None):
+ _stat = self.response(**self._func(conv))
+ if output is not None:
+ output.append(_stat)
+ return _stat
+
+ def response(self, **kwargs):
+ try:
+ name = " ".join(
+ [s.strip() for s in self.__doc__.strip().split("\n")])
+ except AttributeError:
+ name = ""
+
+ res = {
+ "id": self.cid,
+ "status": self._status,
+ "name": name
+ }
+
+ if self._message:
+ res["message"] = self._message
+
+ if kwargs:
+ res.update(kwargs)
+
+ return res
+
+
+class ExpectedError(Check):
+ pass
+
+
+class CriticalError(Check):
+ status = CRITICAL
+
+
+class Information(Check):
+ status = INFORMATION
+
+
+class Error(Check):
+ status = ERROR
+
+
+class ResponseInfo(Information):
+ """Response information"""
+
+ def _func(self, conv=None):
+ self._status = self.status
+ _msg = conv.last_content
+
+ if isinstance(_msg, basestring):
+ self._message = _msg
+ else:
+ self._message = _msg.to_dict()
+
+ return {}
+
+
+class CheckErrorResponse(ExpectedError):
+ """
+ Checks that the HTTP response status is outside the 200 or 300 range
+ or that an JSON encoded error message has been received
+ """
+ cid = "check-error-response"
+ msg = "OP error"
+
+ def _func(self, conv):
+ _response = conv.last_response
+ _content = conv.last_content
+
+ res = {}
+ if _response.status_code >= 400:
+ content_type = _response.headers["content-type"]
+ if content_type is None:
+ res["content"] = _content
+ else:
+ res["content"] = _content
+
+ return res
+
+
+class VerifyBadRequestResponse(ExpectedError):
+ """
+ Verifies that the OP returned a 400 Bad Request response containing a
+ Error message.
+ """
+ cid = "verify-bad-request-response"
+ msg = "OP error"
+
+ def _func(self, conv):
+ _response = conv.last_response
+ _content = conv.last_content
+ res = {}
+ if _response.status_code == 400:
+ pass
+ else:
+ self._message = "Expected a 400 error message"
+ self._status = CRITICAL
+
+ return res
+
+
+class VerifyError(Error):
+ cid = "verify-error"
+
+ def _func(self, conv):
+ response = conv.last_response
+ if response.status_code == 400:
+ try:
+ resp = json.loads(response.text)
+ if "error" in resp:
+ return {}
+ except Exception:
+ pass
+
+ item, msg = conv.protocol_response[-1]
+ try:
+ assert item.type().endswith("ErrorResponse")
+ except AssertionError:
+ self._message = "Expected an error response"
+ self._status = self.status
+ return {}
+
+ try:
+ assert item["error"] in self._kwargs["error"]
+ except AssertionError:
+ self._message = "Wrong type of error, got %s" % item["error"]
+ self._status = self.status
+
+ return {}
+
+
+class WrapException(CriticalError):
+ """
+ A runtime exception
+ """
+ cid = "exception"
+ msg = "Test tool exception"
+
+ def _func(self, conv=None):
+ self._status = self.status
+ self._message = traceback.format_exception(*sys.exc_info())
+ return {}
+
+
+class Other(CriticalError):
+ """ Other error """
+ msg = "Other error"
+
+
+class CheckHTTPResponse(CriticalError):
+ """
+ Checks that the HTTP response status is within the 200 or 300 range
+ """
+ cid = "check-http-response"
+ msg = "OP error"
+
+ def _func(self, conv):
+ _response = conv.last_response
+ _content = conv.last_content
+
+ res = {}
+ if _response.status_code >= 400:
+ self._status = self.status
+ self._message = self.msg
+ res["content"] = _content
+ res["url"] = conv.position
+ res["http_status"] = _response.status_code
+
+ return res
+
+
+class MissingRedirect(CriticalError):
+ """ At this point in the flow a redirect back to the client was expected.
+ """
+ cid = "missing-redirect"
+ msg = "Expected redirect to the RP, got something else"
+
+ def _func(self, conv=None):
+ self._status = self.status
+ return {"url": conv.position}
+
+
+class Parse(CriticalError):
+ """ Parsing the response """
+ cid = "response-parse"
+ errmsg = "Parse error"
+
+ def _func(self, conv=None):
+ if conv.exception:
+ self._status = self.status
+ err = conv.exception
+ self._message = "%s: %s" % (err.__class__.__name__, err)
+ else:
+ _rmsg = conv.response_message
+ cname = _rmsg.type()
+ if conv.response_type != cname:
+ self._status = self.status
+ self._message = (
+ "Didn't get a response of the type I expected:",
+ " '%s' instead of '%s', content:'%s'" % (
+ cname, conv.response_type, _rmsg))
+ return {
+ "response_type": conv.response_type,
+ "url": conv.position
+ }
+
+ return {}
+
+def factory(cid, classes):
+ if len(classes) == 0:
+ for name, obj in inspect.getmembers(sys.modules[__name__]):
+ if inspect.isclass(obj):
+ try:
+ classes[obj.cid] = obj
+ except AttributeError:
+ pass
+
+ if cid in classes:
+ return classes[cid]
+ else:
+ return None
diff --git a/src/saml2test/interaction.py b/src/saml2test/interaction.py
new file mode 100644
index 00000000..d5f48ed1
--- /dev/null
+++ b/src/saml2test/interaction.py
@@ -0,0 +1,394 @@
+__author__ = 'rohe0002'
+
+import json
+import logging
+
+from urlparse import urlparse
+from bs4 import BeautifulSoup
+
+from mechanize import ParseResponseEx
+from mechanize._form import ControlNotFoundError, AmbiguityError
+from mechanize._form import ListControl
+
+logger = logging.getLogger(__name__)
+
+NO_CTRL = "No submit control with the name='%s' and value='%s' could be found"
+
+
+class FlowException(Exception):
+ def __init__(self, function="", content="", url=""):
+ Exception.__init__(self)
+ self.function = function
+ self.content = content
+ self.url = url
+
+ def __str__(self):
+ return json.dumps(self.__dict__)
+
+
+class InteractionNeeded(Exception):
+ pass
+
+
+def NoneFunc():
+ return None
+
+
+class RResponse():
+ """
+ A Response class that behaves in the way that mechanize expects it.
+ Links to a requests.Response
+ """
+ def __init__(self, resp):
+ self._resp = resp
+ self.index = 0
+ self.text = resp.text
+ if isinstance(self.text, unicode):
+ if resp.encoding == "UTF-8":
+ self.text = self.text.encode("utf-8")
+ else:
+ self.text = self.text.encode("latin-1")
+ self._len = len(self.text)
+ self.url = str(resp.url)
+ self.statuscode = resp.status_code
+
+ def geturl(self):
+ return self._resp.url
+
+ def __getitem__(self, item):
+ try:
+ return getattr(self._resp, item)
+ except AttributeError:
+ return getattr(self._resp.headers, item)
+
+ def __getattribute__(self, item):
+ try:
+ return getattr(self._resp, item)
+ except AttributeError:
+ return getattr(self._resp.headers, item)
+
+ def read(self, size=0):
+ """
+ Read from the content of the response. The class remembers what has
+ been read so it's possible to read small consecutive parts of the
+ content.
+
+ :param size: The number of bytes to read
+ :return: Somewhere between zero and 'size' number of bytes depending
+ on how much it left in the content buffer to read.
+ """
+ if size:
+ if self._len < size:
+ return self.text
+ else:
+ if self._len == self.index:
+ part = None
+ elif self._len - self.index < size:
+ part = self.text[self.index:]
+ self.index = self._len
+ else:
+ part = self.text[self.index:self.index + size]
+ self.index += size
+ return part
+ else:
+ return self.text
+
+
+class Interaction(object):
+ def __init__(self, httpc, interactions=None):
+ self.httpc = httpc
+ self.interactions = interactions
+ self.who = "Form process"
+
+ def pick_interaction(self, _base="", content="", req=None):
+ unic = content
+ if content:
+ _bs = BeautifulSoup(content)
+ else:
+ _bs = None
+
+ for interaction in self.interactions:
+ _match = 0
+ for attr, val in interaction["matches"].items():
+ if attr == "url":
+ if val == _base:
+ _match += 1
+ elif attr == "title":
+ if _bs is None:
+ break
+ if _bs.title is None:
+ break
+ if val in _bs.title.contents:
+ _match += 1
+ else:
+ _c = _bs.title.contents
+ if isinstance(_c, list) and not isinstance(_c,
+ basestring):
+ for _line in _c:
+ if val in _line:
+ _match += 1
+ continue
+ elif attr == "content":
+ if unic and val in unic:
+ _match += 1
+ elif attr == "class":
+ if req and val == req:
+ _match += 1
+
+ if _match == len(interaction["matches"]):
+ logger.info("Matched: %s" % interaction["matches"])
+ return interaction
+
+ raise InteractionNeeded("No interaction matched")
+
+ def pick_form(self, response, url=None, **kwargs):
+ """
+ Picks which form in a web-page that should be used
+
+ :param response: A HTTP request response. A DResponse instance
+ :param content: The HTTP response content
+ :param url: The url the request was sent to
+ :param kwargs: Extra key word arguments
+ :return: The picked form or None of no form matched the criteria.
+ """
+
+ forms = ParseResponseEx(response)
+ if not forms:
+ raise FlowException(content=response.text, url=url)
+
+ #if len(forms) == 1:
+ # return forms[0]
+ #else:
+
+ _form = None
+ # ignore the first form, because I use ParseResponseEx which adds
+ # one form at the top of the list
+ forms = forms[1:]
+ if len(forms) == 1:
+ _form = forms[0]
+ else:
+ if "pick" in kwargs:
+ _dict = kwargs["pick"]
+ for form in forms:
+ if _form:
+ break
+ for key, _ava in _dict.items():
+ if key == "form":
+ _keys = form.attrs.keys()
+ for attr, val in _ava.items():
+ if attr in _keys and val == form.attrs[attr]:
+ _form = form
+ elif key == "control":
+ prop = _ava["id"]
+ _default = _ava["value"]
+ try:
+ orig_val = form[prop]
+ if isinstance(orig_val, basestring):
+ if orig_val == _default:
+ _form = form
+ elif _default in orig_val:
+ _form = form
+ except KeyError:
+ pass
+ except ControlNotFoundError:
+ pass
+ elif key == "method":
+ if form.method == _ava:
+ _form = form
+ else:
+ _form = None
+
+ if not _form:
+ break
+ elif "index" in kwargs:
+ _form = forms[int(kwargs["index"])]
+
+ return _form
+
+ def do_click(self, form, **kwargs):
+ """
+ Emulates the user clicking submit on a form.
+
+ :param form: The form that should be submitted
+ :return: What do_request() returns
+ """
+
+ if "click" in kwargs:
+ request = None
+ _name = kwargs["click"]
+ try:
+ _ = form.find_control(name=_name)
+ request = form.click(name=_name)
+ except AmbiguityError:
+ # more than one control with that name
+ _val = kwargs["set"][_name]
+ _nr = 0
+ while True:
+ try:
+ cntrl = form.find_control(name=_name, nr=_nr)
+ if cntrl.value == _val:
+ request = form.click(name=_name, nr=_nr)
+ break
+ else:
+ _nr += 1
+ except ControlNotFoundError:
+ raise Exception(NO_CTRL % (_name, _val))
+ else:
+ request = form.click()
+
+ headers = {}
+ for key, val in request.unredirected_hdrs.items():
+ headers[key] = val
+
+ url = request._Request__original
+
+ if form.method == "POST":
+ return self.httpc.send(url, "POST", data=request.data,
+ headers=headers)
+ else:
+ return self.httpc.send(url, "GET", headers=headers)
+
+ def select_form(self, orig_response, **kwargs):
+ """
+ Pick a form on a web page, possibly enter some information and submit
+ the form.
+
+ :param orig_response: The original response (as returned by requests)
+ :return: The response do_click() returns
+ """
+ logger.info("select_form")
+ response = RResponse(orig_response)
+ try:
+ _url = response.url
+ except KeyError:
+ _url = kwargs["location"]
+
+ form = self.pick_form(response, _url, **kwargs)
+ #form.backwards_compatible = False
+ if not form:
+ raise Exception("Can't pick a form !!")
+
+ if "set" in kwargs:
+ for key, val in kwargs["set"].items():
+ if key.startswith("_"):
+ continue
+ if "click" in kwargs and kwargs["click"] == key:
+ continue
+
+ try:
+ form[key] = val
+ except ControlNotFoundError:
+ pass
+ except TypeError:
+ cntrl = form.find_control(key)
+ if isinstance(cntrl, ListControl):
+ form[key] = [val]
+ else:
+ raise
+
+ if form.action in kwargs["conv"].my_endpoints():
+ return {"SAMLResponse": form["SAMLResponse"],
+ "RelayState": form["RelayState"]}
+
+ return self.do_click(form, **kwargs)
+
+ #noinspection PyUnusedLocal
+ def chose(self, orig_response, path, **kwargs):
+ """
+ Sends a HTTP GET to a url given by the present url and the given
+ relative path.
+
+ :param orig_response: The original response
+ :param content: The content of the response
+ :param path: The relative path to add to the base URL
+ :return: The response do_click() returns
+ """
+
+ if not path.startswith("http"):
+ try:
+ _url = orig_response.url
+ except KeyError:
+ _url = kwargs["location"]
+
+ part = urlparse(_url)
+ url = "%s://%s%s" % (part[0], part[1], path)
+ else:
+ url = path
+
+ logger.info("GET %s" % url)
+ return self.httpc.send(url, "GET")
+ #return resp, ""
+
+ def post_form(self, orig_response, **kwargs):
+ """
+ The same as select_form but with no possibility of change the content
+ of the form.
+
+ :param httpc: A HTTP Client instance
+ :param orig_response: The original response (as returned by requests)
+ :param content: The content of the response
+ :return: The response do_click() returns
+ """
+ response = RResponse(orig_response)
+
+ form = self.pick_form(response, **kwargs)
+
+ return self.do_click(form, **kwargs)
+
+ #noinspection PyUnusedLocal
+ def parse(self, orig_response, **kwargs):
+ # content is a form from which I get the SAMLResponse
+ response = RResponse(orig_response)
+
+ form = self.pick_form(response, **kwargs)
+ #form.backwards_compatible = False
+ if not form:
+ raise InteractionNeeded("Can't pick a form !!")
+
+ return {"SAMLResponse": form["SAMLResponse"],
+ "RelayState": form["RelayState"]}
+
+ #noinspection PyUnusedLocal
+ def interaction(self, args):
+ _type = args["type"]
+ if _type == "form":
+ return self.select_form
+ elif _type == "link":
+ return self.chose
+ elif _type == "response":
+ return self.parse
+ else:
+ return NoneFunc
+
+# ========================================================================
+
+
+class Action(object):
+ def __init__(self, args):
+ self.args = args or {}
+ self.request = None
+
+ def update(self, dic):
+ self.args.update(dic)
+
+ #noinspection PyUnusedLocal
+ def post_op(self, result, conv, args):
+ pass
+
+ def __call__(self, httpc, conv, location, response, content, features):
+ intact = Interaction(httpc)
+ function = intact.interaction(self.args)
+
+ try:
+ _args = self.args.copy()
+ except (KeyError, AttributeError):
+ _args = {}
+
+ _args.update({"location": location, "features": features, "conv": conv})
+
+ logger.info("<-- FUNCTION: %s" % function.__name__)
+ logger.info("<-- ARGS: %s" % _args)
+
+ result = function(response, **_args)
+ self.post_op(result, conv, _args)
+ return result
diff --git a/src/saml2test/opfunc.py b/src/saml2test/opfunc.py
new file mode 100644
index 00000000..2f88c70b
--- /dev/null
+++ b/src/saml2test/opfunc.py
@@ -0,0 +1,373 @@
+import logging
+import json
+
+from urlparse import urlparse
+
+from mechanize import ParseResponseEx
+from mechanize._form import ControlNotFoundError, AmbiguityError
+from mechanize._form import ListControl
+
+__author__ = 'rohe0002'
+
+logger = logging.getLogger(__name__)
+
+
+class FlowException(Exception):
+ def __init__(self, function="", content="", url=""):
+ Exception.__init__(self)
+ self.function = function
+ self.content = content
+ self.url = url
+
+ def __str__(self):
+ return json.dumps(self.__dict__)
+
+
+class DResponse():
+ """ A Response class that behaves in the way that mechanize expects it
+ """
+ def __init__(self, **kwargs):
+ self.status = 200 # default
+ self.index = 0
+ self._message = ""
+ self.url = ""
+ if kwargs:
+ for key, val in kwargs.items():
+ if val:
+ self.__setitem__(key, val)
+
+ def __setitem__(self, key, value):
+ setattr(self, key, value)
+
+ def __getitem__(self, item):
+ if item == "content-location":
+ return self.url
+ elif item == "content-length":
+ return len(self._message)
+ else:
+ return getattr(self, item)
+
+ def geturl(self):
+ """
+ The base url for the response
+
+ :return: The url
+ """
+ return self.url
+
+ def read(self, size=0):
+ """
+ Read from the content of the response. The class remembers what has
+ been read so it's possible to read small consecutive parts of the
+ content.
+
+ :param size: The number of bytes to read
+ :return: Somewhere between zero and 'size' number of bytes depending
+ on how much it left in the content buffer to read.
+ """
+ if size:
+ if self._len < size:
+ return self._message
+ else:
+ if self._len == self.index:
+ part = None
+ elif self._len - self.index < size:
+ part = self._message[self.index:]
+ self.index = self._len
+ else:
+ part = self._message[self.index:self.index + size]
+ self.index += size
+ return part
+ else:
+ return self._message
+
+ def write(self, message):
+ """
+ Write the message into the content buffer
+
+ :param message: The message
+ """
+ self._message = message
+ self._len = len(message)
+
+
+def do_request(client, url, method, body="", headers=None):
+ """
+ Sends a HTTP request.
+
+ :param client: The client instance
+ :param url: Where to send the request
+ :param method: The HTTP method to use for the request
+ :param body: The request body
+ :param headers: The requset headers
+ :return: A tuple of
+ url - the url the request was sent to
+ response - the response to the request
+ content - the content of the response if any
+ """
+ if headers is None:
+ headers = {}
+
+ logger.info("--> URL: %s" % url)
+ logger.info("--> BODY: %s" % body)
+ logger.info("--> Headers: %s" % (headers,))
+
+ response = client.http_request(url, method=method, data=body,
+ headers=headers)
+
+ logger.info("<-- RESPONSE: %s" % response)
+ logger.info("<-- CONTENT: %s" % response.text)
+ if response.cookies:
+ logger.info("<-- COOKIES: %s" % response.cookies)
+
+ return url, response, response.text
+
+
+def pick_form(response, content, url=None, **kwargs):
+ """
+ Picks which form in a web-page that should be used
+
+ :param response: A HTTP request response. A DResponse instance
+ :param content: The HTTP response content
+ :param url: The url the request was sent to
+ :return: The picked form or None of no form matched the criteria.
+ """
+
+ forms = ParseResponseEx(response)
+ if not forms:
+ raise FlowException(content=content, url=url)
+
+ #if len(forms) == 1:
+ # return forms[0]
+ #else:
+
+ _form = None
+ # ignore the first form for now
+ forms = forms[1:]
+ if len(forms) == 1:
+ _form = forms[0]
+ else:
+ if "pick" in kwargs:
+ _dict = kwargs["pick"]
+ for form in forms:
+ if _form:
+ break
+ for key, _ava in _dict.items():
+ if key == "form":
+ _keys = form.attrs.keys()
+ for attr, val in _ava.items():
+ if attr in _keys and val == form.attrs[attr]:
+ _form = form
+ elif key == "control":
+ prop = _ava["id"]
+ _default = _ava["value"]
+ try:
+ orig_val = form[prop]
+ if isinstance(orig_val, basestring):
+ if orig_val == _default:
+ _form = form
+ elif _default in orig_val:
+ _form = form
+ except KeyError:
+ pass
+ elif key == "method":
+ if form.method == _ava:
+ _form = form
+ else:
+ _form = None
+
+ if not _form:
+ break
+ elif "index" in kwargs:
+ _form = forms[int(kwargs["index"])]
+
+ return _form
+
+
+def do_click(client, form, **kwargs):
+ """
+ Emulates the user clicking submit on a form.
+
+ :param client: The Client instance
+ :param form: The form that should be submitted
+ :return: What do_request() returns
+ """
+
+ if "click" in kwargs:
+ request = None
+ _name = kwargs["click"]
+ try:
+ _ = form.find_control(name=_name)
+ request = form.click(name=_name)
+ except AmbiguityError:
+ # more than one control with that name
+ _val = kwargs["set"][_name]
+ _nr = 0
+ while True:
+ try:
+ cntrl = form.find_control(name=_name, nr=_nr)
+ if cntrl.value == _val:
+ request = form.click(name=_name, nr=_nr)
+ break
+ else:
+ _nr += 1
+ except ControlNotFoundError:
+ raise Exception("No submit control with the name='%s' and "
+ "value='%s' could be found" % (_name,
+ _val))
+ else:
+ request = form.click()
+
+ headers = {}
+ for key, val in request.unredirected_hdrs.items():
+ headers[key] = val
+
+ url = request._Request__original
+
+ if form.method == "POST":
+ return do_request(client, url, "POST", request.data, headers)
+ else:
+ return do_request(client, url, "GET", headers=headers)
+
+
+def select_form(client, orig_response, content, **kwargs):
+ """
+ Pick a form on a web page, possibly enter some information and submit
+ the form.
+
+ :param client: The Client
+ :param orig_response: The original response (as returned by httplib2)
+ :param content: The content of the response
+ :return: The response do_click() returns
+ """
+ try:
+ _url = orig_response.url
+ except KeyError:
+ _url = kwargs["location"]
+ # content is a form to be filled in and returned
+ if isinstance(content, unicode):
+ content = content.encode("utf-8")
+
+ response = DResponse(status=orig_response.status_code, url=_url)
+ response.write(content)
+
+ form = pick_form(response, content, _url, **kwargs)
+ #form.backwards_compatible = False
+ if not form:
+ raise Exception("Can't pick a form !!")
+
+ if "set" in kwargs:
+ for key, val in kwargs["set"].items():
+ if key.startswith("_"):
+ continue
+ if "click" in kwargs and kwargs["click"] == key:
+ continue
+
+ try:
+ form[key] = val
+ except ControlNotFoundError:
+ pass
+ except TypeError:
+ cntrl = form.find_control(key)
+ if isinstance(cntrl, ListControl):
+ form[key] = [val]
+ else:
+ raise
+
+ return do_click(client, form, **kwargs)
+
+
+#noinspection PyUnusedLocal
+def chose(client, orig_response, content, path, **kwargs):
+ """
+ Sends a HTTP GET to a url given by the present url and the given
+ relative path.
+
+ :param orig_response: The original response
+ :param content: The content of the response
+ :param path: The relative path to add to the base URL
+ :return: The response do_click() returns
+ """
+
+ if not path.startswith("http"):
+ try:
+ _url = orig_response.url
+ except KeyError:
+ _url = kwargs["location"]
+
+ part = urlparse(_url)
+ url = "%s://%s%s" % (part[0], part[1], path)
+ else:
+ url = path
+
+ return do_request(client, url, "GET")
+
+
+def post_form(client, orig_response, content, **kwargs):
+ """
+ The same as select_form but with no possibility of change the content
+ of the form.
+
+ :param client: The Client instance
+ :param orig_response: The original response (as returned by httplib2)
+ :param content: The content of the response
+ :return: The response do_click() returns
+ """
+ _url = orig_response.url
+ # content is a form to be filled in and returned
+ response = DResponse(status=orig_response.status_code, url=_url)
+ response.write(content)
+
+ form = pick_form(response, content, _url, **kwargs)
+
+ return do_click(client, form, **kwargs)
+
+
+def NoneFunc():
+ return None
+
+
+def interaction(args):
+ _type = args["type"]
+ if _type == "form":
+ return select_form
+ elif _type == "link":
+ return chose
+ else:
+ return NoneFunc
+
+# ========================================================================
+
+
+class Operation(object):
+ def __init__(self, conv, args=None, features=None):
+ if args:
+ self.function = interaction(args)
+
+ self.args = args or {}
+ self.request = None
+ self.conv = conv
+ self.features = features
+ self.cconf = conv.client_config
+
+ def update(self, dic):
+ self.args.update(dic)
+
+ #noinspection PyUnusedLocal
+ def post_op(self, result, environ, args):
+ pass
+
+ def __call__(self, location, response, content, feature=None):
+ try:
+ _args = self.args.copy()
+ except (KeyError, AttributeError):
+ _args = {}
+
+ _args["location"] = location
+
+ logger.info("--> FUNCTION: %s" % self.function.__name__)
+ logger.info("--> ARGS: %s" % (_args,))
+
+ result = self.function(self.conv.client, response, content, **_args)
+ self.post_op(result, self.conv, _args)
+ return result
diff --git a/src/saml2test/status.py b/src/saml2test/status.py
new file mode 100644
index 00000000..4f5ba840
--- /dev/null
+++ b/src/saml2test/status.py
@@ -0,0 +1,11 @@
+__author__ = 'rolandh'
+
+INFORMATION = 0
+OK = 1
+WARNING = 2
+ERROR = 3
+CRITICAL = 4
+INTERACTION = 5
+
+STATUSCODE = ["INFORMATION", "OK", "WARNING", "ERROR", "CRITICAL",
+ "INTERACTION"]
diff --git a/src/saml2test/tool.py b/src/saml2test/tool.py
new file mode 100644
index 00000000..db7c528b
--- /dev/null
+++ b/src/saml2test/tool.py
@@ -0,0 +1,286 @@
+import cookielib
+import sys
+import traceback
+import logging
+from urlparse import parse_qs
+
+from saml2test.opfunc import Operation
+from saml2test import FatalError
+from saml2test.check import ExpectedError
+from saml2test.check import INTERACTION
+from saml2test.interaction import Interaction
+from saml2test.interaction import Action
+from saml2test.interaction import InteractionNeeded
+from saml2test.status import STATUSCODE
+
+__author__ = 'rolandh'
+
+logger = logging.getLogger(__name__)
+
+
+class Conversation(object):
+ """
+ :ivar response: The received HTTP messages
+ :ivar protocol_response: List of the received protocol messages
+ """
+
+ def __init__(self, client, config, interaction,
+ check_factory=None, msg_factory=None,
+ features=None, verbose=False, expect_exception=None):
+ self.client = client
+ self.client_config = config
+ self.test_output = []
+ self.features = features
+ self.verbose = verbose
+ self.check_factory = check_factory
+ self.msg_factory = msg_factory
+ self.expect_exception = expect_exception
+
+ self.cjar = {"browser": cookielib.CookieJar(),
+ "rp": cookielib.CookieJar(),
+ "service": cookielib.CookieJar()}
+
+ self.protocol_response = []
+ self.last_response = None
+ self.last_content = None
+ self.response = None
+ self.interaction = Interaction(self.client, interaction)
+ self.exception = None
+
+ def check_severity(self, stat):
+ if stat["status"] >= 4:
+ logger.error("WHERE: %s" % stat["id"])
+ logger.error("STATUS:%s" % STATUSCODE[stat["status"]])
+ try:
+ logger.error("HTTP STATUS: %s" % stat["http_status"])
+ except KeyError:
+ pass
+ try:
+ logger.error("INFO: %s" % stat["message"])
+ except KeyError:
+ pass
+
+ raise FatalError
+
+ def do_check(self, test, **kwargs):
+ if isinstance(test, basestring):
+ chk = self.check_factory(test)(**kwargs)
+ else:
+ chk = test(**kwargs)
+ stat = chk(self, self.test_output)
+ self.check_severity(stat)
+
+ def err_check(self, test, err=None, bryt=True):
+ if err:
+ self.exception = err
+ chk = self.check_factory(test)()
+ chk(self, self.test_output)
+ if bryt:
+ e = FatalError("%s" % err)
+ e.trace = "".join(traceback.format_exception(*sys.exc_info()))
+ raise e
+
+ def test_sequence(self, sequence):
+ for test in sequence:
+ if isinstance(test, tuple):
+ test, kwargs = test
+ else:
+ kwargs = {}
+ self.do_check(test, **kwargs)
+ if test == ExpectedError:
+ return False
+ return True
+
+ def my_endpoints(self):
+ pass
+
+ def intermit(self):
+ _response = self.last_response
+ _last_action = None
+ _same_actions = 0
+ if _response.status_code >= 400:
+ done = True
+ else:
+ done = False
+
+ url = _response.url
+ content = _response.text
+ while not done:
+ rdseq = []
+ while _response.status_code in [302, 301, 303]:
+ url = _response.headers["location"]
+ if url in rdseq:
+ raise FatalError("Loop detected in redirects")
+ else:
+ rdseq.append(url)
+ if len(rdseq) > 8:
+ raise FatalError(
+ "Too long sequence of redirects: %s" % rdseq)
+
+ logger.info("HTTP %d Location: %s" % (_response.status_code,
+ url))
+ # If back to me
+ for_me = False
+ for redirect_uri in self.my_endpoints():
+ if url.startswith(redirect_uri):
+ # Back at the RP
+ self.client.cookiejar = self.cjar["rp"]
+ for_me = True
+ try:
+ base, query = url.split("?")
+ except ValueError:
+ pass
+ else:
+ _response = parse_qs(query)
+ self.last_response = _response
+ self.last_content = _response
+ return _response
+
+ if for_me:
+ done = True
+ break
+ else:
+ try:
+ logger.info("GET %s" % url)
+ _response = self.client.send(url, "GET")
+ except Exception, err:
+ raise FatalError("%s" % err)
+
+ content = _response.text
+ logger.info("<-- CONTENT: %s" % content)
+ self.position = url
+ self.last_content = content
+ self.response = _response
+
+ if _response.status_code >= 400:
+ done = True
+ break
+
+ if done or url is None:
+ break
+
+ _base = url.split("?")[0]
+
+ try:
+ _spec = self.interaction.pick_interaction(_base, content)
+ except InteractionNeeded:
+ self.position = url
+ logger.error("Page Content: %s" % content)
+ raise
+ except KeyError:
+ self.position = url
+ logger.error("Page Content: %s" % content)
+ self.err_check("interaction-needed")
+
+ if _spec == _last_action:
+ _same_actions += 1
+ if _same_actions >= 3:
+ raise InteractionNeeded("Interaction loop detection")
+ else:
+ _last_action = _spec
+
+ if len(_spec) > 2:
+ logger.info(">> %s <<" % _spec["page-type"])
+ if _spec["page-type"] == "login":
+ self.login_page = content
+
+ _op = Action(_spec["control"])
+
+ try:
+ _response = _op(self.client, self, url, _response, content,
+ self.features)
+ if isinstance(_response, dict):
+ self.last_response = _response
+ self.last_content = _response
+ return _response
+ content = _response.text
+ self.position = url
+ self.last_content = content
+ self.response = _response
+
+ if _response.status_code >= 400:
+ break
+ except (FatalError, InteractionNeeded):
+ raise
+ except Exception, err:
+ self.err_check("exception", err, False)
+
+ self.last_response = _response
+ try:
+ self.last_content = _response.text
+ except AttributeError:
+ self.last_content = None
+
+ def init(self, phase):
+ self.creq, self.cresp = phase
+
+ def setup_request(self):
+ self.request_spec = req = self.creq(conv=self)
+
+ if isinstance(req, Operation):
+ for intact in self.interaction.interactions:
+ try:
+ if req.__class__.__name__ == intact["matches"]["class"]:
+ req.args = intact["args"]
+ break
+ except KeyError:
+ pass
+ else:
+ try:
+ self.request_args = req.request_args
+ except KeyError:
+ pass
+ try:
+ self.args = req.kw_args
+ except KeyError:
+ pass
+
+ # The authorization dance is all done through the browser
+ if req.request == "AuthorizationRequest":
+ self.client.cookiejar = self.cjar["browser"]
+ # everything else by someone else, assuming the RP
+ else:
+ self.client.cookiejar = self.cjar["rp"]
+
+ self.req = req
+
+ def send(self):
+ pass
+
+ def handle_result(self):
+ pass
+
+ def do_query(self):
+ self.setup_request()
+ self.send()
+ if not self.handle_result():
+ self.intermit()
+ self.handle_result()
+
+ def do_sequence(self, oper):
+ try:
+ self.test_sequence(oper["tests"]["pre"])
+ except KeyError:
+ pass
+
+ for phase in oper["sequence"]:
+ self.init(phase)
+ try:
+ self.do_query()
+ except InteractionNeeded:
+ self.test_output.append({"status": INTERACTION,
+ "message": self.last_content,
+ "id": "exception",
+ "name": "interaction needed",
+ "url": self.position})
+ break
+ except FatalError:
+ raise
+ except Exception, err:
+ #self.err_check("exception", err)
+ raise
+
+ try:
+ self.test_sequence(oper["tests"]["post"])
+ except KeyError:
+ pass