summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorda-woods <dw-git@d-woods.co.uk>2022-09-16 08:17:21 +0100
committerda-woods <dw-git@d-woods.co.uk>2022-09-16 08:17:21 +0100
commit4d772106395f40ed83ed4bc8f6ff59a43a880756 (patch)
tree09c6d3e15e37ab46295a85d35db5ff377322e185
parent3eb193be37f47a894737150c6833f707d729362d (diff)
downloadcython-4d772106395f40ed83ed4bc8f6ff59a43a880756.tar.gz
Simplify some type checking
-rw-r--r--Cython/Utility/MatchCase.c89
1 files changed, 37 insertions, 52 deletions
diff --git a/Cython/Utility/MatchCase.c b/Cython/Utility/MatchCase.c
index 6f9b83360..cea40d8b2 100644
--- a/Cython/Utility/MatchCase.c
+++ b/Cython/Utility/MatchCase.c
@@ -1,28 +1,14 @@
///////////////////////////// ABCCheck //////////////////////////////
#if PY_VERSION_HEX < 0x030A0000
-static int __Pyx_MatchCase_IsExactSequence(PyObject *o) {
+static CYTHON_INLINE int __Pyx_MatchCase_IsExactSequence(PyObject *o) {
// is one of the small list of builtin types known to be a sequence
- if (PyList_CheckExact(o) || PyTuple_CheckExact(o)) {
+ if (PyList_CheckExact(o) || PyTuple_CheckExact(o) ||
+ PyType_CheckExact(o, PyRange_Type) || PyType_CheckExact(o, PyMemoryView_Type)) {
// Use exact type match for these checks. I in the event of inheritence we need to make sure
// that it isn't a mapping too
return 1;
}
- if (PyRange_Check(o) || PyMemoryView_Check(o)) {
- // Exact check isn't possible so do exact check in another way
- PyObject *mro = PyObject_GetAttrString((PyObject*)Py_TYPE(o), "__mro__");
- if (mro) {
- Py_ssize_t len = PyObject_Length(mro);
- Py_DECREF(mro);
- if (len < 0) {
- PyErr_Clear(); // doesn't really matter, just proceed with other checks
- } else if (len == 2) {
- return 1; // the type and "object" and no other bases
- }
- } else {
- PyErr_Clear(); // doesn't really matter, just proceed with other checks
- }
- }
return 0;
}
@@ -34,10 +20,13 @@ static CYTHON_INLINE int __Pyx_MatchCase_IsExactMapping(PyObject *o) {
}
static int __Pyx_MatchCase_IsExactNeitherSequenceNorMapping(PyObject *o) {
- if (PyUnicode_Check(o) || PyBytes_Check(o) || PyByteArray_Check(o)) {
+ if (PyType_GetFlags(Py_TYPE(o)) & (Py_TPFLAGS_BYTES_SUBCLASS | Py_TPFLAGS_UNICODE_SUBCLASS)) ||
+ PyByteArray_Check(o)) {
return 1; // these types are deliberately excluded from the sequence test
// even though they look like sequences for most other purposes.
- // They're therefore "inexact" checks
+ // Leave them as inexact checks since they do pass
+ // "isinstance(o, collections.abc.Sequence)" so it's very hard to
+ // reason about their subclasses
}
if (o == Py_None || PyLong_CheckExact(o) || PyFloat_CheckExact(o)) {
return 1;
@@ -73,6 +62,16 @@ static int __Pyx_MatchCase_IsExactNeitherSequenceNorMapping(PyObject *o) {
#define __PYX_SEQUENCE_MAPPING_ERROR (1U<<4) // only used by the ABCCheck function
#endif
+static int __Pyx_MatchCase_InitAndIsInstanceAbc(PyObject *o, PyObject *abc_module,
+ PyObject **abc_type, PyObject *name) {
+ assert(!abc_type);
+ abc_type = PyObject_GetAttr(abc_module, name);
+ if (!abc_type) {
+ return -1;
+ }
+ return PyObject_IsInstance(o, abc_type);
+}
+
// the result is defined using the specification for sequence_mapping_temp
// (detailed in "is_sequence")
static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, int definitely_not_sequence, int definitely_not_mapping) {
@@ -101,12 +100,7 @@ static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, in
result = __PYX_DEFINITELY_SEQUENCE_FLAG;
goto end;
}
- sequence_type = PyObject_GetAttr(abc_module, PYIDENT("Sequence"));
- if (!sequence_type) {
- result = __PYX_SEQUENCE_MAPPING_ERROR;
- goto end;
- }
- sequence_result = PyObject_IsInstance(o, sequence_type);
+ sequence_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &sequence_type, PYIDENT("Sequence"));
if (sequence_result < 0) {
result = __PYX_SEQUENCE_MAPPING_ERROR;
goto end;
@@ -114,41 +108,32 @@ static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, in
result |= __PYX_DEFINITELY_NOT_SEQUENCE_FLAG;
goto end;
}
- // else wait to see what mapping is
+ // else wait to see what mapping is
}
if (!definitely_not_mapping) {
- mapping_type = PyObject_GetAttr(abc_module, PYIDENT("Mapping"));
- if (!mapping_type) {
+ mapping_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &mapping_type, PYIDENT("Mapping"));
+ if (mapping_result < 0) {
+ result = __PYX_SEQUENCE_MAPPING_ERROR;
goto end;
- }
- mapping_result = PyObject_IsInstance(o, mapping_type);
- }
- if (mapping_result < 0) {
- result = __PYX_SEQUENCE_MAPPING_ERROR;
- goto end;
- } else if (mapping_result == 0) {
- result |= __PYX_DEFINITELY_NOT_MAPPING_FLAG;
- if (sequence_first) {
- assert(sequence_result);
- result |= __PYX_DEFINITELY_SEQUENCE_FLAG;
- }
- goto end;
- } else /* mapping_result == 1 */ {
- if (sequence_first && !sequence_result) {
- result |= __PYX_DEFINITELY_MAPPING_FLAG;
+ } else if (mapping_result == 0) {
+ result |= __PYX_DEFINITELY_NOT_MAPPING_FLAG;
+ if (sequence_first) {
+ assert(sequence_result);
+ result |= __PYX_DEFINITELY_SEQUENCE_FLAG;
+ }
goto end;
+ } else /* mapping_result == 1 */ {
+ if (sequence_first && !sequence_result) {
+ result |= __PYX_DEFINITELY_MAPPING_FLAG;
+ goto end;
+ }
}
}
if (!sequence_first) {
// here we know mapping_result is true because we'd have returned otherwise
assert(mapping_result);
if (!definitely_not_sequence) {
- sequence_type = PyObject_GetAttr(abc_module, PYIDENT("Sequence"));
- if (!sequence_type) {
- result = __PYX_SEQUENCE_MAPPING_ERROR;
- goto end;
- }
- sequence_result = PyObject_IsInstance(o, sequence_type);
+ sequence_result = __Pyx_MatchCase_InitAndIsInstanceAbc(o, abc_module, &sequence_type, PYIDENT("Sequence"));
}
if (sequence_result < 0) {
result = __PYX_SEQUENCE_MAPPING_ERROR;
@@ -167,7 +152,7 @@ static unsigned int __Pyx_MatchCase_ABCCheck(PyObject *o, int sequence_first, in
if (!mro) {
PyErr_Clear();
goto end;
- }
+ }
if (!PyTuple_Check(mro)) {
Py_DECREF(mro);
goto end;
@@ -322,7 +307,7 @@ static PyObject *__Pyx_MatchCase_OtherSequenceSliceToList(PyObject *x, Py_ssize_
PyObject *list;
ssizeargfunc slot;
PyTypeObject *type = Py_TYPE(x);
-
+
list = PyList_New(total);
if (!list) {
return NULL;