diff options
author | Sylvain <syt@logilab.fr> | 2007-10-25 18:05:56 +0200 |
---|---|---|
committer | Sylvain <syt@logilab.fr> | 2007-10-25 18:05:56 +0200 |
commit | 4c9f3d2c4ceaa077b9e9f3b0e1afe765ab890bf7 (patch) | |
tree | 7fc23606fbcb19f1b591a7a36d59cda994623c04 /adbh.py | |
parent | 5b99e7b0b2c3caba7c3dfb36969df11d57669c29 (diff) | |
download | logilab-common-4c9f3d2c4ceaa077b9e9f3b0e1afe765ab890bf7.tar.gz |
more db extensions, db helpers in a separated module
Diffstat (limited to 'adbh.py')
-rw-r--r-- | adbh.py | 451 |
1 files changed, 451 insertions, 0 deletions
@@ -0,0 +1,451 @@ +# Copyright (c) 2002-2007 LOGILAB S.A. (Paris, FRANCE). +# http://www.logilab.fr/ -- mailto:contact@logilab.fr +# +# This program is free software; you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the Free Software +# Foundation; either version 2 of the License, or (at your option) any later +# version. +# +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along with +# this program; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. +"""This module contains helpers for DBMS specific (advanced or non standard) +functionalities + +Helpers are provided for postgresql, mysql and sqlite. +""" + +from logilab.common.deprecation import obsolete + +class BadQuery(Exception): pass +class UnsupportedFunction(BadQuery): pass + + +class metafunc(type): + def __new__(mcs, name, bases, dict): + dict['name'] = name + return type.__new__(mcs, name, bases, dict) + + +class FunctionDescr(object): + __metaclass__ = metafunc + + rtype = None # None <-> returned type should be the same as the first argument + aggregat = False + minargs = 1 + maxargs = 1 + + def __init__(self, name=None, rtype=rtype, aggregat=aggregat): + if name is not None: + name = name.upper() + self.name = name + self.rtype = rtype + self.aggregat = aggregat + + @classmethod + def check_nbargs(cls, nbargs): + if cls.minargs is not None and \ + nbargs < cls.minargs: + raise BadQuery('not enough argument for function %s' % cls.name) + if cls.maxargs is not None and \ + nbargs < cls.maxargs: + raise BadQuery('too many arguments for function %s' % cls.name) + + +class AggrFunctionDescr(FunctionDescr): + aggregat = True + rtype = None + +class MAX(AggrFunctionDescr): pass +class MIN(AggrFunctionDescr): pass +class SUM(AggrFunctionDescr): pass +class COUNT(AggrFunctionDescr): + rtype = 'Int' +class AVG(AggrFunctionDescr): + rtype = 'Float' + +class UPPER(FunctionDescr): + rtype = 'String' +class LOWER(FunctionDescr): + rtype = 'String' +class IN(FunctionDescr): + """this is actually a 'keyword' function...""" + maxargs = None +class LENGTH(FunctionDescr): + rtype = 'Int' + +class _GenericAdvFuncHelper: + """Generic helper, trying to provide generic way to implement + specific functionnalities from others DBMS + + An exception is raised when the functionality is not emulatable + """ + # DBMS resources descriptors and accessors + + users_support = True + groups_support = True + ilike_support = True + + FUNCTIONS = { + # aggregat functions + 'MIN': MIN, 'MAX': MAX, + 'SUM': SUM, + 'COUNT':COUNT, + 'AVG': AVG, + # transformation functions + 'UPPER': UPPER, 'LOWER': LOWER, + 'LENGTH': LENGTH, + # keyword function + 'IN': IN + } + + TYPE_MAPPING = { + 'String' : 'text', + 'Int' : 'integer', + 'Float' : 'float', + 'Boolean' : 'boolean', + 'Date' : 'date', + 'Time' : 'time', + 'Datetime' : 'timestamp', + 'Interval' : 'interval', + 'Password' : 'bytea', + 'Bytes' : 'bytea', + # FIXME: still there for use from erudi, should be moved out + 'COUNT' : 'integer', + 'MIN' : 'integer', + 'MAX' : 'integer', + 'SUM' : 'integer', + 'LOWER' : 'text', + 'UPPER' : 'text', + } + + @classmethod + def function_description(cls, funcname): + """return the description (`FunctionDescription`) for a RQL function""" + try: + return cls.FUNCTIONS[funcname.upper()] + except KeyError: + raise UnsupportedFunction(funcname) + + # @obsolete('use users_support attribute') + def support_users(self): + """return True if the DBMS support users (this is usually + not true for in memory DBMS) + """ + return self.users_support + support_user = obsolete('use users_support attribute')(support_users) + + # @obsolete('use groups_support attribute') + def support_groups(self): + """return True if the DBMS support groups""" + return self.groups_support + support_user = obsolete('use groups_support attribute')(support_groups) + + def system_database(self): + """return the system database for the given driver""" + raise NotImplementedError('not supported by this DBMS') + + def backup_command(self, dbname, dbhost, dbuser, dbpassword, backupfile, + keepownership=True): + """return a command to backup the given database""" + raise NotImplementedError('not supported by this DBMS') + + def restore_commands(self, dbname, dbhost, dbuser, backupfile, + encoding='utf-8', keepownership=True, drop=True): + """return a list of commands to restore a backup the given database""" + raise NotImplementedError('not supported by this DBMS') + + # helpers to standardize SQL according to the database + + def sql_current_date(self): + return 'CURRENT_DATE' + + def sql_current_time(self): + return 'CURRENT_TIME' + + def sql_current_timestamp(self): + return 'CURRENT_TIMESTAMP' + + def sql_create_sequence(self, seq_name): + return '''CREATE TABLE %s (last INTEGER); +INSERT INTO %s VALUES (0);''' % (seq_name, seq_name) + + def sql_drop_sequence(self, seq_name): + return 'DROP TABLE %s;' % seq_name + + def sqls_increment_sequence(self, seq_name): + return ('UPDATE %s SET last=last+1;' % seq_name, + 'SELECT last FROM %s;' % seq_name) + + def sql_temporary_table(self, table_name, table_schema, + drop_on_commit=True): + return "CREATE TEMPORARY TABLE %s (%s);" % (table_name, + table_schema) + + def sql_drop_unique_constraint(self, table, column): + # XXX postgres specific ? + return 'ALTER TABLE %s DROP CONSTRAINT %s_%s_key' % ( + table, table, column) + + def boolean_value(self, value): + if value: + return 'TRUE' + else: + return 'FALSE' + + def increment_sequence(self, cursor, seq_name): + for sql in self.sqls_increment_sequence(seq_name): + cursor.execute(sql) + return cursor.fetchone()[0] + + def create_user(self, cursor, user, password): + """create a new database user""" + if not self.users_support: + raise NotImplementedError('not supported by this DBMS') + cursor.execute("CREATE USER %(user)s " + "WITH PASSWORD '%(password)s'" % locals()) + + def user_exists(self, cursor, username): + """return True if a user with the given username exists""" + return username in self.list_users(cursor) + + def list_users(self, cursor): + """return the list of existing database users""" + raise NotImplementedError('not supported by this DBMS') + + def create_database(self, cursor, dbname, owner=None, encoding='utf-8'): + """create a new database""" + raise NotImplementedError('not supported by this DBMS') + + def list_databases(self): + """return the list of existing databases""" + raise NotImplementedError('not supported by this DBMS') + + def list_tables(self, cursor): + """return the list of tables of a database""" + raise NotImplementedError('not supported by this DBMS') + + + +def pgdbcmd(cmd, dbhost, dbuser): + cmd = [cmd] + if dbhost: + cmd.append('--host=%s' % dbhost) + if dbuser: + cmd.append('--username=%s' % dbuser) + return cmd + + +class _PGAdvFuncHelper(_GenericAdvFuncHelper): + """Postgres helper, taking advantage of postgres SEQUENCE support + """ + # modifiable but should not be shared + FUNCTIONS = _GenericAdvFuncHelper.FUNCTIONS.copy() + + def system_database(self): + """return the system database for the given driver""" + return 'template1' + + def backup_command(self, dbname, dbhost, dbuser, backupfile, + keepownership=True): + """return a command to backup the given database""" + cmd = ['pg_dump -Fc'] + if dbhost: + cmd.append('--host=%s' % dbhost) + if dbuser: + cmd.append('--username=%s' % dbuser) + if not keepownership: + cmd.append('--no-owner') + cmd.append('--file=%s' % backupfile) + cmd.append(dbname) + return ' '.join(cmd) + + def restore_commands(self, dbname, dbhost, dbuser, backupfile, + encoding='utf-8', keepownership=True, drop=True): + """return a list of commands to restore a backup the given database""" + cmds = [] + if drop: + cmd = pgdbcmd('dropdb', dbhost, dbuser) + cmd.append(dbname) + cmds.append(' '.join(cmd)) + cmd = pgdbcmd('createdb -T template0 -E %s' % encoding, dbhost, dbuser) + cmd.append(dbname) + cmds.append(' '.join(cmd)) + cmd = pgdbcmd('pg_restore -Fc', dbhost, dbuser) + cmd.append('--dbname %s' % dbname) + if not keepownership: + cmd.append('--no-owner') + cmd.append(backupfile) + cmds.append(' '.join(cmd)) + return cmds + + def sql_create_sequence(self, seq_name): + return 'CREATE SEQUENCE %s;' % seq_name + + def sql_drop_sequence(self, seq_name): + return 'DROP SEQUENCE %s;' % seq_name + + def sqls_increment_sequence(self, seq_name): + return ("SELECT nextval('%s');" % seq_name,) + + def sql_temporary_table(self, table_name, table_schema, + drop_on_commit=True): + if not drop_on_commit: + return "CREATE TEMPORARY TABLE %s (%s);" % (table_name, + table_schema) + return "CREATE TEMPORARY TABLE %s (%s) ON COMMIT DROP;" % (table_name, + table_schema) + + def list_users(self, cursor, username=None): + """return the list of existing database users""" + if username: + warn('username argument is deprecated, use user_exists method', + DeprecationWarning, stacklevel=2) + return self.user_exists(cursor, username) + cursor.execute("SELECT usename FROM pg_user") + return [r[0] for r in cursor.fetchall()] + + def create_database(self, cursor, dbname, owner=None, encoding='utf-8'): + """create a new database""" + sql = "CREATE DATABASE %(dbname)s" + if owner: + sql += " WITH OWNER=%(owner)s" + if encoding: + sql += " ENCODING='%(encoding)s'" + cursor.execute(sql % locals()) + + def list_databases(self, cursor): + """return the list of existing databases""" + cursor.execute('SELECT datname FROM pg_database') + return [r[0] for r in cursor.fetchall()] + + def list_tables(self, cursor): + """return the list of tables of a database""" + cursor.execute("SELECT tablename FROM pg_tables") + return cursor.fetchall() + + def create_language(self, cursor, extlang): + """postgres specific method to install a procedural language on a database""" + # make sure plpythonu is not directly in template1 + cursor.execute("SELECT * FROM pg_language WHERE lanname='%s';" % extlang) + if cursor.fetchall(): + print '%s language already installed' % extlang + else: + cursor.execute('CREATE LANGUAGE %s' % extlang) + print '%s language installed' % extlang + + +class _SqliteAdvFuncHelper(_GenericAdvFuncHelper): + """Generic helper, trying to provide generic way to implement + specific functionnalities from others DBMS + + An exception is raised when the functionality is not emulatable + """ + # modifiable but should not be shared + FUNCTIONS = _GenericAdvFuncHelper.FUNCTIONS.copy() + users_support = groups_support = False + ilike_support = False + + +class _MyAdvFuncHelper(_GenericAdvFuncHelper): + """Postgres helper, taking advantage of postgres SEQUENCE support + """ + # modifiable but should not be shared + FUNCTIONS = _GenericAdvFuncHelper.FUNCTIONS.copy() + TYPE_MAPPING = _GenericAdvFuncHelper.TYPE_MAPPING.copy() + TYPE_MAPPING['Password'] = 'tinyblob' + TYPE_MAPPING['String'] = 'mediumtext' + TYPE_MAPPING['Bytes'] = 'longblob' + + def system_database(self): + """return the system database for the given driver""" + return '' + + def backup_command(self, dbname, dbhost, dbuser, backupfile, + keepownership=True): + """return a command to backup the given database""" + # XXX compress + return 'mysqldump -h %s -u %s -r %s %s' % (dbhost, dbuser, backupfile, dbname) + + def restore_commands(self, dbname, dbhost, dbuser, backupfile, + encoding='utf-8', keepownership=True, drop=True): + """return a list of commands to restore a backup the given database""" + cmds = [] + if drop: + cmd = 'echo "DROP DATABASE %s;" | mysql -h %s -u %s' % (dbname, dbhost, dbuser) + cmds.append(cmd) + cmd = 'echo "%s;" | mysql -h %s -u %s' % (self.sql_create_database(dbname, encoding), + dbhost, dbuser) + cmds.append(cmd) + cmd = pgdbcmd('mysql -h %s -u %s < %s' % (dbname, dbhost, dbuser, backupfile)) + cmds.append(cmd) + return cmds + + def sql_temporary_table(self, table_name, table_schema, + drop_on_commit=True): + if not drop_on_commit: + return "CREATE TEMPORARY TABLE %s (%s);" % (table_name, + table_schema) + return "CREATE TEMPORARY TABLE %s (%s) ON COMMIT DROP;" % (table_name, + table_schema) + + + def boolean_value(self, value): + if value: + return True + else: + return False + + def list_users(self, cursor): + """return the list of existing database users""" + # Host, Password + cursor.execute("SELECT User FROM mysql.user") + return [r[0] for r in cursor.fetchall()] + + def list_databases(self, cursor): + """return the list of existing databases""" + cursor.execute('SHOW DATABASES') + return [r[0] for r in cursor.fetchall()] + + def sql_create_database(self, dbname, encoding='utf-8'): + sql = "CREATE DATABASE %(dbname)s" + if encoding: + sql += " CHARACTER SET %(encoding)s" + return sql % locals() + + def create_database(self, cursor, dbname, owner=None, encoding='utf-8'): + """create a new database""" + cursor.execute(self.sql_create_database(dbname, encoding)) + if owner: + cursor.execute('GRANT ALL ON `%s`.* to %s' % (dbname, owner)) + + def list_tables(self, cursor): + """return the list of tables of a database""" + cursor.execute("SHOW TABLES") + return [r[0] for r in cursor.fetchall()] + + + +ADV_FUNC_HELPER_DIRECTORY = {'postgres': _PGAdvFuncHelper(), + 'sqlite': _SqliteAdvFuncHelper(), + 'mysql': _MyAdvFuncHelper(), + None: _GenericAdvFuncHelper()} + + + +def register_function(driver, funcdef): + if isinstance(funcdef, basestring) : + funcdef = FunctionDescr(funcdef.upper()) + assert not funcdef.name in FUNCTIONS, \ + '%s is already registered' % funcdef.name + ADV_FUNC_HELPER_DIRECTORY[driver].FUNCTIONS[funcdef.name] = funcdef + + +def get_adv_func_helper(driver): + """returns an advanced function helper for the given driver""" + return ADV_FUNC_HELPER_DIRECTORY.get(driver, + ADV_FUNC_HELPER_DIRECTORY[None]) |