summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormsg555 <msg555@gmail.com>2020-04-20 00:06:09 -0700
committerGitHub <noreply@github.com>2020-04-20 09:06:09 +0200
commitc0ceabe1ee0904477ed0b0b9371ae43915d7bd2f (patch)
tree8bef73c797d92ffa606f64a06d02d6bf88710cb9
parentb692adc13c560551e3ea87e0c0d01bc983248893 (diff)
downloadcython-c0ceabe1ee0904477ed0b0b9371ae43915d7bd2f.tar.gz
Update GetItem to support __class_getitem__ for type objects (GH-3518)
Closes #2753.
-rw-r--r--Cython/Compiler/Nodes.py12
-rw-r--r--Cython/Utility/ObjectHandling.c45
-rw-r--r--tests/run/test_genericclass.py139
3 files changed, 181 insertions, 15 deletions
diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py
index 4912ce035..84654bf12 100644
--- a/Cython/Compiler/Nodes.py
+++ b/Cython/Compiler/Nodes.py
@@ -39,6 +39,9 @@ else:
_py_int_types = (int, long)
+IMPLICIT_CLASSMETHODS = {"__init_subclass__", "__class_getitem__"}
+
+
def relative_position(pos):
return (pos[0].get_filenametable_entry(), pos[1])
@@ -2438,7 +2441,7 @@ class CFuncDefNode(FuncDefNode):
py_func_body = self.call_self_node(is_module_scope=env.is_module_scope)
if self.is_static_method:
from .ExprNodes import NameNode
- decorators = [DecoratorNode(self.pos, decorator=NameNode(self.pos, name='staticmethod'))]
+ decorators = [DecoratorNode(self.pos, decorator=NameNode(self.pos, name=EncodedString('staticmethod')))]
decorators[0].decorator.analyse_types(env)
else:
decorators = []
@@ -2883,7 +2886,12 @@ class DefNode(FuncDefNode):
self.is_staticmethod = False
if self.name == '__new__' and env.is_py_class_scope:
- self.is_staticmethod = 1
+ self.is_staticmethod = True
+ if not self.is_classmethod and self.name in IMPLICIT_CLASSMETHODS and env.is_py_class_scope:
+ from .ExprNodes import NameNode
+ self.decorators = self.decorators or []
+ self.decorators.insert(0, DecoratorNode(self.pos, decorator=NameNode(self.pos, name=EncodedString('classmethod'))))
+ self.is_classmethod = True
self.analyse_argument_types(env)
if self.name == '<lambda>':
diff --git a/Cython/Utility/ObjectHandling.c b/Cython/Utility/ObjectHandling.c
index f31d7af88..2e9b50e8d 100644
--- a/Cython/Utility/ObjectHandling.c
+++ b/Cython/Utility/ObjectHandling.c
@@ -273,24 +273,21 @@ static CYTHON_INLINE int __Pyx_IterFinish(void) {
/////////////// ObjectGetItem.proto ///////////////
#if CYTHON_USE_TYPE_SLOTS
-static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject* key);/*proto*/
+static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key);/*proto*/
#else
#define __Pyx_PyObject_GetItem(obj, key) PyObject_GetItem(obj, key)
#endif
/////////////// ObjectGetItem ///////////////
// //@requires: GetItemInt - added in IndexNode as it uses templating.
+//@requires: PyObjectGetAttrStrNoError
+//@requires: PyObjectCallOneArg
#if CYTHON_USE_TYPE_SLOTS
-static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject* index) {
+static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject *index) {
+ // Get element from sequence object `obj` at index `index`.
PyObject *runerr;
Py_ssize_t key_value;
- PySequenceMethods *m = Py_TYPE(obj)->tp_as_sequence;
- if (unlikely(!(m && m->sq_item))) {
- PyErr_Format(PyExc_TypeError, "'%.200s' object is not subscriptable", Py_TYPE(obj)->tp_name);
- return NULL;
- }
-
key_value = __Pyx_PyIndex_AsSsize_t(index);
if (likely(key_value != -1 || !(runerr = PyErr_Occurred()))) {
return __Pyx_GetItemInt_Fast(obj, key_value, 0, 1, 1);
@@ -304,12 +301,34 @@ static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject* index) {
return NULL;
}
-static PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject* key) {
- PyMappingMethods *m = Py_TYPE(obj)->tp_as_mapping;
- if (likely(m && m->mp_subscript)) {
- return m->mp_subscript(obj, key);
+static PyObject *__Pyx_PyObject_GetItem_Slow(PyObject *obj, PyObject *key) {
+ // Handles less common slow-path checks for GetItem
+ if (likely(PyType_Check(obj))) {
+ PyObject *meth = __Pyx_PyObject_GetAttrStrNoError(obj, PYIDENT("__class_getitem__"));
+ if (meth) {
+ PyObject *result = __Pyx_PyObject_CallOneArg(meth, key);
+ Py_DECREF(meth);
+ return result;
+ }
+ }
+
+ PyErr_Format(PyExc_TypeError, "'%.200s' object is not subscriptable", Py_TYPE(obj)->tp_name);
+ return NULL;
+}
+
+static PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key) {
+ PyTypeObject *tp = Py_TYPE(obj);
+ PyMappingMethods *mm = tp->tp_as_mapping;
+ if (likely(mm && mm->mp_subscript)) {
+ return mm->mp_subscript(obj, key);
+ }
+
+ PySequenceMethods *sm = tp->tp_as_sequence;
+ if (likely(sm && sm->sq_item)) {
+ return __Pyx_PyObject_GetIndex(obj, key);
}
- return __Pyx_PyObject_GetIndex(obj, key);
+
+ return __Pyx_PyObject_GetItem_Slow(obj, key);
}
#endif
diff --git a/tests/run/test_genericclass.py b/tests/run/test_genericclass.py
new file mode 100644
index 000000000..e6313cb84
--- /dev/null
+++ b/tests/run/test_genericclass.py
@@ -0,0 +1,139 @@
+# mode: run
+# tag: pure3.7
+# cython: language_level=3
+
+# COPIED FROM CPython 3.7
+
+import unittest
+import sys
+
+
+class TestClassGetitem(unittest.TestCase):
+ # BEGIN - Additional tests from cython
+ def test_no_class_getitem(self):
+ class C: ...
+ with self.assertRaises(TypeError):
+ C[int]
+
+ # END - Additional tests from cython
+
+ def test_class_getitem(self):
+ getitem_args = []
+ class C:
+ def __class_getitem__(*args, **kwargs):
+ getitem_args.extend([args, kwargs])
+ return None
+ C[int, str]
+ self.assertEqual(getitem_args[0], (C, (int, str)))
+ self.assertEqual(getitem_args[1], {})
+
+ def test_class_getitem_format(self):
+ class C:
+ def __class_getitem__(cls, item):
+ return f'C[{item.__name__}]'
+ self.assertEqual(C[int], 'C[int]')
+ self.assertEqual(C[C], 'C[C]')
+
+ def test_class_getitem_inheritance(self):
+ class C:
+ def __class_getitem__(cls, item):
+ return f'{cls.__name__}[{item.__name__}]'
+ class D(C): ...
+ self.assertEqual(D[int], 'D[int]')
+ self.assertEqual(D[D], 'D[D]')
+
+ def test_class_getitem_inheritance_2(self):
+ class C:
+ def __class_getitem__(cls, item):
+ return 'Should not see this'
+ class D(C):
+ def __class_getitem__(cls, item):
+ return f'{cls.__name__}[{item.__name__}]'
+ self.assertEqual(D[int], 'D[int]')
+ self.assertEqual(D[D], 'D[D]')
+
+ def test_class_getitem_classmethod(self):
+ class C:
+ @classmethod
+ def __class_getitem__(cls, item):
+ return f'{cls.__name__}[{item.__name__}]'
+ class D(C): ...
+ self.assertEqual(D[int], 'D[int]')
+ self.assertEqual(D[D], 'D[D]')
+
+ @unittest.skipIf(sys.version_info < (3, 6), "__init_subclass__() requires Py3.6+ (PEP 487)")
+ def test_class_getitem_patched(self):
+ class C:
+ def __init_subclass__(cls):
+ def __class_getitem__(cls, item):
+ return f'{cls.__name__}[{item.__name__}]'
+ cls.__class_getitem__ = classmethod(__class_getitem__)
+ class D(C): ...
+ self.assertEqual(D[int], 'D[int]')
+ self.assertEqual(D[D], 'D[D]')
+
+ def test_class_getitem_with_builtins(self):
+ class A(dict):
+ called_with = None
+
+ def __class_getitem__(cls, item):
+ cls.called_with = item
+ class B(A):
+ pass
+ self.assertIs(B.called_with, None)
+ B[int]
+ self.assertIs(B.called_with, int)
+
+ def test_class_getitem_errors(self):
+ class C_too_few:
+ def __class_getitem__(cls):
+ return None
+ with self.assertRaises(TypeError):
+ C_too_few[int]
+ class C_too_many:
+ def __class_getitem__(cls, one, two):
+ return None
+ with self.assertRaises(TypeError):
+ C_too_many[int]
+
+ def test_class_getitem_errors_2(self):
+ class C:
+ def __class_getitem__(cls, item):
+ return None
+ with self.assertRaises(TypeError):
+ C()[int]
+ class E: ...
+ e = E()
+ e.__class_getitem__ = lambda cls, item: 'This will not work'
+ with self.assertRaises(TypeError):
+ e[int]
+ class C_not_callable:
+ __class_getitem__ = "Surprise!"
+ with self.assertRaises(TypeError):
+ C_not_callable[int]
+
+ def test_class_getitem_metaclass(self):
+ class Meta(type):
+ def __class_getitem__(cls, item):
+ return f'{cls.__name__}[{item.__name__}]'
+ self.assertEqual(Meta[int], 'Meta[int]')
+
+ def test_class_getitem_with_metaclass(self):
+ class Meta(type): pass
+ class C(metaclass=Meta):
+ def __class_getitem__(cls, item):
+ return f'{cls.__name__}[{item.__name__}]'
+ self.assertEqual(C[int], 'C[int]')
+
+ def test_class_getitem_metaclass_first(self):
+ class Meta(type):
+ def __getitem__(cls, item):
+ return 'from metaclass'
+ class C(metaclass=Meta):
+ def __class_getitem__(cls, item):
+ return 'from __class_getitem__'
+ self.assertEqual(C[int], 'from metaclass')
+
+
+if __name__ == '__main__':
+ unittest.main()