summaryrefslogtreecommitdiff
path: root/adbh.py
diff options
context:
space:
mode:
authorSylvain <syt@logilab.fr>2007-10-25 18:05:56 +0200
committerSylvain <syt@logilab.fr>2007-10-25 18:05:56 +0200
commit4c9f3d2c4ceaa077b9e9f3b0e1afe765ab890bf7 (patch)
tree7fc23606fbcb19f1b591a7a36d59cda994623c04 /adbh.py
parent5b99e7b0b2c3caba7c3dfb36969df11d57669c29 (diff)
downloadlogilab-common-4c9f3d2c4ceaa077b9e9f3b0e1afe765ab890bf7.tar.gz
more db extensions, db helpers in a separated module
Diffstat (limited to 'adbh.py')
-rw-r--r--adbh.py451
1 files changed, 451 insertions, 0 deletions
diff --git a/adbh.py b/adbh.py
new file mode 100644
index 0000000..470f772
--- /dev/null
+++ b/adbh.py
@@ -0,0 +1,451 @@
+# Copyright (c) 2002-2007 LOGILAB S.A. (Paris, FRANCE).
+# http://www.logilab.fr/ -- mailto:contact@logilab.fr
+#
+# This program is free software; you can redistribute it and/or modify it under
+# the terms of the GNU General Public License as published by the Free Software
+# Foundation; either version 2 of the License, or (at your option) any later
+# version.
+#
+# This program is distributed in the hope that it will be useful, but WITHOUT
+# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License along with
+# this program; if not, write to the Free Software Foundation, Inc.,
+# 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
+"""This module contains helpers for DBMS specific (advanced or non standard)
+functionalities
+
+Helpers are provided for postgresql, mysql and sqlite.
+"""
+
+from logilab.common.deprecation import obsolete
+
+class BadQuery(Exception): pass
+class UnsupportedFunction(BadQuery): pass
+
+
+class metafunc(type):
+ def __new__(mcs, name, bases, dict):
+ dict['name'] = name
+ return type.__new__(mcs, name, bases, dict)
+
+
+class FunctionDescr(object):
+ __metaclass__ = metafunc
+
+ rtype = None # None <-> returned type should be the same as the first argument
+ aggregat = False
+ minargs = 1
+ maxargs = 1
+
+ def __init__(self, name=None, rtype=rtype, aggregat=aggregat):
+ if name is not None:
+ name = name.upper()
+ self.name = name
+ self.rtype = rtype
+ self.aggregat = aggregat
+
+ @classmethod
+ def check_nbargs(cls, nbargs):
+ if cls.minargs is not None and \
+ nbargs < cls.minargs:
+ raise BadQuery('not enough argument for function %s' % cls.name)
+ if cls.maxargs is not None and \
+ nbargs < cls.maxargs:
+ raise BadQuery('too many arguments for function %s' % cls.name)
+
+
+class AggrFunctionDescr(FunctionDescr):
+ aggregat = True
+ rtype = None
+
+class MAX(AggrFunctionDescr): pass
+class MIN(AggrFunctionDescr): pass
+class SUM(AggrFunctionDescr): pass
+class COUNT(AggrFunctionDescr):
+ rtype = 'Int'
+class AVG(AggrFunctionDescr):
+ rtype = 'Float'
+
+class UPPER(FunctionDescr):
+ rtype = 'String'
+class LOWER(FunctionDescr):
+ rtype = 'String'
+class IN(FunctionDescr):
+ """this is actually a 'keyword' function..."""
+ maxargs = None
+class LENGTH(FunctionDescr):
+ rtype = 'Int'
+
+class _GenericAdvFuncHelper:
+ """Generic helper, trying to provide generic way to implement
+ specific functionnalities from others DBMS
+
+ An exception is raised when the functionality is not emulatable
+ """
+ # DBMS resources descriptors and accessors
+
+ users_support = True
+ groups_support = True
+ ilike_support = True
+
+ FUNCTIONS = {
+ # aggregat functions
+ 'MIN': MIN, 'MAX': MAX,
+ 'SUM': SUM,
+ 'COUNT':COUNT,
+ 'AVG': AVG,
+ # transformation functions
+ 'UPPER': UPPER, 'LOWER': LOWER,
+ 'LENGTH': LENGTH,
+ # keyword function
+ 'IN': IN
+ }
+
+ TYPE_MAPPING = {
+ 'String' : 'text',
+ 'Int' : 'integer',
+ 'Float' : 'float',
+ 'Boolean' : 'boolean',
+ 'Date' : 'date',
+ 'Time' : 'time',
+ 'Datetime' : 'timestamp',
+ 'Interval' : 'interval',
+ 'Password' : 'bytea',
+ 'Bytes' : 'bytea',
+ # FIXME: still there for use from erudi, should be moved out
+ 'COUNT' : 'integer',
+ 'MIN' : 'integer',
+ 'MAX' : 'integer',
+ 'SUM' : 'integer',
+ 'LOWER' : 'text',
+ 'UPPER' : 'text',
+ }
+
+ @classmethod
+ def function_description(cls, funcname):
+ """return the description (`FunctionDescription`) for a RQL function"""
+ try:
+ return cls.FUNCTIONS[funcname.upper()]
+ except KeyError:
+ raise UnsupportedFunction(funcname)
+
+ # @obsolete('use users_support attribute')
+ def support_users(self):
+ """return True if the DBMS support users (this is usually
+ not true for in memory DBMS)
+ """
+ return self.users_support
+ support_user = obsolete('use users_support attribute')(support_users)
+
+ # @obsolete('use groups_support attribute')
+ def support_groups(self):
+ """return True if the DBMS support groups"""
+ return self.groups_support
+ support_user = obsolete('use groups_support attribute')(support_groups)
+
+ def system_database(self):
+ """return the system database for the given driver"""
+ raise NotImplementedError('not supported by this DBMS')
+
+ def backup_command(self, dbname, dbhost, dbuser, dbpassword, backupfile,
+ keepownership=True):
+ """return a command to backup the given database"""
+ raise NotImplementedError('not supported by this DBMS')
+
+ def restore_commands(self, dbname, dbhost, dbuser, backupfile,
+ encoding='utf-8', keepownership=True, drop=True):
+ """return a list of commands to restore a backup the given database"""
+ raise NotImplementedError('not supported by this DBMS')
+
+ # helpers to standardize SQL according to the database
+
+ def sql_current_date(self):
+ return 'CURRENT_DATE'
+
+ def sql_current_time(self):
+ return 'CURRENT_TIME'
+
+ def sql_current_timestamp(self):
+ return 'CURRENT_TIMESTAMP'
+
+ def sql_create_sequence(self, seq_name):
+ return '''CREATE TABLE %s (last INTEGER);
+INSERT INTO %s VALUES (0);''' % (seq_name, seq_name)
+
+ def sql_drop_sequence(self, seq_name):
+ return 'DROP TABLE %s;' % seq_name
+
+ def sqls_increment_sequence(self, seq_name):
+ return ('UPDATE %s SET last=last+1;' % seq_name,
+ 'SELECT last FROM %s;' % seq_name)
+
+ def sql_temporary_table(self, table_name, table_schema,
+ drop_on_commit=True):
+ return "CREATE TEMPORARY TABLE %s (%s);" % (table_name,
+ table_schema)
+
+ def sql_drop_unique_constraint(self, table, column):
+ # XXX postgres specific ?
+ return 'ALTER TABLE %s DROP CONSTRAINT %s_%s_key' % (
+ table, table, column)
+
+ def boolean_value(self, value):
+ if value:
+ return 'TRUE'
+ else:
+ return 'FALSE'
+
+ def increment_sequence(self, cursor, seq_name):
+ for sql in self.sqls_increment_sequence(seq_name):
+ cursor.execute(sql)
+ return cursor.fetchone()[0]
+
+ def create_user(self, cursor, user, password):
+ """create a new database user"""
+ if not self.users_support:
+ raise NotImplementedError('not supported by this DBMS')
+ cursor.execute("CREATE USER %(user)s "
+ "WITH PASSWORD '%(password)s'" % locals())
+
+ def user_exists(self, cursor, username):
+ """return True if a user with the given username exists"""
+ return username in self.list_users(cursor)
+
+ def list_users(self, cursor):
+ """return the list of existing database users"""
+ raise NotImplementedError('not supported by this DBMS')
+
+ def create_database(self, cursor, dbname, owner=None, encoding='utf-8'):
+ """create a new database"""
+ raise NotImplementedError('not supported by this DBMS')
+
+ def list_databases(self):
+ """return the list of existing databases"""
+ raise NotImplementedError('not supported by this DBMS')
+
+ def list_tables(self, cursor):
+ """return the list of tables of a database"""
+ raise NotImplementedError('not supported by this DBMS')
+
+
+
+def pgdbcmd(cmd, dbhost, dbuser):
+ cmd = [cmd]
+ if dbhost:
+ cmd.append('--host=%s' % dbhost)
+ if dbuser:
+ cmd.append('--username=%s' % dbuser)
+ return cmd
+
+
+class _PGAdvFuncHelper(_GenericAdvFuncHelper):
+ """Postgres helper, taking advantage of postgres SEQUENCE support
+ """
+ # modifiable but should not be shared
+ FUNCTIONS = _GenericAdvFuncHelper.FUNCTIONS.copy()
+
+ def system_database(self):
+ """return the system database for the given driver"""
+ return 'template1'
+
+ def backup_command(self, dbname, dbhost, dbuser, backupfile,
+ keepownership=True):
+ """return a command to backup the given database"""
+ cmd = ['pg_dump -Fc']
+ if dbhost:
+ cmd.append('--host=%s' % dbhost)
+ if dbuser:
+ cmd.append('--username=%s' % dbuser)
+ if not keepownership:
+ cmd.append('--no-owner')
+ cmd.append('--file=%s' % backupfile)
+ cmd.append(dbname)
+ return ' '.join(cmd)
+
+ def restore_commands(self, dbname, dbhost, dbuser, backupfile,
+ encoding='utf-8', keepownership=True, drop=True):
+ """return a list of commands to restore a backup the given database"""
+ cmds = []
+ if drop:
+ cmd = pgdbcmd('dropdb', dbhost, dbuser)
+ cmd.append(dbname)
+ cmds.append(' '.join(cmd))
+ cmd = pgdbcmd('createdb -T template0 -E %s' % encoding, dbhost, dbuser)
+ cmd.append(dbname)
+ cmds.append(' '.join(cmd))
+ cmd = pgdbcmd('pg_restore -Fc', dbhost, dbuser)
+ cmd.append('--dbname %s' % dbname)
+ if not keepownership:
+ cmd.append('--no-owner')
+ cmd.append(backupfile)
+ cmds.append(' '.join(cmd))
+ return cmds
+
+ def sql_create_sequence(self, seq_name):
+ return 'CREATE SEQUENCE %s;' % seq_name
+
+ def sql_drop_sequence(self, seq_name):
+ return 'DROP SEQUENCE %s;' % seq_name
+
+ def sqls_increment_sequence(self, seq_name):
+ return ("SELECT nextval('%s');" % seq_name,)
+
+ def sql_temporary_table(self, table_name, table_schema,
+ drop_on_commit=True):
+ if not drop_on_commit:
+ return "CREATE TEMPORARY TABLE %s (%s);" % (table_name,
+ table_schema)
+ return "CREATE TEMPORARY TABLE %s (%s) ON COMMIT DROP;" % (table_name,
+ table_schema)
+
+ def list_users(self, cursor, username=None):
+ """return the list of existing database users"""
+ if username:
+ warn('username argument is deprecated, use user_exists method',
+ DeprecationWarning, stacklevel=2)
+ return self.user_exists(cursor, username)
+ cursor.execute("SELECT usename FROM pg_user")
+ return [r[0] for r in cursor.fetchall()]
+
+ def create_database(self, cursor, dbname, owner=None, encoding='utf-8'):
+ """create a new database"""
+ sql = "CREATE DATABASE %(dbname)s"
+ if owner:
+ sql += " WITH OWNER=%(owner)s"
+ if encoding:
+ sql += " ENCODING='%(encoding)s'"
+ cursor.execute(sql % locals())
+
+ def list_databases(self, cursor):
+ """return the list of existing databases"""
+ cursor.execute('SELECT datname FROM pg_database')
+ return [r[0] for r in cursor.fetchall()]
+
+ def list_tables(self, cursor):
+ """return the list of tables of a database"""
+ cursor.execute("SELECT tablename FROM pg_tables")
+ return cursor.fetchall()
+
+ def create_language(self, cursor, extlang):
+ """postgres specific method to install a procedural language on a database"""
+ # make sure plpythonu is not directly in template1
+ cursor.execute("SELECT * FROM pg_language WHERE lanname='%s';" % extlang)
+ if cursor.fetchall():
+ print '%s language already installed' % extlang
+ else:
+ cursor.execute('CREATE LANGUAGE %s' % extlang)
+ print '%s language installed' % extlang
+
+
+class _SqliteAdvFuncHelper(_GenericAdvFuncHelper):
+ """Generic helper, trying to provide generic way to implement
+ specific functionnalities from others DBMS
+
+ An exception is raised when the functionality is not emulatable
+ """
+ # modifiable but should not be shared
+ FUNCTIONS = _GenericAdvFuncHelper.FUNCTIONS.copy()
+ users_support = groups_support = False
+ ilike_support = False
+
+
+class _MyAdvFuncHelper(_GenericAdvFuncHelper):
+ """Postgres helper, taking advantage of postgres SEQUENCE support
+ """
+ # modifiable but should not be shared
+ FUNCTIONS = _GenericAdvFuncHelper.FUNCTIONS.copy()
+ TYPE_MAPPING = _GenericAdvFuncHelper.TYPE_MAPPING.copy()
+ TYPE_MAPPING['Password'] = 'tinyblob'
+ TYPE_MAPPING['String'] = 'mediumtext'
+ TYPE_MAPPING['Bytes'] = 'longblob'
+
+ def system_database(self):
+ """return the system database for the given driver"""
+ return ''
+
+ def backup_command(self, dbname, dbhost, dbuser, backupfile,
+ keepownership=True):
+ """return a command to backup the given database"""
+ # XXX compress
+ return 'mysqldump -h %s -u %s -r %s %s' % (dbhost, dbuser, backupfile, dbname)
+
+ def restore_commands(self, dbname, dbhost, dbuser, backupfile,
+ encoding='utf-8', keepownership=True, drop=True):
+ """return a list of commands to restore a backup the given database"""
+ cmds = []
+ if drop:
+ cmd = 'echo "DROP DATABASE %s;" | mysql -h %s -u %s' % (dbname, dbhost, dbuser)
+ cmds.append(cmd)
+ cmd = 'echo "%s;" | mysql -h %s -u %s' % (self.sql_create_database(dbname, encoding),
+ dbhost, dbuser)
+ cmds.append(cmd)
+ cmd = pgdbcmd('mysql -h %s -u %s < %s' % (dbname, dbhost, dbuser, backupfile))
+ cmds.append(cmd)
+ return cmds
+
+ def sql_temporary_table(self, table_name, table_schema,
+ drop_on_commit=True):
+ if not drop_on_commit:
+ return "CREATE TEMPORARY TABLE %s (%s);" % (table_name,
+ table_schema)
+ return "CREATE TEMPORARY TABLE %s (%s) ON COMMIT DROP;" % (table_name,
+ table_schema)
+
+
+ def boolean_value(self, value):
+ if value:
+ return True
+ else:
+ return False
+
+ def list_users(self, cursor):
+ """return the list of existing database users"""
+ # Host, Password
+ cursor.execute("SELECT User FROM mysql.user")
+ return [r[0] for r in cursor.fetchall()]
+
+ def list_databases(self, cursor):
+ """return the list of existing databases"""
+ cursor.execute('SHOW DATABASES')
+ return [r[0] for r in cursor.fetchall()]
+
+ def sql_create_database(self, dbname, encoding='utf-8'):
+ sql = "CREATE DATABASE %(dbname)s"
+ if encoding:
+ sql += " CHARACTER SET %(encoding)s"
+ return sql % locals()
+
+ def create_database(self, cursor, dbname, owner=None, encoding='utf-8'):
+ """create a new database"""
+ cursor.execute(self.sql_create_database(dbname, encoding))
+ if owner:
+ cursor.execute('GRANT ALL ON `%s`.* to %s' % (dbname, owner))
+
+ def list_tables(self, cursor):
+ """return the list of tables of a database"""
+ cursor.execute("SHOW TABLES")
+ return [r[0] for r in cursor.fetchall()]
+
+
+
+ADV_FUNC_HELPER_DIRECTORY = {'postgres': _PGAdvFuncHelper(),
+ 'sqlite': _SqliteAdvFuncHelper(),
+ 'mysql': _MyAdvFuncHelper(),
+ None: _GenericAdvFuncHelper()}
+
+
+
+def register_function(driver, funcdef):
+ if isinstance(funcdef, basestring) :
+ funcdef = FunctionDescr(funcdef.upper())
+ assert not funcdef.name in FUNCTIONS, \
+ '%s is already registered' % funcdef.name
+ ADV_FUNC_HELPER_DIRECTORY[driver].FUNCTIONS[funcdef.name] = funcdef
+
+
+def get_adv_func_helper(driver):
+ """returns an advanced function helper for the given driver"""
+ return ADV_FUNC_HELPER_DIRECTORY.get(driver,
+ ADV_FUNC_HELPER_DIRECTORY[None])