# postgres.py # Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php import datetime, sys, StringIO, string, types, re import sqlalchemy.util as util import sqlalchemy.sql as sql import sqlalchemy.engine as engine import sqlalchemy.engine.default as default import sqlalchemy.schema as schema import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions import information_schema as ischema from sqlalchemy import * import re try: import mx.DateTime.DateTime as mxDateTime except: mxDateTime = None try: import psycopg2 as psycopg #import psycopg2.psycopg1 as psycopg except: try: import psycopg except: psycopg = None class PGNumeric(sqltypes.Numeric): def get_col_spec(self): return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} class PGFloat(sqltypes.Float): def get_col_spec(self): return "FLOAT(%(precision)s)" % {'precision': self.precision} class PGInteger(sqltypes.Integer): def get_col_spec(self): return "INTEGER" class PGSmallInteger(sqltypes.Smallinteger): def get_col_spec(self): return "SMALLINT" class PG2DateTime(sqltypes.DateTime): def get_col_spec(self): return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" class PG1DateTime(sqltypes.DateTime): def convert_bind_param(self, value, dialect): if value is not None: if isinstance(value, datetime.datetime): seconds = float(str(value.second) + "." + str(value.microsecond)) mx_datetime = mxDateTime(value.year, value.month, value.day, value.hour, value.minute, seconds) return psycopg.TimestampFromMx(mx_datetime) return psycopg.TimestampFromMx(value) else: return None def convert_result_value(self, value, dialect): if value is None: return None second_parts = str(value.second).split(".") seconds = int(second_parts[0]) microseconds = int(second_parts[1]) return datetime.datetime(value.year, value.month, value.day, value.hour, value.minute, seconds, microseconds) def get_col_spec(self): return "TIMESTAMP " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" class PG2Date(sqltypes.Date): def get_col_spec(self): return "DATE" class PG1Date(sqltypes.Date): def convert_bind_param(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: return psycopg.DateFromMx(value) else: return None def convert_result_value(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime return value def get_col_spec(self): return "DATE" class PG2Time(sqltypes.Time): def get_col_spec(self): return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" class PG1Time(sqltypes.Time): def convert_bind_param(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime # this one doesnt seem to work with the "emulation" mode if value is not None: return psycopg.TimeFromMx(value) else: return None def convert_result_value(self, value, dialect): # TODO: perform appropriate postgres1 conversion between Python DateTime/MXDateTime return value def get_col_spec(self): return "TIME " + (self.timezone and "WITH" or "WITHOUT") + " TIME ZONE" class PGText(sqltypes.TEXT): def get_col_spec(self): return "TEXT" class PGString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} class PGChar(sqltypes.CHAR): def get_col_spec(self): return "CHAR(%(length)s)" % {'length' : self.length} class PGBinary(sqltypes.Binary): def get_col_spec(self): return "BYTEA" class PGBoolean(sqltypes.Boolean): def get_col_spec(self): return "BOOLEAN" pg2_colspecs = { sqltypes.Integer : PGInteger, sqltypes.Smallinteger : PGSmallInteger, sqltypes.Numeric : PGNumeric, sqltypes.Float : PGFloat, sqltypes.DateTime : PG2DateTime, sqltypes.Date : PG2Date, sqltypes.Time : PG2Time, sqltypes.String : PGString, sqltypes.Binary : PGBinary, sqltypes.Boolean : PGBoolean, sqltypes.TEXT : PGText, sqltypes.CHAR: PGChar, } pg1_colspecs = pg2_colspecs.copy() pg1_colspecs.update({ sqltypes.DateTime : PG1DateTime, sqltypes.Date : PG1Date, sqltypes.Time : PG1Time }) pg2_ischema_names = { 'integer' : PGInteger, 'bigint' : PGInteger, 'smallint' : PGSmallInteger, 'character varying' : PGString, 'character' : PGChar, 'text' : PGText, 'numeric' : PGNumeric, 'float' : PGFloat, 'real' : PGFloat, 'double precision' : PGFloat, 'timestamp' : PG2DateTime, 'timestamp with time zone' : PG2DateTime, 'timestamp without time zone' : PG2DateTime, 'time with time zone' : PG2Time, 'time without time zone' : PG2Time, 'date' : PG2Date, 'time': PG2Time, 'bytea' : PGBinary, 'boolean' : PGBoolean, } pg1_ischema_names = pg2_ischema_names.copy() pg1_ischema_names.update({ 'timestamp with time zone' : PG1DateTime, 'timestamp without time zone' : PG1DateTime, 'date' : PG1Date, 'time' : PG1Time }) def engine(opts, **params): return PGSQLEngine(opts, **params) def descriptor(): return {'name':'postgres', 'description':'PostGres', 'arguments':[ ('username',"Database Username",None), ('password',"Database Password",None), ('database',"Database Name",None), ('host',"Hostname", None), ]} class PGExecutionContext(default.DefaultExecutionContext): def post_exec(self, engine, proxy, compiled, parameters, **kwargs): if getattr(compiled, "isinsert", False) and self.last_inserted_ids is None: if not engine.dialect.use_oids: pass # will raise invalid error when they go to get them else: table = compiled.statement.table cursor = proxy() if cursor.lastrowid is not None and table is not None and len(table.primary_key): s = sql.select(table.primary_key, table.oid_column == cursor.lastrowid) c = s.compile(engine=engine) cursor = proxy(str(c), c.get_params()) row = cursor.fetchone() self._last_inserted_ids = [v for v in row] class PGDialect(ansisql.ANSIDialect): def __init__(self, module=None, use_oids=False, use_information_schema=False, **params): self.use_oids = use_oids if module is None: #if psycopg is None: # raise exceptions.ArgumentError("Couldnt locate psycopg1 or psycopg2: specify postgres module argument") self.module = psycopg else: self.module = module # figure psycopg version 1 or 2 try: if self.module.__version__.startswith('2'): self.version = 2 else: self.version = 1 except: self.version = 1 ansisql.ANSIDialect.__init__(self, **params) self.use_information_schema = use_information_schema # produce consistent paramstyle even if psycopg2 module not present if self.module is None: self.paramstyle = 'pyformat' def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) if opts.has_key('port'): if self.version == 2: opts['port'] = int(opts['port']) else: opts['port'] = str(opts['port']) opts.update(url.query) return ([], opts) def create_execution_context(self): return PGExecutionContext(self) def type_descriptor(self, typeobj): if self.version == 2: return sqltypes.adapt_type(typeobj, pg2_colspecs) else: return sqltypes.adapt_type(typeobj, pg1_colspecs) def compiler(self, statement, bindparams, **kwargs): return PGCompiler(self, statement, bindparams, **kwargs) def schemagenerator(self, *args, **kwargs): return PGSchemaGenerator(*args, **kwargs) def schemadropper(self, *args, **kwargs): return PGSchemaDropper(*args, **kwargs) def defaultrunner(self, engine, proxy): return PGDefaultRunner(engine, proxy) def preparer(self): return PGIdentifierPreparer(self) def get_default_schema_name(self, connection): if not hasattr(self, '_default_schema_name'): self._default_schema_name = connection.scalar("select current_schema()", None) return self._default_schema_name def last_inserted_ids(self): if self.context.last_inserted_ids is None: raise exceptions.InvalidRequestError("no INSERT executed, or cant use cursor.lastrowid without Postgres OIDs enabled") else: return self.context.last_inserted_ids def oid_column_name(self): if self.use_oids: return "oid" else: return None def do_executemany(self, c, statement, parameters, context=None): """we need accurate rowcounts for updates, inserts and deletes. psycopg2 is not nice enough to produce this correctly for an executemany, so we do our own executemany here.""" rowcount = 0 for param in parameters: c.execute(statement, param) rowcount += c.rowcount if context is not None: context._rowcount = rowcount def dbapi(self): return self.module def has_table(self, connection, table_name): # TODO: why are we case folding here ? cursor = connection.execute("""select relname from pg_class where lower(relname) = %(name)s""", {'name':table_name.lower()}) return bool( not not cursor.rowcount ) def has_sequence(self, connection, sequence_name): cursor = connection.execute('''SELECT relname FROM pg_class WHERE relkind = 'S' AND relnamespace IN ( SELECT oid FROM pg_namespace WHERE nspname NOT LIKE 'pg_%%' AND nspname != 'information_schema' AND relname = %(seqname)s);''', {'seqname': sequence_name}) return bool(not not cursor.rowcount) def reflecttable(self, connection, table): if self.version == 2: ischema_names = pg2_ischema_names else: ischema_names = pg1_ischema_names if self.use_information_schema: ischema.reflecttable(connection, table, ischema_names) else: preparer = self.identifier_preparer if table.schema is not None: current_schema = table.schema else: current_schema = connection.default_schema_name() ## information schema in pg suffers from too many permissions' restrictions ## let us find out at the pg way what is needed... SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum AND a.atthasdef) AS DEFAULT, a.attnotnull, a.attnum FROM pg_catalog.pg_attribute a WHERE a.attrelid = ( SELECT c.oid FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE (n.nspname = :schema OR pg_catalog.pg_table_is_visible(c.oid)) AND c.relname = :table_name AND (c.relkind = 'r' OR c.relkind = 'v') ) AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum """ s = text(SQL_COLS ) c = connection.execute(s, table_name=table.name, schema=current_schema) found_table = False while True: row = c.fetchone() if row is None: break found_table = True name = row['attname'] ## strip (30) from character varying(30) attype = re.search('([^\(]+)', row['format_type']).group(1) nullable = row['attnotnull'] == False try: charlen = re.search('\(([\d,]+)\)',row['format_type']).group(1) except: charlen = None numericprec = None numericscale = None default = row['default'] if attype == 'numeric': numericprec, numericscale = charlen.split(',') charlen = None if attype == 'double precision': numericprec, numericscale = (53, None) charlen = None if attype == 'integer': numericprec, numericscale = (32, 0) charlen = None args = [] for a in (charlen, numericprec, numericscale): if a is not None: args.append(int(a)) coltype = ischema_names[attype] coltype = coltype(*args) colargs= [] if default is not None: colargs.append(PassiveDefault(sql.text(default))) table.append_item(schema.Column(name, coltype, nullable=nullable, *colargs)) if not found_table: raise exceptions.NoSuchTableError(table.name) # Primary keys PK_SQL = """ SELECT attname FROM pg_attribute WHERE attrelid = ( SELECT indexrelid FROM pg_index i, pg_class c, pg_namespace n WHERE n.nspname = :schema AND c.relname = :table_name AND c.oid = i.indrelid AND n.oid = c.relnamespace AND i.indisprimary = 't' ) ; """ t = text(PK_SQL) c = connection.execute(t, table_name=table.name, schema=current_schema) while True: row = c.fetchone() if row is None: break pk = row[0] table.c[pk]._set_primary_key() # Foreign keys FK_SQL = """ SELECT conname, pg_catalog.pg_get_constraintdef(oid, true) as condef FROM pg_catalog.pg_constraint r WHERE r.conrelid = ( SELECT c.oid FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = :table_name AND pg_catalog.pg_table_is_visible(c.oid)) AND r.contype = 'f' ORDER BY 1 """ t = text(FK_SQL) c = connection.execute(t, table_name=table.name) while True: row = c.fetchone() if row is None: break identifier = '(?:[a-z_][a-z0-9_$]+|"(?:[^"]|"")+")' identifier_group = '%s(?:, %s)*' % (identifier, identifier) identifiers = '(%s)(?:, (%s))*' % (identifier, identifier) f = re.compile(identifiers) # FOREIGN KEY (mail_user_id,"Mail_User_ID2") REFERENCES "mYschema".euro_user(user_id,"User_ID2") foreign_key_pattern = 'FOREIGN KEY \((%s)\) REFERENCES (?:(%s)\.)?(%s)\((%s)\)' % (identifier_group, identifier, identifier, identifier_group) p = re.compile(foreign_key_pattern) m = p.search(row['condef']) (constrained_columns, referred_schema, referred_table, referred_columns) = m.groups() constrained_columns = [preparer._unquote_identifier(x) for x in f.search(constrained_columns).groups() if x] if referred_schema: referred_schema = preparer._unquote_identifier(referred_schema) referred_table = preparer._unquote_identifier(referred_table) referred_columns = [preparer._unquote_identifier(x) for x in f.search(referred_columns).groups() if x] refspec = [] if referred_schema is not None: schema.Table(referred_table, table.metadata, autoload=True, schema=referred_schema, autoload_with=connection) for column in referred_columns: refspec.append(".".join([referred_schema, referred_table, column])) else: schema.Table(referred_table, table.metadata, autoload=True, autoload_with=connection) for column in referred_columns: refspec.append(".".join([referred_table, column])) table.append_item(ForeignKeyConstraint(constrained_columns, refspec, row['conname'])) class PGCompiler(ansisql.ANSICompiler): def visit_insert_column(self, column, parameters): # Postgres advises against OID usage and turns it off in 8.1, # effectively making cursor.lastrowid # useless, effectively making reliance upon SERIAL useless. # so all column primary key inserts must be explicitly present if column.primary_key: parameters[column.key] = None def limit_clause(self, select): text = "" if select.limit is not None: text += " \n LIMIT " + str(select.limit) if select.offset is not None: if select.limit is None: text += " \n LIMIT ALL" text += " OFFSET " + str(select.offset) return text def visit_select_precolumns(self, select): if select.distinct: if type(select.distinct) == bool: return "DISTINCT " if type(select.distinct) == list: dist_set = "DISTINCT ON (" for col in select.distinct: dist_set += self.strings[col] + ", " dist_set = dist_set[:-2] + ") " return dist_set return "DISTINCT ON (" + str(select.distinct) + ") " else: return "" def binary_operator_string(self, binary): if isinstance(binary.type, sqltypes.String) and binary.operator == '+': return '||' else: return ansisql.ANSICompiler.binary_operator_string(self, binary) class PGSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) if column.primary_key and len(column.foreign_keys)==0 and column.autoincrement and isinstance(column.type, sqltypes.Integer) and not isinstance(column.type, sqltypes.SmallInteger) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): colspec += " SERIAL" else: colspec += " " + column.type.engine_impl(self.engine).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default if not column.nullable: colspec += " NOT NULL" return colspec def visit_sequence(self, sequence): if not sequence.optional and not self.engine.dialect.has_sequence(self.connection, sequence.name): self.append("CREATE SEQUENCE %s" % self.preparer.format_sequence(sequence)) self.execute() class PGSchemaDropper(ansisql.ANSISchemaDropper): def visit_sequence(self, sequence): if not sequence.optional and self.engine.dialect.has_sequence(self.connection, sequence.name): self.append("DROP SEQUENCE %s" % sequence.name) self.execute() class PGDefaultRunner(ansisql.ANSIDefaultRunner): def get_column_default(self, column, isinsert=True): if column.primary_key: # passive defaults on primary keys have to be overridden if isinstance(column.default, schema.PassiveDefault): c = self.proxy("select %s" % column.default.arg) return c.fetchone()[0] elif (isinstance(column.type, sqltypes.Integer) and column.autoincrement) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)): sch = column.table.schema # TODO: this has to build into the Sequence object so we can get the quoting # logic from it if sch is not None: exc = "select nextval('\"%s.%s_%s_seq\"')" % (sch, column.table.name, column.name) else: exc = "select nextval('\"%s_%s_seq\"')" % (column.table.name, column.name) c = self.proxy(exc) return c.fetchone()[0] else: return ansisql.ANSIDefaultRunner.get_column_default(self, column) else: return ansisql.ANSIDefaultRunner.get_column_default(self, column) def visit_sequence(self, seq): if not seq.optional: c = self.proxy("select nextval('%s')" % seq.name) #TODO: self.dialect.preparer.format_sequence(seq)) return c.fetchone()[0] else: return None class PGIdentifierPreparer(ansisql.ANSIIdentifierPreparer): def _fold_identifier_case(self, value): return value.lower() def _unquote_identifier(self, value): if value[0] == self.initial_quote: value = value[1:-1].replace('""','"') return value dialect = PGDialect