diff options
43 files changed, 2993 insertions, 2999 deletions
diff --git a/.github/workflows/create-wheels.yaml b/.github/workflows/create-wheels.yaml index e90a9f3a2..c6eab3148 100644 --- a/.github/workflows/create-wheels.yaml +++ b/.github/workflows/create-wheels.yaml @@ -61,15 +61,13 @@ jobs: (cat setup.cfg) | %{$_ -replace "tag_build.?=.?dev",""} | set-content setup.cfg - name: Create wheel - # create the wheel using --no-use-pep517 since locally we have pyproject - # this flag should be removed once sqlalchemy supports pep517 # `--no-deps` is used to only generate the wheel for the current library. Redundant in sqlalchemy since it has no dependencies run: | python -m pip install --upgrade pip pip --version pip install 'setuptools>=44' 'wheel>=0.34' pip list - pip wheel -w dist --no-use-pep517 -v --no-deps . + pip wheel -w dist -v --no-deps . - name: Install wheel # install the created wheel without using the pypi index @@ -164,11 +162,9 @@ jobs: with: # python-versions is the output of the previous step and is in the form <python tag>-<abi tag>. Eg cp27-cp27mu python-versions: ${{ matrix.python-version }} - build-requirements: "setuptools>=44 wheel>=0.34" - # Create the wheel using --no-use-pep517 since locally we have pyproject - # This flag should be removed once sqlalchemy supports pep517 + build-requirements: "setuptools>=47 wheel>=0.34 cython>=0.29.24" # `--no-deps` is used to only generate the wheel for the current library. Redundant in sqlalchemy since it has no dependencies - pip-wheel-args: "-w ./dist --no-use-pep517 -v --no-deps" + pip-wheel-args: "-w ./dist -v --no-deps" - name: Create wheel for manylinux2014 for py3 # this step uses the image provided by pypa here https://github.com/pypa/manylinux to generate the wheels on linux @@ -179,11 +175,9 @@ jobs: with: # python-versions is the output of the previous step and is in the form <python tag>-<abi tag>. Eg cp27-cp27mu python-versions: ${{ matrix.python-version }} - build-requirements: "setuptools>=44 wheel>=0.34" - # Create the wheel using --no-use-pep517 since locally we have pyproject - # This flag should be removed once sqlalchemy supports pep517 + build-requirements: "setuptools>=47 wheel>=0.34 cython>=0.29.24" # `--no-deps` is used to only generate the wheel for the current library. Redundant in sqlalchemy since it has no dependencies - pip-wheel-args: "-w ./dist --no-use-pep517 -v --no-deps" + pip-wheel-args: "-w ./dist -v --no-deps" - name: Set up Python uses: actions/setup-python@v2 @@ -278,11 +272,9 @@ jobs: with: # python-versions is the output of the previous step and is in the form <python tag>-<abi tag>. Eg cp37-cp37mu python-versions: ${{ matrix.python-version }} - build-requirements: "setuptools>=44 wheel>=0.34" - # Create the wheel using --no-use-pep517 since locally we have pyproject - # This flag should be removed once sqlalchemy supports pep517 + build-requirements: "setuptools>=47 wheel>=0.34 cython>=0.29.24" # `--no-deps` is used to only generate the wheel for the current library. Redundant in sqlalchemy since it has no dependencies - pip-wheel-args: "-w ./dist --no-use-pep517 -v --no-deps" + pip-wheel-args: "-w ./dist -v --no-deps" - name: Check created wheel # check that the wheel is compatible with the current installation. diff --git a/MANIFEST.in b/MANIFEST.in index 6d04f593c..0a2c923f1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,13 +5,9 @@ recursive-include doc *.html *.css *.txt *.js *.png *.py Makefile *.rst *.sty recursive-include examples *.py *.xml recursive-include test *.py *.dat *.testpatch -# include the c extensions, which otherwise +# include the pyx and pxd extensions, which otherwise # don't come in if --with-cextensions isn't specified. -recursive-include lib *.c *.txt +recursive-include lib *.pyx *.pxd *.txt include README* AUTHORS LICENSE CHANGES* tox.ini prune doc/build/output - -# don't include pyproject.toml until we -# have explicitly built a pep-517 backend -exclude pyproject.toml diff --git a/lib/sqlalchemy/cextension/immutabledict.c b/lib/sqlalchemy/cextension/immutabledict.c deleted file mode 100644 index 1188dcd2b..000000000 --- a/lib/sqlalchemy/cextension/immutabledict.c +++ /dev/null @@ -1,526 +0,0 @@ -/* -immuatbledict.c -Copyright (C) 2005-2021 the SQLAlchemy authors and contributors <see AUTHORS file> - -This module is part of SQLAlchemy and is released under -the MIT License: https://www.opensource.org/licenses/mit-license.php -*/ - -#include <Python.h> - -#define MODULE_NAME "cimmutabledict" -#define MODULE_DOC "immutable dictionary implementation" - - -typedef struct { - PyObject_HEAD - PyObject *dict; -} ImmutableDict; - -static PyTypeObject ImmutableDictType; - - -#if PY_MAJOR_VERSION < 3 -/* For Python 2.7, VENDORED from cPython: https://github.com/python/cpython/commit/1c496178d2c863f135bd4a43e32e0f099480cd06 - This function was added to Python 2.7.12 as an underscore function. - - Variant of PyDict_GetItem() that doesn't suppress exceptions. - This returns NULL *with* an exception set if an exception occurred. - It returns NULL *without* an exception set if the key wasn't present. -*/ -PyObject * -PyDict_GetItemWithError(PyObject *op, PyObject *key) -{ - long hash; - PyDictObject *mp = (PyDictObject *)op; - PyDictEntry *ep; - if (!PyDict_Check(op)) { - PyErr_BadInternalCall(); - return NULL; - } - if (!PyString_CheckExact(key) || - (hash = ((PyStringObject *) key)->ob_shash) == -1) - { - hash = PyObject_Hash(key); - if (hash == -1) { - return NULL; - } - } - - ep = (mp->ma_lookup)(mp, key, hash); - if (ep == NULL) { - return NULL; - } - return ep->me_value; -} -#endif - -static PyObject * - -ImmutableDict_new(PyTypeObject *type, PyObject *args, PyObject *kw) - -{ - ImmutableDict *new_obj; - PyObject *arg_dict = NULL; - PyObject *our_dict; - - if (!PyArg_UnpackTuple(args, "ImmutableDict", 0, 1, &arg_dict)) { - return NULL; - } - - if (arg_dict != NULL && PyDict_CheckExact(arg_dict)) { - // going on the unproven theory that doing PyDict_New + PyDict_Update - // is faster than just calling CallObject, as we do below to - // accommodate for other dictionary argument forms - our_dict = PyDict_New(); - if (our_dict == NULL) { - return NULL; - } - - if (PyDict_Update(our_dict, arg_dict) == -1) { - Py_DECREF(our_dict); - return NULL; - } - } - else { - // for other calling styles, let PyDict figure it out - our_dict = PyObject_Call((PyObject *) &PyDict_Type, args, kw); - } - - new_obj = PyObject_GC_New(ImmutableDict, &ImmutableDictType); - if (new_obj == NULL) { - Py_DECREF(our_dict); - return NULL; - } - new_obj->dict = our_dict; - PyObject_GC_Track(new_obj); - - return (PyObject *)new_obj; - -} - - -Py_ssize_t -ImmutableDict_length(ImmutableDict *self) -{ - return PyDict_Size(self->dict); -} - -static PyObject * -ImmutableDict_subscript(ImmutableDict *self, PyObject *key) -{ - PyObject *value; -#if PY_MAJOR_VERSION >= 3 - PyObject *err_bytes; -#endif - - value = PyDict_GetItemWithError(self->dict, key); - - if (value == NULL) { - if (PyErr_Occurred() != NULL) { - // there was an error while hashing the key - return NULL; - } -#if PY_MAJOR_VERSION >= 3 - err_bytes = PyUnicode_AsUTF8String(key); - if (err_bytes == NULL) - return NULL; - PyErr_Format(PyExc_KeyError, "%s", PyBytes_AS_STRING(err_bytes)); -#else - PyErr_Format(PyExc_KeyError, "%s", PyString_AsString(key)); -#endif - return NULL; - } - - Py_INCREF(value); - - return value; -} - - -static void -ImmutableDict_dealloc(ImmutableDict *self) -{ - PyObject_GC_UnTrack(self); - Py_XDECREF(self->dict); - PyObject_GC_Del(self); -} - - -static PyObject * -ImmutableDict_reduce(ImmutableDict *self) -{ - return Py_BuildValue("O(O)", Py_TYPE(self), self->dict); -} - - -static PyObject * -ImmutableDict_repr(ImmutableDict *self) -{ - return PyUnicode_FromFormat("immutabledict(%R)", self->dict); -} - - -static PyObject * -ImmutableDict_union(PyObject *self, PyObject *args, PyObject *kw) -{ - PyObject *arg_dict, *new_dict; - - ImmutableDict *new_obj; - - if (!PyArg_UnpackTuple(args, "ImmutableDict", 0, 1, &arg_dict)) { - return NULL; - } - - if (!PyDict_CheckExact(arg_dict)) { - // if we didn't get a dict, and got lists of tuples or - // keyword args, make a dict - arg_dict = PyObject_Call((PyObject *) &PyDict_Type, args, kw); - if (arg_dict == NULL) { - return NULL; - } - } - else { - // otherwise we will use the dict as is - Py_INCREF(arg_dict); - } - - if (PyDict_Size(arg_dict) == 0) { - Py_DECREF(arg_dict); - Py_INCREF(self); - return self; - } - - new_dict = PyDict_New(); - if (new_dict == NULL) { - Py_DECREF(arg_dict); - return NULL; - } - - if (PyDict_Update(new_dict, ((ImmutableDict *)self)->dict) == -1) { - Py_DECREF(arg_dict); - Py_DECREF(new_dict); - return NULL; - } - - if (PyDict_Update(new_dict, arg_dict) == -1) { - Py_DECREF(arg_dict); - Py_DECREF(new_dict); - return NULL; - } - - Py_DECREF(arg_dict); - - new_obj = PyObject_GC_New(ImmutableDict, Py_TYPE(self)); - if (new_obj == NULL) { - Py_DECREF(new_dict); - return NULL; - } - - new_obj->dict = new_dict; - - PyObject_GC_Track(new_obj); - - return (PyObject *)new_obj; -} - - -static PyObject * -ImmutableDict_merge_with(PyObject *self, PyObject *args) -{ - PyObject *element, *arg, *new_dict = NULL; - - ImmutableDict *new_obj; - - Py_ssize_t num_args = PyTuple_Size(args); - Py_ssize_t i; - - for (i=0; i<num_args; i++) { - element = PyTuple_GetItem(args, i); - - if (element == NULL) { - Py_XDECREF(new_dict); - return NULL; - } - else if (element == Py_None) { - // none was passed, skip it - continue; - } - - if (!PyDict_CheckExact(element)) { - // not a dict, try to make a dict - - arg = PyTuple_Pack(1, element); - - element = PyObject_CallObject((PyObject *) &PyDict_Type, arg); - - Py_DECREF(arg); - - if (element == NULL) { - Py_XDECREF(new_dict); - return NULL; - } - } - else { - Py_INCREF(element); - if (PyDict_Size(element) == 0) { - continue; - } - } - - // initialize a new dictionary only if we receive data that - // is not empty. otherwise we return self at the end. - if (new_dict == NULL) { - - new_dict = PyDict_New(); - if (new_dict == NULL) { - Py_DECREF(element); - return NULL; - } - - if (PyDict_Update(new_dict, ((ImmutableDict *)self)->dict) == -1) { - Py_DECREF(element); - Py_DECREF(new_dict); - return NULL; - } - } - - if (PyDict_Update(new_dict, element) == -1) { - Py_DECREF(element); - Py_DECREF(new_dict); - return NULL; - } - - Py_DECREF(element); - } - - - if (new_dict != NULL) { - new_obj = PyObject_GC_New(ImmutableDict, Py_TYPE(self)); - if (new_obj == NULL) { - Py_DECREF(new_dict); - return NULL; - } - - new_obj->dict = new_dict; - PyObject_GC_Track(new_obj); - return (PyObject *)new_obj; - } - else { - Py_INCREF(self); - return self; - } - -} - - -static PyObject * -ImmutableDict_get(ImmutableDict *self, PyObject *args) -{ - PyObject *key; - PyObject *value; - PyObject *default_value = Py_None; - - if (!PyArg_UnpackTuple(args, "key", 1, 2, &key, &default_value)) { - return NULL; - } - - value = PyDict_GetItemWithError(self->dict, key); - - if (value == NULL) { - if (PyErr_Occurred() != NULL) { - // there was an error while hashing the key - return NULL; - } else { - // return default - Py_INCREF(default_value); - return default_value; - } - } else { - Py_INCREF(value); - return value; - } -} - -static PyObject * -ImmutableDict_keys(ImmutableDict *self) -{ - return PyDict_Keys(self->dict); -} - -static int -ImmutableDict_traverse(ImmutableDict *self, visitproc visit, void *arg) -{ - Py_VISIT(self->dict); - return 0; -} - -static PyObject * -ImmutableDict_richcompare(ImmutableDict *self, PyObject *other, int op) -{ - return PyObject_RichCompare(self->dict, other, op); -} - -static PyObject * -ImmutableDict_iter(ImmutableDict *self) -{ - return PyObject_GetIter(self->dict); -} - -static PyObject * -ImmutableDict_items(ImmutableDict *self) -{ - return PyDict_Items(self->dict); -} - -static PyObject * -ImmutableDict_values(ImmutableDict *self) -{ - return PyDict_Values(self->dict); -} - -static PyObject * -ImmutableDict_contains(ImmutableDict *self, PyObject *key) -{ - int ret; - - ret = PyDict_Contains(self->dict, key); - - if (ret == 1) Py_RETURN_TRUE; - else if (ret == 0) Py_RETURN_FALSE; - else return NULL; -} - -static PyMethodDef ImmutableDict_methods[] = { - {"union", (PyCFunction) ImmutableDict_union, METH_VARARGS | METH_KEYWORDS, - "provide a union of this dictionary with the given dictionary-like arguments"}, - {"merge_with", (PyCFunction) ImmutableDict_merge_with, METH_VARARGS, - "provide a union of this dictionary with those given"}, - {"keys", (PyCFunction) ImmutableDict_keys, METH_NOARGS, - "return dictionary keys"}, - - {"__contains__",(PyCFunction)ImmutableDict_contains, METH_O, - "test a member for containment"}, - - {"items", (PyCFunction) ImmutableDict_items, METH_NOARGS, - "return dictionary items"}, - {"values", (PyCFunction) ImmutableDict_values, METH_NOARGS, - "return dictionary values"}, - {"get", (PyCFunction) ImmutableDict_get, METH_VARARGS, - "get a value"}, - {"__reduce__", (PyCFunction)ImmutableDict_reduce, METH_NOARGS, - "Pickle support method."}, - {NULL}, -}; - - -static PyMappingMethods ImmutableDict_as_mapping = { - (lenfunc)ImmutableDict_length, /* mp_length */ - (binaryfunc)ImmutableDict_subscript, /* mp_subscript */ - 0 /* mp_ass_subscript */ -}; - - - - -static PyTypeObject ImmutableDictType = { - PyVarObject_HEAD_INIT(NULL, 0) - "sqlalchemy.cimmutabledict.immutabledict", /* tp_name */ - sizeof(ImmutableDict), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)ImmutableDict_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - (reprfunc)ImmutableDict_repr, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - &ImmutableDict_as_mapping, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC , /* tp_flags */ - "immutable dictionary", /* tp_doc */ - (traverseproc)ImmutableDict_traverse, /* tp_traverse */ - 0, /* tp_clear */ - (richcmpfunc)ImmutableDict_richcompare, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - (getiterfunc)ImmutableDict_iter, /* tp_iter */ - 0, /* tp_iternext */ - ImmutableDict_methods, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - ImmutableDict_new, /* tp_new */ - 0, /* tp_free */ -}; - - - - - -static PyMethodDef module_methods[] = { - {NULL, NULL, 0, NULL} /* Sentinel */ -}; - -#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ -#define PyMODINIT_FUNC void -#endif - - -#if PY_MAJOR_VERSION >= 3 - -static struct PyModuleDef module_def = { - PyModuleDef_HEAD_INIT, - MODULE_NAME, - MODULE_DOC, - -1, - module_methods -}; - -#define INITERROR return NULL - -PyMODINIT_FUNC -PyInit_cimmutabledict(void) - -#else - -#define INITERROR return - -PyMODINIT_FUNC -initcimmutabledict(void) - -#endif - -{ - PyObject *m; - - if (PyType_Ready(&ImmutableDictType) < 0) - INITERROR; - - -#if PY_MAJOR_VERSION >= 3 - m = PyModule_Create(&module_def); -#else - m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC); -#endif - if (m == NULL) - INITERROR; - - Py_INCREF(&ImmutableDictType); - PyModule_AddObject(m, "immutabledict", (PyObject *)&ImmutableDictType); - -#if PY_MAJOR_VERSION >= 3 - return m; -#endif -} diff --git a/lib/sqlalchemy/cextension/processors.c b/lib/sqlalchemy/cextension/processors.c deleted file mode 100644 index 8c031b70a..000000000 --- a/lib/sqlalchemy/cextension/processors.c +++ /dev/null @@ -1,508 +0,0 @@ -/* -processors.c -Copyright (C) 2010-2021 the SQLAlchemy authors and contributors <see AUTHORS file> -Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com - -This module is part of SQLAlchemy and is released under -the MIT License: https://www.opensource.org/licenses/mit-license.php -*/ - -#include <Python.h> -#include <datetime.h> - -#define MODULE_NAME "cprocessors" -#define MODULE_DOC "Module containing C versions of data processing functions." - -#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) -typedef int Py_ssize_t; -#define PY_SSIZE_T_MAX INT_MAX -#define PY_SSIZE_T_MIN INT_MIN -#endif - -static PyObject * -int_to_boolean(PyObject *self, PyObject *arg) -{ - int l = 0; - PyObject *res; - - if (arg == Py_None) - Py_RETURN_NONE; - - l = PyObject_IsTrue(arg); - if (l == 0) { - res = Py_False; - } else if (l == 1) { - res = Py_True; - } else { - return NULL; - } - - Py_INCREF(res); - return res; -} - -static PyObject * -to_str(PyObject *self, PyObject *arg) -{ - if (arg == Py_None) - Py_RETURN_NONE; - - return PyObject_Str(arg); -} - -static PyObject * -to_float(PyObject *self, PyObject *arg) -{ - if (arg == Py_None) - Py_RETURN_NONE; - - return PyNumber_Float(arg); -} - -static PyObject * -str_to_datetime(PyObject *self, PyObject *arg) -{ -#if PY_MAJOR_VERSION >= 3 - PyObject *bytes; - PyObject *err_bytes; -#endif - const char *str; - int numparsed; - unsigned int year, month, day, hour, minute, second, microsecond = 0; - PyObject *err_repr; - - if (arg == Py_None) - Py_RETURN_NONE; - -#if PY_MAJOR_VERSION >= 3 - bytes = PyUnicode_AsASCIIString(arg); - if (bytes == NULL) - str = NULL; - else - str = PyBytes_AS_STRING(bytes); -#else - str = PyString_AsString(arg); -#endif - if (str == NULL) { - err_repr = PyObject_Repr(arg); - if (err_repr == NULL) - return NULL; -#if PY_MAJOR_VERSION >= 3 - err_bytes = PyUnicode_AsASCIIString(err_repr); - if (err_bytes == NULL) - return NULL; - PyErr_Format( - PyExc_ValueError, - "Couldn't parse datetime string '%.200s' " - "- value is not a string.", - PyBytes_AS_STRING(err_bytes)); - Py_DECREF(err_bytes); -#else - PyErr_Format( - PyExc_ValueError, - "Couldn't parse datetime string '%.200s' " - "- value is not a string.", - PyString_AsString(err_repr)); -#endif - Py_DECREF(err_repr); - return NULL; - } - - /* microseconds are optional */ - /* - TODO: this is slightly less picky than the Python version which would - not accept "2000-01-01 00:00:00.". I don't know which is better, but they - should be coherent. - */ - numparsed = sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day, - &hour, &minute, &second, µsecond); -#if PY_MAJOR_VERSION >= 3 - Py_DECREF(bytes); -#endif - if (numparsed < 6) { - err_repr = PyObject_Repr(arg); - if (err_repr == NULL) - return NULL; -#if PY_MAJOR_VERSION >= 3 - err_bytes = PyUnicode_AsASCIIString(err_repr); - if (err_bytes == NULL) - return NULL; - PyErr_Format( - PyExc_ValueError, - "Couldn't parse datetime string: %.200s", - PyBytes_AS_STRING(err_bytes)); - Py_DECREF(err_bytes); -#else - PyErr_Format( - PyExc_ValueError, - "Couldn't parse datetime string: %.200s", - PyString_AsString(err_repr)); -#endif - Py_DECREF(err_repr); - return NULL; - } - return PyDateTime_FromDateAndTime(year, month, day, - hour, minute, second, microsecond); -} - -static PyObject * -str_to_time(PyObject *self, PyObject *arg) -{ -#if PY_MAJOR_VERSION >= 3 - PyObject *bytes; - PyObject *err_bytes; -#endif - const char *str; - int numparsed; - unsigned int hour, minute, second, microsecond = 0; - PyObject *err_repr; - - if (arg == Py_None) - Py_RETURN_NONE; - -#if PY_MAJOR_VERSION >= 3 - bytes = PyUnicode_AsASCIIString(arg); - if (bytes == NULL) - str = NULL; - else - str = PyBytes_AS_STRING(bytes); -#else - str = PyString_AsString(arg); -#endif - if (str == NULL) { - err_repr = PyObject_Repr(arg); - if (err_repr == NULL) - return NULL; - -#if PY_MAJOR_VERSION >= 3 - err_bytes = PyUnicode_AsASCIIString(err_repr); - if (err_bytes == NULL) - return NULL; - PyErr_Format( - PyExc_ValueError, - "Couldn't parse time string '%.200s' - value is not a string.", - PyBytes_AS_STRING(err_bytes)); - Py_DECREF(err_bytes); -#else - PyErr_Format( - PyExc_ValueError, - "Couldn't parse time string '%.200s' - value is not a string.", - PyString_AsString(err_repr)); -#endif - Py_DECREF(err_repr); - return NULL; - } - - /* microseconds are optional */ - /* - TODO: this is slightly less picky than the Python version which would - not accept "00:00:00.". I don't know which is better, but they should be - coherent. - */ - numparsed = sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second, - µsecond); -#if PY_MAJOR_VERSION >= 3 - Py_DECREF(bytes); -#endif - if (numparsed < 3) { - err_repr = PyObject_Repr(arg); - if (err_repr == NULL) - return NULL; -#if PY_MAJOR_VERSION >= 3 - err_bytes = PyUnicode_AsASCIIString(err_repr); - if (err_bytes == NULL) - return NULL; - PyErr_Format( - PyExc_ValueError, - "Couldn't parse time string: %.200s", - PyBytes_AS_STRING(err_bytes)); - Py_DECREF(err_bytes); -#else - PyErr_Format( - PyExc_ValueError, - "Couldn't parse time string: %.200s", - PyString_AsString(err_repr)); -#endif - Py_DECREF(err_repr); - return NULL; - } - return PyTime_FromTime(hour, minute, second, microsecond); -} - -static PyObject * -str_to_date(PyObject *self, PyObject *arg) -{ -#if PY_MAJOR_VERSION >= 3 - PyObject *bytes; - PyObject *err_bytes; -#endif - const char *str; - int numparsed; - unsigned int year, month, day; - PyObject *err_repr; - - if (arg == Py_None) - Py_RETURN_NONE; - -#if PY_MAJOR_VERSION >= 3 - bytes = PyUnicode_AsASCIIString(arg); - if (bytes == NULL) - str = NULL; - else - str = PyBytes_AS_STRING(bytes); -#else - str = PyString_AsString(arg); -#endif - if (str == NULL) { - err_repr = PyObject_Repr(arg); - if (err_repr == NULL) - return NULL; -#if PY_MAJOR_VERSION >= 3 - err_bytes = PyUnicode_AsASCIIString(err_repr); - if (err_bytes == NULL) - return NULL; - PyErr_Format( - PyExc_ValueError, - "Couldn't parse date string '%.200s' - value is not a string.", - PyBytes_AS_STRING(err_bytes)); - Py_DECREF(err_bytes); -#else - PyErr_Format( - PyExc_ValueError, - "Couldn't parse date string '%.200s' - value is not a string.", - PyString_AsString(err_repr)); -#endif - Py_DECREF(err_repr); - return NULL; - } - - numparsed = sscanf(str, "%4u-%2u-%2u", &year, &month, &day); -#if PY_MAJOR_VERSION >= 3 - Py_DECREF(bytes); -#endif - if (numparsed != 3) { - err_repr = PyObject_Repr(arg); - if (err_repr == NULL) - return NULL; -#if PY_MAJOR_VERSION >= 3 - err_bytes = PyUnicode_AsASCIIString(err_repr); - if (err_bytes == NULL) - return NULL; - PyErr_Format( - PyExc_ValueError, - "Couldn't parse date string: %.200s", - PyBytes_AS_STRING(err_bytes)); - Py_DECREF(err_bytes); -#else - PyErr_Format( - PyExc_ValueError, - "Couldn't parse date string: %.200s", - PyString_AsString(err_repr)); -#endif - Py_DECREF(err_repr); - return NULL; - } - return PyDate_FromDate(year, month, day); -} - - -/*********** - * Structs * - ***********/ - -typedef struct { - PyObject_HEAD - PyObject *type; - PyObject *format; -} DecimalResultProcessor; - - - - -/************************** - * DecimalResultProcessor * - **************************/ - -static int -DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args, - PyObject *kwds) -{ - PyObject *type, *format; - -#if PY_MAJOR_VERSION >= 3 - if (!PyArg_ParseTuple(args, "OU", &type, &format)) -#else - if (!PyArg_ParseTuple(args, "OS", &type, &format)) -#endif - return -1; - - Py_INCREF(type); - self->type = type; - - Py_INCREF(format); - self->format = format; - - return 0; -} - -static PyObject * -DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value) -{ - PyObject *str, *result, *args; - - if (value == Py_None) - Py_RETURN_NONE; - - /* Decimal does not accept float values directly */ - /* SQLite can also give us an integer here (see [ticket:2432]) */ - /* XXX: starting with Python 3.1, we could use Decimal.from_float(f), - but the result wouldn't be the same */ - - args = PyTuple_Pack(1, value); - if (args == NULL) - return NULL; - -#if PY_MAJOR_VERSION >= 3 - str = PyUnicode_Format(self->format, args); -#else - str = PyString_Format(self->format, args); -#endif - - Py_DECREF(args); - if (str == NULL) - return NULL; - - result = PyObject_CallFunctionObjArgs(self->type, str, NULL); - Py_DECREF(str); - return result; -} - -static void -DecimalResultProcessor_dealloc(DecimalResultProcessor *self) -{ - Py_XDECREF(self->type); - Py_XDECREF(self->format); -#if PY_MAJOR_VERSION >= 3 - Py_TYPE(self)->tp_free((PyObject*)self); -#else - self->ob_type->tp_free((PyObject*)self); -#endif -} - -static PyMethodDef DecimalResultProcessor_methods[] = { - {"process", (PyCFunction)DecimalResultProcessor_process, METH_O, - "The value processor itself."}, - {NULL} /* Sentinel */ -}; - -static PyTypeObject DecimalResultProcessorType = { - PyVarObject_HEAD_INIT(NULL, 0) - "sqlalchemy.DecimalResultProcessor", /* tp_name */ - sizeof(DecimalResultProcessor), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)DecimalResultProcessor_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ - "DecimalResultProcessor objects", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - DecimalResultProcessor_methods, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)DecimalResultProcessor_init, /* tp_init */ - 0, /* tp_alloc */ - 0, /* tp_new */ -}; - -static PyMethodDef module_methods[] = { - {"int_to_boolean", int_to_boolean, METH_O, - "Convert an integer to a boolean."}, - {"to_str", to_str, METH_O, - "Convert any value to its string representation."}, - {"to_float", to_float, METH_O, - "Convert any value to its floating point representation."}, - {"str_to_datetime", str_to_datetime, METH_O, - "Convert an ISO string to a datetime.datetime object."}, - {"str_to_time", str_to_time, METH_O, - "Convert an ISO string to a datetime.time object."}, - {"str_to_date", str_to_date, METH_O, - "Convert an ISO string to a datetime.date object."}, - {NULL, NULL, 0, NULL} /* Sentinel */ -}; - -#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ -#define PyMODINIT_FUNC void -#endif - - -#if PY_MAJOR_VERSION >= 3 - -static struct PyModuleDef module_def = { - PyModuleDef_HEAD_INIT, - MODULE_NAME, - MODULE_DOC, - -1, - module_methods -}; - -#define INITERROR return NULL - -PyMODINIT_FUNC -PyInit_cprocessors(void) - -#else - -#define INITERROR return - -PyMODINIT_FUNC -initcprocessors(void) - -#endif - -{ - PyObject *m; - - DecimalResultProcessorType.tp_new = PyType_GenericNew; - if (PyType_Ready(&DecimalResultProcessorType) < 0) - INITERROR; - -#if PY_MAJOR_VERSION >= 3 - m = PyModule_Create(&module_def); -#else - m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC); -#endif - if (m == NULL) - INITERROR; - - PyDateTime_IMPORT; - - Py_INCREF(&DecimalResultProcessorType); - PyModule_AddObject(m, "DecimalResultProcessor", - (PyObject *)&DecimalResultProcessorType); - -#if PY_MAJOR_VERSION >= 3 - return m; -#endif -} diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c deleted file mode 100644 index 99b2d36f3..000000000 --- a/lib/sqlalchemy/cextension/resultproxy.c +++ /dev/null @@ -1,1033 +0,0 @@ -/* -resultproxy.c -Copyright (C) 2010-2021 the SQLAlchemy authors and contributors <see AUTHORS file> -Copyright (C) 2010-2011 Gaetan de Menten gdementen@gmail.com - -This module is part of SQLAlchemy and is released under -the MIT License: https://www.opensource.org/licenses/mit-license.php -*/ - -#include <Python.h> - -#define MODULE_NAME "cresultproxy" -#define MODULE_DOC "Module containing C versions of core ResultProxy classes." - -#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) -typedef int Py_ssize_t; -#define PY_SSIZE_T_MAX INT_MAX -#define PY_SSIZE_T_MIN INT_MIN -typedef Py_ssize_t (*lenfunc)(PyObject *); -#define PyInt_FromSsize_t(x) PyInt_FromLong(x) -typedef intargfunc ssizeargfunc; -#endif - -#if PY_MAJOR_VERSION < 3 - -// new typedef in Python 3 -typedef long Py_hash_t; - -// from pymacro.h, new in Python 3.2 -#if defined(__GNUC__) || defined(__clang__) -# define Py_UNUSED(name) _unused_ ## name __attribute__((unused)) -#else -# define Py_UNUSED(name) _unused_ ## name -#endif - -#endif - - -/*********** - * Structs * - ***********/ - -typedef struct { - PyObject_HEAD - PyObject *parent; - PyObject *row; - PyObject *keymap; - long key_style; -} BaseRow; - - -static PyObject *sqlalchemy_engine_row = NULL; -static PyObject *sqlalchemy_engine_result = NULL; - - -static int KEY_INTEGER_ONLY = 0; -static int KEY_OBJECTS_ONLY = 1; - -/**************** - * BaseRow * - ****************/ - -static PyObject * -safe_rowproxy_reconstructor(PyObject *self, PyObject *args) -{ - PyObject *cls, *state, *tmp; - BaseRow *obj; - - if (!PyArg_ParseTuple(args, "OO", &cls, &state)) - return NULL; - - obj = (BaseRow *)PyObject_CallMethod(cls, "__new__", "O", cls); - if (obj == NULL) - return NULL; - - tmp = PyObject_CallMethod((PyObject *)obj, "__setstate__", "O", state); - if (tmp == NULL) { - Py_DECREF(obj); - return NULL; - } - Py_DECREF(tmp); - - if (obj->parent == NULL || obj->row == NULL || - obj->keymap == NULL) { - PyErr_SetString(PyExc_RuntimeError, - "__setstate__ for BaseRow subclasses must set values " - "for parent, row, and keymap"); - Py_DECREF(obj); - return NULL; - } - - return (PyObject *)obj; -} - -static int -BaseRow_init(BaseRow *self, PyObject *args, PyObject *kwds) -{ - PyObject *parent, *keymap, *row, *processors, *key_style; - Py_ssize_t num_values, num_processors; - PyObject **valueptr, **funcptr, **resultptr; - PyObject *func, *result, *processed_value, *values_fastseq; - - if (!PyArg_UnpackTuple(args, "BaseRow", 5, 5, - &parent, &processors, &keymap, &key_style, &row)) - return -1; - - Py_INCREF(parent); - self->parent = parent; - - values_fastseq = PySequence_Fast(row, "row must be a sequence"); - if (values_fastseq == NULL) - return -1; - - num_values = PySequence_Length(values_fastseq); - - - if (processors != Py_None) { - num_processors = PySequence_Size(processors); - if (num_values != num_processors) { - PyErr_Format(PyExc_RuntimeError, - "number of values in row (%d) differ from number of column " - "processors (%d)", - (int)num_values, (int)num_processors); - return -1; - } - - } else { - num_processors = -1; - } - - result = PyTuple_New(num_values); - if (result == NULL) - return -1; - - if (num_processors != -1) { - valueptr = PySequence_Fast_ITEMS(values_fastseq); - funcptr = PySequence_Fast_ITEMS(processors); - resultptr = PySequence_Fast_ITEMS(result); - while (--num_values >= 0) { - func = *funcptr; - if (func != Py_None) { - processed_value = PyObject_CallFunctionObjArgs( - func, *valueptr, NULL); - if (processed_value == NULL) { - Py_DECREF(values_fastseq); - Py_DECREF(result); - return -1; - } - *resultptr = processed_value; - } else { - Py_INCREF(*valueptr); - *resultptr = *valueptr; - } - valueptr++; - funcptr++; - resultptr++; - } - } else { - valueptr = PySequence_Fast_ITEMS(values_fastseq); - resultptr = PySequence_Fast_ITEMS(result); - while (--num_values >= 0) { - Py_INCREF(*valueptr); - *resultptr = *valueptr; - valueptr++; - resultptr++; - } - } - - Py_DECREF(values_fastseq); - self->row = result; - - if (!PyDict_CheckExact(keymap)) { - PyErr_SetString(PyExc_TypeError, "keymap must be a dict"); - return -1; - } - Py_INCREF(keymap); - self->keymap = keymap; - self->key_style = PyLong_AsLong(key_style); - - // observation: because we have not implemented our own new method, - // cPython is apparently already calling PyObject_GC_Track for us. - // We assume it also called PyObject_GC_New since prior to #5348 we - // were already relying upon it to call PyObject_New, and we have now - // set Py_TPFLAGS_HAVE_GC. - - return 0; -} - -static int -BaseRow_traverse(BaseRow *self, visitproc visit, void *arg) -{ - Py_VISIT(self->parent); - Py_VISIT(self->row); - Py_VISIT(self->keymap); - return 0; -} - -/* We need the reduce method because otherwise the default implementation - * does very weird stuff for pickle protocol 0 and 1. It calls - * BaseRow.__new__(Row_instance) upon *pickling*. - */ -static PyObject * -BaseRow_reduce(PyObject *self) -{ - PyObject *method, *state; - PyObject *module, *reconstructor, *cls; - - method = PyObject_GetAttrString(self, "__getstate__"); - if (method == NULL) - return NULL; - - state = PyObject_CallObject(method, NULL); - Py_DECREF(method); - if (state == NULL) - return NULL; - - if (sqlalchemy_engine_row == NULL) { - module = PyImport_ImportModule("sqlalchemy.engine.row"); - if (module == NULL) - return NULL; - sqlalchemy_engine_row = module; - } - - reconstructor = PyObject_GetAttrString(sqlalchemy_engine_row, "rowproxy_reconstructor"); - if (reconstructor == NULL) { - Py_DECREF(state); - return NULL; - } - - cls = PyObject_GetAttrString(self, "__class__"); - if (cls == NULL) { - Py_DECREF(reconstructor); - Py_DECREF(state); - return NULL; - } - - return Py_BuildValue("(N(NN))", reconstructor, cls, state); -} - -static PyObject * -BaseRow_filter_on_values(BaseRow *self, PyObject *filters) -{ - PyObject *module, *row_class, *new_obj, *key_style; - - if (sqlalchemy_engine_row == NULL) { - module = PyImport_ImportModule("sqlalchemy.engine.row"); - if (module == NULL) - return NULL; - sqlalchemy_engine_row = module; - } - - // TODO: do we want to get self.__class__ instead here? I'm not sure - // how to use METH_VARARGS and then also get the BaseRow struct - // at the same time - row_class = PyObject_GetAttrString(sqlalchemy_engine_row, "Row"); - - key_style = PyLong_FromLong(self->key_style); - - new_obj = PyObject_CallFunction( - row_class, "OOOOO", self->parent, filters, self->keymap, - key_style, self->row); - Py_DECREF(key_style); - Py_DECREF(row_class); - if (new_obj == NULL) { - return NULL; - } - - return new_obj; - -} - -static void -BaseRow_dealloc(BaseRow *self) -{ - PyObject_GC_UnTrack(self); - Py_XDECREF(self->parent); - Py_XDECREF(self->row); - Py_XDECREF(self->keymap); - PyObject_GC_Del(self); - -} - -static PyObject * -BaseRow_valuescollection(PyObject *values, int astuple) -{ - PyObject *result; - - if (astuple) { - result = PySequence_Tuple(values); - } else { - result = PySequence_List(values); - } - if (result == NULL) - return NULL; - - return result; -} - -static PyListObject * -BaseRow_values_impl(BaseRow *self) -{ - return (PyListObject *)BaseRow_valuescollection(self->row, 0); -} - -static Py_hash_t -BaseRow_hash(BaseRow *self) -{ - return PyObject_Hash(self->row); -} - -static PyObject * -BaseRow_iter(BaseRow *self) -{ - PyObject *values, *result; - - values = BaseRow_valuescollection(self->row, 1); - if (values == NULL) - return NULL; - - result = PyObject_GetIter(values); - Py_DECREF(values); - if (result == NULL) - return NULL; - - return result; -} - -static Py_ssize_t -BaseRow_length(BaseRow *self) -{ - return PySequence_Length(self->row); -} - -static PyObject * -BaseRow_getitem(BaseRow *self, Py_ssize_t i) -{ - PyObject *value; - PyObject *row; - - row = self->row; - - // row is a Tuple - value = PyTuple_GetItem(row, i); - - if (value == NULL) - return NULL; - - Py_INCREF(value); - - return value; -} - -static PyObject * -BaseRow_getitem_by_object(BaseRow *self, PyObject *key, int asmapping) -{ - PyObject *record, *indexobject; - long index; - int key_fallback = 0; - - // we want to raise TypeError for slice access on a mapping. - // Py3 will do this with PyDict_GetItemWithError, Py2 will do it - // with PyObject_GetItem. However in the Python2 case the object - // protocol gets in the way for reasons not entirely clear, so - // detect slice we have a key error and raise directly. - - record = PyDict_GetItem((PyObject *)self->keymap, key); - - if (record == NULL) { - if (PySlice_Check(key)) { - PyErr_Format(PyExc_TypeError, "can't use slices for mapping access"); - return NULL; - } - record = PyObject_CallMethod(self->parent, "_key_fallback", - "OO", key, Py_None); - if (record == NULL) - return NULL; - - key_fallback = 1; // boolean to indicate record is a new reference - } - - indexobject = PyTuple_GetItem(record, 0); - if (indexobject == NULL) - return NULL; - - if (key_fallback) { - Py_DECREF(record); - } - - if (indexobject == Py_None) { - PyObject *tmp; - - tmp = PyObject_CallMethod(self->parent, "_raise_for_ambiguous_column_name", "(O)", record); - if (tmp == NULL) { - return NULL; - } - Py_DECREF(tmp); - - return NULL; - } - -#if PY_MAJOR_VERSION >= 3 - index = PyLong_AsLong(indexobject); -#else - index = PyInt_AsLong(indexobject); -#endif - if ((index == -1) && PyErr_Occurred()) - /* -1 can be either the actual value, or an error flag. */ - return NULL; - - return BaseRow_getitem(self, index); - -} - -static PyObject * -BaseRow_subscript_impl(BaseRow *self, PyObject *key, int asmapping) -{ - PyObject *values; - PyObject *result; - long index; - PyObject *tmp; - -#if PY_MAJOR_VERSION < 3 - if (PyInt_CheckExact(key)) { - if (self->key_style == KEY_OBJECTS_ONLY) { - // TODO: being very lazy with error catching here - PyErr_Format(PyExc_KeyError, "%s", PyString_AsString(PyObject_Repr(key))); - return NULL; - } - index = PyInt_AS_LONG(key); - - // support negative indexes. We can also call PySequence_GetItem, - // but here we can stay with the simpler tuple protocol - // rather than the sequence protocol which has to check for - // __getitem__ methods etc. - if (index < 0) - index += (long)BaseRow_length(self); - return BaseRow_getitem(self, index); - } else -#endif - - if (PyLong_CheckExact(key)) { - if (self->key_style == KEY_OBJECTS_ONLY) { -#if PY_MAJOR_VERSION < 3 - // TODO: being very lazy with error catching here - PyErr_Format(PyExc_KeyError, "%s", PyString_AsString(PyObject_Repr(key))); -#else - PyErr_Format(PyExc_KeyError, "%R", key); -#endif - return NULL; - } - index = PyLong_AsLong(key); - if ((index == -1) && PyErr_Occurred() != NULL) - /* -1 can be either the actual value, or an error flag. */ - return NULL; - - // support negative indexes. We can also call PySequence_GetItem, - // but here we can stay with the simpler tuple protocol - // rather than the sequence protocol which has to check for - // __getitem__ methods etc. - if (index < 0) - index += (long)BaseRow_length(self); - return BaseRow_getitem(self, index); - - } else if (PySlice_Check(key) && self->key_style != KEY_OBJECTS_ONLY) { - values = PyObject_GetItem(self->row, key); - if (values == NULL) - return NULL; - - result = BaseRow_valuescollection(values, 1); - Py_DECREF(values); - return result; - } - else if (!asmapping && self->key_style == KEY_INTEGER_ONLY) { - tmp = PyObject_CallMethod(self->parent, "_raise_for_nonint", "(O)", key); - if (tmp == NULL) { - return NULL; - } - Py_DECREF(tmp); - return NULL; - } else { - return BaseRow_getitem_by_object(self, key, asmapping); - } -} - -static PyObject * -BaseRow_subscript(BaseRow *self, PyObject *key) -{ - return BaseRow_subscript_impl(self, key, 0); -} - -static PyObject * -BaseRow_subscript_mapping(BaseRow *self, PyObject *key) -{ - if (self->key_style == KEY_INTEGER_ONLY) { - return BaseRow_subscript_impl(self, key, 0); - } - else { - return BaseRow_subscript_impl(self, key, 1); - } -} - - -static PyObject * -BaseRow_getattro(BaseRow *self, PyObject *name) -{ - PyObject *tmp; -#if PY_MAJOR_VERSION >= 3 - PyObject *err_bytes; -#endif - - if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) { - if (!PyErr_ExceptionMatches(PyExc_AttributeError)) - return NULL; - PyErr_Clear(); - } - else - return tmp; - - tmp = BaseRow_subscript_impl(self, name, 1); - - if (tmp == NULL && PyErr_ExceptionMatches(PyExc_KeyError)) { - -#if PY_MAJOR_VERSION >= 3 - err_bytes = PyUnicode_AsASCIIString(name); - if (err_bytes == NULL) - return NULL; - PyErr_Format( - PyExc_AttributeError, - "Could not locate column in row for column '%.200s'", - PyBytes_AS_STRING(err_bytes) - ); -#else - PyErr_Format( - PyExc_AttributeError, - "Could not locate column in row for column '%.200s'", - PyString_AsString(name) - ); -#endif - return NULL; - } - return tmp; -} - -/*********************** - * getters and setters * - ***********************/ - -static PyObject * -BaseRow_getparent(BaseRow *self, void *closure) -{ - Py_INCREF(self->parent); - return self->parent; -} - -static int -BaseRow_setparent(BaseRow *self, PyObject *value, void *closure) -{ - PyObject *module, *cls; - - if (value == NULL) { - PyErr_SetString(PyExc_TypeError, - "Cannot delete the 'parent' attribute"); - return -1; - } - - if (sqlalchemy_engine_result == NULL) { - module = PyImport_ImportModule("sqlalchemy.engine.result"); - if (module == NULL) - return -1; - sqlalchemy_engine_result = module; - } - - cls = PyObject_GetAttrString(sqlalchemy_engine_result, "ResultMetaData"); - if (cls == NULL) - return -1; - - if (PyObject_IsInstance(value, cls) != 1) { - PyErr_SetString(PyExc_TypeError, - "The 'parent' attribute value must be an instance of " - "ResultMetaData"); - return -1; - } - Py_DECREF(cls); - Py_XDECREF(self->parent); - Py_INCREF(value); - self->parent = value; - - return 0; -} - -static PyObject * -BaseRow_getrow(BaseRow *self, void *closure) -{ - Py_INCREF(self->row); - return self->row; -} - -static int -BaseRow_setrow(BaseRow *self, PyObject *value, void *closure) -{ - if (value == NULL) { - PyErr_SetString(PyExc_TypeError, - "Cannot delete the 'row' attribute"); - return -1; - } - - if (!PySequence_Check(value)) { - PyErr_SetString(PyExc_TypeError, - "The 'row' attribute value must be a sequence"); - return -1; - } - - Py_XDECREF(self->row); - Py_INCREF(value); - self->row = value; - - return 0; -} - - - -static PyObject * -BaseRow_getkeymap(BaseRow *self, void *closure) -{ - Py_INCREF(self->keymap); - return self->keymap; -} - -static int -BaseRow_setkeymap(BaseRow *self, PyObject *value, void *closure) -{ - if (value == NULL) { - PyErr_SetString(PyExc_TypeError, - "Cannot delete the 'keymap' attribute"); - return -1; - } - - if (!PyDict_CheckExact(value)) { - PyErr_SetString(PyExc_TypeError, - "The 'keymap' attribute value must be a dict"); - return -1; - } - - Py_XDECREF(self->keymap); - Py_INCREF(value); - self->keymap = value; - - return 0; -} - -static PyObject * -BaseRow_getkeystyle(BaseRow *self, void *closure) -{ - PyObject *result; - - result = PyLong_FromLong(self->key_style); - Py_INCREF(result); - return result; -} - - -static int -BaseRow_setkeystyle(BaseRow *self, PyObject *value, void *closure) -{ - if (value == NULL) { - - PyErr_SetString( - PyExc_TypeError, - "Cannot delete the 'key_style' attribute"); - return -1; - } - - if (!PyLong_CheckExact(value)) { - PyErr_SetString( - PyExc_TypeError, - "The 'key_style' attribute value must be an integer"); - return -1; - } - - self->key_style = PyLong_AsLong(value); - - return 0; -} - -static PyGetSetDef BaseRow_getseters[] = { - {"_parent", - (getter)BaseRow_getparent, (setter)BaseRow_setparent, - "ResultMetaData", - NULL}, - {"_data", - (getter)BaseRow_getrow, (setter)BaseRow_setrow, - "processed data list", - NULL}, - {"_keymap", - (getter)BaseRow_getkeymap, (setter)BaseRow_setkeymap, - "Key to (obj, index) dict", - NULL}, - {"_key_style", - (getter)BaseRow_getkeystyle, (setter)BaseRow_setkeystyle, - "Return the key style", - NULL}, - {NULL} -}; - -static PyMethodDef BaseRow_methods[] = { - {"_values_impl", (PyCFunction)BaseRow_values_impl, METH_NOARGS, - "Return the values represented by this BaseRow as a list."}, - {"__reduce__", (PyCFunction)BaseRow_reduce, METH_NOARGS, - "Pickle support method."}, - {"_get_by_key_impl", (PyCFunction)BaseRow_subscript, METH_O, - "implement mapping-like getitem as well as sequence getitem"}, - {"_get_by_key_impl_mapping", (PyCFunction)BaseRow_subscript_mapping, METH_O, - "implement mapping-like getitem as well as sequence getitem"}, - {"_filter_on_values", (PyCFunction)BaseRow_filter_on_values, METH_O, - "return a new Row with per-value filters applied to columns"}, - - {NULL} /* Sentinel */ -}; - -// currently, the sq_item hook is not used by Python except for slices, -// because we also implement subscript_mapping which seems to intercept -// integers. Ideally, when there -// is a complete separation of "row" from "mapping", we can make -// two separate types here so that one has only sq_item and the other -// has only mp_subscript. -static PySequenceMethods BaseRow_as_sequence = { - (lenfunc)BaseRow_length, /* sq_length */ - 0, /* sq_concat */ - 0, /* sq_repeat */ - (ssizeargfunc)BaseRow_getitem, /* sq_item */ - 0, /* sq_slice */ - 0, /* sq_ass_item */ - 0, /* sq_ass_slice */ - 0, /* sq_contains */ - 0, /* sq_inplace_concat */ - 0, /* sq_inplace_repeat */ -}; - -static PyMappingMethods BaseRow_as_mapping = { - (lenfunc)BaseRow_length, /* mp_length */ - (binaryfunc)BaseRow_subscript_mapping, /* mp_subscript */ - 0 /* mp_ass_subscript */ -}; - -static PyTypeObject BaseRowType = { - PyVarObject_HEAD_INIT(NULL, 0) - "sqlalchemy.cresultproxy.BaseRow", /* tp_name */ - sizeof(BaseRow), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)BaseRow_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_compare */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - &BaseRow_as_sequence, /* tp_as_sequence */ - &BaseRow_as_mapping, /* tp_as_mapping */ - (hashfunc)BaseRow_hash, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - (getattrofunc)BaseRow_getattro,/* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC, /* tp_flags */ - "BaseRow is a abstract base class for Row", /* tp_doc */ - (traverseproc)BaseRow_traverse, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - (getiterfunc)BaseRow_iter, /* tp_iter */ - 0, /* tp_iternext */ - BaseRow_methods, /* tp_methods */ - 0, /* tp_members */ - BaseRow_getseters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)BaseRow_init, /* tp_init */ - 0, /* tp_alloc */ - 0 /* tp_new */ -}; - - - -/* _tuplegetter function ************************************************/ -/* -retrieves segments of a row as tuples. - -mostly like operator.itemgetter but calls a fixed method instead, -returns tuple every time. - -*/ - -typedef struct { - PyObject_HEAD - Py_ssize_t nitems; - PyObject *item; -} tuplegetterobject; - -static PyTypeObject tuplegetter_type; - -static PyObject * -tuplegetter_new(PyTypeObject *type, PyObject *args, PyObject *kwds) -{ - tuplegetterobject *tg; - PyObject *item; - Py_ssize_t nitems; - - if (!_PyArg_NoKeywords("tuplegetter", kwds)) - return NULL; - - nitems = PyTuple_GET_SIZE(args); - item = args; - - tg = PyObject_GC_New(tuplegetterobject, &tuplegetter_type); - if (tg == NULL) - return NULL; - - Py_INCREF(item); - tg->item = item; - tg->nitems = nitems; - PyObject_GC_Track(tg); - return (PyObject *)tg; -} - -static void -tuplegetter_dealloc(tuplegetterobject *tg) -{ - PyObject_GC_UnTrack(tg); - Py_XDECREF(tg->item); - PyObject_GC_Del(tg); -} - -static int -tuplegetter_traverse(tuplegetterobject *tg, visitproc visit, void *arg) -{ - Py_VISIT(tg->item); - return 0; -} - -static PyObject * -tuplegetter_call(tuplegetterobject *tg, PyObject *args, PyObject *kw) -{ - PyObject *row_or_tuple, *result; - Py_ssize_t i, nitems=tg->nitems; - int has_row_method; - - assert(PyTuple_CheckExact(args)); - - // this is a tuple, however if its a BaseRow subclass we want to - // call specific methods to bypass the pure python LegacyRow.__getitem__ - // method for now - row_or_tuple = PyTuple_GET_ITEM(args, 0); - - has_row_method = PyObject_HasAttrString(row_or_tuple, "_get_by_key_impl_mapping"); - - assert(PyTuple_Check(tg->item)); - assert(PyTuple_GET_SIZE(tg->item) == nitems); - - result = PyTuple_New(nitems); - if (result == NULL) - return NULL; - - for (i=0 ; i < nitems ; i++) { - PyObject *item, *val; - item = PyTuple_GET_ITEM(tg->item, i); - - if (has_row_method) { - val = PyObject_CallMethod(row_or_tuple, "_get_by_key_impl_mapping", "O", item); - } - else { - val = PyObject_GetItem(row_or_tuple, item); - } - - if (val == NULL) { - Py_DECREF(result); - return NULL; - } - PyTuple_SET_ITEM(result, i, val); - } - return result; -} - -static PyObject * -tuplegetter_repr(tuplegetterobject *tg) -{ - PyObject *repr; - const char *reprfmt; - - int status = Py_ReprEnter((PyObject *)tg); - if (status != 0) { - if (status < 0) - return NULL; - return PyUnicode_FromFormat("%s(...)", Py_TYPE(tg)->tp_name); - } - - reprfmt = tg->nitems == 1 ? "%s(%R)" : "%s%R"; - repr = PyUnicode_FromFormat(reprfmt, Py_TYPE(tg)->tp_name, tg->item); - Py_ReprLeave((PyObject *)tg); - return repr; -} - -static PyObject * -tuplegetter_reduce(tuplegetterobject *tg, PyObject *Py_UNUSED(ignored)) -{ - return PyTuple_Pack(2, Py_TYPE(tg), tg->item); -} - -PyDoc_STRVAR(reduce_doc, "Return state information for pickling"); - -static PyMethodDef tuplegetter_methods[] = { - {"__reduce__", (PyCFunction)tuplegetter_reduce, METH_NOARGS, - reduce_doc}, - {NULL} -}; - -PyDoc_STRVAR(tuplegetter_doc, -"tuplegetter(item, ...) --> tuplegetter object\n\ -\n\ -Return a callable object that fetches the given item(s) from its operand\n\ -and returns them as a tuple.\n"); - -static PyTypeObject tuplegetter_type = { - PyVarObject_HEAD_INIT(NULL, 0) - "sqlalchemy.engine.util.tuplegetter", /* tp_name */ - sizeof(tuplegetterobject), /* tp_basicsize */ - 0, /* tp_itemsize */ - /* methods */ - (destructor)tuplegetter_dealloc, /* tp_dealloc */ - 0, /* tp_vectorcall_offset */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_as_async */ - (reprfunc)tuplegetter_repr, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - (ternaryfunc)tuplegetter_call, /* tp_call */ - 0, /* tp_str */ - PyObject_GenericGetAttr, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */ - tuplegetter_doc, /* tp_doc */ - (traverseproc)tuplegetter_traverse, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - tuplegetter_methods, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - 0, /* tp_init */ - 0, /* tp_alloc */ - tuplegetter_new, /* tp_new */ - 0, /* tp_free */ -}; - - - -static PyMethodDef module_methods[] = { - {"safe_rowproxy_reconstructor", safe_rowproxy_reconstructor, METH_VARARGS, - "reconstruct a Row instance from its pickled form."}, - {NULL, NULL, 0, NULL} /* Sentinel */ -}; - -#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ -#define PyMODINIT_FUNC void -#endif - - -#if PY_MAJOR_VERSION >= 3 - -static struct PyModuleDef module_def = { - PyModuleDef_HEAD_INIT, - MODULE_NAME, - MODULE_DOC, - -1, - module_methods -}; - -#define INITERROR return NULL - -PyMODINIT_FUNC -PyInit_cresultproxy(void) - -#else - -#define INITERROR return - -PyMODINIT_FUNC -initcresultproxy(void) - -#endif - -{ - PyObject *m; - - BaseRowType.tp_new = PyType_GenericNew; - if (PyType_Ready(&BaseRowType) < 0) - INITERROR; - - if (PyType_Ready(&tuplegetter_type) < 0) - INITERROR; - -#if PY_MAJOR_VERSION >= 3 - m = PyModule_Create(&module_def); -#else - m = Py_InitModule3(MODULE_NAME, module_methods, MODULE_DOC); -#endif - if (m == NULL) - INITERROR; - - Py_INCREF(&BaseRowType); - PyModule_AddObject(m, "BaseRow", (PyObject *)&BaseRowType); - - Py_INCREF(&tuplegetter_type); - PyModule_AddObject(m, "tuplegetter", (PyObject *)&tuplegetter_type); - -#if PY_MAJOR_VERSION >= 3 - return m; -#endif -} diff --git a/lib/sqlalchemy/cyextension/.gitignore b/lib/sqlalchemy/cyextension/.gitignore new file mode 100644 index 000000000..dfc107eaf --- /dev/null +++ b/lib/sqlalchemy/cyextension/.gitignore @@ -0,0 +1,5 @@ +# cython complied files +*.c +*.o +# cython annotated output +*.html
\ No newline at end of file diff --git a/lib/sqlalchemy/cyextension/__init__.py b/lib/sqlalchemy/cyextension/__init__.py new file mode 100644 index 000000000..e69de29bb --- /dev/null +++ b/lib/sqlalchemy/cyextension/__init__.py diff --git a/lib/sqlalchemy/cyextension/collections.pyx b/lib/sqlalchemy/cyextension/collections.pyx new file mode 100644 index 000000000..e695d4c62 --- /dev/null +++ b/lib/sqlalchemy/cyextension/collections.pyx @@ -0,0 +1,393 @@ +from cpython.dict cimport PyDict_Merge, PyDict_Update +from cpython.long cimport PyLong_FromLong +from cpython.set cimport PySet_Add + +from itertools import filterfalse + +cdef bint add_not_present(set seen, object item, hashfunc): + hash_value = hashfunc(item) + if hash_value not in seen: + PySet_Add(seen, hash_value) + return True + else: + return False + +cdef list cunique_list(seq, hashfunc=None): + cdef set seen = set() + if not hashfunc: + return [x for x in seq if x not in seen and not PySet_Add(seen, x)] + else: + return [x for x in seq if add_not_present(seen, x, hashfunc)] + +def unique_list(seq, hashfunc=None): + return cunique_list(seq, hashfunc) + +cdef class OrderedSet(set): + + cdef list _list + + def __init__(self, d=None): + set.__init__(self) + if d is not None: + self._list = cunique_list(d) + set.update(self, self._list) + else: + self._list = [] + + cdef OrderedSet _copy(self): + cdef OrderedSet cp = OrderedSet.__new__(OrderedSet) + cp._list = list(self._list) + set.update(cp, cp._list) + return cp + + cdef OrderedSet _from_list(self, list new_list): + cdef OrderedSet new = OrderedSet.__new__(OrderedSet) + new._list = new_list + set.update(new, new_list) + return new + + def add(self, element): + if element not in self: + self._list.append(element) + PySet_Add(self, element) + + def remove(self, element): + # set.remove will raise if element is not in self + set.remove(self, element) + self._list.remove(element) + + def insert(self, Py_ssize_t pos, element): + if element not in self: + self._list.insert(pos, element) + PySet_Add(self, element) + + def discard(self, element): + if element in self: + set.remove(self, element) + self._list.remove(element) + + def clear(self): + set.clear(self) + self._list = [] + + def __getitem__(self, key): + return self._list[key] + + def __iter__(self): + return iter(self._list) + + def __add__(self, other): + return self.union(other) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._list) + + __str__ = __repr__ + + def update(self, iterable): + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) + return self + + def __ior__(self, iterable): + return self.update(iterable) + + def union(self, other): + result = self._copy() + result.update(other) + return result + + def __or__(self, other): + return self.union(other) + + cdef set _to_set(self, other): + cdef set other_set + if isinstance(other, set): + other_set = <set> other + else: + other_set = set(other) + return other_set + + def intersection(self, other): + cdef set other_set = self._to_set(other) + return self._from_list([a for a in self._list if a in other_set]) + + def __and__(self, other): + return self.intersection(other) + + def symmetric_difference(self, other): + cdef set other_set = self._to_set(other) + result = self._from_list([a for a in self._list if a not in other_set]) + # use other here to keep the order + result.update(a for a in other if a not in self) + return result + + def __xor__(self, other): + return self.symmetric_difference(other) + + def difference(self, other): + cdef set other_set = self._to_set(other) + return self._from_list([a for a in self._list if a not in other_set]) + + def __sub__(self, other): + return self.difference(other) + + def intersection_update(self, other): + cdef set other_set = self._to_set(other) + set.intersection_update(self, other_set) + self._list = [a for a in self._list if a in other_set] + return self + + def __iand__(self, other): + return self.intersection_update(other) + + def symmetric_difference_update(self, other): + set.symmetric_difference_update(self, other) + self._list = [a for a in self._list if a in self] + self._list += [a for a in other if a in self] + return self + + def __ixor__(self, other): + return self.symmetric_difference_update(other) + + def difference_update(self, other): + set.difference_update(self, other) + self._list = [a for a in self._list if a in self] + return self + + def __isub__(self, other): + return self.difference_update(other) + + +cdef object cy_id(object item): + return PyLong_FromLong(<long> (<void *>item)) + +# NOTE: cython 0.x will call __add__, __sub__, etc with the parameter swapped +# instead of the __rmeth__, so they need to check that also self is of the +# correct type. This is fixed in cython 3.x. See: +# https://docs.cython.org/en/latest/src/userguide/special_methods.html#arithmetic-methods + +cdef class IdentitySet: + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + + """ + + cdef dict _members + + def __init__(self, iterable=None): + self._members = {} + if iterable: + self.update(iterable) + + def add(self, value): + self._members[cy_id(value)] = value + + def __contains__(self, value): + return cy_id(value) in self._members + + cpdef remove(self, value): + del self._members[cy_id(value)] + + def discard(self, value): + try: + self.remove(value) + except KeyError: + pass + + def pop(self): + cdef tuple pair + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError("pop from an empty set") + + def clear(self): + self._members.clear() + + def __cmp__(self, other): + raise TypeError("cannot compare sets using cmp()") + + def __eq__(self, other): + cdef IdentitySet other_ + if isinstance(other, IdentitySet): + other_ = other + return self._members == other_._members + else: + return False + + def __ne__(self, other): + cdef IdentitySet other_ + if isinstance(other, IdentitySet): + other_ = other + return self._members != other_._members + else: + return True + + cpdef issubset(self, iterable): + cdef IdentitySet other + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) > len(other): + return False + for m in filterfalse(other._members.__contains__, self._members): + return False + return True + + def __le__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + cpdef issuperset(self, iterable): + cdef IdentitySet other + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) < len(other): + return False + for m in filterfalse(self._members.__contains__, other._members): + return False + return True + + def __ge__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + cpdef IdentitySet union(self, iterable): + cdef IdentitySet result = self.__class__() + result._members.update(self._members) + result.update(iterable) + return result + + def __or__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.union(other) + + cpdef update(self, iterable): + for obj in iterable: + self._members[cy_id(obj)] = obj + + def __ior__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + cpdef difference(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = (<IdentitySet>iterable)._members + else: + other = {cy_id(obj) for obj in iterable} + result._members = {k:v for k, v in self._members.items() if k not in other} + return result + + def __sub__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.difference(other) + + cpdef difference_update(self, iterable): + cdef IdentitySet other = self.difference(iterable) + self._members = other._members + + def __isub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + cpdef intersection(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = (<IdentitySet>iterable)._members + else: + other = {cy_id(obj) for obj in iterable} + result._members = {k: v for k, v in self._members.items() if k in other} + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.intersection(other) + + cpdef intersection_update(self, iterable): + cdef IdentitySet other = self.intersection(iterable) + self._members = other._members + + def __iand__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + cpdef symmetric_difference(self, iterable): + cdef IdentitySet result = self.__new__(self.__class__) + cdef dict other + if isinstance(iterable, self.__class__): + other = (<IdentitySet>iterable)._members + else: + other = {cy_id(obj): obj for obj in iterable} + result._members = {k: v for k, v in self._members.items() if k not in other} + result._members.update( + [(k, v) for k, v in other.items() if k not in self._members] + ) + return result + + def __xor__(self, other): + if not isinstance(other, IdentitySet) or not isinstance(self, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + cpdef symmetric_difference_update(self, iterable): + cdef IdentitySet other = self.symmetric_difference(iterable) + self._members = other._members + + def __ixor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + cpdef copy(self): + cdef IdentitySet cp = self.__new__(self.__class__) + cp._members = self._members.copy() + return cp + + def __copy__(self): + return self.copy() + + def __len__(self): + return len(self._members) + + def __iter__(self): + return iter(self._members.values()) + + def __hash__(self): + raise TypeError("set objects are unhashable") + + def __repr__(self): + return "%s(%r)" % (type(self).__name__, list(self._members.values())) diff --git a/lib/sqlalchemy/cyextension/immutabledict.pxd b/lib/sqlalchemy/cyextension/immutabledict.pxd new file mode 100644 index 000000000..fe7ad6a81 --- /dev/null +++ b/lib/sqlalchemy/cyextension/immutabledict.pxd @@ -0,0 +1,2 @@ +cdef class immutabledict(dict): + pass diff --git a/lib/sqlalchemy/cyextension/immutabledict.pyx b/lib/sqlalchemy/cyextension/immutabledict.pyx new file mode 100644 index 000000000..89bcf3ed6 --- /dev/null +++ b/lib/sqlalchemy/cyextension/immutabledict.pyx @@ -0,0 +1,100 @@ +from cpython.dict cimport PyDict_New, PyDict_Update, PyDict_Size + + +def _immutable_fn(obj): + raise TypeError("%s object is immutable" % obj.__class__.__name__) + + +class ImmutableContainer: + def _immutable(self, *a,**kw): + _immutable_fn(self) + + __delitem__ = __setitem__ = __setattr__ = _immutable + + +cdef class immutabledict(dict): + def __repr__(self): + return f"immutabledict({dict.__repr__(self)})" + + def union(self, *args, **kw): + cdef dict to_merge = None + cdef immutabledict result + cdef Py_ssize_t args_len = len(args) + if args_len > 1: + raise TypeError( + f'union expected at most 1 argument, got {args_len}' + ) + if args_len == 1: + attribute = args[0] + if isinstance(attribute, dict): + to_merge = <dict> attribute + if to_merge is None: + to_merge = dict(*args, **kw) + + if PyDict_Size(to_merge) == 0: + return self + + # new + update is faster than immutabledict(self) + result = immutabledict() + PyDict_Update(result, self) + PyDict_Update(result, to_merge) + return result + + def merge_with(self, *other): + cdef immutabledict result = None + cdef object d + cdef bint update = False + if not other: + return self + for d in other: + if d: + if update == False: + update = True + # new + update is faster than immutabledict(self) + result = immutabledict() + PyDict_Update(result, self) + PyDict_Update( + result, <dict>(d if isinstance(d, dict) else dict(d)) + ) + + return self if update == False else result + + def copy(self): + return self + + def __reduce__(self): + return immutabledict, (dict(self), ) + + def __delitem__(self, k): + _immutable_fn(self) + + def __setitem__(self, k, v): + _immutable_fn(self) + + def __setattr__(self, k, v): + _immutable_fn(self) + + def clear(self, *args, **kw): + _immutable_fn(self) + + def pop(self, *args, **kw): + _immutable_fn(self) + + def popitem(self, *args, **kw): + _immutable_fn(self) + + def setdefault(self, *args, **kw): + _immutable_fn(self) + + def update(self, *args, **kw): + _immutable_fn(self) + + # PEP 584 + def __ior__(self, other): + _immutable_fn(self) + + def __or__(self, other): + return immutabledict(super().__or__(other)) + + def __ror__(self, other): + return immutabledict(super().__ror__(other)) diff --git a/lib/sqlalchemy/cyextension/processors.pyx b/lib/sqlalchemy/cyextension/processors.pyx new file mode 100644 index 000000000..9f23e73b1 --- /dev/null +++ b/lib/sqlalchemy/cyextension/processors.pyx @@ -0,0 +1,91 @@ +import datetime +import re + +from cpython.datetime cimport date_new, datetime_new, import_datetime, time_new +from cpython.object cimport PyObject_Str +from cpython.unicode cimport PyUnicode_AsASCIIString, PyUnicode_Check, PyUnicode_Decode +from libc.stdio cimport sscanf + + +def int_to_boolean(value): + if value is None: + return None + return True if value else False + +def to_str(value): + return PyObject_Str(value) if value is not None else None + +def to_float(value): + return float(value) if value is not None else None + +cdef inline bytes to_bytes(object value, str type_name): + try: + return PyUnicode_AsASCIIString(value) + except Exception as e: + raise ValueError( + f"Couldn't parse {type_name} string '{value!r}' " + "- value is not a string." + ) from e + +import_datetime() # required to call datetime_new/date_new/time_new + +def str_to_datetime(value): + if value is None: + return None + cdef int numparsed + cdef unsigned int year, month, day, hour, minute, second, microsecond = 0 + cdef bytes value_b = to_bytes(value, 'datetime') + cdef const char * string = value_b + + numparsed = sscanf(string, "%4u-%2u-%2u %2u:%2u:%2u.%6u", + &year, &month, &day, &hour, &minute, &second, µsecond) + if numparsed < 6: + raise ValueError( + "Couldn't parse datetime string: '%s'" % (value) + ) + return datetime_new(year, month, day, hour, minute, second, microsecond, None) + +def str_to_date(value): + if value is None: + return None + cdef int numparsed + cdef unsigned int year, month, day + cdef bytes value_b = to_bytes(value, 'date') + cdef const char * string = value_b + + numparsed = sscanf(string, "%4u-%2u-%2u", &year, &month, &day) + if numparsed != 3: + raise ValueError( + "Couldn't parse date string: '%s'" % (value) + ) + return date_new(year, month, day) + +def str_to_time(value): + if value is None: + return None + cdef int numparsed + cdef unsigned int hour, minute, second, microsecond = 0 + cdef bytes value_b = to_bytes(value, 'time') + cdef const char * string = value_b + + numparsed = sscanf(string, "%2u:%2u:%2u.%6u", &hour, &minute, &second, µsecond) + if numparsed < 3: + raise ValueError( + "Couldn't parse time string: '%s'" % (value) + ) + return time_new(hour, minute, second, microsecond, None) + + +cdef class DecimalResultProcessor: + cdef object type_ + cdef str format_ + + def __cinit__(self, type_, format_): + self.type_ = type_ + self.format_ = format_ + + def process(self, object value): + if value is None: + return None + else: + return self.type_(self.format_ % value) diff --git a/lib/sqlalchemy/cyextension/resultproxy.pyx b/lib/sqlalchemy/cyextension/resultproxy.pyx new file mode 100644 index 000000000..daf5cc940 --- /dev/null +++ b/lib/sqlalchemy/cyextension/resultproxy.pyx @@ -0,0 +1,130 @@ +# TODO: this is mostly just copied over from the python implementation +# more improvements are likely possible +import operator + +cdef int MD_INDEX = 0 # integer index in cursor.description + +KEY_INTEGER_ONLY = 0 +KEY_OBJECTS_ONLY = 1 + +sqlalchemy_engine_row = None + +cdef class BaseRow: + cdef readonly object _parent + cdef readonly tuple _data + cdef readonly dict _keymap + cdef readonly int _key_style + + def __init__(self, object parent, object processors, dict keymap, int key_style, object data): + """Row objects are constructed by CursorResult objects.""" + + self._parent = parent + + if processors: + self._data = tuple( + [ + proc(value) if proc else value + for proc, value in zip(processors, data) + ] + ) + else: + self._data = tuple(data) + + self._keymap = keymap + + self._key_style = key_style + + def __reduce__(self): + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) + + def __getstate__(self): + return { + "_parent": self._parent, + "_data": self._data, + "_key_style": self._key_style, + } + + def __setstate__(self, dict state): + self._parent = state["_parent"] + self._data = state["_data"] + self._keymap = self._parent._keymap + self._key_style = state["_key_style"] + + def _filter_on_values(self, filters): + global sqlalchemy_engine_row + if sqlalchemy_engine_row is None: + from sqlalchemy.engine.row import Row as sqlalchemy_engine_row + + return sqlalchemy_engine_row( + self._parent, + filters, + self._keymap, + self._key_style, + self._data, + ) + + def _values_impl(self): + return list(self) + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __hash__(self): + return hash(self._data) + + def _get_by_int_impl(self, key): + return self._data[key] + + cpdef _get_by_key_impl(self, key): + # keep two isinstance since it's noticeably faster in the int case + if isinstance(key, int) or isinstance(key, slice): + return self._data[key] + + self._parent._raise_for_nonint(key) + + def __getitem__(self, key): + return self._get_by_key_impl(key) + + cpdef _get_by_key_impl_mapping(self, key): + try: + rec = self._keymap[key] + except KeyError as ke: + rec = self._parent._key_fallback(key, ke) + + mdindex = rec[MD_INDEX] + if mdindex is None: + self._parent._raise_for_ambiguous_column_name(rec) + elif ( + self._key_style == KEY_OBJECTS_ONLY + and isinstance(key, int) + ): + raise KeyError(key) + + return self._data[mdindex] + + def __getattr__(self, name): + try: + return self._get_by_key_impl_mapping(name) + except KeyError as e: + raise AttributeError(e.args[0]) from e + + +def rowproxy_reconstructor(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + +def tuplegetter(*indexes): + it = operator.itemgetter(*indexes) + + if len(indexes) > 1: + return it + else: + return lambda row: (it(row),) diff --git a/lib/sqlalchemy/cyextension/util.pyx b/lib/sqlalchemy/cyextension/util.pyx new file mode 100644 index 000000000..ac15ff9de --- /dev/null +++ b/lib/sqlalchemy/cyextension/util.pyx @@ -0,0 +1,43 @@ +from collections.abc import Mapping + +from sqlalchemy import exc + +cdef tuple _Empty_Tuple = () + +cdef inline bint _mapping_or_tuple(object value): + return isinstance(value, dict) or isinstance(value, tuple) or isinstance(value, Mapping) + +cdef inline bint _check_item(object params) except 0: + cdef object item + cdef bint ret = 1 + if params: + item = params[0] + if not _mapping_or_tuple(item): + ret = 0 + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + return ret + +def _distill_params_20(object params): + if params is None: + return _Empty_Tuple + elif isinstance(params, list) or isinstance(params, tuple): + _check_item(params) + return params + elif isinstance(params, dict) or isinstance(params, Mapping): + return [params] + else: + raise exc.ArgumentError("mapping or list expected for parameters") + + +def _distill_raw_params(object params): + if params is None: + return _Empty_Tuple + elif isinstance(params, list): + _check_item(params) + return params + elif _mapping_or_tuple(params): + return [params] + else: + raise exc.ArgumentError("mapping or sequence expected for parameters") diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index a9dc97d54..367771987 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -43,9 +43,9 @@ import re from .base import MSDialect from .base import MSIdentifierPreparer -from ... import processors from ... import types as sqltypes from ... import util +from ...engine import processors class _MSNumeric_pymssql(sqltypes.Numeric): diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index eecf8567c..9b097fb0e 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -439,11 +439,11 @@ from .base import OracleCompiler from .base import OracleDialect from .base import OracleExecutionContext from ... import exc -from ... import processors from ... import types as sqltypes from ... import util from ...engine import cursor as _cursor from ...engine import interfaces +from ...engine import processors class _OracleInteger(sqltypes.Integer): diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index a3a378947..265d8617e 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -9,9 +9,9 @@ from .base import PGExecutionContext from .base import UUID from .hstore import HSTORE from ... import exc -from ... import processors from ... import types as sqltypes from ... import util +from ...engine import processors _server_side_id = util.counter() diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 4951107bc..a9d6727c9 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -120,9 +120,9 @@ from .base import REGCLASS from .base import UUID from ... import exc from ... import pool -from ... import processors from ... import util from ...engine import AdaptedConnection +from ...engine import processors from ...sql import sqltypes from ...util.concurrency import asyncio from ...util.concurrency import await_fallback diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 1904a1ae1..3c0d2de64 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -108,9 +108,9 @@ from .json import JSON from .json import JSONB from .json import JSONPathType from ... import exc -from ... import processors from ... import types as sqltypes from ... import util +from ...engine import processors from ...sql.elements import quoted_name diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 0c7f8d839..43883c4b7 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -818,12 +818,12 @@ from .json import JSON from .json import JSONIndexType from .json import JSONPathType from ... import exc -from ... import processors from ... import schema as sa_schema from ... import sql from ... import types as sqltypes from ... import util from ...engine import default +from ...engine import processors from ...engine import reflection from ...sql import coercions from ...sql import ColumnElement diff --git a/lib/sqlalchemy/engine/_py_processors.py b/lib/sqlalchemy/engine/_py_processors.py new file mode 100644 index 000000000..db722a978 --- /dev/null +++ b/lib/sqlalchemy/engine/_py_processors.py @@ -0,0 +1,106 @@ +# sqlalchemy/processors.py +# Copyright (C) 2010-2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""defines generic type conversion functions, as used in bind and result +processors. + +They all share one common characteristic: None is passed through unchanged. + +""" + +import datetime +import re + +from .. import util + + +def str_to_datetime_processor_factory(regexp, type_): + rmatch = regexp.match + # Even on python2.6 datetime.strptime is both slower than this code + # and it does not support microseconds. + has_named_groups = bool(regexp.groupindex) + + def process(value): + if value is None: + return None + else: + try: + m = rmatch(value) + except TypeError as err: + util.raise_( + ValueError( + "Couldn't parse %s string '%r' " + "- value is not a string." % (type_.__name__, value) + ), + from_=err, + ) + if m is None: + raise ValueError( + "Couldn't parse %s string: " + "'%s'" % (type_.__name__, value) + ) + if has_named_groups: + groups = m.groupdict(0) + return type_( + **dict( + list( + zip( + iter(groups.keys()), + list(map(int, iter(groups.values()))), + ) + ) + ) + ) + else: + return type_(*list(map(int, m.groups(0)))) + + return process + + +def to_decimal_processor_factory(target_class, scale): + fstring = "%%.%df" % scale + + def process(value): + if value is None: + return None + else: + return target_class(fstring % value) + + return process + + +def to_float(value): + if value is None: + return None + else: + return float(value) + + +def to_str(value): + if value is None: + return None + else: + return str(value) + + +def int_to_boolean(value): + if value is None: + return None + else: + return bool(value) + + +DATETIME_RE = re.compile(r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?") +TIME_RE = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") +DATE_RE = re.compile(r"(\d+)-(\d+)-(\d+)") + +str_to_datetime = str_to_datetime_processor_factory( + DATETIME_RE, datetime.datetime +) +str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time) +str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date) diff --git a/lib/sqlalchemy/engine/_py_row.py b/lib/sqlalchemy/engine/_py_row.py new file mode 100644 index 000000000..981b6e0b2 --- /dev/null +++ b/lib/sqlalchemy/engine/_py_row.py @@ -0,0 +1,138 @@ +import operator + +MD_INDEX = 0 # integer index in cursor.description + +KEY_INTEGER_ONLY = 0 +"""__getitem__ only allows integer values and slices, raises TypeError + otherwise""" + +KEY_OBJECTS_ONLY = 1 +"""__getitem__ only allows string/object values, raises TypeError otherwise""" + +sqlalchemy_engine_row = None + + +class BaseRow: + Row = None + __slots__ = ("_parent", "_data", "_keymap", "_key_style") + + def __init__(self, parent, processors, keymap, key_style, data): + """Row objects are constructed by CursorResult objects.""" + + object.__setattr__(self, "_parent", parent) + + if processors: + object.__setattr__( + self, + "_data", + tuple( + [ + proc(value) if proc else value + for proc, value in zip(processors, data) + ] + ), + ) + else: + object.__setattr__(self, "_data", tuple(data)) + + object.__setattr__(self, "_keymap", keymap) + + object.__setattr__(self, "_key_style", key_style) + + def __reduce__(self): + return ( + rowproxy_reconstructor, + (self.__class__, self.__getstate__()), + ) + + def __getstate__(self): + return { + "_parent": self._parent, + "_data": self._data, + "_key_style": self._key_style, + } + + def __setstate__(self, state): + parent = state["_parent"] + object.__setattr__(self, "_parent", parent) + object.__setattr__(self, "_data", state["_data"]) + object.__setattr__(self, "_keymap", parent._keymap) + object.__setattr__(self, "_key_style", state["_key_style"]) + + def _filter_on_values(self, filters): + global sqlalchemy_engine_row + if sqlalchemy_engine_row is None: + from sqlalchemy.engine.row import Row as sqlalchemy_engine_row + + return sqlalchemy_engine_row( + self._parent, + filters, + self._keymap, + self._key_style, + self._data, + ) + + def _values_impl(self): + return list(self) + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __hash__(self): + return hash(self._data) + + def _get_by_int_impl(self, key): + return self._data[key] + + def _get_by_key_impl(self, key): + # keep two isinstance since it's noticeably faster in the int case + if isinstance(key, int) or isinstance(key, slice): + return self._data[key] + + self._parent._raise_for_nonint(key) + + # The original 1.4 plan was that Row would not allow row["str"] + # access, however as the C extensions were inadvertently allowing + # this coupled with the fact that orm Session sets future=True, + # this allows a softer upgrade path. see #6218 + __getitem__ = _get_by_key_impl + + def _get_by_key_impl_mapping(self, key): + try: + rec = self._keymap[key] + except KeyError as ke: + rec = self._parent._key_fallback(key, ke) + + mdindex = rec[MD_INDEX] + if mdindex is None: + self._parent._raise_for_ambiguous_column_name(rec) + elif self._key_style == KEY_OBJECTS_ONLY and isinstance(key, int): + raise KeyError(key) + + return self._data[mdindex] + + def __getattr__(self, name): + try: + return self._get_by_key_impl_mapping(name) + except KeyError as e: + raise AttributeError(e.args[0]) from e + + +# This reconstructor is necessary so that pickles with the Cy extension or +# without use the same Binary format. +def rowproxy_reconstructor(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + +def tuplegetter(*indexes): + it = operator.itemgetter(*indexes) + + if len(indexes) > 1: + return it + else: + return lambda row: (it(row),) diff --git a/lib/sqlalchemy/engine/_py_util.py b/lib/sqlalchemy/engine/_py_util.py new file mode 100644 index 000000000..2db6c049b --- /dev/null +++ b/lib/sqlalchemy/engine/_py_util.py @@ -0,0 +1,54 @@ +from collections import abc as collections_abc + +from .. import exc + +_no_tuple = () + + +def _distill_params_20(params): + if params is None: + return _no_tuple + # Assume list is more likely than tuple + elif isinstance(params, list) or isinstance(params, tuple): + # collections_abc.MutableSequence): # avoid abc.__instancecheck__ + if params and not isinstance( + params[0], (tuple, collections_abc.Mapping) + ): + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + + return params + elif isinstance(params, dict) or isinstance( + # only do immutabledict or abc.__instancecheck__ for Mapping after + # we've checked for plain dictionaries and would otherwise raise + params, + collections_abc.Mapping, + ): + return [params] + else: + raise exc.ArgumentError("mapping or list expected for parameters") + + +def _distill_raw_params(params): + if params is None: + return _no_tuple + elif isinstance(params, list): + # collections_abc.MutableSequence): # avoid abc.__instancecheck__ + if params and not isinstance( + params[0], (tuple, collections_abc.Mapping) + ): + raise exc.ArgumentError( + "List argument must consist only of tuples or dictionaries" + ) + + return params + elif isinstance(params, (tuple, dict)) or isinstance( + # only do abc.__instancecheck__ for Mapping after we've checked + # for plain dictionaries and would otherwise raise + params, + collections_abc.Mapping, + ): + return [params] + else: + raise exc.ArgumentError("mapping or sequence expected for parameters") diff --git a/lib/sqlalchemy/engine/processors.py b/lib/sqlalchemy/engine/processors.py new file mode 100644 index 000000000..023444d10 --- /dev/null +++ b/lib/sqlalchemy/engine/processors.py @@ -0,0 +1,44 @@ +# sqlalchemy/processors.py +# Copyright (C) 2010-2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php + +"""defines generic type conversion functions, as used in bind and result +processors. + +They all share one common characteristic: None is passed through unchanged. + +""" +from ._py_processors import str_to_datetime_processor_factory # noqa + +try: + from sqlalchemy.cyextension.processors import ( + DecimalResultProcessor, + ) # noqa + from sqlalchemy.cyextension.processors import int_to_boolean # noqa + from sqlalchemy.cyextension.processors import str_to_date # noqa + from sqlalchemy.cyextension.processors import str_to_datetime # noqa + from sqlalchemy.cyextension.processors import str_to_time # noqa + from sqlalchemy.cyextension.processors import to_float # noqa + from sqlalchemy.cyextension.processors import to_str # noqa + + def to_decimal_processor_factory(target_class, scale): + # Note that the scale argument is not taken into account for integer + # values in the C implementation while it is in the Python one. + # For example, the Python implementation might return + # Decimal('5.00000') whereas the C implementation will + # return Decimal('5'). These are equivalent of course. + return DecimalResultProcessor(target_class, "%%.%df" % scale).process + + +except ImportError: + from ._py_processors import int_to_boolean # noqa + from ._py_processors import str_to_date # noqa + from ._py_processors import str_to_datetime # noqa + from ._py_processors import str_to_time # noqa + from ._py_processors import to_decimal_processor_factory # noqa + from ._py_processors import to_float # noqa + from ._py_processors import to_str # noqa diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index e2f4033e0..ff292c7d7 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -13,7 +13,6 @@ import functools import itertools import operator -from .row import _baserow_usecext from .row import Row from .. import exc from .. import util @@ -21,18 +20,10 @@ from ..sql.base import _generative from ..sql.base import HasMemoized from ..sql.base import InPlaceGenerative - -if _baserow_usecext: - from sqlalchemy.cresultproxy import tuplegetter -else: - - def tuplegetter(*indexes): - it = operator.itemgetter(*indexes) - - if len(indexes) > 1: - return it - else: - return lambda row: (it(row),) +try: + from sqlalchemy.cyextension.resultproxy import tuplegetter +except ImportError: + from ._py_row import tuplegetter class ResultMetaData: @@ -104,7 +95,7 @@ class RMKeyView(collections_abc.KeysView): return iter(self._keys) def __contains__(self, item): - if not _baserow_usecext and isinstance(item, int): + if isinstance(item, int): return False # note this also includes special key fallback behaviors diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 43ef093d6..47f5ac1cd 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -11,143 +11,17 @@ import collections.abc as collections_abc import operator -from .. import util from ..sql import util as sql_util -MD_INDEX = 0 # integer index in cursor.description -# This reconstructor is necessary so that pickles with the C extension or -# without use the same Binary format. try: - # We need a different reconstructor on the C extension so that we can - # add extra checks that fields have correctly been initialized by - # __setstate__. - from sqlalchemy.cresultproxy import safe_rowproxy_reconstructor - - # The extra function embedding is needed so that the - # reconstructor function has the same signature whether or not - # the extension is present. - def rowproxy_reconstructor(cls, state): - return safe_rowproxy_reconstructor(cls, state) - - + from sqlalchemy.cyextension.resultproxy import BaseRow + from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY + from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY except ImportError: - - def rowproxy_reconstructor(cls, state): - obj = cls.__new__(cls) - obj.__setstate__(state) - return obj - - -KEY_INTEGER_ONLY = 0 -"""__getitem__ only allows integer values and slices, raises TypeError - otherwise""" - -KEY_OBJECTS_ONLY = 1 -"""__getitem__ only allows string/object values, raises TypeError otherwise""" - -try: - from sqlalchemy.cresultproxy import BaseRow - - _baserow_usecext = True -except ImportError: - _baserow_usecext = False - - class BaseRow: - __slots__ = ("_parent", "_data", "_keymap", "_key_style") - - def __init__(self, parent, processors, keymap, key_style, data): - """Row objects are constructed by CursorResult objects.""" - - object.__setattr__(self, "_parent", parent) - - if processors: - object.__setattr__( - self, - "_data", - tuple( - [ - proc(value) if proc else value - for proc, value in zip(processors, data) - ] - ), - ) - else: - object.__setattr__(self, "_data", tuple(data)) - - object.__setattr__(self, "_keymap", keymap) - - object.__setattr__(self, "_key_style", key_style) - - def __reduce__(self): - return ( - rowproxy_reconstructor, - (self.__class__, self.__getstate__()), - ) - - def _filter_on_values(self, filters): - return Row( - self._parent, - filters, - self._keymap, - self._key_style, - self._data, - ) - - def _values_impl(self): - return list(self) - - def __iter__(self): - return iter(self._data) - - def __len__(self): - return len(self._data) - - def __hash__(self): - return hash(self._data) - - def _get_by_int_impl(self, key): - return self._data[key] - - def _get_by_key_impl(self, key): - if int in key.__class__.__mro__: - return self._data[key] - - assert self._key_style == KEY_INTEGER_ONLY - - if isinstance(key, slice): - return tuple(self._data[key]) - - self._parent._raise_for_nonint(key) - - # The original 1.4 plan was that Row would not allow row["str"] - # access, however as the C extensions were inadvertently allowing - # this coupled with the fact that orm Session sets future=True, - # this allows a softer upgrade path. see #6218 - __getitem__ = _get_by_key_impl - - def _get_by_key_impl_mapping(self, key): - try: - rec = self._keymap[key] - except KeyError as ke: - rec = self._parent._key_fallback(key, ke) - - mdindex = rec[MD_INDEX] - if mdindex is None: - self._parent._raise_for_ambiguous_column_name(rec) - elif ( - self._key_style == KEY_OBJECTS_ONLY - and int in key.__class__.__mro__ - ): - raise KeyError(key) - - return self._data[mdindex] - - def __getattr__(self, name): - try: - return self._get_by_key_impl_mapping(name) - except KeyError as e: - util.raise_(AttributeError(e.args[0]), replace_context=e) + from ._py_row import BaseRow + from ._py_row import KEY_INTEGER_ONLY + from ._py_row import KEY_OBJECTS_ONLY class Row(BaseRow, collections_abc.Sequence): @@ -235,20 +109,6 @@ class Row(BaseRow, collections_abc.Sequence): def __contains__(self, key): return key in self._data - def __getstate__(self): - return { - "_parent": self._parent, - "_data": self._data, - "_key_style": self._key_style, - } - - def __setstate__(self, state): - parent = state["_parent"] - object.__setattr__(self, "_parent", parent) - object.__setattr__(self, "_data", state["_data"]) - object.__setattr__(self, "_keymap", parent._keymap) - object.__setattr__(self, "_key_style", state["_key_style"]) - def _op(self, other, op): return ( op(tuple(self), tuple(other)) @@ -392,12 +252,10 @@ class RowMapping(BaseRow, collections_abc.Mapping): _default_key_style = KEY_OBJECTS_ONLY - if not _baserow_usecext: - - __getitem__ = BaseRow._get_by_key_impl_mapping + __getitem__ = BaseRow._get_by_key_impl_mapping - def _values_impl(self): - return list(self._data) + def _values_impl(self): + return list(self._data) def __iter__(self): return (k for k in self._parent.keys if k is not None) diff --git a/lib/sqlalchemy/engine/util.py b/lib/sqlalchemy/engine/util.py index e88b9ebf3..4cc7df790 100644 --- a/lib/sqlalchemy/engine/util.py +++ b/lib/sqlalchemy/engine/util.py @@ -5,11 +5,15 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -import collections.abc as collections_abc - from .. import exc from .. import util -from ..util import immutabledict + +try: + from sqlalchemy.cyextension.util import _distill_params_20 # noqa + from sqlalchemy.cyextension.util import _distill_raw_params # noqa +except ImportError: + from ._py_util import _distill_params_20 # noqa + from ._py_util import _distill_raw_params # noqa def connection_memoize(key): @@ -31,57 +35,6 @@ def connection_memoize(key): return decorated -_no_tuple = () - - -def _distill_params_20(params): - if params is None: - return _no_tuple - elif isinstance(params, (list, tuple)): - # collections_abc.MutableSequence): # avoid abc.__instancecheck__ - if params and not isinstance( - params[0], (collections_abc.Mapping, tuple) - ): - raise exc.ArgumentError( - "List argument must consist only of tuples or dictionaries" - ) - - return params - elif isinstance( - params, - (dict, immutabledict), - # only do abc.__instancecheck__ for Mapping after we've checked - # for plain dictionaries and would otherwise raise - ) or isinstance(params, collections_abc.Mapping): - return [params] - else: - raise exc.ArgumentError("mapping or sequence expected for parameters") - - -def _distill_raw_params(params): - if params is None: - return _no_tuple - elif isinstance(params, (list,)): - # collections_abc.MutableSequence): # avoid abc.__instancecheck__ - if params and not isinstance( - params[0], (collections_abc.Mapping, tuple) - ): - raise exc.ArgumentError( - "List argument must consist only of tuples or dictionaries" - ) - - return params - elif isinstance( - params, - (tuple, dict, immutabledict), - # only do abc.__instancecheck__ for Mapping after we've checked - # for plain dictionaries and would otherwise raise - ) or isinstance(params, collections_abc.Mapping): - return [params] - else: - raise exc.ArgumentError("mapping or sequence expected for parameters") - - class TransactionalContext: """Apply Python context manager behavior to transaction objects. diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py deleted file mode 100644 index 156005c6a..000000000 --- a/lib/sqlalchemy/processors.py +++ /dev/null @@ -1,132 +0,0 @@ -# sqlalchemy/processors.py -# Copyright (C) 2010-2021 the SQLAlchemy authors and contributors -# <see AUTHORS file> -# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com -# -# This module is part of SQLAlchemy and is released under -# the MIT License: https://www.opensource.org/licenses/mit-license.php - -"""defines generic type conversion functions, as used in bind and result -processors. - -They all share one common characteristic: None is passed through unchanged. - -""" - -import datetime -import re - -from . import util - - -def str_to_datetime_processor_factory(regexp, type_): - rmatch = regexp.match - # Even on python2.6 datetime.strptime is both slower than this code - # and it does not support microseconds. - has_named_groups = bool(regexp.groupindex) - - def process(value): - if value is None: - return None - else: - try: - m = rmatch(value) - except TypeError as err: - util.raise_( - ValueError( - "Couldn't parse %s string '%r' " - "- value is not a string." % (type_.__name__, value) - ), - from_=err, - ) - if m is None: - raise ValueError( - "Couldn't parse %s string: " - "'%s'" % (type_.__name__, value) - ) - if has_named_groups: - groups = m.groupdict(0) - return type_( - **dict( - list( - zip( - iter(groups.keys()), - list(map(int, iter(groups.values()))), - ) - ) - ) - ) - else: - return type_(*list(map(int, m.groups(0)))) - - return process - - -def py_fallback(): - def to_decimal_processor_factory(target_class, scale): - fstring = "%%.%df" % scale - - def process(value): - if value is None: - return None - else: - return target_class(fstring % value) - - return process - - def to_float(value): # noqa - if value is None: - return None - else: - return float(value) - - def to_str(value): # noqa - if value is None: - return None - else: - return str(value) - - def int_to_boolean(value): # noqa - if value is None: - return None - else: - return bool(value) - - DATETIME_RE = re.compile( - r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?" - ) - TIME_RE = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") - DATE_RE = re.compile(r"(\d+)-(\d+)-(\d+)") - - str_to_datetime = str_to_datetime_processor_factory( # noqa - DATETIME_RE, datetime.datetime - ) - str_to_time = str_to_datetime_processor_factory( # noqa - TIME_RE, datetime.time - ) # noqa - str_to_date = str_to_datetime_processor_factory( # noqa - DATE_RE, datetime.date - ) # noqa - return locals() - - -try: - from sqlalchemy.cprocessors import DecimalResultProcessor # noqa - from sqlalchemy.cprocessors import int_to_boolean # noqa - from sqlalchemy.cprocessors import str_to_date # noqa - from sqlalchemy.cprocessors import str_to_datetime # noqa - from sqlalchemy.cprocessors import str_to_time # noqa - from sqlalchemy.cprocessors import to_float # noqa - from sqlalchemy.cprocessors import to_str # noqa - - def to_decimal_processor_factory(target_class, scale): - # Note that the scale argument is not taken into account for integer - # values in the C implementation while it is in the Python one. - # For example, the Python implementation might return - # Decimal('5.00000') whereas the C implementation will - # return Decimal('5'). These are equivalent of course. - return DecimalResultProcessor(target_class, "%%.%df" % scale).process - - -except ImportError: - globals().update(py_fallback()) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index d141c8c68..e65fa3c14 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -37,8 +37,8 @@ from .type_api import Variant from .. import event from .. import exc from .. import inspection -from .. import processors from .. import util +from ..engine import processors from ..util import langhelpers from ..util import OrderedDict diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index f811be657..5cf80a1fb 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -1349,7 +1349,8 @@ class SuiteRequirements(Requirements): @property def cextensions(self): return exclusions.skip_if( - lambda: not util.has_compiled_ext(), "C extensions not installed" + lambda: not util.has_compiled_ext(), + "Cython extensions not installed", ) def _has_sqlite(self): diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index 32e989fca..774b57934 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -7,92 +7,27 @@ """Collection classes and helpers.""" import collections.abc as collections_abc -from itertools import filterfalse import operator import types import weakref from .compat import threading -EMPTY_SET = frozenset() - - -class ImmutableContainer: - def _immutable(self, *arg, **kw): - raise TypeError("%s object is immutable" % self.__class__.__name__) - - __delitem__ = __setitem__ = __setattr__ = _immutable - - -def _immutabledict_py_fallback(): - class immutabledict(ImmutableContainer, dict): - - clear = ( - pop - ) = popitem = setdefault = update = ImmutableContainer._immutable - - def __new__(cls, *args): - new = dict.__new__(cls) - dict.__init__(new, *args) - return new - - def __init__(self, *args): - pass - - def __reduce__(self): - return _immutabledict_reconstructor, (dict(self),) - - def union(self, __d=None): - if not __d: - return self - - new = dict.__new__(self.__class__) - dict.__init__(new, self) - dict.update(new, __d) - return new - - def _union_w_kw(self, __d=None, **kw): - # not sure if C version works correctly w/ this yet - if not __d and not kw: - return self - - new = dict.__new__(self.__class__) - dict.__init__(new, self) - if __d: - dict.update(new, __d) - dict.update(new, kw) - return new - - def merge_with(self, *dicts): - new = None - for d in dicts: - if d: - if new is None: - new = dict.__new__(self.__class__) - dict.__init__(new, self) - dict.update(new, d) - if new is None: - return self - - return new - - def __repr__(self): - return "immutabledict(%s)" % dict.__repr__(self) - - return immutabledict - - try: - from sqlalchemy.cimmutabledict import immutabledict - - collections_abc.Mapping.register(immutabledict) - + from sqlalchemy.cyextension.immutabledict import ImmutableContainer + from sqlalchemy.cyextension.immutabledict import immutabledict + from sqlalchemy.cyextension.collections import IdentitySet + from sqlalchemy.cyextension.collections import OrderedSet + from sqlalchemy.cyextension.collections import unique_list # noqa except ImportError: - immutabledict = _immutabledict_py_fallback() + from ._py_collections import immutabledict + from ._py_collections import IdentitySet + from ._py_collections import ImmutableContainer + from ._py_collections import OrderedSet + from ._py_collections import unique_list # noqa - def _immutabledict_reconstructor(*arg): - """do the pickle dance""" - return immutabledict(*arg) + +EMPTY_SET = frozenset() def coerce_to_immutabledict(d): @@ -242,334 +177,6 @@ OrderedDict = dict sort_dictionary = _ordered_dictionary_sort -class OrderedSet(set): - def __init__(self, d=None): - set.__init__(self) - if d is not None: - self._list = unique_list(d) - set.update(self, self._list) - else: - self._list = [] - - def add(self, element): - if element not in self: - self._list.append(element) - set.add(self, element) - - def remove(self, element): - set.remove(self, element) - self._list.remove(element) - - def insert(self, pos, element): - if element not in self: - self._list.insert(pos, element) - set.add(self, element) - - def discard(self, element): - if element in self: - self._list.remove(element) - set.remove(self, element) - - def clear(self): - set.clear(self) - self._list = [] - - def __getitem__(self, key): - return self._list[key] - - def __iter__(self): - return iter(self._list) - - def __add__(self, other): - return self.union(other) - - def __repr__(self): - return "%s(%r)" % (self.__class__.__name__, self._list) - - __str__ = __repr__ - - def update(self, iterable): - for e in iterable: - if e not in self: - self._list.append(e) - set.add(self, e) - return self - - __ior__ = update - - def union(self, other): - result = self.__class__(self) - result.update(other) - return result - - __or__ = union - - def intersection(self, other): - other = set(other) - return self.__class__(a for a in self if a in other) - - __and__ = intersection - - def symmetric_difference(self, other): - other = set(other) - result = self.__class__(a for a in self if a not in other) - result.update(a for a in other if a not in self) - return result - - __xor__ = symmetric_difference - - def difference(self, other): - other = set(other) - return self.__class__(a for a in self if a not in other) - - __sub__ = difference - - def intersection_update(self, other): - other = set(other) - set.intersection_update(self, other) - self._list = [a for a in self._list if a in other] - return self - - __iand__ = intersection_update - - def symmetric_difference_update(self, other): - set.symmetric_difference_update(self, other) - self._list = [a for a in self._list if a in self] - self._list += [a for a in other._list if a in self] - return self - - __ixor__ = symmetric_difference_update - - def difference_update(self, other): - set.difference_update(self, other) - self._list = [a for a in self._list if a in self] - return self - - __isub__ = difference_update - - -class IdentitySet: - """A set that considers only object id() for uniqueness. - - This strategy has edge cases for builtin types- it's possible to have - two 'foo' strings in one of these sets, for example. Use sparingly. - - """ - - def __init__(self, iterable=None): - self._members = dict() - if iterable: - self.update(iterable) - - def add(self, value): - self._members[id(value)] = value - - def __contains__(self, value): - return id(value) in self._members - - def remove(self, value): - del self._members[id(value)] - - def discard(self, value): - try: - self.remove(value) - except KeyError: - pass - - def pop(self): - try: - pair = self._members.popitem() - return pair[1] - except KeyError: - raise KeyError("pop from an empty set") - - def clear(self): - self._members.clear() - - def __cmp__(self, other): - raise TypeError("cannot compare sets using cmp()") - - def __eq__(self, other): - if isinstance(other, IdentitySet): - return self._members == other._members - else: - return False - - def __ne__(self, other): - if isinstance(other, IdentitySet): - return self._members != other._members - else: - return True - - def issubset(self, iterable): - if isinstance(iterable, self.__class__): - other = iterable - else: - other = self.__class__(iterable) - - if len(self) > len(other): - return False - for m in filterfalse( - other._members.__contains__, iter(self._members.keys()) - ): - return False - return True - - def __le__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return self.issubset(other) - - def __lt__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return len(self) < len(other) and self.issubset(other) - - def issuperset(self, iterable): - if isinstance(iterable, self.__class__): - other = iterable - else: - other = self.__class__(iterable) - - if len(self) < len(other): - return False - - for m in filterfalse( - self._members.__contains__, iter(other._members.keys()) - ): - return False - return True - - def __ge__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return self.issuperset(other) - - def __gt__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return len(self) > len(other) and self.issuperset(other) - - def union(self, iterable): - result = self.__class__() - members = self._members - result._members.update(members) - result._members.update((id(obj), obj) for obj in iterable) - return result - - def __or__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return self.union(other) - - def update(self, iterable): - self._members.update((id(obj), obj) for obj in iterable) - - def __ior__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - self.update(other) - return self - - def difference(self, iterable): - result = self.__class__() - members = self._members - if isinstance(iterable, self.__class__): - other = set(iterable._members.keys()) - else: - other = {id(obj) for obj in iterable} - result._members.update( - ((k, v) for k, v in members.items() if k not in other) - ) - return result - - def __sub__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return self.difference(other) - - def difference_update(self, iterable): - self._members = self.difference(iterable)._members - - def __isub__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - self.difference_update(other) - return self - - def intersection(self, iterable): - result = self.__class__() - members = self._members - if isinstance(iterable, self.__class__): - other = set(iterable._members.keys()) - else: - other = {id(obj) for obj in iterable} - result._members.update( - (k, v) for k, v in members.items() if k in other - ) - return result - - def __and__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return self.intersection(other) - - def intersection_update(self, iterable): - self._members = self.intersection(iterable)._members - - def __iand__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - self.intersection_update(other) - return self - - def symmetric_difference(self, iterable): - result = self.__class__() - members = self._members - if isinstance(iterable, self.__class__): - other = iterable._members - else: - other = {id(obj): obj for obj in iterable} - result._members.update( - ((k, v) for k, v in members.items() if k not in other) - ) - result._members.update( - ((k, v) for k, v in other.items() if k not in members) - ) - return result - - def __xor__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - return self.symmetric_difference(other) - - def symmetric_difference_update(self, iterable): - self._members = self.symmetric_difference(iterable)._members - - def __ixor__(self, other): - if not isinstance(other, IdentitySet): - return NotImplemented - self.symmetric_difference(other) - return self - - def copy(self): - return type(self)(iter(self._members.values())) - - __copy__ = copy - - def __len__(self): - return len(self._members) - - def __iter__(self): - return iter(self._members.values()) - - def __hash__(self): - raise TypeError("set objects are unhashable") - - def __repr__(self): - return "%s(%r)" % (type(self).__name__, list(self._members.values())) - - class WeakSequence: def __init__(self, __elements=()): # adapted from weakref.WeakKeyDictionary, prevent reference @@ -661,19 +268,6 @@ _property_getters = PopulateDict( ) -def unique_list(seq, hashfunc=None): - seen = set() - seen_add = seen.add - if not hashfunc: - return [x for x in seq if x not in seen and not seen_add(x)] - else: - return [ - x - for x in seq - if hashfunc(x) not in seen and not seen_add(hashfunc(x)) - ] - - class UniqueAppender: """Appends items to a collection ensuring uniqueness. diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py new file mode 100644 index 000000000..ff61f6ca9 --- /dev/null +++ b/lib/sqlalchemy/util/_py_collections.py @@ -0,0 +1,401 @@ +from itertools import filterfalse + + +class ImmutableContainer: + def _immutable(self, *arg, **kw): + raise TypeError("%s object is immutable" % self.__class__.__name__) + + __delitem__ = __setitem__ = __setattr__ = _immutable + + +class immutabledict(ImmutableContainer, dict): + + clear = pop = popitem = setdefault = update = ImmutableContainer._immutable + + def __new__(cls, *args): + new = dict.__new__(cls) + dict.__init__(new, *args) + return new + + def __init__(self, *args): + pass + + def __reduce__(self): + return immutabledict, (dict(self),) + + def union(self, __d=None): + if not __d: + return self + + new = dict.__new__(self.__class__) + dict.__init__(new, self) + dict.update(new, __d) + return new + + def _union_w_kw(self, __d=None, **kw): + # not sure if C version works correctly w/ this yet + if not __d and not kw: + return self + + new = dict.__new__(self.__class__) + dict.__init__(new, self) + if __d: + dict.update(new, __d) + dict.update(new, kw) + return new + + def merge_with(self, *dicts): + new = None + for d in dicts: + if d: + if new is None: + new = dict.__new__(self.__class__) + dict.__init__(new, self) + dict.update(new, d) + if new is None: + return self + + return new + + def __repr__(self): + return "immutabledict(%s)" % dict.__repr__(self) + + +class OrderedSet(set): + def __init__(self, d=None): + set.__init__(self) + if d is not None: + self._list = unique_list(d) + set.update(self, self._list) + else: + self._list = [] + + def add(self, element): + if element not in self: + self._list.append(element) + set.add(self, element) + + def remove(self, element): + set.remove(self, element) + self._list.remove(element) + + def insert(self, pos, element): + if element not in self: + self._list.insert(pos, element) + set.add(self, element) + + def discard(self, element): + if element in self: + self._list.remove(element) + set.remove(self, element) + + def clear(self): + set.clear(self) + self._list = [] + + def __getitem__(self, key): + return self._list[key] + + def __iter__(self): + return iter(self._list) + + def __add__(self, other): + return self.union(other) + + def __repr__(self): + return "%s(%r)" % (self.__class__.__name__, self._list) + + __str__ = __repr__ + + def update(self, iterable): + for e in iterable: + if e not in self: + self._list.append(e) + set.add(self, e) + return self + + __ior__ = update + + def union(self, other): + result = self.__class__(self) + result.update(other) + return result + + __or__ = union + + def intersection(self, other): + other = other if isinstance(other, set) else set(other) + return self.__class__(a for a in self if a in other) + + __and__ = intersection + + def symmetric_difference(self, other): + other_set = other if isinstance(other, set) else set(other) + result = self.__class__(a for a in self if a not in other_set) + result.update(a for a in other if a not in self) + return result + + __xor__ = symmetric_difference + + def difference(self, other): + other = other if isinstance(other, set) else set(other) + return self.__class__(a for a in self if a not in other) + + __sub__ = difference + + def intersection_update(self, other): + other = other if isinstance(other, set) else set(other) + set.intersection_update(self, other) + self._list = [a for a in self._list if a in other] + return self + + __iand__ = intersection_update + + def symmetric_difference_update(self, other): + set.symmetric_difference_update(self, other) + self._list = [a for a in self._list if a in self] + self._list += [a for a in other if a in self] + return self + + __ixor__ = symmetric_difference_update + + def difference_update(self, other): + set.difference_update(self, other) + self._list = [a for a in self._list if a in self] + return self + + __isub__ = difference_update + + +class IdentitySet: + """A set that considers only object id() for uniqueness. + + This strategy has edge cases for builtin types- it's possible to have + two 'foo' strings in one of these sets, for example. Use sparingly. + + """ + + def __init__(self, iterable=None): + self._members = dict() + if iterable: + self.update(iterable) + + def add(self, value): + self._members[id(value)] = value + + def __contains__(self, value): + return id(value) in self._members + + def remove(self, value): + del self._members[id(value)] + + def discard(self, value): + try: + self.remove(value) + except KeyError: + pass + + def pop(self): + try: + pair = self._members.popitem() + return pair[1] + except KeyError: + raise KeyError("pop from an empty set") + + def clear(self): + self._members.clear() + + def __cmp__(self, other): + raise TypeError("cannot compare sets using cmp()") + + def __eq__(self, other): + if isinstance(other, IdentitySet): + return self._members == other._members + else: + return False + + def __ne__(self, other): + if isinstance(other, IdentitySet): + return self._members != other._members + else: + return True + + def issubset(self, iterable): + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) > len(other): + return False + for m in filterfalse( + other._members.__contains__, iter(self._members.keys()) + ): + return False + return True + + def __le__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issubset(other) + + def __lt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) < len(other) and self.issubset(other) + + def issuperset(self, iterable): + if isinstance(iterable, self.__class__): + other = iterable + else: + other = self.__class__(iterable) + + if len(self) < len(other): + return False + + for m in filterfalse( + self._members.__contains__, iter(other._members.keys()) + ): + return False + return True + + def __ge__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.issuperset(other) + + def __gt__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return len(self) > len(other) and self.issuperset(other) + + def union(self, iterable): + result = self.__class__() + members = self._members + result._members.update(members) + result._members.update((id(obj), obj) for obj in iterable) + return result + + def __or__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.union(other) + + def update(self, iterable): + self._members.update((id(obj), obj) for obj in iterable) + + def __ior__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.update(other) + return self + + def difference(self, iterable): + result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = iterable._members + else: + other = {id(obj) for obj in iterable} + result._members = { + k: v for k, v in self._members.items() if k not in other + } + return result + + def __sub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.difference(other) + + def difference_update(self, iterable): + self._members = self.difference(iterable)._members + + def __isub__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.difference_update(other) + return self + + def intersection(self, iterable): + result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = iterable._members + else: + other = {id(obj) for obj in iterable} + result._members = { + k: v for k, v in self._members.items() if k in other + } + return result + + def __and__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.intersection(other) + + def intersection_update(self, iterable): + self._members = self.intersection(iterable)._members + + def __iand__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.intersection_update(other) + return self + + def symmetric_difference(self, iterable): + result = self.__new__(self.__class__) + if isinstance(iterable, self.__class__): + other = iterable._members + else: + other = {id(obj): obj for obj in iterable} + result._members = { + k: v for k, v in self._members.items() if k not in other + } + result._members.update( + (k, v) for k, v in other.items() if k not in self._members + ) + return result + + def __xor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + return self.symmetric_difference(other) + + def symmetric_difference_update(self, iterable): + self._members = self.symmetric_difference(iterable)._members + + def __ixor__(self, other): + if not isinstance(other, IdentitySet): + return NotImplemented + self.symmetric_difference(other) + return self + + def copy(self): + result = self.__new__(self.__class__) + result._members = self._members.copy() + return result + + __copy__ = copy + + def __len__(self): + return len(self._members) + + def __iter__(self): + return iter(self._members.values()) + + def __hash__(self): + raise TypeError("set objects are unhashable") + + def __repr__(self): + return "%s(%r)" % (type(self).__name__, list(self._members.values())) + + +def unique_list(seq, hashfunc=None): + seen = set() + seen_add = seen.add + if not hashfunc: + return [x for x in seq if x not in seen and not seen_add(x)] + else: + return [ + x + for x in seq + if hashfunc(x) not in seen and not seen_add(hashfunc(x)) + ] diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 621941b43..b759490c5 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1899,9 +1899,11 @@ def repr_tuple_names(names): def has_compiled_ext(): try: - from sqlalchemy import cimmutabledict # noqa F401 - from sqlalchemy import cprocessors # noqa F401 - from sqlalchemy import cresultproxy # noqa F401 + from sqlalchemy.cyextension import collections # noqa F401 + from sqlalchemy.cyextension import immutabledict # noqa F401 + from sqlalchemy.cyextension import processors # noqa F401 + from sqlalchemy.cyextension import resultproxy # noqa F401 + from sqlalchemy.cyextension import util # noqa F401 return True except ImportError: diff --git a/pyproject.toml b/pyproject.toml index 0f7257892..7523d67e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,11 @@ +[build-system] + build-backend = "setuptools.build_meta" + requires = [ + "setuptools>=47", + "wheel>=0.34", + "cython>=0.29.24; python_implementation == 'CPython'", # Skip cython when using pypy + ] + [tool.black] line-length = 79 target-version = ['py27', 'py36'] @@ -1,9 +1,6 @@ [metadata] name = SQLAlchemy -# version comes from setup.py; setuptools -# can't read the "attr:" here without importing -# until version 47.0.0 which is too recent - +version = attr: sqlalchemy.__version__ description = Database Abstraction Library long_description = file: README.rst long_description_content_type = text/x-rst @@ -1,13 +1,13 @@ import os import platform -import re import sys +from setuptools import __version__ from setuptools import Distribution as _Distribution -from setuptools import Extension from setuptools import setup -from setuptools.command.build_ext import build_ext -from setuptools.command.test import test as TestCommand + +if not int(__version__.partition(".")[0]) >= 47: + raise RuntimeError(f"Setuptools >= 47 required. Found {__version__}") # attempt to use pep-632 imports for setuptools symbols; however, # since these symbols were only added to setuptools as of 59.0.1, @@ -21,67 +21,79 @@ except ImportError: from distutils.errors import DistutilsExecError from distutils.errors import DistutilsPlatformError +try: + from Cython.Distutils.old_build_ext import old_build_ext + from Cython.Distutils.extension import Extension + + CYTHON = True +except ImportError: + CYTHON = False + cmdclass = {} cpython = platform.python_implementation() == "CPython" ext_errors = (CCompilerError, DistutilsExecError, DistutilsPlatformError) +extra_compile_args = [] if sys.platform == "win32": # Work around issue https://github.com/pypa/setuptools/issues/1902 ext_errors += (IOError, TypeError) - extra_compile_args = [] -elif sys.platform in ("linux", "linux2"): - # warn for undefined symbols in .c files - extra_compile_args = ["-Wundef", "-Werror=implicit-function-declaration"] -else: - extra_compile_args = [] - -ext_modules = [ - Extension( - "sqlalchemy.cprocessors", - sources=["lib/sqlalchemy/cextension/processors.c"], - extra_compile_args=extra_compile_args, - ), - Extension( - "sqlalchemy.cresultproxy", - sources=["lib/sqlalchemy/cextension/resultproxy.c"], - extra_compile_args=extra_compile_args, - ), - Extension( - "sqlalchemy.cimmutabledict", - sources=["lib/sqlalchemy/cextension/immutabledict.c"], - extra_compile_args=extra_compile_args, - ), -] - - -class BuildFailed(Exception): - def __init__(self): - self.cause = sys.exc_info()[1] # work around py 2/3 different syntax +cython_files = [ + "collections.pyx", + "immutabledict.pyx", + "processors.pyx", + "resultproxy.pyx", + "util.pyx", +] +cython_directives = {"language_level": "3"} + +if CYTHON: + + def get_ext_modules(): + module_prefix = "sqlalchemy.cyextension." + source_prefix = "lib/sqlalchemy/cyextension/" + + ext_modules = [] + for file in cython_files: + name, _ = os.path.splitext(file) + ext_modules.append( + Extension( + module_prefix + name, + sources=[source_prefix + file], + extra_compile_args=extra_compile_args, + cython_directives=cython_directives, + ) + ) + return ext_modules -class ve_build_ext(build_ext): - # This class allows C extension building to fail. + class BuildFailed(Exception): + pass - def run(self): - try: - build_ext.run(self) - except DistutilsPlatformError: - raise BuildFailed() + class ve_build_ext(old_build_ext): + # This class allows Cython building to fail. - def build_extension(self, ext): - try: - build_ext.build_extension(self, ext) - except ext_errors: - raise BuildFailed() - except ValueError: - # this can happen on Windows 64 bit, see Python issue 7511 - if "'path'" in str(sys.exc_info()[1]): # works with both py 2/3 + def run(self): + try: + super().run() + except DistutilsPlatformError: raise BuildFailed() - raise - -cmdclass["build_ext"] = ve_build_ext + def build_extension(self, ext): + try: + super().build_extension(ext) + except ext_errors as e: + raise BuildFailed() from e + except ValueError as e: + # this can happen on Windows 64 bit, see Python issue 7511 + if "'path'" in str(e): + raise BuildFailed() from e + raise + + cmdclass["build_ext"] = ve_build_ext + ext_modules = get_ext_modules() +else: + ext_modules = [] class Distribution(_Distribution): @@ -95,24 +107,6 @@ class Distribution(_Distribution): return True -class UseTox(TestCommand): - RED = 31 - RESET_SEQ = "\033[0m" - BOLD_SEQ = "\033[1m" - COLOR_SEQ = "\033[1;%dm" - - def run_tests(self): - sys.stderr.write( - "%s%spython setup.py test is deprecated by pypa. Please invoke " - "'tox' with no arguments for a basic test run.\n%s" - % (self.COLOR_SEQ % self.RED, self.BOLD_SEQ, self.RESET_SEQ) - ) - sys.exit(1) - - -cmdclass["test"] = UseTox - - def status_msgs(*msgs): print("*" * 75) for msg in msgs: @@ -120,16 +114,6 @@ def status_msgs(*msgs): print("*" * 75) -with open( - os.path.join(os.path.dirname(__file__), "lib", "sqlalchemy", "__init__.py") -) as v_file: - VERSION = ( - re.compile(r""".*__version__ = ["'](.*?)['"]""", re.S) - .match(v_file.read()) - .group(1) - ) - - def run_setup(with_cext): kwargs = {} if with_cext: @@ -137,30 +121,36 @@ def run_setup(with_cext): else: if os.environ.get("REQUIRE_SQLALCHEMY_CEXT"): raise AssertionError( - "Can't build on this platform with " - "REQUIRE_SQLALCHEMY_CEXT set." + "Can't build on this platform with REQUIRE_SQLALCHEMY_CEXT" + " set. Cython is required to build compiled extensions" ) kwargs["ext_modules"] = [] - setup(version=VERSION, cmdclass=cmdclass, distclass=Distribution, **kwargs) + setup(cmdclass=cmdclass, distclass=Distribution, **kwargs) if not cpython: run_setup(False) status_msgs( - "WARNING: C extensions are not supported on " - + "this Python platform, speedups are not enabled.", + "WARNING: Cython extensions are not supported on " + "this Python platform, speedups are not enabled.", + "Plain-Python build succeeded.", + ) +elif not CYTHON: + run_setup(False) + status_msgs( + "WARNING: Cython is required to build the compiled " + "extensions, speedups are not enabled.", "Plain-Python build succeeded.", ) elif os.environ.get("DISABLE_SQLALCHEMY_CEXT"): run_setup(False) status_msgs( "DISABLE_SQLALCHEMY_CEXT is set; " - + "not attempting to build C extensions.", + "not attempting to build Cython extensions.", "Plain-Python build succeeded.", ) - else: try: run_setup(True) @@ -168,7 +158,7 @@ else: if os.environ.get("REQUIRE_SQLALCHEMY_CEXT"): status_msgs( - "NOTE: C extension build is required because " + "NOTE: Cython extension build is required because " "REQUIRE_SQLALCHEMY_CEXT is set, and the build has failed; " "will not degrade to non-C extensions" ) @@ -176,8 +166,8 @@ else: status_msgs( exc.cause, - "WARNING: The C extension could not be compiled, " - + "speedups are not enabled.", + "WARNING: The Cython extension could not be compiled, " + "speedups are not enabled.", "Failure information, if any, is above.", "Retrying the build without the C extension now.", ) @@ -185,7 +175,7 @@ else: run_setup(False) status_msgs( - "WARNING: The C extension could not be compiled, " - + "speedups are not enabled.", + "WARNING: The Cython extension could not be compiled, " + "speedups are not enabled.", "Plain-Python build succeeded.", ) diff --git a/test/aaa_profiling/test_memusage.py b/test/aaa_profiling/test_memusage.py index a876108a6..8d6e4f500 100644 --- a/test/aaa_profiling/test_memusage.py +++ b/test/aaa_profiling/test_memusage.py @@ -15,6 +15,7 @@ from sqlalchemy import testing from sqlalchemy import Unicode from sqlalchemy import util from sqlalchemy.engine import result +from sqlalchemy.engine.processors import to_decimal_processor_factory from sqlalchemy.orm import aliased from sqlalchemy.orm import clear_mappers from sqlalchemy.orm import configure_mappers @@ -28,7 +29,6 @@ from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import subqueryload from sqlalchemy.orm.session import _sessions -from sqlalchemy.processors import to_decimal_processor_factory from sqlalchemy.sql import column from sqlalchemy.sql import util as sql_util from sqlalchemy.sql.visitors import cloned_traverse diff --git a/test/base/test_result.py b/test/base/test_result.py index 8c9eb398e..b31e886da 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -195,6 +195,37 @@ class ResultTupleTest(fixtures.TestBase): eq_(kt._fields, ("a", "b")) eq_(kt._asdict(), {"a": 1, "b": 3}) + @testing.requires.cextensions + def test_serialize_cy_py_cy(self): + from sqlalchemy.engine._py_row import BaseRow as _PyRow + from sqlalchemy.cyextension.resultproxy import BaseRow as _CyRow + + global Row + + p = result.SimpleResultMetaData(["a", None, "b"]) + + for loads, dumps in picklers(): + + class Row(_CyRow): + pass + + row = Row(p, p._processors, p._keymap, 0, (1, 2, 3)) + + state = dumps(row) + + class Row(_PyRow): + pass + + row2 = loads(state) + is_true(isinstance(row2, _PyRow)) + state2 = dumps(row2) + + class Row(_CyRow): + pass + + row3 = loads(state2) + is_true(isinstance(row3, _CyRow)) + class ResultTest(fixtures.TestBase): def _fixture( diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 3bbcbe3fb..836778bc9 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -13,7 +13,9 @@ from sqlalchemy.sql import column from sqlalchemy.sql.base import DedupeColumnCollection from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message +from sqlalchemy.testing import combinations from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ @@ -162,6 +164,12 @@ class OrderedSetTest(fixtures.TestBase): eq_(o.intersection(iter([3, 4, 6])), util.OrderedSet([3, 4])) eq_(o.union(iter([3, 4, 6])), util.OrderedSet([2, 3, 4, 5, 6])) + def test_repr(self): + o = util.OrderedSet([]) + eq_(str(o), "OrderedSet([])") + o = util.OrderedSet([3, 2, 4, 5]) + eq_(str(o), "OrderedSet([3, 2, 4, 5])") + class ImmutableDictTest(fixtures.TestBase): def test_union_no_change(self): @@ -267,6 +275,42 @@ class ImmutableDictTest(fixtures.TestBase): assert isinstance(d2, util.immutabledict) + def test_repr(self): + # this is used by the stub generator in alembic + i = util.immutabledict() + eq_(str(i), "immutabledict({})") + i2 = util.immutabledict({"a": 42, 42: "a"}) + eq_(str(i2), "immutabledict({'a': 42, 42: 'a'})") + + +class ImmutableTest(fixtures.TestBase): + @combinations(util.immutabledict({1: 2, 3: 4}), util.FacadeDict({2: 3})) + def test_immutable(self, d): + calls = ( + lambda: d.__delitem__(1), + lambda: d.__setitem__(2, 3), + lambda: d.__setattr__(2, 3), + d.clear, + lambda: d.setdefault(1, 3), + lambda: d.update({2: 4}), + ) + if hasattr(d, "pop"): + calls += (d.pop, d.popitem) + for m in calls: + with expect_raises_message(TypeError, "object is immutable"): + m() + + def test_immutable_properties(self): + d = util.ImmutableProperties({3: 4}) + calls = ( + lambda: d.__delitem__(1), + lambda: d.__setitem__(2, 3), + lambda: d.__setattr__(2, 3), + ) + for m in calls: + with expect_raises_message(TypeError, "object is immutable"): + m() + class MemoizedAttrTest(fixtures.TestBase): def test_memoized_property(self): @@ -1811,6 +1855,12 @@ class IdentitySetTest(fixtures.TestBase): assert_raises(TypeError, util.cmp, ids) assert_raises(TypeError, hash, ids) + def test_repr(self): + i = util.IdentitySet([]) + eq_(str(i), "IdentitySet([])") + i = util.IdentitySet([1, 2, 3]) + eq_(str(i), "IdentitySet([1, 2, 3])") + class NoHashIdentitySetTest(IdentitySetTest): obj_type = NoHash diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index fb4fd02a1..59ebc87e2 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -119,7 +119,7 @@ class ExecuteTest(fixtures.TablesTest): tsa.exc.ArgumentError, "List argument must consist only of tuples or dictionaries", connection.exec_driver_sql, - "insert into users (user_id, user_name) " "values (?, ?)", + "insert into users (user_id, user_name) values (?, ?)", [2, "fred"], ) @@ -127,7 +127,7 @@ class ExecuteTest(fixtures.TablesTest): tsa.exc.ArgumentError, "List argument must consist only of tuples or dictionaries", connection.exec_driver_sql, - "insert into users (user_id, user_name) " "values (?, ?)", + "insert into users (user_id, user_name) values (?, ?)", [[3, "ed"], [4, "horse"]], ) @@ -159,23 +159,23 @@ class ExecuteTest(fixtures.TablesTest): def test_raw_qmark(self, connection): conn = connection conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (?, ?)", + "insert into users (user_id, user_name) values (?, ?)", (1, "jack"), ) conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (?, ?)", + "insert into users (user_id, user_name) values (?, ?)", (2, "fred"), ) conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (?, ?)", + "insert into users (user_id, user_name) values (?, ?)", [(3, "ed"), (4, "horse")], ) conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (?, ?)", + "insert into users (user_id, user_name) values (?, ?)", [(5, "barney"), (6, "donkey")], ) conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (?, ?)", + "insert into users (user_id, user_name) values (?, ?)", (7, "sally"), ) res = conn.exec_driver_sql("select * from users order by user_id") @@ -198,15 +198,15 @@ class ExecuteTest(fixtures.TablesTest): def test_raw_sprintf(self, connection): conn = connection conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (%s, %s)", + "insert into users (user_id, user_name) values (%s, %s)", (1, "jack"), ) conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (%s, %s)", + "insert into users (user_id, user_name) values (%s, %s)", [(2, "ed"), (3, "horse")], ) conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (%s, %s)", + "insert into users (user_id, user_name) values (%s, %s)", (4, "sally"), ) conn.exec_driver_sql("insert into users (user_id) values (%s)", (5,)) @@ -254,15 +254,15 @@ class ExecuteTest(fixtures.TablesTest): def test_raw_named(self, connection): conn = connection conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (:id, :name)", + "insert into users (user_id, user_name) values (:id, :name)", {"id": 1, "name": "jack"}, ) conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (:id, :name)", + "insert into users (user_id, user_name) values (:id, :name)", [{"id": 2, "name": "ed"}, {"id": 3, "name": "horse"}], ) conn.exec_driver_sql( - "insert into users (user_id, user_name) " "values (:id, :name)", + "insert into users (user_id, user_name) values (:id, :name)", {"id": 4, "name": "sally"}, ) res = conn.exec_driver_sql("select * from users order by user_id") @@ -518,7 +518,7 @@ class ExecuteTest(fixtures.TablesTest): ) @testing.fails_on( "oracle+cx_oracle", - "cx_oracle exception seems to be having " "some issue with pickling", + "cx_oracle exception seems to be having some issue with pickling", ) def test_stmt_exception_pickleable_plus_dbapi(self): raw = testing.db.raw_connection() diff --git a/test/engine/test_processors.py b/test/engine/test_processors.py index 943ae32f0..392632327 100644 --- a/test/engine/test_processors.py +++ b/test/engine/test_processors.py @@ -1,6 +1,11 @@ +from types import MappingProxyType + +from sqlalchemy import exc from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.util import immutabledict class _BooleanProcessorTest(fixtures.TestBase): @@ -20,14 +25,14 @@ class _BooleanProcessorTest(fixtures.TestBase): eq_(self.module.int_to_boolean(-4), True) -class CBooleanProcessorTest(_BooleanProcessorTest): +class CyBooleanProcessorTest(_BooleanProcessorTest): __requires__ = ("cextensions",) @classmethod def setup_test_class(cls): - from sqlalchemy import cprocessors + from sqlalchemy.cyextension import processors - cls.module = cprocessors + cls.module = processors class _DateProcessorTest(fixtures.TestBase): @@ -83,23 +88,168 @@ class _DateProcessorTest(fixtures.TestBase): class PyDateProcessorTest(_DateProcessorTest): @classmethod def setup_test_class(cls): - from sqlalchemy import processors - - cls.module = type( - "util", - (object,), - dict( - (k, staticmethod(v)) - for k, v in list(processors.py_fallback().items()) - ), + from sqlalchemy.engine import _py_processors + + cls.module = _py_processors + + +class CyDateProcessorTest(_DateProcessorTest): + __requires__ = ("cextensions",) + + @classmethod + def setup_test_class(cls): + from sqlalchemy.cyextension import processors + + cls.module = processors + + +class _DistillArgsTest(fixtures.TestBase): + def test_distill_20_none(self): + eq_(self.module._distill_params_20(None), ()) + + def test_distill_20_empty_sequence(self): + eq_(self.module._distill_params_20(()), ()) + eq_(self.module._distill_params_20([]), []) + + def test_distill_20_sequence_sequence(self): + eq_(self.module._distill_params_20(((1, 2, 3),)), ((1, 2, 3),)) + eq_(self.module._distill_params_20([(1, 2, 3)]), [(1, 2, 3)]) + + eq_(self.module._distill_params_20(((1, 2), (2, 3))), ((1, 2), (2, 3))) + eq_(self.module._distill_params_20([(1, 2), (2, 3)]), [(1, 2), (2, 3)]) + + def test_distill_20_sequence_dict(self): + eq_(self.module._distill_params_20(({"a": 1},)), ({"a": 1},)) + eq_( + self.module._distill_params_20([{"a": 1}, {"a": 2}]), + [{"a": 1}, {"a": 2}], ) + eq_( + self.module._distill_params_20((MappingProxyType({"a": 1}),)), + (MappingProxyType({"a": 1}),), + ) + + def test_distill_20_sequence_error(self): + with expect_raises_message( + exc.ArgumentError, + "List argument must consist only of tuples or dictionaries", + ): + self.module._distill_params_20((1, 2, 3)) + with expect_raises_message( + exc.ArgumentError, + "List argument must consist only of tuples or dictionaries", + ): + self.module._distill_params_20(([1, 2, 3],)) + with expect_raises_message( + exc.ArgumentError, + "List argument must consist only of tuples or dictionaries", + ): + self.module._distill_params_20([1, 2, 3]) + with expect_raises_message( + exc.ArgumentError, + "List argument must consist only of tuples or dictionaries", + ): + self.module._distill_params_20(["a", "b"]) + + def test_distill_20_dict(self): + eq_(self.module._distill_params_20({"foo": "bar"}), [{"foo": "bar"}]) + eq_( + self.module._distill_params_20(immutabledict({"foo": "bar"})), + [immutabledict({"foo": "bar"})], + ) + eq_( + self.module._distill_params_20(MappingProxyType({"foo": "bar"})), + [MappingProxyType({"foo": "bar"})], + ) + + def test_distill_20_error(self): + with expect_raises_message( + exc.ArgumentError, "mapping or list expected for parameters" + ): + self.module._distill_params_20("foo") + with expect_raises_message( + exc.ArgumentError, "mapping or list expected for parameters" + ): + self.module._distill_params_20(1) + + def test_distill_raw_none(self): + eq_(self.module._distill_raw_params(None), ()) + + def test_distill_raw_empty_list(self): + eq_(self.module._distill_raw_params([]), []) + + def test_distill_raw_list_sequence(self): + eq_(self.module._distill_raw_params([(1, 2, 3)]), [(1, 2, 3)]) + eq_( + self.module._distill_raw_params([(1, 2), (2, 3)]), [(1, 2), (2, 3)] + ) + + def test_distill_raw_list_dict(self): + eq_( + self.module._distill_raw_params([{"a": 1}, {"a": 2}]), + [{"a": 1}, {"a": 2}], + ) + eq_( + self.module._distill_raw_params([MappingProxyType({"a": 1})]), + [MappingProxyType({"a": 1})], + ) + + def test_distill_raw_sequence_error(self): + with expect_raises_message( + exc.ArgumentError, + "List argument must consist only of tuples or dictionaries", + ): + self.module._distill_raw_params([1, 2, 3]) + with expect_raises_message( + exc.ArgumentError, + "List argument must consist only of tuples or dictionaries", + ): + self.module._distill_raw_params([[1, 2, 3]]) + with expect_raises_message( + exc.ArgumentError, + "List argument must consist only of tuples or dictionaries", + ): + self.module._distill_raw_params(["a", "b"]) + + def test_distill_raw_tuple(self): + eq_(self.module._distill_raw_params(()), [()]) + eq_(self.module._distill_raw_params((1, 2, 3)), [(1, 2, 3)]) + + def test_distill_raw_dict(self): + eq_(self.module._distill_raw_params({"foo": "bar"}), [{"foo": "bar"}]) + eq_( + self.module._distill_raw_params(immutabledict({"foo": "bar"})), + [immutabledict({"foo": "bar"})], + ) + eq_( + self.module._distill_raw_params(MappingProxyType({"foo": "bar"})), + [MappingProxyType({"foo": "bar"})], + ) + + def test_distill_raw_error(self): + with expect_raises_message( + exc.ArgumentError, "mapping or sequence expected for parameters" + ): + self.module._distill_raw_params("foo") + with expect_raises_message( + exc.ArgumentError, "mapping or sequence expected for parameters" + ): + self.module._distill_raw_params(1) + + +class PyDistillArgsTest(_DistillArgsTest): + @classmethod + def setup_test_class(cls): + from sqlalchemy.engine import _py_util + + cls.module = _py_util -class CDateProcessorTest(_DateProcessorTest): +class CyDistillArgsTest(_DistillArgsTest): __requires__ = ("cextensions",) @classmethod def setup_test_class(cls): - from sqlalchemy import cprocessors + from sqlalchemy.cyextension import util - cls.module = cprocessors + cls.module = util diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index 8866223cc..0f29cfc56 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -31,10 +31,20 @@ from sqlalchemy.testing import not_in from sqlalchemy.testing.fixtures import fixture_session from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table +from sqlalchemy.util import has_compiled_ext from sqlalchemy.util import OrderedSet from test.orm import _fixtures +if has_compiled_ext(): + # cython ordered set is immutable, subclass it with a python + # class so that its method can be replaced + _OrderedSet = OrderedSet + + class OrderedSet(_OrderedSet): + pass + + class MergeTest(_fixtures.FixtureTest): """Session.merge() functionality""" diff --git a/test/perf/compiled_extensions.py b/test/perf/compiled_extensions.py new file mode 100644 index 000000000..b9bb9ebd2 --- /dev/null +++ b/test/perf/compiled_extensions.py @@ -0,0 +1,1063 @@ +from collections import defaultdict +from decimal import Decimal +import re +from secrets import token_urlsafe +from textwrap import wrap +from timeit import timeit +from types import MappingProxyType + + +def test_case(fn): + fn.__test_case__ = True + return fn + + +class Case: + """Base test case. Mark test cases with ``test_case``""" + + IMPLEMENTATIONS = {} + "Keys are the impl name, values are callable to load it" + NUMBER = 1_000_000 + + _CASES = [] + + def __init__(self, impl): + self.impl = impl + self.init_objects() + + def __init_subclass__(cls): + if not cls.__name__.startswith("_"): + Case._CASES.append(cls) + + def init_objects(self): + pass + + @classmethod + def _load(cls, fn): + try: + return fn() + except Exception as e: + print(f"Error loading {fn}: {e}") + + @classmethod + def import_object(cls): + impl = [] + for name, fn in cls.IMPLEMENTATIONS.items(): + obj = cls._load(fn) + if obj: + impl.append((name, obj)) + return impl + + @classmethod + def _divide_results(cls, results, num, div, name): + "utility method to create ratios of two implementation" + if div in results and num in results: + results[name] = { + m: results[num][m] / results[div][m] for m in results[div] + } + + @classmethod + def update_results(cls, results): + pass + + @classmethod + def run_case(cls, factor, filter_): + objects = cls.import_object() + number = max(1, int(cls.NUMBER * factor)) + + stack = [c for c in cls.mro() if c not in {object, Case}] + methods = [] + while stack: + curr = stack.pop(0) + # dict keeps the definition order, dir is instead sorted + methods += [ + m + for m, fn in curr.__dict__.items() + if hasattr(fn, "__test_case__") + ] + + if filter_: + methods = [m for m in methods if re.search(filter_, m)] + + results = defaultdict(dict) + for name, impl in objects: + print(f"Running {name} ", end="", flush=True) + impl_case = cls(impl) + fails = [] + for m in methods: + call = getattr(impl_case, m) + try: + value = timeit(call, number=number) + print(".", end="", flush=True) + except Exception as e: + fails.append(f"{name}::{m} error: {e}") + print("x", end="", flush=True) + value = float("nan") + + results[name][m] = value + print(" Done") + for f in fails: + print("\t", f) + + cls.update_results(results) + return results + + +class ImmutableDict(Case): + @staticmethod + def python(): + from sqlalchemy.util._py_collections import immutabledict + + return immutabledict + + @staticmethod + def c(): + from sqlalchemy.cimmutabledict import immutabledict + + return immutabledict + + @staticmethod + def cython(): + from sqlalchemy.cyextension.immutabledict import immutabledict + + return immutabledict + + IMPLEMENTATIONS = { + "python": python.__func__, + "c": c.__func__, + "cython": cython.__func__, + } + + def init_objects(self): + self.small = {"a": 5, "b": 4} + self.large = {f"k{i}": f"v{i}" for i in range(50)} + self.d1 = self.impl({"x": 5, "y": 4}) + self.d2 = self.impl({f"key{i}": f"value{i}" for i in range(50)}) + + @classmethod + def update_results(cls, results): + cls._divide_results(results, "c", "python", "c / py") + cls._divide_results(results, "cython", "python", "cy / py") + cls._divide_results(results, "cython", "c", "cy / c") + + @test_case + def init_empty(self): + self.impl() + + @test_case + def init(self): + self.impl(self.small) + + @test_case + def init_large(self): + self.impl(self.large) + + @test_case + def len(self): + len(self.d1) + len(self.d2) + + @test_case + def getitem(self): + self.d1["x"] + self.d2["key42"] + + @test_case + def union(self): + self.d1.union(self.small) + + @test_case + def union_large(self): + self.d2.union(self.large) + + @test_case + def merge_with(self): + self.d1.merge_with(self.small) + + @test_case + def merge_with_large(self): + self.d2.merge_with(self.large) + + @test_case + def get(self): + self.d1.get("x") + self.d2.get("key42") + + @test_case + def get_miss(self): + self.d1.get("xxx") + self.d2.get("xxx") + + @test_case + def keys(self): + self.d1.keys() + self.d2.keys() + + @test_case + def items(self): + self.d1.items() + self.d2.items() + + @test_case + def values(self): + self.d1.values() + self.d2.values() + + @test_case + def iter(self): + list(self.d1) + list(self.d2) + + @test_case + def in_case(self): + "x" in self.d1 + "key42" in self.d1 + + @test_case + def in_miss(self): + "xx" in self.d1 + "xx" in self.d1 + + @test_case + def eq(self): + self.d1 == self.d1 + self.d2 == self.d2 + + @test_case + def eq_dict(self): + self.d1 == dict(self.d1) + self.d2 == dict(self.d2) + + @test_case + def eq_other(self): + self.d1 == self.d2 + self.d1 == "foo" + + @test_case + def ne(self): + self.d1 != self.d1 + self.d2 != self.d2 + + @test_case + def ne_dict(self): + self.d1 != dict(self.d1) + self.d2 != dict(self.d2) + + @test_case + def ne_other(self): + self.d1 != self.d2 + self.d1 != "foo" + + +class Processor(Case): + @staticmethod + def python(): + from sqlalchemy.engine import processors + + return processors + + @staticmethod + def c(): + from sqlalchemy import cprocessors as mod + + mod.to_decimal_processor_factory = ( + lambda t, s: mod.DecimalResultProcessor(t, "%%.%df" % s).process + ) + + return mod + + @staticmethod + def cython(): + from sqlalchemy.cyextension import processors as mod + + mod.to_decimal_processor_factory = ( + lambda t, s: mod.DecimalResultProcessor(t, "%%.%df" % s).process + ) + + return mod + + IMPLEMENTATIONS = { + "python": python.__func__, + "c": c.__func__, + "cython": cython.__func__, + } + NUMBER = 500_000 + + def init_objects(self): + self.to_dec = self.impl.to_decimal_processor_factory(Decimal, 10) + + self.bytes = token_urlsafe(2048).encode() + self.text = token_urlsafe(2048) + + @classmethod + def update_results(cls, results): + cls._divide_results(results, "c", "python", "c / py") + cls._divide_results(results, "cython", "python", "cy / py") + cls._divide_results(results, "cython", "c", "cy / c") + + @test_case + def int_to_boolean(self): + self.impl.int_to_boolean(None) + self.impl.int_to_boolean(10) + self.impl.int_to_boolean(1) + self.impl.int_to_boolean(-10) + self.impl.int_to_boolean(0) + + @test_case + def to_str(self): + self.impl.to_str(None) + self.impl.to_str(123) + self.impl.to_str(True) + self.impl.to_str(self) + + @test_case + def to_float(self): + self.impl.to_float(None) + self.impl.to_float(123) + self.impl.to_float(True) + self.impl.to_float(42) + self.impl.to_float(0) + self.impl.to_float(42.0) + + @test_case + def str_to_datetime(self): + self.impl.str_to_datetime(None) + self.impl.str_to_datetime("2020-01-01 20:10:34") + self.impl.str_to_datetime("2030-11-21 01:04:34.123456") + + @test_case + def str_to_time(self): + self.impl.str_to_time(None) + self.impl.str_to_time("20:10:34") + self.impl.str_to_time("01:04:34.123456") + + @test_case + def str_to_date(self): + self.impl.str_to_date(None) + self.impl.str_to_date("2020-01-01") + + @test_case + def to_decimal(self): + self.to_dec(None) is None + self.to_dec(123.44) + self.to_dec(99) + self.to_dec(99) + + +class DistillParam(Case): + NUMBER = 2_000_000 + + @staticmethod + def python(): + from sqlalchemy.engine import _py_util + + return _py_util + + @staticmethod + def cython(): + from sqlalchemy.cyextension import util as mod + + return mod + + IMPLEMENTATIONS = { + "python": python.__func__, + "cython": cython.__func__, + } + + def init_objects(self): + self.tup_tup = tuple(tuple(range(10)) for _ in range(100)) + self.list_tup = list(self.tup_tup) + self.dict = {f"c{i}": i for i in range(100)} + self.mapping = MappingProxyType(self.dict) + self.tup_dic = (self.dict, self.dict) + self.list_dic = [self.dict, self.dict] + + @classmethod + def update_results(cls, results): + cls._divide_results(results, "c", "python", "c / py") + cls._divide_results(results, "cython", "python", "cy / py") + cls._divide_results(results, "cython", "c", "cy / c") + + @test_case + def none_20(self): + self.impl._distill_params_20(None) + + @test_case + def empty_sequence_20(self): + self.impl._distill_params_20(()) + self.impl._distill_params_20([]) + + @test_case + def list_20(self): + self.impl._distill_params_20(self.list_tup) + + @test_case + def tuple_20(self): + self.impl._distill_params_20(self.tup_tup) + + @test_case + def list_dict_20(self): + self.impl._distill_params_20(self.list_tup) + + @test_case + def tuple_dict_20(self): + self.impl._distill_params_20(self.dict) + + @test_case + def mapping_20(self): + self.impl._distill_params_20(self.mapping) + + @test_case + def raw_none(self): + self.impl._distill_raw_params(None) + + @test_case + def raw_empty_sequence(self): + self.impl._distill_raw_params(()) + self.impl._distill_raw_params([]) + + @test_case + def raw_list(self): + self.impl._distill_raw_params(self.list_tup) + + @test_case + def raw_tuple(self): + self.impl._distill_raw_params(self.tup_tup) + + @test_case + def raw_list_dict(self): + self.impl._distill_raw_params(self.list_tup) + + @test_case + def raw_tuple_dict(self): + self.impl._distill_raw_params(self.dict) + + @test_case + def raw_mapping(self): + self.impl._distill_raw_params(self.mapping) + + +class IdentitySet(Case): + @staticmethod + def set_fn(): + return set + + @staticmethod + def python(): + from sqlalchemy.util._py_collections import IdentitySet + + return IdentitySet + + @staticmethod + def cython(): + from sqlalchemy.cyextension import collections + + return collections.IdentitySet + + IMPLEMENTATIONS = { + "set": set_fn.__func__, + "python": python.__func__, + "cython": cython.__func__, + } + NUMBER = 10 + + def init_objects(self): + self.val1 = list(range(10)) + self.val2 = list(wrap(token_urlsafe(4 * 2048), 4)) + + self.imp_1 = self.impl(self.val1) + self.imp_2 = self.impl(self.val2) + + @classmethod + def update_results(cls, results): + cls._divide_results(results, "python", "set", "py / set") + cls._divide_results(results, "cython", "python", "cy / py") + cls._divide_results(results, "cython", "set", "cy / set") + + @test_case + def init_empty(self): + i = self.impl + for _ in range(10000): + i() + + @test_case + def init(self): + i, v = self.impl, self.val2 + for _ in range(500): + i(v) + + @test_case + def init_from_impl(self): + for _ in range(500): + self.impl(self.imp_2) + + @test_case + def add(self): + ii = self.impl() + for _ in range(10): + for i in range(1000): + ii.add(str(i)) + + @test_case + def contains(self): + ii = self.impl(self.val2) + for _ in range(500): + for x in self.val1 + self.val2: + x in ii + + @test_case + def remove(self): + v = [str(i) for i in range(7500)] + ii = self.impl(v) + for x in v[:5000]: + ii.remove(x) + + @test_case + def discard(self): + v = [str(i) for i in range(7500)] + ii = self.impl(v) + for x in v[:5000]: + ii.discard(x) + + @test_case + def pop(self): + for x in range(1000): + ii = self.impl(self.val1) + for x in self.val1: + ii.pop() + + @test_case + def clear(self): + i, v = self.impl, self.val1 + for _ in range(5000): + ii = i(v) + ii.clear() + + @test_case + def eq(self): + for x in range(1000): + self.imp_1 == self.imp_1 + self.imp_1 == self.imp_2 + self.imp_1 == self.val2 + + @test_case + def ne(self): + for x in range(1000): + self.imp_1 != self.imp_1 + self.imp_1 != self.imp_2 + self.imp_1 != self.val2 + + @test_case + def issubset(self): + for _ in range(250): + self.imp_1.issubset(self.imp_1) + self.imp_1.issubset(self.imp_2) + self.imp_1.issubset(self.val1) + self.imp_1.issubset(self.val2) + + @test_case + def le(self): + for x in range(1000): + self.imp_1 <= self.imp_1 + self.imp_1 <= self.imp_2 + self.imp_2 <= self.imp_1 + self.imp_2 <= self.imp_2 + + @test_case + def lt(self): + for x in range(2500): + self.imp_1 < self.imp_1 + self.imp_1 < self.imp_2 + self.imp_2 < self.imp_1 + self.imp_2 < self.imp_2 + + @test_case + def issuperset(self): + for _ in range(250): + self.imp_1.issuperset(self.imp_1) + self.imp_1.issuperset(self.imp_2) + self.imp_1.issubset(self.val1) + self.imp_1.issubset(self.val2) + + @test_case + def ge(self): + for x in range(1000): + self.imp_1 >= self.imp_1 + self.imp_1 >= self.imp_2 + self.imp_2 >= self.imp_1 + self.imp_2 >= self.imp_2 + + @test_case + def gt(self): + for x in range(2500): + self.imp_1 > self.imp_1 + self.imp_2 > self.imp_2 + self.imp_2 > self.imp_1 + self.imp_2 > self.imp_2 + + @test_case + def union(self): + for _ in range(250): + self.imp_1.union(self.imp_2) + + @test_case + def or_test(self): + for _ in range(250): + self.imp_1 | self.imp_2 + + @test_case + def update(self): + ii = self.impl(self.val1) + for _ in range(250): + ii.update(self.imp_2) + + @test_case + def ior(self): + ii = self.impl(self.val1) + for _ in range(250): + ii |= self.imp_2 + + @test_case + def difference(self): + for _ in range(250): + self.imp_1.difference(self.imp_2) + self.imp_1.difference(self.val2) + + @test_case + def sub(self): + for _ in range(500): + self.imp_1 - self.imp_2 + + @test_case + def difference_update(self): + ii = self.impl(self.val1) + for _ in range(250): + ii.difference_update(self.imp_2) + ii.difference_update(self.val2) + + @test_case + def isub(self): + ii = self.impl(self.val1) + for _ in range(500): + ii -= self.imp_2 + + @test_case + def intersection(self): + for _ in range(250): + self.imp_1.intersection(self.imp_2) + self.imp_1.intersection(self.val2) + + @test_case + def and_test(self): + for _ in range(500): + self.imp_1 & self.imp_2 + + @test_case + def intersection_up(self): + ii = self.impl(self.val1) + for _ in range(250): + ii.intersection_update(self.imp_2) + ii.intersection_update(self.val2) + + @test_case + def iand(self): + ii = self.impl(self.val1) + for _ in range(500): + ii &= self.imp_2 + + @test_case + def symmetric_diff(self): + for _ in range(125): + self.imp_1.symmetric_difference(self.imp_2) + self.imp_1.symmetric_difference(self.val2) + + @test_case + def xor(self): + for _ in range(250): + self.imp_1 ^ self.imp_2 + + @test_case + def symmetric_diff_up(self): + ii = self.impl(self.val1) + for _ in range(125): + ii.symmetric_difference_update(self.imp_2) + ii.symmetric_difference_update(self.val2) + + @test_case + def ixor(self): + ii = self.impl(self.val1) + for _ in range(250): + ii ^= self.imp_2 + + @test_case + def copy(self): + for _ in range(250): + self.imp_1.copy() + self.imp_2.copy() + + @test_case + def len(self): + for x in range(5000): + len(self.imp_1) + len(self.imp_2) + + @test_case + def iter(self): + for _ in range(2000): + list(self.imp_1) + list(self.imp_2) + + @test_case + def repr(self): + for _ in range(250): + str(self.imp_1) + str(self.imp_2) + + +class OrderedSet(IdentitySet): + @staticmethod + def set_fn(): + return set + + @staticmethod + def python(): + from sqlalchemy.util._py_collections import OrderedSet + + return OrderedSet + + @staticmethod + def cython(): + from sqlalchemy.cyextension import collections + + return collections.OrderedSet + + @staticmethod + def ordered_lib(): + from orderedset import OrderedSet + + return OrderedSet + + IMPLEMENTATIONS = { + "set": set_fn.__func__, + "python": python.__func__, + "cython": cython.__func__, + "ordsetlib": ordered_lib.__func__, + } + + @classmethod + def update_results(cls, results): + super().update_results(results) + cls._divide_results(results, "ordsetlib", "set", "ordlib/set") + cls._divide_results(results, "cython", "ordsetlib", "cy / ordlib") + + @test_case + def add_op(self): + ii = self.impl(self.val1) + v2 = self.impl(self.val2) + for _ in range(1000): + ii + v2 + + @test_case + def getitem(self): + ii = self.impl(self.val1) + for _ in range(1000): + for i in range(len(self.val1)): + ii[i] + + @test_case + def insert(self): + ii = self.impl(self.val1) + for _ in range(5): + for i in range(1000): + ii.insert(-i % 2, 1) + + +class TupleGetter(Case): + @staticmethod + def python(): + from sqlalchemy.engine._py_row import tuplegetter + + return tuplegetter + + @staticmethod + def c(): + from sqlalchemy import cresultproxy + + return cresultproxy.tuplegetter + + @staticmethod + def cython(): + from sqlalchemy.cyextension import resultproxy + + return resultproxy.tuplegetter + + IMPLEMENTATIONS = { + "python": python.__func__, + "c": c.__func__, + "cython": cython.__func__, + } + + def init_objects(self): + self.impl_tg = self.impl + + self.tuple = tuple(range(1000)) + self.tg_inst = self.impl_tg(42) + self.tg_inst_m = self.impl_tg(42, 420, 99, 9, 1) + + class MockRow: + def __init__(self, data): + self.data = data + + def _get_by_int_impl(self, index): + # called by python + return self.data[index] + + def _get_by_key_impl_mapping(self, index): + # called by c + return self.data[index] + + self.row = MockRow(self.tuple) + + @classmethod + def update_results(cls, results): + cls._divide_results(results, "c", "python", "c / py") + cls._divide_results(results, "cython", "python", "cy / py") + cls._divide_results(results, "cython", "c", "cy / c") + + @test_case + def tuplegetter_one(self): + self.tg_inst(self.tuple) + + @test_case + def tuplegetter_many(self): + self.tg_inst_m(self.tuple) + + @test_case + def tuplegetter_new_one(self): + self.impl_tg(42)(self.tuple) + + @test_case + def tuplegetter_new_many(self): + self.impl_tg(42, 420, 99, 9, 1)(self.tuple) + + +class BaseRow(Case): + @staticmethod + def python(): + from sqlalchemy.engine._py_row import BaseRow + + return BaseRow + + @staticmethod + def c(): + from sqlalchemy.cresultproxy import BaseRow + + return BaseRow + + @staticmethod + def cython(): + from sqlalchemy.cyextension import resultproxy + + return resultproxy.BaseRow + + IMPLEMENTATIONS = { + "python": python.__func__, + # "c": c.__func__, + "cython": cython.__func__, + } + + def init_objects(self): + from sqlalchemy.engine.result import SimpleResultMetaData + from string import ascii_letters + + self.parent = SimpleResultMetaData(("a", "b", "c")) + self.row_args = ( + self.parent, + self.parent._processors, + self.parent._keymap, + 0, + (1, 2, 3), + ) + self.parent_long = SimpleResultMetaData(tuple(ascii_letters)) + self.row_long_args = ( + self.parent_long, + self.parent_long._processors, + self.parent_long._keymap, + 0, + tuple(range(len(ascii_letters))), + ) + self.row = self.impl(*self.row_args) + self.row_long = self.impl(*self.row_long_args) + assert isinstance(self.row, self.impl), type(self.row) + + class Row(self.impl): + pass + + self.Row = Row + self.row_sub = Row(*self.row_args) + + self.row_state = self.row.__getstate__() + self.row_long_state = self.row_long.__getstate__() + + @classmethod + def update_results(cls, results): + cls._divide_results(results, "c", "python", "c / py") + cls._divide_results(results, "cython", "python", "cy / py") + cls._divide_results(results, "cython", "c", "cy / c") + + @test_case + def base_row_new(self): + self.impl(*self.row_args) + self.impl(*self.row_long_args) + + @test_case + def row_new(self): + self.Row(*self.row_args) + self.Row(*self.row_long_args) + + @test_case + def row_dumps(self): + self.row.__getstate__() + self.row_long.__getstate__() + + @test_case + def row_loads(self): + self.impl.__new__(self.impl).__setstate__(self.row_state) + self.impl.__new__(self.impl).__setstate__(self.row_long_state) + + @test_case + def row_filter(self): + self.row._filter_on_values(None) + self.row_long._filter_on_values(None) + + @test_case + def row_values_impl(self): + self.row._values_impl() + self.row_long._values_impl() + + @test_case + def row_iter(self): + list(self.row) + list(self.row_long) + + @test_case + def row_len(self): + len(self.row) + len(self.row_long) + + @test_case + def row_hash(self): + hash(self.row) + hash(self.row_long) + + @test_case + def getitem(self): + self.row[0] + self.row[1] + self.row[-1] + self.row_long[0] + self.row_long[1] + self.row_long[-1] + + @test_case + def getitem_slice(self): + self.row[0:1] + self.row[1:-1] + self.row_long[0:1] + self.row_long[1:-1] + + @test_case + def get_by_int(self): + self.row._get_by_int_impl(0) + self.row._get_by_int_impl(1) + self.row_long._get_by_int_impl(0) + self.row_long._get_by_int_impl(1) + + @test_case + def get_by_key(self): + self.row._get_by_key_impl(0) + self.row._get_by_key_impl(1) + self.row_long._get_by_key_impl(0) + self.row_long._get_by_key_impl(1) + + @test_case + def get_by_key_slice(self): + self.row._get_by_key_impl(slice(0, 1)) + self.row._get_by_key_impl(slice(1, -1)) + self.row_long._get_by_key_impl(slice(0, 1)) + self.row_long._get_by_key_impl(slice(1, -1)) + + @test_case + def getattr(self): + self.row.a + self.row.b + self.row_long.x + self.row_long.y + + +def tabulate(results, inverse): + dim = 11 + header = "{:<20}|" + (" {:<%s} |" % dim) * len(results) + num_format = "{:<%s.9f}" % dim + row = "{:<20}|" + " {} |" * len(results) + names = list(results) + print(header.format("", *names)) + + for meth in inverse: + strings = [ + num_format.format(inverse[meth][name])[:dim] for name in names + ] + print(row.format(meth, *strings)) + + +def main(): + import argparse + + cases = Case._CASES + + parser = argparse.ArgumentParser( + description="Compare implementation between them" + ) + parser.add_argument( + "case", + help="Case to run", + nargs="+", + choices=["all"] + [c.__name__ for c in cases], + ) + parser.add_argument("--filter", help="filter the test for this regexp") + parser.add_argument( + "--factor", help="scale number passed to timeit", type=float, default=1 + ) + parser.add_argument("--csv", help="save to csv", action="store_true") + + args = parser.parse_args() + + if "all" in args.case: + to_run = cases + else: + to_run = [c for c in cases if c.__name__ in args.case] + + for case in to_run: + print("Running case", case.__name__) + result = case.run_case(args.factor, args.filter) + + inverse = defaultdict(dict) + for name in result: + for meth in result[name]: + inverse[meth][name] = result[name][meth] + + tabulate(result, inverse) + + if args.csv: + import csv + + file_name = f"{case.__name__}.csv" + with open(file_name, "w", newline="") as f: + w = csv.DictWriter(f, ["", *result]) + w.writeheader() + for n in inverse: + w.writerow({"": n, **inverse[n]}) + print("Wrote file", file_name) + + +if __name__ == "__main__": + main() diff --git a/test/profiles.txt b/test/profiles.txt index 7e23025e4..50907e984 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -164,7 +164,7 @@ test.aaa_profiling.test_orm.BranchedOptionTest.test_query_opts_unbound_branching # TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 15261 -test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 26278 +test.aaa_profiling.test_orm.DeferOptionsTest.test_baseline x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 34368 # TEST: test.aaa_profiling.test_orm.DeferOptionsTest.test_defer_many_cols @@ -188,12 +188,12 @@ test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_b_plain x86_64_linux_cpy # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 103539 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 97989 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 103689 # TEST: test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased -test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 101889 +test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 96369 test.aaa_profiling.test_orm.JoinConditionTest.test_a_to_d_aliased x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 102039 # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_build_query @@ -234,7 +234,7 @@ test.aaa_profiling.test_orm.MergeTest.test_merge_no_load x86_64_linux_cpython_3. # TEST: test.aaa_profiling.test_orm.QueryTest.test_query_cols test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 6142 -test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 6932 +test.aaa_profiling.test_orm.QueryTest.test_query_cols x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 7361 # TEST: test.aaa_profiling.test_orm.SelectInEagerLoadTest.test_round_trip_results @@ -312,9 +312,9 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_6 test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 2637 test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 15641 test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 2592 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 15596 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 25595 test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 2547 -test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 15551 +test.aaa_profiling.test_resultset.ResultSetTest.test_fetch_by_key_mappings x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 25550 # TEST: test.aaa_profiling.test_resultset.ResultSetTest.test_one_or_none[False-0] |