summaryrefslogtreecommitdiff
path: root/Modules/_sqlite/connection.c
diff options
context:
space:
mode:
Diffstat (limited to 'Modules/_sqlite/connection.c')
-rw-r--r--Modules/_sqlite/connection.c107
1 files changed, 91 insertions, 16 deletions
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index d52bea49c3..8cf2d6aa80 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -109,7 +109,10 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject
Py_INCREF(isolation_level);
}
self->isolation_level = NULL;
- pysqlite_connection_set_isolation_level(self, isolation_level);
+ if (pysqlite_connection_set_isolation_level(self, isolation_level) < 0) {
+ Py_DECREF(isolation_level);
+ return -1;
+ }
Py_DECREF(isolation_level);
self->statement_cache = (pysqlite_Cache*)PyObject_CallFunction((PyObject*)&pysqlite_CacheType, "Oi", self, cached_statements);
@@ -677,7 +680,7 @@ void _pysqlite_final_callback(sqlite3_context* context)
{
PyObject* function_result;
PyObject** aggregate_instance;
- PyObject* aggregate_class;
+ _Py_IDENTIFIER(finalize);
int ok;
#ifdef WITH_THREAD
@@ -686,8 +689,6 @@ void _pysqlite_final_callback(sqlite3_context* context)
threadstate = PyGILState_Ensure();
#endif
- aggregate_class = (PyObject*)sqlite3_user_data(context);
-
aggregate_instance = (PyObject**)sqlite3_aggregate_context(context, sizeof(PyObject*));
if (!*aggregate_instance) {
/* this branch is executed if there was an exception in the aggregate's
@@ -696,7 +697,7 @@ void _pysqlite_final_callback(sqlite3_context* context)
goto error;
}
- function_result = PyObject_CallMethod(*aggregate_instance, "finalize", "");
+ function_result = _PyObject_CallMethodId(*aggregate_instance, &PyId_finalize, "");
Py_DECREF(*aggregate_instance);
ok = 0;
@@ -717,6 +718,9 @@ error:
#ifdef WITH_THREAD
PyGILState_Release(threadstate);
#endif
+ /* explicit return to avoid a compilation error if WITH_THREAD
+ is not defined */
+ return;
}
static void _pysqlite_drop_unused_statement_references(pysqlite_Connection* self)
@@ -916,6 +920,38 @@ static int _progress_handler(void* user_arg)
return rc;
}
+static void _trace_callback(void* user_arg, const char* statement_string)
+{
+ PyObject *py_statement = NULL;
+ PyObject *ret = NULL;
+
+#ifdef WITH_THREAD
+ PyGILState_STATE gilstate;
+
+ gilstate = PyGILState_Ensure();
+#endif
+ py_statement = PyUnicode_DecodeUTF8(statement_string,
+ strlen(statement_string), "replace");
+ if (py_statement) {
+ ret = PyObject_CallFunctionObjArgs((PyObject*)user_arg, py_statement, NULL);
+ Py_DECREF(py_statement);
+ }
+
+ if (ret) {
+ Py_DECREF(ret);
+ } else {
+ if (_enable_callback_tracebacks) {
+ PyErr_Print();
+ } else {
+ PyErr_Clear();
+ }
+ }
+
+#ifdef WITH_THREAD
+ PyGILState_Release(gilstate);
+#endif
+}
+
static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
{
PyObject* authorizer_cb;
@@ -975,6 +1011,34 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s
return Py_None;
}
+static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
+{
+ PyObject* trace_callback;
+
+ static char *kwlist[] = { "trace_callback", NULL };
+
+ if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
+ return NULL;
+ }
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:set_trace_callback",
+ kwlist, &trace_callback)) {
+ return NULL;
+ }
+
+ if (trace_callback == Py_None) {
+ /* None clears the trace callback previously set */
+ sqlite3_trace(self->db, 0, (void*)0);
+ } else {
+ if (PyDict_SetItem(self->function_pinboard, trace_callback, Py_None) == -1)
+ return NULL;
+ sqlite3_trace(self->db, _trace_callback, trace_callback);
+ }
+
+ Py_INCREF(Py_None);
+ return Py_None;
+}
+
#ifdef HAVE_LOAD_EXTENSION
static PyObject* pysqlite_enable_load_extension(pysqlite_Connection* self, PyObject* args)
{
@@ -1180,8 +1244,9 @@ PyObject* pysqlite_connection_execute(pysqlite_Connection* self, PyObject* args,
PyObject* cursor = 0;
PyObject* result = 0;
PyObject* method = 0;
+ _Py_IDENTIFIER(cursor);
- cursor = PyObject_CallMethod((PyObject*)self, "cursor", "");
+ cursor = _PyObject_CallMethodId((PyObject*)self, &PyId_cursor, "");
if (!cursor) {
goto error;
}
@@ -1209,8 +1274,9 @@ PyObject* pysqlite_connection_executemany(pysqlite_Connection* self, PyObject* a
PyObject* cursor = 0;
PyObject* result = 0;
PyObject* method = 0;
+ _Py_IDENTIFIER(cursor);
- cursor = PyObject_CallMethod((PyObject*)self, "cursor", "");
+ cursor = _PyObject_CallMethodId((PyObject*)self, &PyId_cursor, "");
if (!cursor) {
goto error;
}
@@ -1238,8 +1304,9 @@ PyObject* pysqlite_connection_executescript(pysqlite_Connection* self, PyObject*
PyObject* cursor = 0;
PyObject* result = 0;
PyObject* method = 0;
+ _Py_IDENTIFIER(cursor);
- cursor = PyObject_CallMethod((PyObject*)self, "cursor", "");
+ cursor = _PyObject_CallMethodId((PyObject*)self, &PyId_cursor, "");
if (!cursor) {
goto error;
}
@@ -1394,10 +1461,12 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args)
PyObject* uppercase_name = 0;
PyObject* name;
PyObject* retval;
- Py_UNICODE* chk;
Py_ssize_t i, len;
+ _Py_IDENTIFIER(upper);
char *uppercase_name_str;
int rc;
+ unsigned int kind;
+ void *data;
if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
goto finally;
@@ -1407,17 +1476,21 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args)
goto finally;
}
- uppercase_name = PyObject_CallMethod(name, "upper", "");
+ uppercase_name = _PyObject_CallMethodId(name, &PyId_upper, "");
if (!uppercase_name) {
goto finally;
}
- len = PyUnicode_GET_SIZE(uppercase_name);
- chk = PyUnicode_AS_UNICODE(uppercase_name);
- for (i=0; i<len; i++, chk++) {
- if ((*chk >= '0' && *chk <= '9')
- || (*chk >= 'A' && *chk <= 'Z')
- || (*chk == '_'))
+ if (PyUnicode_READY(uppercase_name))
+ goto finally;
+ len = PyUnicode_GET_LENGTH(uppercase_name);
+ kind = PyUnicode_KIND(uppercase_name);
+ data = PyUnicode_DATA(uppercase_name);
+ for (i=0; i<len; i++) {
+ Py_UCS4 ch = PyUnicode_READ(kind, data, i);
+ if ((ch >= '0' && ch <= '9')
+ || (ch >= 'A' && ch <= 'Z')
+ || (ch == '_'))
{
continue;
} else {
@@ -1536,6 +1609,8 @@ static PyMethodDef connection_methods[] = {
#endif
{"set_progress_handler", (PyCFunction)pysqlite_connection_set_progress_handler, METH_VARARGS|METH_KEYWORDS,
PyDoc_STR("Sets progress handler callback. Non-standard.")},
+ {"set_trace_callback", (PyCFunction)pysqlite_connection_set_trace_callback, METH_VARARGS|METH_KEYWORDS,
+ PyDoc_STR("Sets a trace callback called for each SQL statement (passed as unicode). Non-standard.")},
{"execute", (PyCFunction)pysqlite_connection_execute, METH_VARARGS,
PyDoc_STR("Executes a SQL statement. Non-standard.")},
{"executemany", (PyCFunction)pysqlite_connection_executemany, METH_VARARGS,