diff options
author | msg555 <msg555@gmail.com> | 2020-04-20 00:06:09 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-20 09:06:09 +0200 |
commit | c0ceabe1ee0904477ed0b0b9371ae43915d7bd2f (patch) | |
tree | 8bef73c797d92ffa606f64a06d02d6bf88710cb9 | |
parent | b692adc13c560551e3ea87e0c0d01bc983248893 (diff) | |
download | cython-c0ceabe1ee0904477ed0b0b9371ae43915d7bd2f.tar.gz |
Update GetItem to support __class_getitem__ for type objects (GH-3518)
Closes #2753.
-rw-r--r-- | Cython/Compiler/Nodes.py | 12 | ||||
-rw-r--r-- | Cython/Utility/ObjectHandling.c | 45 | ||||
-rw-r--r-- | tests/run/test_genericclass.py | 139 |
3 files changed, 181 insertions, 15 deletions
diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 4912ce035..84654bf12 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -39,6 +39,9 @@ else: _py_int_types = (int, long) +IMPLICIT_CLASSMETHODS = {"__init_subclass__", "__class_getitem__"} + + def relative_position(pos): return (pos[0].get_filenametable_entry(), pos[1]) @@ -2438,7 +2441,7 @@ class CFuncDefNode(FuncDefNode): py_func_body = self.call_self_node(is_module_scope=env.is_module_scope) if self.is_static_method: from .ExprNodes import NameNode - decorators = [DecoratorNode(self.pos, decorator=NameNode(self.pos, name='staticmethod'))] + decorators = [DecoratorNode(self.pos, decorator=NameNode(self.pos, name=EncodedString('staticmethod')))] decorators[0].decorator.analyse_types(env) else: decorators = [] @@ -2883,7 +2886,12 @@ class DefNode(FuncDefNode): self.is_staticmethod = False if self.name == '__new__' and env.is_py_class_scope: - self.is_staticmethod = 1 + self.is_staticmethod = True + if not self.is_classmethod and self.name in IMPLICIT_CLASSMETHODS and env.is_py_class_scope: + from .ExprNodes import NameNode + self.decorators = self.decorators or [] + self.decorators.insert(0, DecoratorNode(self.pos, decorator=NameNode(self.pos, name=EncodedString('classmethod')))) + self.is_classmethod = True self.analyse_argument_types(env) if self.name == '<lambda>': diff --git a/Cython/Utility/ObjectHandling.c b/Cython/Utility/ObjectHandling.c index f31d7af88..2e9b50e8d 100644 --- a/Cython/Utility/ObjectHandling.c +++ b/Cython/Utility/ObjectHandling.c @@ -273,24 +273,21 @@ static CYTHON_INLINE int __Pyx_IterFinish(void) { /////////////// ObjectGetItem.proto /////////////// #if CYTHON_USE_TYPE_SLOTS -static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject* key);/*proto*/ +static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key);/*proto*/ #else #define __Pyx_PyObject_GetItem(obj, key) PyObject_GetItem(obj, key) #endif /////////////// ObjectGetItem /////////////// // //@requires: GetItemInt - added in IndexNode as it uses templating. +//@requires: PyObjectGetAttrStrNoError +//@requires: PyObjectCallOneArg #if CYTHON_USE_TYPE_SLOTS -static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject* index) { +static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject *index) { + // Get element from sequence object `obj` at index `index`. PyObject *runerr; Py_ssize_t key_value; - PySequenceMethods *m = Py_TYPE(obj)->tp_as_sequence; - if (unlikely(!(m && m->sq_item))) { - PyErr_Format(PyExc_TypeError, "'%.200s' object is not subscriptable", Py_TYPE(obj)->tp_name); - return NULL; - } - key_value = __Pyx_PyIndex_AsSsize_t(index); if (likely(key_value != -1 || !(runerr = PyErr_Occurred()))) { return __Pyx_GetItemInt_Fast(obj, key_value, 0, 1, 1); @@ -304,12 +301,34 @@ static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject* index) { return NULL; } -static PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject* key) { - PyMappingMethods *m = Py_TYPE(obj)->tp_as_mapping; - if (likely(m && m->mp_subscript)) { - return m->mp_subscript(obj, key); +static PyObject *__Pyx_PyObject_GetItem_Slow(PyObject *obj, PyObject *key) { + // Handles less common slow-path checks for GetItem + if (likely(PyType_Check(obj))) { + PyObject *meth = __Pyx_PyObject_GetAttrStrNoError(obj, PYIDENT("__class_getitem__")); + if (meth) { + PyObject *result = __Pyx_PyObject_CallOneArg(meth, key); + Py_DECREF(meth); + return result; + } + } + + PyErr_Format(PyExc_TypeError, "'%.200s' object is not subscriptable", Py_TYPE(obj)->tp_name); + return NULL; +} + +static PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key) { + PyTypeObject *tp = Py_TYPE(obj); + PyMappingMethods *mm = tp->tp_as_mapping; + if (likely(mm && mm->mp_subscript)) { + return mm->mp_subscript(obj, key); + } + + PySequenceMethods *sm = tp->tp_as_sequence; + if (likely(sm && sm->sq_item)) { + return __Pyx_PyObject_GetIndex(obj, key); } - return __Pyx_PyObject_GetIndex(obj, key); + + return __Pyx_PyObject_GetItem_Slow(obj, key); } #endif diff --git a/tests/run/test_genericclass.py b/tests/run/test_genericclass.py new file mode 100644 index 000000000..e6313cb84 --- /dev/null +++ b/tests/run/test_genericclass.py @@ -0,0 +1,139 @@ +# mode: run +# tag: pure3.7 +# cython: language_level=3 + +# COPIED FROM CPython 3.7 + +import unittest +import sys + + +class TestClassGetitem(unittest.TestCase): + # BEGIN - Additional tests from cython + def test_no_class_getitem(self): + class C: ... + with self.assertRaises(TypeError): + C[int] + + # END - Additional tests from cython + + def test_class_getitem(self): + getitem_args = [] + class C: + def __class_getitem__(*args, **kwargs): + getitem_args.extend([args, kwargs]) + return None + C[int, str] + self.assertEqual(getitem_args[0], (C, (int, str))) + self.assertEqual(getitem_args[1], {}) + + def test_class_getitem_format(self): + class C: + def __class_getitem__(cls, item): + return f'C[{item.__name__}]' + self.assertEqual(C[int], 'C[int]') + self.assertEqual(C[C], 'C[C]') + + def test_class_getitem_inheritance(self): + class C: + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + class D(C): ... + self.assertEqual(D[int], 'D[int]') + self.assertEqual(D[D], 'D[D]') + + def test_class_getitem_inheritance_2(self): + class C: + def __class_getitem__(cls, item): + return 'Should not see this' + class D(C): + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + self.assertEqual(D[int], 'D[int]') + self.assertEqual(D[D], 'D[D]') + + def test_class_getitem_classmethod(self): + class C: + @classmethod + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + class D(C): ... + self.assertEqual(D[int], 'D[int]') + self.assertEqual(D[D], 'D[D]') + + @unittest.skipIf(sys.version_info < (3, 6), "__init_subclass__() requires Py3.6+ (PEP 487)") + def test_class_getitem_patched(self): + class C: + def __init_subclass__(cls): + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + cls.__class_getitem__ = classmethod(__class_getitem__) + class D(C): ... + self.assertEqual(D[int], 'D[int]') + self.assertEqual(D[D], 'D[D]') + + def test_class_getitem_with_builtins(self): + class A(dict): + called_with = None + + def __class_getitem__(cls, item): + cls.called_with = item + class B(A): + pass + self.assertIs(B.called_with, None) + B[int] + self.assertIs(B.called_with, int) + + def test_class_getitem_errors(self): + class C_too_few: + def __class_getitem__(cls): + return None + with self.assertRaises(TypeError): + C_too_few[int] + class C_too_many: + def __class_getitem__(cls, one, two): + return None + with self.assertRaises(TypeError): + C_too_many[int] + + def test_class_getitem_errors_2(self): + class C: + def __class_getitem__(cls, item): + return None + with self.assertRaises(TypeError): + C()[int] + class E: ... + e = E() + e.__class_getitem__ = lambda cls, item: 'This will not work' + with self.assertRaises(TypeError): + e[int] + class C_not_callable: + __class_getitem__ = "Surprise!" + with self.assertRaises(TypeError): + C_not_callable[int] + + def test_class_getitem_metaclass(self): + class Meta(type): + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + self.assertEqual(Meta[int], 'Meta[int]') + + def test_class_getitem_with_metaclass(self): + class Meta(type): pass + class C(metaclass=Meta): + def __class_getitem__(cls, item): + return f'{cls.__name__}[{item.__name__}]' + self.assertEqual(C[int], 'C[int]') + + def test_class_getitem_metaclass_first(self): + class Meta(type): + def __getitem__(cls, item): + return 'from metaclass' + class C(metaclass=Meta): + def __class_getitem__(cls, item): + return 'from __class_getitem__' + self.assertEqual(C[int], 'from metaclass') + + +if __name__ == '__main__': + unittest.main() |