diff options
Diffstat (limited to 'MySQLdb/cursors.py')
-rw-r--r-- | MySQLdb/cursors.py | 533 |
1 files changed, 533 insertions, 0 deletions
diff --git a/MySQLdb/cursors.py b/MySQLdb/cursors.py new file mode 100644 index 0000000..a8cfa3e --- /dev/null +++ b/MySQLdb/cursors.py @@ -0,0 +1,533 @@ +"""MySQLdb Cursors + +This module implements Cursors of various types for MySQLdb. By +default, MySQLdb uses the Cursor class. + +""" + +import re +import sys +try: + from types import ListType, TupleType, UnicodeType +except ImportError: + # Python 3 + ListType = list + TupleType = tuple + UnicodeType = str + +restr = r""" + \s + values + \s* + ( + \( + [^()']* + (?: + (?: + (?:\( + # ( - editor hightlighting helper + [^)]* + \)) + | + ' + [^\\']* + (?:\\.[^\\']*)* + ' + ) + [^()']* + )* + \) + ) +""" + +insert_values = re.compile(restr, re.S | re.I | re.X) + +from _mysql_exceptions import Warning, Error, InterfaceError, DataError, \ + DatabaseError, OperationalError, IntegrityError, InternalError, \ + NotSupportedError, ProgrammingError + + +class BaseCursor(object): + + """A base for Cursor classes. Useful attributes: + + description + A tuple of DB API 7-tuples describing the columns in + the last executed query; see PEP-249 for details. + + description_flags + Tuple of column flags for last query, one entry per column + in the result set. Values correspond to those in + MySQLdb.constants.FLAG. See MySQL documentation (C API) + for more information. Non-standard extension. + + arraysize + default number of rows fetchmany() will fetch + + """ + + from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \ + DatabaseError, DataError, OperationalError, IntegrityError, \ + InternalError, ProgrammingError, NotSupportedError + + _defer_warnings = False + + def __init__(self, connection): + from weakref import proxy + + self.connection = proxy(connection) + self.description = None + self.description_flags = None + self.rowcount = -1 + self.arraysize = 1 + self._executed = None + self.lastrowid = None + self.messages = [] + self.errorhandler = connection.errorhandler + self._result = None + self._warnings = 0 + self._info = None + self.rownumber = None + + def __del__(self): + self.close() + self.errorhandler = None + self._result = None + + def close(self): + """Close the cursor. No further queries will be possible.""" + if not self.connection: return + while self.nextset(): pass + self.connection = None + + def _check_executed(self): + if not self._executed: + self.errorhandler(self, ProgrammingError, "execute() first") + + def _warning_check(self): + from warnings import warn + if self._warnings: + warnings = self._get_db().show_warnings() + if warnings: + # This is done in two loops in case + # Warnings are set to raise exceptions. + for w in warnings: + self.messages.append((self.Warning, w)) + for w in warnings: + warn(w[-1], self.Warning, 3) + elif self._info: + self.messages.append((self.Warning, self._info)) + warn(self._info, self.Warning, 3) + + def nextset(self): + """Advance to the next result set. + + Returns None if there are no more result sets. + """ + if self._executed: + self.fetchall() + del self.messages[:] + + db = self._get_db() + nr = db.next_result() + if nr == -1: + return None + self._do_get_result() + self._post_get_result() + self._warning_check() + return 1 + + def _post_get_result(self): pass + + def _do_get_result(self): + db = self._get_db() + self._result = self._get_result() + self.rowcount = db.affected_rows() + self.rownumber = 0 + self.description = self._result and self._result.describe() or None + self.description_flags = self._result and self._result.field_flags() or None + self.lastrowid = db.insert_id() + self._warnings = db.warning_count() + self._info = db.info() + + def setinputsizes(self, *args): + """Does nothing, required by DB API.""" + + def setoutputsizes(self, *args): + """Does nothing, required by DB API.""" + + def _get_db(self): + if not self.connection: + self.errorhandler(self, ProgrammingError, "cursor closed") + return self.connection + + def execute(self, query, args=None): + + """Execute a query. + + query -- string, query to execute on server + args -- optional sequence or mapping, parameters to use with query. + + Note: If args is a sequence, then %s must be used as the + parameter placeholder in the query. If a mapping is used, + %(key)s must be used as the placeholder. + + Returns long integer rows affected, if any + + """ + del self.messages[:] + db = self._get_db() + charset = db.character_set_name() + if isinstance(query, unicode): + query = query.encode(charset) + if args is not None: + query = query % db.literal(args) + try: + r = None + r = self._query(query) + except TypeError, m: + if m.args[0] in ("not enough arguments for format string", + "not all arguments converted"): + self.messages.append((ProgrammingError, m.args[0])) + self.errorhandler(self, ProgrammingError, m.args[0]) + else: + self.messages.append((TypeError, m)) + self.errorhandler(self, TypeError, m) + except (SystemExit, KeyboardInterrupt): + raise + except: + exc, value, tb = sys.exc_info() + del tb + self.messages.append((exc, value)) + self.errorhandler(self, exc, value) + self._executed = query + if not self._defer_warnings: self._warning_check() + return r + + def executemany(self, query, args): + + """Execute a multi-row query. + + query -- string, query to execute on server + + args + + Sequence of sequences or mappings, parameters to use with + query. + + Returns long integer rows affected, if any. + + This method improves performance on multiple-row INSERT and + REPLACE. Otherwise it is equivalent to looping over args with + execute(). + + """ + del self.messages[:] + db = self._get_db() + if not args: return + charset = db.character_set_name() + if isinstance(query, unicode): query = query.encode(charset) + m = insert_values.search(query) + if not m: + r = 0 + for a in args: + r = r + self.execute(query, a) + return r + p = m.start(1) + e = m.end(1) + qv = m.group(1) + try: + q = [ qv % db.literal(a) for a in args ] + except TypeError, msg: + if msg.args[0] in ("not enough arguments for format string", + "not all arguments converted"): + self.errorhandler(self, ProgrammingError, msg.args[0]) + else: + self.errorhandler(self, TypeError, msg) + except (SystemExit, KeyboardInterrupt): + raise + except: + exc, value, tb = sys.exc_info() + del tb + self.errorhandler(self, exc, value) + r = self._query('\n'.join([query[:p], ',\n'.join(q), query[e:]])) + if not self._defer_warnings: self._warning_check() + return r + + def callproc(self, procname, args=()): + + """Execute stored procedure procname with args + + procname -- string, name of procedure to execute on server + + args -- Sequence of parameters to use with procedure + + Returns the original args. + + Compatibility warning: PEP-249 specifies that any modified + parameters must be returned. This is currently impossible + as they are only available by storing them in a server + variable and then retrieved by a query. Since stored + procedures return zero or more result sets, there is no + reliable way to get at OUT or INOUT parameters via callproc. + The server variables are named @_procname_n, where procname + is the parameter above and n is the position of the parameter + (from zero). Once all result sets generated by the procedure + have been fetched, you can issue a SELECT @_procname_0, ... + query using .execute() to get any OUT or INOUT values. + + Compatibility warning: The act of calling a stored procedure + itself creates an empty result set. This appears after any + result sets generated by the procedure. This is non-standard + behavior with respect to the DB-API. Be sure to use nextset() + to advance through all result sets; otherwise you may get + disconnected. + """ + + db = self._get_db() + charset = db.character_set_name() + for index, arg in enumerate(args): + q = "SET @_%s_%d=%s" % (procname, index, + db.literal(arg)) + if isinstance(q, unicode): + q = q.encode(charset) + self._query(q) + self.nextset() + + q = "CALL %s(%s)" % (procname, + ','.join(['@_%s_%d' % (procname, i) + for i in range(len(args))])) + if type(q) is UnicodeType: + q = q.encode(charset) + self._query(q) + self._executed = q + if not self._defer_warnings: self._warning_check() + return args + + def _do_query(self, q): + db = self._get_db() + self._last_executed = q + db.query(q) + self._do_get_result() + return self.rowcount + + def _query(self, q): return self._do_query(q) + + def _fetch_row(self, size=1): + if not self._result: + return () + return self._result.fetch_row(size, self._fetch_type) + + def __iter__(self): + return iter(self.fetchone, None) + + Warning = Warning + Error = Error + InterfaceError = InterfaceError + DatabaseError = DatabaseError + DataError = DataError + OperationalError = OperationalError + IntegrityError = IntegrityError + InternalError = InternalError + ProgrammingError = ProgrammingError + NotSupportedError = NotSupportedError + + +class CursorStoreResultMixIn(object): + + """This is a MixIn class which causes the entire result set to be + stored on the client side, i.e. it uses mysql_store_result(). If the + result set can be very large, consider adding a LIMIT clause to your + query, or using CursorUseResultMixIn instead.""" + + def _get_result(self): return self._get_db().store_result() + + def _query(self, q): + rowcount = self._do_query(q) + self._post_get_result() + return rowcount + + def _post_get_result(self): + self._rows = self._fetch_row(0) + self._result = None + + def fetchone(self): + """Fetches a single row from the cursor. None indicates that + no more rows are available.""" + self._check_executed() + if self.rownumber >= len(self._rows): return None + result = self._rows[self.rownumber] + self.rownumber = self.rownumber+1 + return result + + def fetchmany(self, size=None): + """Fetch up to size rows from the cursor. Result set may be smaller + than size. If size is not defined, cursor.arraysize is used.""" + self._check_executed() + end = self.rownumber + (size or self.arraysize) + result = self._rows[self.rownumber:end] + self.rownumber = min(end, len(self._rows)) + return result + + def fetchall(self): + """Fetchs all available rows from the cursor.""" + self._check_executed() + if self.rownumber: + result = self._rows[self.rownumber:] + else: + result = self._rows + self.rownumber = len(self._rows) + return result + + def scroll(self, value, mode='relative'): + """Scroll the cursor in the result set to a new position according + to mode. + + If mode is 'relative' (default), value is taken as offset to + the current position in the result set, if set to 'absolute', + value states an absolute target position.""" + self._check_executed() + if mode == 'relative': + r = self.rownumber + value + elif mode == 'absolute': + r = value + else: + self.errorhandler(self, ProgrammingError, + "unknown scroll mode %s" % repr(mode)) + if r < 0 or r >= len(self._rows): + self.errorhandler(self, IndexError, "out of range") + self.rownumber = r + + def __iter__(self): + self._check_executed() + result = self.rownumber and self._rows[self.rownumber:] or self._rows + return iter(result) + + +class CursorUseResultMixIn(object): + + """This is a MixIn class which causes the result set to be stored + in the server and sent row-by-row to client side, i.e. it uses + mysql_use_result(). You MUST retrieve the entire result set and + close() the cursor before additional queries can be peformed on + the connection.""" + + _defer_warnings = True + + def _get_result(self): return self._get_db().use_result() + + def fetchone(self): + """Fetches a single row from the cursor.""" + self._check_executed() + r = self._fetch_row(1) + if not r: + self._warning_check() + return None + self.rownumber = self.rownumber + 1 + return r[0] + + def fetchmany(self, size=None): + """Fetch up to size rows from the cursor. Result set may be smaller + than size. If size is not defined, cursor.arraysize is used.""" + self._check_executed() + r = self._fetch_row(size or self.arraysize) + self.rownumber = self.rownumber + len(r) + if not r: + self._warning_check() + return r + + def fetchall(self): + """Fetchs all available rows from the cursor.""" + self._check_executed() + r = self._fetch_row(0) + self.rownumber = self.rownumber + len(r) + self._warning_check() + return r + + def __iter__(self): + return self + + def next(self): + row = self.fetchone() + if row is None: + raise StopIteration + return row + + +class CursorTupleRowsMixIn(object): + + """This is a MixIn class that causes all rows to be returned as tuples, + which is the standard form required by DB API.""" + + _fetch_type = 0 + + +class CursorDictRowsMixIn(object): + + """This is a MixIn class that causes all rows to be returned as + dictionaries. This is a non-standard feature.""" + + _fetch_type = 1 + + def fetchoneDict(self): + """Fetch a single row as a dictionary. Deprecated: + Use fetchone() instead. Will be removed in 1.3.""" + from warnings import warn + warn("fetchoneDict() is non-standard and will be removed in 1.3", + DeprecationWarning, 2) + return self.fetchone() + + def fetchmanyDict(self, size=None): + """Fetch several rows as a list of dictionaries. Deprecated: + Use fetchmany() instead. Will be removed in 1.3.""" + from warnings import warn + warn("fetchmanyDict() is non-standard and will be removed in 1.3", + DeprecationWarning, 2) + return self.fetchmany(size) + + def fetchallDict(self): + """Fetch all available rows as a list of dictionaries. Deprecated: + Use fetchall() instead. Will be removed in 1.3.""" + from warnings import warn + warn("fetchallDict() is non-standard and will be removed in 1.3", + DeprecationWarning, 2) + return self.fetchall() + + +class CursorOldDictRowsMixIn(CursorDictRowsMixIn): + + """This is a MixIn class that returns rows as dictionaries with + the same key convention as the old Mysqldb (MySQLmodule). Don't + use this.""" + + _fetch_type = 2 + + +class Cursor(CursorStoreResultMixIn, CursorTupleRowsMixIn, + BaseCursor): + + """This is the standard Cursor class that returns rows as tuples + and stores the result set in the client.""" + + +class DictCursor(CursorStoreResultMixIn, CursorDictRowsMixIn, + BaseCursor): + + """This is a Cursor class that returns rows as dictionaries and + stores the result set in the client.""" + + +class SSCursor(CursorUseResultMixIn, CursorTupleRowsMixIn, + BaseCursor): + + """This is a Cursor class that returns rows as tuples and stores + the result set in the server.""" + + +class SSDictCursor(CursorUseResultMixIn, CursorDictRowsMixIn, + BaseCursor): + + """This is a Cursor class that returns rows as dictionaries and + stores the result set in the server.""" + + |