diff options
author | Roland Hedberg <roland.hedberg@adm.umu.se> | 2013-08-26 09:45:39 +0200 |
---|---|---|
committer | Roland Hedberg <roland.hedberg@adm.umu.se> | 2013-08-26 09:45:39 +0200 |
commit | edda40422c78443f40d4171d381c5f3ee731778b (patch) | |
tree | 0d60c036314b08ee2d96a738b0242c9cc21730e0 /src/saml2test | |
parent | 6e86b548397858938f13f5616cf0fa4e8d1dbf98 (diff) | |
download | pysaml2-edda40422c78443f40d4171d381c5f3ee731778b.tar.gz |
Changed directory name.
Diffstat (limited to 'src/saml2test')
-rw-r--r-- | src/saml2test/__init__.py | 126 | ||||
-rw-r--r-- | src/saml2test/check.py | 262 | ||||
-rw-r--r-- | src/saml2test/interaction.py | 394 | ||||
-rw-r--r-- | src/saml2test/opfunc.py | 373 | ||||
-rw-r--r-- | src/saml2test/status.py | 11 | ||||
-rw-r--r-- | src/saml2test/tool.py | 286 |
6 files changed, 1452 insertions, 0 deletions
diff --git a/src/saml2test/__init__.py b/src/saml2test/__init__.py new file mode 100644 index 00000000..15e096d8 --- /dev/null +++ b/src/saml2test/__init__.py @@ -0,0 +1,126 @@ +import logging +import time +import traceback +import requests +import sys + +from subprocess import Popen, PIPE +from saml2test.check import CRITICAL + +logger = logging.getLogger(__name__) + +__author__ = 'rolandh' + + +class FatalError(Exception): + pass + + +class CheckError(Exception): + pass + + +class HTTP_ERROR(Exception): + pass + + +class Unknown(Exception): + pass + + +# class Trace(object): +# def __init__(self): +# self.trace = [] +# self.start = time.time() +# +# def request(self, msg): +# delta = time.time() - self.start +# self.trace.append("%f --> %s" % (delta, msg)) +# +# def reply(self, msg): +# delta = time.time() - self.start +# self.trace.append("%f <-- %s" % (delta, msg)) +# +# def info(self, msg, who="saml2client"): +# delta = time.time() - self.start +# self.trace.append("%f - INFO - [%s] %s" % (delta, who, msg)) +# +# def error(self, msg, who="saml2client"): +# delta = time.time() - self.start +# self.trace.append("%f - ERROR - [%s] %s" % (delta, who, msg)) +# +# def warning(self, msg, who="saml2client"): +# delta = time.time() - self.start +# self.trace.append("%f - WARNING - [%s] %s" % (delta, who, msg)) +# +# def __str__(self): +# return "\n". join([t.encode("utf-8") for t in self.trace]) +# +# def clear(self): +# self.trace = [] +# +# def __getitem__(self, item): +# return self.trace[item] +# +# def next(self): +# for line in self.trace: +# yield line + + +class ContextFilter(logging.Filter): + """ + This is a filter which injects time laps information into the log. + """ + + def start(self): + self.start = time.time() + + def filter(self, record): + record.delta = time.time() - self.start + return True + + +def start_script(path, *args): + popen_args = [path] + popen_args.extend(args) + return Popen(popen_args, stdout=PIPE, stderr=PIPE) + + +def stop_script_by_name(name): + import subprocess + import signal + import os + + p = subprocess.Popen(['ps', '-A'], stdout=subprocess.PIPE) + out, err = p.communicate() + + for line in out.splitlines(): + if name in line: + pid = int(line.split(None, 1)[0]) + os.kill(pid, signal.SIGKILL) + + +def stop_script_by_pid(pid): + import signal + import os + + os.kill(pid, signal.SIGKILL) + + +def get_page(url): + resp = requests.get(url) + if resp.status_code == 200: + return resp.text + else: + raise HTTP_ERROR(resp.status) + + +def exception_trace(tag, exc, log=None): + message = traceback.format_exception(*sys.exc_info()) + + try: + _exc = "Exception: %s" % exc + except UnicodeEncodeError: + _exc = "Exception: %s" % exc.message.encode("utf-8", "replace") + + return {"status": CRITICAL, "message": _exc, "content": "".join(message)} diff --git a/src/saml2test/check.py b/src/saml2test/check.py new file mode 100644 index 00000000..1f2f062a --- /dev/null +++ b/src/saml2test/check.py @@ -0,0 +1,262 @@ +import inspect +import json + +__author__ = 'rolandh' + +import traceback +import sys + +INFORMATION = 0 +OK = 1 +WARNING = 2 +ERROR = 3 +CRITICAL = 4 +INTERACTION = 5 + +STATUSCODE = ["INFORMATION", "OK", "WARNING", "ERROR", "CRITICAL", + "INTERACTION"] + +CONT_JSON = "application/json" +CONT_JWT = "application/jwt" + + +class Check(object): + """ General test + """ + + cid = "check" + msg = "OK" + + def __init__(self, **kwargs): + self._status = OK + self._message = "" + self.content = None + self.url = "" + self._kwargs = kwargs + + def _func(self, conv): + return {} + + def __call__(self, conv=None, output=None): + _stat = self.response(**self._func(conv)) + if output is not None: + output.append(_stat) + return _stat + + def response(self, **kwargs): + try: + name = " ".join( + [s.strip() for s in self.__doc__.strip().split("\n")]) + except AttributeError: + name = "" + + res = { + "id": self.cid, + "status": self._status, + "name": name + } + + if self._message: + res["message"] = self._message + + if kwargs: + res.update(kwargs) + + return res + + +class ExpectedError(Check): + pass + + +class CriticalError(Check): + status = CRITICAL + + +class Information(Check): + status = INFORMATION + + +class Error(Check): + status = ERROR + + +class ResponseInfo(Information): + """Response information""" + + def _func(self, conv=None): + self._status = self.status + _msg = conv.last_content + + if isinstance(_msg, basestring): + self._message = _msg + else: + self._message = _msg.to_dict() + + return {} + + +class CheckErrorResponse(ExpectedError): + """ + Checks that the HTTP response status is outside the 200 or 300 range + or that an JSON encoded error message has been received + """ + cid = "check-error-response" + msg = "OP error" + + def _func(self, conv): + _response = conv.last_response + _content = conv.last_content + + res = {} + if _response.status_code >= 400: + content_type = _response.headers["content-type"] + if content_type is None: + res["content"] = _content + else: + res["content"] = _content + + return res + + +class VerifyBadRequestResponse(ExpectedError): + """ + Verifies that the OP returned a 400 Bad Request response containing a + Error message. + """ + cid = "verify-bad-request-response" + msg = "OP error" + + def _func(self, conv): + _response = conv.last_response + _content = conv.last_content + res = {} + if _response.status_code == 400: + pass + else: + self._message = "Expected a 400 error message" + self._status = CRITICAL + + return res + + +class VerifyError(Error): + cid = "verify-error" + + def _func(self, conv): + response = conv.last_response + if response.status_code == 400: + try: + resp = json.loads(response.text) + if "error" in resp: + return {} + except Exception: + pass + + item, msg = conv.protocol_response[-1] + try: + assert item.type().endswith("ErrorResponse") + except AssertionError: + self._message = "Expected an error response" + self._status = self.status + return {} + + try: + assert item["error"] in self._kwargs["error"] + except AssertionError: + self._message = "Wrong type of error, got %s" % item["error"] + self._status = self.status + + return {} + + +class WrapException(CriticalError): + """ + A runtime exception + """ + cid = "exception" + msg = "Test tool exception" + + def _func(self, conv=None): + self._status = self.status + self._message = traceback.format_exception(*sys.exc_info()) + return {} + + +class Other(CriticalError): + """ Other error """ + msg = "Other error" + + +class CheckHTTPResponse(CriticalError): + """ + Checks that the HTTP response status is within the 200 or 300 range + """ + cid = "check-http-response" + msg = "OP error" + + def _func(self, conv): + _response = conv.last_response + _content = conv.last_content + + res = {} + if _response.status_code >= 400: + self._status = self.status + self._message = self.msg + res["content"] = _content + res["url"] = conv.position + res["http_status"] = _response.status_code + + return res + + +class MissingRedirect(CriticalError): + """ At this point in the flow a redirect back to the client was expected. + """ + cid = "missing-redirect" + msg = "Expected redirect to the RP, got something else" + + def _func(self, conv=None): + self._status = self.status + return {"url": conv.position} + + +class Parse(CriticalError): + """ Parsing the response """ + cid = "response-parse" + errmsg = "Parse error" + + def _func(self, conv=None): + if conv.exception: + self._status = self.status + err = conv.exception + self._message = "%s: %s" % (err.__class__.__name__, err) + else: + _rmsg = conv.response_message + cname = _rmsg.type() + if conv.response_type != cname: + self._status = self.status + self._message = ( + "Didn't get a response of the type I expected:", + " '%s' instead of '%s', content:'%s'" % ( + cname, conv.response_type, _rmsg)) + return { + "response_type": conv.response_type, + "url": conv.position + } + + return {} + +def factory(cid, classes): + if len(classes) == 0: + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isclass(obj): + try: + classes[obj.cid] = obj + except AttributeError: + pass + + if cid in classes: + return classes[cid] + else: + return None diff --git a/src/saml2test/interaction.py b/src/saml2test/interaction.py new file mode 100644 index 00000000..d5f48ed1 --- /dev/null +++ b/src/saml2test/interaction.py @@ -0,0 +1,394 @@ +__author__ = 'rohe0002' + +import json +import logging + +from urlparse import urlparse +from bs4 import BeautifulSoup + +from mechanize import ParseResponseEx +from mechanize._form import ControlNotFoundError, AmbiguityError +from mechanize._form import ListControl + +logger = logging.getLogger(__name__) + +NO_CTRL = "No submit control with the name='%s' and value='%s' could be found" + + +class FlowException(Exception): + def __init__(self, function="", content="", url=""): + Exception.__init__(self) + self.function = function + self.content = content + self.url = url + + def __str__(self): + return json.dumps(self.__dict__) + + +class InteractionNeeded(Exception): + pass + + +def NoneFunc(): + return None + + +class RResponse(): + """ + A Response class that behaves in the way that mechanize expects it. + Links to a requests.Response + """ + def __init__(self, resp): + self._resp = resp + self.index = 0 + self.text = resp.text + if isinstance(self.text, unicode): + if resp.encoding == "UTF-8": + self.text = self.text.encode("utf-8") + else: + self.text = self.text.encode("latin-1") + self._len = len(self.text) + self.url = str(resp.url) + self.statuscode = resp.status_code + + def geturl(self): + return self._resp.url + + def __getitem__(self, item): + try: + return getattr(self._resp, item) + except AttributeError: + return getattr(self._resp.headers, item) + + def __getattribute__(self, item): + try: + return getattr(self._resp, item) + except AttributeError: + return getattr(self._resp.headers, item) + + def read(self, size=0): + """ + Read from the content of the response. The class remembers what has + been read so it's possible to read small consecutive parts of the + content. + + :param size: The number of bytes to read + :return: Somewhere between zero and 'size' number of bytes depending + on how much it left in the content buffer to read. + """ + if size: + if self._len < size: + return self.text + else: + if self._len == self.index: + part = None + elif self._len - self.index < size: + part = self.text[self.index:] + self.index = self._len + else: + part = self.text[self.index:self.index + size] + self.index += size + return part + else: + return self.text + + +class Interaction(object): + def __init__(self, httpc, interactions=None): + self.httpc = httpc + self.interactions = interactions + self.who = "Form process" + + def pick_interaction(self, _base="", content="", req=None): + unic = content + if content: + _bs = BeautifulSoup(content) + else: + _bs = None + + for interaction in self.interactions: + _match = 0 + for attr, val in interaction["matches"].items(): + if attr == "url": + if val == _base: + _match += 1 + elif attr == "title": + if _bs is None: + break + if _bs.title is None: + break + if val in _bs.title.contents: + _match += 1 + else: + _c = _bs.title.contents + if isinstance(_c, list) and not isinstance(_c, + basestring): + for _line in _c: + if val in _line: + _match += 1 + continue + elif attr == "content": + if unic and val in unic: + _match += 1 + elif attr == "class": + if req and val == req: + _match += 1 + + if _match == len(interaction["matches"]): + logger.info("Matched: %s" % interaction["matches"]) + return interaction + + raise InteractionNeeded("No interaction matched") + + def pick_form(self, response, url=None, **kwargs): + """ + Picks which form in a web-page that should be used + + :param response: A HTTP request response. A DResponse instance + :param content: The HTTP response content + :param url: The url the request was sent to + :param kwargs: Extra key word arguments + :return: The picked form or None of no form matched the criteria. + """ + + forms = ParseResponseEx(response) + if not forms: + raise FlowException(content=response.text, url=url) + + #if len(forms) == 1: + # return forms[0] + #else: + + _form = None + # ignore the first form, because I use ParseResponseEx which adds + # one form at the top of the list + forms = forms[1:] + if len(forms) == 1: + _form = forms[0] + else: + if "pick" in kwargs: + _dict = kwargs["pick"] + for form in forms: + if _form: + break + for key, _ava in _dict.items(): + if key == "form": + _keys = form.attrs.keys() + for attr, val in _ava.items(): + if attr in _keys and val == form.attrs[attr]: + _form = form + elif key == "control": + prop = _ava["id"] + _default = _ava["value"] + try: + orig_val = form[prop] + if isinstance(orig_val, basestring): + if orig_val == _default: + _form = form + elif _default in orig_val: + _form = form + except KeyError: + pass + except ControlNotFoundError: + pass + elif key == "method": + if form.method == _ava: + _form = form + else: + _form = None + + if not _form: + break + elif "index" in kwargs: + _form = forms[int(kwargs["index"])] + + return _form + + def do_click(self, form, **kwargs): + """ + Emulates the user clicking submit on a form. + + :param form: The form that should be submitted + :return: What do_request() returns + """ + + if "click" in kwargs: + request = None + _name = kwargs["click"] + try: + _ = form.find_control(name=_name) + request = form.click(name=_name) + except AmbiguityError: + # more than one control with that name + _val = kwargs["set"][_name] + _nr = 0 + while True: + try: + cntrl = form.find_control(name=_name, nr=_nr) + if cntrl.value == _val: + request = form.click(name=_name, nr=_nr) + break + else: + _nr += 1 + except ControlNotFoundError: + raise Exception(NO_CTRL % (_name, _val)) + else: + request = form.click() + + headers = {} + for key, val in request.unredirected_hdrs.items(): + headers[key] = val + + url = request._Request__original + + if form.method == "POST": + return self.httpc.send(url, "POST", data=request.data, + headers=headers) + else: + return self.httpc.send(url, "GET", headers=headers) + + def select_form(self, orig_response, **kwargs): + """ + Pick a form on a web page, possibly enter some information and submit + the form. + + :param orig_response: The original response (as returned by requests) + :return: The response do_click() returns + """ + logger.info("select_form") + response = RResponse(orig_response) + try: + _url = response.url + except KeyError: + _url = kwargs["location"] + + form = self.pick_form(response, _url, **kwargs) + #form.backwards_compatible = False + if not form: + raise Exception("Can't pick a form !!") + + if "set" in kwargs: + for key, val in kwargs["set"].items(): + if key.startswith("_"): + continue + if "click" in kwargs and kwargs["click"] == key: + continue + + try: + form[key] = val + except ControlNotFoundError: + pass + except TypeError: + cntrl = form.find_control(key) + if isinstance(cntrl, ListControl): + form[key] = [val] + else: + raise + + if form.action in kwargs["conv"].my_endpoints(): + return {"SAMLResponse": form["SAMLResponse"], + "RelayState": form["RelayState"]} + + return self.do_click(form, **kwargs) + + #noinspection PyUnusedLocal + def chose(self, orig_response, path, **kwargs): + """ + Sends a HTTP GET to a url given by the present url and the given + relative path. + + :param orig_response: The original response + :param content: The content of the response + :param path: The relative path to add to the base URL + :return: The response do_click() returns + """ + + if not path.startswith("http"): + try: + _url = orig_response.url + except KeyError: + _url = kwargs["location"] + + part = urlparse(_url) + url = "%s://%s%s" % (part[0], part[1], path) + else: + url = path + + logger.info("GET %s" % url) + return self.httpc.send(url, "GET") + #return resp, "" + + def post_form(self, orig_response, **kwargs): + """ + The same as select_form but with no possibility of change the content + of the form. + + :param httpc: A HTTP Client instance + :param orig_response: The original response (as returned by requests) + :param content: The content of the response + :return: The response do_click() returns + """ + response = RResponse(orig_response) + + form = self.pick_form(response, **kwargs) + + return self.do_click(form, **kwargs) + + #noinspection PyUnusedLocal + def parse(self, orig_response, **kwargs): + # content is a form from which I get the SAMLResponse + response = RResponse(orig_response) + + form = self.pick_form(response, **kwargs) + #form.backwards_compatible = False + if not form: + raise InteractionNeeded("Can't pick a form !!") + + return {"SAMLResponse": form["SAMLResponse"], + "RelayState": form["RelayState"]} + + #noinspection PyUnusedLocal + def interaction(self, args): + _type = args["type"] + if _type == "form": + return self.select_form + elif _type == "link": + return self.chose + elif _type == "response": + return self.parse + else: + return NoneFunc + +# ======================================================================== + + +class Action(object): + def __init__(self, args): + self.args = args or {} + self.request = None + + def update(self, dic): + self.args.update(dic) + + #noinspection PyUnusedLocal + def post_op(self, result, conv, args): + pass + + def __call__(self, httpc, conv, location, response, content, features): + intact = Interaction(httpc) + function = intact.interaction(self.args) + + try: + _args = self.args.copy() + except (KeyError, AttributeError): + _args = {} + + _args.update({"location": location, "features": features, "conv": conv}) + + logger.info("<-- FUNCTION: %s" % function.__name__) + logger.info("<-- ARGS: %s" % _args) + + result = function(response, **_args) + self.post_op(result, conv, _args) + return result diff --git a/src/saml2test/opfunc.py b/src/saml2test/opfunc.py new file mode 100644 index 00000000..2f88c70b --- /dev/null +++ b/src/saml2test/opfunc.py @@ -0,0 +1,373 @@ +import logging +import json + +from urlparse import urlparse + +from mechanize import ParseResponseEx +from mechanize._form import ControlNotFoundError, AmbiguityError +from mechanize._form import ListControl + +__author__ = 'rohe0002' + +logger = logging.getLogger(__name__) + + +class FlowException(Exception): + def __init__(self, function="", content="", url=""): + Exception.__init__(self) + self.function = function + self.content = content + self.url = url + + def __str__(self): + return json.dumps(self.__dict__) + + +class DResponse(): + """ A Response class that behaves in the way that mechanize expects it + """ + def __init__(self, **kwargs): + self.status = 200 # default + self.index = 0 + self._message = "" + self.url = "" + if kwargs: + for key, val in kwargs.items(): + if val: + self.__setitem__(key, val) + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __getitem__(self, item): + if item == "content-location": + return self.url + elif item == "content-length": + return len(self._message) + else: + return getattr(self, item) + + def geturl(self): + """ + The base url for the response + + :return: The url + """ + return self.url + + def read(self, size=0): + """ + Read from the content of the response. The class remembers what has + been read so it's possible to read small consecutive parts of the + content. + + :param size: The number of bytes to read + :return: Somewhere between zero and 'size' number of bytes depending + on how much it left in the content buffer to read. + """ + if size: + if self._len < size: + return self._message + else: + if self._len == self.index: + part = None + elif self._len - self.index < size: + part = self._message[self.index:] + self.index = self._len + else: + part = self._message[self.index:self.index + size] + self.index += size + return part + else: + return self._message + + def write(self, message): + """ + Write the message into the content buffer + + :param message: The message + """ + self._message = message + self._len = len(message) + + +def do_request(client, url, method, body="", headers=None): + """ + Sends a HTTP request. + + :param client: The client instance + :param url: Where to send the request + :param method: The HTTP method to use for the request + :param body: The request body + :param headers: The requset headers + :return: A tuple of + url - the url the request was sent to + response - the response to the request + content - the content of the response if any + """ + if headers is None: + headers = {} + + logger.info("--> URL: %s" % url) + logger.info("--> BODY: %s" % body) + logger.info("--> Headers: %s" % (headers,)) + + response = client.http_request(url, method=method, data=body, + headers=headers) + + logger.info("<-- RESPONSE: %s" % response) + logger.info("<-- CONTENT: %s" % response.text) + if response.cookies: + logger.info("<-- COOKIES: %s" % response.cookies) + + return url, response, response.text + + +def pick_form(response, content, url=None, **kwargs): + """ + Picks which form in a web-page that should be used + + :param response: A HTTP request response. A DResponse instance + :param content: The HTTP response content + :param url: The url the request was sent to + :return: The picked form or None of no form matched the criteria. + """ + + forms = ParseResponseEx(response) + if not forms: + raise FlowException(content=content, url=url) + + #if len(forms) == 1: + # return forms[0] + #else: + + _form = None + # ignore the first form for now + forms = forms[1:] + if len(forms) == 1: + _form = forms[0] + else: + if "pick" in kwargs: + _dict = kwargs["pick"] + for form in forms: + if _form: + break + for key, _ava in _dict.items(): + if key == "form": + _keys = form.attrs.keys() + for attr, val in _ava.items(): + if attr in _keys and val == form.attrs[attr]: + _form = form + elif key == "control": + prop = _ava["id"] + _default = _ava["value"] + try: + orig_val = form[prop] + if isinstance(orig_val, basestring): + if orig_val == _default: + _form = form + elif _default in orig_val: + _form = form + except KeyError: + pass + elif key == "method": + if form.method == _ava: + _form = form + else: + _form = None + + if not _form: + break + elif "index" in kwargs: + _form = forms[int(kwargs["index"])] + + return _form + + +def do_click(client, form, **kwargs): + """ + Emulates the user clicking submit on a form. + + :param client: The Client instance + :param form: The form that should be submitted + :return: What do_request() returns + """ + + if "click" in kwargs: + request = None + _name = kwargs["click"] + try: + _ = form.find_control(name=_name) + request = form.click(name=_name) + except AmbiguityError: + # more than one control with that name + _val = kwargs["set"][_name] + _nr = 0 + while True: + try: + cntrl = form.find_control(name=_name, nr=_nr) + if cntrl.value == _val: + request = form.click(name=_name, nr=_nr) + break + else: + _nr += 1 + except ControlNotFoundError: + raise Exception("No submit control with the name='%s' and " + "value='%s' could be found" % (_name, + _val)) + else: + request = form.click() + + headers = {} + for key, val in request.unredirected_hdrs.items(): + headers[key] = val + + url = request._Request__original + + if form.method == "POST": + return do_request(client, url, "POST", request.data, headers) + else: + return do_request(client, url, "GET", headers=headers) + + +def select_form(client, orig_response, content, **kwargs): + """ + Pick a form on a web page, possibly enter some information and submit + the form. + + :param client: The Client + :param orig_response: The original response (as returned by httplib2) + :param content: The content of the response + :return: The response do_click() returns + """ + try: + _url = orig_response.url + except KeyError: + _url = kwargs["location"] + # content is a form to be filled in and returned + if isinstance(content, unicode): + content = content.encode("utf-8") + + response = DResponse(status=orig_response.status_code, url=_url) + response.write(content) + + form = pick_form(response, content, _url, **kwargs) + #form.backwards_compatible = False + if not form: + raise Exception("Can't pick a form !!") + + if "set" in kwargs: + for key, val in kwargs["set"].items(): + if key.startswith("_"): + continue + if "click" in kwargs and kwargs["click"] == key: + continue + + try: + form[key] = val + except ControlNotFoundError: + pass + except TypeError: + cntrl = form.find_control(key) + if isinstance(cntrl, ListControl): + form[key] = [val] + else: + raise + + return do_click(client, form, **kwargs) + + +#noinspection PyUnusedLocal +def chose(client, orig_response, content, path, **kwargs): + """ + Sends a HTTP GET to a url given by the present url and the given + relative path. + + :param orig_response: The original response + :param content: The content of the response + :param path: The relative path to add to the base URL + :return: The response do_click() returns + """ + + if not path.startswith("http"): + try: + _url = orig_response.url + except KeyError: + _url = kwargs["location"] + + part = urlparse(_url) + url = "%s://%s%s" % (part[0], part[1], path) + else: + url = path + + return do_request(client, url, "GET") + + +def post_form(client, orig_response, content, **kwargs): + """ + The same as select_form but with no possibility of change the content + of the form. + + :param client: The Client instance + :param orig_response: The original response (as returned by httplib2) + :param content: The content of the response + :return: The response do_click() returns + """ + _url = orig_response.url + # content is a form to be filled in and returned + response = DResponse(status=orig_response.status_code, url=_url) + response.write(content) + + form = pick_form(response, content, _url, **kwargs) + + return do_click(client, form, **kwargs) + + +def NoneFunc(): + return None + + +def interaction(args): + _type = args["type"] + if _type == "form": + return select_form + elif _type == "link": + return chose + else: + return NoneFunc + +# ======================================================================== + + +class Operation(object): + def __init__(self, conv, args=None, features=None): + if args: + self.function = interaction(args) + + self.args = args or {} + self.request = None + self.conv = conv + self.features = features + self.cconf = conv.client_config + + def update(self, dic): + self.args.update(dic) + + #noinspection PyUnusedLocal + def post_op(self, result, environ, args): + pass + + def __call__(self, location, response, content, feature=None): + try: + _args = self.args.copy() + except (KeyError, AttributeError): + _args = {} + + _args["location"] = location + + logger.info("--> FUNCTION: %s" % self.function.__name__) + logger.info("--> ARGS: %s" % (_args,)) + + result = self.function(self.conv.client, response, content, **_args) + self.post_op(result, self.conv, _args) + return result diff --git a/src/saml2test/status.py b/src/saml2test/status.py new file mode 100644 index 00000000..4f5ba840 --- /dev/null +++ b/src/saml2test/status.py @@ -0,0 +1,11 @@ +__author__ = 'rolandh' + +INFORMATION = 0 +OK = 1 +WARNING = 2 +ERROR = 3 +CRITICAL = 4 +INTERACTION = 5 + +STATUSCODE = ["INFORMATION", "OK", "WARNING", "ERROR", "CRITICAL", + "INTERACTION"] diff --git a/src/saml2test/tool.py b/src/saml2test/tool.py new file mode 100644 index 00000000..db7c528b --- /dev/null +++ b/src/saml2test/tool.py @@ -0,0 +1,286 @@ +import cookielib +import sys +import traceback +import logging +from urlparse import parse_qs + +from saml2test.opfunc import Operation +from saml2test import FatalError +from saml2test.check import ExpectedError +from saml2test.check import INTERACTION +from saml2test.interaction import Interaction +from saml2test.interaction import Action +from saml2test.interaction import InteractionNeeded +from saml2test.status import STATUSCODE + +__author__ = 'rolandh' + +logger = logging.getLogger(__name__) + + +class Conversation(object): + """ + :ivar response: The received HTTP messages + :ivar protocol_response: List of the received protocol messages + """ + + def __init__(self, client, config, interaction, + check_factory=None, msg_factory=None, + features=None, verbose=False, expect_exception=None): + self.client = client + self.client_config = config + self.test_output = [] + self.features = features + self.verbose = verbose + self.check_factory = check_factory + self.msg_factory = msg_factory + self.expect_exception = expect_exception + + self.cjar = {"browser": cookielib.CookieJar(), + "rp": cookielib.CookieJar(), + "service": cookielib.CookieJar()} + + self.protocol_response = [] + self.last_response = None + self.last_content = None + self.response = None + self.interaction = Interaction(self.client, interaction) + self.exception = None + + def check_severity(self, stat): + if stat["status"] >= 4: + logger.error("WHERE: %s" % stat["id"]) + logger.error("STATUS:%s" % STATUSCODE[stat["status"]]) + try: + logger.error("HTTP STATUS: %s" % stat["http_status"]) + except KeyError: + pass + try: + logger.error("INFO: %s" % stat["message"]) + except KeyError: + pass + + raise FatalError + + def do_check(self, test, **kwargs): + if isinstance(test, basestring): + chk = self.check_factory(test)(**kwargs) + else: + chk = test(**kwargs) + stat = chk(self, self.test_output) + self.check_severity(stat) + + def err_check(self, test, err=None, bryt=True): + if err: + self.exception = err + chk = self.check_factory(test)() + chk(self, self.test_output) + if bryt: + e = FatalError("%s" % err) + e.trace = "".join(traceback.format_exception(*sys.exc_info())) + raise e + + def test_sequence(self, sequence): + for test in sequence: + if isinstance(test, tuple): + test, kwargs = test + else: + kwargs = {} + self.do_check(test, **kwargs) + if test == ExpectedError: + return False + return True + + def my_endpoints(self): + pass + + def intermit(self): + _response = self.last_response + _last_action = None + _same_actions = 0 + if _response.status_code >= 400: + done = True + else: + done = False + + url = _response.url + content = _response.text + while not done: + rdseq = [] + while _response.status_code in [302, 301, 303]: + url = _response.headers["location"] + if url in rdseq: + raise FatalError("Loop detected in redirects") + else: + rdseq.append(url) + if len(rdseq) > 8: + raise FatalError( + "Too long sequence of redirects: %s" % rdseq) + + logger.info("HTTP %d Location: %s" % (_response.status_code, + url)) + # If back to me + for_me = False + for redirect_uri in self.my_endpoints(): + if url.startswith(redirect_uri): + # Back at the RP + self.client.cookiejar = self.cjar["rp"] + for_me = True + try: + base, query = url.split("?") + except ValueError: + pass + else: + _response = parse_qs(query) + self.last_response = _response + self.last_content = _response + return _response + + if for_me: + done = True + break + else: + try: + logger.info("GET %s" % url) + _response = self.client.send(url, "GET") + except Exception, err: + raise FatalError("%s" % err) + + content = _response.text + logger.info("<-- CONTENT: %s" % content) + self.position = url + self.last_content = content + self.response = _response + + if _response.status_code >= 400: + done = True + break + + if done or url is None: + break + + _base = url.split("?")[0] + + try: + _spec = self.interaction.pick_interaction(_base, content) + except InteractionNeeded: + self.position = url + logger.error("Page Content: %s" % content) + raise + except KeyError: + self.position = url + logger.error("Page Content: %s" % content) + self.err_check("interaction-needed") + + if _spec == _last_action: + _same_actions += 1 + if _same_actions >= 3: + raise InteractionNeeded("Interaction loop detection") + else: + _last_action = _spec + + if len(_spec) > 2: + logger.info(">> %s <<" % _spec["page-type"]) + if _spec["page-type"] == "login": + self.login_page = content + + _op = Action(_spec["control"]) + + try: + _response = _op(self.client, self, url, _response, content, + self.features) + if isinstance(_response, dict): + self.last_response = _response + self.last_content = _response + return _response + content = _response.text + self.position = url + self.last_content = content + self.response = _response + + if _response.status_code >= 400: + break + except (FatalError, InteractionNeeded): + raise + except Exception, err: + self.err_check("exception", err, False) + + self.last_response = _response + try: + self.last_content = _response.text + except AttributeError: + self.last_content = None + + def init(self, phase): + self.creq, self.cresp = phase + + def setup_request(self): + self.request_spec = req = self.creq(conv=self) + + if isinstance(req, Operation): + for intact in self.interaction.interactions: + try: + if req.__class__.__name__ == intact["matches"]["class"]: + req.args = intact["args"] + break + except KeyError: + pass + else: + try: + self.request_args = req.request_args + except KeyError: + pass + try: + self.args = req.kw_args + except KeyError: + pass + + # The authorization dance is all done through the browser + if req.request == "AuthorizationRequest": + self.client.cookiejar = self.cjar["browser"] + # everything else by someone else, assuming the RP + else: + self.client.cookiejar = self.cjar["rp"] + + self.req = req + + def send(self): + pass + + def handle_result(self): + pass + + def do_query(self): + self.setup_request() + self.send() + if not self.handle_result(): + self.intermit() + self.handle_result() + + def do_sequence(self, oper): + try: + self.test_sequence(oper["tests"]["pre"]) + except KeyError: + pass + + for phase in oper["sequence"]: + self.init(phase) + try: + self.do_query() + except InteractionNeeded: + self.test_output.append({"status": INTERACTION, + "message": self.last_content, + "id": "exception", + "name": "interaction needed", + "url": self.position}) + break + except FatalError: + raise + except Exception, err: + #self.err_check("exception", err) + raise + + try: + self.test_sequence(oper["tests"]["post"]) + except KeyError: + pass |