diff options
author | Robert Bradshaw <robertwb@gmail.com> | 2020-05-23 22:37:59 -0700 |
---|---|---|
committer | Robert Bradshaw <robertwb@gmail.com> | 2020-05-23 23:27:54 -0700 |
commit | e6a812402b0368cf930a55ed465a38820f606054 (patch) | |
tree | 9772127f5473bf23a2ad4bd2940caa351a7d3931 | |
parent | 9c78524a726a898b6abaa33c2f6054c760f1d7b0 (diff) | |
download | cython-e6a812402b0368cf930a55ed465a38820f606054.tar.gz |
Python-style binary operation methods.
-rw-r--r-- | CHANGES.rst | 6 | ||||
-rw-r--r-- | Cython/Compiler/ModuleNode.py | 43 | ||||
-rw-r--r-- | Cython/Compiler/Options.py | 1 | ||||
-rw-r--r-- | Cython/Compiler/TypeSlots.py | 44 | ||||
-rw-r--r-- | Cython/Utility/ExtensionTypes.c | 56 | ||||
-rw-r--r-- | tests/run/binop_reverse_methods_GH2056.pyx | 72 |
6 files changed, 205 insertions, 17 deletions
diff --git a/CHANGES.rst b/CHANGES.rst index 534bf5f83..c82aae518 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -56,6 +56,12 @@ Bugs fixed * The signature of the NumPy C-API function ``PyArray_SearchSorted()`` was fixed. Patch by Brock Mendel. (Github issue #3606) +* Added support for Python binary operator semantics. + One can now define, e.g. both ``__add__`` and ``__radd__`` for cdef classes + as for standard Python classes rather than a single ``__add__`` method where + self can be either the first or second argument. (Github issue #2056) + This behavior is guarded by the c_api_binop_methods directive. + 0.29.17 (2020-04-26) ==================== diff --git a/Cython/Compiler/ModuleNode.py b/Cython/Compiler/ModuleNode.py index ede21ff09..193e01ee8 100644 --- a/Cython/Compiler/ModuleNode.py +++ b/Cython/Compiler/ModuleNode.py @@ -29,7 +29,7 @@ from . import Pythran from .Errors import error, warning from .PyrexTypes import py_object_type from ..Utils import open_new_file, replace_suffix, decode_filename, build_hex_version -from .Code import UtilityCode, IncludeCode +from .Code import UtilityCode, IncludeCode, TempitaUtilityCode from .StringEncoding import EncodedString from .Pythran import has_np_pythran @@ -1255,6 +1255,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): self.generate_dict_getter_function(scope, code) if scope.defines_any_special(TypeSlots.richcmp_special_methods): self.generate_richcmp_function(scope, code) + for slot in TypeSlots.PyNumberMethods: + if slot.is_binop and scope.defines_any_special(slot.user_methods): + self.generate_binop_function(scope, slot, code) self.generate_property_accessors(scope, code) self.generate_method_table(scope, code) self.generate_getset_table(scope, code) @@ -1892,6 +1895,44 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): code.putln("}") # switch code.putln("}") + def generate_binop_function(self, scope, slot, code): + func_name = scope.mangle_internal(slot.slot_name) + code.putln() + preprocessor_guard = slot.preprocessor_guard_code() + if preprocessor_guard: + code.putln(preprocessor_guard) + if scope.directives['c_api_binop_methods']: + code.putln('#define %s %s' % (func_name, slot.left_slot.slot_code(scope))) + else: + def has_slot_method(method_name): + entry = scope.lookup(method_name) + return bool(entry and entry.is_special and entry.func_cname) + def call_slot_method(method_name, reverse): + entry = scope.lookup(method_name) + if reverse: + operands = "right, left" + else: + operands = "left, right" + if entry and entry.is_special and entry.func_cname: + return "%s(%s)" % (entry.func_cname, operands) + else: + py_ident = code.intern_identifier(EncodedString(method_name)) + return "%s_maybe_call_super(%s, %s)" % (func_name, operands, py_ident) + code.putln( + TempitaUtilityCode.load_cached( + "BinopSlot", "ExtensionTypes.c", + context={ + "func_name": func_name, + "slot_name": slot.slot_name, + "overloads_left": int(has_slot_method(slot.left_slot.method_name)), + "call_left": call_slot_method(slot.left_slot.method_name, reverse=False), + "call_right": call_slot_method(slot.right_slot.method_name, reverse=True), + "type_cname": '((PyTypeObject*) %s)' % scope.namespace_cname, + }).impl.strip()) + code.putln() + if preprocessor_guard: + code.putln("#endif") + def generate_getattro_function(self, scope, code): # First try to get the attribute using __getattribute__, if defined, or # PyObject_GenericGetAttr. diff --git a/Cython/Compiler/Options.py b/Cython/Compiler/Options.py index a634aaf56..db497aa40 100644 --- a/Cython/Compiler/Options.py +++ b/Cython/Compiler/Options.py @@ -178,6 +178,7 @@ _directive_defaults = { 'auto_pickle': None, 'cdivision': False, # was True before 0.12 'cdivision_warnings': False, + 'c_api_binop_methods': True, # Change for 3.0 'overflowcheck': False, 'overflowcheck.fold': True, 'always_allow_keywords': False, diff --git a/Cython/Compiler/TypeSlots.py b/Cython/Compiler/TypeSlots.py index 137ea4eba..50e09cd74 100644 --- a/Cython/Compiler/TypeSlots.py +++ b/Cython/Compiler/TypeSlots.py @@ -180,13 +180,14 @@ class SlotDescriptor(object): # ifdef Full #ifdef string that slot is wrapped in. Using this causes py3, py2 and flags to be ignored.) def __init__(self, slot_name, dynamic=False, inherited=False, - py3=True, py2=True, ifdef=None): + py3=True, py2=True, ifdef=None, is_binop=False): self.slot_name = slot_name self.is_initialised_dynamically = dynamic self.is_inherited = inherited self.ifdef = ifdef self.py3 = py3 self.py2 = py2 + self.is_binop = is_binop def preprocessor_guard_code(self): ifdef = self.ifdef @@ -405,6 +406,17 @@ class SyntheticSlot(InternalMethodSlot): return self.default_value +class BinopSlot(SyntheticSlot): + def __init__(self, signature, slot_name, left_method, **kargs): + assert left_method.startswith('__') + right_method = '__r' + left_method[2:] + SyntheticSlot.__init__( + self, slot_name, [left_method, right_method], "0", is_binop=True, **kargs) + # MethodSlot causes special method registration. + self.left_slot = MethodSlot(signature, "", left_method) + self.right_slot = MethodSlot(signature, "", right_method) + + class RichcmpSlot(MethodSlot): def slot_code(self, scope): entry = scope.lookup_here(self.method_name) @@ -728,23 +740,23 @@ property_accessor_signatures = { PyNumberMethods_Py3_GUARD = "PY_MAJOR_VERSION < 3 || (CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX < 0x03050000)" PyNumberMethods = ( - MethodSlot(binaryfunc, "nb_add", "__add__"), - MethodSlot(binaryfunc, "nb_subtract", "__sub__"), - MethodSlot(binaryfunc, "nb_multiply", "__mul__"), - MethodSlot(binaryfunc, "nb_divide", "__div__", ifdef = PyNumberMethods_Py3_GUARD), - MethodSlot(binaryfunc, "nb_remainder", "__mod__"), - MethodSlot(binaryfunc, "nb_divmod", "__divmod__"), - MethodSlot(ternaryfunc, "nb_power", "__pow__"), + BinopSlot(binaryfunc, "nb_add", "__add__"), + BinopSlot(binaryfunc, "nb_subtract", "__sub__"), + BinopSlot(binaryfunc, "nb_multiply", "__mul__"), + BinopSlot(binaryfunc, "nb_divide", "__div__", ifdef = PyNumberMethods_Py3_GUARD), + BinopSlot(binaryfunc, "nb_remainder", "__mod__"), + BinopSlot(binaryfunc, "nb_divmod", "__divmod__"), + BinopSlot(ternaryfunc, "nb_power", "__pow__"), MethodSlot(unaryfunc, "nb_negative", "__neg__"), MethodSlot(unaryfunc, "nb_positive", "__pos__"), MethodSlot(unaryfunc, "nb_absolute", "__abs__"), MethodSlot(inquiry, "nb_nonzero", "__nonzero__", py3 = ("nb_bool", "__bool__")), MethodSlot(unaryfunc, "nb_invert", "__invert__"), - MethodSlot(binaryfunc, "nb_lshift", "__lshift__"), - MethodSlot(binaryfunc, "nb_rshift", "__rshift__"), - MethodSlot(binaryfunc, "nb_and", "__and__"), - MethodSlot(binaryfunc, "nb_xor", "__xor__"), - MethodSlot(binaryfunc, "nb_or", "__or__"), + BinopSlot(binaryfunc, "nb_lshift", "__lshift__"), + BinopSlot(binaryfunc, "nb_rshift", "__rshift__"), + BinopSlot(binaryfunc, "nb_and", "__and__"), + BinopSlot(binaryfunc, "nb_xor", "__xor__"), + BinopSlot(binaryfunc, "nb_or", "__or__"), EmptySlot("nb_coerce", ifdef = PyNumberMethods_Py3_GUARD), MethodSlot(unaryfunc, "nb_int", "__int__", fallback="__long__"), MethodSlot(unaryfunc, "nb_long", "__long__", fallback="__int__", py3 = "<RESERVED>"), @@ -767,8 +779,8 @@ PyNumberMethods = ( # Added in release 2.2 # The following require the Py_TPFLAGS_HAVE_CLASS flag - MethodSlot(binaryfunc, "nb_floor_divide", "__floordiv__"), - MethodSlot(binaryfunc, "nb_true_divide", "__truediv__"), + BinopSlot(binaryfunc, "nb_floor_divide", "__floordiv__"), + BinopSlot(binaryfunc, "nb_true_divide", "__truediv__"), MethodSlot(ibinaryfunc, "nb_inplace_floor_divide", "__ifloordiv__"), MethodSlot(ibinaryfunc, "nb_inplace_true_divide", "__itruediv__"), @@ -776,7 +788,7 @@ PyNumberMethods = ( MethodSlot(unaryfunc, "nb_index", "__index__"), # Added in release 3.5 - MethodSlot(binaryfunc, "nb_matrix_multiply", "__matmul__", ifdef="PY_VERSION_HEX >= 0x03050000"), + BinopSlot(binaryfunc, "nb_matrix_multiply", "__matmul__", ifdef="PY_VERSION_HEX >= 0x03050000"), MethodSlot(ibinaryfunc, "nb_inplace_matrix_multiply", "__imatmul__", ifdef="PY_VERSION_HEX >= 0x03050000"), ) diff --git a/Cython/Utility/ExtensionTypes.c b/Cython/Utility/ExtensionTypes.c index 1b39c9e42..d34916642 100644 --- a/Cython/Utility/ExtensionTypes.c +++ b/Cython/Utility/ExtensionTypes.c @@ -278,3 +278,59 @@ __PYX_GOOD: Py_XDECREF(setstate_cython); return ret; } + +/////////////// BinopSlot /////////////// + +static CYTHON_INLINE PyObject *{{func_name}}_maybe_call_super(PyObject *self, PyObject *other, PyObject* name) { + PyObject *res; + PyObject *method; + if (!Py_TYPE(self)->tp_base) { + return Py_INCREF(Py_NotImplemented), Py_NotImplemented; + } + // TODO: Use _PyType_LookupId or similar. + method = PyObject_GetAttr((PyObject*) Py_TYPE(self)->tp_base, name); + if (!method) { + PyErr_Clear(); + return Py_INCREF(Py_NotImplemented), Py_NotImplemented; + } + res = __Pyx_PyObject_Call2Args(method, self, other); + Py_DECREF(method); + if (!res) { + return Py_INCREF(Py_NotImplemented), Py_NotImplemented; + } + return res; +} + +static PyObject *{{func_name}}(PyObject *left, PyObject *right) { + PyObject *res; + int maybe_self_is_left, maybe_self_is_right = 0; + maybe_self_is_left = Py_TYPE(left) == Py_TYPE(right) + || (Py_TYPE(left)->tp_as_number && Py_TYPE(left)->tp_as_number->{{slot_name}} == &{{func_name}}) + || PyType_IsSubtype(Py_TYPE(left), {{type_cname}}); + // Optimize for the common case where the left operation is defined (and successful). + if (!{{overloads_left}}) { + maybe_self_is_right = Py_TYPE(left) == Py_TYPE(right) + || (Py_TYPE(right)->tp_as_number && Py_TYPE(right)->tp_as_number->{{slot_name}} == &{{func_name}}) + || PyType_IsSubtype(Py_TYPE(right), {{type_cname}}); + } + if (maybe_self_is_left) { + if (maybe_self_is_right && !{{overloads_left}}) { + res = {{call_right}}; + if (res != Py_NotImplemented) return res; + Py_DECREF(res); + maybe_self_is_right = 0; // Don't bother calling it again. + } + res = {{call_left}}; + if (res != Py_NotImplemented) return res; + Py_DECREF(res); + } + if ({{overloads_left}}) { + maybe_self_is_right = Py_TYPE(left) == Py_TYPE(right) + || (Py_TYPE(right)->tp_as_number && Py_TYPE(right)->tp_as_number->{{slot_name}} == &{{func_name}}) + || PyType_IsSubtype(Py_TYPE(right), {{type_cname}}); + } + if (maybe_self_is_right) { + return {{call_right}}; + } + return Py_INCREF(Py_NotImplemented), Py_NotImplemented; +} diff --git a/tests/run/binop_reverse_methods_GH2056.pyx b/tests/run/binop_reverse_methods_GH2056.pyx new file mode 100644 index 000000000..480853d93 --- /dev/null +++ b/tests/run/binop_reverse_methods_GH2056.pyx @@ -0,0 +1,72 @@ +cimport cython + +@cython.c_api_binop_methods(False) +@cython.cclass +class Base(object): + """ + >>> Base() + 2 + 'Base.__add__(Base(), 2)' + >>> 2 + Base() + 'Base.__radd__(Base(), 2)' + """ + def __add__(self, other): + return "Base.__add__(%s, %s)" % (self, other) + + def __radd__(self, other): + return "Base.__radd__(%s, %s)" % (self, other) + + def __repr__(self): + return "%s()" % (self.__class__.__name__) + +@cython.c_api_binop_methods(False) +@cython.cclass +class OverloadLeft(Base): + """ + >>> OverloadLeft() + 2 + 'OverloadLeft.__add__(OverloadLeft(), 2)' + >>> 2 + OverloadLeft() + 'Base.__radd__(OverloadLeft(), 2)' + + >>> OverloadLeft() + Base() + 'OverloadLeft.__add__(OverloadLeft(), Base())' + >>> Base() + OverloadLeft() + 'Base.__add__(Base(), OverloadLeft())' + """ + def __add__(self, other): + return "OverloadLeft.__add__(%s, %s)" % (self, other) + + +@cython.c_api_binop_methods(False) +@cython.cclass +class OverloadRight(Base): + """ + >>> OverloadRight() + 2 + 'Base.__add__(OverloadRight(), 2)' + >>> 2 + OverloadRight() + 'OverloadRight.__radd__(OverloadRight(), 2)' + + >>> OverloadRight() + Base() + 'Base.__add__(OverloadRight(), Base())' + >>> Base() + OverloadRight() + 'OverloadRight.__radd__(OverloadRight(), Base())' + """ + def __radd__(self, other): + return "OverloadRight.__radd__(%s, %s)" % (self, other) + +@cython.c_api_binop_methods(True) +@cython.cclass +class OverloadCApi(Base): + """ + >>> OverloadCApi() + 2 + 'OverloadCApi.__add__(OverloadCApi(), 2)' + >>> 2 + OverloadCApi() + 'OverloadCApi.__add__(2, OverloadCApi())' + + >>> OverloadCApi() + Base() + 'OverloadCApi.__add__(OverloadCApi(), Base())' + >>> Base() + OverloadCApi() + 'OverloadCApi.__add__(Base(), OverloadCApi())' + """ + def __add__(self, other): + return "OverloadCApi.__add__(%s, %s)" % (self, other) + |