diff options
-rw-r--r-- | src/OpenSSL/crypto.py | 26 | ||||
-rw-r--r-- | tests/test_crypto.py | 13 |
2 files changed, 23 insertions, 16 deletions
diff --git a/src/OpenSSL/crypto.py b/src/OpenSSL/crypto.py index 77b2c6c..894d5fa 100644 --- a/src/OpenSSL/crypto.py +++ b/src/OpenSSL/crypto.py @@ -1,9 +1,8 @@ import calendar import datetime - from base64 import b16encode +import functools from functools import partial -from operator import __eq__, __ne__, __lt__, __le__, __gt__, __ge__ from cryptography import utils, x509 from cryptography.hazmat.primitives.asymmetric import dsa, rsa @@ -528,6 +527,7 @@ def get_elliptic_curve(name): raise ValueError("unknown curve name", name) +@functools.total_ordering class X509Name: """ An X.509 Distinguished Name. @@ -642,23 +642,17 @@ class X509Name: _lib.OPENSSL_free(result_buffer[0]) return result - def _cmp(op): - def f(self, other): - if not isinstance(other, X509Name): - return NotImplemented - result = _lib.X509_NAME_cmp(self._name, other._name) - return op(result, 0) - - return f + def __eq__(self, other): + if not isinstance(other, X509Name): + return NotImplemented - __eq__ = _cmp(__eq__) - __ne__ = _cmp(__ne__) + return _lib.X509_NAME_cmp(self._name, other._name) == 0 - __lt__ = _cmp(__lt__) - __le__ = _cmp(__le__) + def __lt__(self, other): + if not isinstance(other, X509Name): + return NotImplemented - __gt__ = _cmp(__gt__) - __ge__ = _cmp(__ge__) + return _lib.X509_NAME_cmp(self._name, other._name) < 0 def __repr__(self): """ diff --git a/tests/test_crypto.py b/tests/test_crypto.py index d66e257..ca2a17a 100644 --- a/tests/test_crypto.py +++ b/tests/test_crypto.py @@ -1398,6 +1398,19 @@ class TestX509Name: # other X509Name. assert_greater_than(x509_name(CN="def"), x509_name(CN="abc")) + def assert_raises(a, b): + with pytest.raises(TypeError): + a < b + with pytest.raises(TypeError): + a <= b + with pytest.raises(TypeError): + a > b + with pytest.raises(TypeError): + a >= b + + # Only X509Name objects can be compared with lesser than / greater than + assert_raises(x509_name(), object()) + def test_hash(self): """ `X509Name.hash` returns an integer hash based on the value of the name. |