From bc96c3856dfd6ffbd27e0f59acacfef2b71e4edd Mon Sep 17 00:00:00 2001 From: Ivan Kanakarakis Date: Thu, 10 Sep 2020 02:28:38 +0300 Subject: Replace assert with proper checks Signed-off-by: Ivan Kanakarakis --- src/saml2/assertion.py | 4 +-- src/saml2/authn_context/__init__.py | 4 ++- src/saml2/client.py | 7 ++++- src/saml2/client_base.py | 6 +++-- src/saml2/discovery.py | 31 ++++++++++++---------- src/saml2/ecp_client.py | 18 +++++++++++-- src/saml2/entity.py | 14 +++++++--- src/saml2/mdstore.py | 8 ++---- src/saml2/pack.py | 14 ++++++++-- src/saml2/request.py | 16 +++++++---- src/saml2/response.py | 52 ++++++++++++++++++------------------ src/saml2/saml.py | 27 ++++++++++++------- src/saml2/sigver.py | 10 ++++--- src/saml2/soap.py | 53 +++++++++++++++++++++++++++++-------- src/saml2/validate.py | 6 ----- 15 files changed, 173 insertions(+), 97 deletions(-) diff --git a/src/saml2/assertion.py b/src/saml2/assertion.py index ed043d9a..b47f14b9 100644 --- a/src/saml2/assertion.py +++ b/src/saml2/assertion.py @@ -296,9 +296,7 @@ def post_entity_categories(maps, **kwargs): else: attrs = atlist for _key in key: - try: - assert _key in ecs - except AssertionError: + if _key not in ecs: attrs = [] break elif key in ecs: diff --git a/src/saml2/authn_context/__init__.py b/src/saml2/authn_context/__init__.py index 97461396..3abdf74c 100644 --- a/src/saml2/authn_context/__init__.py +++ b/src/saml2/authn_context/__init__.py @@ -90,7 +90,9 @@ class AuthnBroker(object): if _ref is None: _ref = str(self.next) - assert _ref not in self.db["info"] + if _ref in self.db["info"]: + raise Exception("Internal error: reference is not unique") + self.db["info"][_ref] = _info try: self.db["key"][key].append(_ref) diff --git a/src/saml2/client.py b/src/saml2/client.py index 3dd447df..e283420a 100644 --- a/src/saml2/client.py +++ b/src/saml2/client.py @@ -75,7 +75,12 @@ class Saml2Client(Base): response_binding=response_binding, **kwargs) - assert negotiated_binding == binding + if negotiated_binding != binding: + raise ValueError( + "Negotiated binding '{}' does not match binding to use '{}'".format( + negotiated_binding, binding + ) + ) return reqid, info diff --git a/src/saml2/client_base.py b/src/saml2/client_base.py index 564f657e..871f3f2c 100644 --- a/src/saml2/client_base.py +++ b/src/saml2/client_base.py @@ -646,8 +646,10 @@ class Base(Entity): :return: """ - # One of them must be present - assert name_id or base_id or encrypted_id + if not name_id and not base_id and not encrypted_id: + raise ValueError( + "At least one of name_id, base_id or encrypted_id must be present." + ) if name_id: return self._message(NameIDMappingRequest, destination, message_id, diff --git a/src/saml2/discovery.py b/src/saml2/discovery.py index d3e42500..f85ebf44 100644 --- a/src/saml2/discovery.py +++ b/src/saml2/discovery.py @@ -24,10 +24,10 @@ class DiscoveryServer(Entity): # verify - for key in ["isPassive", "return", "returnIDParam", "policy", - 'entityID']: + for key in ["isPassive", "return", "returnIDParam", "policy", 'entityID']: try: - assert len(dsr[key]) == 1 + if len(dsr[key]) != 1: + raise Exception("Invalid DS request keys: {k}".format(k=key)) dsr[key] = dsr[key][0] except KeyError: pass @@ -37,9 +37,13 @@ class DiscoveryServer(Entity): if part.query: qp = parse.parse_qs(part.query) if "returnIDParam" in dsr: - assert dsr["returnIDParam"] not in qp.keys() + if dsr["returnIDParam"] in qp.keys(): + raise Exception( + "returnIDParam value should not be in the query params" + ) else: - assert "entityID" not in qp.keys() + if "entityID" in qp.keys(): + raise Exception("entityID should not be in the query params") else: # If metadata not used this is mandatory raise VerificationError("Missing mandatory parameter 'return'") @@ -47,10 +51,13 @@ class DiscoveryServer(Entity): if "policy" not in dsr: dsr["policy"] = IDPDISC_POLICY - try: - assert dsr["isPassive"] in ["true", "false"] - except KeyError: - pass + is_passive = dsr.get("isPassive") + if is_passive not in ["true", "false"]: + raise ValueError( + "Invalid value '{v}' for attribute '{attr}'".format( + v=is_passive, attr="isPassive" + ) + ) if "isPassive" in dsr and dsr["isPassive"] == "true": dsr["isPassive"] = True @@ -93,10 +100,6 @@ class DiscoveryServer(Entity): def verify_return(self, entity_id, return_url): for endp in self.metadata.discovery_response(entity_id): - try: - assert return_url.startswith(endp["location"]) - except AssertionError: - pass - else: + if not return_url.startswith(endp["location"]): return True return False diff --git a/src/saml2/ecp_client.py b/src/saml2/ecp_client.py index f0183f45..5265f99d 100644 --- a/src/saml2/ecp_client.py +++ b/src/saml2/ecp_client.py @@ -136,7 +136,14 @@ class Client(Entity): logger.debug("[P2] IdP response dict: %s", respdict) idp_response = respdict["body"] - assert idp_response.c_tag == "Response" + + expected_tag = "Response" + if idp_response.c_tag != expected_tag: + raise ValueError( + "Invalid Response tag '{invalid}' should be '{valid}'".format( + invalid=idp_response.c_tag, valid=expected_tag + ) + ) logger.debug("[P2] IdP AUTHN response: %s", idp_response) @@ -165,7 +172,14 @@ class Client(Entity): # AuthnRequest in the body or not authn_request = respdict["body"] - assert authn_request.c_tag == "AuthnRequest" + + expected_tag = "AuthnRequest" + if authn_request.c_tag != expected_tag: + raise ValueError( + "Invalid AuthnRequest tag '{invalid}' should be '{valid}'".format( + invalid=authn_request.c_tag, valid=expected_tag + ) + ) # ecp.RelayState among headers _relay_state = None diff --git a/src/saml2/entity.py b/src/saml2/entity.py index 91219d72..9f564ffa 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -848,7 +848,7 @@ class Entity(HTTPBase): _log_debug("Loaded request") if _request: - _request = _request.verify() + _request.verify() _log_debug("Verified request") if not _request: @@ -1192,14 +1192,14 @@ class Entity(HTTPBase): response.require_signature = True # Verify that the assertion is syntactically correct and the # signature on the assertion is correct if present. - response = response.verify(keys) + response.verify(keys) except SignatureError as err: if require_signature: logger.error("Signature Error: %s", err) raise else: response.require_signature = require_signature - response = response.verify(keys) + response.verify(keys) else: assertions_are_signed = True finally: @@ -1260,7 +1260,13 @@ class Entity(HTTPBase): _art = base64.b64decode(artifact) - assert _art[:2] == ARTIFACT_TYPECODE + typecode = _art[:2] + if typecode != ARTIFACT_TYPECODE: + raise ValueError( + "Invalid artifact typecode '{invalid}' should be {valid}".format( + invalid=typecode, valid=ARTIFACT_TYPECODE + ) + ) try: endpoint_index = str(int(_art[2:4])) diff --git a/src/saml2/mdstore.py b/src/saml2/mdstore.py index 32778af9..24fccb4d 100644 --- a/src/saml2/mdstore.py +++ b/src/saml2/mdstore.py @@ -405,9 +405,7 @@ class MetaData(object): return res def __eq__(self, other): - try: - assert isinstance(other, MetaData) - except AssertionError: + if not isinstance(other, MetaData): return False if len(self.entity) != len(other.entity): @@ -417,9 +415,7 @@ class MetaData(object): return False for key, item in self.entity.items(): - try: - assert item == other[key] - except AssertionError: + if item != other[key]: return False return True diff --git a/src/saml2/pack.py b/src/saml2/pack.py index 5a2534c2..090b13b4 100644 --- a/src/saml2/pack.py +++ b/src/saml2/pack.py @@ -179,7 +179,10 @@ def http_redirect_message(message, location, relay_state="", typ="SAMLRequest", if signer: # sigalgs, should be one defined in xmldsig - assert sigalg in [b for a, b in SIG_ALLOWED_ALG] + if sigalg not in [long_name for short_name, long_name in SIG_ALLOWED_ALG]: + raise Exception( + "Signature algo not in allowed list: {algo}".format(algo=sigalg) + ) args["SigAlg"] = sigalg string = "&".join([urlencode({k: args[k]}) @@ -269,7 +272,14 @@ def parse_soap_enveloped_saml(text, body_class, header_class=None): :return: header parts and body as saml.samlbase instances """ envelope = defusedxml.ElementTree.fromstring(text) - assert envelope.tag == '{%s}Envelope' % NAMESPACE + + envelope_tag = "{%s}Envelope" % NAMESPACE + if envelope.tag != envelope_tag: + raise ValueError( + "Invalid envelope tag '{invalid}' should be '{valid}'".format( + invalid=envelope.tag, valid=envelope_tag + ) + ) # print(len(envelope)) body = None diff --git a/src/saml2/request.py b/src/saml2/request.py index 479b993f..20b711cf 100644 --- a/src/saml2/request.py +++ b/src/saml2/request.py @@ -80,15 +80,21 @@ class Request(object): return issued_at > lower and issued_at < upper def _verify(self): - assert self.message.version == "2.0" + valid_version = "2.0" + if self.message.version != valid_version: + raise VersionMismatch( + "Invalid version {invalid} should be {valid}".format( + invalid=self.message.version, valid=valid_version + ) + ) + if self.message.destination and self.receiver_addrs and \ self.message.destination not in self.receiver_addrs: - logger.error("%s not in %s", self.message.destination, - self.receiver_addrs) + logger.error("%s not in %s", self.message.destination, self.receiver_addrs) raise OtherError("Not destined for me!") - assert self.issue_instant_ok() - return self + valid = self.issue_instant_ok() + return valid def loads(self, xmldata, binding, origdoc=None, must=None, only_valid_cert=False): diff --git a/src/saml2/response.py b/src/saml2/response.py index bcf3285f..5fd8043a 100644 --- a/src/saml2/response.py +++ b/src/saml2/response.py @@ -394,9 +394,7 @@ class StatusResponse(object): self.in_response_to, self.request_id) return None - try: - assert self.response.version == "2.0" - except AssertionError: + if self.response.version != "2.0": _ver = float(self.response.version) if _ver < 2.0: raise RequestVersionTooLow() @@ -410,9 +408,8 @@ class StatusResponse(object): self.return_addrs) return None - assert self.issue_instant_ok() - assert self.status_ok() - return self + valid = self.issue_instant_ok() and self.status_ok() + return valid def loads(self, xmldata, decode=True, origxml=None): return self._loads(xmldata, decode, origxml) @@ -509,9 +506,7 @@ class AuthnResponse(StatusResponse): def check_subject_confirmation_in_response_to(self, irp): for assertion in self.response.assertion: for _sc in assertion.subject.subject_confirmation: - try: - assert _sc.subject_confirmation_data.in_response_to == irp - except AssertionError: + if _sc.subject_confirmation_data.in_response_to != irp: return False return True @@ -546,15 +541,13 @@ class AuthnResponse(StatusResponse): self.assertion = None def authn_statement_ok(self, optional=False): - try: - # the assertion MUST contain one AuthNStatement - assert len(self.assertion.authn_statement) == 1 - except AssertionError: + n_authn_statements = len(self.assertion.authn_statement) + if n_authn_statements != 1: if optional: return True else: - logger.error("No AuthnStatement") - raise + msg = "Invalid number of AuthnStatement found in Response: {n}".format(n=n_authn_statements) + raise ValueError(msg) authn_statement = self.assertion.authn_statement[0] if authn_statement.session_not_on_or_after: @@ -664,9 +657,11 @@ class AuthnResponse(StatusResponse): if _assertion.advice.assertion: for tmp_assertion in _assertion.advice.assertion: if tmp_assertion.attribute_statement: - assert len(tmp_assertion.attribute_statement) == 1 - ava.update(self.read_attribute_statement( - tmp_assertion.attribute_statement[0])) + n_attr_statements = len(tmp_assertion.attribute_statement) + if n_attr_statements != 1: + msg = "Invalid number of AuthnStatement found in Response: {n}".format(n=n_attr_statements) + raise ValueError(msg) + ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0])) if _assertion.attribute_statement: logger.debug("Assertion contains %s attribute statement(s)", (len(self.assertion.attribute_statement))) @@ -729,7 +724,12 @@ class AuthnResponse(StatusResponse): def get_subject(self): """ The assertion must contain a Subject """ - assert self.assertion.subject + if not self.assertion.subject: + raise ValueError( + "Invalid assertion subject: {subject}".format( + subject=self.assertion.subject + ) + ) subject = self.assertion.subject subjconf = [] @@ -916,14 +916,14 @@ class AuthnResponse(StatusResponse): pass else: # This is a saml2int limitation - try: - assert ( - len(self.response.assertion) == 1 - or len(self.response.encrypted_assertion) == 1 - or self.assertion is not None + n_assertions = len(self.response.assertion) + n_assertions_enc = len(self.response.encrypted_assertion) + if n_assertions != 1 and n_assertions_enc != 1 and self.assertion is None: + raise Exception( + "Invalid number of assertions in Response: {n}".format( + n=n_assertions+n_assertions_enc + ) ) - except AssertionError: - raise Exception("No assertion part") if self.response.assertion: logger.debug("***Unencrypted assertion***") diff --git a/src/saml2/saml.py b/src/saml2/saml.py index e551bcb6..9753def0 100644 --- a/src/saml2/saml.py +++ b/src/saml2/saml.py @@ -116,8 +116,14 @@ class AttributeValueBase(SamlBase): def verify(self): if not self.text: - assert self.extension_attributes - assert self.extension_attributes[XSI_NIL] == "true" + if not self.extension_attributes: + raise Exception( + "Attribute value base should not have extension attributes" + ) + if self.extension_attributes[XSI_NIL] != "true": + raise Exception( + "Attribute value base should not have extension attributes" + ) return True else: SamlBase.verify(self) @@ -1029,12 +1035,11 @@ class AuthnContextType_(SamlBase): self.authenticating_authority = authenticating_authority or [] def verify(self): - # either or not both - if self.authn_context_decl: - assert self.authn_context_decl_ref is None - elif self.authn_context_decl_ref: - assert self.authn_context_decl is None - + if self.authn_context_decl and self.authn_context_decl_ref: + raise Exception( + "Invalid Response: " + "Cannot have both and " + ) return SamlBase.verify(self) @@ -1264,9 +1269,11 @@ class ConditionsType_(SamlBase): def verify(self): if self.one_time_use: - assert len(self.one_time_use) == 1 + if len(self.one_time_use) != 1: + raise Exception("Cannot be used more than once") if self.proxy_restriction: - assert len(self.proxy_restriction) == 1 + if len(self.proxy_restriction) != 1: + raise Exception("Cannot be used more than once") return SamlBase.verify(self) diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index 3cf7c215..0d960d21 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -679,7 +679,8 @@ class CryptoBackendXmlSec1(CryptoBackend): def __init__(self, xmlsec_binary, delete_tmpfiles=True, **kwargs): CryptoBackend.__init__(self, **kwargs) - assert (isinstance(xmlsec_binary, six.string_types)) + if not isinstance(xmlsec_binary, six.string_types): + raise ValueError("xmlsec_binary should be of type string") self.xmlsec = xmlsec_binary self.delete_tmpfiles = delete_tmpfiles try: @@ -1237,11 +1238,12 @@ class SecurityContext(object): sec_backend=None, delete_tmpfiles=True): + if not isinstance(crypto, CryptoBackend): + raise ValueError("crypto should be of type CryptoBackend") self.crypto = crypto - assert (isinstance(self.crypto, CryptoBackend)) - if sec_backend: - assert (isinstance(sec_backend, RSACrypto)) + if sec_backend and not isinstance(sec_backend, RSACrypto): + raise ValueError("sec_backend should be of type RSACrypto") self.sec_backend = sec_backend # Your private key for signing diff --git a/src/saml2/soap.py b/src/saml2/soap.py index 46c08b76..94af4f1f 100644 --- a/src/saml2/soap.py +++ b/src/saml2/soap.py @@ -6,6 +6,7 @@ Suppport for the client part of the SAML2.0 SOAP binding. """ import logging +import re from saml2 import create_class_from_element_tree from saml2.samlp import NAMESPACE as SAMLP_NAMESPACE @@ -136,14 +137,25 @@ def parse_soap_enveloped_saml_thingy(text, expected_tags): """ envelope = defusedxml.ElementTree.fromstring(text) - # Make sure it's a SOAP message - assert envelope.tag == '{%s}Envelope' % soapenv.NAMESPACE + envelope_tag = "{%s}Envelope" % soapenv.NAMESPACE + if envelope.tag != envelope_tag: + raise ValueError( + "Invalid envelope tag '{invalid}' should be '{valid}'".format( + invalid=envelope.tag, valid=envelope_tag + ) + ) + + if len(envelope) < 1: + raise Exception("No items in envelope.") - assert len(envelope) >= 1 body = None for part in envelope: if part.tag == '{%s}Body' % soapenv.NAMESPACE: - assert len(part) == 1 + n_children = len(part) + if n_children != 1: + raise Exception( + "Expected a single child element, found {n}".format(n=n_children) + ) body = part break @@ -157,7 +169,6 @@ def parse_soap_enveloped_saml_thingy(text, expected_tags): raise WrongMessageType("Was '%s' expected one of %s" % (saml_part.tag, expected_tags)) -import re NS_AND_TAG = re.compile(r"\{([^}]+)\}(.*)") @@ -188,13 +199,23 @@ def class_instances_from_soap_enveloped_saml_thingies(text, modules): except Exception as exc: raise XmlParseError("%s" % exc) - assert envelope.tag == '{%s}Envelope' % soapenv.NAMESPACE - assert len(envelope) >= 1 + envelope_tag = "{%s}Envelope" % soapenv.NAMESPACE + if envelope.tag != envelope_tag: + raise ValueError( + "Invalid envelope tag '{invalid}' should be '{valid}'".format( + invalid=envelope.tag, valid=envelope_tag + ) + ) + + if len(envelope) < 1: + raise Exception("No items in envelope.") + env = {"header": [], "body": None} for part in envelope: if part.tag == '{%s}Body' % soapenv.NAMESPACE: - assert len(part) == 1 + if len(envelope) < 1: + raise Exception("No items in envelope part.") env["body"] = instanciate_class(part[0], modules) elif part.tag == "{%s}Header" % soapenv.NAMESPACE: for item in part: @@ -214,13 +235,23 @@ def open_soap_envelope(text): except Exception as exc: raise XmlParseError("%s" % exc) - assert envelope.tag == '{%s}Envelope' % soapenv.NAMESPACE - assert len(envelope) >= 1 + envelope_tag = "{%s}Envelope" % soapenv.NAMESPACE + if envelope.tag != envelope_tag: + raise ValueError( + "Invalid envelope tag '{invalid}' should be '{valid}'".format( + invalid=envelope.tag, valid=envelope_tag + ) + ) + + if len(envelope) < 1: + raise Exception("No items in envelope.") + content = {"header": [], "body": None} for part in envelope: if part.tag == '{%s}Body' % soapenv.NAMESPACE: - assert len(part) == 1 + if len(envelope) < 1: + raise Exception("No items in envelope part.") content["body"] = ElementTree.tostring(part[0], encoding="UTF-8") elif part.tag == "{%s}Header" % soapenv.NAMESPACE: for item in part: diff --git a/src/saml2/validate.py b/src/saml2/validate.py index b098de59..334b3d14 100644 --- a/src/saml2/validate.py +++ b/src/saml2/validate.py @@ -425,12 +425,6 @@ def valid_instance(instance): "Class '%s' instance cardinality error: %s" % ( class_name, "too few values on %s" % name)) -# if not _has_val: -# if class_name != "RequestedAttribute": -# # Not allow unless xsi:nil="true" -# assert instance.extension_attributes -# assert instance.extension_attributes[XSI_NIL] == "true" - return True -- cgit v1.2.1