summaryrefslogtreecommitdiff
path: root/Modules/_sqlite/connection.c
diff options
context:
space:
mode:
Diffstat (limited to 'Modules/_sqlite/connection.c')
-rw-r--r--Modules/_sqlite/connection.c134
1 files changed, 122 insertions, 12 deletions
diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c
index 64e43ebcbc..703af15fa9 100644
--- a/Modules/_sqlite/connection.c
+++ b/Modules/_sqlite/connection.c
@@ -34,6 +34,19 @@
static int connection_set_isolation_level(Connection* self, PyObject* isolation_level);
+
+void _sqlite3_result_error(sqlite3_context* ctx, const char* errmsg, int len)
+{
+ /* in older SQLite versions, calling sqlite3_result_error in callbacks
+ * triggers a bug in SQLite that leads either to irritating results or
+ * segfaults, depending on the SQLite version */
+#if SQLITE_VERSION_NUMBER >= 3003003
+ sqlite3_result_error(ctx, errmsg, len);
+#else
+ PyErr_SetString(OperationalError, errmsg);
+#endif
+}
+
int connection_init(Connection* self, PyObject* args, PyObject* kwargs)
{
static char *kwlist[] = {"database", "timeout", "detect_types", "isolation_level", "check_same_thread", "factory", "cached_statements", NULL, NULL};
@@ -405,8 +418,6 @@ void _set_result(sqlite3_context* context, PyObject* py_val)
PyObject* stringval;
if ((!py_val) || PyErr_Occurred()) {
- /* Errors in callbacks are ignored, and we return NULL */
- PyErr_Clear();
sqlite3_result_null(context);
} else if (py_val == Py_None) {
sqlite3_result_null(context);
@@ -519,8 +530,17 @@ void _func_callback(sqlite3_context* context, int argc, sqlite3_value** argv)
Py_DECREF(args);
}
- _set_result(context, py_retval);
- Py_XDECREF(py_retval);
+ if (py_retval) {
+ _set_result(context, py_retval);
+ Py_DECREF(py_retval);
+ } else {
+ if (_enable_callback_tracebacks) {
+ PyErr_Print();
+ } else {
+ PyErr_Clear();
+ }
+ _sqlite3_result_error(context, "user-defined function raised exception", -1);
+ }
PyGILState_Release(threadstate);
}
@@ -545,8 +565,13 @@ static void _step_callback(sqlite3_context *context, int argc, sqlite3_value** p
*aggregate_instance = PyObject_CallFunction(aggregate_class, "");
if (PyErr_Occurred()) {
- PyErr_Clear();
*aggregate_instance = 0;
+ if (_enable_callback_tracebacks) {
+ PyErr_Print();
+ } else {
+ PyErr_Clear();
+ }
+ _sqlite3_result_error(context, "user-defined aggregate's '__init__' method raised error", -1);
goto error;
}
}
@@ -565,7 +590,12 @@ static void _step_callback(sqlite3_context *context, int argc, sqlite3_value** p
Py_DECREF(args);
if (!function_result) {
- PyErr_Clear();
+ if (_enable_callback_tracebacks) {
+ PyErr_Print();
+ } else {
+ PyErr_Clear();
+ }
+ _sqlite3_result_error(context, "user-defined aggregate's 'step' method raised error", -1);
}
error:
@@ -597,13 +627,16 @@ void _final_callback(sqlite3_context* context)
function_result = PyObject_CallMethod(*aggregate_instance, "finalize", "");
if (!function_result) {
- PyErr_Clear();
- Py_INCREF(Py_None);
- function_result = Py_None;
+ if (_enable_callback_tracebacks) {
+ PyErr_Print();
+ } else {
+ PyErr_Clear();
+ }
+ _sqlite3_result_error(context, "user-defined aggregate's 'finalize' method raised error", -1);
+ } else {
+ _set_result(context, function_result);
}
- _set_result(context, function_result);
-
error:
Py_XDECREF(*aggregate_instance);
Py_XDECREF(function_result);
@@ -631,7 +664,7 @@ void _drop_unused_statement_references(Connection* self)
for (i = 0; i < PyList_Size(self->statements); i++) {
weakref = PyList_GetItem(self->statements, i);
- if (weakref != Py_None) {
+ if (PyWeakref_GetObject(weakref) != Py_None) {
if (PyList_Append(new_list, weakref) != 0) {
Py_DECREF(new_list);
return;
@@ -699,6 +732,61 @@ PyObject* connection_create_aggregate(Connection* self, PyObject* args, PyObject
}
}
+int _authorizer_callback(void* user_arg, int action, const char* arg1, const char* arg2 , const char* dbname, const char* access_attempt_source)
+{
+ PyObject *ret;
+ int rc;
+ PyGILState_STATE gilstate;
+
+ gilstate = PyGILState_Ensure();
+ ret = PyObject_CallFunction((PyObject*)user_arg, "issss", action, arg1, arg2, dbname, access_attempt_source);
+
+ if (!ret) {
+ if (_enable_callback_tracebacks) {
+ PyErr_Print();
+ } else {
+ PyErr_Clear();
+ }
+
+ rc = SQLITE_DENY;
+ } else {
+ if (PyInt_Check(ret)) {
+ rc = (int)PyInt_AsLong(ret);
+ } else {
+ rc = SQLITE_DENY;
+ }
+ Py_DECREF(ret);
+ }
+
+ PyGILState_Release(gilstate);
+ return rc;
+}
+
+PyObject* connection_set_authorizer(Connection* self, PyObject* args, PyObject* kwargs)
+{
+ PyObject* authorizer_cb;
+
+ static char *kwlist[] = { "authorizer_callback", NULL };
+ int rc;
+
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:set_authorizer",
+ kwlist, &authorizer_cb)) {
+ return NULL;
+ }
+
+ rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb);
+
+ if (rc != SQLITE_OK) {
+ PyErr_SetString(OperationalError, "Error setting authorizer callback");
+ return NULL;
+ } else {
+ PyDict_SetItem(self->function_pinboard, authorizer_cb, Py_None);
+
+ Py_INCREF(Py_None);
+ return Py_None;
+ }
+}
+
int check_thread(Connection* self)
{
if (self->check_same_thread) {
@@ -975,6 +1063,24 @@ finally:
}
static PyObject *
+connection_interrupt(Connection* self, PyObject* args)
+{
+ PyObject* retval = NULL;
+
+ if (!check_connection(self)) {
+ goto finally;
+ }
+
+ sqlite3_interrupt(self->db);
+
+ Py_INCREF(Py_None);
+ retval = Py_None;
+
+finally:
+ return retval;
+}
+
+static PyObject *
connection_create_collation(Connection* self, PyObject* args)
{
PyObject* callable;
@@ -1067,6 +1173,8 @@ static PyMethodDef connection_methods[] = {
PyDoc_STR("Creates a new function. Non-standard.")},
{"create_aggregate", (PyCFunction)connection_create_aggregate, METH_VARARGS|METH_KEYWORDS,
PyDoc_STR("Creates a new aggregate. Non-standard.")},
+ {"set_authorizer", (PyCFunction)connection_set_authorizer, METH_VARARGS|METH_KEYWORDS,
+ PyDoc_STR("Sets authorizer callback. Non-standard.")},
{"execute", (PyCFunction)connection_execute, METH_VARARGS,
PyDoc_STR("Executes a SQL statement. Non-standard.")},
{"executemany", (PyCFunction)connection_executemany, METH_VARARGS,
@@ -1075,6 +1183,8 @@ static PyMethodDef connection_methods[] = {
PyDoc_STR("Executes a multiple SQL statements at once. Non-standard.")},
{"create_collation", (PyCFunction)connection_create_collation, METH_VARARGS,
PyDoc_STR("Creates a collation function. Non-standard.")},
+ {"interrupt", (PyCFunction)connection_interrupt, METH_NOARGS,
+ PyDoc_STR("Abort any pending database operation. Non-standard.")},
{NULL, NULL}
};