diff options
author | Ivan Kanakarakis <ivan.kanak@gmail.com> | 2019-09-01 21:17:01 +0200 |
---|---|---|
committer | Ivan Kanakarakis <ivan.kanak@gmail.com> | 2019-11-26 14:02:27 +0200 |
commit | 2109a65b1a233d42da84cc2aad982bf8a4b49816 (patch) | |
tree | b7b265b89e013f90be80502fabddefc74e761a04 | |
parent | 2a0dda1beeeb1b5d7f6fa31b628306ae8146996a (diff) | |
download | pysaml2-2109a65b1a233d42da84cc2aad982bf8a4b49816.tar.gz |
Read from env var PYSAML2_DELETE_TMPFILES
Signed-off-by: Ivan Kanakarakis <ivan.kanak@gmail.com>
-rw-r--r-- | docs/howto/config.rst | 16 | ||||
-rw-r--r-- | src/saml2/entity.py | 8 | ||||
-rw-r--r-- | src/saml2/sigver.py | 136 | ||||
-rw-r--r-- | tests/test_40_sigver.py | 8 | ||||
-rw-r--r-- | tests/test_50_server.py | 80 |
5 files changed, 110 insertions, 138 deletions
diff --git a/docs/howto/config.rst b/docs/howto/config.rst index 3a27d199..ddb41194 100644 --- a/docs/howto/config.rst +++ b/docs/howto/config.rst @@ -1,5 +1,15 @@ .. _howto_config: +Environment variables +===================== + +PYSAML2_DELETE_TMPFILES +^^^^^^^^^^^^^^^^^^^^^^^ + +If set to "False" will keep temporary xml files in the system temporary storage. +Default: "true"; delete temporary files. + + Configuration of pySAML2 entities ================================= @@ -157,12 +167,6 @@ Format:: Whether debug information should be sent to the log file. -os.environ['PYSAML2_DELETE_XMLSEC_TMP'] -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -If set to "False" will keep temporary xml files in `/tmp`. -Default: True, delete temporary files. - entityid ^^^^^^^^ diff --git a/src/saml2/entity.py b/src/saml2/entity.py index 0e2cc94c..f76ae67d 100644 --- a/src/saml2/entity.py +++ b/src/saml2/entity.py @@ -144,8 +144,8 @@ class Entity(HTTPBase): if _val.startswith("http"): r = requests.request("GET", _val) if r.status_code == 200: - _, filename = make_temp(r.text, ".pem", False) - setattr(self.config, item, filename) + tmp = make_temp(r.text, ".pem", False) + setattr(self.config, item, tmp.name) else: raise Exception( "Could not fetch certificate from %s" % _val) @@ -560,8 +560,8 @@ class Entity(HTTPBase): _cert = "%s%s" % (begin_cert, _cert) if end_cert not in _cert: _cert = "%s%s" % (_cert, end_cert) - _, cert_file = make_temp(_cert.encode('ascii'), decode=False) - response = self.sec.encrypt_assertion(response, cert_file, + tmp = make_temp(_cert.encode('ascii'), decode=False) + response = self.sec.encrypt_assertion(response, tmp.name, pre_encryption_part(), node_xpath=node_xpath) return response diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index 036cad09..aaaa412c 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -210,15 +210,20 @@ NODE_NAME = 'urn:oasis:names:tc:SAML:2.0:assertion:Assertion' ENC_NODE_NAME = 'urn:oasis:names:tc:SAML:2.0:assertion:EncryptedAssertion' ENC_KEY_CLASS = 'EncryptedKey' + def get_environ_delete_tmpfiles(): - xmlsec_delete_tmpfiles = os.environ.get('PYSAML2_DELETE_XMLSEC_TMP', "True") - if xmlsec_delete_tmpfiles.upper() == 'FALSE': - xmlsec_delete_tmpfiles = False - logger.warn('PYSAML2_DELETE_XMLSEC_TMP set to False, ' - 'temporary xml files will not be deleted.') - else: - xmlsec_delete_tmpfiles = True - return xmlsec_delete_tmpfiles + default = "true" + value = os.environ.get("PYSAML2_DELETE_TMPFILES", default) + result = value.lower() == default + + if not result: + logger.warning( + "PYSAML2_DELETE_TMPFILES set to False, " + "temporary xml files will not be deleted." + ) + + return result + def _make_vals(val, klass, seccont, klass_inst=None, prop=None, part=False, base64encode=False, elements_to_sign=None): @@ -331,31 +336,31 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None): return instance -def make_temp(string, suffix='', decode=True, delete=True): - """ xmlsec needs files in some cases where only strings exist, hence the - need for this function. It creates a temporary file with the - string as only content. +def make_temp(content, suffix="", decode=True): + """ + Create a temporary file with the given content. + + This is needed by xmlsec in some cases where only strings exist when files + are expected. - :param string: The information to be placed in the file + :param content: The information to be placed in the file :param suffix: The temporary file might have to have a specific suffix in certain circumstances. - :param decode: The input string might be base64 coded. If so it + :param decode: The input content might be base64 coded. If so it must, in some cases, be decoded before being placed in the file. :return: 2-tuple with file pointer ( so the calling function can close the file) and filename (which is for instance needed by the xmlsec function). """ - ntf = NamedTemporaryFile(suffix=suffix, delete=delete) - # Python3 tempfile requires byte-like object - if not isinstance(string, six.binary_type): - string = string.encode('utf-8') - - if decode: - ntf.write(base64.b64decode(string)) - else: - ntf.write(string) + content_encoded = ( + content.encode("utf-8") if not isinstance(content, six.binary_type) else content + ) + content_raw = base64.b64decode(content_encoded) if decode else content_encoded + delete_tmpfiles = get_environ_delete_tmpfiles() + ntf = NamedTemporaryFile(suffix=suffix, delete=delete_tmpfiles) + ntf.write(content_raw) ntf.seek(0) - return ntf, ntf.name + return ntf def split_len(seq, length): @@ -688,7 +693,6 @@ class CryptoBackendXmlSec1(CryptoBackend): CryptoBackend.__init__(self, **kwargs) assert (isinstance(xmlsec_binary, six.string_types)) self.xmlsec = xmlsec_binary - self._xmlsec_delete_tmpfiles = get_environ_delete_tmpfiles() try: self.non_xml_crypto = RSACrypto(kwargs['rsa_key']) @@ -717,13 +721,13 @@ class CryptoBackendXmlSec1(CryptoBackend): :return: """ logger.debug('Encryption input len: %d', len(text)) - f, fil = make_temp(text, decode=False) + tmp = make_temp(text, decode=False) com_list = [ self.xmlsec, '--encrypt', '--pubkey-cert-pem', recv_key, '--session-key', session_key_type, - '--xml-data', fil, + '--xml-data', tmp.name, ] if xpath: @@ -754,9 +758,8 @@ class CryptoBackendXmlSec1(CryptoBackend): if isinstance(statement, SamlBase): statement = pre_encrypt_assertion(statement) - f, fil = make_temp( - _str(statement), decode=False) - t, tmpl = make_temp(_str(template), decode=False) + tmp = make_temp(_str(statement), decode=False) + tmp2 = make_temp(_str(template), decode=False) if not node_xpath: node_xpath = ASSERT_XPATH @@ -766,7 +769,7 @@ class CryptoBackendXmlSec1(CryptoBackend): '--encrypt', '--pubkey-cert-pem', enc_key, '--session-key', key_type, - '--xml-data', fil, + '--xml-data', tmp.name, '--node-xpath', node_xpath, ] @@ -774,7 +777,7 @@ class CryptoBackendXmlSec1(CryptoBackend): com_list.extend(['--node-id', node_id]) try: - (_stdout, _stderr, output) = self._run_xmlsec(com_list, [tmpl]) + (_stdout, _stderr, output) = self._run_xmlsec(com_list, [tmp2.name]) except XmlsecError as e: six.raise_from(EncryptError(com_list), e) @@ -789,7 +792,7 @@ class CryptoBackendXmlSec1(CryptoBackend): """ logger.debug('Decrypt input len: %d', len(enctext)) - _, fil = make_temp(enctext, decode=False) + tmp = make_temp(enctext, decode=False) com_list = [ self.xmlsec, @@ -800,7 +803,7 @@ class CryptoBackendXmlSec1(CryptoBackend): ] try: - (_stdout, _stderr, output) = self._run_xmlsec(com_list, [fil]) + (_stdout, _stderr, output) = self._run_xmlsec(com_list, [tmp.name]) except XmlsecError as e: six.raise_from(DecryptError(com_list), e) @@ -821,12 +824,7 @@ class CryptoBackendXmlSec1(CryptoBackend): if isinstance(statement, SamlBase): statement = str(statement) - _, fil = make_temp( - statement, - suffix='.xml', - decode=False, - delete=self._xmlsec_delete_tmpfiles, - ) + tmp = make_temp(statement, suffix=".xml", decode=False) com_list = [ self.xmlsec, @@ -840,7 +838,7 @@ class CryptoBackendXmlSec1(CryptoBackend): com_list.extend(['--node-id', node_id]) try: - (stdout, stderr, output) = self._run_xmlsec(com_list, [fil]) + (stdout, stderr, output) = self._run_xmlsec(com_list, [tmp.name]) except XmlsecError as e: raise SignatureError(com_list) @@ -867,12 +865,7 @@ class CryptoBackendXmlSec1(CryptoBackend): if not isinstance(signedtext, six.binary_type): signedtext = signedtext.encode('utf-8') - _, fil = make_temp( - signedtext, - suffix='.xml', - decode=False, - delete=self._xmlsec_delete_tmpfiles, - ) + tmp = make_temp(signedtext, suffix=".xml", decode=False) com_list = [ self.xmlsec, @@ -887,7 +880,7 @@ class CryptoBackendXmlSec1(CryptoBackend): com_list.extend(['--node-id', node_id]) try: - (_stdout, stderr, _output) = self._run_xmlsec(com_list, [fil]) + (_stdout, stderr, _output) = self._run_xmlsec(com_list, [tmp.name]) except XmlsecError as e: six.raise_from(SignatureError(com_list), e) @@ -901,7 +894,7 @@ class CryptoBackendXmlSec1(CryptoBackend): key-value parameters :result: Whatever xmlsec wrote to an --output temporary file """ - with NamedTemporaryFile(suffix='.xml', delete=self._xmlsec_delete_tmpfiles) as ntf: + with NamedTemporaryFile(suffix='.xml') as ntf: com_list.extend(['--output', ntf.name]) com_list += extra_args @@ -1311,8 +1304,6 @@ class SecurityContext(object): self.template = template self.encrypt_key_type = encrypt_key_type - # keep certificate files to debug xmlsec invocations - self._xmlsec_delete_tmpfiles = get_environ_delete_tmpfiles() def correctly_signed(self, xml, must=False): logger.debug('verify correct signature') @@ -1364,15 +1355,16 @@ class SecurityContext(object): if not isinstance(keys, list): keys = [keys] - keys = [key for key in keys if key] - for key in keys: - if not isinstance(key, six.binary_type): - key = key.encode("ascii") - key_file, _ = make_temp(key, decode=False) - key_files.append(key_file) + keys_filtered = (key for key in keys if key) + keys_encoded = ( + key.encode("ascii") if not isinstance(key, six.binary_type) else key + for key in keys_filtered + ) + key_files = list(make_temp(key, decode=False) for key in keys_encoded) + key_file_names = list(tmp.name for tmp in key_files) try: - dectext = self.decrypt(enctext, key_file=[x.name for x in key_files], id_attr=id_attr) + dectext = self.decrypt(enctext, key_file=key_file_names, id_attr=id_attr) except DecryptError as e: raise else: @@ -1457,14 +1449,9 @@ class SecurityContext(object): for cert in _certs: if isinstance(cert, six.string_types): - certs.append( - make_temp( - pem_format(cert), - suffix='.pem', - decode=False, - delete=self._xmlsec_delete_tmpfiles, - ) - ) + content = pem_format(cert) + tmp = make_temp(content, suffix=".pem", decode=False) + certs.append(tmp) else: certs.append(cert) else: @@ -1473,12 +1460,7 @@ class SecurityContext(object): if not certs and not self.only_use_keys_in_metadata: logger.debug('==== Certs from instance ====') certs = [ - make_temp( - pem_format(cert), - suffix='.pem', - decode=False, - delete=self._xmlsec_delete_tmpfiles, - ) + make_temp(content=pem_format(cert), suffix=".pem", decode=False) for cert in cert_from_instance(item) ] else: @@ -1490,12 +1472,12 @@ class SecurityContext(object): verified = False last_pem_file = None - for _, pem_file in certs: + for pem_fd in certs: try: - last_pem_file = pem_file + last_pem_file = pem_fd.name if self.verify_signature( decoded_xml, - pem_file, + pem_fd.name, node_name=node_name, node_id=item.id, id_attr=id_attr): @@ -1665,7 +1647,9 @@ class SecurityContext(object): id_attr = self.id_attr if not key_file and key: - _, key_file = make_temp(str(key).encode(), '.pem') + content = str(key).encode() + tmp = make_temp(content, suffix=".pem") + key_file = tmp.name if not key and not key_file: key_file = self.key_file diff --git a/tests/test_40_sigver.py b/tests/test_40_sigver.py index b7566297..89296dc0 100644 --- a/tests/test_40_sigver.py +++ b/tests/test_40_sigver.py @@ -798,13 +798,13 @@ def test_xbox(): encrypted_assertion = EncryptedAssertion() encrypted_assertion.add_extension_element(_ass0) - _, pre = make_temp( + tmp = make_temp( str(pre_encryption_part()).encode('utf-8'), decode=False ) enctext = sec.crypto.encrypt( str(encrypted_assertion), conf.cert_file, - pre, + tmp.name, "des-192", '/*[local-name()="EncryptedAssertion"]/*[local-name()="Assertion"]', ) @@ -860,13 +860,13 @@ def test_xbox_non_ascii_ava(): encrypted_assertion = EncryptedAssertion() encrypted_assertion.add_extension_element(_ass0) - _, pre = make_temp( + tmp = make_temp( str(pre_encryption_part()).encode('utf-8'), decode=False ) enctext = sec.crypto.encrypt( str(encrypted_assertion), conf.cert_file, - pre, + tmp.name, "des-192", '/*[local-name()="EncryptedAssertion"]/*[local-name()="Assertion"]', ) diff --git a/tests/test_50_server.py b/tests/test_50_server.py index 3d828171..4a262976 100644 --- a/tests/test_50_server.py +++ b/tests/test_50_server.py @@ -563,9 +563,8 @@ class TestServer1(): assert valid - _, key_file = make_temp(cert_key_str, decode=False) - - decr_text = self.server.sec.decrypt(signed_resp, key_file) + key_fd = make_temp(cert_key_str, decode=False) + decr_text = self.server.sec.decrypt(signed_resp, key_fd.name) resp = samlp.response_from_string(decr_text) @@ -658,9 +657,8 @@ class TestServer1(): id_attr="") assert valid - _, key_file = make_temp(cert_key_str, decode=False) - - decr_text = self.server.sec.decrypt(signed_resp, key_file) + key_fd = make_temp(cert_key_str, decode=False) + decr_text = self.server.sec.decrypt(signed_resp, key_fd.name) resp = samlp.response_from_string(decr_text) @@ -720,9 +718,8 @@ class TestServer1(): assert valid - _, key_file = make_temp(cert_key_str, decode=False) - - decr_text = self.server.sec.decrypt(decr_text, key_file) + key_fd = make_temp(cert_key_str, decode=False) + decr_text = self.server.sec.decrypt(decr_text, key_fd.name) resp = samlp.response_from_string(decr_text) @@ -763,9 +760,8 @@ class TestServer1(): assert sresponse.signature is None - _, key_file = make_temp(cert_key_str_advice, decode=False) - - decr_text = self.server.sec.decrypt(_resp, key_file) + key_fd = make_temp(cert_key_str_advice, decode=False) + decr_text = self.server.sec.decrypt(_resp, key_fd.name) resp = samlp.response_from_string(decr_text) @@ -795,9 +791,8 @@ class TestServer1(): decr_text_1 = self.server.sec.decrypt(_resp, self.client.config.encryption_keypairs[1]["key_file"]) - _, key_file = make_temp(cert_key_str_advice, decode=False) - - decr_text_2 = self.server.sec.decrypt(decr_text_1, key_file) + key_fd = make_temp(cert_key_str_advice, decode=False) + decr_text_2 = self.server.sec.decrypt(decr_text_1, key_fd.name) resp = samlp.response_from_string(decr_text_2) @@ -826,9 +821,8 @@ class TestServer1(): assert sresponse.signature is None - _, key_file = make_temp(cert_key_str_assertion, decode=False) - - decr_text = self.server.sec.decrypt(_resp, key_file) + key_fd = make_temp(cert_key_str_assertion, decode=False) + decr_text = self.server.sec.decrypt(_resp, key_fd.name) resp = samlp.response_from_string(decr_text) @@ -918,13 +912,11 @@ class TestServer1(): assert sresponse.signature is None - _, key_file = make_temp(cert_key_str_assertion, decode=False) - - decr_text_1 = _server.sec.decrypt(_resp, key_file) - - _, key_file = make_temp(cert_key_str_advice, decode=False) + key_fd1 = make_temp(cert_key_str_assertion, decode=False) + decr_text_1 = _server.sec.decrypt(_resp, key_fd1.name) - decr_text_2 = _server.sec.decrypt(decr_text_1, key_file) + key_fd2 = make_temp(cert_key_str_advice, decode=False) + decr_text_2 = _server.sec.decrypt(decr_text_1, key_fd2.name) resp = samlp.response_from_string(decr_text_2) @@ -1637,9 +1629,8 @@ class TestServer1NonAsciiAva(): assert valid - _, key_file = make_temp(cert_key_str, decode=False) - - decr_text = self.server.sec.decrypt(signed_resp, key_file) + key_fd = make_temp(cert_key_str, decode=False) + decr_text = self.server.sec.decrypt(signed_resp, key_fd.name) resp = samlp.response_from_string(decr_text) @@ -1732,9 +1723,8 @@ class TestServer1NonAsciiAva(): id_attr="") assert valid - _, key_file = make_temp(cert_key_str, decode=False) - - decr_text = self.server.sec.decrypt(signed_resp, key_file) + key_fd = make_temp(cert_key_str, decode=False) + decr_text = self.server.sec.decrypt(signed_resp, key_fd.name) resp = samlp.response_from_string(decr_text) @@ -1794,9 +1784,8 @@ class TestServer1NonAsciiAva(): assert valid - _, key_file = make_temp(cert_key_str, decode=False) - - decr_text = self.server.sec.decrypt(decr_text, key_file) + key_fd = make_temp(cert_key_str, decode=False) + decr_text = self.server.sec.decrypt(decr_text, key_fd.name) resp = samlp.response_from_string(decr_text) @@ -1837,9 +1826,8 @@ class TestServer1NonAsciiAva(): assert sresponse.signature is None - _, key_file = make_temp(cert_key_str_advice, decode=False) - - decr_text = self.server.sec.decrypt(_resp, key_file) + key_fd = make_temp(cert_key_str_advice, decode=False) + decr_text = self.server.sec.decrypt(_resp, key_fd.name) resp = samlp.response_from_string(decr_text) @@ -1869,9 +1857,8 @@ class TestServer1NonAsciiAva(): decr_text_1 = self.server.sec.decrypt(_resp, self.client.config.encryption_keypairs[1]["key_file"]) - _, key_file = make_temp(cert_key_str_advice, decode=False) - - decr_text_2 = self.server.sec.decrypt(decr_text_1, key_file) + key_fd = make_temp(cert_key_str_advice, decode=False) + decr_text_2 = self.server.sec.decrypt(decr_text_1, key_fd.name) resp = samlp.response_from_string(decr_text_2) @@ -1900,9 +1887,8 @@ class TestServer1NonAsciiAva(): assert sresponse.signature is None - _, key_file = make_temp(cert_key_str_assertion, decode=False) - - decr_text = self.server.sec.decrypt(_resp, key_file) + key_fd = make_temp(cert_key_str_assertion, decode=False) + decr_text = self.server.sec.decrypt(_resp, key_fd.name) resp = samlp.response_from_string(decr_text) @@ -1992,13 +1978,11 @@ class TestServer1NonAsciiAva(): assert sresponse.signature is None - _, key_file = make_temp(cert_key_str_assertion, decode=False) - - decr_text_1 = _server.sec.decrypt(_resp, key_file) - - _, key_file = make_temp(cert_key_str_advice, decode=False) + key_fd1 = make_temp(cert_key_str_assertion, decode=False) + decr_text_1 = _server.sec.decrypt(_resp, key_fd1.name) - decr_text_2 = _server.sec.decrypt(decr_text_1, key_file) + key_fd2 = make_temp(cert_key_str_advice, decode=False) + decr_text_2 = _server.sec.decrypt(decr_text_1, key_fd2.name) resp = samlp.response_from_string(decr_text_2) |