diff options
27 files changed, 1880 insertions, 402 deletions
@@ -59,6 +59,20 @@ CHANGES [ticket:1689] - sql + - Added an optional C extension to speed up the sql layer by + reimplementing RowProxy and the most common result processors. + The actual speedups will depend heavily on your DBAPI and + the mix of datatypes used in your tables, and can vary from + a 50% improvement to more than 200%. It also provides a modest + (~20%) indirect improvement to ORM speed for large queries. + Note that it is *not* built/installed by default. + See README for installation instructions. + + - The most common result processors conversion function were + moved to the new "processors" module. Dialect authors are + encouraged to use those functions whenever they correspond + to their needs instead of implementing custom ones. + - Added math negation operator support, -x. - FunctionElement subclasses are now directly executable the @@ -35,6 +35,15 @@ To install:: To use without installation, include the ``lib`` directory in your Python path. +Installing the C extension +-------------------------- + +Edit "setup.py" and set ``BUILD_CEXTENSIONS`` to ``True``, then install it as +above. If you want only to build the extension and not install it, you can do +so with:: + + python setup.py build + Running Tests ------------- diff --git a/lib/sqlalchemy/cextension/processors.c b/lib/sqlalchemy/cextension/processors.c new file mode 100644 index 000000000..23b7be4f2 --- /dev/null +++ b/lib/sqlalchemy/cextension/processors.c @@ -0,0 +1,384 @@ +/* +processors.c +Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com + +This module is part of SQLAlchemy and is released under +the MIT License: http://www.opensource.org/licenses/mit-license.php +*/ + +#include <Python.h> +#include <datetime.h> + +static PyObject * +int_to_boolean(PyObject *self, PyObject *arg) +{ + long l = 0; + PyObject *res; + + if (arg == Py_None) + Py_RETURN_NONE; + + l = PyInt_AsLong(arg); + if (l == 0) { + res = Py_False; + } else if (l == 1) { + res = Py_True; + } else if ((l == -1) && PyErr_Occurred()) { + /* -1 can be either the actual value, or an error flag. */ + return NULL; + } else { + PyErr_SetString(PyExc_ValueError, + "int_to_boolean only accepts None, 0 or 1"); + return NULL; + } + + Py_INCREF(res); + return res; +} + +static PyObject * +to_str(PyObject *self, PyObject *arg) +{ + if (arg == Py_None) + Py_RETURN_NONE; + + return PyObject_Str(arg); +} + +static PyObject * +to_float(PyObject *self, PyObject *arg) +{ + if (arg == Py_None) + Py_RETURN_NONE; + + return PyNumber_Float(arg); +} + +static PyObject * +str_to_datetime(PyObject *self, PyObject *arg) +{ + const char *str; + unsigned int year, month, day, hour, minute, second, microsecond = 0; + + if (arg == Py_None) + Py_RETURN_NONE; + + str = PyString_AsString(arg); + if (str == NULL) + return NULL; + + /* microseconds are optional */ + /* + TODO: this is slightly less picky than the Python version which would + not accept "2000-01-01 00:00:00.". I don't know which is better, but they + should be coherent. + */ + if (sscanf(str, "%4u-%2u-%2u %2u:%2u:%2u.%6u", &year, &month, &day, + &hour, &minute, &second, µsecond) < 6) { + PyErr_SetString(PyExc_ValueError, "Couldn't parse datetime string."); + return NULL; + } + return PyDateTime_FromDateAndTime(year, month, day, + hour, minute, second, microsecond); +} + +static PyObject * +str_to_time(PyObject *self, PyObject *arg) +{ + const char *str; + unsigned int hour, minute, second, microsecond = 0; + + if (arg == Py_None) + Py_RETURN_NONE; + + str = PyString_AsString(arg); + if (str == NULL) + return NULL; + + /* microseconds are optional */ + /* + TODO: this is slightly less picky than the Python version which would + not accept "00:00:00.". I don't know which is better, but they should be + coherent. + */ + if (sscanf(str, "%2u:%2u:%2u.%6u", &hour, &minute, &second, + µsecond) < 3) { + PyErr_SetString(PyExc_ValueError, "Couldn't parse time string."); + return NULL; + } + return PyTime_FromTime(hour, minute, second, microsecond); +} + +static PyObject * +str_to_date(PyObject *self, PyObject *arg) +{ + const char *str; + unsigned int year, month, day; + + if (arg == Py_None) + Py_RETURN_NONE; + + str = PyString_AsString(arg); + if (str == NULL) + return NULL; + + if (sscanf(str, "%4u-%2u-%2u", &year, &month, &day) != 3) { + PyErr_SetString(PyExc_ValueError, "Couldn't parse date string."); + return NULL; + } + return PyDate_FromDate(year, month, day); +} + + +/*********** + * Structs * + ***********/ + +typedef struct { + PyObject_HEAD + PyObject *encoding; + PyObject *errors; +} UnicodeResultProcessor; + +typedef struct { + PyObject_HEAD + PyObject *type; +} DecimalResultProcessor; + + + +/************************** + * UnicodeResultProcessor * + **************************/ + +static int +UnicodeResultProcessor_init(UnicodeResultProcessor *self, PyObject *args, + PyObject *kwds) +{ + PyObject *encoding, *errors; + static char *kwlist[] = {"encoding", "errors", NULL}; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "S|S:init", kwlist, + &encoding, &errors)) + return -1; + + Py_INCREF(encoding); + self->encoding = encoding; + + if (errors) { + Py_INCREF(errors); + } else { + errors = PyString_FromString("strict"); + if (errors == NULL) + return -1; + } + self->errors = errors; + + return 0; +} + +static PyObject * +UnicodeResultProcessor_process(UnicodeResultProcessor *self, PyObject *value) +{ + const char *encoding, *errors; + char *str; + Py_ssize_t len; + + if (value == Py_None) + Py_RETURN_NONE; + + if (PyString_AsStringAndSize(value, &str, &len)) + return NULL; + + encoding = PyString_AS_STRING(self->encoding); + errors = PyString_AS_STRING(self->errors); + + return PyUnicode_Decode(str, len, encoding, errors); +} + +static PyMethodDef UnicodeResultProcessor_methods[] = { + {"process", (PyCFunction)UnicodeResultProcessor_process, METH_O, + "The value processor itself."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject UnicodeResultProcessorType = { + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ + "sqlalchemy.cprocessors.UnicodeResultProcessor", /* tp_name */ + sizeof(UnicodeResultProcessor), /* tp_basicsize */ + 0, /* tp_itemsize */ + 0, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "UnicodeResultProcessor objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + UnicodeResultProcessor_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)UnicodeResultProcessor_init, /* tp_init */ + 0, /* tp_alloc */ + 0, /* tp_new */ +}; + +/************************** + * DecimalResultProcessor * + **************************/ + +static int +DecimalResultProcessor_init(DecimalResultProcessor *self, PyObject *args, + PyObject *kwds) +{ + PyObject *type; + + if (!PyArg_ParseTuple(args, "O", &type)) + return -1; + + Py_INCREF(type); + self->type = type; + + return 0; +} + +static PyObject * +DecimalResultProcessor_process(DecimalResultProcessor *self, PyObject *value) +{ + PyObject *str, *result; + + if (value == Py_None) + Py_RETURN_NONE; + + if (PyFloat_CheckExact(value)) { + /* Decimal does not accept float values directly */ + str = PyObject_Str(value); + if (str == NULL) + return NULL; + result = PyObject_CallFunctionObjArgs(self->type, str, NULL); + Py_DECREF(str); + return result; + } else { + return PyObject_CallFunctionObjArgs(self->type, value, NULL); + } +} + +static PyMethodDef DecimalResultProcessor_methods[] = { + {"process", (PyCFunction)DecimalResultProcessor_process, METH_O, + "The value processor itself."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject DecimalResultProcessorType = { + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ + "sqlalchemy.DecimalResultProcessor", /* tp_name */ + sizeof(DecimalResultProcessor), /* tp_basicsize */ + 0, /* tp_itemsize */ + 0, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "DecimalResultProcessor objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + DecimalResultProcessor_methods, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)DecimalResultProcessor_init, /* tp_init */ + 0, /* tp_alloc */ + 0, /* tp_new */ +}; + +#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ +#define PyMODINIT_FUNC void +#endif + + +static PyMethodDef module_methods[] = { + {"int_to_boolean", int_to_boolean, METH_O, + "Convert an integer to a boolean."}, + {"to_str", to_str, METH_O, + "Convert any value to its string representation."}, + {"to_float", to_float, METH_O, + "Convert any value to its floating point representation."}, + {"str_to_datetime", str_to_datetime, METH_O, + "Convert an ISO string to a datetime.datetime object."}, + {"str_to_time", str_to_time, METH_O, + "Convert an ISO string to a datetime.time object."}, + {"str_to_date", str_to_date, METH_O, + "Convert an ISO string to a datetime.date object."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +PyMODINIT_FUNC +initcprocessors(void) +{ + PyObject *m; + + UnicodeResultProcessorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&UnicodeResultProcessorType) < 0) + return; + + DecimalResultProcessorType.tp_new = PyType_GenericNew; + if (PyType_Ready(&DecimalResultProcessorType) < 0) + return; + + m = Py_InitModule3("cprocessors", module_methods, + "Module containing C versions of data processing functions."); + if (m == NULL) + return; + + PyDateTime_IMPORT; + + Py_INCREF(&UnicodeResultProcessorType); + PyModule_AddObject(m, "UnicodeResultProcessor", + (PyObject *)&UnicodeResultProcessorType); + + Py_INCREF(&DecimalResultProcessorType); + PyModule_AddObject(m, "DecimalResultProcessor", + (PyObject *)&DecimalResultProcessorType); +} + diff --git a/lib/sqlalchemy/cextension/resultproxy.c b/lib/sqlalchemy/cextension/resultproxy.c new file mode 100644 index 000000000..14ea1828e --- /dev/null +++ b/lib/sqlalchemy/cextension/resultproxy.c @@ -0,0 +1,586 @@ +/* +resultproxy.c +Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com + +This module is part of SQLAlchemy and is released under +the MIT License: http://www.opensource.org/licenses/mit-license.php +*/ + +#include <Python.h> + + +/*********** + * Structs * + ***********/ + +typedef struct { + PyObject_HEAD + PyObject *parent; + PyObject *row; + PyObject *processors; + PyObject *keymap; +} BaseRowProxy; + +/**************** + * BaseRowProxy * + ****************/ + +static PyObject * +rowproxy_reconstructor(PyObject *self, PyObject *args) +{ + PyObject *cls, *state, *tmp; + BaseRowProxy *obj; + + if (!PyArg_ParseTuple(args, "OO", &cls, &state)) + return NULL; + + obj = (BaseRowProxy *)PyObject_CallMethod(cls, "__new__", "O", cls); + if (obj == NULL) + return NULL; + + tmp = PyObject_CallMethod((PyObject *)obj, "__setstate__", "O", state); + if (tmp == NULL) { + Py_DECREF(obj); + return NULL; + } + Py_DECREF(tmp); + + if (obj->parent == NULL || obj->row == NULL || + obj->processors == NULL || obj->keymap == NULL) { + PyErr_SetString(PyExc_RuntimeError, + "__setstate__ for BaseRowProxy subclasses must set values " + "for parent, row, processors and keymap"); + Py_DECREF(obj); + return NULL; + } + + return (PyObject *)obj; +} + +static int +BaseRowProxy_init(BaseRowProxy *self, PyObject *args, PyObject *kwds) +{ + PyObject *parent, *row, *processors, *keymap; + + if (!PyArg_UnpackTuple(args, "BaseRowProxy", 4, 4, + &parent, &row, &processors, &keymap)) + return -1; + + Py_INCREF(parent); + self->parent = parent; + + if (!PyTuple_CheckExact(row)) { + PyErr_SetString(PyExc_TypeError, "row must be a tuple"); + return -1; + } + Py_INCREF(row); + self->row = row; + + if (!PyList_CheckExact(processors)) { + PyErr_SetString(PyExc_TypeError, "processors must be a list"); + return -1; + } + Py_INCREF(processors); + self->processors = processors; + + if (!PyDict_CheckExact(keymap)) { + PyErr_SetString(PyExc_TypeError, "keymap must be a dict"); + return -1; + } + Py_INCREF(keymap); + self->keymap = keymap; + + return 0; +} + +/* We need the reduce method because otherwise the default implementation + * does very weird stuff for pickle protocol 0 and 1. It calls + * BaseRowProxy.__new__(RowProxy_instance) upon *pickling*. + */ +static PyObject * +BaseRowProxy_reduce(PyObject *self) +{ + PyObject *method, *state; + PyObject *module, *reconstructor, *cls; + + method = PyObject_GetAttrString(self, "__getstate__"); + if (method == NULL) + return NULL; + + state = PyObject_CallObject(method, NULL); + Py_DECREF(method); + if (state == NULL) + return NULL; + + module = PyImport_ImportModule("sqlalchemy.engine.base"); + if (module == NULL) + return NULL; + + reconstructor = PyObject_GetAttrString(module, "rowproxy_reconstructor"); + Py_DECREF(module); + if (reconstructor == NULL) { + Py_DECREF(state); + return NULL; + } + + cls = PyObject_GetAttrString(self, "__class__"); + if (cls == NULL) { + Py_DECREF(reconstructor); + Py_DECREF(state); + return NULL; + } + + return Py_BuildValue("(N(NN))", reconstructor, cls, state); +} + +static void +BaseRowProxy_dealloc(BaseRowProxy *self) +{ + Py_XDECREF(self->parent); + Py_XDECREF(self->row); + Py_XDECREF(self->processors); + Py_XDECREF(self->keymap); + self->ob_type->tp_free((PyObject *)self); +} + +static PyObject * +BaseRowProxy_processvalues(PyObject *values, PyObject *processors, int astuple) +{ + Py_ssize_t num_values, num_processors; + PyObject **valueptr, **funcptr, **resultptr; + PyObject *func, *result, *processed_value; + + num_values = Py_SIZE(values); + num_processors = Py_SIZE(processors); + if (num_values != num_processors) { + PyErr_SetString(PyExc_RuntimeError, + "number of values in row difer from number of column processors"); + return NULL; + } + + if (astuple) { + result = PyTuple_New(num_values); + } else { + result = PyList_New(num_values); + } + if (result == NULL) + return NULL; + + /* we don't need to use PySequence_Fast as long as values, processors and + * result are simple tuple or lists. */ + valueptr = PySequence_Fast_ITEMS(values); + funcptr = PySequence_Fast_ITEMS(processors); + resultptr = PySequence_Fast_ITEMS(result); + while (--num_values >= 0) { + func = *funcptr; + if (func != Py_None) { + processed_value = PyObject_CallFunctionObjArgs(func, *valueptr, + NULL); + if (processed_value == NULL) { + Py_DECREF(result); + return NULL; + } + *resultptr = processed_value; + } else { + Py_INCREF(*valueptr); + *resultptr = *valueptr; + } + valueptr++; + funcptr++; + resultptr++; + } + return result; +} + +static PyListObject * +BaseRowProxy_values(BaseRowProxy *self) +{ + return (PyListObject *)BaseRowProxy_processvalues(self->row, + self->processors, 0); +} + +static PyTupleObject * +BaseRowProxy_tuplevalues(BaseRowProxy *self) +{ + return (PyTupleObject *)BaseRowProxy_processvalues(self->row, + self->processors, 1); +} + +static PyObject * +BaseRowProxy_iter(BaseRowProxy *self) +{ + PyObject *values, *result; + + values = (PyObject *)BaseRowProxy_tuplevalues(self); + if (values == NULL) + return NULL; + + result = PyObject_GetIter(values); + Py_DECREF(values); + if (result == NULL) + return NULL; + + return result; +} + +static Py_ssize_t +BaseRowProxy_length(BaseRowProxy *self) +{ + return Py_SIZE(self->row); +} + +static PyObject * +BaseRowProxy_subscript(BaseRowProxy *self, PyObject *key) +{ + PyObject *processors, *values; + PyObject *processor, *value; + PyObject *record, *result, *indexobject; + PyObject *exc_module, *exception; + char *cstr_key; + long index; + + if (PyInt_CheckExact(key)) { + index = PyInt_AS_LONG(key); + } else if (PyLong_CheckExact(key)) { + index = PyLong_AsLong(key); + if ((index == -1) && PyErr_Occurred()) + /* -1 can be either the actual value, or an error flag. */ + return NULL; + } else if (PySlice_Check(key)) { + values = PyObject_GetItem(self->row, key); + if (values == NULL) + return NULL; + + processors = PyObject_GetItem(self->processors, key); + if (processors == NULL) { + Py_DECREF(values); + return NULL; + } + + result = BaseRowProxy_processvalues(values, processors, 1); + Py_DECREF(values); + Py_DECREF(processors); + return result; + } else { + record = PyDict_GetItem((PyObject *)self->keymap, key); + if (record == NULL) { + record = PyObject_CallMethod(self->parent, "_key_fallback", + "O", key); + if (record == NULL) + return NULL; + } + + indexobject = PyTuple_GetItem(record, 1); + if (indexobject == NULL) + return NULL; + + if (indexobject == Py_None) { + exc_module = PyImport_ImportModule("sqlalchemy.exc"); + if (exc_module == NULL) + return NULL; + + exception = PyObject_GetAttrString(exc_module, + "InvalidRequestError"); + Py_DECREF(exc_module); + if (exception == NULL) + return NULL; + + cstr_key = PyString_AsString(key); + if (cstr_key == NULL) + return NULL; + + PyErr_Format(exception, + "Ambiguous column name '%s' in result set! " + "try 'use_labels' option on select statement.", cstr_key); + return NULL; + } + + index = PyInt_AsLong(indexobject); + if ((index == -1) && PyErr_Occurred()) + /* -1 can be either the actual value, or an error flag. */ + return NULL; + } + processor = PyList_GetItem(self->processors, index); + if (processor == NULL) + return NULL; + + value = PyTuple_GetItem(self->row, index); + if (value == NULL) + return NULL; + + if (processor != Py_None) { + return PyObject_CallFunctionObjArgs(processor, value, NULL); + } else { + Py_INCREF(value); + return value; + } +} + +static PyObject * +BaseRowProxy_getattro(BaseRowProxy *self, PyObject *name) +{ + PyObject *tmp; + + if (!(tmp = PyObject_GenericGetAttr((PyObject *)self, name))) { + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) + return NULL; + PyErr_Clear(); + } + else + return tmp; + + return BaseRowProxy_subscript(self, name); +} + +/*********************** + * getters and setters * + ***********************/ + +static PyObject * +BaseRowProxy_getparent(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->parent); + return self->parent; +} + +static int +BaseRowProxy_setparent(BaseRowProxy *self, PyObject *value, void *closure) +{ + PyObject *module, *cls; + + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'parent' attribute"); + return -1; + } + + module = PyImport_ImportModule("sqlalchemy.engine.base"); + if (module == NULL) + return -1; + + cls = PyObject_GetAttrString(module, "ResultMetaData"); + Py_DECREF(module); + if (cls == NULL) + return -1; + + if (PyObject_IsInstance(value, cls) != 1) { + PyErr_SetString(PyExc_TypeError, + "The 'parent' attribute value must be an instance of " + "ResultMetaData"); + return -1; + } + Py_DECREF(cls); + Py_XDECREF(self->parent); + Py_INCREF(value); + self->parent = value; + + return 0; +} + +static PyObject * +BaseRowProxy_getrow(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->row); + return self->row; +} + +static int +BaseRowProxy_setrow(BaseRowProxy *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'row' attribute"); + return -1; + } + + if (!PyTuple_CheckExact(value)) { + PyErr_SetString(PyExc_TypeError, + "The 'row' attribute value must be a tuple"); + return -1; + } + + Py_XDECREF(self->row); + Py_INCREF(value); + self->row = value; + + return 0; +} + +static PyObject * +BaseRowProxy_getprocessors(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->processors); + return self->processors; +} + +static int +BaseRowProxy_setprocessors(BaseRowProxy *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'processors' attribute"); + return -1; + } + + if (!PyList_CheckExact(value)) { + PyErr_SetString(PyExc_TypeError, + "The 'processors' attribute value must be a list"); + return -1; + } + + Py_XDECREF(self->processors); + Py_INCREF(value); + self->processors = value; + + return 0; +} + +static PyObject * +BaseRowProxy_getkeymap(BaseRowProxy *self, void *closure) +{ + Py_INCREF(self->keymap); + return self->keymap; +} + +static int +BaseRowProxy_setkeymap(BaseRowProxy *self, PyObject *value, void *closure) +{ + if (value == NULL) { + PyErr_SetString(PyExc_TypeError, + "Cannot delete the 'keymap' attribute"); + return -1; + } + + if (!PyDict_CheckExact(value)) { + PyErr_SetString(PyExc_TypeError, + "The 'keymap' attribute value must be a dict"); + return -1; + } + + Py_XDECREF(self->keymap); + Py_INCREF(value); + self->keymap = value; + + return 0; +} + +static PyGetSetDef BaseRowProxy_getseters[] = { + {"_parent", + (getter)BaseRowProxy_getparent, (setter)BaseRowProxy_setparent, + "ResultMetaData", + NULL}, + {"_row", + (getter)BaseRowProxy_getrow, (setter)BaseRowProxy_setrow, + "Original row tuple", + NULL}, + {"_processors", + (getter)BaseRowProxy_getprocessors, (setter)BaseRowProxy_setprocessors, + "list of type processors", + NULL}, + {"_keymap", + (getter)BaseRowProxy_getkeymap, (setter)BaseRowProxy_setkeymap, + "Key to (processor, index) dict", + NULL}, + {NULL} +}; + +static PyMethodDef BaseRowProxy_methods[] = { + {"values", (PyCFunction)BaseRowProxy_values, METH_NOARGS, + "Return the values represented by this BaseRowProxy as a list."}, + {"__reduce__", (PyCFunction)BaseRowProxy_reduce, METH_NOARGS, + "Pickle support method."}, + {NULL} /* Sentinel */ +}; + +static PySequenceMethods BaseRowProxy_as_sequence = { + (lenfunc)BaseRowProxy_length, /* sq_length */ + 0, /* sq_concat */ + 0, /* sq_repeat */ + 0, /* sq_item */ + 0, /* sq_slice */ + 0, /* sq_ass_item */ + 0, /* sq_ass_slice */ + 0, /* sq_contains */ + 0, /* sq_inplace_concat */ + 0, /* sq_inplace_repeat */ +}; + +static PyMappingMethods BaseRowProxy_as_mapping = { + (lenfunc)BaseRowProxy_length, /* mp_length */ + (binaryfunc)BaseRowProxy_subscript, /* mp_subscript */ + 0 /* mp_ass_subscript */ +}; + +static PyTypeObject BaseRowProxyType = { + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ + "sqlalchemy.cresultproxy.BaseRowProxy", /* tp_name */ + sizeof(BaseRowProxy), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)BaseRowProxy_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + &BaseRowProxy_as_sequence, /* tp_as_sequence */ + &BaseRowProxy_as_mapping, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + (getattrofunc)BaseRowProxy_getattro,/* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + "BaseRowProxy is a abstract base class for RowProxy", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + (getiterfunc)BaseRowProxy_iter, /* tp_iter */ + 0, /* tp_iternext */ + BaseRowProxy_methods, /* tp_methods */ + 0, /* tp_members */ + BaseRowProxy_getseters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)BaseRowProxy_init, /* tp_init */ + 0, /* tp_alloc */ + 0 /* tp_new */ +}; + + +#ifndef PyMODINIT_FUNC /* declarations for DLL import/export */ +#define PyMODINIT_FUNC void +#endif + + +static PyMethodDef module_methods[] = { + {"rowproxy_reconstructor", rowproxy_reconstructor, METH_VARARGS, + "reconstruct a RowProxy instance from its pickled form."}, + {NULL, NULL, 0, NULL} /* Sentinel */ +}; + +PyMODINIT_FUNC +initcresultproxy(void) +{ + PyObject *m; + + BaseRowProxyType.tp_new = PyType_GenericNew; + if (PyType_Ready(&BaseRowProxyType) < 0) + return; + + m = Py_InitModule3("cresultproxy", module_methods, + "Module containing C versions of core ResultProxy classes."); + if (m == NULL) + return; + + Py_INCREF(&BaseRowProxyType); + PyModule_AddObject(m, "BaseRowProxy", (PyObject *)&BaseRowProxyType); + +} + diff --git a/lib/sqlalchemy/dialects/access/base.py b/lib/sqlalchemy/dialects/access/base.py index a46ad247a..c10e77011 100644 --- a/lib/sqlalchemy/dialects/access/base.py +++ b/lib/sqlalchemy/dialects/access/base.py @@ -17,23 +17,17 @@ This dialect is *not* tested on SQLAlchemy 0.6. from sqlalchemy import sql, schema, types, exc, pool from sqlalchemy.sql import compiler, expression from sqlalchemy.engine import default, base - +from sqlalchemy import processors class AcNumeric(types.Numeric): - def result_processor(self, dialect, coltype): - return None + def get_col_spec(self): + return "NUMERIC" def bind_processor(self, dialect): - def process(value): - if value is None: - # Not sure that this exception is needed - return value - else: - return str(value) - return process + return processors.to_str - def get_col_spec(self): - return "NUMERIC" + def result_processor(self, dialect, coltype): + return None class AcFloat(types.Float): def get_col_spec(self): @@ -41,11 +35,7 @@ class AcFloat(types.Float): def bind_processor(self, dialect): """By converting to string, we can use Decimal types round-trip.""" - def process(value): - if not value is None: - return str(value) - return None - return process + return processors.to_str class AcInteger(types.Integer): def get_col_spec(self): @@ -103,25 +93,6 @@ class AcBoolean(types.Boolean): def get_col_spec(self): return "YESNO" - def result_processor(self, dialect, coltype): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process - class AcTimeStamp(types.TIMESTAMP): def get_col_spec(self): return "TIMESTAMP" @@ -443,4 +414,4 @@ dialect.poolclass = pool.SingletonThreadPool dialect.statement_compiler = AccessCompiler dialect.ddlcompiler = AccessDDLCompiler dialect.preparer = AccessIdentifierPreparer -dialect.execution_ctx_cls = AccessExecutionContext
\ No newline at end of file +dialect.execution_ctx_cls = AccessExecutionContext diff --git a/lib/sqlalchemy/dialects/informix/base.py b/lib/sqlalchemy/dialects/informix/base.py index 2802d493a..54aae6eb3 100644 --- a/lib/sqlalchemy/dialects/informix/base.py +++ b/lib/sqlalchemy/dialects/informix/base.py @@ -302,4 +302,4 @@ class InformixDialect(default.DefaultDialect): @reflection.cache def get_indexes(self, connection, table_name, schema, **kw): # TODO - return []
\ No newline at end of file + return [] diff --git a/lib/sqlalchemy/dialects/maxdb/base.py b/lib/sqlalchemy/dialects/maxdb/base.py index 2e0b9518b..f409f3213 100644 --- a/lib/sqlalchemy/dialects/maxdb/base.py +++ b/lib/sqlalchemy/dialects/maxdb/base.py @@ -60,7 +60,7 @@ this. """ import datetime, itertools, re -from sqlalchemy import exc, schema, sql, util +from sqlalchemy import exc, schema, sql, util, processors from sqlalchemy.sql import operators as sql_operators, expression as sql_expr from sqlalchemy.sql import compiler, visitors from sqlalchemy.engine import base as engine_base, default @@ -86,6 +86,12 @@ class _StringType(sqltypes.String): return process def result_processor(self, dialect, coltype): + #XXX: this code is probably very slow and one should try (if at all + # possible) to determine the correct code path on a per-connection + # basis (ie, here in result_processor, instead of inside the processor + # function itself) and probably also use a few generic + # processors, or possibly per query (though there is no mechanism + # for that yet). def process(value): while True: if value is None: @@ -152,6 +158,7 @@ class MaxNumeric(sqltypes.Numeric): def bind_processor(self, dialect): return None + class MaxTimestamp(sqltypes.DateTime): def bind_processor(self, dialect): def process(value): @@ -172,25 +179,30 @@ class MaxTimestamp(sqltypes.DateTime): return process def result_processor(self, dialect, coltype): - def process(value): - if value is None: - return None - elif dialect.datetimeformat == 'internal': - return datetime.datetime( - *[int(v) - for v in (value[0:4], value[4:6], value[6:8], - value[8:10], value[10:12], value[12:14], - value[14:])]) - elif dialect.datetimeformat == 'iso': - return datetime.datetime( - *[int(v) - for v in (value[0:4], value[5:7], value[8:10], - value[11:13], value[14:16], value[17:19], - value[20:])]) - else: - raise exc.InvalidRequestError( - "datetimeformat '%s' is not supported." % ( - dialect.datetimeformat,)) + if dialect.datetimeformat == 'internal': + def process(value): + if value is None: + return None + else: + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[4:6], value[6:8], + value[8:10], value[10:12], value[12:14], + value[14:])]) + elif dialect.datetimeformat == 'iso': + def process(value): + if value is None: + return None + else: + return datetime.datetime( + *[int(v) + for v in (value[0:4], value[5:7], value[8:10], + value[11:13], value[14:16], value[17:19], + value[20:])]) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % + dialect.datetimeformat) return process @@ -212,19 +224,24 @@ class MaxDate(sqltypes.Date): return process def result_processor(self, dialect, coltype): - def process(value): - if value is None: - return None - elif dialect.datetimeformat == 'internal': - return datetime.date( - *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) - elif dialect.datetimeformat == 'iso': - return datetime.date( - *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) - else: - raise exc.InvalidRequestError( - "datetimeformat '%s' is not supported." % ( - dialect.datetimeformat,)) + if dialect.datetimeformat == 'internal': + def process(value): + if value is None: + return None + else: + return datetime.date(int(value[0:4]), int(value[4:6]), + int(value[6:8])) + elif dialect.datetimeformat == 'iso': + def process(value): + if value is None: + return None + else: + return datetime.date(int(value[0:4]), int(value[5:7]), + int(value[8:10])) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % + dialect.datetimeformat) return process @@ -246,31 +263,30 @@ class MaxTime(sqltypes.Time): return process def result_processor(self, dialect, coltype): - def process(value): - if value is None: - return None - elif dialect.datetimeformat == 'internal': - t = datetime.time( - *[int(v) for v in (value[0:4], value[4:6], value[6:8])]) - return t - elif dialect.datetimeformat == 'iso': - return datetime.time( - *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) - else: - raise exc.InvalidRequestError( - "datetimeformat '%s' is not supported." % ( - dialect.datetimeformat,)) + if dialect.datetimeformat == 'internal': + def process(value): + if value is None: + return None + else: + return datetime.time(int(value[0:4]), int(value[4:6]), + int(value[6:8])) + elif dialect.datetimeformat == 'iso': + def process(value): + if value is None: + return None + else: + return datetime.time(int(value[0:4]), int(value[5:7]), + int(value[8:10])) + else: + raise exc.InvalidRequestError( + "datetimeformat '%s' is not supported." % + dialect.datetimeformat) return process class MaxBlob(sqltypes.LargeBinary): def bind_processor(self, dialect): - def process(value): - if value is None: - return None - else: - return str(value) - return process + return processors.to_str def result_processor(self, dialect, coltype): def process(value): diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 4e58d64b3..3f4e0b9f3 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -233,11 +233,10 @@ from sqlalchemy.sql import select, compiler, expression, \ functions as sql_functions, util as sql_util from sqlalchemy.engine import default, base, reflection from sqlalchemy import types as sqltypes -from decimal import Decimal as _python_Decimal +from sqlalchemy import processors from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ FLOAT, TIMESTAMP, DATETIME, DATE, BINARY,\ VARBINARY, BLOB - from sqlalchemy.dialects.mssql import information_schema as ischema @@ -280,22 +279,12 @@ RESERVED_WORDS = set( class _MSNumeric(sqltypes.Numeric): def result_processor(self, dialect, coltype): if self.asdecimal: - def process(value): - if value is not None: - return _python_Decimal(str(value)) - else: - return value - return process + return processors.to_decimal_processor_factory(decimal.Decimal) else: #XXX: if the DBAPI returns a float (this is likely, given the # processor when asdecimal is True), this should be a None # processor instead. - def process(value): - if value is not None: - return float(value) - else: - return value - return process + return processors.to_float def bind_processor(self, dialect): def process(value): diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index eb348f1a1..82a4af941 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -351,7 +351,8 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL): numeric. """ - super(DECIMAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + super(DECIMAL, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) class DOUBLE(_FloatType): @@ -375,7 +376,8 @@ class DOUBLE(_FloatType): numeric. """ - super(DOUBLE, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + super(DOUBLE, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) class REAL(_FloatType): """MySQL REAL type.""" @@ -398,7 +400,8 @@ class REAL(_FloatType): numeric. """ - super(REAL, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + super(REAL, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) class FLOAT(_FloatType, sqltypes.FLOAT): """MySQL FLOAT type.""" @@ -421,7 +424,8 @@ class FLOAT(_FloatType, sqltypes.FLOAT): numeric. """ - super(FLOAT, self).__init__(precision=precision, scale=scale, asdecimal=asdecimal, **kw) + super(FLOAT, self).__init__(precision=precision, scale=scale, + asdecimal=asdecimal, **kw) def bind_processor(self, dialect): return None @@ -2459,6 +2463,7 @@ class _DecodingRowProxy(object): def __init__(self, rowproxy, charset): self.rowproxy = rowproxy self.charset = charset + def __getitem__(self, index): item = self.rowproxy[index] if isinstance(item, _array): @@ -2467,6 +2472,7 @@ class _DecodingRowProxy(object): return item.decode(self.charset) else: return item + def __getattr__(self, attr): item = getattr(self.rowproxy, attr) if isinstance(item, _array): diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index c07ed8713..8cfd5930f 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -28,6 +28,7 @@ from sqlalchemy.dialects.mysql.base import (DECIMAL, MySQLDialect, MySQLExecutio from sqlalchemy.engine import base as engine_base, default from sqlalchemy.sql import operators as sql_operators from sqlalchemy import exc, log, schema, sql, types as sqltypes, util +from sqlalchemy import processors class MySQL_mysqldbExecutionContext(MySQLExecutionContext): @@ -51,12 +52,7 @@ class _DecimalType(_NumericType): def result_processor(self, dialect, coltype): if self.asdecimal: return None - def process(value): - if value is not None: - return float(value) - else: - return value - return process + return processors.to_float class _MySQLdbNumeric(_DecimalType, NUMERIC): diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py index a03aa988e..1fca6850a 100644 --- a/lib/sqlalchemy/dialects/mysql/oursql.py +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -29,18 +29,14 @@ from sqlalchemy.dialects.mysql.base import (BIT, MySQLDialect, MySQLExecutionCon from sqlalchemy.engine import base as engine_base, default from sqlalchemy.sql import operators as sql_operators from sqlalchemy import exc, log, schema, sql, types as sqltypes, util +from sqlalchemy import processors class _oursqlNumeric(NUMERIC): def result_processor(self, dialect, coltype): if self.asdecimal: return None - def process(value): - if value is not None: - return float(value) - else: - return value - return process + return processors.to_float class _oursqlBIT(BIT): diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py index 22a1f443c..fba16288a 100644 --- a/lib/sqlalchemy/dialects/oracle/zxjdbc.py +++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py @@ -32,6 +32,9 @@ class _ZxJDBCDate(sqltypes.Date): class _ZxJDBCNumeric(sqltypes.Numeric): def result_processor(self, dialect, coltype): + #XXX: does the dialect return Decimal or not??? + # if it does (in all cases), we could use a None processor as well as + # the to_float generic processor if self.asdecimal: def process(value): if isinstance(value, decimal.Decimal): diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index e90bebb6b..079b05530 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -19,31 +19,23 @@ Interval Passing data from/to the Interval type is not supported as of yet. """ -from sqlalchemy.engine import default import decimal + +from sqlalchemy.engine import default from sqlalchemy import util, exc +from sqlalchemy import processors from sqlalchemy import types as sqltypes from sqlalchemy.dialects.postgresql.base import PGDialect, \ PGCompiler, PGIdentifierPreparer, PGExecutionContext class _PGNumeric(sqltypes.Numeric): def bind_processor(self, dialect): - def process(value): - if value is not None: - return float(value) - else: - return value - return process + return processors.to_float def result_processor(self, dialect, coltype): if self.asdecimal: if coltype in (700, 701): - def process(value): - if value is not None: - return decimal.Decimal(str(value)) - else: - return value - return process + return processors.to_decimal_processor_factory(decimal.Decimal) elif coltype == 1700: # pg8000 returns Decimal natively for 1700 return None @@ -54,12 +46,7 @@ class _PGNumeric(sqltypes.Numeric): # pg8000 returns float natively for 701 return None elif coltype == 1700: - def process(value): - if value is not None: - return float(value) - else: - return value - return process + return processors.to_float else: raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index bb6562dea..712124288 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -46,8 +46,11 @@ The following per-statement execution options are respected: """ -import decimal, random, re +import random, re +import decimal + from sqlalchemy import util +from sqlalchemy import processors from sqlalchemy.engine import base, default from sqlalchemy.sql import expression from sqlalchemy.sql import operators as sql_operators @@ -63,12 +66,7 @@ class _PGNumeric(sqltypes.Numeric): def result_processor(self, dialect, coltype): if self.asdecimal: if coltype in (700, 701): - def process(value): - if value is not None: - return decimal.Decimal(str(value)) - else: - return value - return process + return processors.to_decimal_processor_factory(decimal.Decimal) elif coltype == 1700: # pg8000 returns Decimal natively for 1700 return None @@ -79,12 +77,7 @@ class _PGNumeric(sqltypes.Numeric): # pg8000 returns float natively for 701 return None elif coltype == 1700: - def process(value): - if value is not None: - return float(value) - else: - return value - return process + return processors.to_float else: raise exc.InvalidRequestError("Unknown PG numeric type: %d" % coltype) diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py index 77ed44512..88f1acde7 100644 --- a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -12,6 +12,7 @@ import decimal from sqlalchemy import util from sqlalchemy import types as sqltypes from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext +from sqlalchemy import processors class PGNumeric(sqltypes.Numeric): def bind_processor(self, dialect): @@ -21,12 +22,7 @@ class PGNumeric(sqltypes.Numeric): if self.asdecimal: return None else: - def process(value): - if value is not None: - return float(value) - else: - return value - return process + return processors.to_float class PostgreSQL_pypostgresqlExecutionContext(PGExecutionContext): pass diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 696f65a6c..e987439c5 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -54,6 +54,7 @@ from sqlalchemy import types as sqltypes from sqlalchemy import util from sqlalchemy.sql import compiler, functions as sql_functions from sqlalchemy.util import NoneType +from sqlalchemy import processors from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\ FLOAT, INTEGER, NUMERIC, SMALLINT, TEXT, TIME,\ @@ -62,13 +63,10 @@ from sqlalchemy.types import BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL,\ class _NumericMixin(object): def bind_processor(self, dialect): - type_ = self.asdecimal and str or float - def process(value): - if value is not None: - return type_(value) - else: - return value - return process + if self.asdecimal: + return processors.to_str + else: + return processors.to_float class _SLNumeric(_NumericMixin, sqltypes.Numeric): pass @@ -86,19 +84,7 @@ class _DateTimeMixin(object): if storage_format is not None: self._storage_format = storage_format - def _result_processor(self, fn): - rmatch = self._reg.match - # Even on python2.6 datetime.strptime is both slower than this code - # and it does not support microseconds. - def process(value): - if value is not None: - return fn(*map(int, rmatch(value).groups(0))) - else: - return None - return process - class DATETIME(_DateTimeMixin, sqltypes.DateTime): - _reg = re.compile(r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?") _storage_format = "%04d-%02d-%02d %02d:%02d:%02d.%06d" def bind_processor(self, dialect): @@ -121,10 +107,13 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): return process def result_processor(self, dialect, coltype): - return self._result_processor(datetime.datetime) + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.datetime) + else: + return processors.str_to_datetime class DATE(_DateTimeMixin, sqltypes.Date): - _reg = re.compile(r"(\d+)-(\d+)-(\d+)") _storage_format = "%04d-%02d-%02d" def bind_processor(self, dialect): @@ -141,10 +130,13 @@ class DATE(_DateTimeMixin, sqltypes.Date): return process def result_processor(self, dialect, coltype): - return self._result_processor(datetime.date) + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.date) + else: + return processors.str_to_date class TIME(_DateTimeMixin, sqltypes.Time): - _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d+))?") _storage_format = "%02d:%02d:%02d.%06d" def bind_processor(self, dialect): @@ -162,7 +154,11 @@ class TIME(_DateTimeMixin, sqltypes.Time): return process def result_processor(self, dialect, coltype): - return self._result_processor(datetime.time) + if self._reg: + return processors.str_to_datetime_processor_factory( + self._reg, datetime.time) + else: + return processors.str_to_time colspecs = { sqltypes.Date: DATE, diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index cdbf6138d..886a773d8 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -115,24 +115,7 @@ class SybaseUniqueIdentifier(sqltypes.TypeEngine): __visit_name__ = "UNIQUEIDENTIFIER" class SybaseBoolean(sqltypes.Boolean): - def result_processor(self, dialect, coltype): - def process(value): - if value is None: - return None - return value and True or False - return process - - def bind_processor(self, dialect): - def process(value): - if value is True: - return 1 - elif value is False: - return 0 - elif value is None: - return None - else: - return value and True or False - return process + pass class SybaseTypeCompiler(compiler.GenericTypeCompiler): def visit_large_binary(self, type_): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 844183628..4dc9665c0 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -20,6 +20,7 @@ __all__ = [ 'connection_memoize'] import inspect, StringIO, sys, operator +from itertools import izip from sqlalchemy import exc, schema, util, types, log from sqlalchemy.sql import expression @@ -1536,16 +1537,20 @@ class Engine(Connectable): def _proxy_connection_cls(cls, proxy): class ProxyConnection(cls): def execute(self, object, *multiparams, **params): - return proxy.execute(self, super(ProxyConnection, self).execute, object, *multiparams, **params) + return proxy.execute(self, super(ProxyConnection, self).execute, + object, *multiparams, **params) def _execute_clauseelement(self, elem, multiparams=None, params=None): - return proxy.execute(self, super(ProxyConnection, self).execute, elem, *(multiparams or []), **(params or {})) + return proxy.execute(self, super(ProxyConnection, self).execute, + elem, *(multiparams or []), **(params or {})) def _cursor_execute(self, cursor, statement, parameters, context=None): - return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, cursor, statement, parameters, context, False) + return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, + cursor, statement, parameters, context, False) def _cursor_executemany(self, cursor, statement, parameters, context=None): - return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, cursor, statement, parameters, context, True) + return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, + cursor, statement, parameters, context, True) def _begin_impl(self): return proxy.begin(self, super(ProxyConnection, self)._begin_impl) @@ -1560,27 +1565,125 @@ def _proxy_connection_cls(cls, proxy): return proxy.savepoint(self, super(ProxyConnection, self)._savepoint_impl, name=name) def _rollback_to_savepoint_impl(self, name, context): - return proxy.rollback_savepoint(self, super(ProxyConnection, self)._rollback_to_savepoint_impl, name, context) + return proxy.rollback_savepoint(self, + super(ProxyConnection, self)._rollback_to_savepoint_impl, + name, context) def _release_savepoint_impl(self, name, context): - return proxy.release_savepoint(self, super(ProxyConnection, self)._release_savepoint_impl, name, context) + return proxy.release_savepoint(self, + super(ProxyConnection, self)._release_savepoint_impl, + name, context) def _begin_twophase_impl(self, xid): - return proxy.begin_twophase(self, super(ProxyConnection, self)._begin_twophase_impl, xid) + return proxy.begin_twophase(self, + super(ProxyConnection, self)._begin_twophase_impl, xid) def _prepare_twophase_impl(self, xid): - return proxy.prepare_twophase(self, super(ProxyConnection, self)._prepare_twophase_impl, xid) + return proxy.prepare_twophase(self, + super(ProxyConnection, self)._prepare_twophase_impl, xid) def _rollback_twophase_impl(self, xid, is_prepared): - return proxy.rollback_twophase(self, super(ProxyConnection, self)._rollback_twophase_impl, xid, is_prepared) + return proxy.rollback_twophase(self, + super(ProxyConnection, self)._rollback_twophase_impl, + xid, is_prepared) def _commit_twophase_impl(self, xid, is_prepared): - return proxy.commit_twophase(self, super(ProxyConnection, self)._commit_twophase_impl, xid, is_prepared) + return proxy.commit_twophase(self, + super(ProxyConnection, self)._commit_twophase_impl, + xid, is_prepared) return ProxyConnection +# This reconstructor is necessary so that pickles with the C extension or +# without use the same Binary format. +# We need a different reconstructor on the C extension so that we can +# add extra checks that fields have correctly been initialized by +# __setstate__. +try: + from sqlalchemy.cresultproxy import rowproxy_reconstructor + + # this is a hack so that the reconstructor function is pickled with the + # same name as without the C extension. + # BUG: It fails for me if I run the "python" interpreter and + # then say "import sqlalchemy": + # TypeError: 'builtin_function_or_method' object has only read-only attributes (assign to .__module__) + # However, if I run the tests with nosetests, it succeeds ! + # I've verified with pdb etc. that this is the case. + #rowproxy_reconstructor.__module__ = 'sqlalchemy.engine.base' + +except ImportError: + def rowproxy_reconstructor(cls, state): + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + +try: + from sqlalchemy.cresultproxy import BaseRowProxy +except ImportError: + class BaseRowProxy(object): + __slots__ = ('_parent', '_row', '_processors', '_keymap') + + def __init__(self, parent, row, processors, keymap): + """RowProxy objects are constructed by ResultProxy objects.""" + + self._parent = parent + self._row = row + self._processors = processors + self._keymap = keymap + + def __reduce__(self): + return (rowproxy_reconstructor, + (self.__class__, self.__getstate__())) + + def values(self): + """Return the values represented by this RowProxy as a list.""" + return list(self) + + def __iter__(self): + for processor, value in izip(self._processors, self._row): + if processor is None: + yield value + else: + yield processor(value) + + def __len__(self): + return len(self._row) -class RowProxy(object): + def __getitem__(self, key): + try: + processor, index = self._keymap[key] + except KeyError: + processor, index = self._parent._key_fallback(key) + except TypeError: + if isinstance(key, slice): + l = [] + for processor, value in izip(self._processors[key], + self._row[key]): + if processor is None: + l.append(value) + else: + l.append(processor(value)) + return tuple(l) + else: + raise + if index is None: + raise exc.InvalidRequestError( + "Ambiguous column name '%s' in result set! " + "try 'use_labels' option on select statement." % key) + if processor is not None: + return processor(self._row[index]) + else: + return self._row[index] + + def __getattr__(self, name): + try: + # TODO: no test coverage here + return self[name] + except KeyError, e: + raise AttributeError(e.args[0]) + + +class RowProxy(BaseRowProxy): """Proxy values from a single cursor row. Mostly follows "ordered dictionary" behavior, mapping result @@ -1589,38 +1692,22 @@ class RowProxy(object): mapped to the original Columns that produced this result set (for results that correspond to constructed SQL expressions). """ + __slots__ = () - __slots__ = ['__parent', '__row', '__colfuncs'] - - def __init__(self, parent, row): - - self.__parent = parent - self.__row = row - self.__colfuncs = parent._colfuncs - if self.__parent._echo: - self.__parent.logger.debug("Row %r", row) - def __contains__(self, key): - return self.__parent._has_key(self.__row, key) + return self._parent._has_key(self._row, key) - def __len__(self): - return len(self.__row) - def __getstate__(self): return { - '__row':[self.__colfuncs[i][0](self.__row) for i in xrange(len(self.__row))], - '__parent':self.__parent + '_parent': self._parent, + '_row': tuple(self) } - - def __setstate__(self, d): - self.__row = d['__row'] - self.__parent = d['__parent'] - self.__colfuncs = self.__parent._colfuncs - - def __iter__(self): - row = self.__row - for func in self.__parent._colfunc_list: - yield func(row) + + def __setstate__(self, state): + self._parent = parent = state['_parent'] + self._row = state['_row'] + self._processors = parent._processors + self._keymap = parent._keymap __hash__ = None @@ -1636,33 +1723,7 @@ class RowProxy(object): def has_key(self, key): """Return True if this RowProxy contains the given key.""" - return self.__parent._has_key(self.__row, key) - - def __getitem__(self, key): - # the fallback and slices are only useful for __getitem__ anyway - try: - return self.__colfuncs[key][0](self.__row) - except KeyError: - k = self.__parent._key_fallback(key) - if k is None: - raise exc.NoSuchColumnError( - "Could not locate column in row for column '%s'" % key) - else: - # save on KeyError + _key_fallback() lookup next time around - self.__colfuncs[key] = k - return k[0](self.__row) - except TypeError: - if isinstance(key, slice): - return tuple(func(self.__row) for func in self.__parent._colfunc_list[key]) - else: - raise - - def __getattr__(self, name): - try: - # TODO: no test coverage here - return self[name] - except KeyError, e: - raise AttributeError(e.args[0]) + return self._parent._has_key(self._row, key) def items(self): """Return a list of tuples, each tuple containing a key/value pair.""" @@ -1672,24 +1733,25 @@ class RowProxy(object): def keys(self): """Return the list of keys as strings represented by this RowProxy.""" - return self.__parent.keys + return self._parent.keys def iterkeys(self): - return iter(self.__parent.keys) - - def values(self): - """Return the values represented by this RowProxy as a list.""" - - return list(self) + return iter(self._parent.keys) def itervalues(self): return iter(self) + class ResultMetaData(object): """Handle cursor.description, applying additional info from an execution context.""" def __init__(self, parent, metadata): - self._colfuncs = colfuncs = {} + self._processors = processors = [] + + # We do not strictly need to store the processor in the key mapping, + # though it is faster in the Python version (probably because of the + # saved attribute lookup self._processors) + self._keymap = keymap = {} self.keys = [] self._echo = parent._echo context = parent.context @@ -1720,29 +1782,25 @@ class ResultMetaData(object): processor = type_.dialect_impl(dialect).\ result_processor(dialect, coltype) - if processor: - def make_colfunc(processor, index): - def getcol(row): - return processor(row[index]) - return getcol - rec = (make_colfunc(processor, i), i, "colfunc") - else: - rec = (operator.itemgetter(i), i, "itemgetter") + processors.append(processor) + rec = (processor, i) - # indexes as keys - colfuncs[i] = rec + # indexes as keys. This is only needed for the Python version of + # RowProxy (the C version uses a faster path for integer indexes). + keymap[i] = rec # Column names as keys - if colfuncs.setdefault(name.lower(), rec) is not rec: - #XXX: why not raise directly? because several columns colliding - #by name is not a problem as long as the user don't use them (ie - #use the more precise ColumnElement - colfuncs[name.lower()] = (self._ambiguous_processor(name), i, "ambiguous") - + if keymap.setdefault(name.lower(), rec) is not rec: + # We do not raise an exception directly because several + # columns colliding by name is not a problem as long as the + # user does not try to access them (ie use an index directly, + # or the more precise ColumnElement) + keymap[name.lower()] = (processor, None) + # store the "origname" if we truncated (sqlite only) if origname and \ - colfuncs.setdefault(origname.lower(), rec) is not rec: - colfuncs[origname.lower()] = (self._ambiguous_processor(origname), i, "ambiguous") + keymap.setdefault(origname.lower(), rec) is not rec: + keymap[origname.lower()] = (processor, None) if dialect.requires_name_normalize: colname = dialect.normalize_name(colname) @@ -1750,76 +1808,67 @@ class ResultMetaData(object): self.keys.append(colname) if obj: for o in obj: - colfuncs[o] = rec + keymap[o] = rec if self._echo: self.logger = context.engine.logger self.logger.debug( "Col %r", tuple(x[0] for x in metadata)) - @util.memoized_property - def _colfunc_list(self): - funcs = self._colfuncs - return [funcs[i][0] for i in xrange(len(self.keys))] - def _key_fallback(self, key): - funcs = self._colfuncs - + map = self._keymap + result = None if isinstance(key, basestring): - key = key.lower() - if key in funcs: - return funcs[key] - + result = map.get(key.lower()) # fallback for targeting a ColumnElement to a textual expression # this is a rare use case which only occurs when matching text() - # constructs to ColumnElements - if isinstance(key, expression.ColumnElement): - if key._label and key._label.lower() in funcs: - return funcs[key._label.lower()] - elif hasattr(key, 'name') and key.name.lower() in funcs: - return funcs[key.name.lower()] - - return None + # constructs to ColumnElements, and after a pickle/unpickle roundtrip + elif isinstance(key, expression.ColumnElement): + if key._label and key._label.lower() in map: + result = map[key._label.lower()] + elif hasattr(key, 'name') and key.name.lower() in map: + result = map[key.name.lower()] + if result is None: + raise exc.NoSuchColumnError( + "Could not locate column in row for column '%s'" % key) + else: + map[key] = result + return result def _has_key(self, row, key): - if key in self._colfuncs: + if key in self._keymap: return True else: - key = self._key_fallback(key) - return key is not None + try: + self._key_fallback(key) + return True + except exc.NoSuchColumnError: + return False - @classmethod - def _ambiguous_processor(cls, colname): - def process(value): - raise exc.InvalidRequestError( - "Ambiguous column name '%s' in result set! " - "try 'use_labels' option on select statement." % colname) - return process - def __len__(self): return len(self.keys) def __getstate__(self): return { - '_pickled_colfuncs':dict( - (key, (i, type_)) - for key, (fn, i, type_) in self._colfuncs.iteritems() + '_pickled_keymap': dict( + (key, index) + for key, (processor, index) in self._keymap.iteritems() if isinstance(key, (basestring, int)) ), - 'keys':self.keys + 'keys': self.keys } def __setstate__(self, state): - pickled_colfuncs = state['_pickled_colfuncs'] - self._colfuncs = d = {} - for key, (index, type_) in pickled_colfuncs.iteritems(): - if type_ == 'ambiguous': - d[key] = (self._ambiguous_processor(key), index, type_) - else: - d[key] = (operator.itemgetter(index), index, "itemgetter") + # the row has been processed at pickling time so we don't need any + # processor anymore + self._processors = [None for _ in xrange(len(state['keys']))] + self._keymap = keymap = {} + for key, index in state['_pickled_keymap'].iteritems(): + keymap[key] = (None, index) self.keys = state['keys'] self._echo = False - + + class ResultProxy(object): """Wraps a DB-API cursor object to provide easier access to row columns. @@ -2031,13 +2080,27 @@ class ResultProxy(object): def _fetchall_impl(self): return self.cursor.fetchall() + def process_rows(self, rows): + process_row = self._process_row + metadata = self._metadata + keymap = metadata._keymap + processors = metadata._processors + if self._echo: + log = self.context.engine.logger.debug + l = [] + for row in rows: + log("Row %r", row) + l.append(process_row(metadata, row, processors, keymap)) + return l + else: + return [process_row(metadata, row, processors, keymap) + for row in rows] + def fetchall(self): """Fetch all rows, just like DB-API ``cursor.fetchall()``.""" try: - process_row = self._process_row - metadata = self._metadata - l = [process_row(metadata, row) for row in self._fetchall_impl()] + l = self.process_rows(self._fetchall_impl()) self.close() return l except Exception, e: @@ -2053,9 +2116,7 @@ class ResultProxy(object): """ try: - process_row = self._process_row - metadata = self._metadata - l = [process_row(metadata, row) for row in self._fetchmany_impl(size)] + l = self.process_rows(self._fetchmany_impl(size)) if len(l) == 0: self.close() return l @@ -2074,7 +2135,7 @@ class ResultProxy(object): try: row = self._fetchone_impl() if row is not None: - return self._process_row(self._metadata, row) + return self.process_rows([row])[0] else: self.close() return None @@ -2096,13 +2157,12 @@ class ResultProxy(object): try: if row is not None: - return self._process_row(self._metadata, row) + return self.process_rows([row])[0] else: return None finally: self.close() - def scalar(self): """Fetch the first column of the first row, and close the result set. @@ -2210,9 +2270,18 @@ class FullyBufferedResultProxy(ResultProxy): return ret class BufferedColumnRow(RowProxy): - def __init__(self, parent, row): - row = [parent._orig_colfuncs[i][0](row) for i in xrange(len(row))] - super(BufferedColumnRow, self).__init__(parent, row) + def __init__(self, parent, row, processors, keymap): + # preprocess row + row = list(row) + # this is a tad faster than using enumerate + index = 0 + for processor in parent._orig_processors: + if processor is not None: + row[index] = processor(row[index]) + index += 1 + row = tuple(row) + super(BufferedColumnRow, self).__init__(parent, row, + processors, keymap) class BufferedColumnResultProxy(ResultProxy): """A ResultProxy with column buffering behavior. @@ -2221,7 +2290,7 @@ class BufferedColumnResultProxy(ResultProxy): fetchone() is called. If fetchmany() or fetchall() are called, the full grid of results is fetched. This is to operate with databases where result rows contain "live" results that fall out - of scope unless explicitly fetched. Currently this includes + of scope unless explicitly fetched. Currently this includes cx_Oracle LOB objects. """ @@ -2230,17 +2299,16 @@ class BufferedColumnResultProxy(ResultProxy): def _init_metadata(self): super(BufferedColumnResultProxy, self)._init_metadata() - self._metadata._orig_colfuncs = self._metadata._colfuncs - self._metadata._colfuncs = colfuncs = {} - # replace the parent's _colfuncs dict, replacing - # column processors with straight itemgetters. - # the original _colfuncs dict is used when each row - # is constructed. - for k, (colfunc, index, type_) in self._metadata._orig_colfuncs.iteritems(): - if type_ == "colfunc": - colfuncs[k] = (operator.itemgetter(index), index, "itemgetter") - else: - colfuncs[k] = (colfunc, index, type_) + metadata = self._metadata + # orig_processors will be used to preprocess each row when they are + # constructed. + metadata._orig_processors = metadata._processors + # replace the all type processors by None processors. + metadata._processors = [None for _ in xrange(len(metadata.keys))] + keymap = {} + for k, (func, index) in metadata._keymap.iteritems(): + keymap[k] = (None, index) + self._metadata._keymap = keymap def fetchall(self): # can't call cursor.fetchall(), since rows must be diff --git a/lib/sqlalchemy/processors.py b/lib/sqlalchemy/processors.py new file mode 100644 index 000000000..cb4b72545 --- /dev/null +++ b/lib/sqlalchemy/processors.py @@ -0,0 +1,90 @@ +# processors.py +# Copyright (C) 2010 Gaetan de Menten gdementen@gmail.com +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""defines generic type conversion functions, as used in result processors. + +They all share one common characteristic: None is passed through unchanged. + +""" + +import codecs +import re +import datetime + +def str_to_datetime_processor_factory(regexp, type_): + rmatch = regexp.match + # Even on python2.6 datetime.strptime is both slower than this code + # and it does not support microseconds. + def process(value): + if value is None: + return None + else: + return type_(*map(int, rmatch(value).groups(0))) + return process + +try: + from sqlalchemy.cprocessors import UnicodeResultProcessor, \ + DecimalResultProcessor, \ + to_float, to_str, int_to_boolean, \ + str_to_datetime, str_to_time, \ + str_to_date + + def to_unicode_processor_factory(encoding): + return UnicodeResultProcessor(encoding).process + + def to_decimal_processor_factory(target_class): + return DecimalResultProcessor(target_class).process + +except ImportError: + def to_unicode_processor_factory(encoding): + decoder = codecs.getdecoder(encoding) + + def process(value): + if value is None: + return None + else: + # decoder returns a tuple: (value, len). Simply dropping the + # len part is safe: it is done that way in the normal + # 'xx'.decode(encoding) code path. + # cfr python-source/Python/codecs.c:PyCodec_Decode + return decoder(value)[0] + return process + + def to_decimal_processor_factory(target_class): + def process(value): + if value is None: + return None + else: + return target_class(str(value)) + return process + + def to_float(value): + if value is None: + return None + else: + return float(value) + + def to_str(value): + if value is None: + return None + else: + return str(value) + + def int_to_boolean(value): + if value is None: + return None + else: + return value and True or False + + DATETIME_RE = re.compile("(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)(?:\.(\d+))?") + TIME_RE = re.compile("(\d+):(\d+):(\d+)(?:\.(\d+))?") + DATE_RE = re.compile("(\d+)-(\d+)-(\d+)") + + str_to_datetime = str_to_datetime_processor_factory(DATETIME_RE, + datetime.datetime) + str_to_time = str_to_datetime_processor_factory(TIME_RE, datetime.time) + str_to_date = str_to_datetime_processor_factory(DATE_RE, datetime.date) + diff --git a/lib/sqlalchemy/test/profiling.py b/lib/sqlalchemy/test/profiling.py index 8cab6ceba..c5256affa 100644 --- a/lib/sqlalchemy/test/profiling.py +++ b/lib/sqlalchemy/test/profiling.py @@ -93,9 +93,16 @@ def function_call_count(count=None, versions={}, variance=0.05): version_info = list(sys.version_info) py_version = '.'.join([str(v) for v in sys.version_info]) - + try: + from sqlalchemy.cprocessors import to_float + cextension = True + except ImportError: + cextension = False + while version_info: version = '.'.join([str(v) for v in version_info]) + if cextension: + version += "+cextension" if version in versions: count = versions[version] break diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 465454df9..36302cae3 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -32,6 +32,7 @@ schema.types = expression.sqltypes =sys.modules['sqlalchemy.types'] from sqlalchemy.util import pickle from sqlalchemy.sql.visitors import Visitable from sqlalchemy import util +from sqlalchemy import processors NoneType = type(None) if util.jython: @@ -608,14 +609,7 @@ class String(Concatenable, TypeEngine): if needs_convert: # note we *assume* that we do not have a unicode object # here, instead of an expensive isinstance() check. - decoder = codecs.getdecoder(dialect.encoding) - def process(value): - if value is not None: - # decoder returns a tuple: (value, len) - return decoder(value)[0] - else: - return value - return process + return processors.to_unicode_processor_factory(dialect.encoding) else: return None @@ -810,21 +804,15 @@ class Numeric(_DateAffinity, TypeEngine): return dbapi.NUMBER def bind_processor(self, dialect): - def process(value): - if value is not None: - return float(value) - else: - return value - return process + return processors.to_float def result_processor(self, dialect, coltype): if self.asdecimal: - def process(value): - if value is not None: - return _python_Decimal(str(value)) - else: - return value - return process + #XXX: use decimal from http://www.bytereef.org/libmpdec.html +# try: +# from fastdec import mpd as Decimal +# except ImportError: + return processors.to_decimal_processor_factory(_python_Decimal) else: return None @@ -991,11 +979,7 @@ class _Binary(TypeEngine): else: return None else: - def process(value): - if value is not None: - return str(value) - else: - return None + process = processors.to_str return process # end Py2K @@ -1349,11 +1333,7 @@ class Boolean(TypeEngine, SchemaType): if dialect.supports_native_boolean: return None else: - def process(value): - if value is None: - return None - return value and True or False - return process + return processors.int_to_boolean class Interval(_DateAffinity, TypeDecorator): """A type for ``datetime.timedelta()`` objects. @@ -1419,7 +1399,7 @@ class Interval(_DateAffinity, TypeDecorator): if impl_processor: def process(value): value = impl_processor(value) - if value is None: + if value is None: return None return value - epoch else: @@ -15,9 +15,11 @@ if sys.version_info >= (3, 0): ) try: - from setuptools import setup + from setuptools import setup, Extension except ImportError: - from distutils.core import setup + from distutils.core import setup, Extension + +BUILD_CEXTENSIONS = False def find_packages(dir_): packages = [] @@ -46,6 +48,12 @@ setup(name = "SQLAlchemy", license = "MIT License", tests_require = ['nose >= 0.10'], test_suite = "nose.collector", + ext_modules = (BUILD_CEXTENSIONS and + [Extension('sqlalchemy.cprocessors', + sources=['lib/sqlalchemy/cextension/processors.c']), + Extension('sqlalchemy.cresultproxy', + sources=['lib/sqlalchemy/cextension/resultproxy.c']) + ]), entry_points = { 'nose.plugins.0.10': [ 'sqlalchemy = sqlalchemy.test.noseplugin:NoseSQLAlchemy', diff --git a/test/aaa_profiling/test_resultset.py b/test/aaa_profiling/test_resultset.py index 83901b7f7..459a8e4c4 100644 --- a/test/aaa_profiling/test_resultset.py +++ b/test/aaa_profiling/test_resultset.py @@ -29,13 +29,13 @@ class ResultSetTest(TestBase, AssertsExecutionResults): def teardown(self): metadata.drop_all() - @profiling.function_call_count(14416, versions={'2.4':13214}) + @profiling.function_call_count(14416, versions={'2.4':13214, '2.6+cextension':409}) def test_string(self): [tuple(row) for row in t.select().execute().fetchall()] # sqlite3 returns native unicode. so shouldn't be an # increase here. - @profiling.function_call_count(14396, versions={'2.4':13214}) + @profiling.function_call_count(14396, versions={'2.4':13214, '2.6+cextension':409}) def test_unicode(self): [tuple(row) for row in t2.select().execute().fetchall()] diff --git a/test/aaa_profiling/test_zoomark.py b/test/aaa_profiling/test_zoomark.py index 66bb45f31..706f8e470 100644 --- a/test/aaa_profiling/test_zoomark.py +++ b/test/aaa_profiling/test_zoomark.py @@ -339,7 +339,7 @@ class ZooMarkTest(TestBase): def test_profile_3_properties(self): self.test_baseline_3_properties() - @profiling.function_call_count(13341, {'2.4': 7963}) + @profiling.function_call_count(13341, {'2.4': 7963, '2.6+cextension':12447}) def test_profile_4_expressions(self): self.test_baseline_4_expressions() @@ -351,7 +351,7 @@ class ZooMarkTest(TestBase): def test_profile_6_editing(self): self.test_baseline_6_editing() - @profiling.function_call_count(2641, {'2.4': 1673}) + @profiling.function_call_count(2641, {'2.4': 1673, '2.6+cextension':2502}) def test_profile_7_multiview(self): self.test_baseline_7_multiview() diff --git a/test/perf/stress_all.py b/test/perf/stress_all.py new file mode 100644 index 000000000..ad074ee53 --- /dev/null +++ b/test/perf/stress_all.py @@ -0,0 +1,226 @@ +# -*- encoding: utf8 -*- +from datetime import * +from decimal import Decimal +#from fastdec import mpd as Decimal +from cPickle import dumps, loads + +#from sqlalchemy.dialects.postgresql.base import ARRAY + +from stresstest import * + +# --- +test_types = False +test_methods = True +test_pickle = False +test_orm = False +# --- +verbose = True + +def values_results(raw_results): + return [tuple(r.values()) for r in raw_results] + +def getitem_str_results(raw_results): + return [ + (r['id'], + r['field0'], r['field1'], r['field2'], r['field3'], r['field4'], + r['field5'], r['field6'], r['field7'], r['field8'], r['field9']) + for r in raw_results] + +def getitem_fallback_results(raw_results): + return [ + (r['ID'], + r['FIELD0'], r['FIELD1'], r['FIELD2'], r['FIELD3'], r['FIELD4'], + r['FIELD5'], r['FIELD6'], r['FIELD7'], r['FIELD8'], r['FIELD9']) + for r in raw_results] + +def getitem_int_results(raw_results): + return [ + (r[0], + r[1], r[2], r[3], r[4], r[5], + r[6], r[7], r[8], r[9], r[10]) + for r in raw_results] + +def getitem_long_results(raw_results): + return [ + (r[0L], + r[1L], r[2L], r[3L], r[4L], r[5L], + r[6L], r[7L], r[8L], r[9L], r[10L]) + for r in raw_results] + +def getitem_obj_results(raw_results): + c = test_table.c + fid, f0, f1, f2, f3, f4, f5, f6, f7, f8, f9 = ( + c.id, c.field0, c.field1, c.field2, c.field3, c.field4, + c.field5, c.field6, c.field7, c.field8, c.field9) + return [ + (r[fid], + r[f0], r[f1], r[f2], r[f3], r[f4], + r[f5], r[f6], r[f7], r[f8], r[f9]) + for r in raw_results] + +def slice_results(raw_results): + return [row[0:6] + row[6:11] for row in raw_results] + +# ---------- # +# Test types # +# ---------- # + +# Array +#def genarrayvalue(rnum, fnum): +# return [fnum, fnum + 1, fnum + 2] +#arraytest = (ARRAY(Integer), genarrayvalue, +# dict(num_fields=100, num_records=1000, +# engineurl='postgresql:///test')) + +# Boolean +def genbooleanvalue(rnum, fnum): + if rnum % 4: + return bool(fnum % 2) + else: + return None +booleantest = (Boolean, genbooleanvalue, dict(num_records=100000)) + +# Datetime +def gendatetimevalue(rnum, fnum): + return (rnum % 4) and datetime(2005, 3, 3) or None +datetimetest = (DateTime, gendatetimevalue, dict(num_records=10000)) + +# Decimal +def gendecimalvalue(rnum, fnum): + if rnum % 4: + return Decimal(str(0.25 * fnum)) + else: + return None +decimaltest = (Numeric(10, 2), gendecimalvalue, dict(num_records=10000)) + +# Interval + +# no microseconds because Postgres does not seem to support it +from_epoch = timedelta(14643, 70235) +def genintervalvalue(rnum, fnum): + return from_epoch +intervaltest = (Interval, genintervalvalue, + dict(num_fields=2, num_records=100000)) + +# PickleType +def genpicklevalue(rnum, fnum): + return (rnum % 4) and {'str': "value%d" % fnum, 'int': rnum} or None +pickletypetest = (PickleType, genpicklevalue, + dict(num_fields=1, num_records=100000)) + +# TypeDecorator +class MyIntType(TypeDecorator): + impl = Integer + + def process_bind_param(self, value, dialect): + return value * 10 + + def process_result_value(self, value, dialect): + return value / 10 + + def copy(self): + return MyIntType() + +def genmyintvalue(rnum, fnum): + return rnum + fnum +typedecoratortest = (MyIntType, genmyintvalue, + dict(num_records=100000)) + +# Unicode +def genunicodevalue(rnum, fnum): + return (rnum % 4) and (u"value%d" % fnum) or None +unicodetest = (Unicode(20, assert_unicode=False), genunicodevalue, + dict(num_records=100000)) +# dict(engineurl='mysql:///test', freshdata=False)) + +# do the tests +if test_types: + tests = [booleantest, datetimetest, decimaltest, intervaltest, + pickletypetest, typedecoratortest, unicodetest] + for engineurl in ('postgresql://scott:tiger@localhost/test', + 'sqlite://', 'mysql://scott:tiger@localhost/test'): + print "\n%s\n" % engineurl + for datatype, genvalue, kwargs in tests: + print "%s:" % getattr(datatype, '__name__', + datatype.__class__.__name__), + profile_and_time_dbfunc(iter_results, datatype, genvalue, + profile=False, engineurl=engineurl, + verbose=verbose, **kwargs) + +# ---------------------- # +# test row proxy methods # +# ---------------------- # + +if test_methods: + methods = [iter_results, values_results, getattr_results, + getitem_str_results, getitem_fallback_results, + getitem_int_results, getitem_long_results, getitem_obj_results, + slice_results] + for engineurl in ('postgresql://scott:tiger@localhost/test', + 'sqlite://', 'mysql://scott:tiger@localhost/test'): + print "\n%s\n" % engineurl + test_table = prepare(Unicode(20, assert_unicode=False), + genunicodevalue, + num_fields=10, num_records=100000, + verbose=verbose, engineurl=engineurl) + for method in methods: + print "%s:" % method.__name__, + time_dbfunc(test_table, method, genunicodevalue, + num_fields=10, num_records=100000, profile=False, + verbose=verbose) + +# -------------------------------- +# test pickling Rowproxy instances +# -------------------------------- + +def pickletofile_results(raw_results): + from cPickle import dump, load + for protocol in (0, 1, 2): + print "dumping protocol %d..." % protocol + f = file('noext.pickle%d' % protocol, 'wb') + dump(raw_results, f, protocol) + f.close() + return raw_results + +def pickle_results(raw_results): + return loads(dumps(raw_results, 2)) + +def pickle_meta(raw_results): + pickled = dumps(raw_results[0]._parent, 2) + metadata = loads(pickled) + return raw_results + +def pickle_rows(raw_results): + return [loads(dumps(row, 2)) for row in raw_results] + +if test_pickle: + test_table = prepare(Unicode, genunicodevalue, + num_fields=10, num_records=10000) + funcs = [pickle_rows, pickle_results] + for func in funcs: + print "%s:" % func.__name__, + time_dbfunc(test_table, func, genunicodevalue, + num_records=10000, profile=False, verbose=verbose) + +# -------------------------------- +# test ORM +# -------------------------------- + +if test_orm: + from sqlalchemy.orm import * + + class Test(object): + pass + + Session = sessionmaker() + session = Session() + + def get_results(): + return session.query(Test).all() + print "ORM:", + for engineurl in ('postgresql:///test', 'sqlite://', 'mysql:///test'): + print "\n%s\n" % engineurl + profile_and_time_dbfunc(getattr_results, Unicode(20), genunicodevalue, + class_=Test, getresults_func=get_results, + engineurl=engineurl, #freshdata=False, + num_records=10000, verbose=verbose) diff --git a/test/perf/stresstest.py b/test/perf/stresstest.py new file mode 100644 index 000000000..cf9404f53 --- /dev/null +++ b/test/perf/stresstest.py @@ -0,0 +1,174 @@ +import gc +import sys +import timeit +import cProfile + +from sqlalchemy import MetaData, Table, Column +from sqlalchemy.types import * +from sqlalchemy.orm import mapper, clear_mappers + +metadata = MetaData() + +def gen_table(num_fields, field_type, metadata): + return Table('test', metadata, + Column('id', Integer, primary_key=True), + *[Column("field%d" % fnum, field_type) + for fnum in range(num_fields)]) + +def insert(test_table, num_fields, num_records, genvalue, verbose=True): + if verbose: + print "building insert values...", + sys.stdout.flush() + values = [dict(("field%d" % fnum, genvalue(rnum, fnum)) + for fnum in range(num_fields)) + for rnum in range(num_records)] + if verbose: + print "inserting...", + sys.stdout.flush() + def db_insert(): + test_table.insert().execute(values) + sys.modules['__main__'].db_insert = db_insert + timing = timeit.timeit("db_insert()", + "from __main__ import db_insert", + number=1) + if verbose: + print "%s" % round(timing, 3) + +def check_result(results, num_fields, genvalue, verbose=True): + if verbose: + print "checking...", + sys.stdout.flush() + for rnum, row in enumerate(results): + expected = tuple([rnum + 1] + + [genvalue(rnum, fnum) for fnum in range(num_fields)]) + assert row == expected, "got: %s\nexpected: %s" % (row, expected) + return True + +def avgdev(values, comparison): + return sum(value - comparison for value in values) / len(values) + +def nicer_res(values, printvalues=False): + if printvalues: + print values + min_time = min(values) + return round(min_time, 3), round(avgdev(values, min_time), 2) + +def profile_func(func_name, verbose=True): + if verbose: + print "profiling...", + sys.stdout.flush() + cProfile.run('%s()' % func_name, 'prof') + +def time_func(func_name, num_tests=1, verbose=True): + if verbose: + print "timing...", + sys.stdout.flush() + timings = timeit.repeat('%s()' % func_name, + "from __main__ import %s" % func_name, + number=num_tests, repeat=5) + avg, dev = nicer_res(timings) + if verbose: + print "%s (%s)" % (avg, dev) + else: + print avg + +def profile_and_time(func_name, num_tests=1): + profile_func(func_name) + time_func(func_name, num_tests) + +def iter_results(raw_results): + return [tuple(row) for row in raw_results] + +def getattr_results(raw_results): + return [ + (r.id, + r.field0, r.field1, r.field2, r.field3, r.field4, + r.field5, r.field6, r.field7, r.field8, r.field9) + for r in raw_results] + +def fetchall(test_table): + def results(): + return test_table.select().order_by(test_table.c.id).execute() \ + .fetchall() + return results + +def hashable_set(l): + hashables = [] + for o in l: + try: + hash(o) + hashables.append(o) + except: + pass + return set(hashables) + +def prepare(field_type, genvalue, engineurl='sqlite://', + num_fields=10, num_records=1000, freshdata=True, verbose=True): + global metadata + metadata.clear() + metadata.bind = engineurl + test_table = gen_table(num_fields, field_type, metadata) + if freshdata: + metadata.drop_all() + metadata.create_all() + insert(test_table, num_fields, num_records, genvalue, verbose) + return test_table + +def time_dbfunc(test_table, test_func, genvalue, + class_=None, + getresults_func=None, + num_fields=10, num_records=1000, num_tests=1, + check_results=check_result, profile=True, + check_leaks=True, print_leaks=False, verbose=True): + if verbose: + print "testing '%s'..." % test_func.__name__, + sys.stdout.flush() + if class_ is not None: + clear_mappers() + mapper(class_, test_table) + if getresults_func is None: + getresults_func = fetchall(test_table) + def test(): + return test_func(getresults_func()) + sys.modules['__main__'].test = test + if check_leaks: + gc.collect() + objects_before = gc.get_objects() + num_objects_before = len(objects_before) + hashable_objects_before = hashable_set(objects_before) +# gc.set_debug(gc.DEBUG_LEAK) + if check_results: + check_results(test(), num_fields, genvalue, verbose) + if check_leaks: + gc.collect() + objects_after = gc.get_objects() + num_objects_after = len(objects_after) + num_leaks = num_objects_after - num_objects_before + hashable_objects_after = hashable_set(objects_after) + diff = hashable_objects_after - hashable_objects_before + ldiff = len(diff) + if print_leaks and ldiff < num_records: + print "\n*** hashable objects leaked (%d) ***" % ldiff + print '\n'.join(map(str, diff)) + print "***\n" + + if num_leaks > num_records: + print "(leaked: %d !)" % num_leaks, + if profile: + profile_func('test', verbose) + time_func('test', num_tests, verbose) + +def profile_and_time_dbfunc(test_func, field_type, genvalue, + class_=None, + getresults_func=None, + engineurl='sqlite://', freshdata=True, + num_fields=10, num_records=1000, num_tests=1, + check_results=check_result, profile=True, + check_leaks=True, print_leaks=False, verbose=True): + test_table = prepare(field_type, genvalue, engineurl, + num_fields, num_records, freshdata, verbose) + time_dbfunc(test_table, test_func, genvalue, class_, + getresults_func, + num_fields, num_records, num_tests, + check_results, profile, + check_leaks, print_leaks, verbose) diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 345ecef67..5433cb92f 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -701,21 +701,21 @@ class QueryTest(TestBase): Column('shadow_name', VARCHAR(20)), Column('parent', VARCHAR(20)), Column('row', VARCHAR(40)), - Column('__parent', VARCHAR(20)), - Column('__row', VARCHAR(20)), + Column('_parent', VARCHAR(20)), + Column('_row', VARCHAR(20)), ) shadowed.create(checkfirst=True) try: - shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', __parent='Hidden parent', __row='Hidden row') + shadowed.insert().execute(shadow_id=1, shadow_name='The Shadow', parent='The Light', row='Without light there is no shadow', _parent='Hidden parent', _row='Hidden row') r = shadowed.select(shadowed.c.shadow_id==1).execute().first() self.assert_(r.shadow_id == r['shadow_id'] == r[shadowed.c.shadow_id] == 1) self.assert_(r.shadow_name == r['shadow_name'] == r[shadowed.c.shadow_name] == 'The Shadow') self.assert_(r.parent == r['parent'] == r[shadowed.c.parent] == 'The Light') self.assert_(r.row == r['row'] == r[shadowed.c.row] == 'Without light there is no shadow') - self.assert_(r['__parent'] == 'Hidden parent') - self.assert_(r['__row'] == 'Hidden row') + self.assert_(r['_parent'] == 'Hidden parent') + self.assert_(r['_row'] == 'Hidden row') try: - print r.__parent, r.__row + print r._parent, r._row self.fail('Should not allow access to private attributes') except AttributeError: pass # expected |