diff options
author | Ivan Kanakarakis <ivan.kanak@gmail.com> | 2018-08-04 02:20:52 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-08-04 02:20:52 +0300 |
commit | cf144871dc8d457d4bbc04e0b019bc937be4e9de (patch) | |
tree | a46a50c5bb12b0b0f8c443835d2f98e0febc8857 | |
parent | 134f5d78d5354f0be48a2461e0b9d8aeb3735329 (diff) | |
parent | 96948b580f70ab69d53b04cb657b498582eed32b (diff) | |
download | pysaml2-cf144871dc8d457d4bbc04e0b019bc937be4e9de.tar.gz |
Merge pull request #530 from c00kiemon5ter/feature-set-id-attr-name
Allow configuration and specification of id attribute name
-rw-r--r-- | src/saml2/config.py | 23 | ||||
-rw-r--r-- | src/saml2/httputil.py | 2 | ||||
-rw-r--r-- | src/saml2/sigver.py | 1094 | ||||
-rw-r--r-- | tests/okta_assertion | 2 | ||||
-rw-r--r-- | tests/okta_response.xml | 47 | ||||
-rw-r--r-- | tests/test_40_sigver.py | 30 |
6 files changed, 629 insertions, 569 deletions
diff --git a/src/saml2/config.py b/src/saml2/config.py index 1a832871..8f90afc8 100644 --- a/src/saml2/config.py +++ b/src/saml2/config.py @@ -28,10 +28,21 @@ __author__ = 'rolandh' COMMON_ARGS = [ - "entityid", "xmlsec_binary", "debug", "key_file", "cert_file", - "encryption_keypairs", "additional_cert_files", - "metadata_key_usage", "secret", "accepted_time_diff", "name", "ca_certs", - "description", "valid_for", "verify_ssl_cert", + "debug", + "entityid", + "xmlsec_binary", + "key_file", + "cert_file", + "encryption_keypairs", + "additional_cert_files", + "metadata_key_usage", + "secret", + "accepted_time_diff", + "name", + "ca_certs", + "description", + "valid_for", + "verify_ssl_cert", "organization", "contact_person", "name_form", @@ -54,7 +65,8 @@ COMMON_ARGS = [ "validate_certificate", "extensions", "allow_unknown_attributes", - "crypto_backend" + "crypto_backend", + "id_attr_name", ] SP_ARGS = [ @@ -210,6 +222,7 @@ class Config(object): self.name_qualifier = "" self.entity_category = "" self.crypto_backend = 'xmlsec1' + self.id_attr_name = None self.scope = "" self.allow_unknown_attributes = False self.extension_schema = {} diff --git a/src/saml2/httputil.py b/src/saml2/httputil.py index 0e7f32a6..253ff570 100644 --- a/src/saml2/httputil.py +++ b/src/saml2/httputil.py @@ -157,7 +157,7 @@ class BadGateway(Response): _status = "502 Bad Gateway" -class HttpParameters(): +class HttpParameters(object): """GET or POST signature parameters for Redirect or POST-SimpleSign bindings because they are not contained in XML unlike the POST binding """ diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index de399c4d..576993be 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -1,7 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# - """ Functions connected to signing and verifying. Based on the use of xmlsec1 binaries and not the python xmlsec module. """ @@ -60,15 +56,13 @@ from saml2.xmlenc import CipherData from saml2.xmlenc import CipherValue from saml2.xmlenc import EncryptedData + logger = logging.getLogger(__name__) -SIG = "{%s#}%s" % (ds.NAMESPACE, "Signature") +SIG = '{{{ns}#}}{attribute}'.format(ns=ds.NAMESPACE, attribute='Signature') -RSA_1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5" -TRIPLE_DES_CBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc" -XMLTAG = "<?xml version='1.0'?>" -PREFIX1 = "<?xml version='1.0' encoding='UTF-8'?>" -PREFIX2 = '<?xml version="1.0" encoding="UTF-8"?>' +RSA_1_5 = 'http://www.w3.org/2001/04/xmlenc#rsa-1_5' +TRIPLE_DES_CBC = 'http://www.w3.org/2001/04/xmlenc#tripledes-cbc' class SigverError(SAMLError): @@ -114,10 +108,14 @@ def read_file(*args, **kwargs): def rm_xmltag(statement): + XMLTAG = "<?xml version='1.0'?>" + PREFIX1 = "<?xml version='1.0' encoding='UTF-8'?>" + PREFIX2 = '<?xml version="1.0" encoding="UTF-8"?>' + try: _t = statement.startswith(XMLTAG) except TypeError: - statement = statement.decode("utf8") + statement = statement.decode() _t = statement.startswith(XMLTAG) if _t: @@ -167,12 +165,12 @@ def get_xmlsec_binary(paths=None): :return: full name of the xmlsec1 binary found. If no binaries are found then an exception is raised. """ - if os.name == "posix": - bin_name = ["xmlsec1"] - elif os.name == "nt": - bin_name = ["xmlsec.exe", "xmlsec1.exe"] + if os.name == 'posix': + bin_name = ['xmlsec1'] + elif os.name == 'nt': + bin_name = ['xmlsec.exe', 'xmlsec1.exe'] else: # Default !? - bin_name = ["xmlsec1"] + bin_name = ['xmlsec1'] if paths: for bname in bin_name: @@ -184,7 +182,7 @@ def get_xmlsec_binary(paths=None): except OSError: pass - for path in os.environ["PATH"].split(os.pathsep): + for path in os.environ['PATH'].split(os.pathsep): for bname in bin_name: fil = os.path.join(path, bname) try: @@ -193,10 +191,10 @@ def get_xmlsec_binary(paths=None): except OSError: pass - raise SigverError("Can't find %s" % bin_name) + raise SigverError('Cannot find {binary}'.format(binary=bin_name)) -def _get_xmlsec_cryptobackend(path=None, search_paths=None, debug=False): +def _get_xmlsec_cryptobackend(path=None, search_paths=None): """ Initialize a CryptoBackendXmlSec1 crypto backend. @@ -204,18 +202,12 @@ def _get_xmlsec_cryptobackend(path=None, search_paths=None, debug=False): """ if path is None: path = get_xmlsec_binary(paths=search_paths) - return CryptoBackendXmlSec1(path, debug=debug) - - -ID_ATTR = "ID" -NODE_NAME = "urn:oasis:names:tc:SAML:2.0:assertion:Assertion" -ENC_NODE_NAME = "urn:oasis:names:tc:SAML:2.0:assertion:EncryptedAssertion" -ENC_KEY_CLASS = "EncryptedKey" - -_TEST_ = True + return CryptoBackendXmlSec1(path) -# -------------------------------------------------------------------------- +NODE_NAME = 'urn:oasis:names:tc:SAML:2.0:assertion:Assertion' +ENC_NODE_NAME = 'urn:oasis:names:tc:SAML:2.0:assertion:EncryptedAssertion' +ENC_KEY_CLASS = 'EncryptedKey' def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, @@ -235,8 +227,6 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, """ cinst = None - # print("make_vals(%s, %s)" % (val, klass)) - if isinstance(val, dict): cinst = _instance(klass, val, seccont, base64encode=base64encode, elements_to_sign=elements_to_sign) @@ -245,9 +235,18 @@ def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, cinst = klass().set_text(val) except ValueError: if not part: - cis = [_make_vals(sval, klass, seccont, klass_inst, prop, - True, base64encode, elements_to_sign) for sval - in val] + cis = [ + _make_vals( + sval, + klass, + seccont, + klass_inst, + prop, + True, + base64encode, + elements_to_sign) + for sval in val + ] setattr(klass_inst, prop, cis) else: raise @@ -264,22 +263,19 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None): instance = klass() for prop in instance.c_attributes.values(): - # print("# %s" % (prop)) if prop in ava: if isinstance(ava[prop], bool): - setattr(instance, prop, str(ava[prop]).encode('utf-8')) + setattr(instance, prop, str(ava[prop]).encode()) elif isinstance(ava[prop], int): - setattr(instance, prop, "%d" % ava[prop]) + setattr(instance, prop, str(ava[prop])) else: setattr(instance, prop, ava[prop]) - if "text" in ava: - instance.set_text(ava["text"], base64encode) + if 'text' in ava: + instance.set_text(ava['text'], base64encode) for prop, klassdef in instance.c_children.values(): - # print("## %s, %s" % (prop, klassdef)) if prop in ava: - # print("### %s" % ava[prop]) if isinstance(klassdef, list): # means there can be a list of values _make_vals(ava[prop], klassdef[0], seccont, instance, prop, @@ -290,16 +286,16 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None): True, base64encode, elements_to_sign) setattr(instance, prop, cis) - if "extension_elements" in ava: - for item in ava["extension_elements"]: + if 'extension_elements' in ava: + for item in ava['extension_elements']: instance.extension_elements.append( - ExtensionElement(item["tag"]).loadd(item)) + ExtensionElement(item['tag']).loadd(item)) - if "extension_attributes" in ava: - for key, val in ava["extension_attributes"].items(): + if 'extension_attributes' in ava: + for key, val in ava['extension_attributes'].items(): instance.extension_attributes[key] = val - if "signature" in ava: + if 'signature' in ava: elements_to_sign.append((class_name(instance), instance.id)) return instance @@ -318,26 +314,12 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None): for (node_name, nodeid) in elements_to_sign: signed_xml = seccont.sign_statement( signed_xml, node_name=node_name, node_id=nodeid) - - # print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") - # print("%s" % signed_xml) - # print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") return signed_xml else: return instance -# -------------------------------------------------------------------------- -# def create_id(): -# """ Create a string of 40 random characters from the set [a-p], -# can be used as a unique identifier of objects. -# -# :return: The string of random characters -# """ -# return rndstr(40, "abcdefghijklmonp") - - -def make_temp(string, suffix="", decode=True, delete=True): +def make_temp(string, suffix='', decode=True, delete=True): """ xmlsec needs files in some cases where only strings exist, hence the need for this function. It creates a temporary file with the string as only content. @@ -354,7 +336,7 @@ def make_temp(string, suffix="", decode=True, delete=True): ntf = NamedTemporaryFile(suffix=suffix, delete=delete) # Python3 tempfile requires byte-like object if not isinstance(string, six.binary_type): - string = string.encode("utf8") + string = string.encode() if decode: ntf.write(base64.b64decode(string)) @@ -368,13 +350,11 @@ def split_len(seq, length): return [seq[i:i + length] for i in range(0, len(seq), length)] -# -------------------------------------------------------------------------- - -M2_TIME_FORMAT = "%b %d %H:%M:%S %Y" +M2_TIME_FORMAT = '%b %d %H:%M:%S %Y' def to_time(_time): - assert _time.endswith(" GMT") + assert _time.endswith(' GMT') _time = _time[:-4] return mktime(str_to_time(_time, M2_TIME_FORMAT)) @@ -412,15 +392,14 @@ def cert_from_key_info(key_info, ignore_age=False): """ res = [] for x509_data in key_info.x509_data: - # print("X509Data",x509_data) x509_certificate = x509_data.x509_certificate cert = x509_certificate.text.strip() - cert = "\n".join(split_len("".join([s.strip() for s in + cert = '\n'.join(split_len(''.join([s.strip() for s in cert.split()]), 64)) if ignore_age or active_cert(cert): res.append(cert) else: - logger.info("Inactive cert") + logger.info('Inactive cert') return res @@ -436,18 +415,18 @@ def cert_from_key_info_dict(key_info, ignore_age=False): :return: A possibly empty list of certs in their text representation """ res = [] - if not "x509_data" in key_info: + if 'x509_data' not in key_info: return res - for x509_data in key_info["x509_data"]: - x509_certificate = x509_data["x509_certificate"] - cert = x509_certificate["text"].strip() - cert = "\n".join(split_len("".join([s.strip() for s in - cert.split()]), 64)) + for x509_data in key_info['x509_data']: + x509_certificate = x509_data['x509_certificate'] + cert = x509_certificate['text'].strip() + cert = '\n'.join(split_len(''.join( + [s.strip() for s in cert.split()]), 64)) if ignore_age or active_cert(cert): res.append(cert) else: - logger.info("Inactive cert") + logger.info('Inactive cert') return res @@ -464,31 +443,17 @@ def cert_from_instance(instance): return [] -# ============================================================================= - - -def intarr2long(arr): - return long(''.join(["%02x" % byte for byte in arr]), 16) - - -def dehexlify(bi): - s = hexlify(bi) - return [int(s[i] + s[i + 1], 16) for i in range(0, len(s), 2)] - - -def base64_to_long(data): - _d = base64.urlsafe_b64decode(data + '==') - return intarr2long(dehexlify(_d)) - - def extract_rsa_key_from_x509_cert(pem): cert = saml2.cryptography.pki.load_pem_x509_certificate(pem) return cert.public_key() def pem_format(key): - return "\n".join(["-----BEGIN CERTIFICATE-----", - key, "-----END CERTIFICATE-----"]).encode('ascii') + return '\n'.join([ + '-----BEGIN CERTIFICATE-----', + key, + '-----END CERTIFICATE-----' + ]).encode('ascii') def import_rsa_key_from_file(filename): @@ -505,9 +470,9 @@ def parse_xmlsec_output(output): :return: A boolean; True if the command was a success otherwise False """ for line in output.splitlines(): - if line == "OK": + if line == 'OK': return True - elif line == "FAIL": + elif line == 'FAIL': raise XmlsecError(output) raise XmlsecError(output) @@ -553,8 +518,17 @@ SIGNER_ALGS = { SIG_RSA_SHA512: RSASigner(saml2.cryptography.asymmetric.hashes.SHA512()), } -REQ_ORDER = ["SAMLRequest", "RelayState", "SigAlg"] -RESP_ORDER = ["SAMLResponse", "RelayState", "SigAlg"] +REQ_ORDER = [ + 'SAMLRequest', + 'RelayState', + 'SigAlg', +] + +RESP_ORDER = [ + 'SAMLResponse', + 'RelayState', + 'SigAlg', +] class RSACrypto(object): @@ -585,22 +559,23 @@ def verify_redirect_signature(saml_msg, crypto, cert=None, sigkey=None): """ try: - signer = crypto.get_signer(saml_msg["SigAlg"], sigkey) + signer = crypto.get_signer(saml_msg['SigAlg'], sigkey) except KeyError: - raise Unsupported("Signature algorithm: %s" % saml_msg["SigAlg"]) + raise Unsupported('Signature algorithm: {alg}'.format( + alg=saml_msg['SigAlg'])) else: - if saml_msg["SigAlg"] in SIGNER_ALGS: - if "SAMLRequest" in saml_msg: + if saml_msg['SigAlg'] in SIGNER_ALGS: + if 'SAMLRequest' in saml_msg: _order = REQ_ORDER - elif "SAMLResponse" in saml_msg: + elif 'SAMLResponse' in saml_msg: _order = RESP_ORDER else: raise Unsupported( - "Verifying signature on something that should not be " - "signed") + 'Verifying signature on something that should not be ' + 'signed') _args = saml_msg.copy() - del _args["Signature"] # everything but the signature - string = "&".join( + del _args['Signature'] # everything but the signature + string = '&'.join( [urlencode({k: _args[k]}) for k in _order if k in _args]).encode('ascii') @@ -609,23 +584,20 @@ def verify_redirect_signature(saml_msg, crypto, cert=None, sigkey=None): else: _key = sigkey - _sign = base64.b64decode(saml_msg["Signature"]) + _sign = base64.b64decode(saml_msg['Signature']) return bool(signer.verify(string, _sign, _key)) -LOG_LINE = 60 * "=" + "\n%s\n" + 60 * "-" + "\n%s" + 60 * "=" -LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "=" +LOG_LINE = 60 * '=' + '\n%s\n' + 60 * '-' + '\n%s' + 60 * '=' +LOG_LINE_2 = 60 * '=' + '\n%s\n%s\n' + 60 * '-' + '\n%s' + 60 * '=' def make_str(txt): if isinstance(txt, six.string_types): return txt else: - return txt.decode("utf8") - - -# --------------------------------------------------------------------------- + return txt.decode() def read_cert_from_file(cert_file, cert_type): @@ -638,64 +610,62 @@ def read_cert_from_file(cert_file, cert_type): """ if not cert_file: - return "" + return '' - if cert_type == "pem": - _a = read_file(cert_file, 'rb').decode("utf8") - _b = _a.replace("\r\n", "\n") - lines = _b.split("\n") + if cert_type == 'pem': + _a = read_file(cert_file, 'rb').decode() + _b = _a.replace('\r\n', '\n') + lines = _b.split('\n') - for pattern in ("-----BEGIN CERTIFICATE-----", - "-----BEGIN PUBLIC KEY-----"): + for pattern in ( + '-----BEGIN CERTIFICATE-----', + '-----BEGIN PUBLIC KEY-----'): if pattern in lines: lines = lines[lines.index(pattern) + 1:] break else: - raise CertificateError("Strange beginning of PEM file") + raise CertificateError('Strange beginning of PEM file') - for pattern in ("-----END CERTIFICATE-----", - "-----END PUBLIC KEY-----"): + for pattern in ( + '-----END CERTIFICATE-----', + '-----END PUBLIC KEY-----'): if pattern in lines: lines = lines[:lines.index(pattern)] break else: - raise CertificateError("Strange end of PEM file") - return make_str("".join(lines).encode("utf8")) + raise CertificateError('Strange end of PEM file') + return make_str(''.join(lines).encode()) - if cert_type in ["der", "cer", "crt"]: + if cert_type in ['der', 'cer', 'crt']: data = read_file(cert_file, 'rb') _cert = base64.b64encode(data) return make_str(_cert) -class CryptoBackend(): - def __init__(self, debug=False): - self.debug = debug - +class CryptoBackend(object): def version(self): raise NotImplementedError() def encrypt(self, text, recv_key, template, key_type): raise NotImplementedError() - def encrypt_assertion(self, statement, enc_key, template, key_type, - node_xpath): + def encrypt_assertion(self, statement, enc_key, template, key_type, node_xpath): raise NotImplementedError() - def decrypt(self, enctext, key_file): + def decrypt(self, enctext, key_file, id_attr): raise NotImplementedError() - def sign_statement(self, statement, node_name, key_file, node_id, - id_attr): + def sign_statement(self, statement, node_name, key_file, node_id, id_attr): raise NotImplementedError() - def validate_signature(self, enctext, cert_file, cert_type, node_name, - node_id, id_attr): + def validate_signature(self, enctext, cert_file, cert_type, node_name, node_id, id_attr): raise NotImplementedError() -ASSERT_XPATH = ''.join(["/*[local-name()=\"%s\"]" % v for v in [ - "Response", "EncryptedAssertion", "Assertion"]]) +ASSERT_XPATH = ''.join([ + '/*[local-name()=\'{name}\']'.format(name=n) + for n in ['Response', 'EncryptedAssertion', 'Assertion'] +]) class CryptoBackendXmlSec1(CryptoBackend): @@ -721,44 +691,49 @@ class CryptoBackendXmlSec1(CryptoBackend): pass def version(self): - com_list = [self.xmlsec, "--version"] + com_list = [self.xmlsec, '--version'] pof = Popen(com_list, stderr=PIPE, stdout=PIPE) content, _ = pof.communicate() content = content.decode('ascii') try: - return content.split(" ")[1] + return content.split(' ')[1] except IndexError: - return "" + return '' - def encrypt(self, text, recv_key, template, session_key_type, xpath=""): + def encrypt(self, text, recv_key, template, session_key_type, xpath=''): """ :param text: The text to be compiled :param recv_key: Filename of a file where the key resides :param template: Filename of a file with the pre-encryption part :param session_key_type: Type and size of a new session key - "des-192" generates a new 192 bits DES key for DES3 encryption + 'des-192' generates a new 192 bits DES key for DES3 encryption :param xpath: What should be encrypted :return: """ - logger.debug("Encryption input len: %d", len(text)) - _, fil = make_temp(str(text).encode('utf-8'), decode=False) + logger.debug('Encryption input len: %d', len(text)) + _, fil = make_temp(str(text).encode(), decode=False) - com_list = [self.xmlsec, "--encrypt", "--pubkey-cert-pem", recv_key, - "--session-key", session_key_type, "--xml-data", fil] + com_list = [ + self.xmlsec, + '--encrypt', + '--pubkey-cert-pem', recv_key, + '--session-key', session_key_type, + '--xml-data', fil, + ] if xpath: com_list.extend(['--node-xpath', xpath]) - (_stdout, _stderr, output) = self._run_xmlsec(com_list, [template], - exception=DecryptError, - validate_output=False) - if isinstance(output, six.binary_type): - output = output.decode('utf-8') + (_stdout, _stderr, output) = self._run_xmlsec( + com_list, + [template], + exception=DecryptError, + validate_output=False) + return output - def encrypt_assertion(self, statement, enc_key, template, - key_type="des-192", node_xpath=None, node_id=None): + def encrypt_assertion(self, statement, enc_key, template, key_type='des-192', node_xpath=None, node_id=None): """ Will encrypt an assertion @@ -772,29 +747,38 @@ class CryptoBackendXmlSec1(CryptoBackend): if isinstance(statement, SamlBase): statement = pre_encrypt_assertion(statement) - _, fil = make_temp(str(statement).encode('utf-8'), decode=False, + _, fil = make_temp(str(statement).encode(), decode=False, delete=False) - _, tmpl = make_temp(str(template).encode('utf-8'), decode=False) + _, tmpl = make_temp(str(template).encode(), decode=False) if not node_xpath: node_xpath = ASSERT_XPATH - com_list = [self.xmlsec, "encrypt", "--pubkey-cert-pem", enc_key, - "--session-key", key_type, "--xml-data", fil, - "--node-xpath", node_xpath] + com_list = [ + self.xmlsec, + '--encrypt', + '--pubkey-cert-pem', enc_key, + '--session-key', key_type, + '--xml-data', fil, + '--node-xpath', node_xpath, + ] + if node_id: - com_list.extend(["--node-id", node_id]) + com_list.extend(['--node-id', node_id]) (_stdout, _stderr, output) = self._run_xmlsec( - com_list, [tmpl], exception=EncryptError, validate_output=False) + com_list, + [tmpl], + exception=EncryptError, + validate_output=False) os.unlink(fil) if not output: raise EncryptError(_stderr) - return output.decode('utf-8') + return output.decode() - def decrypt(self, enctext, key_file): + def decrypt(self, enctext, key_file, id_attr): """ :param enctext: XML document containing an encrypted part @@ -802,19 +786,26 @@ class CryptoBackendXmlSec1(CryptoBackend): :return: The decrypted document """ - logger.debug("Decrypt input len: %d", len(enctext)) - _, fil = make_temp(str(enctext).encode('utf-8'), decode=False) + logger.debug('Decrypt input len: %d', len(enctext)) + _, fil = make_temp(str(enctext).encode(), decode=False) + + com_list = [ + self.xmlsec, + '--decrypt', + '--privkey-pem', key_file, + '--id-attr:{id_attr}'.format(id_attr=id_attr), + ENC_KEY_CLASS, + ] - com_list = [self.xmlsec, "--decrypt", "--privkey-pem", - key_file, "--id-attr:%s" % ID_ATTR, ENC_KEY_CLASS] + (_stdout, _stderr, output) = self._run_xmlsec( + com_list, + [fil], + exception=DecryptError, + validate_output=False) - (_stdout, _stderr, output) = self._run_xmlsec(com_list, [fil], - exception=DecryptError, - validate_output=False) - return output.decode('utf-8') + return output.decode() - def sign_statement(self, statement, node_name, key_file, node_id, - id_attr): + def sign_statement(self, statement, node_name, key_file, node_id, id_attr): """ Sign an XML statement. @@ -829,31 +820,40 @@ class CryptoBackendXmlSec1(CryptoBackend): if isinstance(statement, SamlBase): statement = str(statement) - _, fil = make_temp(statement, suffix=".xml", - decode=False, delete=self._xmlsec_delete_tmpfiles) + _, fil = make_temp( + statement, + suffix='.xml', + decode=False, + delete=self._xmlsec_delete_tmpfiles) + + com_list = [ + self.xmlsec, + '--sign', + '--privkey-pem', key_file, + '--id-attr:{id_attr_name}'.format(id_attr_name=id_attr), + node_name, + ] - com_list = [self.xmlsec, "--sign", - "--privkey-pem", key_file, - "--id-attr:%s" % id_attr, node_name] if node_id: - com_list.extend(["--node-id", node_id]) + com_list.extend(['--node-id', node_id]) try: (stdout, stderr, signed_statement) = self._run_xmlsec( - com_list, [fil], validate_output=False) + com_list, + [fil], + validate_output=False) + # this doesn't work if --store-signatures are used - if stdout == "": + if stdout == '': if signed_statement: - return signed_statement.decode('utf-8') - logger.error( - "Signing operation failed :\nstdout : %s\nstderr : %s", - stdout, stderr) + return signed_statement.decode() + + logger.error('Signing operation failed :\nstdout : %s\nstderr : %s', stdout, stderr) raise SigverError(stderr) except DecryptError: - raise SigverError("Signing failed") + raise SigverError('Signing failed') - def validate_signature(self, signedtext, cert_file, cert_type, node_name, - node_id, id_attr): + def validate_signature(self, signedtext, cert_file, cert_type, node_name, node_id, id_attr): """ Validate signature on XML document. @@ -862,43 +862,38 @@ class CryptoBackendXmlSec1(CryptoBackend): :param cert_type: The file type of the certificate :param node_name: The name of the class that is signed :param node_id: The identifier of the node - :param id_attr: Should normally be one of "id", "Id" or "ID" + :param id_attr: Should normally be one of 'id', 'Id' or 'ID' :return: Boolean True if the signature was correct otherwise False. """ if not isinstance(signedtext, six.binary_type): - signedtext = signedtext.encode('utf-8') - _, fil = make_temp(signedtext, suffix=".xml", - decode=False, delete=self._xmlsec_delete_tmpfiles) - - com_list = [self.xmlsec, "--verify", - "--enabled-reference-uris", "empty,same-doc", - "--pubkey-cert-%s" % cert_type, cert_file, - "--id-attr:%s" % id_attr, node_name] - - if self.debug: - com_list.append("--store-signatures") + signedtext = signedtext.encode() + + _, fil = make_temp( + signedtext, + suffix='.xml', + decode=False, + delete=self._xmlsec_delete_tmpfiles) + + com_list = [ + self.xmlsec, + '--verify', + '--enabled-reference-uris', 'empty,same-doc', + '--pubkey-cert-{type}'.format(type=cert_type), cert_file, + '--id-attr:{id_attr_name}'.format(id_attr_name=id_attr), + node_name, + ] if node_id: - com_list.extend(["--node-id", node_id]) + com_list.extend(['--node-id', node_id]) - if self.__DEBUG: - try: - print(" ".join(com_list)) - except TypeError: - print("cert_type", cert_type) - print("cert_file", cert_file) - print("node_name", node_name) - print("fil", fil) - raise - print("%s: %s" % (cert_file, os.access(cert_file, os.F_OK))) - print("%s: %s" % (fil, os.access(fil, os.F_OK))) + (_stdout, stderr, _output) = self._run_xmlsec( + com_list, + [fil], + exception=SignatureError) - (_stdout, stderr, _output) = self._run_xmlsec(com_list, [fil], - exception=SignatureError) return parse_xmlsec_output(stderr) - def _run_xmlsec(self, com_list, extra_args, validate_output=True, - exception=XmlsecError): + def _run_xmlsec(self, com_list, extra_args, validate_output=True, exception=XmlsecError): """ Common code to invoke xmlsec and parse the output. :param com_list: Key-value parameter list for xmlsec @@ -908,20 +903,21 @@ class CryptoBackendXmlSec1(CryptoBackend): :param exception: The exception class to raise on errors :result: Whatever xmlsec wrote to an --output temporary file """ - with NamedTemporaryFile(suffix=".xml", delete=self._xmlsec_delete_tmpfiles) as ntf: - com_list.extend(["--output", ntf.name]) + with NamedTemporaryFile(suffix='.xml', delete=self._xmlsec_delete_tmpfiles) as ntf: + com_list.extend(['--output', ntf.name]) com_list += extra_args - logger.debug("xmlsec command: %s", " ".join(com_list)) + logger.debug('xmlsec command: %s', ' '.join(com_list)) pof = Popen(com_list, stderr=PIPE, stdout=PIPE) p_out, p_err = pof.communicate() - p_out = p_out.decode('utf-8') - p_err = p_err.decode('utf-8') + p_out = p_out.decode() + p_err = p_err.decode() if pof.returncode is not None and pof.returncode < 0: logger.error(LOG_LINE, p_out, p_err) - raise XmlsecError("%d:%s" % (pof.returncode, p_err)) + raise XmlsecError('{err_code}:{err_msg}'.format( + err_code=pof.returncode, err_msg=p_err)) try: if validate_output: @@ -947,17 +943,15 @@ class CryptoBackendXMLSecurity(CryptoBackend): to an external PKCS#11 module. """ - def __init__(self, debug=False): + def __init__(self): CryptoBackend.__init__(self) - self.debug = debug def version(self): # XXX if XMLSecurity.__init__ included a __version__, that would be # better than static 0.0 here. - return "XMLSecurity 0.0" + return 'XMLSecurity 0.0' - def sign_statement(self, statement, node_name, key_file, node_id, - _id_attr): + def sign_statement(self, statement, node_name, key_file, node_id, id_attr): """ Sign an XML statement. @@ -967,7 +961,7 @@ class CryptoBackendXMLSecurity(CryptoBackend): :param statement: XML as string :param node_name: Name of the node to sign :param key_file: xmlsec key_spec string(), filename, - "pkcs11://" URI or PEM data + 'pkcs11://' URI or PEM data :returns: Signed XML as string """ import xmlsec @@ -977,8 +971,7 @@ class CryptoBackendXMLSecurity(CryptoBackend): signed = xmlsec.sign(xml, key_file) return lxml.etree.tostring(signed, xml_declaration=True) - def validate_signature(self, signedtext, cert_file, cert_type, node_name, - node_id, id_attr): + def validate_signature(self, signedtext, cert_file, cert_type, node_name, node_id, id_attr): """ Validate signature on XML document. @@ -987,22 +980,23 @@ class CryptoBackendXMLSecurity(CryptoBackend): :param signedtext: The signed XML data as string :param cert_file: xmlsec key_spec string(), filename, - "pkcs11://" URI or PEM data + 'pkcs11://' URI or PEM data :param cert_type: string, must be 'pem' for now :returns: True on successful validation, False otherwise """ - if cert_type != "pem": - raise Unsupported("Only PEM certs supported here") - import xmlsec + if cert_type != 'pem': + raise Unsupported('Only PEM certs supported here') + import xmlsec xml = xmlsec.parse_xml(signedtext) + try: return xmlsec.verify(xml, cert_file) except xmlsec.XMLSigException: return False -def security_context(conf, debug=None): +def security_context(conf): """ Creates a security context based on the configuration :param conf: The configuration, this is a Config instance @@ -1011,72 +1005,76 @@ def security_context(conf, debug=None): if not conf: return None - if debug is None: - try: - debug = conf.debug - except AttributeError: - pass - try: metadata = conf.metadata except AttributeError: metadata = None - _only_md = conf.only_use_keys_in_metadata - if _only_md is None: - _only_md = False + try: + id_attr = conf.id_attr_name + except AttributeError: + id_attr = None sec_backend = None if conf.crypto_backend == 'xmlsec1': xmlsec_binary = conf.xmlsec_binary + if not xmlsec_binary: try: _path = conf.xmlsec_path except AttributeError: _path = [] xmlsec_binary = get_xmlsec_binary(_path) - # verify that xmlsec is where it's supposed to be + + # verify that xmlsec is where it's supposed to be if not os.path.exists(xmlsec_binary): # if not os.access(, os.F_OK): - raise SigverError( - "xmlsec binary not in '%s' !" % xmlsec_binary) - crypto = _get_xmlsec_cryptobackend(xmlsec_binary, debug=debug) - _file_name = conf.getattr("key_file", "") + err_msg = 'xmlsec binary not found: {binary}' + err_msg = err_msg.format(binary=xmlsec_binary) + raise SigverError(err_msg) + + crypto = _get_xmlsec_cryptobackend(xmlsec_binary) + + _file_name = conf.getattr('key_file', '') if _file_name: try: rsa_key = import_rsa_key_from_file(_file_name) except Exception as err: - logger.error("Could not import key from {}: {}".format(_file_name, - err)) + logger.error('Cannot import key from {file}: {err_msg}'.format( + file=_file_name, err_msg=err)) raise else: sec_backend = RSACrypto(rsa_key) elif conf.crypto_backend == 'XMLSecurity': # new and somewhat untested pyXMLSecurity crypto backend. - crypto = CryptoBackendXMLSecurity(debug=debug) + crypto = CryptoBackendXMLSecurity() else: - raise SigverError('Unknown crypto_backend %s' % ( - repr(conf.crypto_backend))) - + err_msg = 'Unknown crypto_backend {backend}' + err_msg = err_msg.format(backend=conf.crypto_backend) + raise SigverError(err_msg) enc_key_files = [] if conf.encryption_keypairs is not None: for _encryption_keypair in conf.encryption_keypairs: - if "key_file" in _encryption_keypair: - enc_key_files.append(_encryption_keypair["key_file"]) + if 'key_file' in _encryption_keypair: + enc_key_files.append(_encryption_keypair['key_file']) return SecurityContext( - crypto, conf.key_file, cert_file=conf.cert_file, metadata=metadata, - debug=debug, only_use_keys_in_metadata=_only_md, - cert_handler_extra_class=conf.cert_handler_extra_class, - generate_cert_info=conf.generate_cert_info, - tmp_cert_file=conf.tmp_cert_file, - tmp_key_file=conf.tmp_key_file, - validate_certificate=conf.validate_certificate, - enc_key_files=enc_key_files, - encryption_keypairs=conf.encryption_keypairs, - sec_backend=sec_backend) + crypto, + conf.key_file, + cert_file=conf.cert_file, + metadata=metadata, + only_use_keys_in_metadata=conf.only_use_keys_in_metadata, + cert_handler_extra_class=conf.cert_handler_extra_class, + generate_cert_info=conf.generate_cert_info, + tmp_cert_file=conf.tmp_cert_file, + tmp_key_file=conf.tmp_key_file, + validate_certificate=conf.validate_certificate, + enc_key_files=enc_key_files, + encryption_keypairs=conf.encryption_keypairs, + sec_backend=sec_backend, + id_attr=id_attr) def encrypt_cert_from_item(item): @@ -1098,25 +1096,14 @@ def encrypt_cert_from_item(item): _encrypt_cert = _tmp_key_info.x509_data[ 0].x509_certificate.text break - # _encrypt_cert = _elem[0].x509_data[ - # 0].x509_certificate.text - # else: - # certs = cert_from_instance(item) - # if len(certs) > 0: - # _encrypt_cert = certs[0] except Exception as _exception: pass - # if _encrypt_cert is None: - # certs = cert_from_instance(item) - # if len(certs) > 0: - # _encrypt_cert = certs[0] - if _encrypt_cert is not None: - if _encrypt_cert.find("-----BEGIN CERTIFICATE-----\n") == -1: - _encrypt_cert = "-----BEGIN CERTIFICATE-----\n" + _encrypt_cert - if _encrypt_cert.find("\n-----END CERTIFICATE-----") == -1: - _encrypt_cert = _encrypt_cert + "\n-----END CERTIFICATE-----" + if _encrypt_cert.find('-----BEGIN CERTIFICATE-----\n') == -1: + _encrypt_cert = '-----BEGIN CERTIFICATE-----\n' + _encrypt_cert + if _encrypt_cert.find('\n-----END CERTIFICATE-----') == -1: + _encrypt_cert = _encrypt_cert + '\n-----END CERTIFICATE-----' return _encrypt_cert @@ -1125,26 +1112,32 @@ class CertHandlerExtra(object): pass def use_generate_cert_func(self): - raise Exception("use_generate_cert_func function must be implemented") + raise Exception('use_generate_cert_func function must be implemented') def generate_cert(self, generate_cert_info, root_cert_string, root_key_string): - raise Exception("generate_cert function must be implemented") + raise Exception('generate_cert function must be implemented') # Excepts to return (cert_string, key_string) def use_validate_cert_func(self): - raise Exception("use_validate_cert_func function must be implemented") + raise Exception('use_validate_cert_func function must be implemented') def validate_cert(self, cert_str, root_cert_string, root_key_string): - raise Exception("validate_cert function must be implemented") + raise Exception('validate_cert function must be implemented') # Excepts to return True/False class CertHandler(object): - def __init__(self, security_context, cert_file=None, cert_type="pem", - key_file=None, key_type="pem", generate_cert_info=None, - cert_handler_extra_class=None, tmp_cert_file=None, - tmp_key_file=None, verify_cert=False): + def __init__( + self, + security_context, + cert_file=None, cert_type='pem', + key_file=None, key_type='pem', + generate_cert_info=None, + cert_handler_extra_class=None, + tmp_cert_file=None, + tmp_key_file=None, + verify_cert=False): """ Initiates the class for handling certificates. Enables the certificates to either be a single certificate as base functionality or makes it @@ -1168,19 +1161,19 @@ class CertHandler(object): # validated. self._last_cert_verified = None self._last_validated_cert = None - if cert_type == "pem" and key_type == "pem": + if cert_type == 'pem' and key_type == 'pem': self._verify_cert = verify_cert is True self._security_context = security_context self._osw = OpenSSLWrapper() if key_file and os.path.isfile(key_file): self._key_str = self._osw.read_str_from_file(key_file, key_type) else: - self._key_str = "" + self._key_str = '' if cert_file and os.path.isfile(cert_file): self._cert_str = self._osw.read_str_from_file(cert_file, cert_type) else: - self._cert_str = "" + self._cert_str = '' self._tmp_cert_str = self._cert_str self._tmp_key_str = self._key_str @@ -1189,9 +1182,11 @@ class CertHandler(object): self._cert_info = None self._generate_cert_func_active = False - if generate_cert_info is not None and len(self._cert_str) > 0 and \ - len(self._key_str) > 0 and tmp_key_file is not \ - None and tmp_cert_file is not None: + if generate_cert_info is not None \ + and len(self._cert_str) > 0 \ + and len(self._key_str) > 0 \ + and tmp_key_file is not None \ + and tmp_cert_file is not None: self._generate_cert = True self._cert_info = generate_cert_info self._cert_handler_extra_class = cert_handler_extra_class @@ -1199,7 +1194,7 @@ class CertHandler(object): def verify_cert(self, cert_file): if self._verify_cert: if cert_file and os.path.isfile(cert_file): - cert_str = self._osw.read_str_from_file(cert_file, "pem") + cert_str = self._osw.read_str_from_file(cert_file, 'pem') else: return False self._last_validated_cert = cert_str @@ -1209,7 +1204,7 @@ class CertHandler(object): cert_str, self._cert_str, self._key_str) else: valid, mess = self._osw.verify(self._cert_str, cert_str) - logger.info("CertHandler.verify_cert: %s", mess) + logger.info('CertHandler.verify_cert: %s', mess) return valid return True @@ -1221,7 +1216,7 @@ class CertHandler(object): if client_crt is not None: self._tmp_cert_str = client_crt # No private key for signing - self._tmp_key_str = "" + self._tmp_key_str = '' elif self._cert_handler_extra_class is not None and \ self._cert_handler_extra_class.use_generate_cert_func(): (self._tmp_cert_str, self._tmp_key_str) = \ @@ -1229,8 +1224,7 @@ class CertHandler(object): self._cert_info, self._cert_str, self._key_str) else: self._tmp_cert_str, self._tmp_key_str = self._osw \ - .create_certificate( - self._cert_info, request=True) + .create_certificate(self._cert_info, request=True) self._tmp_cert_str = self._osw.create_cert_signed_certificate( self._cert_str, self._key_str, self._tmp_cert_str) valid, mess = self._osw.verify(self._cert_str, @@ -1239,8 +1233,8 @@ class CertHandler(object): self._osw.write_str_to_file(self._tmp_key_file, self._tmp_key_str) self._security_context.key_file = self._tmp_key_file self._security_context.cert_file = self._tmp_cert_file - self._security_context.key_type = "pem" - self._security_context.cert_type = "pem" + self._security_context.key_type = 'pem' + self._security_context.cert_type = 'pem' self._security_context.my_cert = read_cert_from_file( self._security_context.cert_file, self._security_context.cert_type) @@ -1250,17 +1244,29 @@ class CertHandler(object): # openssl x509 -inform pem -noout -in server.crt -pubkey > publickey.pem # openssl rsa -inform pem -noout -in publickey.pem -pubin -modulus class SecurityContext(object): + DEFAULT_ID_ATTR_NAME = 'ID' my_cert = None - def __init__(self, crypto, key_file="", key_type="pem", - cert_file="", cert_type="pem", metadata=None, - debug=False, template="", encrypt_key_type="des-192", - only_use_keys_in_metadata=False, cert_handler_extra_class=None, - generate_cert_info=None, tmp_cert_file=None, - tmp_key_file=None, validate_certificate=None, - enc_key_files=None, enc_key_type="pem", - encryption_keypairs=None, enc_cert_type="pem", - sec_backend=None): + def __init__( + self, + crypto, + key_file='', key_type='pem', + cert_file='', cert_type='pem', + metadata=None, + template='', + encrypt_key_type='des-192', + only_use_keys_in_metadata=False, + cert_handler_extra_class=None, + generate_cert_info=None, + tmp_cert_file=None, tmp_key_file=None, + validate_certificate=None, + enc_key_files=None, enc_key_type='pem', + encryption_keypairs=None, + enc_cert_type='pem', + sec_backend=None, + id_attr=''): + + self.id_attr = id_attr or SecurityContext.DEFAULT_ID_ATTR_NAME self.crypto = crypto assert (isinstance(self.crypto, CryptoBackend)) @@ -1287,20 +1293,24 @@ class SecurityContext(object): self.my_cert = read_cert_from_file(cert_file, cert_type) - self.cert_handler = CertHandler(self, cert_file, cert_type, key_file, - key_type, generate_cert_info, - cert_handler_extra_class, tmp_cert_file, - tmp_key_file, validate_certificate) + self.cert_handler = CertHandler( + self, + cert_file, cert_type, + key_file, key_type, + generate_cert_info, + cert_handler_extra_class, + tmp_cert_file, + tmp_key_file, + validate_certificate) self.cert_handler.update_cert(True) self.metadata = metadata self.only_use_keys_in_metadata = only_use_keys_in_metadata - self.debug = debug if not template: this_dir, this_filename = os.path.split(__file__) - self.template = os.path.join(this_dir, "xml_template", "template.xml") + self.template = os.path.join(this_dir, 'xml_template', 'template.xml') else: self.template = template @@ -1312,10 +1322,10 @@ class SecurityContext(object): self._xmlsec_delete_tmpfiles = True def correctly_signed(self, xml, must=False): - logger.debug("verify correct signature") + logger.debug('verify correct signature') return self.correctly_signed_response(xml, must) - def encrypt(self, text, recv_key="", template="", key_type=""): + def encrypt(self, text, recv_key='', template='', key_type=''): """ xmlsec encrypt --pubkey-pem pub-userkey.pem --session-key aes128-cbc --xml-data doc-plain.xml @@ -1334,8 +1344,7 @@ class SecurityContext(object): return self.crypto.encrypt(text, recv_key, template, key_type) - def encrypt_assertion(self, statement, enc_key, template, - key_type="des-192", node_xpath=None): + def encrypt_assertion(self, statement, enc_key, template, key_type='des-192', node_xpath=None): """ Will encrypt an assertion @@ -1345,53 +1354,65 @@ class SecurityContext(object): :param key_type: The type of session key to use. :return: The encrypted text """ - return self.crypto.encrypt_assertion(statement, enc_key, template, - key_type, node_xpath) + return self.crypto.encrypt_assertion( + statement, enc_key, template, key_type, node_xpath) - def decrypt_keys(self, enctext, keys=None): + def decrypt_keys(self, enctext, keys=None, id_attr=''): """ Decrypting an encrypted text by the use of a private key. :param enctext: The encrypted text as a string :return: The decrypted text """ _enctext = None + + if not id_attr: + id_attr = self.id_attr + if not isinstance(keys, list): keys = [keys] + if self.enc_key_files is not None: for _enc_key_file in self.enc_key_files: - _enctext = self.crypto.decrypt(enctext, _enc_key_file) + _enctext = self.crypto.decrypt(enctext, _enc_key_file, id_attr) if _enctext is not None and len(_enctext) > 0: return _enctext + for _key in keys: if _key is not None and len(_key.strip()) > 0: if not isinstance(_key, six.binary_type): _key = str(_key).encode('ascii') _, key_file = make_temp(_key, decode=False) - _enctext = self.crypto.decrypt(enctext, key_file) + _enctext = self.crypto.decrypt(enctext, key_file, id_attr) if _enctext is not None and len(_enctext) > 0: return _enctext + return enctext - def decrypt(self, enctext, key_file=None): + def decrypt(self, enctext, key_file=None, id_attr=''): """ Decrypting an encrypted text by the use of a private key. :param enctext: The encrypted text as a string :return: The decrypted text """ _enctext = None + + if not id_attr: + id_attr = self.id_attr + if self.enc_key_files is not None: for _enc_key_file in self.enc_key_files: - _enctext = self.crypto.decrypt(enctext, _enc_key_file) + _enctext = self.crypto.decrypt(enctext, _enc_key_file, id_attr) if _enctext is not None and len(_enctext) > 0: return _enctext + if key_file is not None and len(key_file.strip()) > 0: - _enctext = self.crypto.decrypt(enctext, key_file) + _enctext = self.crypto.decrypt(enctext, key_file, id_attr) if _enctext is not None and len(_enctext) > 0: return _enctext + return enctext - def verify_signature(self, signedtext, cert_file=None, cert_type="pem", - node_name=NODE_NAME, node_id=None, id_attr=""): + def verify_signature(self, signedtext, cert_file=None, cert_type='pem', node_name=NODE_NAME, node_id=None, id_attr=''): """ Verifies the signature of a XML document. :param signedtext: The XML document as a string @@ -1399,7 +1420,7 @@ class SecurityContext(object): :param cert_type: The file type of the certificate :param node_name: The name of the class that is signed :param node_id: The identifier of the node - :param id_attr: Should normally be one of "id", "Id" or "ID" + :param id_attr: Should normally be one of 'id', 'Id' or 'ID' :return: Boolean True if the signature was correct otherwise False. """ # This is only for testing purposes, otherwise when would you receive @@ -1409,17 +1430,17 @@ class SecurityContext(object): cert_type = self.cert_type if not id_attr: - id_attr = ID_ATTR + id_attr = self.id_attr - return self.crypto.validate_signature(signedtext, cert_file=cert_file, - cert_type=cert_type, - node_name=node_name, - node_id=node_id, id_attr=id_attr) + return self.crypto.validate_signature( + signedtext, + cert_file=cert_file, + cert_type=cert_type, + node_name=node_name, + node_id=node_id, + id_attr=id_attr) - def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, - origdoc=None, id_attr="", must=False, - only_valid_cert=False, issuer=None): - # print(item) + def _check_signature(self, decoded_xml, item, node_name=NODE_NAME, origdoc=None, id_attr='', must=False, only_valid_cert=False, issuer=None): try: _issuer = item.issuer.text.strip() except AttributeError: @@ -1430,68 +1451,76 @@ class SecurityContext(object): _issuer = issuer.text.strip() except AttributeError: _issuer = None + # More trust in certs from metadata then certs in the XML document if self.metadata: try: - _certs = self.metadata.certs(_issuer, "any", "signing") + _certs = self.metadata.certs(_issuer, 'any', 'signing') except KeyError: _certs = [] certs = [] + for cert in _certs: if isinstance(cert, six.string_types): - certs.append(make_temp(pem_format(cert), suffix=".pem", - decode=False, - delete=self._xmlsec_delete_tmpfiles)) + certs.append(make_temp( + pem_format(cert), + suffix='.pem', + decode=False, + delete=self._xmlsec_delete_tmpfiles)) else: certs.append(cert) else: certs = [] if not certs and not self.only_use_keys_in_metadata: - logger.debug("==== Certs from instance ====") - certs = [make_temp(pem_format(cert), suffix=".pem", - decode=False, - delete=self._xmlsec_delete_tmpfiles) - for cert in cert_from_instance(item)] + logger.debug('==== Certs from instance ====') + certs = [ + make_temp( + pem_format(cert), + suffix='.pem', + decode=False, + delete=self._xmlsec_delete_tmpfiles) + for cert in cert_from_instance(item) + ] else: - logger.debug("==== Certs from metadata ==== %s: %s ====", issuer, - certs) + logger.debug('==== Certs from metadata ==== %s: %s ====', issuer, certs) if not certs: - raise MissingKey("%s" % issuer) - - # print(certs) + raise MissingKey(issuer) verified = False last_pem_file = None + for _, pem_file in certs: try: last_pem_file = pem_file - if self.verify_signature(decoded_xml, pem_file, - node_name=node_name, - node_id=item.id, id_attr=id_attr): + if self.verify_signature( + decoded_xml, + pem_file, + node_name=node_name, + node_id=item.id, + id_attr=id_attr): verified = True break except XmlsecError as exc: - logger.error("check_sig: %s", exc) + logger.error('check_sig: %s', exc) pass except SignatureError as exc: - logger.error("check_sig: %s", exc) + logger.error('check_sig: %s', exc) pass except Exception as exc: - logger.error("check_sig: %s", exc) + logger.error('check_sig: %s', exc) raise - if (not verified) and (not only_valid_cert): - raise SignatureError("Failed to verify signature") - else: + if verified or only_valid_cert: if not self.cert_handler.verify_cert(last_pem_file): - raise CertificateError("Invalid certificate!") + raise CertificateError('Invalid certificate!') + else: + raise SignatureError('Failed to verify signature') return item - def check_signature(self, item, node_name=NODE_NAME, origdoc=None, - id_attr="", must=False, issuer=None): + def check_signature(self, item, node_name=NODE_NAME, origdoc=None, id_attr='', must=False, issuer=None): """ :param item: Parsed entity @@ -1501,11 +1530,16 @@ class SecurityContext(object): :param must: :return: """ - return self._check_signature(origdoc, item, node_name, origdoc, - id_attr=id_attr, must=must, issuer=issuer) - - def correctly_signed_message(self, decoded_xml, msgtype, must=False, - origdoc=None, only_valid_cert=False): + return self._check_signature( + origdoc, + item, + node_name, + origdoc, + id_attr=id_attr, + must=must, + issuer=issuer) + + def correctly_signed_message(self, decoded_xml, msgtype, must=False, origdoc=None, only_valid_cert=False): """Check if a request is correctly signed, if we have metadata for the entity that sent the info use that, if not use the key that are in the message if any. @@ -1517,136 +1551,76 @@ class SecurityContext(object): :return: """ - try: - _func = getattr(samlp, "%s_from_string" % msgtype) - except AttributeError: - _func = getattr(saml, "%s_from_string" % msgtype) + attr = '{type}_from_string'.format(type=msgtype) + _func = getattr(saml, attr, None) + _func = getattr(samlp, attr, _func) msg = _func(decoded_xml) if not msg: - raise TypeError("Not a %s" % msgtype) + raise TypeError('Not a {type}'.format(type=msgtype)) if not msg.signature: if must: - raise SignatureError( - "Required signature missing on %s" % msgtype) + err_msg = 'Required signature missing on {type}' + err_msg = err_msg.format(type=msgtype) + raise SignatureError(err_msg) else: return msg - return self._check_signature(decoded_xml, msg, class_name(msg), - origdoc, must=must, - only_valid_cert=only_valid_cert) - - def correctly_signed_authn_request(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, "authn_request", - must, origdoc, - only_valid_cert=only_valid_cert) - - def correctly_signed_authn_query(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, "authn_query", - must, origdoc, only_valid_cert) - - def correctly_signed_logout_request(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, "logout_request", - must, origdoc, only_valid_cert) - - def correctly_signed_logout_response(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, "logout_response", - must, origdoc, only_valid_cert) - - def correctly_signed_attribute_query(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, "attribute_query", - must, origdoc, only_valid_cert) - - def correctly_signed_authz_decision_query(self, decoded_xml, must=False, - origdoc=None, - only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, - "authz_decision_query", must, - origdoc, only_valid_cert) - - def correctly_signed_authz_decision_response(self, decoded_xml, must=False, - origdoc=None, - only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, - "authz_decision_response", must, - origdoc, only_valid_cert) - - def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False, - origdoc=None, - only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, - "name_id_mapping_request", - must, origdoc, only_valid_cert) - - def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False, - origdoc=None, - only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, - "name_id_mapping_response", - must, origdoc, only_valid_cert) - - def correctly_signed_artifact_request(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, - "artifact_request", - must, origdoc, only_valid_cert) - - def correctly_signed_artifact_response(self, decoded_xml, must=False, - origdoc=None, only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, - "artifact_response", - must, origdoc, only_valid_cert) - - def correctly_signed_manage_name_id_request(self, decoded_xml, must=False, - origdoc=None, - only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, - "manage_name_id_request", - must, origdoc, only_valid_cert) - - def correctly_signed_manage_name_id_response(self, decoded_xml, must=False, - origdoc=None, - only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, - "manage_name_id_response", must, - origdoc, only_valid_cert) - - def correctly_signed_assertion_id_request(self, decoded_xml, must=False, - origdoc=None, - only_valid_cert=False, - **kwargs): - return self.correctly_signed_message(decoded_xml, - "assertion_id_request", must, - origdoc, only_valid_cert) - - def correctly_signed_assertion_id_response(self, decoded_xml, must=False, - origdoc=None, - only_valid_cert=False, **kwargs): - return self.correctly_signed_message(decoded_xml, "assertion", must, - origdoc, only_valid_cert) - - def correctly_signed_response(self, decoded_xml, must=False, origdoc=None, - only_valid_cert=False, - require_response_signature=False, **kwargs): + return self._check_signature( + decoded_xml, + msg, + class_name(msg), + origdoc, + must=must, + only_valid_cert=only_valid_cert) + + def correctly_signed_authn_request(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'authn_request', must, origdoc, only_valid_cert=only_valid_cert) + + def correctly_signed_authn_query(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'authn_query', must, origdoc, only_valid_cert) + + def correctly_signed_logout_request(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'logout_request', must, origdoc, only_valid_cert) + + def correctly_signed_logout_response(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'logout_response', must, origdoc, only_valid_cert) + + def correctly_signed_attribute_query(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'attribute_query', must, origdoc, only_valid_cert) + + def correctly_signed_authz_decision_query(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'authz_decision_query', must, origdoc, only_valid_cert) + + def correctly_signed_authz_decision_response(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'authz_decision_response', must, origdoc, only_valid_cert) + + def correctly_signed_name_id_mapping_request(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'name_id_mapping_request', must, origdoc, only_valid_cert) + + def correctly_signed_name_id_mapping_response(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'name_id_mapping_response', must, origdoc, only_valid_cert) + + def correctly_signed_artifact_request(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'artifact_request', must, origdoc, only_valid_cert) + + def correctly_signed_artifact_response(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'artifact_response', must, origdoc, only_valid_cert) + + def correctly_signed_manage_name_id_request(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'manage_name_id_request', must, origdoc, only_valid_cert) + + def correctly_signed_manage_name_id_response(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'manage_name_id_response', must, origdoc, only_valid_cert) + + def correctly_signed_assertion_id_request(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'assertion_id_request', must, origdoc, only_valid_cert) + + def correctly_signed_assertion_id_response(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, **kwargs): + return self.correctly_signed_message(decoded_xml, 'assertion', must, origdoc, only_valid_cert) + + def correctly_signed_response(self, decoded_xml, must=False, origdoc=None, only_valid_cert=False, require_response_signature=False, **kwargs): """ Check if a instance is correctly signed, if we have metadata for the IdP that sent the info use that, if not use the key that are in the message if any. @@ -1661,28 +1635,24 @@ class SecurityContext(object): response = samlp.any_response_from_string(decoded_xml) if not response: - raise TypeError("Not a Response") + raise TypeError('Not a Response') if response.signature: - if "do_not_verify" in kwargs: + if 'do_not_verify' in kwargs: pass else: self._check_signature(decoded_xml, response, class_name(response), origdoc) elif require_response_signature: - raise SignatureError("Signature missing for response") + raise SignatureError('Signature missing for response') return response - # -------------------------------------------------------------------------- - # SIGNATURE PART - # -------------------------------------------------------------------------- def sign_statement_using_xmlsec(self, statement, **kwargs): """ Deprecated function. See sign_statement(). """ return self.sign_statement(statement, **kwargs) - def sign_statement(self, statement, node_name, key=None, - key_file=None, node_id=None, id_attr=""): + def sign_statement(self, statement, node_name, key=None, key_file=None, node_id=None, id_attr=''): """Sign a SAML statement. :param statement: The statement to be signed @@ -1695,21 +1665,25 @@ class SecurityContext(object): :return: The signed statement """ if not id_attr: - id_attr = ID_ATTR + id_attr = self.id_attr if not key_file and key: - _, key_file = make_temp(str(key).encode('utf-8'), ".pem") + _, key_file = make_temp(str(key).encode(), '.pem') if not key and not key_file: key_file = self.key_file - return self.crypto.sign_statement(statement, node_name, key_file, - node_id, id_attr) + return self.crypto.sign_statement( + statement, + node_name, + key_file, + node_id, + id_attr) def sign_assertion_using_xmlsec(self, statement, **kwargs): """ Deprecated function. See sign_assertion(). """ - return self.sign_statement(statement, class_name(saml.Assertion()), - **kwargs) + return self.sign_statement( + statement, class_name(saml.Assertion()), **kwargs) def sign_assertion(self, statement, **kwargs): """Sign a SAML assertion. @@ -1719,8 +1693,8 @@ class SecurityContext(object): :param statement: The statement to be signed :return: The signed statement """ - return self.sign_statement(statement, class_name(saml.Assertion()), - **kwargs) + return self.sign_statement( + statement, class_name(saml.Assertion()), **kwargs) def sign_attribute_query_using_xmlsec(self, statement, **kwargs): """ Deprecated function. See sign_attribute_query(). """ @@ -1734,11 +1708,10 @@ class SecurityContext(object): :param statement: The statement to be signed :return: The signed statement """ - return self.sign_statement(statement, class_name( - samlp.AttributeQuery()), **kwargs) + return self.sign_statement( + statement, class_name(samlp.AttributeQuery()), **kwargs) - def multiple_signatures(self, statement, to_sign, key=None, key_file=None, - sign_alg=None, digest_alg=None): + def multiple_signatures(self, statement, to_sign, key=None, key_file=None, sign_alg=None, digest_alg=None): """ Sign multiple parts of a statement @@ -1757,21 +1730,24 @@ class SecurityContext(object): sid = item.id if not item.signature: - item.signature = pre_signature_part(sid, self.cert_file, - sign_alg=sign_alg, - digest_alg=digest_alg) + item.signature = pre_signature_part( + sid, + self.cert_file, + sign_alg=sign_alg, + digest_alg=digest_alg) + + statement = self.sign_statement( + statement, + class_name(item), + key=key, + key_file=key_file, + node_id=sid, + id_attr=id_attr) - statement = self.sign_statement(statement, class_name(item), - key=key, key_file=key_file, - node_id=sid, id_attr=id_attr) return statement -# =========================================================================== - - -def pre_signature_part(ident, public_key=None, identifier=None, - digest_alg=None, sign_alg=None): +def pre_signature_part(ident, public_key=None, identifier=None, digest_alg=None, sign_alg=None): """ If an assertion is to be signed the signature part has to be preset with which algorithms to be used, this function returns such a @@ -1796,18 +1772,23 @@ def pre_signature_part(ident, public_key=None, identifier=None, transforms = ds.Transforms(transform=[trans0, trans1]) digest_method = ds.DigestMethod(algorithm=digest_alg) - reference = ds.Reference(uri="#%s" % ident, digest_value=ds.DigestValue(), - transforms=transforms, digest_method=digest_method) + reference = ds.Reference( + uri='#{id}'.format(id=ident), + digest_value=ds.DigestValue(), + transforms=transforms, + digest_method=digest_method) - signed_info = ds.SignedInfo(signature_method=signature_method, - canonicalization_method=canonicalization_method, - reference=reference) + signed_info = ds.SignedInfo( + signature_method=signature_method, + canonicalization_method=canonicalization_method, + reference=reference) - signature = ds.Signature(signed_info=signed_info, - signature_value=ds.SignatureValue()) + signature = ds.Signature( + signed_info=signed_info, + signature_value=ds.SignatureValue()) if identifier: - signature.id = "Signature%d" % identifier + signature.id = 'Signature{n}'.format(n=identifier) if public_key: x509_data = ds.X509Data( @@ -1845,8 +1826,8 @@ def pre_signature_part(ident, public_key=None, identifier=None, # </CipherData> # </EncryptedData> -def pre_encryption_part(msg_enc=TRIPLE_DES_CBC, key_enc=RSA_1_5, - key_name="my-rsa-key"): + +def pre_encryption_part(msg_enc=TRIPLE_DES_CBC, key_enc=RSA_1_5, key_name='my-rsa-key'): """ :param msg_enc: @@ -1856,19 +1837,20 @@ def pre_encryption_part(msg_enc=TRIPLE_DES_CBC, key_enc=RSA_1_5, """ msg_encryption_method = EncryptionMethod(algorithm=msg_enc) key_encryption_method = EncryptionMethod(algorithm=key_enc) - encrypted_key = EncryptedKey(id="EK", - encryption_method=key_encryption_method, - key_info=ds.KeyInfo( - key_name=ds.KeyName(text=key_name)), - cipher_data=CipherData( - cipher_value=CipherValue(text=""))) + encrypted_key = EncryptedKey( + id='EK', + encryption_method=key_encryption_method, + key_info=ds.KeyInfo( + key_name=ds.KeyName(text=key_name)), + cipher_data=CipherData( + cipher_value=CipherValue(text=''))) key_info = ds.KeyInfo(encrypted_key=encrypted_key) encrypted_data = EncryptedData( - id="ED", - type="http://www.w3.org/2001/04/xmlenc#Element", + id='ED', + type='http://www.w3.org/2001/04/xmlenc#Element', encryption_method=msg_encryption_method, key_info=key_info, - cipher_data=CipherData(cipher_value=CipherValue(text=""))) + cipher_data=CipherData(cipher_value=CipherValue(text=''))) return encrypted_data @@ -1887,13 +1869,6 @@ def pre_encrypt_assertion(response): response.encrypted_assertion.add_extension_elements(assertion) else: response.encrypted_assertion.add_extension_element(assertion) - # txt = "%s" % response - # _ass = "%s" % assertion - # _ass = rm_xmltag(_ass) - # txt.replace( - # "<ns1:EncryptedAssertion/>", - # "<ns1:EncryptedAssertion>%s</ns1:EncryptedAssertion>" % _ass) - return response @@ -1903,8 +1878,8 @@ def response_factory(sign=False, encrypt=False, sign_alg=None, digest_alg=None, issue_instant=instant()) if sign: - response.signature = pre_signature_part(kwargs["id"], sign_alg=sign_alg, - digest_alg=digest_alg) + response.signature = pre_signature_part( + kwargs['id'], sign_alg=sign_alg, digest_alg=digest_alg) if encrypt: pass @@ -1914,14 +1889,13 @@ def response_factory(sign=False, encrypt=False, sign_alg=None, digest_alg=None, return response -# ---------------------------------------------------------------------------- if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('-s', '--list-sigalgs', dest='listsigalgs', - action='store_true', - help='List implemented signature algorithms') + action='store_true', + help='List implemented signature algorithms') args = parser.parse_args() if args.listsigalgs: diff --git a/tests/okta_assertion b/tests/okta_assertion new file mode 100644 index 00000000..35da4890 --- /dev/null +++ b/tests/okta_assertion @@ -0,0 +1,2 @@ +<?xml version='1.0' encoding='UTF-8'?> +<ns0:Assertion xmlns:ns0="urn:oasis:names:tc:SAML:2.0:assertion" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" ID="id615445004975895274776851" IssueInstant="2018-03-07T01:08:39.444Z" Version="2.0"><ns0:Issuer Format="urn:oasis:names:tc:SAML:2.0:nameid-format:entity">http://www.okta.com/Issuer</ns0:Issuer><ns0:Subject><ns0:NameID Format="urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified">userName</ns0:NameID><ns0:SubjectConfirmation Method="urn:oasis:names:tc:SAML:2.0:cm:bearer"><ns0:SubjectConfirmationData NotOnOrAfter="2018-03-07T01:13:39.448Z" Recipient="https://example.com" /></ns0:SubjectConfirmation></ns0:Subject><ns0:Conditions NotBefore="2018-03-07T01:03:39.448Z" NotOnOrAfter="2018-03-07T01:13:39.448Z"><ns0:AudienceRestriction><ns0:Audience>audience</ns0:Audience></ns0:AudienceRestriction></ns0:Conditions><ns0:AuthnStatement AuthnInstant="2018-03-07T01:08:39.444Z"><ns0:AuthnContext><ns0:AuthnContextClassRef>urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport</ns0:AuthnContextClassRef></ns0:AuthnContext></ns0:AuthnStatement><ns0:AttributeStatement><ns0:Attribute Name="name" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:unspecified"><ns0:AttributeValue xsi:type="xs:string">John Doe</ns0:AttributeValue></ns0:Attribute></ns0:AttributeStatement></ns0:Assertion>
\ No newline at end of file diff --git a/tests/okta_response.xml b/tests/okta_response.xml new file mode 100644 index 00000000..3c0d52ba --- /dev/null +++ b/tests/okta_response.xml @@ -0,0 +1,47 @@ +<?xml version="1.0" encoding="UTF-8"?> +<saml2:EncryptedAssertion + xmlns:saml2="urn:oasis:names:tc:SAML:2.0:assertion"> + <xenc:EncryptedData + xmlns:xenc="http://www.w3.org/2001/04/xmlenc#" Id="_09c17cde95c4d7b85e79b1e4fb70907c" Type="http://www.w3.org/2001/04/xmlenc#Element"> + <xenc:EncryptionMethod Algorithm="http://www.w3.org/2001/04/xmlenc#aes256-cbc" + xmlns:xenc="http://www.w3.org/2001/04/xmlenc#"/> + <ds:KeyInfo + xmlns:ds="http://www.w3.org/2000/09/xmldsig#"> + <ds:RetrievalMethod Type="http://www.w3.org/2001/04/xmlenc#EncryptedKey" URI="#_d3dd57a740decc239fa7818eabe6aaf7"/> + </ds:KeyInfo> + <xenc:CipherData + xmlns:xenc="http://www.w3.org/2001/04/xmlenc#"> + <xenc:CipherValue>kqkk8wwlKllIjzPxj8YTuEIRiH7sqrZF99ONbZ1GrSnZljUXsGl/XttJbFDsJ3ty2JdCD68N3HeinNp+QtMwuHUE10DQhfRCb5cvgKRhCPXbxnZk3p70OVQxXK1Gh1d8jkEz7/oHNT0EnnHu8ysZEFy82LrkMhttspoiIwJyXEvbSVQTzh2hPsJTx3z4SvewiQQre4ST6zY7f7QzAK0Y7wCj8JV3qtTJPeiqwMFbDhQM/WYYYzbt7EqfWYYKWmkTaznnxUG73N+iwSVU6DkRWgk0o4RAUzAmEypsSirD6Vdzbgx468bmPnfhGa3IjG4URMXbHEgjTJDP+StPGLma0oa5GSSB27So9oR5LGB9UJkK2y1gF5xFce8i6BrxPtnEEtGoctiiB2PQ1gVP2L1opn32dajWBXTFnRWmXJQ+ekZbeWejX6LadcQY1/X56xIUwbRobkcowLGtRtd1SpxFA4RJkQYzJF7zYLHRLtbJzAFKJjlaOWTS3aPj1LOPRfPWjObrKpdd0XfN2NkveA/wgwbnAxhBDi1J6apJkjX18494GLVtRH15MjDWWSMM1Kui7aZR5CHVxUIQwuE+0+zwaO89BOh6pGO/Qrev7YzYVCkQlJf7V2OsKLKTUobqFN6RLIBI8036dhkcEaCrqs3uI5xAr3C5IwLZ6O2MUGzSEnhYToTnlO4I7lotAev3BArPQeXeo1Tcw8tedFSIbLN0GIyzxfnMDegjdw6vIYZRbolJ5QDurWlhyAQDEe10OUWtvq43GZD0scIsWtnClLMusZoQv/y2+9pjtsKl+iXiUxlV6NmDJdMf7R2KUC5wkwoRpJ2x+wGH7xs8Xf7suSE/9Y3Kn7aLONnCb+SpAG7ItBISHj6lTnl87r4Cieylc3/85pBKX5YbUTrEtrZdZvnGrxn2TGeY0QN4PNYtS+UoL/KQMmnrBblddSyFTEOFrT0CLw77AHVzjAnzXxT3fyMvomubLW7UefG/d2kcX0b/A9rBPIHEl0IJQPxuEZRAsp52ilcI7crXX2HLDMzN8KOfxKlT62gnAh/dXVTdU1PHRa6f/dUEEI2B/1Dxt/aZ6uDC4dPDseKaJ9r9h37dZHMzJudZ18Jt0Smj2KGySGY+cAMy9/8sPXEJzp1vRWIXLccfnpWWu7PlENfIeIncyyDl1cBhYWYKFmQjoEkRL3vQsXiMh4n4UzGUQ8Xm9dTk+NpQKvwRtXBmQvx8giST5jzQH4GGl2Pq3aWLFLGAsJ25ndnamK50eUJGDZKlEPRjrEFv6HXnRQDH6j0Kyj1sTmitUWd2TGn7TcEqEkrvhON0xVeUxFvUYRH/+RLjzyxps8gXtY/CQkE5P9kjCCDIbMabKihg1TTSO8AmEIJzv4Dkrq+LJcR9ovYChvLFflK38eM4swr8ofYvP4nighNgKruZpgDMrW0bKnsSlh997feEyhuaiUFsqHfONvB+ACn6Sw9w105eHLsvFrDf733GTg+iDbxkzofJrXIibBLv15i91pi5A9VfduRVFFXQ1Iu2Fvmrgz/x27w47Lhj0OIhLO3mJui4dbKRe/yZnoasJVspNpXJzAt3X7mfUFOEB9GoGQgpdP71LNmgJ2FFukOl10wKBoDh4WVKqSarQFWn+EoPp/HgrEDgoJUt+ewzfBviQXTgya2v+f7B7yNrYyP6+ZAv0HSCo+1iDCZaKmyNpW646pN3wOMqYPnL7pEY3bnVJ6mQfQX54wEgKbCVxHjKfH4F1aWfk8fvMK0J8X/KYpmTxyw3KghrnEL5BHyYob5a496BhqJW164oGJbYi0/SnTQ/sPWIZ8Qa2vMCUJV6okyDUerpsG4Z320h48156TM0gSt+VMbjyp5Py8JyCu67jLYb9DRRgl14EwIvwhq7ARNH32W3LYeefWZ1cFVQJ52zkVYK6JmLuNZI7Tp1K6QeircZeIpoaXqQ4mzAOFMxqPIe4qE=</xenc:CipherValue> + </xenc:CipherData> + </xenc:EncryptedData> + <xenc:EncryptedKey + xmlns:xenc="http://www.w3.org/2001/04/xmlenc#" Id="_d3dd57a740decc239fa7818eabe6aaf7"> + <xenc:EncryptionMethod Algorithm="http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p" + xmlns:xenc="http://www.w3.org/2001/04/xmlenc#"> + <ds:DigestMethod + xmlns:ds="http://www.w3.org/2000/09/xmldsig#" Algorithm="http://www.w3.org/2000/09/xmldsig#sha1"/> + </xenc:EncryptionMethod> + <ds:KeyInfo + xmlns:ds="http://www.w3.org/2000/09/xmldsig#"> + <ds:X509Data> + <ds:X509Certificate>MIICHzCCAYgCAQEwDQYJKoZIhvcNAQELBQAwWDELMAkGA1UEBhMCenoxCzAJBgNVBAgMAnp6MQ0w +CwYDVQQHDAR6enp6MQ4wDAYDVQQKDAVaenp6ejEOMAwGA1UECwwFWnp6enoxDTALBgNVBAMMBHRl +c3QwHhcNMTUwNjAyMDc0MjI2WhcNMjUwNTMwMDc0MjI2WjBYMQswCQYDVQQGEwJ6ejELMAkGA1UE +CAwCenoxDTALBgNVBAcMBHp6enoxDjAMBgNVBAoMBVp6enp6MQ4wDAYDVQQLDAVaenp6ejENMAsG +A1UEAwwEdGVzdDCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAx3I/NFlP1wbHfRZckJn4z1HX +5nnYQhQ3ekxEJmTTaj/1BvlZBmvgV40SBzH4nP1sT02xoQo7+vHItFAzaJlF2oBXsSxjaZMGu/gk +VbaHP9cYKvskhOjOJ4XArrUnKMTb1jZ+XkkOuot1NLE7/dTILF8ahHU2omYNASLnxHN3bnkCAwEA +ATANBgkqhkiG9w0BAQsFAAOBgQCQam1Oz7iQcD9+OurBM5a+Hth53m5hbAFuguSvERPCuJ/CfP1+ +g7CIZN/GnsIsg9QW77NvdOyxjXxzoJJmokl1qz/qy3FY3mJ0gIUxDyPD9DL3c9/03MDv5YmWsoP+ +HNqK8QtNJ/JDEOhBr/Eo/MokRo4gtMNeLF/soveWNoNiUg==</ds:X509Certificate> + </ds:X509Data> + </ds:KeyInfo> + <xenc:CipherData + xmlns:xenc="http://www.w3.org/2001/04/xmlenc#"> + <xenc:CipherValue>L6uCTESvYUr0/noTZHibyglkfaV8zMpXBk36bm8sofAv9JoYSO3HeWXkSKN7QT8vjOX9JZ32sg4fgJyE0uoSph9zx3YyRu2MMstY+zaD3mM3FdXgSBkmMwLcQ1ESBNXlp/8bLyTkQlE4cBhLnsJbgK/nR1Dss0DR1vZXRg3yg+g=</xenc:CipherValue> + </xenc:CipherData> + <xenc:ReferenceList> + <xenc:DataReference URI="#_09c17cde95c4d7b85e79b1e4fb70907c"/> + </xenc:ReferenceList> + </xenc:EncryptedKey> + </saml2:EncryptedAssertion>
\ No newline at end of file diff --git a/tests/test_40_sigver.py b/tests/test_40_sigver.py index c963533e..f975b5ea 100644 --- a/tests/test_40_sigver.py +++ b/tests/test_40_sigver.py @@ -26,6 +26,8 @@ from pathutils import full_path SIGNED = full_path("saml_signed.xml") UNSIGNED = full_path("saml_unsigned.xml") SIMPLE_SAML_PHP_RESPONSE = full_path("simplesamlphp_authnresponse.xml") +OKTA_RESPONSE = full_path("okta_response.xml") +OKTA_ASSERTION = full_path("okta_assertion") PUB_KEY = full_path("test.pem") PRIV_KEY = full_path("test.key") @@ -112,7 +114,6 @@ class FakeConfig(): key_file = PRIV_KEY encryption_keypairs = [{"key_file": ENC_PRIV_KEY, "cert_file": ENC_PUB_KEY}] enc_key_files = [ENC_PRIV_KEY] - debug = False cert_handler_extra_class = None generate_cert_func = None generate_cert_info = False @@ -137,7 +138,7 @@ class TestSecurity(): # (TestSecurityMetadata below) excersise the SPConfig() mechanism. # conf = FakeConfig() - self.sec = sigver.security_context(FakeConfig()) + self.sec = sigver.security_context(conf) self._assertion = factory( saml.Assertion, @@ -474,7 +475,6 @@ def test_xbox(): str(encrypted_assertion), conf.cert_file, pre, "des-192", '/*[local-name()="EncryptedAssertion"]/*[local-name()="Assertion"]') - decr_text = sec.decrypt(enctext) _seass = saml.encrypted_assertion_from_string(decr_text) assertions = [] @@ -495,6 +495,30 @@ def test_xbox(): print(assertions) +def test_okta(): + conf = config.Config() + conf.load_file("server_conf") + conf.id_attr_name = 'Id' + md = MetadataStore([saml, samlp], None, conf) + md.load("local", full_path("idp_example.xml")) + + conf.metadata = md + conf.only_use_keys_in_metadata = False + sec = sigver.security_context(conf) + with open(OKTA_RESPONSE) as f: + enctext = f.read() + decr_text = sec.decrypt(enctext) + _seass = saml.encrypted_assertion_from_string(decr_text) + assers = extension_elements_to_elements(_seass.extension_elements, + [saml, samlp]) + + with open(OKTA_ASSERTION) as f: + okta_assertion = f.read() + expected_assert = assertion_from_string(okta_assertion) + assert len(assers) == 1 + assert assers[0] == expected_assert + + def test_xmlsec_err(): conf = config.SPConfig() conf.load_file("server_conf") |