diff options
Diffstat (limited to 'lib/sqlalchemy')
61 files changed, 6335 insertions, 5548 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 343a0cac8..e33451611 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -5,6 +5,11 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php import inspect +import sys + +import sqlalchemy.exc as exceptions +sys.modules['sqlalchemy.exceptions'] = exceptions + from sqlalchemy.types import \ BLOB, BOOLEAN, CHAR, CLOB, DATE, DATETIME, DECIMAL, FLOAT, INT, \ NCHAR, NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR, \ @@ -32,3 +37,5 @@ __all__ = [ name for name, obj in locals().items() if not (name.startswith('_') or inspect.ismodule(obj)) ] __version__ = 'svn' + +del inspect, sys diff --git a/lib/sqlalchemy/databases/access.py b/lib/sqlalchemy/databases/access.py index 38dba17a5..aa65985d4 100644 --- a/lib/sqlalchemy/databases/access.py +++ b/lib/sqlalchemy/databases/access.py @@ -5,7 +5,7 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, schema, types, exceptions, pool +from sqlalchemy import sql, schema, types, exc, pool from sqlalchemy.sql import compiler, expression from sqlalchemy.engine import default, base @@ -202,7 +202,7 @@ class AccessDialect(default.DefaultDialect): except pythoncom.com_error: pass else: - raise exceptions.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.") + raise exc.InvalidRequestError("Can't find a DB engine. Check http://support.microsoft.com/kb/239114 for details.") import pyodbc as module return module @@ -236,7 +236,7 @@ class AccessDialect(default.DefaultDialect): c.execute(statement, parameters) self.context.rowcount = c.rowcount except Exception, e: - raise exceptions.DBAPIError.instance(statement, parameters, e) + raise exc.DBAPIError.instance(statement, parameters, e) def has_table(self, connection, tablename, schema=None): # This approach seems to be more reliable that using DAO @@ -272,7 +272,7 @@ class AccessDialect(default.DefaultDialect): if tbl.Name.lower() == table.name.lower(): break else: - raise exceptions.NoSuchTableError(table.name) + raise exc.NoSuchTableError(table.name) for col in tbl.Fields: coltype = self.ischema_names[col.Type] @@ -333,7 +333,7 @@ class AccessDialect(default.DefaultDialect): # This is necessary, so we get the latest updates dtbs = daoEngine.OpenDatabase(connection.engine.url.database) - names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] <> "~TMP"] + names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] != "~TMP"] dtbs.Close() return names @@ -345,7 +345,7 @@ class AccessCompiler(compiler.DefaultCompiler): if select.limit: s += "TOP %s " % (select.limit) if select.offset: - raise exceptions.InvalidRequestError('Access does not support LIMIT with an offset') + raise exc.InvalidRequestError('Access does not support LIMIT with an offset') return s def limit_clause(self, select): @@ -378,14 +378,14 @@ class AccessCompiler(compiler.DefaultCompiler): # Strip schema def visit_table(self, table, asfrom=False, **kwargs): if asfrom: - return self.preparer.quote(table, table.name) + return self.preparer.quote(table.name, table.quote) else: return "" class AccessSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() # install a sequence if we have an implicit IDENTITY column if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 5e1dd72bb..098759d18 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -89,7 +89,7 @@ connections are active, the following setting may alleviate the problem:: import datetime -from sqlalchemy import exceptions, schema, types as sqltypes, sql, util +from sqlalchemy import exc, schema, types as sqltypes, sql, util from sqlalchemy.engine import base, default @@ -272,7 +272,7 @@ class FBDialect(default.DefaultDialect): default.DefaultDialect.__init__(self, **kwargs) self.type_conv = type_conv - self.concurrency_level= concurrency_level + self.concurrency_level = concurrency_level def dbapi(cls): import kinterbasdb @@ -320,7 +320,7 @@ class FBDialect(default.DefaultDialect): version = fbconn.server_version m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version) if not m: - raise exceptions.AssertionError("Could not determine version from string '%s'" % version) + raise AssertionError("Could not determine version from string '%s'" % version) return tuple([int(x) for x in m.group(5, 6, 4)]) def _normalize_name(self, name): @@ -455,7 +455,7 @@ class FBDialect(default.DefaultDialect): # get primary key fields c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) - pkfields =[self._normalize_name(r['fname']) for r in c.fetchall()] + pkfields = [self._normalize_name(r['fname']) for r in c.fetchall()] # get all of the fields for this table c = connection.execute(tblqry, [tablename]) @@ -509,14 +509,15 @@ class FBDialect(default.DefaultDialect): table.append_column(col) if not found_table: - raise exceptions.NoSuchTableError(table.name) + raise exc.NoSuchTableError(table.name) # get the foreign keys c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) fks = {} while True: row = c.fetchone() - if not row: break + if not row: + break cname = self._normalize_name(row['cname']) try: @@ -530,7 +531,7 @@ class FBDialect(default.DefaultDialect): fk[0].append(fname) fk[1].append(refspec) - for name,value in fks.iteritems(): + for name, value in fks.iteritems(): table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name)) def do_execute(self, cursor, statement, parameters, **kwargs): @@ -626,7 +627,7 @@ class FBSchemaGenerator(sql.compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) - colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: @@ -711,7 +712,7 @@ class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS def __init__(self, dialect): - super(FBIdentifierPreparer,self).__init__(dialect, omit_schema=True) + super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) dialect = FBDialect diff --git a/lib/sqlalchemy/databases/information_schema.py b/lib/sqlalchemy/databases/information_schema.py index 1b3b3838a..20929cf1e 100644 --- a/lib/sqlalchemy/databases/information_schema.py +++ b/lib/sqlalchemy/databases/information_schema.py @@ -1,5 +1,5 @@ import sqlalchemy.sql as sql -import sqlalchemy.exceptions as exceptions +import sqlalchemy.exc as exc from sqlalchemy import select, MetaData, Table, Column, String, Integer from sqlalchemy.schema import PassiveDefault, ForeignKeyConstraint @@ -124,13 +124,13 @@ def reflecttable(connection, table, include_columns, ischema_names): coltype = ischema_names[type] #print "coltype " + repr(coltype) + " args " + repr(args) coltype = coltype(*args) - colargs= [] + colargs = [] if default is not None: colargs.append(PassiveDefault(sql.text(default))) table.append_column(Column(name, coltype, nullable=nullable, *colargs)) if not found_table: - raise exceptions.NoSuchTableError(table.name) + raise exc.NoSuchTableError(table.name) # we are relying on the natural ordering of the constraint_column_usage table to return the referenced columns # in an order that corresponds to the ordinal_position in the key_constraints table, otherwise composite foreign keys @@ -157,13 +157,13 @@ def reflecttable(connection, table, include_columns, ischema_names): row[colmap[6]] ) #print "type %s on column %s to remote %s.%s.%s" % (type, constrained_column, referred_schema, referred_table, referred_column) - if type=='PRIMARY KEY': + if type == 'PRIMARY KEY': table.primary_key.add(table.c[constrained_column]) - elif type=='FOREIGN KEY': + elif type == 'FOREIGN KEY': try: fk = fks[constraint_name] except KeyError: - fk = ([],[]) + fk = ([], []) fks[constraint_name] = fk if current_schema == referred_schema: referred_schema = table.schema diff --git a/lib/sqlalchemy/databases/informix.py b/lib/sqlalchemy/databases/informix.py index 2e1f19de9..c7bc49dbe 100644 --- a/lib/sqlalchemy/databases/informix.py +++ b/lib/sqlalchemy/databases/informix.py @@ -8,7 +8,7 @@ import datetime -from sqlalchemy import sql, schema, exceptions, pool, util +from sqlalchemy import sql, schema, exc, pool, util from sqlalchemy.sql import compiler from sqlalchemy.engine import default from sqlalchemy import types as sqltypes @@ -197,7 +197,7 @@ class InfoExecutionContext(default.DefaultExecutionContext): # 5 - rowid after insert def post_exec(self): if getattr(self.compiled, "isinsert", False) and self.last_inserted_ids() is None: - self._last_inserted_ids = [self.cursor.sqlerrd[1],] + self._last_inserted_ids = [self.cursor.sqlerrd[1]] elif hasattr( self.compiled , 'offset' ): self.cursor.offset( self.compiled.offset ) super(InfoExecutionContext, self).post_exec() @@ -210,7 +210,7 @@ class InfoDialect(default.DefaultDialect): # for informix 7.31 max_identifier_length = 18 - def __init__(self, use_ansi=True,**kwargs): + def __init__(self, use_ansi=True, **kwargs): self.use_ansi = use_ansi default.DefaultDialect.__init__(self, **kwargs) @@ -244,19 +244,19 @@ class InfoDialect(default.DefaultDialect): else: opt = {} - return ([dsn,], opt ) + return ([dsn], opt) def create_execution_context(self , *args, **kwargs): return InfoExecutionContext(self, *args, **kwargs) - def oid_column_name(self,column): + def oid_column_name(self, column): return "rowid" def table_names(self, connection, schema): s = "select tabname from systables" return [row[0] for row in connection.execute(s)] - def has_table(self, connection, table_name,schema=None): + def has_table(self, connection, table_name, schema=None): cursor = connection.execute("""select tabname from systables where tabname=?""", table_name.lower() ) return bool( cursor.fetchone() is not None ) @@ -264,18 +264,18 @@ class InfoDialect(default.DefaultDialect): c = connection.execute ("select distinct OWNER from systables where tabname=?", table.name.lower() ) rows = c.fetchall() if not rows : - raise exceptions.NoSuchTableError(table.name) + raise exc.NoSuchTableError(table.name) else: if table.owner is not None: if table.owner.lower() in [r[0] for r in rows]: owner = table.owner.lower() else: - raise exceptions.AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name)) + raise AssertionError("Specified owner %s does not own table %s"%(table.owner, table.name)) else: if len(rows)==1: owner = rows[0][0] else: - raise exceptions.AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name) + raise AssertionError("There are multiple tables with name %s in the schema, you must specifie owner"%table.name) c = connection.execute ("""select colname , coltype , collength , t3.default , t1.colno from syscolumns as t1 , systables as t2 , OUTER sysdefaults as t3 where t1.tabid = t2.tabid and t2.tabname=? and t2.owner=? @@ -284,7 +284,7 @@ class InfoDialect(default.DefaultDialect): rows = c.fetchall() if not rows: - raise exceptions.NoSuchTableError(table.name) + raise exc.NoSuchTableError(table.name) for name , colattr , collength , default , colno in rows: name = name.lower() @@ -341,8 +341,8 @@ class InfoDialect(default.DefaultDialect): try: fk = fks[cons_name] except KeyError: - fk = ([], []) - fks[cons_name] = fk + fk = ([], []) + fks[cons_name] = fk refspec = ".".join([remote_table, remote_column]) schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection) if local_column not in fk[0]: @@ -436,7 +436,7 @@ class InfoSchemaGenerator(compiler.SchemaGenerator): colspec += " SERIAL" self.has_serial = True else: - colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/databases/maxdb.py b/lib/sqlalchemy/databases/maxdb.py index 23ff1f4a0..392cde61f 100644 --- a/lib/sqlalchemy/databases/maxdb.py +++ b/lib/sqlalchemy/databases/maxdb.py @@ -58,7 +58,7 @@ this. import datetime, itertools, re -from sqlalchemy import exceptions, schema, sql, util +from sqlalchemy import exc, schema, sql, util from sqlalchemy.sql import operators as sql_operators, expression as sql_expr from sqlalchemy.sql import compiler, visitors from sqlalchemy.engine import base as engine_base, default @@ -213,7 +213,7 @@ class MaxTimestamp(sqltypes.DateTime): ms = getattr(value, 'microsecond', 0) return value.strftime("%Y-%m-%d %H:%M:%S." + ("%06u" % ms)) else: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "datetimeformat '%s' is not supported." % ( dialect.datetimeformat,)) return process @@ -235,7 +235,7 @@ class MaxTimestamp(sqltypes.DateTime): value[11:13], value[14:16], value[17:19], value[20:])]) else: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "datetimeformat '%s' is not supported." % ( dialect.datetimeformat,)) return process @@ -256,7 +256,7 @@ class MaxDate(sqltypes.Date): elif dialect.datetimeformat == 'iso': return value.strftime("%Y-%m-%d") else: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "datetimeformat '%s' is not supported." % ( dialect.datetimeformat,)) return process @@ -272,7 +272,7 @@ class MaxDate(sqltypes.Date): return datetime.date( *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) else: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "datetimeformat '%s' is not supported." % ( dialect.datetimeformat,)) return process @@ -293,7 +293,7 @@ class MaxTime(sqltypes.Time): elif dialect.datetimeformat == 'iso': return value.strftime("%H-%M-%S") else: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "datetimeformat '%s' is not supported." % ( dialect.datetimeformat,)) return process @@ -310,7 +310,7 @@ class MaxTime(sqltypes.Time): return datetime.time( *[int(v) for v in (value[0:4], value[5:7], value[8:10])]) else: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "datetimeformat '%s' is not supported." % ( dialect.datetimeformat,)) return process @@ -599,7 +599,7 @@ class MaxDBDialect(default.DefaultDialect): rows = connection.execute(st, params).fetchall() if not rows: - raise exceptions.NoSuchTableError(table.fullname) + raise exc.NoSuchTableError(table.fullname) include_columns = util.Set(include_columns or []) @@ -833,7 +833,7 @@ class MaxDBCompiler(compiler.DefaultCompiler): # LIMIT. Right? Other dialects seem to get away with # dropping order. if select._limit: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "MaxDB does not support ORDER BY in subqueries") else: return "" @@ -846,7 +846,7 @@ class MaxDBCompiler(compiler.DefaultCompiler): sql = select._distinct and 'DISTINCT ' or '' if self.is_subquery(select) and select._limit: if select._offset: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( 'MaxDB does not support LIMIT with an offset.') sql += 'TOP %s ' % select._limit return sql @@ -858,7 +858,7 @@ class MaxDBCompiler(compiler.DefaultCompiler): # sub queries need TOP return '' elif select._offset: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( 'MaxDB does not support LIMIT with an offset.') else: return ' \n LIMIT %s' % (select._limit,) @@ -952,7 +952,7 @@ class MaxDBIdentifierPreparer(compiler.IdentifierPreparer): class MaxDBSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kw): colspec = [self.preparer.format_column(column), - column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec()] + column.type.dialect_impl(self.dialect).get_col_spec()] if not column.nullable: colspec.append('NOT NULL') diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index ab5a96871..4e129952f 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -40,7 +40,7 @@ Known issues / TODO: import datetime, operator, re, sys -from sqlalchemy import sql, schema, exceptions, util +from sqlalchemy import sql, schema, exc, util from sqlalchemy.sql import compiler, expression, operators as sqlops, functions as sql_functions from sqlalchemy.engine import default, base from sqlalchemy import types as sqltypes @@ -440,7 +440,7 @@ class MSSQLDialect(default.DefaultDialect): dialect_cls = dialect_mapping[module_name] return dialect_cls.import_dbapi() except KeyError: - raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name) + raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name) else: for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]: try: @@ -512,7 +512,7 @@ class MSSQLDialect(default.DefaultDialect): self.context.rowcount = c.rowcount c.DBPROP_COMMITPRESERVE = "Y" except Exception, e: - raise exceptions.DBAPIError.instance(statement, parameters, e) + raise exc.DBAPIError.instance(statement, parameters, e) def table_names(self, connection, schema): from sqlalchemy.databases import information_schema as ischema @@ -602,14 +602,14 @@ class MSSQLDialect(default.DefaultDialect): elif coltype in (MSNVarchar, AdoMSNVarchar) and charlen == -1: args[0] = None coltype = coltype(*args) - colargs= [] + colargs = [] if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) table.append_column(schema.Column(name, coltype, nullable=nullable, autoincrement=False, *colargs)) if not found_table: - raise exceptions.NoSuchTableError(table.name) + raise exc.NoSuchTableError(table.name) # We also run an sp_columns to check for identity columns: cursor = connection.execute("sp_columns @table_name = '%s', @table_owner = '%s'" % (table.name, current_schema)) @@ -633,8 +633,8 @@ class MSSQLDialect(default.DefaultDialect): row = cursor.fetchone() cursor.close() if not row is None: - ic.sequence.start=int(row[0]) - ic.sequence.increment=int(row[1]) + ic.sequence.start = int(row[0]) + ic.sequence.increment = int(row[1]) except: # ignoring it, works just like before pass @@ -684,13 +684,15 @@ class MSSQLDialect(default.DefaultDialect): if rfknm != fknm: if fknm: - table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table,s,t,c) for s,t,c in rcols], fknm)) + table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm)) fknm, scols, rcols = (rfknm, [], []) - if (not scol in scols): scols.append(scol) - if (not (rschema, rtbl, rcol) in rcols): rcols.append((rschema, rtbl, rcol)) + if not scol in scols: + scols.append(scol) + if not (rschema, rtbl, rcol) in rcols: + rcols.append((rschema, rtbl, rcol)) if fknm and scols: - table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table,s,t,c) for s,t,c in rcols], fknm)) + table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm)) class MSSQLDialect_pymssql(MSSQLDialect): @@ -895,7 +897,7 @@ class MSSQLCompiler(compiler.DefaultCompiler): if select._limit: s += "TOP %s " % (select._limit,) if select._offset: - raise exceptions.InvalidRequestError('MSSQL does not support LIMIT with an offset') + raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') return s return compiler.DefaultCompiler.get_select_precolumns(self, select) @@ -1005,7 +1007,7 @@ class MSSQLCompiler(compiler.DefaultCompiler): class MSSQLSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() # install a IDENTITY Sequence if we have an implicit IDENTITY column if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \ diff --git a/lib/sqlalchemy/databases/mxODBC.py b/lib/sqlalchemy/databases/mxODBC.py index a3acac587..92f533633 100644 --- a/lib/sqlalchemy/databases/mxODBC.py +++ b/lib/sqlalchemy/databases/mxODBC.py @@ -53,8 +53,8 @@ class Connection: # override 'connect' call def connect(*args, **kwargs): - import mx.ODBC.Windows - conn = mx.ODBC.Windows.Connect(*args, **kwargs) - conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT - return Connection(conn) + import mx.ODBC.Windows + conn = mx.ODBC.Windows.Connect(*args, **kwargs) + conn.datetimeformat = mx.ODBC.Windows.PYDATETIME_DATETIMEFORMAT + return Connection(conn) Connect = connect diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index a86035be5..9cc5c38a6 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -156,7 +156,7 @@ timely information affecting MySQL in SQLAlchemy. import datetime, inspect, re, sys from array import array as _array -from sqlalchemy import exceptions, logging, schema, sql, util +from sqlalchemy import exc, log, schema, sql, util from sqlalchemy.sql import operators as sql_operators from sqlalchemy.sql import functions as sql_functions from sqlalchemy.sql import compiler @@ -404,7 +404,7 @@ class MSDouble(sqltypes.Float, _NumericType): if ((precision is None and length is not None) or (precision is not None and length is None)): - raise exceptions.ArgumentError( + raise exc.ArgumentError( "You must specify both precision and length or omit " "both altogether.") @@ -1188,7 +1188,7 @@ class MSEnum(MSString): super_convert = super(MSEnum, self).bind_processor(dialect) def process(value): if self.strict and value is not None and value not in self.enums: - raise exceptions.InvalidRequestError('"%s" not a valid value for ' + raise exc.InvalidRequestError('"%s" not a valid value for ' 'this enum' % value) if super_convert: return super_convert(value) @@ -1588,7 +1588,7 @@ class MySQLDialect(default.DefaultDialect): have = rs.rowcount > 0 rs.close() return have - except exceptions.SQLError, e: + except exc.SQLError, e: if e.orig.args[0] == 1146: return False raise @@ -1823,14 +1823,14 @@ class MySQLDialect(default.DefaultDialect): try: try: rp = connection.execute(st) - except exceptions.SQLError, e: + except exc.SQLError, e: if e.orig.args[0] == 1146: - raise exceptions.NoSuchTableError(full_name) + raise exc.NoSuchTableError(full_name) else: raise row = _compat_fetchone(rp, charset=charset) if not row: - raise exceptions.NoSuchTableError(full_name) + raise exc.NoSuchTableError(full_name) return row[1].strip() finally: if rp: @@ -1850,9 +1850,9 @@ class MySQLDialect(default.DefaultDialect): try: try: rp = connection.execute(st) - except exceptions.SQLError, e: + except exc.SQLError, e: if e.orig.args[0] == 1146: - raise exceptions.NoSuchTableError(full_name) + raise exc.NoSuchTableError(full_name) else: raise rows = _compat_fetchall(rp, charset=charset) @@ -1966,7 +1966,7 @@ class MySQLCompiler(compiler.DefaultCompiler): def for_update_clause(self, select): if select.for_update == 'read': - return ' LOCK IN SHARE MODE' + return ' LOCK IN SHARE MODE' else: return super(MySQLCompiler, self).for_update_clause(select) @@ -2022,8 +2022,7 @@ class MySQLSchemaGenerator(compiler.SchemaGenerator): """Builds column DDL.""" colspec = [self.preparer.format_column(column), - column.type.dialect_impl(self.dialect, - _for_ddl=column).get_col_spec()] + column.type.dialect_impl(self.dialect).get_col_spec()] default = self.get_column_default_string(column) if default is not None: @@ -2308,7 +2307,7 @@ class MySQLSchemaReflector(object): ref_names = spec['foreign'] if not util.Set(ref_names).issubset( util.Set([c.name for c in ref_table.c])): - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "Foreign key columns (%s) are not present on " "foreign table %s" % (', '.join(ref_names), ref_table.fullname())) @@ -2643,7 +2642,7 @@ class MySQLSchemaReflector(object): return self._re_keyexprs.findall(identifiers) -MySQLSchemaReflector.logger = logging.class_logger(MySQLSchemaReflector) +MySQLSchemaReflector.logger = log.class_logger(MySQLSchemaReflector) class _MySQLIdentifierPreparer(compiler.IdentifierPreparer): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 734ad58d1..5bc8a186f 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -7,7 +7,7 @@ import datetime, random, re -from sqlalchemy import util, sql, schema, exceptions, logging +from sqlalchemy import util, sql, schema, log from sqlalchemy.engine import default, base from sqlalchemy.sql import compiler, visitors from sqlalchemy.sql import operators as sql_operators, functions as sql_functions @@ -49,11 +49,11 @@ class OracleDateTime(sqltypes.DateTime): def result_processor(self, dialect): def process(value): - if value is None or isinstance(value,datetime.datetime): + if value is None or isinstance(value, datetime.datetime): return value else: # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year,value.month, + return datetime.datetime(value.year, value.month, value.day,value.hour, value.minute, value.second) return process @@ -72,11 +72,11 @@ class OracleTimestamp(sqltypes.TIMESTAMP): def result_processor(self, dialect): def process(value): - if value is None or isinstance(value,datetime.datetime): + if value is None or isinstance(value, datetime.datetime): return value else: # convert cx_oracle datetime object returned pre-python 2.4 - return datetime.datetime(value.year,value.month, + return datetime.datetime(value.year, value.month, value.day,value.hour, value.minute, value.second) return process @@ -216,13 +216,13 @@ class OracleExecutionContext(default.DefaultExecutionContext): def get_result_proxy(self): if hasattr(self, 'out_parameters'): if self.compiled_parameters is not None and len(self.compiled_parameters) == 1: - for bind, name in self.compiled.bind_names.iteritems(): - if name in self.out_parameters: - type = bind.type - self.out_parameters[name] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[name].getvalue()) + for bind, name in self.compiled.bind_names.iteritems(): + if name in self.out_parameters: + type = bind.type + self.out_parameters[name] = type.dialect_impl(self.dialect).result_processor(self.dialect)(self.out_parameters[name].getvalue()) else: - for k in self.out_parameters: - self.out_parameters[k] = self.out_parameters[k].getvalue() + for k in self.out_parameters: + self.out_parameters[k] = self.out_parameters[k].getvalue() if self.cursor.description is not None: for column in self.cursor.description: @@ -331,7 +331,7 @@ class OracleDialect(default.DefaultDialect): this id will be passed to do_begin_twophase(), do_rollback_twophase(), do_commit_twophase(). its format is unspecified.""" - id = random.randint(0,2**128) + id = random.randint(0, 2 ** 128) return (0x1234, "%032x" % 9, "%032x" % id) def do_release_savepoint(self, connection, name): @@ -392,7 +392,7 @@ class OracleDialect(default.DefaultDialect): cursor = connection.execute(s) else: s = "select table_name from all_tables where tablespace_name NOT IN ('SYSTEM','SYSAUX') AND OWNER = :owner" - cursor = connection.execute(s,{'owner':self._denormalize_name(schema)}) + cursor = connection.execute(s, {'owner': self._denormalize_name(schema)}) return [self._normalize_name(row[0]) for row in cursor] def _resolve_synonym(self, connection, desired_owner=None, desired_synonym=None, desired_table=None): @@ -400,11 +400,11 @@ class OracleDialect(default.DefaultDialect): if desired_owner is None, attempts to locate a distinct owner. - returns the actual name, owner, dblink name, and synonym name if found. + returns the actual name, owner, dblink name, and synonym name if found. """ - sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME - from ALL_SYNONYMS WHERE """ + sql = """select OWNER, TABLE_OWNER, TABLE_NAME, DB_LINK, SYNONYM_NAME + from ALL_SYNONYMS WHERE """ clauses = [] params = {} @@ -418,9 +418,9 @@ class OracleDialect(default.DefaultDialect): clauses.append("TABLE_NAME=:tname") params['tname'] = desired_table - sql += " AND ".join(clauses) + sql += " AND ".join(clauses) - result = connection.execute(sql, **params) + result = connection.execute(sql, **params) if desired_owner: row = result.fetchone() if row: @@ -430,7 +430,7 @@ class OracleDialect(default.DefaultDialect): else: rows = result.fetchall() if len(rows) > 1: - raise exceptions.AssertionError("There are multiple tables visible to the schema, you must specify owner") + raise AssertionError("There are multiple tables visible to the schema, you must specify owner") elif len(rows) == 1: row = rows[0] return row['TABLE_NAME'], row['TABLE_OWNER'], row['DB_LINK'], row['SYNONYM_NAME'] @@ -442,7 +442,7 @@ class OracleDialect(default.DefaultDialect): resolve_synonyms = table.kwargs.get('oracle_resolve_synonyms', False) - if resolve_synonyms: + if resolve_synonyms: actual_name, owner, dblink, synonym = self._resolve_synonym(connection, desired_owner=self._denormalize_name(table.schema), desired_synonym=self._denormalize_name(table.name)) else: actual_name, owner, dblink, synonym = None, None, None, None @@ -473,7 +473,7 @@ class OracleDialect(default.DefaultDialect): # NUMBER(9,2) if the precision is 9 and the scale is 2 # NUMBER(3) if the precision is 3 and scale is 0 #length is ignored except for CHAR and VARCHAR2 - if coltype=='NUMBER' : + if coltype == 'NUMBER' : if precision is None and scale is None: coltype = OracleNumeric elif precision is None and scale == 0 : @@ -498,7 +498,7 @@ class OracleDialect(default.DefaultDialect): table.append_column(schema.Column(colname, coltype, nullable=nullable, *colargs)) if not table.columns: - raise exceptions.AssertionError("Couldn't find any column information for table %s" % actual_name) + raise AssertionError("Couldn't find any column information for table %s" % actual_name) c = connection.execute("""SELECT ac.constraint_name, @@ -534,8 +534,8 @@ class OracleDialect(default.DefaultDialect): try: fk = fks[cons_name] except KeyError: - fk = ([], []) - fks[cons_name] = fk + fk = ([], []) + fks[cons_name] = fk if remote_table is None: # ticket 363 util.warn( @@ -551,7 +551,7 @@ class OracleDialect(default.DefaultDialect): remote_owner = self._normalize_name(ref_remote_owner) if not table.schema and self._denormalize_name(remote_owner) == owner: - refspec = ".".join([remote_table, remote_column]) + refspec = ".".join([remote_table, remote_column]) t = schema.Table(remote_table, table.metadata, autoload=True, autoload_with=connection, oracle_resolve_synonyms=resolve_synonyms, useexisting=True) else: refspec = ".".join([x for x in [remote_owner, remote_table, remote_column] if x]) @@ -566,7 +566,7 @@ class OracleDialect(default.DefaultDialect): table.append_constraint(schema.ForeignKeyConstraint(value[0], value[1], name=name)) -OracleDialect.logger = logging.class_logger(OracleDialect) +OracleDialect.logger = log.class_logger(OracleDialect) class _OuterJoinColumn(sql.ClauseElement): __visit_name__ = 'outer_join_column' @@ -574,7 +574,7 @@ class _OuterJoinColumn(sql.ClauseElement): self.column = column def _get_from_objects(self, **kwargs): return [] - + class OracleCompiler(compiler.DefaultCompiler): """Oracle compiler modifies the lexical structure of Select statements to work under non-ANSI configured Oracle databases, if @@ -615,10 +615,10 @@ class OracleCompiler(compiler.DefaultCompiler): return compiler.DefaultCompiler.visit_join(self, join, **kwargs) else: return self.process(join.left, asfrom=True) + ", " + self.process(join.right, asfrom=True) - + def _get_nonansi_join_whereclause(self, froms): clauses = [] - + def visit_join(join): if join.isouter: def visit_binary(binary): @@ -627,14 +627,14 @@ class OracleCompiler(compiler.DefaultCompiler): binary.left = _OuterJoinColumn(binary.left) elif binary.right.table is join.right: binary.right = _OuterJoinColumn(binary.right) - clauses.append(visitors.traverse(join.onclause, visit_binary=visit_binary, clone=True)) + clauses.append(visitors.cloned_traverse(join.onclause, {}, {'binary':visit_binary})) else: clauses.append(join.onclause) - + for f in froms: - visitors.traverse(f, visit_join=visit_join) + visitors.traverse(f, {}, {'join':visit_join}) return sql.and_(*clauses) - + def visit_outer_join_column(self, vc): return self.process(vc.column) + "(+)" @@ -670,7 +670,7 @@ class OracleCompiler(compiler.DefaultCompiler): if whereclause: select = select.where(whereclause) select._oracle_visit = True - + if select._limit is not None or select._offset is not None: # to use ROW_NUMBER(), an ORDER BY is required. orderby = self.process(select._order_by_clause) @@ -680,11 +680,11 @@ class OracleCompiler(compiler.DefaultCompiler): select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("ora_rn")).order_by(None) select._oracle_visit = True - + limitselect = sql.select([c for c in select.c if c.key!='ora_rn']) limitselect._oracle_visit = True limitselect._is_wrapper = True - + if select._offset is not None: limitselect.append_whereclause("ora_rn>%d" % select._offset) if select._limit is not None: @@ -692,7 +692,7 @@ class OracleCompiler(compiler.DefaultCompiler): else: limitselect.append_whereclause("ora_rn<=%d" % select._limit) select = limitselect - + kwargs['iswrapper'] = getattr(select, '_is_wrapper', False) return compiler.DefaultCompiler.visit_select(self, select, **kwargs) @@ -700,7 +700,7 @@ class OracleCompiler(compiler.DefaultCompiler): return "" def for_update_clause(self, select): - if select.for_update=="nowait": + if select.for_update == "nowait": return " FOR UPDATE NOWAIT" else: return super(OracleCompiler, self).for_update_clause(select) @@ -709,7 +709,7 @@ class OracleCompiler(compiler.DefaultCompiler): class OracleSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) - colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 605ce7272..23b0a273e 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -21,7 +21,7 @@ parameter when creating the queries:: import random, re, string -from sqlalchemy import sql, schema, exceptions, util +from sqlalchemy import sql, schema, exc, util from sqlalchemy.engine import base, default from sqlalchemy.sql import compiler, expression from sqlalchemy.sql import operators as sql_operators @@ -99,11 +99,17 @@ class PGText(sqltypes.Text): class PGString(sqltypes.String): def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} + if self.length: + return "VARCHAR(%(length)d)" % {'length' : self.length} + else: + return "VARCHAR" class PGChar(sqltypes.CHAR): def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} + if self.length: + return "CHAR(%(length)d)" % {'length' : self.length} + else: + return "CHAR" class PGBinary(sqltypes.Binary): def get_col_spec(self): @@ -146,7 +152,7 @@ class PGArray(sqltypes.MutableType, sqltypes.Concatenable, sqltypes.TypeEngine): if value is None: return value def convert_item(item): - if isinstance(item, (list,tuple)): + if isinstance(item, (list, tuple)): return [convert_item(child) for child in item] else: if item_proc: @@ -373,7 +379,7 @@ class PGDialect(default.DefaultDialect): def last_inserted_ids(self): if self.context.last_inserted_ids is None: - raise exceptions.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled") + raise exc.InvalidRequestError("no INSERT executed, or can't use cursor.lastrowid without Postgres OIDs enabled") else: return self.context.last_inserted_ids @@ -419,7 +425,7 @@ class PGDialect(default.DefaultDialect): v = connection.execute("select version()").scalar() m = re.match('PostgreSQL (\d+)\.(\d+)\.(\d+)', v) if not m: - raise exceptions.AssertionError("Could not determine version from string '%s'" % v) + raise AssertionError("Could not determine version from string '%s'" % v) return tuple([int(x) for x in m.group(1, 2, 3)]) def reflecttable(self, connection, table, include_columns): @@ -459,7 +465,7 @@ class PGDialect(default.DefaultDialect): rows = c.fetchall() if not rows: - raise exceptions.NoSuchTableError(table.name) + raise exc.NoSuchTableError(table.name) domains = self._load_domains(connection) @@ -519,7 +525,7 @@ class PGDialect(default.DefaultDialect): default = domain['default'] coltype = ischema_names[domain['attype']] else: - coltype=None + coltype = None if coltype: coltype = coltype(*args, **kwargs) @@ -530,7 +536,7 @@ class PGDialect(default.DefaultDialect): (attype, name)) coltype = sqltypes.NULLTYPE - colargs= [] + colargs = [] if default is not None: match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) if match is not None: @@ -560,7 +566,7 @@ class PGDialect(default.DefaultDialect): col = table.c[pk] table.primary_key.add(col) if col.default is None: - col.autoincrement=False + col.autoincrement = False # Foreign keys FK_SQL = """ @@ -697,7 +703,7 @@ class PGCompiler(compiler.DefaultCompiler): yield co else: yield c - columns = [self.process(c) for c in flatten_columnlist(returning_cols)] + columns = [self.process(c, render_labels=True) for c in flatten_columnlist(returning_cols)] text += ' RETURNING ' + string.join(columns, ', ') return text @@ -724,7 +730,7 @@ class PGSchemaGenerator(compiler.SchemaGenerator): else: colspec += " SERIAL" else: - colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index f8bea90eb..a63741cf7 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -7,7 +7,7 @@ import datetime, re, time -from sqlalchemy import schema, exceptions, pool, PassiveDefault +from sqlalchemy import schema, exc, pool, PassiveDefault from sqlalchemy.engine import default import sqlalchemy.types as sqltypes import sqlalchemy.util as util @@ -67,7 +67,7 @@ class DateTimeMixin(object): microsecond = 0 return time.strptime(value, self.__format__)[0:6] + (microsecond,) -class SLDateTime(DateTimeMixin,sqltypes.DateTime): +class SLDateTime(DateTimeMixin, sqltypes.DateTime): __format__ = "%Y-%m-%d %H:%M:%S" __microsecond__ = True @@ -112,11 +112,11 @@ class SLText(sqltypes.Text): class SLString(sqltypes.String): def get_col_spec(self): - return "VARCHAR(%(length)s)" % {'length' : self.length} + return "VARCHAR" + (self.length and "(%d)" % self.length or "") class SLChar(sqltypes.CHAR): def get_col_spec(self): - return "CHAR(%(length)s)" % {'length' : self.length} + return "CHAR" + (self.length and "(%d)" % self.length or "") class SLBinary(sqltypes.Binary): def get_col_spec(self): @@ -203,7 +203,7 @@ class SQLiteDialect(default.DefaultDialect): return tuple([int(x) for x in num.split('.')]) if self.dbapi is not None: sqlite_ver = self.dbapi.version_info - if sqlite_ver < (2,1,'3'): + if sqlite_ver < (2, 1, '3'): util.warn( ("The installed version of pysqlite2 (%s) is out-dated " "and will cause errors in some cases. Version 2.1.3 " @@ -227,7 +227,7 @@ class SQLiteDialect(default.DefaultDialect): def create_connect_args(self, url): if url.username or url.password or url.host or url.port: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Invalid SQLite URL: %s\n" "Valid SQLite URL forms are:\n" " sqlite:///:memory: (or, sqlite://)\n" @@ -270,7 +270,7 @@ class SQLiteDialect(default.DefaultDialect): " SELECT * FROM sqlite_temp_master) " "WHERE type='table' ORDER BY name") rs = connection.execute(s) - except exceptions.DBAPIError: + except exc.DBAPIError: raise s = ("SELECT name FROM sqlite_master " "WHERE type='table' ORDER BY name") @@ -334,13 +334,13 @@ class SQLiteDialect(default.DefaultDialect): args = re.findall(r'(\d+)', args) coltype = coltype(*[int(a) for a in args]) - colargs= [] + colargs = [] if has_default: colargs.append(PassiveDefault('?')) table.append_column(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable, *colargs)) if not found_table: - raise exceptions.NoSuchTableError(table.name) + raise exc.NoSuchTableError(table.name) c = connection.execute("%sforeign_key_list(%s)" % (pragma, qtable)) fks = {} @@ -355,7 +355,7 @@ class SQLiteDialect(default.DefaultDialect): try: fk = fks[constraint_name] except KeyError: - fk = ([],[]) + fk = ([], []) fks[constraint_name] = fk # look up the table based on the given table's engine, not 'self', @@ -438,7 +438,7 @@ class SQLiteCompiler(compiler.DefaultCompiler): class SQLiteSchemaGenerator(compiler.SchemaGenerator): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default diff --git a/lib/sqlalchemy/databases/sybase.py b/lib/sqlalchemy/databases/sybase.py index 2551e90c5..14734c6e0 100644 --- a/lib/sqlalchemy/databases/sybase.py +++ b/lib/sqlalchemy/databases/sybase.py @@ -24,7 +24,7 @@ Known issues / TODO: import datetime, operator -from sqlalchemy import util, sql, schema, exceptions +from sqlalchemy import util, sql, schema, exc from sqlalchemy.sql import compiler, expression from sqlalchemy.engine import default, base from sqlalchemy import types as sqltypes @@ -160,11 +160,11 @@ class SybaseTypeError(sqltypes.TypeEngine): def bind_processor(self, dialect): def process(value): - raise exceptions.NotSupportedError("Data type not supported", [value]) + raise exc.NotSupportedError("Data type not supported", [value]) return process def get_col_spec(self): - raise exceptions.NotSupportedError("Data type not supported") + raise exc.NotSupportedError("Data type not supported") class SybaseNumeric(sqltypes.Numeric): def get_col_spec(self): @@ -487,7 +487,7 @@ class SybaseSQLDialect(default.DefaultDialect): dialect_cls = dialect_mapping[module_name] return dialect_cls.import_dbapi() except KeyError: - raise exceptions.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name) + raise exc.InvalidRequestError("Unsupported SybaseSQL module '%s' requested (must be " + " or ".join([x for x in dialect_mapping.keys()]) + ")" % module_name) else: for dialect_cls in dialect_mapping.values(): try: @@ -527,7 +527,7 @@ class SybaseSQLDialect(default.DefaultDialect): self.context.rowcount = c.rowcount c.DBPROP_COMMITPRESERVE = "Y" except Exception, e: - raise exceptions.DBAPIError.instance(statement, parameters, e) + raise exc.DBAPIError.instance(statement, parameters, e) def table_names(self, connection, schema): """Ignore the schema and the charset for now.""" @@ -597,7 +597,7 @@ class SybaseSQLDialect(default.DefaultDialect): (type, name)) coltype = sqltypes.NULLTYPE coltype = coltype(*args) - colargs= [] + colargs = [] if default is not None: colargs.append(schema.PassiveDefault(sql.text(default))) @@ -624,16 +624,16 @@ class SybaseSQLDialect(default.DefaultDialect): row[0], row[1], row[2], row[3], ) if not primary_table in foreignKeys.keys(): - foreignKeys[primary_table] = [['%s'%(foreign_column)], ['%s.%s'%(primary_table,primary_column)]] + foreignKeys[primary_table] = [['%s' % (foreign_column)], ['%s.%s'%(primary_table, primary_column)]] else: foreignKeys[primary_table][0].append('%s'%(foreign_column)) - foreignKeys[primary_table][1].append('%s.%s'%(primary_table,primary_column)) + foreignKeys[primary_table][1].append('%s.%s'%(primary_table, primary_column)) for primary_table in foreignKeys.keys(): #table.append_constraint(schema.ForeignKeyConstraint(['%s.%s'%(foreign_table, foreign_column)], ['%s.%s'%(primary_table,primary_column)])) table.append_constraint(schema.ForeignKeyConstraint(foreignKeys[primary_table][0], foreignKeys[primary_table][1])) if not found_table: - raise exceptions.NoSuchTableError(table.name) + raise exc.NoSuchTableError(table.name) class SybaseSQLDialect_mxodbc(SybaseSQLDialect): @@ -749,7 +749,7 @@ class SybaseSQLCompiler(compiler.DefaultCompiler): def bindparam_string(self, name): res = super(SybaseSQLCompiler, self).bindparam_string(name) if name.lower().startswith('literal'): - res = 'STRING(%s)'%res + res = 'STRING(%s)' % res return res def get_select_precolumns(self, select): @@ -828,7 +828,7 @@ class SybaseSQLSchemaGenerator(compiler.SchemaGenerator): #colspec += " numeric(30,0) IDENTITY" colspec += " Integer IDENTITY" else: - colspec += " " + column.type.dialect_impl(self.dialect, _for_ddl=column).get_col_spec() + colspec += " " + column.type.dialect_impl(self.dialect).get_col_spec() if not column.nullable: colspec += " NOT NULL" diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 583a02763..2ca2ac5f7 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -13,7 +13,7 @@ and result contexts. """ import inspect, StringIO, sys -from sqlalchemy import exceptions, schema, util, types, logging +from sqlalchemy import exc, schema, util, types, log from sqlalchemy.sql import expression @@ -451,7 +451,7 @@ class Compiled(object): self.statement = statement self.column_keys = column_keys self.bind = bind - self.can_execute = statement.supports_execution() + self.can_execute = statement.supports_execution def compile(self): """Produce the internal string representation of this element.""" @@ -482,7 +482,7 @@ class Compiled(object): e = self.bind if e is None: - raise exceptions.UnboundExecutionError("This Compiled object is not bound to any Engine or Connection.") + raise exc.UnboundExecutionError("This Compiled object is not bound to any Engine or Connection.") return e._execute_compiled(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -541,7 +541,7 @@ class Connection(Connectable): self.__savepoint_seq = 0 self.__branch = _branch self.__invalid = False - + def _branch(self): """Return a new Connection which references this Connection's engine and connection; but does not have close_with_result enabled, @@ -550,7 +550,7 @@ class Connection(Connectable): This is used to execute "sub" statements within a single execution, usually an INSERT statement. """ - return Connection(self.engine, self.__connection, _branch=True) + return self.engine.Connection(self.engine, self.__connection, _branch=True) def dialect(self): "Dialect used by this Connection." @@ -578,11 +578,11 @@ class Connection(Connectable): except AttributeError: if self.__invalid: if self.__transaction is not None: - raise exceptions.InvalidRequestError("Can't reconnect until invalid transaction is rolled back") + raise exc.InvalidRequestError("Can't reconnect until invalid transaction is rolled back") self.__connection = self.engine.raw_connection() self.__invalid = False return self.__connection - raise exceptions.InvalidRequestError("This Connection is closed") + raise exc.InvalidRequestError("This Connection is closed") connection = property(connection) def should_close_with_result(self): @@ -702,7 +702,7 @@ class Connection(Connectable): """ if self.__transaction is not None: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "Cannot start a two phase transaction when a transaction " "is already in progress.") if xid is None: @@ -843,7 +843,7 @@ class Connection(Connectable): if c in Connection.executors: return Connection.executors[c](self, object, multiparams, params) else: - raise exceptions.InvalidRequestError("Unexecutable object type: " + str(type(object))) + raise exc.InvalidRequestError("Unexecutable object type: " + str(type(object))) def _execute_default(self, default, multiparams=None, params=None): return self.engine.dialect.defaultrunner(self.__create_execution_context()).traverse_single(default) @@ -862,7 +862,7 @@ class Connection(Connectable): in the case of 'raw' execution which accepts positional parameters, it may be a list of tuples or lists.""" - if multiparams is None or len(multiparams) == 0: + if not multiparams: if params: return [params] else: @@ -897,7 +897,7 @@ class Connection(Connectable): def _execute_compiled(self, compiled, multiparams=None, params=None, distilled_params=None): """Execute a sql.Compiled object.""" if not compiled.can_execute: - raise exceptions.ArgumentError("Not an executable clause: %s" % (str(compiled))) + raise exc.ArgumentError("Not an executable clause: %s" % (str(compiled))) if distilled_params is None: distilled_params = self.__distill_params(multiparams, params) @@ -924,7 +924,7 @@ class Connection(Connectable): def _handle_dbapi_exception(self, e, statement, parameters, cursor): if getattr(self, '_reentrant_error', False): - raise exceptions.DBAPIError.instance(None, None, e) + raise exc.DBAPIError.instance(None, None, e) self._reentrant_error = True try: if not isinstance(e, self.dialect.dbapi.Error): @@ -939,7 +939,7 @@ class Connection(Connectable): self._autorollback() if self.__close_with_result: self.close() - raise exceptions.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) + raise exc.DBAPIError.instance(statement, parameters, e, connection_invalidated=is_disconnect) finally: del self._reentrant_error @@ -1047,7 +1047,7 @@ class Transaction(object): def commit(self): if not self._parent._is_active: - raise exceptions.InvalidRequestError("This transaction is inactive") + raise exc.InvalidRequestError("This transaction is inactive") self._do_commit() self._is_active = False @@ -1094,7 +1094,7 @@ class TwoPhaseTransaction(Transaction): def prepare(self): if not self._parent._is_active: - raise exceptions.InvalidRequestError("This transaction is inactive") + raise exc.InvalidRequestError("This transaction is inactive") self._connection._prepare_twophase_impl(self.xid) self._is_prepared = True @@ -1110,13 +1110,17 @@ class Engine(Connectable): provide a default implementation of SchemaEngine. """ - def __init__(self, pool, dialect, url, echo=None): + def __init__(self, pool, dialect, url, echo=None, proxy=None): self.pool = pool self.url = url - self.dialect=dialect + self.dialect = dialect self.echo = echo self.engine = self - self.logger = logging.instance_logger(self, echoflag=echo) + self.logger = log.instance_logger(self, echoflag=echo) + if proxy: + self.Connection = _proxy_connection_cls(Connection, proxy) + else: + self.Connection = Connection def name(self): "String name of the [sqlalchemy.engine#Dialect] in use by this ``Engine``." @@ -1124,7 +1128,7 @@ class Engine(Connectable): return sys.modules[self.dialect.__module__].descriptor()['name'] name = property(name) - echo = logging.echo_property() + echo = log.echo_property() def __repr__(self): return 'Engine(%s)' % str(self.url) @@ -1228,7 +1232,7 @@ class Engine(Connectable): def connect(self, **kwargs): """Return a newly allocated Connection object.""" - return Connection(self, **kwargs) + return self.Connection(self, **kwargs) def contextual_connect(self, close_with_result=False, **kwargs): """Return a Connection object which may be newly allocated, or may be part of some ongoing context. @@ -1236,7 +1240,7 @@ class Engine(Connectable): This Connection is meant to be used by the various "auto-connecting" operations. """ - return Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs) + return self.Connection(self, self.pool.connect(), close_with_result=close_with_result, **kwargs) def table_names(self, schema=None, connection=None): """Return a list of all table names available in the database. @@ -1286,6 +1290,22 @@ class Engine(Connectable): return self.pool.unique_connection() +def _proxy_connection_cls(cls, proxy): + class ProxyConnection(cls): + def execute(self, object, *multiparams, **params): + return proxy.execute(self, super(ProxyConnection, self).execute, object, *multiparams, **params) + + def execute_clauseelement(self, elem, multiparams=None, params=None): + return proxy.execute(self, super(ProxyConnection, self).execute, elem, *(multiparams or []), **(params or {})) + + def _cursor_execute(self, cursor, statement, parameters, context=None): + return proxy.cursor_execute(super(ProxyConnection, self)._cursor_execute, cursor, statement, parameters, context, False) + + def _cursor_executemany(self, cursor, statement, parameters, context=None): + return proxy.cursor_execute(super(ProxyConnection, self)._cursor_executemany, cursor, statement, parameters, context, True) + + return ProxyConnection + class RowProxy(object): """Proxy a single cursor row for a parent ResultProxy. @@ -1296,6 +1316,8 @@ class RowProxy(object): results that correspond to constructed SQL expressions). """ + __slots__ = ['__parent', '__row'] + def __init__(self, parent, row): """RowProxy objects are constructed by ResultProxy objects.""" @@ -1488,14 +1510,14 @@ class ResultProxy(object): return props[key._label.lower()] elif hasattr(key, 'name') and key.name.lower() in props: return props[key.name.lower()] - raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key))) + raise exc.NoSuchColumnError("Could not locate column in row for column '%s'" % (str(key))) return rec return util.PopulateDict(lookup_key) def __ambiguous_processor(self, colname): def process(value): - raise exceptions.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % colname) + raise exc.InvalidRequestError("Ambiguous column name '%s' in result set! try 'use_labels' option on select statement." % colname) return process def close(self): diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 3c1721f9d..e39cbdd39 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -12,7 +12,6 @@ as the base class for their own corresponding classes. """ - import re, random from sqlalchemy.engine import base from sqlalchemy.sql import compiler, expression @@ -112,7 +111,7 @@ class DefaultDialect(base.Dialect): This id will be passed to do_begin_twophase(), do_rollback_twophase(), do_commit_twophase(). Its format is unspecified.""" - return "_sa_%032x" % random.randint(0,2**128) + return "_sa_%032x" % random.randint(0, 2 ** 128) def do_savepoint(self, connection, name): connection.execute(expression.SavepointClause(name)) @@ -331,9 +330,9 @@ class DefaultExecutionContext(base.ExecutionContext): if self.dialect.positional: inputsizes = [] for key in self.compiled.positiontup: - typeengine = types[key] - dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) - if dbtype is not None: + typeengine = types[key] + dbtype = typeengine.dialect_impl(self.dialect).get_dbapi_type(self.dialect.dbapi) + if dbtype is not None: inputsizes.append(dbtype) try: self.cursor.setinputsizes(*inputsizes) @@ -395,4 +394,4 @@ class DefaultExecutionContext(base.ExecutionContext): self._last_updated_params = compiled_parameters self.postfetch_cols = self.compiled.postfetch - self.prefetch_cols = self.compiled.prefetch
\ No newline at end of file + self.prefetch_cols = self.compiled.prefetch diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index d4a0ad841..aab191231 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -12,7 +12,7 @@ classes. from sqlalchemy.engine import base, threadlocal, url -from sqlalchemy import util, exceptions +from sqlalchemy import util, exc from sqlalchemy import pool as poollib strategies = {} @@ -77,7 +77,7 @@ class DefaultEngineStrategy(EngineStrategy): try: return dbapi.connect(*cargs, **cparams) except Exception, e: - raise exceptions.DBAPIError.instance(None, None, e) + raise exc.DBAPIError.instance(None, None, e) creator = kwargs.pop('creator', connect) poolclass = (kwargs.pop('poolclass', None) or @@ -200,7 +200,7 @@ class MockEngineStrategy(EngineStrategy): def create(self, entity, **kwargs): kwargs['checkfirst'] = False - self.dialect.schemagenerator(self.dialect ,self, **kwargs).traverse(entity) + self.dialect.schemagenerator(self.dialect, self, **kwargs).traverse(entity) def drop(self, entity, **kwargs): kwargs['checkfirst'] = False diff --git a/lib/sqlalchemy/engine/threadlocal.py b/lib/sqlalchemy/engine/threadlocal.py index e4b2859dc..91b16ed5f 100644 --- a/lib/sqlalchemy/engine/threadlocal.py +++ b/lib/sqlalchemy/engine/threadlocal.py @@ -17,7 +17,7 @@ class TLSession(object): try: return self.__transaction._increment_connect() except AttributeError: - return TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result) + return self.engine.TLConnection(self, self.engine.pool.connect(), close_with_result=close_with_result) def reset(self): try: @@ -81,11 +81,14 @@ class TLSession(object): class TLConnection(base.Connection): - def __init__(self, session, connection, close_with_result): - base.Connection.__init__(self, session.engine, connection, close_with_result=close_with_result) + def __init__(self, session, connection, **kwargs): + base.Connection.__init__(self, session.engine, connection, **kwargs) self.__session = session self.__opencount = 1 + def _branch(self): + return self.engine.Connection(self.engine, self.connection, _branch=True) + def session(self): return self.__session session = property(session) @@ -168,6 +171,12 @@ class TLEngine(base.Engine): super(TLEngine, self).__init__(*args, **kwargs) self.context = util.ThreadLocal() + proxy = kwargs.get('proxy') + if proxy: + self.TLConnection = base._proxy_connection_cls(TLConnection, proxy) + else: + self.TLConnection = TLConnection + def session(self): "Returns the current thread's TLSession" if not hasattr(self.context, 'session'): diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 7364f0227..72d09bf85 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -7,7 +7,7 @@ be used directly and is also accepted directly by ``create_engine()``. """ import re, cgi, sys, urllib -from sqlalchemy import exceptions +from sqlalchemy import exc class URL(object): @@ -53,7 +53,7 @@ class URL(object): self.port = int(port) else: self.port = None - self.database= database + self.database = database self.query = query or {} def __str__(self): @@ -180,7 +180,7 @@ def _parse_rfc1738_args(name): name = components.pop('name') return URL(name, **components) else: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Could not parse rfc1738 URL from string '%s'" % name) def _parse_keyvalue_args(name): diff --git a/lib/sqlalchemy/exceptions.py b/lib/sqlalchemy/exc.py index 43623df93..71b46ca11 100644 --- a/lib/sqlalchemy/exceptions.py +++ b/lib/sqlalchemy/exc.py @@ -3,78 +3,78 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php + """Exceptions used with SQLAlchemy. -The base exception class is SQLAlchemyError. Exceptions which are raised as a result -of DBAPI exceptions are all subclasses of [sqlalchemy.exceptions#DBAPIError].""" +The base exception class is SQLAlchemyError. Exceptions which are raised as a +result of DBAPI exceptions are all subclasses of +[sqlalchemy.exceptions#DBAPIError]. + +""" + class SQLAlchemyError(Exception): """Generic error class.""" class ArgumentError(SQLAlchemyError): - """Raised for all those conditions where invalid arguments are - sent to constructed objects. This error generally corresponds to - construction time state errors. + """Raised when an invalid or conflicting function argument is supplied. + + This error generally corresponds to construction time state errors. + """ +class CircularDependencyError(SQLAlchemyError): + """Raised by topological sorts when a circular dependency is detected""" + + class CompileError(SQLAlchemyError): """Raised when an error occurs during SQL compilation""" -class TimeoutError(SQLAlchemyError): - """Raised when a connection pool times out on getting a connection.""" +# Moved to orm.exc; compatability definition installed by orm import until 0.6 +ConcurrentModificationError = None +class DisconnectionError(SQLAlchemyError): + """A disconnect is detected on a raw DB-API connection. -class ConcurrentModificationError(SQLAlchemyError): - """Raised when a concurrent modification condition is detected.""" + This error is raised and consumed internally by a connection pool. It can + be raised by a ``PoolListener`` so that the host pool forces a disconnect. + """ -class CircularDependencyError(SQLAlchemyError): - """Raised by topological sorts when a circular dependency is detected""" +# Moved to orm.exc; compatability definition installed by orm import until 0.6 +FlushError = None -class FlushError(SQLAlchemyError): - """Raised when an invalid condition is detected upon a ``flush()``.""" +class TimeoutError(SQLAlchemyError): + """Raised when a connection pool times out on getting a connection.""" class InvalidRequestError(SQLAlchemyError): - """SQLAlchemy was asked to do something it can't do, return - nonexistent data, etc. + """SQLAlchemy was asked to do something it can't do. This error generally corresponds to runtime state errors. - """ - -class UnmappedColumnError(InvalidRequestError): - """A mapper was asked to return mapped information about a column - which it does not map""" -class NoSuchTableError(InvalidRequestError): - """SQLAlchemy was asked to load a table's definition from the - database, but the table doesn't exist. """ -class UnboundExecutionError(InvalidRequestError): - """SQL was attempted without a database connection to execute it on.""" +class NoSuchColumnError(KeyError, InvalidRequestError): + """A nonexistent column is requested from a ``RowProxy``.""" -class AssertionError(SQLAlchemyError): - """Corresponds to internal state being detected in an invalid state.""" - - -class NoSuchColumnError(KeyError, SQLAlchemyError): - """Raised by ``RowProxy`` when a nonexistent column is requested from a row.""" - class NoReferencedTableError(InvalidRequestError): """Raised by ``ForeignKey`` when the referred ``Table`` cannot be located.""" -class DisconnectionError(SQLAlchemyError): - """Raised within ``Pool`` when a disconnect is detected on a raw DB-API connection. +class NoSuchTableError(InvalidRequestError): + """Table does not exist or is not visible to a connection.""" - This error is consumed internally by a connection pool. It can be raised by - a ``PoolListener`` so that the host pool forces a disconnect. - """ +class UnboundExecutionError(InvalidRequestError): + """SQL was attempted without a database connection to execute it on.""" + + +# Moved to orm.exc; compatability definition installed by orm import until 0.6 +UnmappedColumnError = None class DBAPIError(SQLAlchemyError): """Raised when the execution of a database operation fails. @@ -93,6 +93,7 @@ class DBAPIError(SQLAlchemyError): The wrapped exception object is available in the ``orig`` attribute. Its type and properties are DB-API implementation specific. + """ def instance(cls, statement, params, orig, connection_invalidated=False): @@ -117,7 +118,7 @@ class DBAPIError(SQLAlchemyError): except Exception, e: text = 'Error in str() of DB-API-generated exception: ' + str(e) SQLAlchemyError.__init__( - self, "(%s) %s" % (orig.__class__.__name__, text)) + self, '(%s) %s' % (orig.__class__.__name__, text)) self.statement = statement self.params = params self.orig = orig @@ -128,39 +129,51 @@ class DBAPIError(SQLAlchemyError): repr(self.statement), repr(self.params)]) -# As of 0.4, SQLError is now DBAPIError +# As of 0.4, SQLError is now DBAPIError. +# SQLError alias will be removed in 0.6. SQLError = DBAPIError class InterfaceError(DBAPIError): """Wraps a DB-API InterfaceError.""" + class DatabaseError(DBAPIError): """Wraps a DB-API DatabaseError.""" + class DataError(DatabaseError): """Wraps a DB-API DataError.""" + class OperationalError(DatabaseError): """Wraps a DB-API OperationalError.""" + class IntegrityError(DatabaseError): """Wraps a DB-API IntegrityError.""" + class InternalError(DatabaseError): """Wraps a DB-API InternalError.""" + class ProgrammingError(DatabaseError): """Wraps a DB-API ProgrammingError.""" + class NotSupportedError(DatabaseError): """Wraps a DB-API NotSupportedError.""" + # Warnings + class SADeprecationWarning(DeprecationWarning): """Issued once per usage of a deprecated API.""" + class SAPendingDeprecationWarning(PendingDeprecationWarning): """Issued once per usage of a deprecated API.""" + class SAWarning(RuntimeWarning): """Issued at runtime.""" diff --git a/lib/sqlalchemy/ext/activemapper.py b/lib/sqlalchemy/ext/activemapper.py deleted file mode 100644 index 02f4b5b35..000000000 --- a/lib/sqlalchemy/ext/activemapper.py +++ /dev/null @@ -1,298 +0,0 @@ -from sqlalchemy import ThreadLocalMetaData, util, Integer -from sqlalchemy import Table, Column, ForeignKey -from sqlalchemy.orm import class_mapper, relation, scoped_session -from sqlalchemy.orm import sessionmaker - -from sqlalchemy.orm import backref as create_backref - -import inspect -import sys - -# -# the "proxy" to the database engine... this can be swapped out at runtime -# -metadata = ThreadLocalMetaData() -Objectstore = scoped_session -objectstore = scoped_session(sessionmaker(autoflush=True, transactional=False)) - -# -# declarative column declaration - this is so that we can infer the colname -# -class column(object): - def __init__(self, coltype, colname=None, foreign_key=None, - primary_key=False, *args, **kwargs): - if isinstance(foreign_key, basestring): - foreign_key = ForeignKey(foreign_key) - - self.coltype = coltype - self.colname = colname - self.foreign_key = foreign_key - self.primary_key = primary_key - self.kwargs = kwargs - self.args = args - -# -# declarative relationship declaration -# -class relationship(object): - def __init__(self, classname, colname=None, backref=None, private=False, - lazy=True, uselist=True, secondary=None, order_by=False, viewonly=False): - self.classname = classname - self.colname = colname - self.backref = backref - self.private = private - self.lazy = lazy - self.uselist = uselist - self.secondary = secondary - self.order_by = order_by - self.viewonly = viewonly - - def process(self, klass, propname, relations): - relclass = ActiveMapperMeta.classes[self.classname] - - if isinstance(self.order_by, str): - self.order_by = [ self.order_by ] - - if isinstance(self.order_by, list): - for itemno in range(len(self.order_by)): - if isinstance(self.order_by[itemno], str): - self.order_by[itemno] = \ - getattr(relclass.c, self.order_by[itemno]) - - backref = self.create_backref(klass) - relations[propname] = relation(relclass.mapper, - secondary=self.secondary, - backref=backref, - private=self.private, - lazy=self.lazy, - uselist=self.uselist, - order_by=self.order_by, - viewonly=self.viewonly) - - def create_backref(self, klass): - if self.backref is None: - return None - - relclass = ActiveMapperMeta.classes[self.classname] - - if klass.__name__ == self.classname: - class_mapper(relclass).compile() - br_fkey = relclass.c[self.colname] - else: - br_fkey = None - - return create_backref(self.backref, remote_side=br_fkey) - - -class one_to_many(relationship): - def __init__(self, *args, **kwargs): - kwargs['uselist'] = True - relationship.__init__(self, *args, **kwargs) - -class one_to_one(relationship): - def __init__(self, *args, **kwargs): - kwargs['uselist'] = False - relationship.__init__(self, *args, **kwargs) - - def create_backref(self, klass): - if self.backref is None: - return None - - relclass = ActiveMapperMeta.classes[self.classname] - - if klass.__name__ == self.classname: - br_fkey = getattr(relclass.c, self.colname) - else: - br_fkey = None - - return create_backref(self.backref, foreignkey=br_fkey, uselist=False) - - -class many_to_many(relationship): - def __init__(self, classname, secondary, backref=None, lazy=True, - order_by=False): - relationship.__init__(self, classname, None, backref, False, lazy, - uselist=True, secondary=secondary, - order_by=order_by) - - -# -# SQLAlchemy metaclass and superclass that can be used to do SQLAlchemy -# mapping in a declarative way, along with a function to process the -# relationships between dependent objects as they come in, without blowing -# up if the classes aren't specified in a proper order -# - -__deferred_classes__ = {} -__processed_classes__ = {} -def process_relationships(klass, was_deferred=False): - # first, we loop through all of the relationships defined on the - # class, and make sure that the related class already has been - # completely processed and defer processing if it has not - defer = False - for propname, reldesc in klass.relations.items(): - found = (reldesc.classname == klass.__name__ or reldesc.classname in __processed_classes__) - if not found: - defer = True - break - - # next, we loop through all the columns looking for foreign keys - # and make sure that we can find the related tables (they do not - # have to be processed yet, just defined), and we defer if we are - # not able to find any of the related tables - if not defer: - for col in klass.columns: - if col.foreign_keys: - found = False - cn = col.foreign_keys[0]._colspec - table_name = cn[:cn.rindex('.')] - for other_klass in ActiveMapperMeta.classes.values(): - if other_klass.table.fullname.lower() == table_name.lower(): - found = True - - if not found: - defer = True - break - - if defer and not was_deferred: - __deferred_classes__[klass.__name__] = klass - - # if we are able to find all related and referred to tables, then - # we can go ahead and assign the relationships to the class - if not defer: - relations = {} - for propname, reldesc in klass.relations.items(): - reldesc.process(klass, propname, relations) - - class_mapper(klass).add_properties(relations) - if klass.__name__ in __deferred_classes__: - del __deferred_classes__[klass.__name__] - __processed_classes__[klass.__name__] = klass - - # finally, loop through the deferred classes and attempt to process - # relationships for them - if not was_deferred: - # loop through the list of deferred classes, processing the - # relationships, until we can make no more progress - last_count = len(__deferred_classes__) + 1 - while last_count > len(__deferred_classes__): - last_count = len(__deferred_classes__) - deferred = __deferred_classes__.copy() - for deferred_class in deferred.values(): - process_relationships(deferred_class, was_deferred=True) - - -class ActiveMapperMeta(type): - classes = {} - metadatas = util.Set() - def __init__(cls, clsname, bases, dict): - table_name = clsname.lower() - columns = [] - relations = {} - autoload = False - _metadata = getattr(sys.modules[cls.__module__], - "__metadata__", metadata) - version_id_col = None - version_id_col_object = None - table_opts = {} - - if 'mapping' in dict: - found_pk = False - - members = inspect.getmembers(dict.get('mapping')) - for name, value in members: - if name == '__table__': - table_name = value - continue - - if '__metadata__' == name: - _metadata= value - continue - - if '__autoload__' == name: - autoload = True - continue - - if '__version_id_col__' == name: - version_id_col = value - - if '__table_opts__' == name: - table_opts = value - - if name.startswith('__'): continue - - if isinstance(value, column): - if value.primary_key == True: found_pk = True - - if value.foreign_key: - col = Column(value.colname or name, - value.coltype, - value.foreign_key, - primary_key=value.primary_key, - *value.args, **value.kwargs) - else: - col = Column(value.colname or name, - value.coltype, - primary_key=value.primary_key, - *value.args, **value.kwargs) - columns.append(col) - continue - - if isinstance(value, relationship): - relations[name] = value - - if not found_pk and not autoload: - col = Column('id', Integer, primary_key=True) - cls.mapping.id = col - columns.append(col) - - assert _metadata is not None, "No MetaData specified" - - ActiveMapperMeta.metadatas.add(_metadata) - - if not autoload: - cls.table = Table(table_name, _metadata, *columns, **table_opts) - cls.columns = columns - else: - cls.table = Table(table_name, _metadata, autoload=True, **table_opts) - cls.columns = cls.table._columns - - if version_id_col is not None: - version_id_col_object = getattr(cls.table.c, version_id_col, None) - assert(version_id_col_object is not None, "version_id_col (%s) does not exist." % version_id_col) - - # check for inheritence - if hasattr(bases[0], "mapping"): - cls._base_mapper= bases[0].mapper - cls.mapper = objectstore.mapper(cls, cls.table, - inherits=cls._base_mapper, version_id_col=version_id_col_object) - else: - cls.mapper = objectstore.mapper(cls, cls.table, version_id_col=version_id_col_object) - cls.relations = relations - ActiveMapperMeta.classes[clsname] = cls - - process_relationships(cls) - - super(ActiveMapperMeta, cls).__init__(clsname, bases, dict) - - - -class ActiveMapper(object): - __metaclass__ = ActiveMapperMeta - - def set(self, **kwargs): - for key, value in kwargs.items(): - setattr(self, key, value) - - -# -# a utility function to create all tables for all ActiveMapper classes -# - -def create_tables(): - for metadata in ActiveMapperMeta.metadatas: - metadata.create_all() - -def drop_tables(): - for metadata in ActiveMapperMeta.metadatas: - metadata.drop_all() diff --git a/lib/sqlalchemy/ext/assignmapper.py b/lib/sqlalchemy/ext/assignmapper.py deleted file mode 100644 index 5a28fbe68..000000000 --- a/lib/sqlalchemy/ext/assignmapper.py +++ /dev/null @@ -1,72 +0,0 @@ -from sqlalchemy import util, exceptions -import types -from sqlalchemy.orm import mapper, Query - -def _monkeypatch_query_method(name, ctx, class_): - def do(self, *args, **kwargs): - query = Query(class_, session=ctx.current) - util.warn_deprecated('Query methods on the class are deprecated; use %s.query.%s instead' % (class_.__name__, name)) - return getattr(query, name)(*args, **kwargs) - try: - do.__name__ = name - except: - pass - if not hasattr(class_, name): - setattr(class_, name, classmethod(do)) - -def _monkeypatch_session_method(name, ctx, class_): - def do(self, *args, **kwargs): - session = ctx.current - return getattr(session, name)(self, *args, **kwargs) - try: - do.__name__ = name - except: - pass - if not hasattr(class_, name): - setattr(class_, name, do) - -def assign_mapper(ctx, class_, *args, **kwargs): - extension = kwargs.pop('extension', None) - if extension is not None: - extension = util.to_list(extension) - extension.append(ctx.mapper_extension) - else: - extension = ctx.mapper_extension - - validate = kwargs.pop('validate', False) - - if not isinstance(getattr(class_, '__init__'), types.MethodType): - def __init__(self, **kwargs): - for key, value in kwargs.items(): - if validate: - if not self.mapper.get_property(key, - resolve_synonyms=False, - raiseerr=False): - raise exceptions.ArgumentError( - "Invalid __init__ argument: '%s'" % key) - setattr(self, key, value) - class_.__init__ = __init__ - - class query(object): - def __getattr__(self, key): - return getattr(ctx.current.query(class_), key) - def __call__(self): - return ctx.current.query(class_) - - if not hasattr(class_, 'query'): - class_.query = query() - - for name in ('get', 'filter', 'filter_by', 'select', 'select_by', - 'selectfirst', 'selectfirst_by', 'selectone', 'selectone_by', - 'get_by', 'join_to', 'join_via', 'count', 'count_by', - 'options', 'instances'): - _monkeypatch_query_method(name, ctx, class_) - for name in ('refresh', 'expire', 'delete', 'expunge', 'update'): - _monkeypatch_session_method(name, ctx, class_) - - m = mapper(class_, extension=extension, *args, **kwargs) - class_.mapper = m - return m - -assign_mapper = util.deprecated( - "assign_mapper is deprecated. Use scoped_session() instead.")(assign_mapper) diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index d878f7b9b..4d54f6072 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -406,13 +406,26 @@ class _AssociationList(object): def clear(self): del self.col[0:len(self.col)] - def __eq__(self, other): return list(self) == other - def __ne__(self, other): return list(self) != other - def __lt__(self, other): return list(self) < other - def __le__(self, other): return list(self) <= other - def __gt__(self, other): return list(self) > other - def __ge__(self, other): return list(self) >= other - def __cmp__(self, other): return cmp(list(self), other) + def __eq__(self, other): + return list(self) == other + + def __ne__(self, other): + return list(self) != other + + def __lt__(self, other): + return list(self) < other + + def __le__(self, other): + return list(self) <= other + + def __gt__(self, other): + return list(self) > other + + def __ge__(self, other): + return list(self) >= other + + def __cmp__(self, other): + return cmp(list(self), other) def __add__(self, iterable): try: @@ -534,13 +547,26 @@ class _AssociationDict(object): def clear(self): self.col.clear() - def __eq__(self, other): return dict(self) == other - def __ne__(self, other): return dict(self) != other - def __lt__(self, other): return dict(self) < other - def __le__(self, other): return dict(self) <= other - def __gt__(self, other): return dict(self) > other - def __ge__(self, other): return dict(self) >= other - def __cmp__(self, other): return cmp(dict(self), other) + def __eq__(self, other): + return dict(self) == other + + def __ne__(self, other): + return dict(self) != other + + def __lt__(self, other): + return dict(self) < other + + def __le__(self, other): + return dict(self) <= other + + def __gt__(self, other): + return dict(self) > other + + def __ge__(self, other): + return dict(self) >= other + + def __cmp__(self, other): + return cmp(dict(self), other) def __repr__(self): return repr(dict(self.items())) @@ -802,12 +828,23 @@ class _AssociationSet(object): def copy(self): return util.Set(self) - def __eq__(self, other): return util.Set(self) == other - def __ne__(self, other): return util.Set(self) != other - def __lt__(self, other): return util.Set(self) < other - def __le__(self, other): return util.Set(self) <= other - def __gt__(self, other): return util.Set(self) > other - def __ge__(self, other): return util.Set(self) >= other + def __eq__(self, other): + return util.Set(self) == other + + def __ne__(self, other): + return util.Set(self) != other + + def __lt__(self, other): + return util.Set(self) < other + + def __le__(self, other): + return util.Set(self) <= other + + def __gt__(self, other): + return util.Set(self) > other + + def __ge__(self, other): + return util.Set(self) >= other def __repr__(self): return repr(util.Set(self)) diff --git a/lib/sqlalchemy/ext/declarative.py b/lib/sqlalchemy/ext/declarative.py index d736736e9..f06f16059 100644 --- a/lib/sqlalchemy/ext/declarative.py +++ b/lib/sqlalchemy/ext/declarative.py @@ -213,6 +213,9 @@ class DeclarativeMeta(type): continue prop = _deferred_relation(cls, value) our_stuff[k] = prop + + # set up attributes in the order they were created + our_stuff.sort(lambda x, y: cmp(our_stuff[x]._creation_order, our_stuff[y]._creation_order)) table = None if '__table__' not in cls.__dict__: @@ -254,6 +257,7 @@ class DeclarativeMeta(type): mapper_cls = util.unbound_method_to_callable(cls.__mapper_cls__) else: mapper_cls = mapper + cls.__mapper__ = mapper_cls(cls, table, properties=our_stuff, **mapper_args) return type.__init__(cls, classname, bases, dict_) diff --git a/lib/sqlalchemy/ext/orderinglist.py b/lib/sqlalchemy/ext/orderinglist.py index e7464b0bd..21adc85a8 100644 --- a/lib/sqlalchemy/ext/orderinglist.py +++ b/lib/sqlalchemy/ext/orderinglist.py @@ -34,7 +34,7 @@ which have a user-defined, serialized order:: u = User() u.topten.append(Blurb('Number one!')) u.topten.append(Blurb('Number two!')) - + # Like magic. assert [blurb.position for blurb in u.topten] == [0, 1] @@ -60,7 +60,7 @@ __all__ = [ 'ordering_list' ] def ordering_list(attr, count_from=None, **kw): """Prepares an OrderingList factory for use in mapper definitions. - + Returns an object suitable for use as an argument to a Mapper relation's ``collection_class`` option. Arguments are: @@ -73,7 +73,7 @@ def ordering_list(attr, count_from=None, **kw): example, ``ordering_list('pos', count_from=1)`` would create a 1-based list in SQL, storing the value in the 'pos' column. Ignored if ``ordering_func`` is supplied. - + Passes along any keyword arguments to ``OrderingList`` constructor. """ @@ -108,7 +108,7 @@ def _unsugar_count_from(**kw): Keyword argument filter, prepares a simple ``ordering_func`` from a ``count_from`` argument, otherwise passes ``ordering_func`` on unchanged. """ - + count_from = kw.pop('count_from', None) if kw.get('ordering_func', None) is None and count_from is not None: if count_from == 0: @@ -126,11 +126,11 @@ class OrderingList(list): ``ordering_list`` function is used to configure ``OrderingList`` collections in ``mapper`` relation definitions. """ - + def __init__(self, ordering_attr=None, ordering_func=None, reorder_on_append=False): """A custom list that manages position information for its children. - + ``OrderingList`` is a ``collection_class`` list implementation that syncs position in a Python list with a position attribute on the mapped objects. @@ -148,7 +148,7 @@ class OrderingList(list): An ``ordering_func`` is called with two positional parameters: the index of the element in the list, and the list itself. - + If omitted, Python list indexes are used for the attribute values. Two basic pre-built numbering functions are provided in this module: ``count_from_0`` and ``count_from_1``. For more exotic examples @@ -194,7 +194,7 @@ class OrderingList(list): def _reorder(self): """Sweep through the list and ensure that each object has accurate ordering information set.""" - + for index, entity in enumerate(self): self._order_entity(index, entity, True) @@ -206,7 +206,7 @@ class OrderingList(list): return should_be = self.ordering_func(index, self) - if have <> should_be: + if have != should_be: self._set_order_value(entity, should_be) def append(self, entity): @@ -229,7 +229,7 @@ class OrderingList(list): entity = super(OrderingList, self).pop(index) self._reorder() return entity - + def __setitem__(self, index, entity): if isinstance(index, slice): for i in range(index.start or 0, index.stop or 0, index.step or 1): @@ -237,7 +237,7 @@ class OrderingList(list): else: self._order_entity(index, entity, True) super(OrderingList, self).__setitem__(index, entity) - + def __delitem__(self, index): super(OrderingList, self).__delitem__(index) self._reorder() diff --git a/lib/sqlalchemy/ext/selectresults.py b/lib/sqlalchemy/ext/selectresults.py deleted file mode 100644 index 446228254..000000000 --- a/lib/sqlalchemy/ext/selectresults.py +++ /dev/null @@ -1,28 +0,0 @@ -"""SelectResults has been rolled into Query. This class is now just a placeholder.""" - -import sqlalchemy.sql as sql -import sqlalchemy.orm as orm - -class SelectResultsExt(orm.MapperExtension): - """a MapperExtension that provides SelectResults functionality for the - results of query.select_by() and query.select()""" - - def select_by(self, query, *args, **params): - q = query - for a in args: - q = q.filter(a) - return q.filter_by(**params) - - def select(self, query, arg=None, **kwargs): - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - return orm.EXT_CONTINUE - else: - if arg is not None: - query = query.filter(arg) - return query._legacy_select_kwargs(**kwargs) - -def SelectResults(query, clause=None, ops={}): - if clause is not None: - query = query.filter(clause) - query = query.options(orm.extension(SelectResultsExt())) - return query._legacy_select_kwargs(**ops) diff --git a/lib/sqlalchemy/ext/sessioncontext.py b/lib/sqlalchemy/ext/sessioncontext.py deleted file mode 100644 index 5ac8acb40..000000000 --- a/lib/sqlalchemy/ext/sessioncontext.py +++ /dev/null @@ -1,50 +0,0 @@ -from sqlalchemy.orm.scoping import ScopedSession, _ScopedExt -from sqlalchemy.util import warn_deprecated -from sqlalchemy.orm import create_session - -__all__ = ['SessionContext', 'SessionContextExt'] - - -class SessionContext(ScopedSession): - """Provides thread-local management of Sessions. - - Usage:: - - context = SessionContext(sessionmaker(autoflush=True)) - - """ - - def __init__(self, session_factory=None, scopefunc=None): - warn_deprecated("SessionContext is deprecated. Use scoped_session().") - if session_factory is None: - session_factory=create_session - super(SessionContext, self).__init__(session_factory, scopefunc=scopefunc) - - def get_current(self): - return self.registry() - - def set_current(self, session): - self.registry.set(session) - - def del_current(self): - self.registry.clear() - - current = property(get_current, set_current, del_current, - """Property used to get/set/del the session in the current scope.""") - - def _get_mapper_extension(self): - try: - return self._extension - except AttributeError: - self._extension = ext = SessionContextExt(self) - return ext - - mapper_extension = property(_get_mapper_extension, - doc="""Get a mapper extension that implements `get_session` using this context. Deprecated.""") - - -class SessionContextExt(_ScopedExt): - def __init__(self, *args, **kwargs): - warn_deprecated("SessionContextExt is deprecated. Use ScopedSession(enhance_classes=True)") - super(SessionContextExt, self).__init__(*args, **kwargs) - diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py index bad9ba5a8..95971f787 100644 --- a/lib/sqlalchemy/ext/sqlsoup.py +++ b/lib/sqlalchemy/ext/sqlsoup.py @@ -210,7 +210,7 @@ Advanced Use Accessing the Session --------------------- -SqlSoup uses a SessionContext to provide thread-local sessions. You +SqlSoup uses a ScopedSession to provide thread-local sessions. You can get a reference to the current one like this:: >>> from sqlalchemy.ext.sqlsoup import objectstore @@ -325,7 +325,7 @@ Boring tests here. Nothing of real expository value. from sqlalchemy import * from sqlalchemy import schema, sql from sqlalchemy.orm import * -from sqlalchemy.ext.sessioncontext import SessionContext +from sqlalchemy.orm.scoping import ScopedSession from sqlalchemy.exceptions import * from sqlalchemy.sql import expression @@ -379,15 +379,24 @@ __all__ = ['PKNotFoundError', 'SqlSoup'] # # thread local SessionContext # -class Objectstore(SessionContext): +class Objectstore(ScopedSession): def __getattr__(self, key): - return getattr(self.current, key) + if key.startswith('__'): # dont trip the registry for module-level sweeps of things + # like '__bases__'. the session gets bound to the + # module which is interfered with by other unit tests. + # (removal of mapper.get_session() revealed the issue) + raise AttributeError() + return getattr(self.registry(), key) + def current(self): + return self.registry() + current = property(current) def get_session(self): - return self.current + return self.registry() objectstore = Objectstore(create_session) -class PKNotFoundError(SQLAlchemyError): pass +class PKNotFoundError(SQLAlchemyError): + pass def _ddl_error(cls): msg = 'SQLSoup can only modify mapped Tables (found: %s)' \ @@ -439,7 +448,7 @@ def _is_outer_join(selectable): def _selectable_name(selectable): if isinstance(selectable, sql.Alias): - return _selectable_name(selectable.selectable) + return _selectable_name(selectable.element) elif isinstance(selectable, sql.Select): return ''.join([_selectable_name(s) for s in selectable.froms]) elif isinstance(selectable, schema.Table): @@ -457,7 +466,7 @@ def class_for_table(selectable, **mapper_kwargs): klass = TableClassType(mapname, (object,), {}) else: klass = SelectableClassType(mapname, (object,), {}) - + def __cmp__(self, o): L = self.__class__.c.keys() L.sort() @@ -482,12 +491,17 @@ def class_for_table(selectable, **mapper_kwargs): for m in ['__cmp__', '__repr__']: setattr(klass, m, eval(m)) klass._table = selectable + klass.c = expression.ColumnCollection() mappr = mapper(klass, selectable, - extension=objectstore.mapper_extension, + extension=objectstore.extension, allow_null_pks=_is_outer_join(selectable), **mapper_kwargs) - klass._query = Query(mappr) + + for k in mappr.iterate_properties: + klass.c[k.key] = k.columns[0] + + klass._query = objectstore.query_property() return klass class SqlSoup: diff --git a/lib/sqlalchemy/interfaces.py b/lib/sqlalchemy/interfaces.py index eaad25769..959989662 100644 --- a/lib/sqlalchemy/interfaces.py +++ b/lib/sqlalchemy/interfaces.py @@ -67,7 +67,7 @@ class PoolListener(object): The ``_ConnectionFairy`` which manages the connection for the span of the current checkout. - If you raise an ``exceptions.DisconnectionError``, the current + If you raise an ``exc.DisconnectionError``, the current connection will be disposed and a fresh connection retrieved. Processing of all checkout listeners will abort and restart using the new connection. @@ -87,3 +87,24 @@ class PoolListener(object): The ``_ConnectionRecord`` that persistently manages the connection """ + +class ConnectionProxy(object): + """Allows interception of statement execution by Connections. + + Subclass ``ConnectionProxy``, overriding either or both of + ``execute()`` and ``cursor_execute()`` The default behavior is provided, + which is to call the given executor function with the remaining + arguments. The proxy is then connected to an engine via + ``create_engine(url, proxy=MyProxy())`` where ``MyProxy`` is + the user-defined ``ConnectionProxy`` class. + + """ + def execute(self, conn, execute, clauseelement, *multiparams, **params): + """""" + return execute(clauseelement, *multiparams, **params) + + def cursor_execute(self, execute, cursor, statement, parameters, context, executemany): + """""" + return execute(cursor, statement, parameters, context) + + diff --git a/lib/sqlalchemy/logging.py b/lib/sqlalchemy/log.py index 13872caa3..65100d469 100644 --- a/lib/sqlalchemy/logging.py +++ b/lib/sqlalchemy/log.py @@ -1,4 +1,4 @@ -# logging.py - adapt python logging module to SQLAlchemy +# log.py - adapt python logging module to SQLAlchemy # Copyright (C) 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under @@ -28,25 +28,19 @@ is equivalent to:: logger.setLevel(logging.DEBUG) """ -import sys, warnings -import sqlalchemy.exceptions as sa_exc +import logging +import sys -# py2.5 absolute imports will fix.... -logging = __import__('logging') - -# moved to sqlalchemy.exceptions. this alias will be removed in 0.5. -SADeprecationWarning = sa_exc.SADeprecationWarning rootlogger = logging.getLogger('sqlalchemy') if rootlogger.level == logging.NOTSET: rootlogger.setLevel(logging.WARN) -warnings.filterwarnings("once", category=sa_exc.SADeprecationWarning) default_enabled = False def default_logging(name): global default_enabled if logging.getLogger(name).getEffectiveLevel() < logging.WARN: - default_enabled=True + default_enabled = True if not default_enabled: default_enabled = True handler = logging.StreamHandler(sys.stdout) diff --git a/lib/sqlalchemy/mods/__init__.py b/lib/sqlalchemy/mods/__init__.py deleted file mode 100644 index e69de29bb..000000000 --- a/lib/sqlalchemy/mods/__init__.py +++ /dev/null diff --git a/lib/sqlalchemy/mods/selectresults.py b/lib/sqlalchemy/mods/selectresults.py deleted file mode 100644 index 25bfa2840..000000000 --- a/lib/sqlalchemy/mods/selectresults.py +++ /dev/null @@ -1,7 +0,0 @@ -from sqlalchemy.ext.selectresults import SelectResultsExt -from sqlalchemy.orm.mapper import global_extensions - -def install_plugin(): - global_extensions.append(SelectResultsExt) - -install_plugin() diff --git a/lib/sqlalchemy/orm/__init__.py b/lib/sqlalchemy/orm/__init__.py index 2466a2763..9c23fd409 100644 --- a/lib/sqlalchemy/orm/__init__.py +++ b/lib/sqlalchemy/orm/__init__.py @@ -9,63 +9,98 @@ Functional constructs for ORM configuration. See the SQLAlchemy object relational tutorial and mapper configuration documentation for an overview of how this module is used. + """ -from sqlalchemy.orm.mapper import Mapper, object_mapper, class_mapper, _mapper_registry -from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, EXT_STOP, EXT_PASS, ExtensionOption, PropComparator -from sqlalchemy.orm.properties import SynonymProperty, ComparableProperty, PropertyLoader, ColumnProperty, CompositeProperty, BackRef +from sqlalchemy.orm import exc +from sqlalchemy.orm.mapper import \ + Mapper, _mapper_registry, class_mapper, object_mapper +from sqlalchemy.orm.interfaces import \ + EXT_CONTINUE, EXT_STOP, ExtensionOption, InstrumentationManager, \ + MapperExtension, PropComparator, SessionExtension +from sqlalchemy.orm.properties import \ + BackRef, ColumnProperty, ComparableProperty, CompositeProperty, \ + PropertyLoader, SynonymProperty from sqlalchemy.orm import mapper as mapperlib from sqlalchemy.orm import strategies -from sqlalchemy.orm.query import Query, aliased -from sqlalchemy.orm.util import polymorphic_union, create_row_adapter +from sqlalchemy.orm.query import AliasOption, Query +from sqlalchemy.orm.util import \ + AliasedClass as aliased, join, outerjoin, polymorphic_union, with_parent +from sqlalchemy.sql import util as sql_util from sqlalchemy.orm.session import Session as _Session from sqlalchemy.orm.session import object_session, sessionmaker from sqlalchemy.orm.scoping import ScopedSession - - -__all__ = [ 'relation', 'column_property', 'composite', 'backref', 'eagerload', - 'eagerload_all', 'lazyload', 'noload', 'deferred', 'defer', - 'undefer', 'undefer_group', 'extension', 'mapper', 'clear_mappers', - 'compile_mappers', 'class_mapper', 'object_mapper', 'sessionmaker', - 'scoped_session', 'dynamic_loader', 'MapperExtension', - 'polymorphic_union', 'comparable_property', - 'create_session', 'synonym', 'contains_alias', 'Query', 'aliased', - 'contains_eager', 'EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', - 'object_session', 'PropComparator' ] +from sqlalchemy import util as sa_util + +__all__ = ( + 'EXT_CONTINUE', + 'EXT_STOP', + 'InstrumentationManager', + 'MapperExtension', + 'PropComparator', + 'Query', + 'aliased', + 'backref', + 'class_mapper', + 'clear_mappers', + 'column_property', + 'comparable_property', + 'compile_mappers', + 'composite', + 'contains_alias', + 'contains_eager', + 'create_session', + 'defer', + 'deferred', + 'dynamic_loader', + 'eagerload', + 'eagerload_all', + 'extension', + 'lazyload', + 'mapper', + 'noload', + 'object_mapper', + 'object_session', + 'polymorphic_union', + 'relation', + 'scoped_session', + 'sessionmaker', + 'synonym', + 'undefer', + 'undefer_group', + ) def scoped_session(session_factory, scopefunc=None): - """Provides thread-local management of Sessions. - - This is a front-end function to the [sqlalchemy.orm.scoping#ScopedSession] - class. + """Provides thread-local management of Sessions. - Usage:: + This is a front-end function to the [sqlalchemy.orm.scoping#ScopedSession] + class. - Session = scoped_session(sessionmaker(autoflush=True)) + Usage:: - To instantiate a Session object which is part of the scoped - context, instantiate normally:: + Session = scoped_session(sessionmaker(autoflush=True)) - session = Session() + To instantiate a Session object which is part of the scoped context, + instantiate normally:: - Most session methods are available as classmethods from - the scoped session:: + session = Session() - Session.commit() - Session.close() + Most session methods are available as classmethods from the scoped + session:: - To map classes so that new instances are saved in the current - Session automatically, as well as to provide session-aware - class attributes such as "query", use the `mapper` classmethod - from the scoped session:: + Session.commit() + Session.close() - mapper = Session.mapper - mapper(Class, table, ...) + To map classes so that new instances are saved in the current Session + automatically, as well as to provide session-aware class attributes such + as "query", use the `mapper` classmethod from the scoped session:: - """ + mapper = Session.mapper + mapper(Class, table, ...) - return ScopedSession(session_factory, scopefunc=scopefunc) + """ + return ScopedSession(session_factory, scopefunc=scopefunc) def create_session(bind=None, **kwargs): """create a new [sqlalchemy.orm.session#Session]. @@ -76,26 +111,36 @@ def create_session(bind=None, **kwargs): It is recommended to use the [sqlalchemy.orm#sessionmaker()] function instead of create_session(). """ + + if 'transactional' in kwargs: + sa_util.warn_deprecated( + "The 'transactional' argument to sessionmaker() is deprecated; " + "use autocommit=True|False instead.") + if 'autocommit' in kwargs: + raise TypeError('Specify autocommit *or* transactional, not both.') + kwargs['autocommit'] = not kwargs.pop('transactional') + kwargs.setdefault('autoflush', False) - kwargs.setdefault('transactional', False) + kwargs.setdefault('autocommit', True) + kwargs.setdefault('autoexpire', False) return _Session(bind=bind, **kwargs) def relation(argument, secondary=None, **kwargs): """Provide a relationship of a primary Mapper to a secondary Mapper. - This corresponds to a parent-child or associative table relationship. - The constructed class is an instance of [sqlalchemy.orm.properties#PropertyLoader]. + This corresponds to a parent-child or associative table relationship. The + constructed class is an instance of + [sqlalchemy.orm.properties#PropertyLoader]. argument a class or Mapper instance, representing the target of the relation. secondary for a many-to-many relationship, specifies the intermediary table. The - ``secondary`` keyword argument should generally only be used for a table - that is not otherwise expressed in any class mapping. In particular, - using the Association Object Pattern is - generally mutually exclusive against using the ``secondary`` keyword - argument. + ``secondary`` keyword argument should generally only be used for a + table that is not otherwise expressed in any class mapping. In + particular, using the Association Object Pattern is generally mutually + exclusive against using the ``secondary`` keyword argument. \**kwargs follow: @@ -482,8 +527,8 @@ def mapper(class_, local_table=None, *args, **params): which will identify the class/mapper combination to be used with a particular row. Requires the ``polymorphic_identity`` value to be set for all mappers in the inheritance - hierarchy. The column specified by ``polymorphic_on`` is - usually a column that resides directly within the base + hierarchy. The column specified by ``polymorphic_on`` is + usually a column that resides directly within the base mapper's mapped table; alternatively, it may be a column that is only present within the <selectable> portion of the ``with_polymorphic`` argument. @@ -532,7 +577,7 @@ def mapper(class_, local_table=None, *args, **params): to be used against this mapper's selectable unit. This is normally simply the primary key of the `local_table`, but can be overridden here. - + with_polymorphic A tuple in the form ``(<classes>, <selectable>)`` indicating the default style of "polymorphic" loading, that is, which tables @@ -549,9 +594,9 @@ def mapper(class_, local_table=None, *args, **params): which load from a "concrete" inheriting table, the <selectable> argument is required, since it usually requires more complex UNION queries. - + select_table - Deprecated. Synonymous with + Deprecated. Synonymous with ``with_polymorphic=('*', <selectable>)``. version_id_col @@ -677,15 +722,16 @@ def extension(ext): return ExtensionOption(ext) -def eagerload(name, mapper=None): +def eagerload(*keys): """Return a ``MapperOption`` that will convert the property of the given name into an eager load. Used with ``query.options()``. """ - return strategies.EagerLazyOption(name, lazy=False, mapper=mapper) + return strategies.EagerLazyOption(keys, lazy=False) +eagerload = sa_util.array_as_starargs_fn_decorator(eagerload) -def eagerload_all(name, mapper=None): +def eagerload_all(*keys): """Return a ``MapperOption`` that will convert all properties along the given dot-separated path into an eager load. For example, this:: @@ -698,25 +744,27 @@ def eagerload_all(name, mapper=None): Used with ``query.options()``. """ - return strategies.EagerLazyOption(name, lazy=False, chained=True, mapper=mapper) + return strategies.EagerLazyOption(keys, lazy=False, chained=True) +eagerload_all = sa_util.array_as_starargs_fn_decorator(eagerload_all) -def lazyload(name, mapper=None): +def lazyload(*keys): """Return a ``MapperOption`` that will convert the property of the given name into a lazy load. Used with ``query.options()``. """ - return strategies.EagerLazyOption(name, lazy=True, mapper=mapper) + return strategies.EagerLazyOption(keys, lazy=True) +lazyload = sa_util.array_as_starargs_fn_decorator(lazyload) -def noload(name): +def noload(*keys): """Return a ``MapperOption`` that will convert the property of the given name into a non-load. Used with ``query.options()``. """ - return strategies.EagerLazyOption(name, lazy=None) + return strategies.EagerLazyOption(keys, lazy=None) def contains_alias(alias): """Return a ``MapperOption`` that will indicate to the query that @@ -726,22 +774,9 @@ def contains_alias(alias): alias. """ - class AliasedRow(MapperExtension): - def __init__(self, alias): - self.alias = alias - if isinstance(self.alias, basestring): - self.translator = None - else: - self.translator = create_row_adapter(alias) - - def translate_row(self, mapper, context, row): - if not self.translator: - self.translator = create_row_adapter(mapper.mapped_table.alias(self.alias)) - return self.translator(row) - - return ExtensionOption(AliasedRow(alias)) + return AliasOption(alias) -def contains_eager(key, alias=None, decorator=None): +def contains_eager(*keys, **kwargs): """Return a ``MapperOption`` that will indicate to the query that the given attribute will be eagerly loaded. @@ -752,30 +787,31 @@ def contains_eager(key, alias=None, decorator=None): `alias` is the string name of an alias, **or** an ``sql.Alias`` object, which represents the aliased columns in the query. This argument is optional. - - `decorator` is mutually exclusive of `alias` and is a - row-processing function which will be applied to the incoming row - before sending to the eager load handler. use this for more - sophisticated row adjustments beyond a straight alias. """ + alias = kwargs.pop('alias', None) + if kwargs: + raise exceptions.ArgumentError("Invalid kwargs for contains_eager: %r" % kwargs.keys()) + + return (strategies.EagerLazyOption(keys, lazy=False), strategies.LoadEagerFromAliasOption(keys, alias=alias)) +contains_eager = sa_util.array_as_starargs_fn_decorator(contains_eager) - return (strategies.EagerLazyOption(key, lazy=False), strategies.RowDecorateOption(key, alias=alias, decorator=decorator)) - -def defer(name): +def defer(*keys): """Return a ``MapperOption`` that will convert the column property of the given name into a deferred load. Used with ``query.options()``""" - return strategies.DeferredOption(name, defer=True) + return strategies.DeferredOption(keys, defer=True) +defer = sa_util.array_as_starargs_fn_decorator(defer) -def undefer(name): +def undefer(*keys): """Return a ``MapperOption`` that will convert the column property of the given name into a non-deferred (regular column) load. Used with ``query.options()``. """ - return strategies.DeferredOption(name, defer=False) + return strategies.DeferredOption(keys, defer=False) +undefer = sa_util.array_as_starargs_fn_decorator(undefer) def undefer_group(name): """Return a ``MapperOption`` that will convert the given diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index fb0621a70..7ce825c9d 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -4,26 +4,69 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -import operator, weakref -from itertools import chain -import UserDict +import operator +import weakref + from sqlalchemy import util +from sqlalchemy.util import attrgetter, itemgetter, EMPTY_SET from sqlalchemy.orm import interfaces, collections -from sqlalchemy.orm.util import identity_equal -from sqlalchemy import exceptions +import sqlalchemy.exceptions as sa_exc + +# lazy imports +_entity_info = None +identity_equal = None PASSIVE_NORESULT = util.symbol('PASSIVE_NORESULT') ATTR_WAS_SET = util.symbol('ATTR_WAS_SET') NO_VALUE = util.symbol('NO_VALUE') NEVER_SET = util.symbol('NEVER_SET') +NO_ENTITY_NAME = util.symbol('NO_ENTITY_NAME') -class InstrumentedAttribute(interfaces.PropComparator): - """public-facing instrumented attribute, placed in the - class dictionary. +INSTRUMENTATION_MANAGER = '__sa_instrumentation_manager__' +"""Attribute, elects custom instrumentation when present on a mapped class. - """ +Allows a class to specify a slightly or wildly different technique for +tracking changes made to mapped attributes and collections. + +Only one instrumentation implementation is allowed in a given object +inheritance hierarchy. + +The value of this attribute must be a callable and will be passed a class +object. The callable must return one of: + + - An instance of an interfaces.InstrumentationManager or subclass + - An object implementing all or some of InstrumentationManager (todo) + - A dictionary of callables, implementing all or some of the above (todo) + - An instance of a ClassManager or subclass + +interfaces.InstrumentationManager is public API and will remain stable +between releases. ClassManager is not public and no guarantees are made +about stability. Caveat emptor. + +This attribute is consulted by the default SQLAlchemy instrumentation +resultion code. If custom finders are installed in the global +instrumentation_finders list, they may or may not choose to honor this +attribute. + +""" - def __init__(self, impl, comparator=None): +instrumentation_finders = [] +"""An extensible sequence of instrumentation implementation finding callables. + +Finders callables will be passed a class object. If None is returned, the +next finder in the sequence is consulted. Otherwise the return must be an +instrumentation factory that follows the same guidelines as +INSTRUMENTATION_MANAGER. + +By default, the only finder is find_native_user_instrumentation_hook, which +searches for INSTRUMENTATION_MANAGER. If all finders return None, standard +ClassManager instrumentation is used. + +""" + +class QueryableAttribute(interfaces.PropComparator): + + def __init__(self, impl, comparator=None, parententity=None): """Construct an InstrumentedAttribute. comparator a sql.Comparator to which class-level compare/math events will be sent @@ -31,76 +74,52 @@ class InstrumentedAttribute(interfaces.PropComparator): self.impl = impl self.comparator = comparator + self.parententity = parententity - def __set__(self, instance, value): - self.impl.set(instance._state, value, None) - - def __delete__(self, instance): - self.impl.delete(instance._state) - - def __get__(self, instance, owner): - if instance is None: - return self - return self.impl.get(instance._state) + if parententity: + mapper, selectable, is_aliased_class = _entity_info(parententity, compile=False) + self.property = mapper._get_property(self.impl.key) + else: + self.property = None def get_history(self, instance, **kwargs): - return self.impl.get_history(instance._state, **kwargs) - - def clause_element(self): - return self.comparator.clause_element() - - def expression_element(self): - return self.comparator.expression_element() - + return self.impl.get_history(instance_state(instance), **kwargs) + + def __selectable__(self): + # TODO: conditionally attach this method based on clause_element ? + return self + + def __clause_element__(self): + return self.comparator.__clause_element__() + + def label(self, name): + return self.__clause_element__().label(name) + def operate(self, op, *other, **kwargs): return op(self.comparator, *other, **kwargs) def reverse_operate(self, op, other, **kwargs): return op(other, self.comparator, **kwargs) - def hasparent(self, instance, optimistic=False): - return self.impl.hasparent(instance._state, optimistic=optimistic) - - def _property(self): - from sqlalchemy.orm.mapper import class_mapper - return class_mapper(self.impl.class_).get_property(self.impl.key) - property = property(_property, doc="the MapperProperty object associated with this attribute") - -class ProxiedAttribute(InstrumentedAttribute): - """Adds InstrumentedAttribute class-level behavior to a regular descriptor. - - Obsoleted by proxied_attribute_factory. - """ + def hasparent(self, state, optimistic=False): + return self.impl.hasparent(state, optimistic=optimistic) - class ProxyImpl(object): - accepts_scalar_loader = False + def __str__(self): + return repr(self.parententity) + "." + self.property.key - def __init__(self, key): - self.key = key +class InstrumentedAttribute(QueryableAttribute): + """Public-facing descriptor, placed in the mapped class dictionary.""" - def __init__(self, key, user_prop, comparator=None): - self.user_prop = user_prop - self._comparator = comparator - self.key = key - self.impl = ProxiedAttribute.ProxyImpl(key) + def __set__(self, instance, value): + self.impl.set(instance_state(instance), value, None) - def comparator(self): - if callable(self._comparator): - self._comparator = self._comparator() - return self._comparator - comparator = property(comparator) + def __delete__(self, instance): + self.impl.delete(instance_state(instance)) def __get__(self, instance, owner): if instance is None: - self.user_prop.__get__(instance, owner) return self - return self.user_prop.__get__(instance, owner) - - def __set__(self, instance, value): - return self.user_prop.__set__(instance, value) - - def __delete__(self, instance): - return self.user_prop.__delete__(instance) + return self.impl.get(instance_state(instance)) def proxied_attribute_factory(descriptor): """Create an InstrumentedAttribute / user descriptor hybrid. @@ -111,17 +130,19 @@ def proxied_attribute_factory(descriptor): class ProxyImpl(object): accepts_scalar_loader = False + def __init__(self, key): self.key = key - + class Proxy(InstrumentedAttribute): """A combination of InsturmentedAttribute and a regular descriptor.""" - def __init__(self, key, descriptor, comparator): + def __init__(self, key, descriptor, comparator, parententity): self.key = key # maintain ProxiedAttribute.user_prop compatability. self.descriptor = self.user_prop = descriptor self._comparator = comparator + self._parententity = parententity self.impl = ProxyImpl(key) def comparator(self): @@ -148,6 +169,11 @@ def proxied_attribute_factory(descriptor): def __getattr__(self, attribute): """Delegate __getattr__ to the original descriptor.""" return getattr(descriptor, attribute) + + def _property(self): + return self._parententity.get_property(self.key, resolve_synonyms=True) + property = property(_property) + Proxy.__name__ = type(descriptor).__name__ + 'Proxy' util.monkeypatch_proxied_specials(Proxy, type(descriptor), @@ -158,7 +184,7 @@ def proxied_attribute_factory(descriptor): class AttributeImpl(object): """internal implementation for instrumented attributes.""" - def __init__(self, class_, key, callable_, trackparent=False, extension=None, compare_function=None, **kwargs): + def __init__(self, class_, key, callable_, class_manager, trackparent=False, extension=None, compare_function=None, **kwargs): """Construct an AttributeImpl. class_ @@ -190,6 +216,7 @@ class AttributeImpl(object): self.class_ = class_ self.key = key self.callable_ = callable_ + self.class_manager = class_manager self.trackparent = trackparent if compare_function is None: self.is_equal = operator.eq @@ -210,16 +237,16 @@ class AttributeImpl(object): An instance attribute that is loaded by a callable function will also not have a `hasparent` flag. - """ + """ return state.parents.get(id(self), optimistic) def sethasparent(self, state, value): """Set a boolean flag on the given item corresponding to whether or not it is attached to a parent object via the attribute represented by this ``InstrumentedAttribute``. - """ + """ state.parents[id(self)] = value def set_callable(self, state, callable_): @@ -235,8 +262,8 @@ class AttributeImpl(object): The callable overrides the class level callable set in the ``InstrumentedAttribute` constructor. - """ + """ if callable_ is None: self.initialize(state) else: @@ -249,7 +276,7 @@ class AttributeImpl(object): if self.key in state.callables: return state.callables[self.key] elif self.callable_ is not None: - return self.callable_(state.obj()) + return self.callable_(state) else: return None @@ -271,7 +298,7 @@ class AttributeImpl(object): return state.dict[self.key] except KeyError: # if no history, check for lazy callables, etc. - if self.key not in state.committed_state: + if state.committed_state.get(self.key, NEVER_SET) is NEVER_SET: callable_ = self._get_callable(state) if callable_ is not None: if passive: @@ -310,34 +337,54 @@ class AttributeImpl(object): def set_committed_value(self, state, value): """set an attribute value on the given instance and 'commit' it.""" - state.commit_attr(self, value) + state.commit([self.key]) + + state.callables.pop(self.key, None) + state.dict[self.key] = value + return value class ScalarAttributeImpl(AttributeImpl): """represents a scalar value-holding InstrumentedAttribute.""" accepts_scalar_loader = True + uses_objects = False def delete(self, state): - if self.key not in state.committed_state: - state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE) + state.modified_event(self, False, state.dict.get(self.key, NO_VALUE)) # TODO: catch key errors, convert to attributeerror? - del state.dict[self.key] - state.modified=True + if self.extensions: + old = self.get(state) + del state.dict[self.key] + self.fire_remove_event(state, old, None) + else: + del state.dict[self.key] def get_history(self, state, passive=False): - return _create_history(self, state, state.dict.get(self.key, NO_VALUE)) + return History.from_attribute( + self, state, state.dict.get(self.key, NO_VALUE)) def set(self, state, value, initiator): if initiator is self: return - if self.key not in state.committed_state: - state.committed_state[self.key] = state.dict.get(self.key, NO_VALUE) + state.modified_event(self, False, state.dict.get(self.key, NO_VALUE)) - state.dict[self.key] = value - state.modified=True + if self.extensions: + old = self.get(state) + state.dict[self.key] = value + self.fire_replace_event(state, value, old, initiator) + else: + state.dict[self.key] = value + + def fire_replace_event(self, state, value, previous, initiator): + for ext in self.extensions: + ext.set(state, value, previous, initiator or self) + + def fire_remove_event(self, state, value, initiator): + for ext in self.extensions: + ext.remove(state, value, initiator or self) def type(self): self.property.columns[0].type @@ -348,39 +395,38 @@ class MutableScalarAttributeImpl(ScalarAttributeImpl): changes within the value itself. """ - def __init__(self, class_, key, callable_, copy_function=None, compare_function=None, **kwargs): - super(ScalarAttributeImpl, self).__init__(class_, key, callable_, compare_function=compare_function, **kwargs) - class_._class_state.has_mutable_scalars = True + uses_objects = False + + def __init__(self, class_, key, callable_, class_manager, copy_function=None, compare_function=None, **kwargs): + super(ScalarAttributeImpl, self).__init__(class_, key, callable_, class_manager, compare_function=compare_function, **kwargs) + class_manager.mutable_attributes.add(key) if copy_function is None: - raise exceptions.ArgumentError("MutableScalarAttributeImpl requires a copy function") + raise sa_exc.ArgumentError("MutableScalarAttributeImpl requires a copy function") self.copy = copy_function def get_history(self, state, passive=False): - return _create_history(self, state, state.dict.get(self.key, NO_VALUE)) + return History.from_attribute( + self, state, state.dict.get(self.key, NO_VALUE)) - def commit_to_state(self, state, value): - state.committed_state[self.key] = self.copy(value) + def commit_to_state(self, state, dest): + dest[self.key] = self.copy(state.dict[self.key]) def check_mutable_modified(self, state): (added, unchanged, deleted) = self.get_history(state, passive=True) - if added or deleted: - state.modified = True - return True - else: - return False + return bool(added or deleted) def set(self, state, value, initiator): if initiator is self: return - if self.key not in state.committed_state: - if self.key in state.dict: - state.committed_state[self.key] = self.copy(state.dict[self.key]) - else: - state.committed_state[self.key] = NO_VALUE + state.modified_event(self, True, NEVER_SET) - state.dict[self.key] = value - state.modified=True + if self.extensions: + old = self.get(state) + state.dict[self.key] = value + self.fire_replace_event(state, value, old, initiator) + else: + state.dict[self.key] = value class ScalarObjectAttributeImpl(ScalarAttributeImpl): @@ -390,10 +436,11 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): """ accepts_scalar_loader = False + uses_objects = True - def __init__(self, class_, key, callable_, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): + def __init__(self, class_, key, callable_, class_manager, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): super(ScalarObjectAttributeImpl, self).__init__(class_, key, - callable_, trackparent=trackparent, extension=extension, + callable_, class_manager, trackparent=trackparent, extension=extension, compare_function=compare_function, **kwargs) if compare_function is None: self.is_equal = identity_equal @@ -406,13 +453,13 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): def get_history(self, state, passive=False): if self.key in state.dict: - return _create_history(self, state, state.dict[self.key]) + return History.from_attribute(self, state, state.dict[self.key]) else: current = self.get(state, passive=passive) if current is PASSIVE_NORESULT: return (None, None, None) else: - return _create_history(self, state, current) + return History.from_attribute(self, state, current) def set(self, state, value, initiator): """Set a value on the given InstanceState. @@ -424,43 +471,33 @@ class ScalarObjectAttributeImpl(ScalarAttributeImpl): if initiator is self: return - - if value is not None and not hasattr(value, '_state'): - raise TypeError("Can not assign %s instance to %s's %r attribute, " - "a mapped instance was expected." % ( - type(value).__name__, type(state.obj()).__name__, self.key)) - - # TODO: add options to allow the get() to be passive + + # may want to add options to allow the get() here to be passive old = self.get(state) state.dict[self.key] = value self.fire_replace_event(state, value, old, initiator) def fire_remove_event(self, state, value, initiator): - if self.key not in state.committed_state: - state.committed_state[self.key] = value - state.modified = True + state.modified_event(self, False, value) if self.trackparent and value is not None: - self.sethasparent(value._state, False) + self.sethasparent(instance_state(value), False) - instance = state.obj() for ext in self.extensions: - ext.remove(instance, value, initiator or self) + ext.remove(state, value, initiator or self) def fire_replace_event(self, state, value, previous, initiator): - if self.key not in state.committed_state: - state.committed_state[self.key] = previous - state.modified = True + state.modified_event(self, False, previous) if self.trackparent: if value is not None: - self.sethasparent(value._state, True) + self.sethasparent(instance_state(value), True) if previous is not value and previous is not None: - self.sethasparent(previous._state, False) + self.sethasparent(instance_state(previous), False) - instance = state.obj() for ext in self.extensions: - ext.set(instance, value, previous, initiator or self) + ext.set(state, value, previous, initiator or self) + class CollectionAttributeImpl(AttributeImpl): """A collection-holding attribute that instruments changes in membership. @@ -471,22 +508,21 @@ class CollectionAttributeImpl(AttributeImpl): container object (defaulting to a list) and brokers access to the CollectionAdapter, a "view" onto that object that presents consistent bag semantics to the orm layer independent of the user data implementation. + """ accepts_scalar_loader = False + uses_objects = True - def __init__(self, class_, key, callable_, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): + def __init__(self, class_, key, callable_, class_manager, typecallable=None, trackparent=False, extension=None, copy_function=None, compare_function=None, **kwargs): super(CollectionAttributeImpl, self).__init__(class_, - key, callable_, trackparent=trackparent, extension=extension, - compare_function=compare_function, **kwargs) + key, callable_, class_manager, trackparent=trackparent, + extension=extension, compare_function=compare_function, **kwargs) if copy_function is None: copy_function = self.__copy self.copy = copy_function - if typecallable is None: - typecallable = list - self.collection_factory = \ - collections._prepare_instrumentation(typecallable) + self.collection_factory = typecallable # may be removed in 0.5: self.collection_interface = \ util.duck_type_collection(self.collection_factory()) @@ -499,42 +535,34 @@ class CollectionAttributeImpl(AttributeImpl): if current is PASSIVE_NORESULT: return (None, None, None) else: - return _create_history(self, state, current) + return History.from_attribute(self, state, current) def fire_append_event(self, state, value, initiator): - if self.key not in state.committed_state and self.key in state.dict: - state.committed_state[self.key] = self.copy(state.dict[self.key]) - - state.modified = True + state.modified_event(self, True, NEVER_SET, passive=True) if self.trackparent and value is not None: - self.sethasparent(value._state, True) - instance = state.obj() + self.sethasparent(instance_state(value), True) + for ext in self.extensions: - ext.append(instance, value, initiator or self) + ext.append(state, value, initiator or self) def fire_pre_remove_event(self, state, initiator): - if self.key not in state.committed_state and self.key in state.dict: - state.committed_state[self.key] = self.copy(state.dict[self.key]) + state.modified_event(self, True, NEVER_SET, passive=True) def fire_remove_event(self, state, value, initiator): - if self.key not in state.committed_state and self.key in state.dict: - state.committed_state[self.key] = self.copy(state.dict[self.key]) - - state.modified = True + state.modified_event(self, True, NEVER_SET, passive=True) if self.trackparent and value is not None: - self.sethasparent(value._state, False) + self.sethasparent(instance_state(value), False) - instance = state.obj() for ext in self.extensions: - ext.remove(instance, value, initiator or self) + ext.remove(state, value, initiator or self) def delete(self, state): if self.key not in state.dict: return - state.modified = True + state.modified_event(self, True, NEVER_SET) collection = self.get_collection(state) collection.clear_with_event() @@ -544,10 +572,14 @@ class CollectionAttributeImpl(AttributeImpl): def initialize(self, state): """Initialize this attribute on the given object instance with an empty collection.""" - _, user_data = self._build_collection(state) + _, user_data = self._initialize_collection(state) state.dict[self.key] = user_data return user_data + def _initialize_collection(self, state): + return state.manager.initialize_collection( + self.key, state, self.collection_factory) + def append(self, state, value, initiator, passive=False): if initiator is self: return @@ -597,7 +629,7 @@ class CollectionAttributeImpl(AttributeImpl): """ # pulling a new collection first so that an adaptation exception does # not trigger a lazy load of the old collection. - new_collection, user_data = self._build_collection(state) + new_collection, user_data = self._initialize_collection(state) if adapter: new_values = list(adapter(new_collection, iterable)) else: @@ -610,25 +642,20 @@ class CollectionAttributeImpl(AttributeImpl): if old is iterable: return - if self.key not in state.committed_state: - state.committed_state[self.key] = self.copy(old) + state.modified_event(self, True, old) old_collection = self.get_collection(state, old) state.dict[self.key] = user_data - state.modified = True collections.bulk_replace(new_values, old_collection, new_collection) old_collection.unlink(old) def set_committed_value(self, state, value): - """Set an attribute value on the given instance and 'commit' it. - - Loads the existing collection from lazy callables in all cases. - """ + """Set an attribute value on the given instance and 'commit' it.""" - collection, user_data = self._build_collection(state) + collection, user_data = self._initialize_collection(state) if value: for item in value: @@ -637,30 +664,23 @@ class CollectionAttributeImpl(AttributeImpl): state.callables.pop(self.key, None) state.dict[self.key] = user_data + state.commit([self.key]) + if self.key in state.pending: - # pending items. commit loaded data, add/remove new data - state.committed_state[self.key] = list(value or []) - added = state.pending[self.key].added_items - removed = state.pending[self.key].deleted_items + # pending items exist. issue a modified event, + # add/remove new items. + state.modified_event(self, True, user_data) + + pending = state.pending.pop(self.key) + added = pending.added_items + removed = pending.deleted_items for item in added: collection.append_without_event(item) for item in removed: collection.remove_without_event(item) - del state.pending[self.key] - elif self.key in state.committed_state: - # no pending items. remove committed state if any. - # (this can occur with an expired attribute) - del state.committed_state[self.key] return user_data - def _build_collection(self, state): - """build a new, blank collection and return it wrapped in a CollectionAdapter.""" - - user_data = self.collection_factory() - collection = collections.CollectionAdapter(self, state, user_data) - return collection, user_data - def get_collection(self, state, user_data=None, passive=False): """retrieve the CollectionAdapter associated with the given state. @@ -672,13 +692,8 @@ class CollectionAttributeImpl(AttributeImpl): user_data = self.get(state, passive=passive) if user_data is PASSIVE_NORESULT: return user_data - try: - return getattr(user_data, '_sa_adapter') - except AttributeError: - # TODO: this codepath never occurs, and this - # except/initialize should be removed - collections.CollectionAdapter(self, state, user_data) - return getattr(user_data, '_sa_adapter') + + return getattr(user_data, '_sa_adapter') class GenericBackrefExtension(interfaces.AttributeExtension): """An extension which synchronizes a two-way relationship. @@ -692,134 +707,150 @@ class GenericBackrefExtension(interfaces.AttributeExtension): def __init__(self, key): self.key = key - def set(self, instance, child, oldchild, initiator): + def set(self, state, child, oldchild, initiator): if oldchild is child: return if oldchild is not None: # With lazy=None, there's no guarantee that the full collection is # present when updating via a backref. - impl = getattr(oldchild.__class__, self.key).impl + old_state = instance_state(oldchild) + impl = old_state.get_impl(self.key) try: - impl.remove(oldchild._state, instance, initiator, passive=True) + impl.remove(old_state, state.obj(), initiator, passive=True) except (ValueError, KeyError, IndexError): pass if child is not None: - getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True) + new_state = instance_state(child) + new_state.get_impl(self.key).append(new_state, state.obj(), initiator, passive=True) - def append(self, instance, child, initiator): - getattr(child.__class__, self.key).impl.append(child._state, instance, initiator, passive=True) + def append(self, state, child, initiator): + child_state = instance_state(child) + child_state.get_impl(self.key).append(child_state, state.obj(), initiator, passive=True) - def remove(self, instance, child, initiator): + def remove(self, state, child, initiator): if child is not None: - getattr(child.__class__, self.key).impl.remove(child._state, instance, initiator, passive=True) - -class ClassState(object): - """tracks state information at the class level.""" - def __init__(self): - self.mappers = {} - self.attrs = {} - self.has_mutable_scalars = False + child_state = instance_state(child) + child_state.get_impl(self.key).remove(child_state, state.obj(), initiator, passive=True) -import sets -_empty_set = sets.ImmutableSet() class InstanceState(object): """tracks state information at the instance level.""" - def __init__(self, obj): + _cleanup = None + session_id = None + key = None + runid = None + entity_name = NO_ENTITY_NAME + expired_attributes = EMPTY_SET + + def __init__(self, obj, manager): self.class_ = obj.__class__ - self.obj = weakref.ref(obj, self.__cleanup) + self.manager = manager + self.obj = weakref.ref(obj, self._cleanup) self.dict = obj.__dict__ self.committed_state = {} self.modified = False self.callables = {} self.parents = {} self.pending = {} - self.appenders = {} - self.instance_dict = None - self.runid = None - self.expired_attributes = _empty_set - - def __cleanup(self, ref): - # tiptoe around Python GC unpredictableness - instance_dict = self.instance_dict - if instance_dict is None: - return + self.expired = False + + def dispose(self): + del self.session_id + + def check_modified(self): + if self.modified: + return True + else: + for key in self.manager.mutable_attributes: + if self.manager[key].impl.check_mutable_modified(self): + return True + else: + return False - instance_dict = instance_dict() - if instance_dict is None or instance_dict._mutex is None: - return + def initialize_instance(*mixed, **kwargs): + self, instance, args = mixed[0], mixed[1], mixed[2:] + manager = self.manager - # the mutexing here is based on the assumption that gc.collect() - # may be firing off cleanup handlers in a different thread than that - # which is normally operating upon the instance dict. - instance_dict._mutex.acquire() + for fn in manager.events.on_init: + fn(self, instance, args, kwargs) try: - try: - self.__resurrect(instance_dict) - except: - # catch app cleanup exceptions. no other way around this - # without warnings being produced - pass - finally: - instance_dict._mutex.release() + return manager.events.original_init(*mixed[1:], **kwargs) + except: + for fn in manager.events.on_init_failure: + fn(self, instance, args, kwargs) + raise - def _check_resurrect(self, instance_dict): - instance_dict._mutex.acquire() - try: - return self.obj() or self.__resurrect(instance_dict) - finally: - instance_dict._mutex.release() + def get_history(self, key, **kwargs): + return self.manager.get_impl(key).get_history(self, **kwargs) + + def get_impl(self, key): + return self.manager.get_impl(key) + + def get_inst(self, key): + return self.manager.get_inst(key) def get_pending(self, key): if key not in self.pending: self.pending[key] = PendingCollection() return self.pending[key] - def is_modified(self): - if self.modified: - return True - elif self.class_._class_state.has_mutable_scalars: - for attr in _managed_attributes(self.class_): - if hasattr(attr.impl, 'check_mutable_modified') and attr.impl.check_mutable_modified(self): - return True - else: - return False - else: - return False + def value_as_iterable(self, key, passive=False): + """return an InstanceState attribute as a list, + regardless of it being a scalar or collection-based + attribute. - def __resurrect(self, instance_dict): - if self.is_modified(): - # store strong ref'ed version of the object; will revert - # to weakref when changes are persisted - obj = new_instance(self.class_, state=self) - self.obj = weakref.ref(obj, self.__cleanup) - self._strong_obj = obj - obj.__dict__.update(self.dict) - self.dict = obj.__dict__ - return obj - else: - del instance_dict[self.dict['_instance_key']] + returns None if passive=True and the getter returns + PASSIVE_NORESULT. + """ + + impl = self.get_impl(key) + x = impl.get(self, passive=passive) + if x is PASSIVE_NORESULT: return None + elif hasattr(impl, 'get_collection'): + return impl.get_collection(self, x, passive=passive) + elif isinstance(x, list): + return x + else: + return [x] + + def _run_on_load(self, instance=None): + if instance is None: + instance = self.obj() + self.manager.events.run('on_load', instance) def __getstate__(self): - return {'committed_state':self.committed_state, 'pending':self.pending, 'parents':self.parents, 'modified':self.modified, 'instance':self.obj(), 'expired_attributes':self.expired_attributes, 'callables':self.callables} + return {'key': self.key, + 'entity_name': self.entity_name, + 'committed_state': self.committed_state, + 'pending': self.pending, + 'parents': self.parents, + 'modified': self.modified, + 'expired':self.expired, + 'instance': self.obj(), + 'expired_attributes':self.expired_attributes, + 'callables': self.callables} def __setstate__(self, state): self.committed_state = state['committed_state'] self.parents = state['parents'] + self.key = state['key'] + self.session_id = None + self.entity_name = state['entity_name'] self.pending = state['pending'] self.modified = state['modified'] self.obj = weakref.ref(state['instance']) self.class_ = self.obj().__class__ + self.manager = manager_of_class(self.class_) self.dict = self.obj().__dict__ self.callables = state['callables'] self.runid = None - self.appenders = {} + self.expired = state['expired'] self.expired_attributes = state['expired_attributes'] def initialize(self, key): - getattr(self.class_, key).impl.initialize(self) + self.manager.get_impl(key).initialize(self) def set_callable(self, key, callable_): self.dict.pop(key, None) @@ -829,70 +860,70 @@ class InstanceState(object): """__call__ allows the InstanceState to act as a deferred callable for loading expired attributes, which is also serializable. + """ - instance = self.obj() unmodified = self.unmodified - self.class_._class_state.deferred_scalar_loader(instance, [ - attr.impl.key for attr in _managed_attributes(self.class_) if + class_manager = self.manager + class_manager.deferred_scalar_loader(self, [ + attr.impl.key for attr in class_manager.attributes if attr.impl.accepts_scalar_loader and attr.impl.key in self.expired_attributes and attr.impl.key in unmodified ]) for k in self.expired_attributes: self.callables.pop(k, None) - self.expired_attributes.clear() + del self.expired_attributes return ATTR_WAS_SET def unmodified(self): """a set of keys which have no uncommitted changes""" return util.Set([ - attr.impl.key for attr in _managed_attributes(self.class_) if - attr.impl.key not in self.committed_state - and (not hasattr(attr.impl, 'commit_to_state') or not attr.impl.check_mutable_modified(self)) - ]) + key for key in self.manager.keys() if + key not in self.committed_state + or (key in self.manager.mutable_attributes and not self.manager[key].impl.check_mutable_modified(self)) + ]) unmodified = property(unmodified) def expire_attributes(self, attribute_names): self.expired_attributes = util.Set(self.expired_attributes) if attribute_names is None: - for attr in _managed_attributes(self.class_): - self.dict.pop(attr.impl.key, None) - self.expired_attributes.add(attr.impl.key) - if attr.impl.accepts_scalar_loader: - self.callables[attr.impl.key] = self - - self.committed_state = {} - else: - for key in attribute_names: - self.dict.pop(key, None) - self.committed_state.pop(key, None) - self.expired_attributes.add(key) - if getattr(self.class_, key).impl.accepts_scalar_loader: - self.callables[key] = self + attribute_names = self.manager.keys() + self.expired = True + self.modified = False + for key in attribute_names: + self.dict.pop(key, None) + self.committed_state.pop(key, None) + self.expired_attributes.add(key) + if self.manager.get_impl(key).accepts_scalar_loader: + self.callables[key] = self def reset(self, key): """remove the given attribute and any callables associated with it.""" + self.dict.pop(key, None) self.callables.pop(key, None) - - def commit_attr(self, attr, value): - """set the value of an attribute and mark it 'committed'.""" - - if hasattr(attr, 'commit_to_state'): - attr.commit_to_state(self, value) - else: - self.committed_state.pop(attr.key, None) - self.dict[attr.key] = value - self.pending.pop(attr.key, None) - self.appenders.pop(attr.key, None) - - # we have a value so we can also unexpire it - self.callables.pop(attr.key, None) - if attr.key in self.expired_attributes: - self.expired_attributes.remove(attr.key) - + + def modified_event(self, attr, should_copy, previous, passive=False): + needs_committed = attr.key not in self.committed_state + + if needs_committed: + if previous is NEVER_SET: + if passive: + if attr.key in self.dict: + previous = self.dict[attr.key] + else: + previous = attr.get(self) + + if should_copy and previous not in (None, NO_VALUE, NEVER_SET): + previous = attr.copy(previous) + + if needs_committed: + self.committed_state[attr.key] = previous + + self.modified = True + def commit(self, keys): """commit all attributes named in the given list of key names. @@ -903,219 +934,405 @@ class InstanceState(object): if a value was not populated in state.dict. """ - if self.class_._class_state.has_mutable_scalars: - for key in keys: - attr = getattr(self.class_, key).impl - if hasattr(attr, 'commit_to_state') and attr.key in self.dict: - attr.commit_to_state(self, self.dict[attr.key]) - else: - self.committed_state.pop(attr.key, None) - self.pending.pop(key, None) - self.appenders.pop(key, None) - else: - for key in keys: + class_manager = self.manager + for key in keys: + if key in self.dict and key in class_manager.mutable_attributes: + class_manager[key].impl.commit_to_state(self, self.committed_state) + else: self.committed_state.pop(key, None) - self.pending.pop(key, None) - self.appenders.pop(key, None) + self.expired = False # unexpire attributes which have loaded for key in self.expired_attributes.intersection(keys): if key in self.dict: self.expired_attributes.remove(key) self.callables.pop(key, None) - def commit_all(self): """commit all attributes unconditionally. - This is used after a flush() or a regular instance load or refresh operation - to mark committed all populated attributes. + This is used after a flush() or a full load/refresh + to remove all pending state from the instance. + + - all attributes are marked as "committed" + - the "strong dirty reference" is removed + - the "modified" flag is set to False + - any "expired" markers/callables are removed. Attributes marked as "expired" can potentially remain "expired" after this step if a value was not populated in state.dict. + """ - self.committed_state = {} - self.modified = False - self.pending = {} - self.appenders = {} - + # unexpire attributes which have loaded - for key in list(self.expired_attributes): - if key in self.dict: - self.expired_attributes.remove(key) + if self.expired_attributes: + for key in self.expired_attributes.intersection(self.dict): self.callables.pop(key, None) + self.expired_attributes.difference_update(self.dict) + + for key in self.manager.mutable_attributes: + if key in self.dict: + self.manager[key].impl.commit_to_state(self, self.committed_state) - if self.class_._class_state.has_mutable_scalars: - for attr in _managed_attributes(self.class_): - if hasattr(attr.impl, 'commit_to_state') and attr.impl.key in self.dict: - attr.impl.commit_to_state(self, self.dict[attr.impl.key]) - - # remove strong ref + self.modified = self.expired = False self._strong_obj = None -class WeakInstanceDict(UserDict.UserDict): - """similar to WeakValueDictionary, but wired towards 'state' objects.""" +class Events(object): + def __init__(self): + self.original_init = object.__init__ + self.on_init = () + self.on_init_failure = () + self.on_load = () + + def run(self, event, *args, **kwargs): + for fn in getattr(self, event): + fn(*args, **kwargs) + + def add_listener(self, event, listener): + # not thread safe... problem? + bucket = getattr(self, event) + if bucket == (): + setattr(self, event, [listener]) + else: + bucket.append(listener) - def __init__(self, *args, **kw): - self._wr = weakref.ref(self) - # RLock because the mutex is used by a cleanup handler, which can be - # called at any time (including within an already mutexed block) - self._mutex = util.threading.RLock() - UserDict.UserDict.__init__(self, *args, **kw) + def remove_listener(self, event, listener): + bucket = getattr(self, event) + bucket.remove(listener) - def __getitem__(self, key): - state = self.data[key] - o = state.obj() - if o is None: - o = state._check_resurrect(self) - if o is None: - raise KeyError, key - return o - def __contains__(self, key): - try: - state = self.data[key] - o = state.obj() - if o is None: - o = state._check_resurrect(self) - except KeyError: - return False - return o is not None +class ClassManager(dict): + """tracks state information at the class level.""" - def has_key(self, key): - return key in self + MANAGER_ATTR = '_fooclass_manager' + STATE_ATTR = '_foostate' - def __repr__(self): - return "<InstanceDict at %s>" % id(self) + event_registry_factory = Events + instance_state_factory = InstanceState - def __setitem__(self, key, value): - if key in self.data: - self._mutex.acquire() - try: - if key in self.data: - self.data[key].instance_dict = None - finally: - self._mutex.release() - self.data[key] = value._state - value._state.instance_dict = self._wr - - def __delitem__(self, key): - state = self.data[key] - state.instance_dict = None - del self.data[key] - - def get(self, key, default=None): - try: - state = self.data[key] - except KeyError: - return default + def __init__(self, class_): + self.class_ = class_ + self.factory = None # where we came from, for inheritance bookkeeping + self.info = {} + self.mappers = {} + self.mutable_attributes = util.Set() + self.local_attrs = {} + self.originals = {} + for base in class_.__mro__[-2:0:-1]: # reverse, skipping 1st and last + cls_state = manager_of_class(base) + if cls_state: + self.update(cls_state) + self.registered = False + self._instantiable = False + self.events = self.event_registry_factory() + + def instantiable(self, boolean): + # experiment, probably won't stay in this form + assert boolean ^ self._instantiable, (boolean, self._instantiable) + if boolean: + self.events.original_init = self.class_.__init__ + new_init = _generate_init(self.class_, self) + self.install_member('__init__', new_init) else: - o = state.obj() - if o is None: - # This should only happen - return default - else: - return o - - def items(self): - L = [] - for key, state in self.data.items(): - o = state.obj() - if o is not None: - L.append((key, o)) - return L - - def iteritems(self): - for state in self.data.itervalues(): - value = state.obj() - if value is not None: - yield value._instance_key, value + self.uninstall_member('__init__') + self._instantiable = bool(boolean) + instantiable = property(lambda s: s._instantiable, instantiable) - def iterkeys(self): - return self.data.iterkeys() + def manage(self): + """Mark this instance as the manager for its class.""" + setattr(self.class_, self.MANAGER_ATTR, self) - def __iter__(self): - return self.data.iterkeys() + def dispose(self): + """Dissasociate this instance from its class.""" + delattr(self.class_, self.MANAGER_ATTR) - def __len__(self): - return len(self.values()) + def manager_getter(self): + return attrgetter(self.MANAGER_ATTR) - def itervalues(self): - for state in self.data.itervalues(): - instance = state.obj() - if instance is not None: - yield instance + def instrument_attribute(self, key, inst, propagated=False): + if propagated: + if key in self.local_attrs: + return # don't override local attr with inherited attr + else: + self.local_attrs[key] = inst + self.install_descriptor(key, inst) + self[key] = inst + for cls in self.class_.__subclasses__(): + manager = manager_of_class(cls) + if manager is None: + manager = create_manager_for_cls(cls) + manager.instrument_attribute(key, inst, True) + + def uninstrument_attribute(self, key, propagated=False): + if key not in self: + return + if propagated: + if key in self.local_attrs: + return # don't get rid of local attr + else: + del self.local_attrs[key] + self.uninstall_descriptor(key) + del self[key] + if key in self.mutable_attributes: + self.mutable_attributes.remove(key) + for cls in self.class_.__subclasses__(): + manager = manager_of_class(cls) + if manager is None: + manager = create_manager_for_cls(cls) + manager.uninstrument_attribute(key, True) + + def unregister(self): + for key in list(self): + if key in self.local_attrs: + self.uninstrument_attribute(key) + self.registered = False + + def install_descriptor(self, key, inst): + if key in (self.STATE_ATTR, self.MANAGER_ATTR): + raise KeyError("%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % key) + setattr(self.class_, key, inst) + + def uninstall_descriptor(self, key): + delattr(self.class_, key) + + def install_member(self, key, implementation): + if key in (self.STATE_ATTR, self.MANAGER_ATTR): + raise KeyError("%r: requested attribute name conflicts with " + "instrumentation attribute of the same name." % key) + self.originals.setdefault(key, getattr(self.class_, key, None)) + setattr(self.class_, key, implementation) + + def uninstall_member(self, key): + original = self.originals.pop(key, None) + if original is not None: + setattr(self.class_, key, original) + + def instrument_collection_class(self, key, collection_class): + return collections.prepare_instrumentation(collection_class) + + def initialize_collection(self, key, state, factory): + user_data = factory() + adapter = collections.CollectionAdapter( + self.get_impl(key), state, user_data) + return adapter, user_data + + def is_instrumented(self, key, search=False): + if search: + return key in self + else: + return key in self.local_attrs - def values(self): - L = [] - for state in self.data.values(): - o = state.obj() - if o is not None: - L.append(o) - return L + def get_impl(self, key): + return self[key].impl - def popitem(self): - raise NotImplementedError() + get_inst = dict.__getitem__ - def pop(self, key, *args): - raise NotImplementedError() + def attributes(self): + return self.itervalues() + attributes = property(attributes) - def setdefault(self, key, default=None): - raise NotImplementedError() + def deferred_scalar_loader(cls, state, keys): + """TODO""" + deferred_scalar_loader = classmethod(deferred_scalar_loader) - def update(self, dict=None, **kwargs): - raise NotImplementedError() + ## InstanceState management - def copy(self): - raise NotImplementedError() + def new_instance(self, state=None): + instance = self.class_.__new__(self.class_) + self.setup_instance(instance, state) + return instance + + def setup_instance(self, instance, with_state=None): + """Register an InstanceState with an instance.""" + if self.has_state(instance): + state = self.state_of(instance) + if with_state: + assert state is with_state + return state + if with_state is None: + with_state = self.instance_state_factory(instance, self) + self.install_state(instance, with_state) + return with_state + + def install_state(self, instance, state): + setattr(instance, self.STATE_ATTR, state) + + def has_state(self, instance): + """True if an InstanceState is installed on the instance.""" + return bool(getattr(instance, self.STATE_ATTR, False)) + + def state_of(self, instance): + """Retrieve the InstanceState of an instance. + + May raise KeyError or AttributeError if no state is available. + """ + return getattr(instance, self.STATE_ATTR) - def all_states(self): - return self.data.values() + def state_getter(self): + """Return a (instance) -> InstanceState callable. -class StrongInstanceDict(dict): - def all_states(self): - return [o._state for o in self.values()] + "state getter" callables should raise either KeyError or + AttributeError if no InstanceState could be found for the + instance. + """ + return attrgetter(self.STATE_ATTR) -def _create_history(attr, state, current): - original = state.committed_state.get(attr.key, NEVER_SET) + def _new_state_if_none(self, instance): + """Install a default InstanceState if none is present. - if hasattr(attr, 'get_collection'): - current = attr.get_collection(state, current) - if original is NO_VALUE: - return (list(current), [], []) - elif original is NEVER_SET: - return ([], list(current), []) + A private convenience method used by the __init__ decorator. + """ + if self.has_state(instance): + return False else: - collection = util.OrderedIdentitySet(current) - s = util.OrderedIdentitySet(original) - return (list(collection.difference(s)), list(collection.intersection(s)), list(s.difference(collection))) - else: - if current is NO_VALUE: - if original not in [None, NEVER_SET, NO_VALUE]: - deleted = [original] + new_state = self.instance_state_factory(instance, self) + self.install_state(instance, new_state) + return new_state + + def has_parent(self, state, key, optimistic=False): + """TODO""" + return self.get_impl(key).hasparent(state, optimistic=optimistic) + + def __nonzero__(self): + """All ClassManagers are non-zero regardless of attribute state.""" + return True + + def __repr__(self): + return '<%s of %r at %x>' % ( + self.__class__.__name__, self.class_, id(self)) + +class _ClassInstrumentationAdapter(ClassManager): + """Adapts a user-defined InstrumentationManager to a ClassManager.""" + + def __init__(self, class_, override): + ClassManager.__init__(self, class_) + self._adapted = override + + def manage(self): + self._adapted.manage(self.class_, self) + + def dispose(self): + self._adapted.dispose(self.class_) + + def manager_getter(self): + return self._adapted.manager_getter(self.class_) + + def instrument_attribute(self, key, inst, propagated=False): + ClassManager.instrument_attribute(self, key, inst, propagated) + if not propagated: + self._adapted.instrument_attribute(self.class_, key, inst) + + def install_descriptor(self, key, inst): + self._adapted.install_descriptor(self.class_, key, inst) + + def uninstall_descriptor(self, key): + self._adapted.uninstall_descriptor(self.class_, key) + + def install_member(self, key, implementation): + self._adapted.install_member(self.class_, key, implementation) + + def uninstall_member(self, key): + self._adapted.uninstall_member(self.class_, key) + + def instrument_collection_class(self, key, collection_class): + return self._adapted.instrument_collection_class( + self.class_, key, collection_class) + + def initialize_collection(self, key, state, factory): + delegate = getattr(self._adapted, 'initialize_collection', None) + if delegate: + return delegate(key, state, factory) + else: + return ClassManager.initialize_collection(self, key, state, factory) + + def setup_instance(self, instance, state=None): + self._adapted.initialize_instance_dict(self.class_, instance) + state = ClassManager.setup_instance(self, instance, with_state=state) + state.dict = self._adapted.get_instance_dict(self.class_, instance) + return state + + def install_state(self, instance, state): + self._adapted.install_state(self.class_, instance, state) + + def state_of(self, instance): + if hasattr(self._adapted, 'state_of'): + return self._adapted.state_of(self.class_, instance) + else: + getter = self._adapted.state_getter(self.class_) + return getter(instance) + + def has_state(self, instance): + if hasattr(self._adapted, 'has_state'): + return self._adapted.has_state(self.class_, instance) + else: + try: + state = self.state_of(instance) + return True + except (KeyError, AttributeError): + return False + + def state_getter(self): + return self._adapted.state_getter(self.class_) + + +class History(tuple): + # TODO: migrate [] marker for empty slots to () + __slots__ = () + + added = property(itemgetter(0)) + unchanged = property(itemgetter(1)) + deleted = property(itemgetter(2)) + + def __new__(cls, added, unchanged, deleted): + return tuple.__new__(cls, (added, unchanged, deleted)) + + def from_attribute(cls, attribute, state, current): + original = state.committed_state.get(attribute.key, NEVER_SET) + + if hasattr(attribute, 'get_collection'): + current = attribute.get_collection(state, current) + if original is NO_VALUE: + return cls(list(current), [], []) + elif original is NEVER_SET: + return cls([], list(current), []) else: - deleted = [] - return ([], [], deleted) - elif original is NO_VALUE: - return ([current], [], []) - elif original is NEVER_SET or attr.is_equal(current, original) is True: # dont let ClauseElement expressions here trip things up - return ([], [current], []) + collection = util.OrderedIdentitySet(current) + s = util.OrderedIdentitySet(original) + return cls(list(collection.difference(s)), + list(collection.intersection(s)), + list(s.difference(collection))) else: - if original is not None: - deleted = [original] + if current is NO_VALUE: + if original not in [None, NEVER_SET, NO_VALUE]: + deleted = [original] + else: + deleted = [] + return cls([], [], deleted) + elif original is NO_VALUE: + return cls([current], [], []) + elif (original is NEVER_SET or + attribute.is_equal(current, original) is True): + # dont let ClauseElement expressions here trip things up + return cls([], [current], []) else: - deleted = [] - return ([current], [], deleted) + if original is not None: + deleted = [original] + else: + deleted = [] + return cls([current], [], deleted) + from_attribute = classmethod(from_attribute) + class PendingCollection(object): """stores items appended and removed from a collection that has not been loaded yet. When the collection is loaded, the changes present in PendingCollection are applied to produce the final result. + """ - def __init__(self): self.deleted_items = util.IdentitySet() self.added_items = util.OrderedIdentitySet() @@ -1130,166 +1347,280 @@ class PendingCollection(object): self.added_items.remove(value) self.deleted_items.add(value) -def _managed_attributes(class_): - """return all InstrumentedAttributes associated with the given class_ and its superclasses.""" - - return chain(*[cl._class_state.attrs.values() for cl in class_.__mro__[:-1] if hasattr(cl, '_class_state')]) def get_history(state, key, **kwargs): - return getattr(state.class_, key).impl.get_history(state, **kwargs) + return state.get_history(key, **kwargs) -def get_as_list(state, key, passive=False): - """return an InstanceState attribute as a list, - regardless of it being a scalar or collection-based - attribute. - returns None if passive=True and the getter returns - PASSIVE_NORESULT. - """ +def has_parent(cls, obj, key, optimistic=False): + """TODO""" + manager = manager_of_class(cls) + state = instance_state(obj) + return manager.has_parent(state, key, optimistic) - attr = getattr(state.class_, key).impl - x = attr.get(state, passive=passive) - if x is PASSIVE_NORESULT: - return None - elif hasattr(attr, 'get_collection'): - return attr.get_collection(state, x, passive=passive) - elif isinstance(x, list): - return x - else: - return [x] - -def has_parent(class_, instance, key, optimistic=False): - return getattr(class_, key).impl.hasparent(instance._state, optimistic=optimistic) +def register_class(class_): + """TODO""" + + # TODO: what's this function for ? why would I call this and not create_manager_for_cls ? + + manager = manager_of_class(class_) + if manager is None: + manager = create_manager_for_cls(class_) + if not manager.instantiable: + manager.instantiable = True -def _create_prop(class_, key, uselist, callable_, typecallable, useobject, mutable_scalars, impl_class, **kwargs): - if impl_class: - return impl_class(class_, key, typecallable, **kwargs) - elif uselist: - return CollectionAttributeImpl(class_, key, callable_, typecallable, **kwargs) - elif useobject: - return ScalarObjectAttributeImpl(class_, key, callable_,**kwargs) - elif mutable_scalars: - return MutableScalarAttributeImpl(class_, key, callable_, **kwargs) - else: - return ScalarAttributeImpl(class_, key, callable_, **kwargs) +def unregister_class(class_): + """TODO""" + manager = manager_of_class(class_) + assert manager + assert manager.instantiable + manager.instantiable = False + manager.unregister() -def manage(instance): - """initialize an InstanceState on the given instance.""" +def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, mutable_scalars=False, impl_class=None, **kwargs): - if not hasattr(instance, '_state'): - instance._state = InstanceState(instance) + manager = manager_of_class(class_) + if manager.is_instrumented(key): + # this currently only occurs if two primary mappers are made for the + # same class. TODO: possibly have InstrumentedAttribute check + # "entity_name" when searching for impl. raise an error if two + # attrs attached simultaneously otherwise + return -def new_instance(class_, state=None): - """create a new instance of class_ without its __init__() method being called. + if uselist: + factory = kwargs.pop('typecallable', None) + typecallable = manager.instrument_collection_class( + key, factory or list) + else: + typecallable = kwargs.pop('typecallable', None) - Also initializes an InstanceState on the new instance. - """ + comparator = kwargs.pop('comparator', None) + parententity = kwargs.pop('parententity', None) - s = class_.__new__(class_) - if state: - s._state = state + if proxy_property: + proxy_type = proxied_attribute_factory(proxy_property) + descriptor = proxy_type(key, proxy_property, comparator, parententity) else: - s._state = InstanceState(s) - return s + descriptor = InstrumentedAttribute( + _create_prop(class_, key, uselist, callable_, + class_manager=manager, + useobject=useobject, + typecallable=typecallable, + mutable_scalars=mutable_scalars, + impl_class=impl_class, + **kwargs), + comparator=comparator, parententity=parententity) + + manager.instrument_attribute(key, descriptor) -def _init_class_state(class_): - if not '_class_state' in class_.__dict__: - class_._class_state = ClassState() +def unregister_attribute(class_, key): + manager_of_class(class_).uninstrument_attribute(key) -def register_class(class_, extra_init=None, on_exception=None, deferred_scalar_loader=None): - _init_class_state(class_) - class_._class_state.deferred_scalar_loader=deferred_scalar_loader +def init_collection(state, key): + """Initialize a collection attribute and return the collection adapter.""" + attr = state.get_impl(key) + user_data = attr.initialize(state) + return attr.get_collection(state, user_data) - oldinit = None - doinit = False +def set_attribute(instance, key, value): + state = instance_state(instance) + state.get_impl(key).set(state, value, None) - def init(instance, *args, **kwargs): - if not hasattr(instance, '_state'): - instance._state = InstanceState(instance) +def get_attribute(instance, key): + state = instance_state(instance) + return state.get_impl(key).get(state) - if extra_init: - extra_init(class_, oldinit, instance, args, kwargs) +def del_attribute(instance, key): + state = instance_state(instance) + state.get_impl(key).delete(state) - try: - if doinit: - oldinit(instance, *args, **kwargs) - elif args or kwargs: - # simulate error message raised by object(), but don't copy - # the text verbatim - raise TypeError("default constructor for object() takes no parameters") - except: - if on_exception: - on_exception(class_, oldinit, instance, args, kwargs) - raise +def is_instrumented(instance, key): + return manager_of_class(instance.__class__).is_instrumented(key, search=True) + +class InstrumentationRegistry(object): + """Private instrumentation registration singleton.""" + manager_finders = weakref.WeakKeyDictionary() + state_finders = util.WeakIdentityMapping() + extended = False - # override oldinit - oldinit = class_.__init__ - if oldinit is None or not hasattr(oldinit, '_oldinit'): - init._oldinit = oldinit - class_.__init__ = init - # if oldinit is already one of our 'init' methods, replace it - elif hasattr(oldinit, '_oldinit'): - init._oldinit = oldinit._oldinit - class_.__init = init - oldinit = oldinit._oldinit + def create_manager_for_cls(self, class_): + assert class_ is not None + assert manager_of_class(class_) is None - if oldinit is not None: - doinit = oldinit is not object.__init__ + for finder in instrumentation_finders: + factory = finder(class_) + if factory is not None: + break + else: + factory = ClassManager + + existing_factories = collect_management_factories_for(class_) + existing_factories.add(factory) + if len(existing_factories) > 1: + raise TypeError( + "multiple instrumentation implementations specified " + "in %s inheritance hierarchy: %r" % ( + class_.__name__, list(existing_factories))) + + manager = factory(class_) + if not isinstance(manager, ClassManager): + manager = _ClassInstrumentationAdapter(class_, manager) + if factory != ClassManager and not self.extended: + self.extended = True + _install_lookup_strategy(self) + + manager.factory = factory + manager.manage() + self.manager_finders[class_] = manager.manager_getter() + self.state_finders[class_] = manager.state_getter() + return manager + + def manager_of_class(self, cls): + if cls is None: + return None try: - init.__name__ = oldinit.__name__ - init.__doc__ = oldinit.__doc__ - except: - # cant set __name__ in py 2.3 ! - pass + finder = self.manager_finders[cls] + except KeyError: + return None + else: + return finder(cls) -def unregister_class(class_): - if hasattr(class_, '__init__') and hasattr(class_.__init__, '_oldinit'): - if class_.__init__._oldinit is not None: - class_.__init__ = class_.__init__._oldinit + def state_of(self, instance): + if instance is None: + raise AttributeError("None has no persistent state.") + return self.state_finders[instance.__class__](instance) + + def state_or_default(self, instance, default=None): + if instance is None: + return default + try: + finder = self.state_finders[instance.__class__] + except KeyError: + return default else: - delattr(class_, '__init__') + try: + return finder(instance) + except (KeyError, AttributeError): + return default + except: + raise + + def unregister(self, class_): + if class_ in self.manager_finders: + manager = self.manager_of_class(class_) + manager.dispose() + del self.manager_finders[class_] + del self.state_finders[class_] + +# Create a registry singleton and prepare placeholders for lookup functions. + +instrumentation_registry = InstrumentationRegistry() +create_manager_for_cls = None +manager_of_class = None +instance_state = None +_lookup_strategy = None + +def _install_lookup_strategy(implementation): + """Switch between native and extended instrumentation modes. - if '_class_state' in class_.__dict__: - _class_state = class_.__dict__['_class_state'] - for key, attr in _class_state.attrs.iteritems(): - if key in class_.__dict__: - delattr(class_, attr.impl.key) - delattr(class_, '_class_state') + Completely private. Use the instrumentation_finders interface to + inject global instrumentation behavior. -def register_attribute(class_, key, uselist, useobject, callable_=None, proxy_property=None, mutable_scalars=False, impl_class=None, **kwargs): - _init_class_state(class_) + """ + global manager_of_class, instance_state, create_manager_for_cls + global _lookup_strategy + + # Using a symbol here to make debugging a little friendlier. + if implementation is not util.symbol('native'): + manager_of_class = implementation.manager_of_class + instance_state = implementation.state_of + create_manager_for_cls = implementation.create_manager_for_cls + else: + def manager_of_class(class_): + return getattr(class_, ClassManager.MANAGER_ATTR, None) + manager_of_class = instrumentation_registry.manager_of_class + instance_state = attrgetter(ClassManager.STATE_ATTR) + create_manager_for_cls = instrumentation_registry.create_manager_for_cls + # TODO: maybe log an event when setting a strategy. + _lookup_strategy = implementation - typecallable = kwargs.pop('typecallable', None) - if isinstance(typecallable, InstrumentedAttribute): - typecallable = None - comparator = kwargs.pop('comparator', None) +_install_lookup_strategy(util.symbol('native')) - if key in class_.__dict__ and isinstance(class_.__dict__[key], InstrumentedAttribute): - # this currently only occurs if two primary mappers are made for the same class. - # TODO: possibly have InstrumentedAttribute check "entity_name" when searching for impl. - # raise an error if two attrs attached simultaneously otherwise - return +def find_native_user_instrumentation_hook(cls): + """Find user-specified instrumentation management for a class.""" + return getattr(cls, INSTRUMENTATION_MANAGER, None) +instrumentation_finders.append(find_native_user_instrumentation_hook) - if proxy_property: - proxy_type = proxied_attribute_factory(proxy_property) - inst = proxy_type(key, proxy_property, comparator) - else: - inst = InstrumentedAttribute(_create_prop(class_, key, uselist, callable_, useobject=useobject, - typecallable=typecallable, mutable_scalars=mutable_scalars, impl_class=impl_class, **kwargs), comparator=comparator) +def collect_management_factories_for(cls): + """Return a collection of factories in play or specified for a hierarchy. - setattr(class_, key, inst) - class_._class_state.attrs[key] = inst + Traverses the entire inheritance graph of a cls and returns a collection + of instrumentation factories for those classes. Factories are extracted + from active ClassManagers, if available, otherwise + instrumentation_finders is consulted. -def unregister_attribute(class_, key): - class_state = class_._class_state - if key in class_state.attrs: - del class_._class_state.attrs[key] - delattr(class_, key) + """ + hierarchy = util.class_hierarchy(cls) + factories = util.Set() + for member in hierarchy: + manager = manager_of_class(member) + if manager is not None: + factories.add(manager.factory) + else: + for finder in instrumentation_finders: + factory = finder(member) + if factory is not None: + break + else: + factory = None + factories.add(factory) + factories.discard(None) + return factories + + +def _create_prop(class_, key, uselist, callable_, class_manager, typecallable, useobject, mutable_scalars, impl_class, **kwargs): + if impl_class: + return impl_class(class_, key, typecallable, class_manager=class_manager, **kwargs) + elif uselist: + return CollectionAttributeImpl(class_, key, callable_, + typecallable=typecallable, + class_manager=class_manager, **kwargs) + elif useobject: + return ScalarObjectAttributeImpl(class_, key, callable_, + class_manager=class_manager, **kwargs) + elif mutable_scalars: + return MutableScalarAttributeImpl(class_, key, callable_, + class_manager=class_manager, **kwargs) + else: + return ScalarAttributeImpl(class_, key, callable_, + class_manager=class_manager, **kwargs) + +def _generate_init(class_, class_manager): + """Build an __init__ decorator that triggers ClassManager events.""" + + original__init__ = class_.__init__ + assert original__init__ + + # Go through some effort here and don't change the user's __init__ + # calling signature. + # FIXME: need to juggle local names to avoid constructor argument + # clashes. + func_body = """\ +def __init__(%(args)s): + new_state = class_manager._new_state_if_none(%(self_arg)s) + if new_state: + return new_state.initialize_instance(%(apply_kw)s) + else: + return original__init__(%(apply_kw)s) +""" + func_vars = util.format_argspec_init(original__init__, grouped=False) + func_text = func_body % func_vars + #TODO: log debug #print func_text + + env = locals().copy() + exec func_text in env + __init__ = env['__init__'] + __init__.__doc__ = original__init__.__doc__ + return __init__ -def init_collection(instance, key): - """Initialize a collection attribute and return the collection adapter.""" - attr = getattr(instance.__class__, key).impl - state = instance._state - user_data = attr.initialize(state) - return attr.get_collection(state, user_data) diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index c8fc2f189..13204e8c1 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -93,6 +93,7 @@ explicit control over triggering append and remove events. The owning object and InstrumentedCollectionAttribute are also reachable through the adapter, allowing for some very sophisticated behavior. + """ import copy @@ -101,7 +102,9 @@ import sets import sys import weakref -from sqlalchemy import exceptions, schema, util as sautil +import sqlalchemy.exceptions as sa_exc +from sqlalchemy import schema +import sqlalchemy.util as sautil from sqlalchemy.util import attrgetter, Set @@ -109,6 +112,9 @@ __all__ = ['collection', 'collection_adapter', 'mapped_collection', 'column_mapped_collection', 'attribute_mapped_collection'] +__instrumentation_mutex = sautil.threading.Lock() + + def column_mapped_collection(mapping_spec): """A dictionary-based collection type with column-based keying. @@ -119,25 +125,29 @@ def column_mapped_collection(mapping_spec): can not, for example, map on foreign key values if those key values will change during the session, i.e. from None to a database-assigned integer after a session flush. - """ - from sqlalchemy.orm import object_mapper + """ + from sqlalchemy.orm.util import _state_mapper + from sqlalchemy.orm.attributes import instance_state if isinstance(mapping_spec, schema.Column): def keyfunc(value): - m = object_mapper(value) - return m._get_attr_by_column(value, mapping_spec) + state = instance_state(value) + m = _state_mapper(state) + return m._get_state_attr_by_column(state, mapping_spec) else: cols = [] for c in mapping_spec: if not isinstance(c, schema.Column): - raise exceptions.ArgumentError( + raise sa_exc.ArgumentError( "mapping_spec tuple may only contain columns") cols.append(c) mapping_spec = tuple(cols) def keyfunc(value): - m = object_mapper(value) - return tuple([m._get_attr_by_column(value, c) for c in mapping_spec]) + state = instance_state(value) + m = _state_mapper(state) + return tuple([m._get_state_attr_by_column(state, c) + for c in mapping_spec]) return lambda: MappedCollection(keyfunc) def attribute_mapped_collection(attr_name): @@ -150,8 +160,8 @@ def attribute_mapped_collection(attr_name): can not, for example, map on foreign key values if those key values will change during the session, i.e. from None to a database-assigned integer after a session flush. - """ + """ return lambda: MappedCollection(attrgetter(attr_name)) @@ -165,8 +175,8 @@ def mapped_collection(keyfunc): can not, for example, map on foreign key values if those key values will change during the session, i.e. from None to a database-assigned integer after a session flush. - """ + """ return lambda: MappedCollection(keyfunc) class collection(object): @@ -193,8 +203,8 @@ class collection(object): Decorators can be specified in long-hand for Python 2.3, or with the class-level dict attribute '__instrumentation__'- see the source for details. - """ + """ # Bundled as a class solely for ease of use: packaging, doc strings, # importability. @@ -236,8 +246,8 @@ class collection(object): If the appender method is internally instrumented, you must also receive the keyword argument '_sa_initiator' and ensure its promulgation to collection events. - """ + """ setattr(fn, '_sa_instrument_role', 'appender') return fn appender = classmethod(appender) @@ -263,8 +273,8 @@ class collection(object): If the remove method is internally instrumented, you must also receive the keyword argument '_sa_initiator' and ensure its promulgation to collection events. - """ + """ setattr(fn, '_sa_instrument_role', 'remover') return fn remover = classmethod(remover) @@ -277,8 +287,8 @@ class collection(object): @collection.iterator def __iter__(self): ... - """ + """ setattr(fn, '_sa_instrument_role', 'iterator') return fn iterator = classmethod(iterator) @@ -297,8 +307,8 @@ class collection(object): # never be called, unless: @collection.internally_instrumented def extend(self, items): ... - """ + """ setattr(fn, '_sa_instrumented', True) return fn internally_instrumented = classmethod(internally_instrumented) @@ -311,8 +321,8 @@ class collection(object): invoked immediately after the '_sa_adapter' property is set on the instance. A single argument is passed: the collection adapter that has been linked, or None if unlinking. - """ + """ setattr(fn, '_sa_instrument_role', 'on_link') return fn on_link = classmethod(on_link) @@ -344,8 +354,8 @@ class collection(object): Supply an implementation of this method if you want to expand the range of possible types that can be assigned in bulk or perform validation on the values about to be assigned. - """ + """ setattr(fn, '_sa_instrument_role', 'converter') return fn converter = classmethod(converter) @@ -362,8 +372,8 @@ class collection(object): @collection.adds('entity') def do_stuff(self, thing, entity=None): ... - """ + """ def decorator(fn): setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) return fn @@ -382,8 +392,8 @@ class collection(object): @collection.replaces(2) def __setitem__(self, index, item): ... - """ + """ def decorator(fn): setattr(fn, '_sa_instrument_before', ('fire_append_event', arg)) setattr(fn, '_sa_instrument_after', 'fire_remove_event') @@ -404,8 +414,8 @@ class collection(object): For methods where the value to remove is not known at call-time, use collection.removes_return. - """ + """ def decorator(fn): setattr(fn, '_sa_instrument_before', ('fire_remove_event', arg)) return fn @@ -424,8 +434,8 @@ class collection(object): For methods where the value to remove is known at call-time, use collection.remove. - """ + """ def decorator(fn): setattr(fn, '_sa_instrument_after', 'fire_remove_event') return fn @@ -437,7 +447,6 @@ class collection(object): # implementations def collection_adapter(collection): """Fetch the CollectionAdapter for a collection.""" - return getattr(collection, '_sa_adapter', None) def collection_iter(collection): @@ -445,8 +454,8 @@ def collection_iter(collection): If the collection is an ORM collection, it need not be attached to an object to be iterable. - """ + """ try: return getattr(collection, '_sa_iterator', getattr(collection, '__iter__'))() @@ -464,8 +473,8 @@ class CollectionAdapter(object): The ORM uses an CollectionAdapter exclusively for interaction with entity collections. - """ + """ def __init__(self, attr, owner_state, data): self.attr = attr self._data = weakref.ref(data) @@ -477,14 +486,12 @@ class CollectionAdapter(object): def link_to_self(self, data): """Link a collection to this adapter, and fire a link event.""" - setattr(data, '_sa_adapter', self) if hasattr(data, '_sa_on_link'): getattr(data, '_sa_on_link')(self) def unlink(self, data): """Unlink a collection from any adapter, and fire a link event.""" - setattr(data, '_sa_adapter', None) if hasattr(data, '_sa_on_link'): getattr(data, '_sa_on_link')(None) @@ -501,8 +508,8 @@ class CollectionAdapter(object): If a converter implementation is not supplied on the collection, a default duck-typing-based implementation is used. - """ + """ converter = getattr(self._data(), '_sa_converter', None) if converter is not None: return converter(obj) @@ -531,44 +538,36 @@ class CollectionAdapter(object): def append_with_event(self, item, initiator=None): """Add an entity to the collection, firing mutation events.""" - getattr(self._data(), '_sa_appender')(item, _sa_initiator=initiator) def append_without_event(self, item): """Add or restore an entity to the collection, firing no events.""" - getattr(self._data(), '_sa_appender')(item, _sa_initiator=False) def remove_with_event(self, item, initiator=None): """Remove an entity from the collection, firing mutation events.""" - getattr(self._data(), '_sa_remover')(item, _sa_initiator=initiator) def remove_without_event(self, item): """Remove an entity from the collection, firing no events.""" - getattr(self._data(), '_sa_remover')(item, _sa_initiator=False) def clear_with_event(self, initiator=None): """Empty the collection, firing a mutation event for each entity.""" - for item in list(self): self.remove_with_event(item, initiator) def clear_without_event(self): """Empty the collection, firing no events.""" - for item in list(self): self.remove_without_event(item) def __iter__(self): """Iterate over entities in the collection.""" - return getattr(self._data(), '_sa_iterator')() def __len__(self): """Count entities in the collection.""" - return len(list(getattr(self._data(), '_sa_iterator')())) def __nonzero__(self): @@ -580,8 +579,8 @@ class CollectionAdapter(object): Initiator is the InstrumentedAttribute that initiated the membership mutation, and should be left as None unless you are passing along an initiator value from a chained operation. - """ + """ if initiator is not False and item is not None: self.attr.fire_append_event(self.owner_state, item, initiator) @@ -591,8 +590,8 @@ class CollectionAdapter(object): Initiator is the InstrumentedAttribute that initiated the membership mutation, and should be left as None unless you are passing along an initiator value from a chained operation. - """ + """ if initiator is not False and item is not None: self.attr.fire_remove_event(self.owner_state, item, initiator) @@ -601,8 +600,8 @@ class CollectionAdapter(object): Only called if the entity cannot be removed after calling fire_remove_event(). - """ + """ self.attr.fire_pre_remove_event(self.owner_state, initiator=initiator) def __getstate__(self): @@ -653,8 +652,7 @@ def bulk_replace(values, existing_adapter, new_adapter): for member in removals: existing_adapter.remove_with_event(member) -__instrumentation_mutex = sautil.threading.Lock() -def _prepare_instrumentation(factory): +def prepare_instrumentation(factory): """Prepare a callable for future use as a collection class factory. Given a collection class factory (either a type or no-arg callable), @@ -663,8 +661,8 @@ def _prepare_instrumentation(factory): This function is responsible for converting collection_class=list into the run-time behavior of collection_class=InstrumentedList. - """ + """ # Convert a builtin to 'Instrumented*' if factory in __canned_instrumentation: factory = __canned_instrumentation[factory] @@ -694,8 +692,8 @@ def __converting_factory(original_factory): Given a collection factory that returns a builtin type (e.g. a list), return a wrapped function that converts that type to one of our instrumented types. - """ + """ def wrapper(): collection = original_factory() type_ = type(collection) @@ -704,7 +702,7 @@ def __converting_factory(original_factory): # collection return __canned_instrumentation[type_](collection) else: - raise exceptions.InvalidRequestError( + raise sa_exc.InvalidRequestError( "Collection class factories must produce instances of a " "single class.") try: @@ -717,7 +715,6 @@ def __converting_factory(original_factory): def _instrument_class(cls): """Modify methods in a class and install instrumentation.""" - # FIXME: more formally document this as a decoratorless/Python 2.3 # option for specifying instrumentation. (likely doc'd here in code only, # not in online docs.) @@ -737,7 +734,7 @@ def _instrument_class(cls): # types is transformed into one of our trivial subclasses # (e.g. InstrumentedList). Catch anything else that sneaks in here... if cls.__module__ == '__builtin__': - raise exceptions.ArgumentError( + raise sa_exc.ArgumentError( "Can not instrument a built-in type. Use a " "subclass, even a trivial one.") @@ -790,7 +787,7 @@ def _instrument_class(cls): # ensure all roles are present, and apply implicit instrumentation if # needed if 'appender' not in roles or not hasattr(cls, roles['appender']): - raise exceptions.ArgumentError( + raise sa_exc.ArgumentError( "Type %s must elect an appender method to be " "a collection class" % cls.__name__) elif (roles['appender'] not in methods and @@ -798,7 +795,7 @@ def _instrument_class(cls): methods[roles['appender']] = ('fire_append_event', 1, None) if 'remover' not in roles or not hasattr(cls, roles['remover']): - raise exceptions.ArgumentError( + raise sa_exc.ArgumentError( "Type %s must elect a remover method to be " "a collection class" % cls.__name__) elif (roles['remover'] not in methods and @@ -806,7 +803,7 @@ def _instrument_class(cls): methods[roles['remover']] = ('fire_remove_event', 1, None) if 'iterator' not in roles or not hasattr(cls, roles['iterator']): - raise exceptions.ArgumentError( + raise sa_exc.ArgumentError( "Type %s must elect an iterator method to be " "a collection class" % cls.__name__) @@ -824,7 +821,6 @@ def _instrument_class(cls): def _instrument_membership_mutator(method, before, argument, after): """Route method args and/or return value through the collection adapter.""" - # This isn't smart enough to handle @adds(1) for 'def fn(self, (a, b))' if before: fn_args = list(sautil.flatten_iterator(inspect.getargspec(method)[0])) @@ -843,7 +839,7 @@ def _instrument_membership_mutator(method, before, argument, after): if before: if pos_arg is None: if named_arg not in kw: - raise exceptions.ArgumentError( + raise sa_exc.ArgumentError( "Missing argument %s" % argument) value = kw[named_arg] else: @@ -852,7 +848,7 @@ def _instrument_membership_mutator(method, before, argument, after): elif named_arg in kw: value = kw[named_arg] else: - raise exceptions.ArgumentError( + raise sa_exc.ArgumentError( "Missing argument %s" % argument) initiator = kw.pop('_sa_initiator', None) @@ -881,7 +877,6 @@ def _instrument_membership_mutator(method, before, argument, after): def __set(collection, item, _sa_initiator=None): """Run set events, may eventually be inlined into decorators.""" - if _sa_initiator is not False and item is not None: executor = getattr(collection, '_sa_adapter', None) if executor: @@ -889,7 +884,6 @@ def __set(collection, item, _sa_initiator=None): def __del(collection, item, _sa_initiator=None): """Run del events, may eventually be inlined into decorators.""" - if _sa_initiator is not False and item is not None: executor = getattr(collection, '_sa_adapter', None) if executor: @@ -897,14 +891,12 @@ def __del(collection, item, _sa_initiator=None): def __before_delete(collection, _sa_initiator=None): """Special method to run 'commit existing value' methods""" - executor = getattr(collection, '_sa_adapter', None) if executor: getattr(executor, 'fire_pre_remove_event')(_sa_initiator) def _list_decorators(): - """Hand-turned instrumentation wrappers that can decorate any list-like - class.""" + """Tailored instrumentation wrappers for any list-like class.""" def _tidy(fn): setattr(fn, '_sa_instrumented', True) @@ -1045,14 +1037,13 @@ def _list_decorators(): return l def _dict_decorators(): - """Hand-turned instrumentation wrappers that can decorate any dict-like - mapping class.""" + """Tailored instrumentation wrappers for any dict-like mapping class.""" def _tidy(fn): setattr(fn, '_sa_instrumented', True) fn.__doc__ = getattr(getattr(dict, fn.__name__), '__doc__') - Unspecified=sautil.symbol('Unspecified') + Unspecified = sautil.symbol('Unspecified') def __setitem__(fn): def __setitem__(self, key, value, _sa_initiator=None): @@ -1157,14 +1148,13 @@ def _set_binops_check_loose(self, obj): def _set_decorators(): - """Hand-turned instrumentation wrappers that can decorate any set-like - sequence class.""" + """Tailored instrumentation wrappers for any set-like class.""" def _tidy(fn): setattr(fn, '_sa_instrumented', True) fn.__doc__ = getattr(getattr(Set, fn.__name__), '__doc__') - Unspecified=sautil.symbol('Unspecified') + Unspecified = sautil.symbol('Unspecified') def add(fn): def add(self, value, _sa_initiator=None): @@ -1365,6 +1355,7 @@ class MappedCollection(dict): ``set`` and ``remove`` are implemented in terms of a keying function: any callable that takes an object and returns an object for use as a dictionary key. + """ def __init__(self, keyfunc): @@ -1374,16 +1365,17 @@ class MappedCollection(dict): returns an object for use as a dictionary key. The keyfunc will be called every time the ORM needs to add a member by - value-only (such as when loading instances from the database) or remove - a member. The usual cautions about dictionary keying apply- + value-only (such as when loading instances from the database) or + remove a member. The usual cautions about dictionary keying apply- ``keyfunc(object)`` should return the same output for the life of the collection. Keying based on mutable properties can result in unreachable instances "lost" in the collection. + """ self.keyfunc = keyfunc def set(self, value, _sa_initiator=None): - """Add an item to the collection, with a key provided by this instance's keyfunc.""" + """Add an item by value, consulting the keyfunc for the key.""" key = self.keyfunc(value) self.__setitem__(key, value, _sa_initiator) @@ -1391,13 +1383,13 @@ class MappedCollection(dict): set = collection.appender(set) def remove(self, value, _sa_initiator=None): - """Remove an item from the collection by value, consulting this instance's keyfunc for the key.""" + """Remove an item by value, consulting the keyfunc for the key.""" key = self.keyfunc(value) # Let self[key] raise if key is not in this collection # testlib.pragma exempt:__ne__ if self[key] != value: - raise exceptions.InvalidRequestError( + raise sa_exc.InvalidRequestError( "Can not remove '%s': collection holds '%s' for key '%s'. " "Possible cause: is the MappedCollection key function " "based on mutable properties or properties that only obtain " @@ -1418,8 +1410,8 @@ class MappedCollection(dict): Raises a TypeError if the key in any (key, value) pair in the dictlike object does not match the key that this collection's keyfunc would have assigned for that value. - """ + """ for incoming_key, value in sautil.dictlike_iteritems(dictlike): new_key = self.keyfunc(value) if incoming_key != new_key: diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index c667460a7..24bbdadce 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -4,14 +4,17 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +"""Relationship dependencies. -"""Bridge the ``PropertyLoader`` (i.e. a ``relation()``) and the +Bridges the ``PropertyLoader`` (i.e. a ``relation()``) and the ``UOWTransaction`` together to allow processing of relation()-based - dependencies at flush time. +dependencies at flush time. + """ -from sqlalchemy.orm import sync -from sqlalchemy import sql, util, exceptions +from sqlalchemy import sql, util +import sqlalchemy.exceptions as sa_exc +from sqlalchemy.orm import attributes, exc, sync from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY @@ -21,10 +24,7 @@ def create_dependency_processor(prop): MANYTOONE: ManyToOneDP, MANYTOMANY : ManyToManyDP, } - if prop.association is not None: - return AssociationDP(prop) - else: - return types[prop.direction](prop) + return types[prop.direction](prop) class DependencyProcessor(object): no_dependencies = False @@ -36,7 +36,7 @@ class DependencyProcessor(object): self.parent = prop.parent self.secondary = prop.secondary self.direction = prop.direction - self.is_backref = prop.is_backref + self.is_backref = prop._is_backref self.post_update = prop.post_update self.foreign_keys = prop.foreign_keys self.passive_deletes = prop.passive_deletes @@ -44,21 +44,21 @@ class DependencyProcessor(object): self.enable_typechecks = prop.enable_typechecks self.key = prop.key if not self.prop.synchronize_pairs: - raise exceptions.ArgumentError("Can't build a DependencyProcessor for relation %s. No target attributes to populate between parent and child are present" % self.prop) + raise sa_exc.ArgumentError("Can't build a DependencyProcessor for relation %s. No target attributes to populate between parent and child are present" % self.prop) def _get_instrumented_attribute(self): """Return the ``InstrumentedAttribute`` handled by this ``DependencyProecssor``. """ - return getattr(self.parent.class_, self.key) + return self.parent.class_manager.get_impl(self.key) def hasparent(self, state): """return True if the given object instance has a parent, according to the ``InstrumentedAttribute`` handled by this ``DependencyProcessor``.""" # TODO: use correct API for this - return self._get_instrumented_attribute().impl.hasparent(state) + return self._get_instrumented_attribute().hasparent(state) def register_dependencies(self, uowcommit): """Tell a ``UOWTransaction`` what mappers are dependent on @@ -111,7 +111,7 @@ class DependencyProcessor(object): if not self.enable_typechecks: return if state is not None and not self.mapper._canload(state): - raise exceptions.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (state.class_, self.prop, self.mapper)) + raise exc.FlushError("Attempting to flush an item of type %s on collection '%s', which is handled by mapper '%s' and does not load items of that type. Did you mean to use a polymorphic mapper for this relationship ? Set 'enable_typechecks=False' on the relation() to disable this exception. Mismatched typeloading may cause bi-directional relationships (backrefs) to not function properly." % (state.class_, self.prop, self.mapper)) def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): """Called during a flush to synchronize primary key identifier @@ -167,9 +167,9 @@ class OneToManyDP(DependencyProcessor): # the child objects have to have their foreign key to the parent set to NULL # this phase can be called safely for any cascade but is unnecessary if delete cascade # is on. - if self.post_update or not self.passive_deletes=='all': + if self.post_update or not self.passive_deletes == 'all': for state in deplist: - (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) if unchanged or deleted: for child in deleted: if child is not None and self.hasparent(child) is False: @@ -204,9 +204,9 @@ class OneToManyDP(DependencyProcessor): # head object is being deleted, and we manage its list of child objects # the child objects have to have their foreign key to the parent set to NULL if not self.post_update: - should_null_fks = not self.cascade.delete and not self.passive_deletes=='all' + should_null_fks = not self.cascade.delete and not self.passive_deletes == 'all' for state in deplist: - (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) if unchanged or deleted: for child in deleted: if child is not None and self.hasparent(child) is False: @@ -220,7 +220,7 @@ class OneToManyDP(DependencyProcessor): uowcommit.register_object(child) else: for state in deplist: - (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True) + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=True) if added or deleted: for child in added: if child is not None: @@ -231,7 +231,9 @@ class OneToManyDP(DependencyProcessor): elif self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c._state, isdelete=True) + uowcommit.register_object( + attributes.instance_state(c), + isdelete=True) if not self.passive_updates and self._pks_changed(uowcommit, state): if not unchanged: (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=False) @@ -287,10 +289,10 @@ class DetectKeySwitch(DependencyProcessor): for s in [elem for elem in uowcommit.session.identity_map.all_states() if issubclass(elem.class_, self.parent.class_) and self.key in elem.dict and - elem.dict[self.key]._state in switchers + attributes.instance_state(elem.dict[self.key]) in switchers ]: uowcommit.register_object(s, listonly=self.passive_updates) - sync.populate(s.dict[self.key]._state, self.mapper, s, self.parent, self.prop.synchronize_pairs) + sync.populate(attributes.instance_state(s.dict[self.key]), self.mapper, s, self.parent, self.prop.synchronize_pairs) #self.syncrules.execute(s.dict[self.key]._state, s, None, None, False) def _pks_changed(self, uowcommit, state): @@ -316,17 +318,17 @@ class ManyToOneDP(DependencyProcessor): def process_dependencies(self, task, deplist, uowcommit, delete = False): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " process_dep isdelete " + repr(delete) + " direction " + repr(self.direction) if delete: - if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes=='all': + if self.post_update and not self.cascade.delete_orphan and not self.passive_deletes == 'all': # post_update means we have to update our row to not reference the child object # before we can DELETE the row for state in deplist: self._synchronize(state, None, None, True, uowcommit) - (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) if added or unchanged or deleted: self._conditional_post_update(state, uowcommit, deleted + unchanged + added) else: for state in deplist: - (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True) + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=True) if added or deleted or unchanged: for child in added: self._synchronize(state, child, None, False, uowcommit) @@ -339,7 +341,7 @@ class ManyToOneDP(DependencyProcessor): if delete: if self.cascade.delete or self.cascade.delete_orphan: for state in deplist: - (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) if self.cascade.delete_orphan: todelete = added + unchanged + deleted else: @@ -349,18 +351,21 @@ class ManyToOneDP(DependencyProcessor): continue uowcommit.register_object(child, isdelete=True) for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c._state, isdelete=True) + uowcommit.register_object( + attributes.instance_state(c), isdelete=True) else: for state in deplist: uowcommit.register_object(state) if self.cascade.delete_orphan: - (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) if deleted: for child in deleted: if self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c._state, isdelete=True) + uowcommit.register_object( + attributes.instance_state(c), + isdelete=True) def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): @@ -400,7 +405,7 @@ class ManyToManyDP(DependencyProcessor): if delete: for state in deplist: - (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=self.passive_deletes) + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=self.passive_deletes) if deleted or unchanged: for child in deleted + unchanged: if child is None or (reverse_dep and (reverse_dep, "manytomany", child, state) in uowcommit.attributes): @@ -443,13 +448,13 @@ class ManyToManyDP(DependencyProcessor): statement = self.secondary.delete(sql.and_(*[c == sql.bindparam(c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow])) result = connection.execute(statement, secondary_delete) if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_delete): - raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of secondary table rows deleted from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_delete))) + raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of secondary table rows deleted from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_delete))) if secondary_update: statement = self.secondary.update(sql.and_(*[c == sql.bindparam("old_" + c.key, type_=c.type) for c in self.secondary.c if c.key in associationrow])) result = connection.execute(statement, secondary_update) if result.supports_sane_multi_rowcount() and result.rowcount != len(secondary_update): - raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of secondary table rows updated from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_update))) + raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of secondary table rows updated from table '%s': %d" % (result.rowcount, self.secondary.description, len(secondary_update))) if secondary_insert: statement = self.secondary.insert() @@ -459,13 +464,14 @@ class ManyToManyDP(DependencyProcessor): #print self.mapper.mapped_table.name + " " + self.key + " " + repr(len(deplist)) + " preprocess_dep isdelete " + repr(delete) + " direction " + repr(self.direction) if not delete: for state in deplist: - (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key,passive=True) + (added, unchanged, deleted) = uowcommit.get_attribute_history(state, self.key, passive=True) if deleted: for child in deleted: if self.cascade.delete_orphan and self.hasparent(child) is False: uowcommit.register_object(child, isdelete=True) for c, m in self.mapper.cascade_iterator('delete', child): - uowcommit.register_object(c._state, isdelete=True) + uowcommit.register_object( + attributes.instance_state(c), isdelete=True) def _synchronize(self, state, child, associationrow, clearkeys, uowcommit): if associationrow is None: @@ -478,12 +484,6 @@ class ManyToManyDP(DependencyProcessor): def _pks_changed(self, uowcommit, state): return sync.source_changes(uowcommit, state, self.parent, self.prop.synchronize_pairs) -class AssociationDP(OneToManyDP): - def __init__(self, *args, **kwargs): - super(AssociationDP, self).__init__(*args, **kwargs) - self.cascade.delete = True - self.cascade.delete_orphan = True - class MapperStub(object): """Pose as a Mapper representing the association table in a many-to-many join, when performing a ``flush()``. diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index 133ad99c8..08e6a57f4 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -1,8 +1,21 @@ -"""'dynamic' collection API. returns Query() objects on the 'read' side, alters -a special AttributeHistory on the 'write' side.""" +# dynamic.py +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import exceptions, util, logging -from sqlalchemy.orm import attributes, object_session, util as mapperutil, strategies +"""Dynamic collection API. + +Dynamic collections act like Query() objects for read operations and support +basic add/delete mutation. + +""" + +from sqlalchemy import log, util +import sqlalchemy.exceptions as sa_exc + +from sqlalchemy.orm import attributes, object_session, \ + util as mapperutil, strategies from sqlalchemy.orm.query import Query from sqlalchemy.orm.mapper import has_identity, object_mapper @@ -12,16 +25,19 @@ class DynaLoader(strategies.AbstractRelationLoader): self.is_class_level = True self._register_attribute(self.parent.class_, impl_class=DynamicAttributeImpl, target_mapper=self.parent_property.mapper, order_by=self.parent_property.order_by) - def create_row_processor(self, selectcontext, mapper, row): - return (None, None, None) + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + return (None, None) -DynaLoader.logger = logging.class_logger(DynaLoader) +DynaLoader.logger = log.class_logger(DynaLoader) class DynamicAttributeImpl(attributes.AttributeImpl): - def __init__(self, class_, key, typecallable, target_mapper, order_by, **kwargs): - super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, **kwargs) + uses_objects = True + accepts_scalar_loader = False + + def __init__(self, class_, key, typecallable, class_manager, target_mapper, order_by, **kwargs): + super(DynamicAttributeImpl, self).__init__(class_, key, typecallable, class_manager, **kwargs) self.target_mapper = target_mapper - self.order_by=order_by + self.order_by = order_by self.query_class = AppenderQuery def get(self, state, passive=False): @@ -41,20 +57,18 @@ class DynamicAttributeImpl(attributes.AttributeImpl): state.modified = True if self.trackparent and value is not None: - self.sethasparent(value._state, True) - instance = state.obj() + self.sethasparent(attributes.instance_state(value), True) for ext in self.extensions: - ext.append(instance, value, initiator or self) + ext.append(state, value, initiator or self) def fire_remove_event(self, state, value, initiator): state.modified = True if self.trackparent and value is not None: - self.sethasparent(value._state, False) + self.sethasparent(attributes.instance_state(value), False) - instance = state.obj() for ext in self.extensions: - ext.remove(instance, value, initiator or self) + ext.remove(state, value, initiator or self) def set(self, state, value, initiator): if initiator is self: @@ -111,26 +125,32 @@ class AppenderQuery(Query): def session(self): return self.__session() - session = property(session) + session = property(session, lambda s, x:None) def __iter__(self): sess = self.__session() if sess is None: - return iter(self.attr._get_collection_history(self.instance._state, passive=True).added_items) + return iter(self.attr._get_collection_history( + attributes.instance_state(self.instance), + passive=True).added_items) else: return iter(self._clone(sess)) def __getitem__(self, index): sess = self.__session() if sess is None: - return self.attr._get_collection_history(self.instance._state, passive=True).added_items.__getitem__(index) + return self.attr._get_collection_history( + attributes.instance_state(self.instance), + passive=True).added_items.__getitem__(index) else: return self._clone(sess).__getitem__(index) def count(self): sess = self.__session() if sess is None: - return len(self.attr._get_collection_history(self.instance._state, passive=True).added_items) + return len(self.attr._get_collection_history( + attributes.instance_state(self.instance), + passive=True).added_items) else: return self._clone(sess).count() @@ -142,10 +162,7 @@ class AppenderQuery(Query): if sess is None: sess = object_session(instance) if sess is None: - try: - sess = object_mapper(instance).get_session() - except exceptions.InvalidRequestError: - raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.instance_str(instance), self.attr.key)) + raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.instance_str(instance), self.attr.key)) q = sess.query(self.attr.target_mapper).with_parent(instance, self.attr.key) if self.attr.order_by: @@ -158,14 +175,14 @@ class AppenderQuery(Query): oldlist = list(self) else: oldlist = [] - self.attr._get_collection_history(self.instance._state, passive=True).replace(oldlist, collection) + self.attr._get_collection_history(attributes.instance_state(self.instance), passive=True).replace(oldlist, collection) return oldlist def append(self, item): - self.attr.append(self.instance._state, item, None) + self.attr.append(attributes.instance_state(self.instance), item, None) def remove(self, item): - self.attr.remove(self.instance._state, item, None) + self.attr.remove(attributes.instance_state(self.instance), item, None) class CollectionHistory(object): diff --git a/lib/sqlalchemy/orm/exc.py b/lib/sqlalchemy/orm/exc.py new file mode 100644 index 000000000..2d1d2b108 --- /dev/null +++ b/lib/sqlalchemy/orm/exc.py @@ -0,0 +1,31 @@ +# exc.py - ORM exceptions +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +"""SQLAlchemy ORM exceptions.""" + +import sqlalchemy.exceptions as sa_exc + + +class ConcurrentModificationError(sa_exc.SQLAlchemyError): + """Rows have been modified outside of the unit of work.""" + + +class FlushError(sa_exc.SQLAlchemyError): + """A invalid condition was detected during flush().""" + + +class ObjectDeletedError(sa_exc.InvalidRequestError): + """An refresh() operation failed to re-retrieve an object's row.""" + + +class UnmappedColumnError(sa_exc.InvalidRequestError): + """Mapping operation was requested on an unknown column.""" + + +# Legacy compat until 0.6. +sa_exc.ConcurrentModificationError = ConcurrentModificationError +sa_exc.FlushError = FlushError +sa_exc.UnmappedColumnError diff --git a/lib/sqlalchemy/orm/identity.py b/lib/sqlalchemy/orm/identity.py new file mode 100644 index 000000000..4487e21dc --- /dev/null +++ b/lib/sqlalchemy/orm/identity.py @@ -0,0 +1,250 @@ +# identity.py +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import weakref + +from sqlalchemy import util as base_util +from sqlalchemy.orm import attributes + + +class IdentityMap(dict): + def __init__(self): + self._mutable_attrs = weakref.WeakKeyDictionary() + self.modified = False + + def add(self, state): + raise NotImplementedError() + + def remove(self, state): + raise NotImplementedError() + + def update(self, dict): + raise NotImplementedError("IdentityMap uses add() to insert data") + + def clear(self): + raise NotImplementedError("IdentityMap uses remove() to remove data") + + def _manage_incoming_state(self, state): + if state.modified: + self.modified = True + if state.manager.mutable_attributes: + self._mutable_attrs[state] = True + + def _manage_removed_state(self, state): + if state in self._mutable_attrs: + del self._mutable_attrs[state] + + def check_modified(self): + """return True if any InstanceStates present have been marked as 'modified'.""" + + if not self.modified: + for state in self._mutable_attrs: + if state.check_modified(): + return True + else: + return False + else: + return True + + def has_key(self, key): + return key in self + + def popitem(self): + raise NotImplementedError("IdentityMap uses remove() to remove data") + + def pop(self, key, *args): + raise NotImplementedError("IdentityMap uses remove() to remove data") + + def setdefault(self, key, default=None): + raise NotImplementedError("IdentityMap uses add() to insert data") + + def copy(self): + raise NotImplementedError() + + def __setitem__(self, key, value): + raise NotImplementedError("IdentityMap uses add() to insert data") + + def __delitem__(self, key): + raise NotImplementedError("IdentityMap uses remove() to remove data") + +class WeakInstanceDict(IdentityMap): + + def __init__(self): + IdentityMap.__init__(self) + self._wr = weakref.ref(self) + # RLock because the mutex is used by a cleanup + # handler, which can be called at any time (including within an already mutexed block) + self._mutex = base_util.threading.RLock() + + def __getitem__(self, key): + state = dict.__getitem__(self, key) + o = state.obj() + if o is None: + o = state._check_resurrect(self) + if o is None: + raise KeyError, key + return o + + def __contains__(self, key): + try: + state = dict.__getitem__(self, key) + o = state.obj() + if o is None: + o = state._check_resurrect(self) + except KeyError: + return False + return o is not None + + def contains_state(self, state): + return dict.get(self, state.key) is state + + def add(self, state): + if state.key in self: + if dict.__getitem__(self, state.key) is not state: + raise AssertionError("A conflicting state is already present in the identity map for key %r" % state.key) + else: + dict.__setitem__(self, state.key, state) + state._instance_dict = self._wr + self._manage_incoming_state(state) + + def remove_key(self, key): + state = dict.__getitem__(self, key) + self.remove(state) + + def remove(self, state): + if not self.contains_state(state): + raise AssertionError("State %s is not present in this identity map" % state) + dict.__delitem__(self, state.key) + del state._instance_dict + self._manage_removed_state(state) + + def discard(self, state): + if self.contains_state(state): + dict.__delitem__(self, state.key) + del state._instance_dict + self._manage_removed_state(state) + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def items(self): + return list(self.iteritems()) + + def iteritems(self): + for state in dict.itervalues(self): + value = state.obj() + if value is not None: + yield state.key, value + + def itervalues(self): + for state in dict.itervalues(self): + instance = state.obj() + if instance is not None: + yield instance + + def values(self): + return list(self.itervalues()) + + def all_states(self): + return dict.values(self) + + def prune(self): + return 0 + +class StrongInstanceDict(IdentityMap): + def all_states(self): + return [attributes.instance_state(o) for o in self.values()] + + def contains_state(self, state): + return state.key in self and attributes.instance_state(self[state.key]) is state + + def add(self, state): + dict.__setitem__(self, state.key, state.obj()) + self._manage_incoming_state(state) + + def remove(self, state): + if not self.contains_state(state): + raise AssertionError("State %s is not present in this identity map" % state) + dict.__delitem__(self, state.key) + self._manage_removed_state(state) + + def discard(self, state): + if self.contains_state(state): + dict.__delitem__(self, state.key) + self._manage_removed_state(state) + + def remove_key(self, key): + state = dict.__getitem__(self, key) + self.remove(state) + + def prune(self): + """prune unreferenced, non-dirty states.""" + + ref_count = len(self) + dirty = [s.obj() for s in self.all_states() if s.check_modified()] + keepers = weakref.WeakValueDictionary(self) + dict.clear(self) + dict.update(self, keepers) + self.modified = bool(dirty) + return ref_count - len(self) + +class IdentityManagedState(attributes.InstanceState): + def _instance_dict(self): + return None + + def _check_resurrect(self, instance_dict): + instance_dict._mutex.acquire() + try: + return self.obj() or self.__resurrect(instance_dict) + finally: + instance_dict._mutex.release() + + def modified_event(self, attr, should_copy, previous, passive=False): + attributes.InstanceState.modified_event(self, attr, should_copy, previous, passive) + + instance_dict = self._instance_dict() + if instance_dict: + instance_dict.modified = True + + def _cleanup(self, ref): + # tiptoe around Python GC unpredictableness + try: + instance_dict = self._instance_dict() + instance_dict._mutex.acquire() + except: + return + # the mutexing here is based on the assumption that gc.collect() + # may be firing off cleanup handlers in a different thread than that + # which is normally operating upon the instance dict. + try: + try: + self.__resurrect(instance_dict) + except: + # catch app cleanup exceptions. no other way around this + # without warnings being produced + pass + finally: + instance_dict._mutex.release() + + def __resurrect(self, instance_dict): + if self.check_modified(): + # store strong ref'ed version of the object; will revert + # to weakref when changes are persisted + obj = self.manager.new_instance(state=self) + self.obj = weakref.ref(obj, self._cleanup) + self._strong_obj = obj + # todo: revisit this wrt user-defined-state + obj.__dict__.update(self.dict) + self.dict = obj.__dict__ + self._run_on_load(obj) + return obj + else: + instance_dict.remove(self) + self.dispose() + return None diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index d61ebe960..6c9fe7753 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -4,27 +4,45 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Semi-private implementation objects which form the basis -of ORM-mapped attributes, query options and mapper extension. +""" + +Semi-private implementation objects which form the basis of ORM-mapped +attributes, query options and mapper extension. + +Defines the [sqlalchemy.orm.interfaces#MapperExtension] class, which can be +end-user subclassed to add event-based functionality to mappers. The +remainder of this module is generally private to the ORM. -Defines the [sqlalchemy.orm.interfaces#MapperExtension] class, -which can be end-user subclassed to add event-based functionality -to mappers. The remainder of this module is generally private to the -ORM. """ from itertools import chain -from sqlalchemy import exceptions, logging, util + +import sqlalchemy.exceptions as sa_exc +from sqlalchemy import log, util from sqlalchemy.sql import expression -class_mapper = None -__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension', - 'MapperProperty', 'PropComparator', 'StrategizedProperty', - 'build_path', 'MapperOption', - 'ExtensionOption', 'PropertyOption', - 'AttributeExtension', 'StrategizedOption', 'LoaderStrategy' ] +class_mapper = None +collections = None + +__all__ = ( + 'AttributeExtension', + 'EXT_CONTINUE', + 'EXT_STOP', + 'ExtensionOption', + 'InstrumentationManager', + 'LoaderStrategy', + 'MapperExtension', + 'MapperOption', + 'MapperProperty', + 'PropComparator', + 'PropertyOption', + 'SessionExtension', + 'StrategizedOption', + 'StrategizedProperty', + 'build_path', + ) -EXT_CONTINUE = EXT_PASS = util.symbol('EXT_CONTINUE') +EXT_CONTINUE = util.symbol('EXT_CONTINUE') EXT_STOP = util.symbol('EXT_STOP') ONETOMANY = util.symbol('ONETOMANY') @@ -44,10 +62,7 @@ class MapperExtension(object): these exception cases, any return value other than EXT_CONTINUE or EXT_STOP will be interpreted as equivalent to EXT_STOP. - EXT_PASS is a synonym for EXT_CONTINUE and is provided for backward - compatibility. """ - def instrument_class(self, mapper, class_): return EXT_CONTINUE @@ -57,16 +72,6 @@ class MapperExtension(object): def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): return EXT_CONTINUE - def get_session(self): - """Retrieve a contextual Session instance with which to - register a new object. - - Note: this is not called if a session is provided with the - `__init__` params (i.e. `_sa_session`). - """ - - return EXT_CONTINUE - def load(self, query, *args, **kwargs): """Override the `load` method of the Query object. @@ -85,43 +90,6 @@ class MapperExtension(object): return EXT_CONTINUE - def get_by(self, query, *args, **kwargs): - """Override the `get_by` method of the Query object. - - The return value of this method is used as the result of - ``query.get_by()`` if the value is anything other than - EXT_CONTINUE. - - DEPRECATED. - """ - - return EXT_CONTINUE - - def select_by(self, query, *args, **kwargs): - """Override the `select_by` method of the Query object. - - The return value of this method is used as the result of - ``query.select_by()`` if the value is anything other than - EXT_CONTINUE. - - DEPRECATED. - """ - - return EXT_CONTINUE - - def select(self, query, *args, **kwargs): - """Override the `select` method of the Query object. - - The return value of this method is used as the result of - ``query.select()`` if the value is anything other than - EXT_CONTINUE. - - DEPRECATED. - """ - - return EXT_CONTINUE - - def translate_row(self, mapper, context, row): """Perform pre-processing on the given result row and return a new row instance. @@ -276,6 +244,56 @@ class MapperExtension(object): return EXT_CONTINUE +class SessionExtension(object): + """An extension hook object for Sessions. Subclasses may be installed into a Session + (or sessionmaker) using the ``extension`` keyword argument. + """ + + def before_commit(self, session): + """Execute right before commit is called. + + Note that this may not be per-flush if a longer running transaction is ongoing.""" + + def after_commit(self, session): + """Execute after a commit has occured. + + Note that this may not be per-flush if a longer running transaction is ongoing.""" + + def after_rollback(self, session): + """Execute after a rollback has occured. + + Note that this may not be per-flush if a longer running transaction is ongoing.""" + + def before_flush(self, session, flush_context, instances): + """Execute before flush process has started. + + `instances` is an optional list of objects which were passed to the ``flush()`` + method. + """ + + def after_flush(self, session, flush_context): + """Execute after flush has completed, but before commit has been called. + + Note that the session's state is still in pre-flush, i.e. 'new', 'dirty', + and 'deleted' lists still show pre-flush state as well as the history + settings on instance attributes.""" + + def after_flush_postexec(self, session, flush_context): + """Execute after flush has completed, and after the post-exec state occurs. + + This will be when the 'new', 'dirty', and 'deleted' lists are in their final + state. An actual commit() may or may not have occured, depending on whether or not + the flush started its own transaction or participated in a larger transaction. + """ + + def after_begin(self, session, transaction, connection): + """Execute after a transaction is begun on a connection + + `transaction` is the SessionTransaction. This method is called after an + engine level transaction is begun on a connection. + """ + + class MapperProperty(object): """Manage the relationship of a ``Mapper`` to a single class attribute, as well as that attribute as it appears on individual @@ -283,7 +301,7 @@ class MapperProperty(object): attribute access, loading behavior, and dependency calculations. """ - def setup(self, querycontext, **kwargs): + def setup(self, context, entity, path, adapter, **kwargs): """Called by Query for the purposes of constructing a SQL statement. Each MapperProperty associated with the target mapper processes the @@ -293,8 +311,8 @@ class MapperProperty(object): pass - def create_row_processor(self, selectcontext, mapper, row): - """Return a 3-tuple consiting of two row processing functions and an instance post-processing function. + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + """Return a 2-tuple consiting of two row processing functions and an instance post-processing function. Input arguments are the query.SelectionContext and the *first* applicable row of a result set obtained within @@ -305,32 +323,24 @@ class MapperProperty(object): columns present in the row (which will be the same columns present in all rows) are used to determine the presence and behavior of the returned callables. The callables will then be used to process all - rows and to post-process all instances, respectively. + rows and instances. Callables are of the following form:: - def new_execute(instance, row, **flags): - # process incoming instance and given row. the instance is + def new_execute(state, row, **flags): + # process incoming instance state and given row. the instance is # "new" and was just created upon receipt of this row. # flags is a dictionary containing at least the following # attributes: # isnew - indicates if the instance was newly created as a # result of reading this row # instancekey - identity key of the instance - # optional attribute: - # ispostselect - indicates if this row resulted from a - # 'post' select of additional tables/columns - def existing_execute(instance, row, **flags): - # process incoming instance and given row. the instance is + def existing_execute(state, row, **flags): + # process incoming instance state and given row. the instance is # "existing" and was created based on a previous row. - def post_execute(instance, **flags): - # process instance after all result rows have been processed. - # this function should be used to issue additional selections - # in order to eagerly load additional properties. - - return (new_execute, existing_execute, post_execute) + return (new_execute, existing_execute) Either of the three tuples can be ``None`` in which case no function is called. @@ -347,20 +357,6 @@ class MapperProperty(object): return iter([]) - def get_criterion(self, query, key, value): - """Return a ``WHERE`` clause suitable for this - ``MapperProperty`` corresponding to the given key/value pair, - where the key is a column or object property name, and value - is a value to be matched. This is only picked up by - ``PropertyLoaders``. - - This is called by a ``Query``'s ``join_by`` method to formulate a set - of key/value pairs into a ``WHERE`` criterion that spans multiple - tables if needed. - """ - - return None - def set_parent(self, parent): self.parent = parent @@ -427,10 +423,10 @@ class PropComparator(expression.ColumnOperators): which returns the MapperProperty associated with this PropComparator. """ - - def expression_element(self): - return self.clause_element() - + + def __clause_element__(self): + raise NotImplementedError("%r" % self) + def contains_op(a, b): return a.contains(b) contains_op = staticmethod(contains_op) @@ -511,37 +507,44 @@ class StrategizedProperty(MapperProperty): ``StrategizedOption`` objects via the Query.options() method. """ - def _get_context_strategy(self, context): - path = context.path - return self._get_strategy(context.attributes.get(("loaderstrategy", path), self.strategy.__class__)) - + def __get_context_strategy(self, context, path): + cls = context.attributes.get(("loaderstrategy", path), None) + if cls: + try: + return self.__all_strategies[cls] + except KeyError: + return self.__init_strategy(cls) + else: + return self.strategy + def _get_strategy(self, cls): try: - return self._all_strategies[cls] + return self.__all_strategies[cls] except KeyError: - # cache the located strategy per class for faster re-lookup - strategy = cls(self) - strategy.init() - self._all_strategies[cls] = strategy - return strategy - - def setup(self, querycontext, **kwargs): - self._get_context_strategy(querycontext).setup_query(querycontext, **kwargs) + return self.__init_strategy(cls) + + def __init_strategy(self, cls): + self.__all_strategies[cls] = strategy = cls(self) + strategy.init() + return strategy + + def setup(self, context, entity, path, adapter, **kwargs): + self.__get_context_strategy(context, path + (self.key,)).setup_query(context, entity, path, adapter, **kwargs) - def create_row_processor(self, selectcontext, mapper, row): - return self._get_context_strategy(selectcontext).create_row_processor(selectcontext, mapper, row) + def create_row_processor(self, context, path, mapper, row, adapter): + return self.__get_context_strategy(context, path + (self.key,)).create_row_processor(context, path, mapper, row, adapter) def do_init(self): - self._all_strategies = {} - self.strategy = self._get_strategy(self.strategy_class) + self.__all_strategies = {} + self.strategy = self.__init_strategy(self.strategy_class) if self.is_primary(): self.strategy.init_class_attribute() -def build_path(mapper, key, prev=None): +def build_path(entity, key, prev=None): if prev: - return prev + (mapper.base_mapper, key) + return prev + (entity, key) else: - return (mapper.base_mapper, key) + return (entity, key) def serialize_path(path): if path is None: @@ -585,9 +588,9 @@ class ExtensionOption(MapperOption): self.ext = ext def process_query(self, query): - query._extension = query._extension.copy() - query._extension.insert(self.ext) - + entity = query._generate_mapper_zero() + entity.extension = entity.extension.copy() + entity.extension.push(self.ext) class PropertyOption(MapperOption): """A MapperOption that is applied to a property off the mapper or @@ -607,60 +610,86 @@ class PropertyOption(MapperOption): def _process(self, query, raiseerr): if self._should_log_debug: self.logger.debug("applying option to Query, property key '%s'" % self.key) - paths = self._get_paths(query, raiseerr) + paths = self.__get_paths(query, raiseerr) if paths: self.process_query_property(query, paths) def process_query_property(self, query, paths): pass + + def __find_entity(self, query, mapper, raiseerr): + from sqlalchemy.orm.util import _class_to_mapper, _is_aliased_class + + if _is_aliased_class(mapper): + searchfor = mapper + else: + searchfor = _class_to_mapper(mapper).base_mapper - def _get_paths(self, query, raiseerr): + for ent in query._mapper_entities: + if ent.path_entity is searchfor: + return ent + else: + if raiseerr: + raise sa_exc.ArgumentError("Can't find entity %s in Query. Current list: %r" % (searchfor, [str(m.path_entity) for m in query._entities])) + else: + return None + + def __get_paths(self, query, raiseerr): path = None + entity = None l = [] + current_path = list(query._current_path) - + if self.mapper: - global class_mapper - if class_mapper is None: - from sqlalchemy.orm import class_mapper - mapper = self.mapper - if isinstance(self.mapper, type): - mapper = class_mapper(mapper) - if mapper is not query.mapper and mapper not in [q.mapper for q in query._entities]: - raise exceptions.ArgumentError("Can't find entity %s in Query. Current list: %r" % (str(mapper), [str(m) for m in query._entities])) - else: - mapper = query.mapper - if isinstance(self.key, basestring): - tokens = self.key.split('.') - else: - tokens = util.to_list(self.key) + entity = self.__find_entity(query, self.mapper, raiseerr) + mapper = entity.mapper + path_element = entity.path_entity - for token in tokens: - if isinstance(token, basestring): - prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr) - elif isinstance(token, PropComparator): - prop = token.property - token = prop.key - + for key in util.to_list(self.key): + if isinstance(key, basestring): + tokens = key.split('.') else: - raise exceptions.ArgumentError("mapper option expects string key or list of attributes") - - if current_path and token == current_path[1]: - current_path = current_path[2:] - continue + tokens = [key] + for token in tokens: + if isinstance(token, basestring): + if not entity: + entity = query._entity_zero() + path_element = entity.path_entity + mapper = entity.mapper + prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr) + key = token + elif isinstance(token, PropComparator): + prop = token.property + if not entity: + entity = self.__find_entity(query, token.parententity, raiseerr) + if not entity: + return [] + path_element = entity.path_entity + key = prop.key + else: + raise sa_exc.ArgumentError("mapper option expects string key or list of attributes") + + if current_path and key == current_path[1]: + current_path = current_path[2:] + continue - if prop is None: - return [] - path = build_path(mapper, prop.key, path) - l.append(path) - if getattr(token, '_of_type', None): - mapper = token._of_type - else: - mapper = getattr(prop, 'mapper', None) + if prop is None: + return [] + + path = build_path(path_element, prop.key, path) + l.append(path) + if getattr(token, '_of_type', None): + path_element = mapper = token._of_type + else: + path_element = mapper = getattr(prop, 'mapper', None) + if path_element: + path_element = path_element.base_mapper + return l -PropertyOption.logger = logging.class_logger(PropertyOption) -PropertyOption._should_log_debug = logging.is_debug_enabled(PropertyOption.logger) +PropertyOption.logger = log.class_logger(PropertyOption) +PropertyOption._should_log_debug = log.is_debug_enabled(PropertyOption.logger) class AttributeExtension(object): """An abstract class which specifies `append`, `delete`, and `set` @@ -732,10 +761,10 @@ class LoaderStrategy(object): def init_class_attribute(self): pass - def setup_query(self, context, **kwargs): + def setup_query(self, context, entity, path, adapter, **kwargs): pass - def create_row_processor(self, selectcontext, mapper, row): + def create_row_processor(self, selectcontext, path, mapper, row, adapter): """Return row processing functions which fulfill the contract specified by MapperProperty.create_row_processor. @@ -744,3 +773,71 @@ class LoaderStrategy(object): """ raise NotImplementedError() + + def __str__(self): + return str(self.parent_property) + + def debug_callable(self, fn, logger, announcement, logfn): + if announcement: + logger.debug(announcement) + if logfn: + def call(*args, **kwargs): + logger.debug(logfn(*args, **kwargs)) + return fn(*args, **kwargs) + return call + else: + return fn + +class InstrumentationManager(object): + """User-defined class instrumentation extension.""" + + # r4361 added a mandatory (cls) constructor to this interface. + # given that, perhaps class_ should be dropped from all of these + # signatures. + + def __init__(self, class_): + pass + + def manage(self, class_, manager): + setattr(class_, '_default_class_manager', manager) + + def dispose(self, class_, manager): + delattr(class_, '_default_class_manager') + + def manager_getter(self, class_): + def get(cls): + return cls._default_class_manager + return get + + def instrument_attribute(self, class_, key, inst): + pass + + def install_descriptor(self, class_, key, inst): + setattr(class_, key, inst) + + def uninstall_descriptor(self, class_, key): + delattr(class_, key) + + def install_member(self, class_, key, implementation): + setattr(class_, key, implementation) + + def uninstall_member(self, class_, key): + delattr(class_, key) + + def instrument_collection_class(self, class_, key, collection_class): + global collections + if collections is None: + from sqlalchemy.orm import collections + return collections.prepare_instrumentation(collection_class) + + def get_instance_dict(self, class_, instance): + return instance.__dict__ + + def initialize_instance_dict(self, class_, instance): + pass + + def install_state(self, class_, instance, state): + setattr(instance, '_default_state', state) + + def state_getter(self, class_): + return lambda instance: getattr(instance, '_default_state') diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index ba0644758..6d79f6cd5 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1,27 +1,44 @@ -# orm/mapper.py +# mapper.py # Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -"""Defines the [sqlalchemy.orm.mapper#Mapper] class, the central configurational +"""Logic to map Python classes to and from selectables. + +Defines the [sqlalchemy.orm.mapper#Mapper] class, the central configurational unit which associates a class with a database table. This is a semi-private module; the main configurational API of the ORM is available in [sqlalchemy.orm#]. + """ import weakref from itertools import chain -from sqlalchemy import sql, util, exceptions, logging -from sqlalchemy.sql import expression, visitors, operators, util as sqlutil -from sqlalchemy.orm import sync, attributes -from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator -from sqlalchemy.orm.util import has_identity, _state_has_identity, _is_mapped_class, has_mapper, \ - _state_mapper, class_mapper, object_mapper, _class_to_mapper,\ - ExtensionCarrier, state_str, instance_str - -__all__ = ['Mapper', 'class_mapper', 'object_mapper', '_mapper_registry'] + +from sqlalchemy import sql, util, log +import sqlalchemy.exceptions as sa_exc +from sqlalchemy.sql import expression, visitors, operators +import sqlalchemy.sql.util as sqlutil +from sqlalchemy.orm import attributes +from sqlalchemy.orm import exc +from sqlalchemy.orm import sync +from sqlalchemy.orm.identity import IdentityManagedState +from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, \ + PropComparator +from sqlalchemy.orm.util import \ + ExtensionCarrier, _INSTRUMENTOR, _class_to_mapper, _is_mapped_class, \ + _state_has_identity, _state_mapper, class_mapper, has_identity, \ + has_mapper, instance_str, object_mapper, state_str + + +__all__ = ( + 'Mapper', + '_mapper_registry', + 'class_mapper', + 'object_mapper', + ) _mapper_registry = weakref.WeakKeyDictionary() _new_mappers = False @@ -43,6 +60,7 @@ ColumnProperty = None SynonymProperty = None ComparableProperty = None _expire_state = None +_state_session = None class Mapper(object): @@ -85,10 +103,11 @@ class Mapper(object): Mappers are normally constructed via the [sqlalchemy.orm#mapper()] function. See for details. - + """ self.class_ = class_ + self.class_manager = None self.entity_name = entity_name self.primary_key_argument = primary_key self.non_primary = non_primary @@ -110,19 +129,18 @@ class Mapper(object): self.eager_defaults = eager_defaults self.column_prefix = column_prefix self.polymorphic_on = polymorphic_on - self._eager_loaders = util.Set() self._dependency_processors = [] self._clause_adapter = None self._requires_row_aliasing = False self.__inherits_equated_pairs = None - + if not issubclass(class_, object): - raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) + raise sa_exc.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) self.select_table = select_table if select_table: if with_polymorphic: - raise exceptions.ArgumentError("select_table can't be used with with_polymorphic (they define conflicting settings)") + raise sa_exc.ArgumentError("select_table can't be used with with_polymorphic (they define conflicting settings)") self.with_polymorphic = ('*', select_table) else: if with_polymorphic == '*': @@ -133,14 +151,14 @@ class Mapper(object): else: self.with_polymorphic = (with_polymorphic, None) elif with_polymorphic is not None: - raise exceptions.ArgumentError("Invalid setting for with_polymorphic") + raise sa_exc.ArgumentError("Invalid setting for with_polymorphic") else: self.with_polymorphic = None - + if isinstance(self.local_table, expression._SelectBaseMixin): util.warn("mapper %s creating an alias for the given selectable - use Class attributes for queries." % self) self.local_table = self.local_table.alias() - + if self.with_polymorphic and isinstance(self.with_polymorphic[1], expression._SelectBaseMixin): self.with_polymorphic[1] = self.with_polymorphic[1].alias() @@ -148,12 +166,8 @@ class Mapper(object): # indicates this Mapper should be used to construct the object instance for that row. self.polymorphic_identity = polymorphic_identity - if polymorphic_fetch not in (None, 'union', 'select', 'deferred'): - raise exceptions.ArgumentError("Invalid option for 'polymorphic_fetch': '%s'" % polymorphic_fetch) - if polymorphic_fetch is None: - self.polymorphic_fetch = (self.with_polymorphic is None) and 'select' or 'union' - else: - self.polymorphic_fetch = polymorphic_fetch + if polymorphic_fetch: + util.warn_deprecated('polymorphic_fetch option is deprecated. Unloaded columns load as deferred in all cases; loading can be controlled using the "with_polymorphic" option.') # a dictionary of 'polymorphic identity' names, associating those names with # Mappers that will be used to construct object instances upon a select operation. @@ -170,14 +184,14 @@ class Mapper(object): # a set of all mappers which inherit from this one. self._inheriting_mappers = util.Set() - self.__props_init = False + self.compiled = False - self.__should_log_info = logging.is_info_enabled(self.logger) - self.__should_log_debug = logging.is_debug_enabled(self.logger) + self.__should_log_info = log.is_info_enabled(self.logger) + self.__should_log_debug = log.is_debug_enabled(self.logger) - self.__compile_class() self.__compile_inheritance() self.__compile_extensions() + self.__compile_class() self.__compile_properties() self.__compile_pks() global _new_mappers @@ -192,11 +206,12 @@ class Mapper(object): if self.__should_log_debug: self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg) - def _is_orphan(self, obj): + def _is_orphan(self, state): o = False for mapper in self.iterate_to_root(): - for (key,klass) in mapper.delete_orphans: - if attributes.has_parent(klass, obj, key, optimistic=has_identity(obj)): + for (key, cls) in mapper.delete_orphans: + if attributes.manager_of_class(cls).has_parent( + state, key, optimistic=_state_has_identity(state)): return False o = o or bool(mapper.delete_orphans) return o @@ -208,41 +223,26 @@ class Mapper(object): return self._get_property(key, resolve_synonyms=resolve_synonyms, raiseerr=raiseerr) def _get_property(self, key, resolve_synonyms=False, raiseerr=True): - """private in-compilation version of get_property().""" - prop = self.__props.get(key, None) if resolve_synonyms: while isinstance(prop, SynonymProperty): prop = self.__props.get(prop.name, None) if prop is None and raiseerr: - raise exceptions.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key)) + raise sa_exc.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key)) return prop def iterate_properties(self): + """return an iterator of all MapperProperty objects.""" self.compile() return self.__props.itervalues() - iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.") + iterate_properties = property(iterate_properties) - def __adjust_wp_selectable(self, spec=None, selectable=False): - """given a with_polymorphic() argument, resolve it against this mapper's with_polymorphic setting""" - - isdefault = False - if self.with_polymorphic: - isdefault = not spec and selectable is False - - if not spec: - spec = self.with_polymorphic[0] - if selectable is False: - selectable = self.with_polymorphic[1] - - return spec, selectable, isdefault - def __mappers_from_spec(self, spec, selectable): """given a with_polymorphic() argument, return the set of mappers it represents. - + Trims the list of mappers to just those represented within the given selectable, if present. This helps some more legacy-ish mappings. - + """ if spec == '*': mappers = list(self.polymorphic_iterator()) @@ -250,86 +250,98 @@ class Mapper(object): mappers = [_class_to_mapper(m) for m in util.to_list(spec)] else: mappers = [] - + if selectable: - tables = util.Set(sqlutil.find_tables(selectable)) + tables = util.Set(sqlutil.find_tables(selectable, include_aliases=True)) mappers = [m for m in mappers if m.local_table in tables] - + return mappers - __mappers_from_spec = util.conditional_cache_decorator(__mappers_from_spec) - + def __selectable_from_mappers(self, mappers): """given a list of mappers (assumed to be within this mapper's inheritance hierarchy), construct an outerjoin amongst those mapper's mapped tables. - + """ from_obj = self.mapped_table for m in mappers: if m is self: continue if m.concrete: - raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.") + raise sa_exc.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.") elif not m.single: from_obj = from_obj.outerjoin(m.local_table, m.inherit_condition) - + return from_obj - __selectable_from_mappers = util.conditional_cache_decorator(__selectable_from_mappers) - - def _with_polymorphic_mappers(self, spec=None, selectable=False): - spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable) - return self.__mappers_from_spec(spec, selectable, cache=isdefault) - - def _with_polymorphic_selectable(self, spec=None, selectable=False): - spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable) + + def _with_polymorphic_mappers(self): + if not self.with_polymorphic: + return [self] + return self.__mappers_from_spec(*self.with_polymorphic) + _with_polymorphic_mappers = property(util.cache_decorator(_with_polymorphic_mappers)) + + def _with_polymorphic_selectable(self): + if not self.with_polymorphic: + return self.mapped_table + + spec, selectable = self.with_polymorphic if selectable: return selectable else: - return self.__selectable_from_mappers(self.__mappers_from_spec(spec, selectable, cache=isdefault), cache=isdefault) - + return self.__selectable_from_mappers(self.__mappers_from_spec(spec, selectable)) + _with_polymorphic_selectable = property(util.cache_decorator(_with_polymorphic_selectable)) + def _with_polymorphic_args(self, spec=None, selectable=False): - spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable) - mappers = self.__mappers_from_spec(spec, selectable, cache=isdefault) + if self.with_polymorphic: + if not spec: + spec = self.with_polymorphic[0] + if selectable is False: + selectable = self.with_polymorphic[1] + + mappers = self.__mappers_from_spec(spec, selectable) if selectable: return mappers, selectable else: - return mappers, self.__selectable_from_mappers(mappers, cache=isdefault) - - def _iterate_polymorphic_properties(self, spec=None, selectable=False): + return mappers, self.__selectable_from_mappers(mappers) + + def _iterate_polymorphic_properties(self, mappers=None): + if mappers is None: + mappers = self._with_polymorphic_mappers return iter(util.OrderedSet( - chain(*[list(mapper.iterate_properties) for mapper in [self] + self._with_polymorphic_mappers(spec, selectable)]) + chain(*[list(mapper.iterate_properties) for mapper in [self] + mappers]) )) def properties(self): raise NotImplementedError("Public collection of MapperProperty objects is provided by the get_property() and iterate_properties accessors.") properties = property(properties) - def compiled(self): - """return True if this mapper is compiled""" - return self.__props_init - compiled = property(compiled) - def dispose(self): - # disaable any attribute-based compilation - self.__props_init = True - try: - del self.class_.c - except AttributeError: - pass - if not self.non_primary and self.entity_name in self._class_state.mappers: - del self._class_state.mappers[self.entity_name] - if not self._class_state.mappers: + # Disable any attribute-based compilation. + self.compiled = True + + manager = self.class_manager + mappers = manager.mappers + + if not self.non_primary and self.entity_name in mappers: + del mappers[self.entity_name] + if not mappers and manager.info.get(_INSTRUMENTOR, False): + for legacy in _legacy_descriptors.keys(): + manager.uninstall_member(legacy) + manager.events.remove_listener('on_init', _event_on_init) + manager.events.remove_listener('on_init_failure', + _event_on_init_failure) + manager.uninstall_member('__init__') + del manager.info[_INSTRUMENTOR] attributes.unregister_class(self.class_) def compile(self): """Compile this mapper and all other non-compiled mappers. - + This method checks the local compiled status as well as for - any new mappers that have been defined, and is safe to call + any new mappers that have been defined, and is safe to call repeatedly. """ - global _new_mappers - if self.__props_init and not _new_mappers: + if self.compiled and not _new_mappers: return self _COMPILE_MUTEX.acquire() @@ -341,12 +353,12 @@ class Mapper(object): try: # double-check inside mutex - if self.__props_init and not _new_mappers: + if self.compiled and not _new_mappers: return self # initialize properties on all mappers for mapper in list(_mapper_registry): - if not mapper.__props_init: + if not mapper.compiled: mapper.__initialize_properties() _new_mappers = False @@ -358,7 +370,7 @@ class Mapper(object): def __initialize_properties(self): """Call the ``init()`` method on all ``MapperProperties`` attached to this mapper. - + This is a deferred configuration step which is intended to execute once all mappers have been constructed. """ @@ -370,8 +382,7 @@ class Mapper(object): if getattr(prop, 'key', None) is None: prop.init(key, self) self.__log("__initialize_properties() complete") - self.__props_init = True - + self.compiled = True def __compile_extensions(self): """Go through the global_extensions list as well as the list @@ -391,14 +402,12 @@ class Mapper(object): for ext in self.inherits.extension: if ext not in extlist: extlist.add(ext) - ext.instrument_class(self, self.class_) else: for ext in global_extensions: if isinstance(ext, type): ext = ext() if ext not in extlist: extlist.add(ext) - ext.instrument_class(self, self.class_) self.extension = ExtensionCarrier() for ext in extlist: @@ -410,13 +419,11 @@ class Mapper(object): if self.inherits: if isinstance(self.inherits, type): self.inherits = class_mapper(self.inherits, compile=False) - else: - self.inherits = self.inherits if not issubclass(self.class_, self.inherits.class_): - raise exceptions.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__)) + raise sa_exc.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__)) if self.non_primary != self.inherits.non_primary: np = not self.non_primary and "primary" or "non-primary" - raise exceptions.ArgumentError("Inheritance of %s mapper for class '%s' is only allowed from a %s mapper" % (np, self.class_.__name__, np)) + raise sa_exc.ArgumentError("Inheritance of %s mapper for class '%s' is only allowed from a %s mapper" % (np, self.class_.__name__, np)) # inherit_condition is optional. if self.local_table is None: self.local_table = self.inherits.local_table @@ -428,29 +435,17 @@ class Mapper(object): if mapper.polymorphic_on: mapper._requires_row_aliasing = True else: - if self.inherit_condition is None: + if not self.inherit_condition: # figure out inherit condition from our table to the immediate table # of the inherited mapper, not its full table which could pull in other # stuff we dont want (allows test/inheritance.InheritTest4 to pass) self.inherit_condition = sqlutil.join_condition(self.inherits.local_table, self.local_table) self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition) - + fks = util.to_set(self.inherit_foreign_keys) self.__inherits_equated_pairs = sqlutil.criterion_as_pairs(self.mapped_table.onclause, consider_as_foreign_keys=fks) else: self.mapped_table = self.local_table - if self.polymorphic_identity is not None: - self.inherits.polymorphic_map[self.polymorphic_identity] = self - if self.polymorphic_on is None: - for mapper in self.iterate_to_root(): - # try to set up polymorphic on using correesponding_column(); else leave - # as None - if mapper.polymorphic_on: - self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on) - break - else: - # TODO: this exception not covered - raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) if self.polymorphic_identity and not self.concrete: self._identity_class = self.inherits._identity_class @@ -470,25 +465,38 @@ class Mapper(object): self.inherits._inheriting_mappers.add(self) self.base_mapper = self.inherits.base_mapper self._all_tables = self.inherits._all_tables + + if self.polymorphic_identity is not None: + self.polymorphic_map[self.polymorphic_identity] = self + if not self.polymorphic_on: + for mapper in self.iterate_to_root(): + # try to set up polymorphic on using correesponding_column(); else leave + # as None + if mapper.polymorphic_on: + self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on) + break + else: + # TODO: this exception not covered + raise sa_exc.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) else: self._all_tables = util.Set() self.base_mapper = self self.mapped_table = self.local_table if self.polymorphic_identity: if self.polymorphic_on is None: - raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) + raise sa_exc.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) self.polymorphic_map[self.polymorphic_identity] = self self._identity_class = self.class_ - + if self.mapped_table is None: - raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self)) + raise sa_exc.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self)) def __compile_pks(self): self.tables = sqlutil.find_tables(self.mapped_table) if not self.tables: - raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table)) + raise sa_exc.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table)) self._pks_by_table = {} self._cols_by_table = {} @@ -512,7 +520,7 @@ class Mapper(object): self._pks_by_table[k.table].add(k) if self.mapped_table not in self._pks_by_table or len(self._pks_by_table[self.mapped_table]) == 0: - raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) + raise sa_exc.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) if self.inherits and not self.concrete and not self.primary_key_argument: # if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit) @@ -525,7 +533,7 @@ class Mapper(object): primary_key = sqlutil.reduce_columns(self._pks_by_table[self.mapped_table]) if len(primary_key) == 0: - raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) + raise sa_exc.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) self.primary_key = primary_key self.__log("Identified primary key columns: " + str(primary_key)) @@ -534,25 +542,18 @@ class Mapper(object): """create a "get clause" based on the primary key. this is used by query.get() and many-to-one lazyloads to load this item by primary key. - + """ params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key] return sql.and_(*[k==v for (k, v) in params]), dict(params) _get_clause = property(util.cache_decorator(_get_clause)) - + def _equivalent_columns(self): """Create a map of all *equivalent* columns, based on the determination of column pairs that are equated to one another either by an established foreign key relationship or by a joined-table inheritance join. - This is used to determine the minimal set of primary key - columns for the mapper, as well as when relating - columns to those of a polymorphic selectable (i.e. a UNION of - several mapped tables), as that selectable usually only contains - one column in its columns clause out of a group of several which - are equated to each other. - The resulting structure is a dictionary of columns mapped to lists of equivalent columns, i.e. @@ -578,7 +579,7 @@ class Mapper(object): result[binary.right] = util.Set([binary.left]) for mapper in self.base_mapper.polymorphic_iterator(): if mapper.inherit_condition: - visitors.traverse(mapper.inherit_condition, visit_binary=visit_binary) + visitors.traverse(mapper.inherit_condition, {}, {'binary':visit_binary}) # TODO: matching of cols to foreign keys might better be generalized # into general column translation (i.e. corresponding_column) @@ -619,7 +620,7 @@ class Mapper(object): cls = object.__getattribute__(self, 'class_') clskey = object.__getattribute__(self, 'key') - if key.startswith('__'): + if key.startswith('__') and key != '__clause_element__': return object.__getattribute__(self, key) class_mapper(cls) @@ -676,13 +677,13 @@ class Mapper(object): column_key = (self.column_prefix or '') + column.key self._compile_property(column_key, column, init=False, setparent=True) - + # do a special check for the "discriminiator" column, as it may only be present # in the 'with_polymorphic' selectable but we need it for the base mapper - if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty: - col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on - self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True) - + if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty: + col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on + self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True) + def _adapt_inherited_property(self, key, prop): if not self.concrete: self._compile_property(key, prop, init=False, setparent=False) @@ -696,7 +697,7 @@ class Mapper(object): columns = util.to_list(prop) column = columns[0] if not expression.is_column(column): - raise exceptions.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop)) + raise sa_exc.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop)) prop = self.__props.get(key, None) @@ -715,12 +716,12 @@ class Mapper(object): for c in columns: mc = self.mapped_table.corresponding_column(c) if not mc: - raise exceptions.ArgumentError("Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c)) + raise sa_exc.ArgumentError("Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c)) mapped_column.append(mc) prop = ColumnProperty(*mapped_column) else: if not self.allow_column_override: - raise exceptions.ArgumentError("WARNING: column '%s' not being added due to property '%s'. Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop))) + raise sa_exc.ArgumentError("WARNING: column '%s' not being added due to property '%s'. Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop))) else: return @@ -731,7 +732,7 @@ class Mapper(object): if col is None: col = prop.columns[0] else: - # if column is coming in after _cols_by_table was initialized, ensure the col is in the + # if column is coming in after _cols_by_table was initialized, ensure the col is in the # right set if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]: self._cols_by_table[col.table].add(col) @@ -740,35 +741,28 @@ class Mapper(object): for col in prop.columns: for col in col.proxy_set: self._columntoproperty[col] = prop - - - elif isinstance(prop, SynonymProperty) and setparent: + + elif isinstance(prop, (ComparableProperty, SynonymProperty)) and setparent: if prop.descriptor is None: prop.descriptor = getattr(self.class_, key, None) if isinstance(prop.descriptor, Mapper._CompileOnAttr): prop.descriptor = object.__getattribute__(prop.descriptor, 'existing_prop') - if prop.map_column: - if not key in self.mapped_table.c: - raise exceptions.ArgumentError("Can't compile synonym '%s': no column on table '%s' named '%s'" % (prop.name, self.mapped_table.description, key)) + if getattr(prop, 'map_column', False): + if key not in self.mapped_table.c: + raise sa_exc.ArgumentError("Can't compile synonym '%s': no column on table '%s' named '%s'" % (prop.name, self.mapped_table.description, key)) self._compile_property(prop.name, ColumnProperty(self.mapped_table.c[key]), init=init, setparent=setparent) - elif isinstance(prop, ComparableProperty) and setparent: - # refactor me - if prop.descriptor is None: - prop.descriptor = getattr(self.class_, key, None) - if isinstance(prop.descriptor, Mapper._CompileOnAttr): - prop.descriptor = object.__getattribute__(prop.descriptor, - 'existing_prop') + self.__props[key] = prop if setparent: prop.set_parent(self) if not self.non_primary: - setattr(self.class_, key, Mapper._CompileOnAttr(self.class_, key)) - + self.class_manager.install_descriptor( + key, Mapper._CompileOnAttr(self.class_, key)) if init: prop.init(key, self) - + for mapper in self._inheriting_mappers: mapper._adapt_inherited_property(key, prop) @@ -783,49 +777,78 @@ class Mapper(object): auto-session attachment logic. """ + manager = attributes.manager_of_class(self.class_) + if self.non_primary: - if not hasattr(self.class_, '_class_state'): - raise exceptions.InvalidRequestError("Class %s has no primary mapper configured. Configure a primary mapper first before setting up a non primary Mapper.") - self._class_state = self.class_._class_state + if not manager or None not in manager.mappers: + raise sa_exc.InvalidRequestError( + "Class %s has no primary mapper configured. Configure " + "a primary mapper first before setting up a non primary " + "Mapper.") + self.class_manager = manager _mapper_registry[self] = True return - if not self.non_primary and '_class_state' in self.class_.__dict__ and (self.entity_name in self.class_._class_state.mappers): - raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'. Use non_primary=True to create a non primary Mapper. clear_mappers() will remove *all* current mappers from all classes." % (self.class_, self.entity_name)) + if manager is not None: + if manager.class_ is not self.class_: + # An inherited manager. Install one for this subclass. + manager = None + elif self.entity_name in manager.mappers: + raise sa_exc.ArgumentError( + "Class '%s' already has a primary mapper defined " + "with entity name '%s'. Use non_primary=True to " + "create a non primary Mapper. clear_mappers() will " + "remove *all* current mappers from all classes." % + (self.class_, self.entity_name)) - def extra_init(class_, oldinit, instance, args, kwargs): - self.compile() - if 'init_instance' in self.extension.methods: - self.extension.init_instance(self, class_, oldinit, instance, args, kwargs) + _mapper_registry[self] = True - def on_exception(class_, oldinit, instance, args, kwargs): - util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs) + if manager is None: + manager = attributes.create_manager_for_cls(self.class_) - attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception, deferred_scalar_loader=_load_scalar_attributes) + self.class_manager = manager - self._class_state = self.class_._class_state - _mapper_registry[self] = True + has_been_initialized = bool(manager.info.get(_INSTRUMENTOR, False)) + manager.mappers[self.entity_name] = self - self.class_._class_state.mappers[self.entity_name] = self + # The remaining members can be added by any mapper, e_name None or not. + if has_been_initialized: + return - for ext in util.to_list(self.extension, []): - ext.instrument_class(self, self.class_) + self.extension.instrument_class(self, self.class_) - if self.entity_name is None: - self.class_.c = self.c + manager.instantiable = True + manager.instance_state_factory = IdentityManagedState + manager.deferred_scalar_loader = _load_scalar_attributes + + event_registry = manager.events + event_registry.add_listener('on_init', _event_on_init) + event_registry.add_listener('on_init_failure', _event_on_init_failure) + + for key, impl in _legacy_descriptors.items(): + manager.install_member(key, impl) + + manager.info[_INSTRUMENTOR] = self def common_parent(self, other): """Return true if the given mapper shares a common inherited parent as this mapper.""" return self.base_mapper is other.base_mapper + def _canload(self, state): + s = self.primary_mapper() + if s.polymorphic_on: + return _state_mapper(state).isa(s) + else: + return _state_mapper(state) is s + def isa(self, other): - """Return True if the given mapper inherits from this mapper.""" + """Return True if the this mapper inherits from the given mapper.""" - m = other - while m is not self and m.inherits: + m = self + while m and m is not other: m = m.inherits - return m is self + return bool(m) def iterate_to_root(self): m = self @@ -867,42 +890,20 @@ class Mapper(object): """ self._init_properties[key] = prop - self._compile_property(key, prop, init=self.__props_init) + self._compile_property(key, prop, init=self.compiled) + + def __repr__(self): + return '<Mapper at 0x%x; %s>' % ( + id(self), self.class_.__name__) def __str__(self): return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") def primary_mapper(self): """Return the primary mapper corresponding to this mapper's class key (class + entity_name).""" - return self._class_state.mappers[self.entity_name] - - def get_session(self): - """Return the contextual session provided by the mapper - extension chain, if any. - - Raise ``InvalidRequestError`` if a session cannot be retrieved - from the extension chain. - """ - - if 'get_session' in self.extension.methods: - s = self.extension.get_session() - if s is not EXT_CONTINUE: - return s - - raise exceptions.InvalidRequestError("No contextual Session is established.") - - def instances(self, cursor, session, *mappers, **kwargs): - """Return a list of mapped instances corresponding to the rows - in a given ResultProxy. - - DEPRECATED. - """ + return self.class_manager.mappers[self.entity_name] - import sqlalchemy.orm.query - return sqlalchemy.orm.Query(self, session).instances(cursor, *mappers, **kwargs) - instances = util.deprecated(None, False)(instances) - - def identity_key_from_row(self, row): + def identity_key_from_row(self, row, adapter=None): """Return an identity-map key for use in storing/retrieving an item from the identity map. @@ -911,7 +912,12 @@ class Mapper(object): dictionary corresponding result-set ``ColumnElement`` instances to their values within a row. """ - return (self._identity_class, tuple([row[column] for column in self.primary_key]), self.entity_name) + + pk_cols = self.primary_key + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + + return (self._identity_class, tuple([row[column] for column in pk_cols]), self.entity_name) def identity_key_from_primary_key(self, primary_key): """Return an identity-map key for use in storing/retrieving an @@ -926,8 +932,9 @@ class Mapper(object): """Return the identity key for the given instance, based on its primary key attributes. - This value is typically also found on the instance itself - under the attribute name `_instance_key`. + This value is typically also found on the instance state under the + attribute name `key`. + """ return self.identity_key_from_primary_key(self.primary_key_from_instance(instance)) @@ -938,17 +945,12 @@ class Mapper(object): """Return the list of primary key values for the given instance. """ - - return [self._get_state_attr_by_column(instance._state, column) for column in self.primary_key] + state = attributes.instance_state(instance) + return self._primary_key_from_state(state) def _primary_key_from_state(self, state): return [self._get_state_attr_by_column(state, column) for column in self.primary_key] - def _canload(self, state): - if self.polymorphic_on: - return issubclass(state.class_, self.class_) - else: - return state.class_ is self.class_ def _get_col_to_prop(self, column): try: @@ -956,24 +958,23 @@ class Mapper(object): except KeyError: prop = self.__props.get(column.key, None) if prop: - raise exceptions.UnmappedColumnError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) + raise exc.UnmappedColumnError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) else: - raise exceptions.UnmappedColumnError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self))) + raise exc.UnmappedColumnError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self))) + # TODO: improve names def _get_state_attr_by_column(self, state, column): return self._get_col_to_prop(column).getattr(state, column) def _set_state_attr_by_column(self, state, column, value): return self._get_col_to_prop(column).setattr(state, value, column) - def _get_attr_by_column(self, obj, column): - return self._get_col_to_prop(column).getattr(obj._state, column) - def _get_committed_attr_by_column(self, obj, column): - return self._get_col_to_prop(column).getcommitted(obj._state, column) + state = attributes.instance_state(obj) + return self._get_committed_state_attr_by_column(state, column) - def _set_attr_by_column(self, obj, column, value): - self._get_col_to_prop(column).setattr(obj._state, column, value) + def _get_committed_state_attr_by_column(self, state, column): + return self._get_col_to_prop(column).getcommitted(state, column) def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. @@ -1002,15 +1003,14 @@ class Mapper(object): # organize individual states with the connection to use for insert/update if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, connection_callable(self, state.obj()), _state_has_identity(state)) for state in states] + tups = [(state, _state_mapper(state), connection_callable(self, state.obj()), _state_has_identity(state)) for state in states] else: connection = uowtransaction.transaction.connection(self) - tups = [(state, connection, _state_has_identity(state)) for state in states] + tups = [(state, _state_mapper(state), connection, _state_has_identity(state)) for state in states] if not postupdate: # call before_XXX extensions - for state, connection, has_identity in tups: - mapper = _state_mapper(state) + for state, mapper, connection, has_identity in tups: if not has_identity: if 'before_insert' in mapper.extension.methods: mapper.extension.before_insert(mapper, connection, state.obj()) @@ -1018,16 +1018,16 @@ class Mapper(object): if 'before_update' in mapper.extension.methods: mapper.extension.before_update(mapper, connection, state.obj()) - for state, connection, has_identity in tups: + for state, mapper, connection, has_identity in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), # and another instance with the same identity key already exists as persistent. convert to an # UPDATE if so. - mapper = _state_mapper(state) instance_key = mapper._identity_key_from_state(state) - if not postupdate and not has_identity and instance_key in uowtransaction.uow.identity_map: - existing = uowtransaction.uow.identity_map[instance_key]._state + if not postupdate and not has_identity and instance_key in uowtransaction.session.identity_map: + instance = uowtransaction.session.identity_map[instance_key] + existing = attributes.instance_state(instance) if not uowtransaction.is_deleted(existing): - raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing))) + raise exc.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing))) if self.__should_log_debug: self.__log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing))) uowtransaction.set_row_switch(existing) @@ -1044,8 +1044,7 @@ class Mapper(object): insert = [] update = [] - for state, connection, has_identity in tups: - mapper = _state_mapper(state) + for state, mapper, connection, has_identity in tups: if table not in mapper._pks_by_table: continue pks = mapper._pks_by_table[table] @@ -1054,7 +1053,7 @@ class Mapper(object): if self.__should_log_debug: self.__log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key))) - isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity + isinsert = not instance_key in uowtransaction.session.identity_map and not postupdate and not has_identity params = {} value_params = {} hasdata = False @@ -1131,7 +1130,7 @@ class Mapper(object): pks = mapper._pks_by_table[table] def comparator(a, b): for col in pks: - x = cmp(a[1][col._label],b[1][col._label]) + x = cmp(a[1][col._label], b[1][col._label]) if x != 0: return x return 0 @@ -1148,7 +1147,7 @@ class Mapper(object): rows += c.rowcount if c.supports_sane_rowcount() and rows != len(update): - raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update))) + raise exc.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update))) if insert: statement = table.insert() @@ -1179,8 +1178,7 @@ class Mapper(object): if not postupdate: # call after_XXX extensions - for state, connection, has_identity in tups: - mapper = _state_mapper(state) + for state, mapper, connection, has_identity in tups: if not has_identity: if 'after_insert' in mapper.extension.methods: mapper.extension.after_insert(mapper, connection, state.obj()) @@ -1216,9 +1214,10 @@ class Mapper(object): if deferred_props: if self.eager_defaults: - _instance_key = self._identity_key_from_state(state) - state.dict['_instance_key'] = _instance_key - uowtransaction.session.query(self)._get(_instance_key, refresh_instance=state, only_load_props=deferred_props) + state.key = self._identity_key_from_state(state) + uowtransaction.session.query(self)._get( + state.key, refresh_instance=state, + only_load_props=deferred_props) else: _expire_state(state, deferred_props) @@ -1234,17 +1233,15 @@ class Mapper(object): if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] - tups = [(state, connection_callable(self, state.obj())) for state in states] + tups = [(state, _state_mapper(state), connection_callable(self, state.obj())) for state in states] else: connection = uowtransaction.transaction.connection(self) - tups = [(state, connection) for state in states] + tups = [(state, _state_mapper(state), connection) for state in states] - for (state, connection) in tups: - mapper = _state_mapper(state) + for state, mapper, connection in tups: if 'before_delete' in mapper.extension.methods: mapper.extension.before_delete(mapper, connection, state.obj()) - deleted_objects = util.Set() table_to_mapper = {} for mapper in self.base_mapper.polymorphic_iterator(): for t in mapper.tables: @@ -1252,8 +1249,7 @@ class Mapper(object): for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True): delete = {} - for (state, connection) in tups: - mapper = _state_mapper(state) + for state, mapper, connection in tups: if table not in mapper._pks_by_table: continue @@ -1266,13 +1262,12 @@ class Mapper(object): params[col.key] = mapper._get_state_attr_by_column(state, col) if mapper.version_id_col and table.c.contains_column(mapper.version_id_col): params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col) - # testlib.pragma exempt:__hash__ - deleted_objects.add((state, connection)) + for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] def comparator(a, b): for col in mapper._pks_by_table[table]: - x = cmp(a[col.key],b[col.key]) + x = cmp(a[col.key], b[col.key]) if x != 0: return x return 0 @@ -1285,10 +1280,9 @@ class Mapper(object): statement = table.delete(clause) c = connection.execute(statement, del_objects) if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects): - raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects))) + raise exc.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects))) - for state, connection in deleted_objects: - mapper = _state_mapper(state) + for state, mapper, connection in tups: if 'after_delete' in mapper.extension.methods: mapper.extension.after_delete(mapper, connection, state.obj()) @@ -1325,7 +1319,7 @@ class Mapper(object): visitables = [(self.__props.itervalues(), 'property', state)] while visitables: - iterator,item_type,parent_state = visitables[-1] + iterator, item_type, parent_state = visitables[-1] try: if item_type == 'property': prop = iterator.next() @@ -1337,291 +1331,315 @@ class Mapper(object): except StopIteration: visitables.pop() - def _instance(self, context, row, result=None, polymorphic_from=None, extension=None, only_load_props=None, refresh_instance=None): - if not extension: - extension = self.extension + def _instance_processor(self, context, path, adapter, polymorphic_from=None, extension=None, only_load_props=None, refresh_instance=None): + pk_cols = self.primary_key - if 'translate_row' in extension.methods: - ret = extension.translate_row(self, context, row) - if ret is not EXT_CONTINUE: - row = ret - - if polymorphic_from: - # if we are called from a base mapper doing a polymorphic load, figure out what tables, - # if any, will need to be "post-fetched" based on the tables present in the row, - # or from the options set up on the query - if ('polymorphic_fetch', self) not in context.attributes: - if self in context.query._with_polymorphic: - context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, []) - else: - context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, [t for t in self.tables if t not in polymorphic_from.tables]) - - elif not refresh_instance and self.polymorphic_on: - discriminator = row[self.polymorphic_on] - if discriminator is not None: - try: - mapper = self.polymorphic_map[discriminator] - except KeyError: - raise exceptions.AssertionError("No such polymorphic_identity %r is defined" % discriminator) - if mapper is not self: - return mapper._instance(context, row, result=result, polymorphic_from=self) - - # determine identity key - if refresh_instance: - try: - identitykey = refresh_instance.dict['_instance_key'] - except KeyError: - # super-rare condition; a refresh is being called - # on a non-instance-key instance; this is meant to only - # occur wihtin a flush() - identitykey = self._identity_key_from_state(refresh_instance) + if polymorphic_from or refresh_instance: + polymorphic_on = None else: - identitykey = self.identity_key_from_row(row) - - session_identity_map = context.session.identity_map + polymorphic_on = self.polymorphic_on + polymorphic_instances = util.PopulateDict(self.__configure_subclass_mapper(context, path, adapter)) - if identitykey in session_identity_map: - instance = session_identity_map[identitykey] - state = instance._state - - if self.__should_log_debug: - self.__log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), str(identitykey))) - - isnew = state.runid != context.runid - currentload = not isnew - - if not currentload and context.version_check and self.version_id_col and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]: - raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col])) - elif refresh_instance: - # out of band refresh_instance detected (i.e. its not in the session.identity_map) - # honor it anyway. this can happen if a _get() occurs within save_obj(), such as - # when eager_defaults is True. - state = refresh_instance - instance = state.obj() - isnew = state.runid != context.runid - currentload = True - else: - if self.__should_log_debug: - self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) + version_id_col = self.version_id_col - if self.allow_null_pks: - for x in identitykey[1]: - if x is not None: - break - else: - return None - else: - if None in identitykey[1]: - return None - isnew = True - currentload = True - - if 'create_instance' in extension.methods: - instance = extension.create_instance(self, context, row, self.class_) - if instance is EXT_CONTINUE: - instance = attributes.new_instance(self.class_) - else: - attributes.manage(instance) - else: - instance = attributes.new_instance(self.class_) + if adapter: + pk_cols = [adapter.columns[c] for c in pk_cols] + if polymorphic_on: + polymorphic_on = adapter.columns[polymorphic_on] + if version_id_col: + version_id_col = adapter.columns[version_id_col] + + identity_class, entity_name = self._identity_class, self.entity_name + def identity_key(row): + return (identity_class, tuple([row[column] for column in pk_cols]), entity_name) - if self.__should_log_debug: - self.__log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey))) + new_populators = [] + existing_populators = [] - state = instance._state - instance._entity_name = self.entity_name - instance._instance_key = identitykey - instance._sa_session_id = context.session.hash_key - session_identity_map[identitykey] = instance + def populate_state(state, row, isnew, only_load_props, **flags): + if not new_populators: + new_populators[:], existing_populators[:] = self.__populators(context, path, row, adapter) - if currentload or context.populate_existing or self.always_refresh: if isnew: - state.runid = context.runid - context.progress.add(state) + populators = new_populators + else: + populators = existing_populators - if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) - - else: - # populate attributes on non-loading instances which have been expired - # TODO: also support deferred attributes here [ticket:870] - if state.expired_attributes: - if state in context.partials: - isnew = False - attrs = context.partials[state] - else: - isnew = True - attrs = state.expired_attributes.intersection(state.unmodified) - context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs + if only_load_props: + populators = [p for p in populators if p[0] in only_load_props] - if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: - self.populate_instance(context, instance, row, only_load_props=attrs, instancekey=identitykey, isnew=isnew) - - if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE): - result.append(instance) - - return instance - - def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, only_load_props=None, **flags): - """populate an instance from a result row.""" - - snapshot = selectcontext.path + (self,) - # retrieve a set of "row population" functions derived from the MapperProperties attached - # to this Mapper. These are keyed in the select context based primarily off the - # "snapshot" of the stack, which represents a path from the lead mapper in the query to this one, - # including relation() names. the key also includes "self", and allows us to distinguish between - # other mappers within our inheritance hierarchy - (new_populators, existing_populators) = selectcontext.attributes.get(('populators', self, snapshot, ispostselect), (None, None)) - if new_populators is None: - # no populators; therefore this is the first time we are receiving a row for - # this result set. issue create_row_processor() on all MapperProperty objects - # and cache in the select context. - new_populators = [] - existing_populators = [] - post_processors = [] - for prop in self.__props.values(): - (newpop, existingpop, post_proc) = selectcontext.exec_with_path(self, prop.key, prop.create_row_processor, selectcontext, self, row) - if newpop: - new_populators.append((prop.key, newpop)) - if existingpop: - existing_populators.append((prop.key, existingpop)) - if post_proc: - post_processors.append(post_proc) - - # install a post processor for immediate post-load of joined-table inheriting mappers - poly_select_loader = self._get_poly_select_loader(selectcontext, row) - if poly_select_loader: - post_processors.append(poly_select_loader) - - selectcontext.attributes[('populators', self, snapshot, ispostselect)] = (new_populators, existing_populators) - selectcontext.attributes[('post_processors', self, ispostselect)] = post_processors - - if isnew or ispostselect: - populators = new_populators - else: - populators = existing_populators + for key, populator in populators: + populator(state, row, isnew=isnew, **flags) - if only_load_props: - populators = [p for p in populators if p[0] in only_load_props] + session_identity_map = context.session.identity_map - for (key, populator) in populators: - selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags) + if not extension: + extension = self.extension - if self.non_primary: - selectcontext.attributes[('populating_mapper', instance._state)] = self + translate_row = 'translate_row' in extension.methods + create_instance = 'create_instance' in extension.methods + populate_instance = 'populate_instance' in extension.methods + append_result = 'append_result' in extension.methods + populate_existing = context.populate_existing or self.always_refresh + + def _instance(row, result): + if translate_row: + ret = extension.translate_row(self, context, row) + if ret is not EXT_CONTINUE: + row = ret + + if polymorphic_on: + discriminator = row[polymorphic_on] + if discriminator is not None: + _instance = polymorphic_instances[discriminator] + if _instance: + return _instance(row, result) + + # determine identity key + if refresh_instance: + # TODO: refresh_instance seems to be named wrongly -- it is always an instance state. + refresh_state = refresh_instance + identitykey = refresh_state.key + if identitykey is None: + # super-rare condition; a refresh is being called + # on a non-instance-key instance; this is meant to only + # occur within a flush() + identitykey = self._identity_key_from_state(refresh_state) + else: + identitykey = identity_key(row) - def _post_instance(self, selectcontext, state, **kwargs): - post_processors = selectcontext.attributes[('post_processors', self, None)] - for p in post_processors: - p(state.obj(), **kwargs) + if identitykey in session_identity_map: + instance = session_identity_map[identitykey] + state = attributes.instance_state(instance) - def _get_poly_select_loader(self, selectcontext, row): - """set up attribute loaders for 'select' and 'deferred' polymorphic loading. + if self.__should_log_debug: + self.__log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), str(identitykey))) + + isnew = state.runid != context.runid + currentload = not isnew + loaded_instance = False + + if not currentload and version_id_col and context.version_check and self._get_state_attr_by_column(state, self.version_id_col) != row[version_id_col]: + raise exc.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (state_str(state), self._get_state_attr_by_column(state, self.version_id_col), row[version_id_col])) + elif refresh_instance: + # out of band refresh_instance detected (i.e. its not in the session.identity_map) + # honor it anyway. this can happen if a _get() occurs within save_obj(), such as + # when eager_defaults is True. + state = refresh_instance + instance = state.obj() + isnew = state.runid != context.runid + currentload = True + loaded_instance = False + else: + if self.__should_log_debug: + self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) - this loading uses a second SELECT statement to load additional tables, - either immediately after loading the main table or via a deferred attribute trigger. - """ + if self.allow_null_pks: + for x in identitykey[1]: + if x is not None: + break + else: + return None + else: + if None in identitykey[1]: + return None + isnew = True + currentload = True + loaded_instance = True + + if create_instance: + instance = extension.create_instance(self, context, row, self.class_) + if instance is EXT_CONTINUE: + instance = self.class_manager.new_instance() + else: + manager = attributes.manager_for_cls(instance.__class__) + # TODO: if manager is None, raise a friendly error about + # returning instances of unmapped types + manager.setup_instance(instance) + else: + instance = self.class_manager.new_instance() - (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None)) + if self.__should_log_debug: + self.__log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey))) - if hosted_mapper is None or not needs_tables: - return + state = attributes.instance_state(instance) + state.entity_name = self.entity_name + state.key = identitykey + # manually adding instance to session. for a complete add, + # session._finalize_loaded() must be called. + state.session_id = context.session.hash_key + session_identity_map.add(state) - cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables) - statement = sql.select(needs_tables, cond, use_labels=True) + if currentload or populate_existing: + if isnew: + state.runid = context.runid + context.progress.add(state) - if hosted_mapper.polymorphic_fetch == 'select': - def post_execute(instance, **flags): - if self.__should_log_debug: - self.__log_debug("Post query loading instance " + instance_str(instance)) + if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: + populate_state(state, row, isnew, only_load_props) - identitykey = self.identity_key_from_instance(instance) - - only_load_props = flags.get('only_load_props', None) + else: + # populate attributes on non-loading instances which have been expired + # TODO: also support deferred attributes here [ticket:870] + # TODO: apply eager loads to un-lazy loaded collections ? + # we might want to create an expanded form of 'state.expired_attributes' which includes deferred/un-lazy loaded + if state.expired_attributes: + if state in context.partials: + isnew = False + attrs = context.partials[state] + else: + isnew = True + attrs = state.expired_attributes.intersection(state.unmodified) + context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs - params = {} - for c, bind in param_names: - params[bind] = self._get_attr_by_column(instance, c) - row = selectcontext.session.connection(self).execute(statement, params).fetchone() - self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True, only_load_props=only_load_props) - return post_execute - elif hosted_mapper.polymorphic_fetch == 'deferred': - from sqlalchemy.orm.strategies import DeferredColumnLoader - - def post_execute(instance, **flags): - def create_statement(instance): - params = {} - for (c, bind) in param_names: - # use the "committed" (database) version to get query column values - params[bind] = self._get_committed_attr_by_column(instance, c) - return (statement, params) - - props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__] - keys = [p.key for p in props] - - only_load_props = flags.get('only_load_props', None) - if only_load_props: - keys = util.Set(keys).difference(only_load_props) - props = [p for p in props if p.key in only_load_props] - - for prop in props: - strategy = prop._get_strategy(DeferredColumnLoader) - instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement)) - return post_execute - else: - return None + if not populate_instance or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: + populate_state(state, row, isnew, attrs, instancekey=identitykey) + + if result is not None and (not append_result or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE): + result.append(instance) - def _deferred_inheritance_condition(self, base_mapper, needs_tables): - base_mapper = base_mapper.primary_mapper() + if loaded_instance: + state._run_on_load(instance) + + return instance + return _instance + + def __populators(self, context, path, row, adapter): + new_populators, existing_populators = [], [] + for prop in self.__props.values(): + newpop, existingpop = prop.create_row_processor(context, path, self, row, adapter) + if newpop: + new_populators.append((prop.key, newpop)) + if existingpop: + existing_populators.append((prop.key, existingpop)) + return new_populators, existing_populators + + def __configure_subclass_mapper(self, context, path, adapter): + def configure_subclass_mapper(discriminator): + try: + mapper = self.polymorphic_map[discriminator] + except KeyError: + raise AssertionError("No such polymorphic_identity %r is defined" % discriminator) + if mapper is self: + return None + return mapper._instance_processor(context, path, adapter, polymorphic_from=self) + return configure_subclass_mapper + + def _optimized_get_statement(self, state, attribute_names): + props = self.__props + tables = util.Set([props[key].parent.local_table for key in attribute_names]) + if self.base_mapper.local_table in tables: + return None def visit_binary(binary): leftcol = binary.left rightcol = binary.right if leftcol is None or rightcol is None: return - if leftcol.table not in needs_tables: - binary.left = sql.bindparam(None, None, type_=binary.right.type) - param_names.append((leftcol, binary.left)) - elif rightcol not in needs_tables: - binary.right = sql.bindparam(None, None, type_=binary.right.type) - param_names.append((rightcol, binary.right)) + + if leftcol.table not in tables: + binary.left = sql.bindparam(None, self._get_committed_state_attr_by_column(state, leftcol), type_=binary.right.type) + elif rightcol.table not in tables: + binary.right = sql.bindparam(None, self._get_committed_state_attr_by_column(state, rightcol), type_=binary.right.type) allconds = [] - param_names = [] - for mapper in self.iterate_to_root(): - if mapper is base_mapper: - break - allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary)) + start = False + for mapper in util.reversed(list(self.iterate_to_root())): + if mapper.local_table in tables: + start = True + if start: + allconds.append(visitors.cloned_traverse(mapper.inherit_condition, {}, {'binary':visit_binary})) + + cond = sql.and_(*allconds) + return sql.select(tables, cond, use_labels=True) + +Mapper.logger = log.class_logger(Mapper) + - return sql.and_(*allconds), param_names +def _event_on_init(state, instance, args, kwargs): + """Trigger mapper compilation and run init_instance hooks.""" -Mapper.logger = logging.class_logger(Mapper) + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + # compile() always compiles all mappers + instrumenting_mapper.compile() + if 'init_instance' in instrumenting_mapper.extension.methods: + instrumenting_mapper.extension.init_instance( + instrumenting_mapper, instrumenting_mapper.class_, + state.manager.events.original_init, + instance, args, kwargs) +def _event_on_init_failure(state, instance, args, kwargs): + """Run init_failed hooks.""" -object_session = None + instrumenting_mapper = state.manager.info[_INSTRUMENTOR] + if 'init_failed' in instrumenting_mapper.extension.methods: + util.warn_exception( + instrumenting_mapper.extension.init_failed, + instrumenting_mapper, instrumenting_mapper.class_, + state.manager.events.original_init, instance, args, kwargs) -def _load_scalar_attributes(instance, attribute_names): - mapper = object_mapper(instance) - global object_session - if not object_session: - from sqlalchemy.orm.session import object_session - session = object_session(instance) +def _legacy_descriptors(): + """Build compatibility descriptors mapping legacy to InstanceState. + + These are slated for removal in 0.5. They were never part of the + official public API but were suggested as temporary workarounds in a + number of mailing list posts. Permanent and public solutions for those + needs should be available now. Consult the applicable mailing list + threads for details. + + """ + def _instance_key(self): + state = attributes.instance_state(self) + if state.key is not None: + return state.key + else: + raise AttributeError("_instance_key") + _instance_key = util.deprecated(None, False)(_instance_key) + _instance_key = property(_instance_key) + + def _sa_session_id(self): + state = attributes.instance_state(self) + if state.session_id is not None: + return state.session_id + else: + raise AttributeError("_sa_session_id") + _sa_session_id = util.deprecated(None, False)(_sa_session_id) + _sa_session_id = property(_sa_session_id) + + def _entity_name(self): + state = attributes.instance_state(self) + if state.entity_name is attributes.NO_ENTITY_NAME: + return None + else: + return state.entity_name + _entity_name = util.deprecated(None, False)(_entity_name) + _entity_name = property(_entity_name) + + return dict(locals()) +_legacy_descriptors = _legacy_descriptors() + +def _load_scalar_attributes(state, attribute_names): + mapper = _state_mapper(state) + session = _state_session(state) if not session: - try: - session = mapper.get_session() - except exceptions.InvalidRequestError: - raise exceptions.UnboundExecutionError("Instance %s is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed" % (instance.__class__)) - - state = instance._state - if '_instance_key' in state.dict: - identity_key = state.dict['_instance_key'] - shouldraise = True - else: - # if instance is pending, a refresh operation may not complete (even if PK attributes are assigned) - shouldraise = False - identity_key = mapper._identity_key_from_state(state) - - if session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) is None and shouldraise: - raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance)) + raise sa_exc.UnboundExecutionError("Instance %s is not bound to a Session; attribute refresh operation cannot proceed" % (state_str(state))) + + has_key = _state_has_identity(state) + + result = False + if mapper.inherits and not mapper.concrete: + statement = mapper._optimized_get_statement(state, attribute_names) + if statement: + result = session.query(mapper).from_statement(statement)._get(None, only_load_props=attribute_names, refresh_instance=state) + + if result is False: + if has_key: + identity_key = state.key + else: + identity_key = mapper._identity_key_from_state(state) + result = session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) + # if instance is pending, a refresh operation may not complete (even if PK attributes are assigned) + if has_key and result is None: + raise exc.ObjectDeletedError("Instance '%s' has been deleted." % state_str(state)) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 33a0ff432..fc2e90189 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -6,19 +6,20 @@ """MapperProperty implementations. -This is a private module which defines the behavior of -invidual ORM-mapped attributes. +This is a private module which defines the behavior of invidual ORM-mapped +attributes. + """ -from sqlalchemy import sql, schema, util, exceptions, logging -from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs, find_columns -from sqlalchemy.sql import visitors, operators, ColumnElement -from sqlalchemy.orm import mapper, sync, strategies, attributes, dependency, object_mapper -from sqlalchemy.orm import session as sessionlib -from sqlalchemy.orm.mapper import _class_to_mapper -from sqlalchemy.orm.util import CascadeOptions, PropertyAliasedClauses -from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY -from sqlalchemy.exceptions import ArgumentError +from sqlalchemy import sql, util, log +import sqlalchemy.exceptions as sa_exc +from sqlalchemy.sql.util import ClauseAdapter, criterion_as_pairs +from sqlalchemy.sql import operators, ColumnElement, expression +from sqlalchemy.orm import mapper, strategies, attributes, dependency, \ + object_mapper, session as sessionlib +from sqlalchemy.orm.util import CascadeOptions, _class_to_mapper, _orm_annotate +from sqlalchemy.orm.interfaces import StrategizedProperty, PropComparator, \ + MapperProperty, ONETOMANY, MANYTOONE, MANYTOMANY __all__ = ('ColumnProperty', 'CompositeProperty', 'SynonymProperty', 'ComparableProperty', 'PropertyLoader', 'BackRef') @@ -34,18 +35,15 @@ class ColumnProperty(StrategizedProperty): appears across each table. """ - self.columns = list(columns) + self.columns = [expression._labeled(c) for c in columns] self.group = kwargs.pop('group', None) self.deferred = kwargs.pop('deferred', False) self.comparator = ColumnProperty.ColumnComparator(self) + util.set_creation_order(self) if self.deferred: self.strategy_class = strategies.DeferredColumnLoader else: self.strategy_class = strategies.ColumnLoader - # sanity check - for col in columns: - if not isinstance(col, ColumnElement): - raise ArgumentError('column_property() must be given a ColumnElement as its argument. Try .label() or .as_scalar() for Selectables to fix this.') def do_init(self): super(ColumnProperty, self).do_init() @@ -61,37 +59,41 @@ class ColumnProperty(StrategizedProperty): return ColumnProperty(deferred=self.deferred, group=self.group, *self.columns) def getattr(self, state, column): - return getattr(state.class_, self.key).impl.get(state) + return state.get_impl(self.key).get(state) def getcommitted(self, state, column): - return getattr(state.class_, self.key).impl.get_committed_value(state) + return state.get_impl(self.key).get_committed_value(state) def setattr(self, state, value, column): - getattr(state.class_, self.key).impl.set(state, value, None) + state.get_impl(self.key).set(state, value, None) def merge(self, session, source, dest, dont_load, _recursive): - value = attributes.get_as_list(source._state, self.key, passive=True) + value = attributes.instance_state(source).value_as_iterable( + self.key, passive=True) if value: setattr(dest, self.key, value[0]) else: - # TODO: lazy callable should merge to the new instance - dest._state.expire_attributes([self.key]) + attributes.instance_state(dest).expire_attributes([self.key]) def get_col_value(self, column, value): return value class ColumnComparator(PropComparator): - def clause_element(self): - return self.prop.columns[0] - + def __clause_element__(self): + return self.prop.columns[0]._annotate({"parententity": self.prop.parent}) + __clause_element__ = util.cache_decorator(__clause_element__) + def operate(self, op, *other, **kwargs): - return op(self.prop.columns[0], *other, **kwargs) + return op(self.__clause_element__(), *other, **kwargs) def reverse_operate(self, op, other, **kwargs): - col = self.prop.columns[0] + col = self.__clause_element__() return op(col._bind_param(other), col, **kwargs) -ColumnProperty.logger = logging.class_logger(ColumnProperty) + def __str__(self): + return str(self.parent.class_.__name__) + "." + self.key + +ColumnProperty.logger = log.class_logger(ColumnProperty) class CompositeProperty(ColumnProperty): """subclasses ColumnProperty to provide composite type support.""" @@ -100,6 +102,7 @@ class CompositeProperty(ColumnProperty): super(CompositeProperty, self).__init__(*columns, **kwargs) self.composite_class = class_ self.comparator = kwargs.pop('comparator', CompositeProperty.Comparator)(self) + self.strategy_class = strategies.CompositeColumnLoader def do_init(self): super(ColumnProperty, self).do_init() @@ -109,19 +112,19 @@ class CompositeProperty(ColumnProperty): return CompositeProperty(deferred=self.deferred, group=self.group, composite_class=self.composite_class, *self.columns) def getattr(self, state, column): - obj = getattr(state.class_, self.key).impl.get(state) + obj = state.get_impl(self.key).get(state) return self.get_col_value(column, obj) def getcommitted(self, state, column): - obj = getattr(state.class_, self.key).impl.get_committed_value(state) + obj = state.get_impl(self.key).get_committed_value(state) return self.get_col_value(column, obj) def setattr(self, state, value, column): # TODO: test coverage for this method - obj = getattr(state.class_, self.key).impl.get(state) + obj = state.get_impl(self.key).get(state) if obj is None: obj = self.composite_class(*[None for c in self.columns]) - getattr(state.class_, self.key).impl.set(state, obj, None) + state.get_impl(self.key).set(state, obj, None) for a, b in zip(self.columns, value.__composite_values__()): if a is column: @@ -133,6 +136,9 @@ class CompositeProperty(ColumnProperty): return b class Comparator(PropComparator): + def __clause_element__(self): + return expression.ClauseList(*self.prop.columns) + def __eq__(self, other): if other is None: return sql.and_(*[a==None for a in self.prop.columns]) @@ -146,17 +152,21 @@ class CompositeProperty(ColumnProperty): zip(self.prop.columns, other.__composite_values__())]) + def __str__(self): + return str(self.parent.class_.__name__) + "." + self.key + class SynonymProperty(MapperProperty): def __init__(self, name, map_column=None, descriptor=None): self.name = name - self.map_column=map_column + self.map_column = map_column self.descriptor = descriptor + util.set_creation_order(self) - def setup(self, querycontext, **kwargs): + def setup(self, context, entity, path, adapter, **kwargs): pass - def create_row_processor(self, selectcontext, mapper, row): - return (None, None, None) + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + return (None, None) def do_init(self): class_ = self.parent.class_ @@ -174,12 +184,11 @@ class SynonymProperty(MapperProperty): return s return getattr(obj, self.name) self.descriptor = SynonymProp() - sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator) + sessionlib.register_attribute(class_, self.key, uselist=False, proxy_property=self.descriptor, useobject=False, comparator=comparator, parententity=self.parent) def merge(self, session, source, dest, _recursive): pass -SynonymProperty.logger = logging.class_logger(SynonymProperty) - +SynonymProperty.logger = log.class_logger(SynonymProperty) class ComparableProperty(MapperProperty): """Instruments a Python property for use in query expressions.""" @@ -187,6 +196,7 @@ class ComparableProperty(MapperProperty): def __init__(self, comparator_factory, descriptor=None): self.descriptor = descriptor self.comparator = comparator_factory(self) + util.set_creation_order(self) def do_init(self): """Set up a proxy to the unmanaged descriptor.""" @@ -198,11 +208,11 @@ class ComparableProperty(MapperProperty): useobject=False, comparator=self.comparator) - def setup(self, querycontext, **kwargs): + def setup(self, context, entity, path, adapter, **kwargs): pass - def create_row_processor(self, selectcontext, mapper, row): - return (None, None, None) + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + return (None, None) class PropertyLoader(StrategizedProperty): @@ -210,7 +220,22 @@ class PropertyLoader(StrategizedProperty): of items that correspond to a related database table. """ - def __init__(self, argument, secondary=None, primaryjoin=None, secondaryjoin=None, entity_name=None, foreign_keys=None, foreignkey=None, uselist=None, private=False, association=None, order_by=False, attributeext=None, backref=None, is_backref=False, post_update=False, cascade=None, viewonly=False, lazy=True, collection_class=None, passive_deletes=False, passive_updates=True, remote_side=None, enable_typechecks=True, join_depth=None, strategy_class=None, _local_remote_pairs=None): + def __init__(self, argument, + secondary=None, primaryjoin=None, + secondaryjoin=None, entity_name=None, + foreign_keys=None, + uselist=None, + order_by=False, + backref=None, + _is_backref=False, + post_update=False, + cascade=None, + viewonly=False, lazy=True, + collection_class=None, passive_deletes=False, + passive_updates=True, remote_side=None, + enable_typechecks=True, join_depth=None, + strategy_class=None, _local_remote_pairs=None): + self.uselist = uselist self.argument = argument self.entity_name = entity_name @@ -222,9 +247,6 @@ class PropertyLoader(StrategizedProperty): self.viewonly = viewonly self.lazy = lazy self.foreign_keys = util.to_set(foreign_keys) - self._legacy_foreignkey = util.to_set(foreignkey) - if foreignkey: - util.warn_deprecated('foreignkey option is deprecated; see docs for details') self.collection_class = collection_class self.passive_deletes = passive_deletes self.passive_updates = passive_updates @@ -233,6 +255,8 @@ class PropertyLoader(StrategizedProperty): self.comparator = PropertyLoader.Comparator(self) self.join_depth = join_depth self._arg_local_remote_pairs = _local_remote_pairs + self.__join_cache = {} + util.set_creation_order(self) if strategy_class: self.strategy_class = strategy_class @@ -251,20 +275,13 @@ class PropertyLoader(StrategizedProperty): if cascade is not None: self.cascade = CascadeOptions(cascade) else: - if private: - util.warn_deprecated('private option is deprecated; see docs for details') - self.cascade = CascadeOptions("all, delete-orphan") - else: - self.cascade = CascadeOptions("save-update, merge") + self.cascade = CascadeOptions("save-update, merge") if self.passive_deletes == 'all' and ("delete" in self.cascade or "delete-orphan" in self.cascade): - raise exceptions.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade") + raise sa_exc.ArgumentError("Can't set passive_deletes='all' in conjunction with 'delete' or 'delete-orphan' cascade") - self.association = association - if association: - util.warn_deprecated('association option is deprecated; see docs for details') self.order_by = order_by - self.attributeext=attributeext + if isinstance(backref, str): # propigate explicitly sent primary/secondary join conditions to the BackRef object if # just a string was sent @@ -275,14 +292,21 @@ class PropertyLoader(StrategizedProperty): self.backref = BackRef(backref, primaryjoin=primaryjoin, secondaryjoin=secondaryjoin, passive_updates=self.passive_updates) else: self.backref = backref - self.is_backref = is_backref - + self._is_backref = _is_backref + class Comparator(PropComparator): def __init__(self, prop, of_type=None): self.prop = self.property = prop if of_type: self._of_type = _class_to_mapper(of_type) + def parententity(self): + return self.prop.parent + parententity = property(parententity) + + def __clause_element__(self): + return self.prop.parent._with_polymorphic_selectable + def of_type(self, cls): return PropertyLoader.Comparator(self.prop, cls) @@ -294,7 +318,7 @@ class PropertyLoader(StrategizedProperty): return self.prop._optimized_compare(None) elif self.prop.uselist: if not hasattr(other, '__iter__'): - raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object. Use contains().") + raise sa_exc.InvalidRequestError("Can only compare a collection to an iterable object. Use contains().") else: j = self.prop.primaryjoin if self.prop.secondaryjoin: @@ -308,60 +332,62 @@ class PropertyLoader(StrategizedProperty): else: return self.prop._optimized_compare(other) - def _join_and_criterion(self, criterion=None, **kwargs): + def __criterion_exists(self, criterion=None, **kwargs): if getattr(self, '_of_type', None): target_mapper = self._of_type - to_selectable = target_mapper._with_polymorphic_selectable() #mapped_table + to_selectable = target_mapper._with_polymorphic_selectable else: to_selectable = None - pj, sj, source, dest, target_adapter = self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable) + pj, sj, source, dest, secondary, target_adapter = self.prop._create_joins(dest_polymorphic=True, dest_selectable=to_selectable) for k in kwargs: - crit = (getattr(self.prop.mapper.class_, k) == kwargs[k]) + crit = self.prop.mapper.class_manager.get_inst(k) == kwargs[k] if criterion is None: criterion = crit else: criterion = criterion & crit if sj: - j = pj & sj + j = _orm_annotate(pj) & sj else: - j = pj + j = _orm_annotate(pj, exclude=self.prop.remote_side) if criterion and target_adapter: + # limit this adapter to annotated only? criterion = target_adapter.traverse(criterion) - return j, criterion, dest + # only have the "joined left side" of what we return be subject to Query adaption. The right + # side of it is used for an exists() subquery and should not correlate or otherwise reach out + # to anything in the enclosing query. + if criterion: + criterion = criterion._annotate({'_halt_adapt': True}) + return sql.exists([1], j & criterion, from_obj=dest).correlate(source) def any(self, criterion=None, **kwargs): if not self.prop.uselist: - raise exceptions.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") - j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs) + raise sa_exc.InvalidRequestError("'any()' not implemented for scalar attributes. Use has().") - return sql.exists([1], j & criterion, from_obj=from_obj) + return self.__criterion_exists(criterion, **kwargs) def has(self, criterion=None, **kwargs): if self.prop.uselist: - raise exceptions.InvalidRequestError("'has()' not implemented for collections. Use any().") - j, criterion, from_obj = self._join_and_criterion(criterion, **kwargs) - - return sql.exists([1], j & criterion, from_obj=from_obj) + raise sa_exc.InvalidRequestError("'has()' not implemented for collections. Use any().") + return self.__criterion_exists(criterion, **kwargs) def contains(self, other): if not self.prop.uselist: - raise exceptions.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==") + raise sa_exc.InvalidRequestError("'contains' not implemented for scalar attributes. Use ==") clause = self.prop._optimized_compare(other) if self.prop.secondaryjoin: - clause.negation_clause = self._negated_contains_or_equals(other) + clause.negation_clause = self.__negated_contains_or_equals(other) return clause - def _negated_contains_or_equals(self, other): + def __negated_contains_or_equals(self, other): criterion = sql.and_(*[x==y for (x, y) in zip(self.prop.mapper.primary_key, self.prop.mapper.primary_key_from_instance(other))]) - j, criterion, from_obj = self._join_and_criterion(criterion) - return ~sql.exists([1], j & criterion, from_obj=from_obj) + return ~self.__criterion_exists(criterion) def __ne__(self, other): if other is None: @@ -373,9 +399,9 @@ class PropertyLoader(StrategizedProperty): return self.has() if self.prop.uselist and not hasattr(other, '__iter__'): - raise exceptions.InvalidRequestError("Can only compare a collection to an iterable object") + raise sa_exc.InvalidRequestError("Can only compare a collection to an iterable object") - return self._negated_contains_or_equals(other) + return self.__negated_contains_or_equals(other) def compare(self, op, value, value_is_parent=False): if op == operators.eq: @@ -390,27 +416,29 @@ class PropertyLoader(StrategizedProperty): return op(self.comparator, value) def _optimized_compare(self, value, value_is_parent=False): + if value is not None: + value = attributes.instance_state(value) return self._get_strategy(strategies.LazyLoader).lazy_clause(value, reverse_direction=not value_is_parent) - def private(self): - return self.cascade.delete_orphan - private = property(private) - def __str__(self): - return str(self.parent.class_.__name__) + "." + self.key + " (" + str(self.mapper.class_.__name__) + ")" + return str(self.parent.class_.__name__) + "." + self.key def merge(self, session, source, dest, dont_load, _recursive): if not dont_load and self._reverse_property and (source, self._reverse_property) in _recursive: return - + + source_state = attributes.instance_state(source) + dest_state = attributes.instance_state(dest) + if not "merge" in self.cascade: - dest._state.expire_attributes([self.key]) + dest_state.expire_attributes([self.key]) return - instances = attributes.get_as_list(source._state, self.key, passive=True) + instances = source_state.value_as_iterable(self.key, passive=True) + if not instances: return - + if self.uselist: dest_list = [] for current in instances: @@ -419,11 +447,11 @@ class PropertyLoader(StrategizedProperty): if obj is not None: dest_list.append(obj) if dont_load: - coll = attributes.init_collection(dest, self.key) + coll = attributes.init_collection(dest_state, self.key) for c in dest_list: coll.append_without_event(c) else: - getattr(dest.__class__, self.key).impl._set_iterable(dest._state, dest_list) + getattr(dest.__class__, self.key).impl._set_iterable(dest_state, dest_list) else: current = instances[0] if current is not None: @@ -440,17 +468,17 @@ class PropertyLoader(StrategizedProperty): return passive = type_ != 'delete' or self.passive_deletes mapper = self.mapper.primary_mapper() - instances = attributes.get_as_list(state, self.key, passive=passive) + instances = state.value_as_iterable(self.key, passive=passive) if instances: for c in instances: if c is not None and c not in visited_instances and (halt_on is None or not halt_on(c)): if not isinstance(c, self.mapper.class_): - raise exceptions.AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) + raise AssertionError("Attribute '%s' on class '%s' doesn't handle objects of type '%s'" % (self.key, str(self.parent.class_), str(c.__class__))) visited_instances.add(c) # cascade using the mapper local to this object, so that its individual properties are located instance_mapper = object_mapper(c, entity_name=mapper.entity_name) - yield (c, instance_mapper, c._state) + yield (c, instance_mapper, attributes.instance_state(c)) def _get_target_class(self): """Return the target class of the relation, even if the @@ -479,7 +507,8 @@ class PropertyLoader(StrategizedProperty): # accept a callable to suit various deferred-configurational schemes self.mapper = mapper.class_mapper(self.argument(), entity_name=self.entity_name, compile=False) else: - raise exceptions.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument))) + raise sa_exc.ArgumentError("relation '%s' expects a class or a mapper argument (received: %s)" % (self.key, type(self.argument))) + assert isinstance(self.mapper, mapper.Mapper), self.mapper if not self.parent.concrete: for inheriting in self.parent.iterate_to_root(): @@ -495,14 +524,14 @@ class PropertyLoader(StrategizedProperty): if self.cascade.delete_orphan: if self.parent.class_ is self.mapper.class_: - raise exceptions.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade " + raise sa_exc.ArgumentError("In relationship '%s', can't establish 'delete-orphan' cascade " "rule on a self-referential relationship. " "You probably want cascade='all', which includes delete cascading but not orphan detection." %(str(self))) self.mapper.primary_mapper().delete_orphans.append((self.key, self.parent.class_)) def __determine_joins(self): if self.secondaryjoin is not None and self.secondary is None: - raise exceptions.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument") + raise sa_exc.ArgumentError("Property '" + self.key + "' specified with secondary join condition but no secondary argument") # if join conditions were not specified, figure them out based on foreign keys def _search_for_join(mapper, table): @@ -512,7 +541,7 @@ class PropertyLoader(StrategizedProperty): is a join.""" try: return sql.join(mapper.local_table, table) - except exceptions.ArgumentError, e: + except sa_exc.ArgumentError, e: return sql.join(mapper.mapped_table, table) try: @@ -524,8 +553,8 @@ class PropertyLoader(StrategizedProperty): else: if self.primaryjoin is None: self.primaryjoin = _search_for_join(self.parent, self.target).onclause - except exceptions.ArgumentError, e: - raise exceptions.ArgumentError("Could not determine join condition between parent/child tables on relation %s. " + except sa_exc.ArgumentError, e: + raise sa_exc.ArgumentError("Could not determine join condition between parent/child tables on relation %s. " "Specify a 'primaryjoin' expression. If this is a many-to-many relation, 'secondaryjoin' is needed as well." % (self)) @@ -540,14 +569,11 @@ class PropertyLoader(StrategizedProperty): def __determine_fks(self): - if self._legacy_foreignkey and not self._refers_to_parent_table(): - self.foreign_keys = self._legacy_foreignkey - arg_foreign_keys = self.foreign_keys if self._arg_local_remote_pairs: if not arg_foreign_keys: - raise exceptions.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument") + raise sa_exc.ArgumentError("foreign_keys argument is required with _local_remote_pairs argument") self.foreign_keys = util.OrderedSet(arg_foreign_keys) self._opposite_side = util.OrderedSet() for l, r in self._arg_local_remote_pairs: @@ -562,15 +588,15 @@ class PropertyLoader(StrategizedProperty): if not eq_pairs: if not self.viewonly and criterion_as_pairs(self.primaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True): - raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. " + raise sa_exc.ArgumentError("Could not locate any equated, locally mapped column pairs for primaryjoin condition '%s' on relation %s. " "For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.primaryjoin, self) ) else: if arg_foreign_keys: - raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. " + raise sa_exc.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. " "Specify _local_remote_pairs=[(local, remote), (local, remote), ...] to explicitly establish the local/remote column pairs." % (self.primaryjoin, self)) else: - raise exceptions.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. " + raise sa_exc.ArgumentError("Could not determine relation direction for primaryjoin condition '%s', on relation %s. " "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.primaryjoin, self)) self.foreign_keys = util.OrderedSet([r for l, r in eq_pairs]) @@ -583,11 +609,11 @@ class PropertyLoader(StrategizedProperty): if not sq_pairs: if not self.viewonly and criterion_as_pairs(self.secondaryjoin, consider_as_foreign_keys=arg_foreign_keys, any_operator=True): - raise exceptions.ArgumentError("Could not locate any equated, locally mapped column pairs for secondaryjoin condition '%s' on relation %s. " + raise sa_exc.ArgumentError("Could not locate any equated, locally mapped column pairs for secondaryjoin condition '%s' on relation %s. " "For more relaxed rules on join conditions, the relation may be marked as viewonly=True." % (self.secondaryjoin, self) ) else: - raise exceptions.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. " + raise sa_exc.ArgumentError("Could not determine relation direction for secondaryjoin condition '%s', on relation %s. " "Specify the foreign_keys argument to indicate which columns on the relation are foreign." % (self.secondaryjoin, self)) self.foreign_keys.update([r for l, r in sq_pairs]) @@ -599,7 +625,7 @@ class PropertyLoader(StrategizedProperty): def __determine_remote_side(self): if self._arg_local_remote_pairs: if self.remote_side: - raise exceptions.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.") + raise sa_exc.ArgumentError("remote_side argument is redundant against more detailed _local_remote_side argument.") if self.direction is MANYTOONE: eq_pairs = [(r, l) for l, r in self._arg_local_remote_pairs] else: @@ -629,11 +655,11 @@ class PropertyLoader(StrategizedProperty): if self.direction is ONETOMANY: for l in self.local_side: if not self.__col_is_part_of_mappings(l): - raise exceptions.ArgumentError("Local column '%s' is not part of mapping %s. Specify remote_side argument to indicate which column lazy join condition should compare against." % (l, self.parent)) + raise sa_exc.ArgumentError("Local column '%s' is not part of mapping %s. Specify remote_side argument to indicate which column lazy join condition should compare against." % (l, self.parent)) elif self.direction is MANYTOONE: for r in self.remote_side: if not self.__col_is_part_of_mappings(r): - raise exceptions.ArgumentError("Remote column '%s' is not part of mapping %s. Specify remote_side argument to indicate which column lazy join condition should bind." % (r, self.mapper)) + raise sa_exc.ArgumentError("Remote column '%s' is not part of mapping %s. Specify remote_side argument to indicate which column lazy join condition should bind." % (r, self.mapper)) def __determine_direction(self): """Determine our *direction*, i.e. do we represent one to @@ -646,13 +672,7 @@ class PropertyLoader(StrategizedProperty): # for a self referential mapper, if the "foreignkey" is a single or composite primary key, # then we are "many to one", since the remote site of the relationship identifies a singular entity. # otherwise we are "one to many". - if self._legacy_foreignkey: - for f in self._legacy_foreignkey: - if not f.primary_key: - self.direction = ONETOMANY - else: - self.direction = MANYTOONE - elif self._arg_local_remote_pairs: + if self._arg_local_remote_pairs: remote = util.Set([r for l, r in self._arg_local_remote_pairs]) if self.foreign_keys.intersection(remote): self.direction = ONETOMANY @@ -671,7 +691,7 @@ class PropertyLoader(StrategizedProperty): manytoone = [c for c in self.foreign_keys if parenttable.c.contains_column(c)] if not onetomany and not manytoone: - raise exceptions.ArgumentError( + raise sa_exc.ArgumentError( "Can't determine relation direction for relationship '%s' " "- foreign key columns are present in neither the " "parent nor the child's mapped tables" %(str(self))) @@ -684,14 +704,14 @@ class PropertyLoader(StrategizedProperty): self.direction = MANYTOONE break else: - raise exceptions.ArgumentError( + raise sa_exc.ArgumentError( "Can't determine relation direction for relationship '%s' " "- foreign key columns are present in both the parent and " "the child's mapped tables. Specify 'foreign_keys' " "argument." % (str(self))) def _post_init(self): - if logging.is_info_enabled(self.logger): + if log.is_info_enabled(self.logger): self.logger.info(str(self) + " setup primary join %s" % self.primaryjoin) self.logger.info(str(self) + " setup secondary join %s" % self.secondaryjoin) self.logger.info(str(self) + " synchronize pairs [%s]" % ",".join(["(%s => %s)" % (l, r) for l, r in self.synchronize_pairs])) @@ -710,15 +730,10 @@ class PropertyLoader(StrategizedProperty): # primary property handler, set up class attributes if self.is_primary(): - # if a backref name is defined, set up an extension to populate - # attributes in the other direction - if self.backref is not None: - self.attributeext = self.backref.get_extension() - if self.backref is not None: self.backref.compile(self) elif not mapper.class_mapper(self.parent.class_, compile=False)._get_property(self.key, raiseerr=False): - raise exceptions.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__)) + raise sa_exc.ArgumentError("Attempting to assign a new relation '%s' to a non-primary mapper on class '%s'. New relations can only be added to the primary mapper, i.e. the very first mapper created for class '%s' " % (self.key, self.parent.class_.__name__, self.parent.class_.__name__)) super(PropertyLoader, self).do_init() @@ -729,50 +744,69 @@ class PropertyLoader(StrategizedProperty): return self.mapper.common_parent(self.parent) def _create_joins(self, source_polymorphic=False, source_selectable=None, dest_polymorphic=False, dest_selectable=None): + key = util.WeakCompositeKey(source_polymorphic, source_selectable, dest_polymorphic, dest_selectable) + try: + return self.__join_cache[key] + except KeyError: + pass + if source_selectable is None: if source_polymorphic and self.parent.with_polymorphic: - source_selectable = self.parent._with_polymorphic_selectable() - else: - source_selectable = None + source_selectable = self.parent._with_polymorphic_selectable + + aliased = False if dest_selectable is None: if dest_polymorphic and self.mapper.with_polymorphic: - dest_selectable = self.mapper._with_polymorphic_selectable() + dest_selectable = self.mapper._with_polymorphic_selectable + aliased = True else: dest_selectable = self.mapper.mapped_table - if self._is_self_referential(): + + if self._is_self_referential() and source_selectable is None: + dest_selectable = dest_selectable.alias() + aliased = True + else: + aliased = True + + aliased = aliased or bool(source_selectable) + + primaryjoin, secondaryjoin, secondary = self.primaryjoin, self.secondaryjoin, self.secondary + if aliased: + if secondary: + secondary = secondary.alias() + primary_aliasizer = ClauseAdapter(secondary) if dest_selectable: - dest_selectable = dest_selectable.alias() + secondary_aliasizer = ClauseAdapter(dest_selectable, equivalents=self.mapper._equivalent_columns).chain(primary_aliasizer) else: - dest_selectable = self.mapper.mapped_table.alias() - - primaryjoin = self.primaryjoin - if source_selectable: - if self.direction in (ONETOMANY, MANYTOMANY): - primaryjoin = ClauseAdapter(source_selectable, exclude=self.foreign_keys, equivalents=self.parent._equivalent_columns).traverse(primaryjoin) + secondary_aliasizer = primary_aliasizer + + if source_selectable: + primary_aliasizer = ClauseAdapter(secondary).chain(ClauseAdapter(source_selectable, equivalents=self.parent._equivalent_columns)) + + secondaryjoin = secondary_aliasizer.traverse(secondaryjoin) else: - primaryjoin = ClauseAdapter(source_selectable, include=self.foreign_keys, equivalents=self.parent._equivalent_columns).traverse(primaryjoin) + if dest_selectable: + primary_aliasizer = ClauseAdapter(dest_selectable, exclude=self.local_side, equivalents=self.mapper._equivalent_columns) + if source_selectable: + primary_aliasizer.chain(ClauseAdapter(source_selectable, exclude=self.remote_side, equivalents=self.parent._equivalent_columns)) + elif source_selectable: + primary_aliasizer = ClauseAdapter(source_selectable, exclude=self.remote_side, equivalents=self.parent._equivalent_columns) + + secondary_aliasizer = None - secondaryjoin = self.secondaryjoin - target_adapter = None - if dest_selectable: - if self.direction == ONETOMANY: - target_adapter = ClauseAdapter(dest_selectable, include=self.foreign_keys, equivalents=self.mapper._equivalent_columns) - elif self.direction == MANYTOMANY: - target_adapter = ClauseAdapter(dest_selectable, equivalents=self.mapper._equivalent_columns) - else: - target_adapter = ClauseAdapter(dest_selectable, exclude=self.foreign_keys, equivalents=self.mapper._equivalent_columns) - if secondaryjoin: - secondaryjoin = target_adapter.traverse(secondaryjoin) - else: - primaryjoin = target_adapter.traverse(primaryjoin) + primaryjoin = primary_aliasizer.traverse(primaryjoin) + target_adapter = secondary_aliasizer or primary_aliasizer target_adapter.include = target_adapter.exclude = None - - return primaryjoin, secondaryjoin, source_selectable or self.parent.local_table, dest_selectable or self.mapper.local_table, target_adapter + else: + target_adapter = None + + self.__join_cache[key] = ret = (primaryjoin, secondaryjoin, (source_selectable or self.parent.local_table), (dest_selectable or self.mapper.local_table), secondary, target_adapter) + return ret def _get_join(self, parent, primary=True, secondary=True, polymorphic_parent=True): """deprecated. use primary_join_against(), secondary_join_against(), full_join_against()""" - pj, sj, source, dest, adapter = self._create_joins(source_polymorphic=polymorphic_parent) + pj, sj, source, dest, secondarytable, adapter = self._create_joins(source_polymorphic=polymorphic_parent) if primary and secondary: return pj & sj @@ -788,7 +822,7 @@ class PropertyLoader(StrategizedProperty): if not self.viewonly: self._dependency_processor.register_dependencies(uowcommit) -PropertyLoader.logger = logging.class_logger(PropertyLoader) +PropertyLoader.logger = log.class_logger(PropertyLoader) class BackRef(object): """Attached to a PropertyLoader to indicate a complementary reverse relationship. @@ -799,7 +833,8 @@ class BackRef(object): self.key = key self.kwargs = kwargs self.prop = _prop - + self.extension = attributes.GenericBackrefExtension(self.key) + def compile(self, prop): if self.prop: return @@ -817,7 +852,7 @@ class BackRef(object): relation = PropertyLoader(parent, prop.secondary, pj, sj, backref=BackRef(prop.key, _prop=prop), - is_backref=True, + _is_backref=True, **self.kwargs) mapper._compile_property(self.key, relation); @@ -826,12 +861,7 @@ class BackRef(object): mapper._get_property(self.key)._reverse_property = prop else: - raise exceptions.ArgumentError("Error creating backref '%s' on relation '%s': property of that name exists on mapper '%s'" % (self.key, prop, mapper)) - - def get_extension(self): - """Return an attribute extension to use with this backreference.""" - - return attributes.GenericBackrefExtension(self.key) + raise sa_exc.ArgumentError("Error creating backref '%s' on relation '%s': property of that name exists on mapper '%s'" % (self.key, prop, mapper)) mapper.ColumnProperty = ColumnProperty mapper.SynonymProperty = SynonymProperty diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 8996a758e..dfa24efee 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -15,35 +15,54 @@ operations at the SQL (non-ORM) level. ``Query`` differs from ``Select`` in that it returns ORM-mapped objects and interacts with an ORM session, whereas the ``Select`` construct interacts directly with the database to return iterable result sets. + """ from itertools import chain -from sqlalchemy import sql, util, exceptions, logging + +from sqlalchemy import sql, util, log +from sqlalchemy import exc as sa_exc +from sqlalchemy.orm import exc as orm_exc from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import expression, visitors, operators -from sqlalchemy.orm import mapper, object_mapper +from sqlalchemy.orm import attributes, interfaces, mapper, object_mapper +from sqlalchemy.orm.util import _state_mapper, _is_mapped_class, \ + _is_aliased_class, _entity_descriptor, _entity_info, _class_to_mapper, \ + _orm_columns, AliasedClass, _orm_selectable, join as orm_join, ORMAdapter -from sqlalchemy.orm.util import _state_mapper, _class_to_mapper, _is_mapped_class, _is_aliased_class -from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm import interfaces -from sqlalchemy.orm import attributes -from sqlalchemy.orm.util import AliasedClass +__all__ = ['Query', 'QueryContext', 'aliased'] -aliased = AliasedClass -__all__ = ['Query', 'QueryContext', 'aliased'] +aliased = AliasedClass +def _generative(*assertions): + """mark a method as generative.""" + + def decorate(fn): + argspec = util.format_argspec_plus(fn) + run_assertions = assertions + code = "\n".join([ + "def %s%s:", + " %r", + " self = self._clone()", + " for a in run_assertions:", + " a(self, %r)", + " fn%s", + " return self" + ]) % (fn.__name__, argspec['args'], fn.__doc__, fn.__name__, argspec['apply_pos']) + env = locals().copy() + exec code in env + return env[fn.__name__] + return decorate class Query(object): """Encapsulates the object-fetching operations provided by Mappers.""" - def __init__(self, class_or_mapper, session=None, entity_name=None): - self._session = session - + def __init__(self, entities, session=None, entity_name=None): + self.session = session + self._with_options = [] self._lockmode = None - - self._entities = [] self._order_by = False self._group_by = False self._distinct = False @@ -53,51 +72,53 @@ class Query(object): self._params = {} self._yield_per = None self._criterion = None + self._correlate = util.Set() + self._joinpoint = None + self._with_labels = False self.__joinable_tables = None self._having = None - self._column_aggregate = None self._populate_existing = False self._version_check = False self._autoflush = True - self._attributes = {} self._current_path = () self._only_load_props = None self._refresh_instance = None - - self.__init_mapper(_class_to_mapper(class_or_mapper, entity_name=entity_name)) - - def __init_mapper(self, mapper): - """populate all instance variables derived from this Query's mapper.""" - - self.mapper = mapper - self.table = self._from_obj = self.mapper.mapped_table - self._eager_loaders = util.Set(chain(*[mp._eager_loaders for mp in [m for m in self.mapper.iterate_to_root()]])) - self._extension = self.mapper.extension - self._aliases_head = self._aliases_tail = None - self._alias_ids = {} - self._joinpoint = self.mapper - self._entities.append(_PrimaryMapperEntity(self.mapper)) - if self.mapper.with_polymorphic: - self.__set_with_polymorphic(*self.mapper.with_polymorphic) - else: - self._with_polymorphic = [] - - def __generate_alias_ids(self): - self._alias_ids = dict([ - (k, list(v)) for k, v in self._alias_ids.iteritems() - ]) + self._from_obj = None + self._entities = [] + self._polymorphic_adapters = {} + self._filter_aliases = None + self._from_obj_alias = None + self.__currenttables = util.Set() + + for ent in util.to_list(entities): + _QueryEntity(self, ent, entity_name=entity_name) + + self.__setup_aliasizers(self._entities) + + def __setup_aliasizers(self, entities): + d = {} + for ent in entities: + for entity in ent.entities: + if entity not in d: + mapper, selectable, is_aliased_class = _entity_info(entity, ent.entity_name) + if not is_aliased_class and mapper.with_polymorphic: + with_polymorphic = mapper._with_polymorphic_mappers + self.__mapper_loads_polymorphically_with(mapper, sql_util.ColumnAdapter(selectable, mapper._equivalent_columns)) + adapter = None + elif is_aliased_class: + adapter = sql_util.ColumnAdapter(selectable, mapper._equivalent_columns) + with_polymorphic = None + else: + with_polymorphic = adapter = None - def __no_criterion(self, meth): - return self.__conditional_clone(meth, [self.__no_criterion_condition]) + d[entity] = (mapper, adapter, selectable, is_aliased_class, with_polymorphic) + ent.setup_entity(entity, *d[entity]) - def __no_statement(self, meth): - return self.__conditional_clone(meth, [self.__no_statement_condition]) - - def __reset_all(self, mapper, meth): - q = self.__conditional_clone(meth, [self.__no_criterion_condition]) - q.__init_mapper(mapper, mapper) - return q + def __mapper_loads_polymorphically_with(self, mapper, adapter): + for m2 in mapper._with_polymorphic_mappers: + for m in m2.iterate_to_root(): + self._polymorphic_adapters[m.mapped_table] = self._polymorphic_adapters[m.local_table] = adapter def __set_select_from(self, from_obj): if isinstance(from_obj, expression._SelectBaseMixin): @@ -105,54 +126,168 @@ class Query(object): from_obj = from_obj.alias() self._from_obj = from_obj - self._alias_ids = {} + equivs = self.__all_equivs() + + if isinstance(from_obj, expression.Alias): + # dont alias a regular join (since its not an alias itself) + self._from_obj_alias = sql_util.ColumnAdapter(self._from_obj, equivs) + + def _get_polymorphic_adapter(self, entity, selectable): + self.__mapper_loads_polymorphically_with(entity.mapper, sql_util.ColumnAdapter(selectable, entity.mapper._equivalent_columns)) + + def _reset_polymorphic_adapter(self, mapper): + for m2 in mapper._with_polymorphic_mappers: + for m in m2.iterate_to_root(): + self._polymorphic_adapters.pop(m.mapped_table, None) + self._polymorphic_adapters.pop(m.local_table, None) + + def __reset_joinpoint(self): + self._joinpoint = None + self._filter_aliases = None + + def __adapt_polymorphic_element(self, element): + if isinstance(element, expression.FromClause): + search = element + elif hasattr(element, 'table'): + search = element.table + else: + search = None + + if search: + alias = self._polymorphic_adapters.get(search, None) + if alias: + return alias.adapt_clause(element) + + def __replace_element(self, adapters): + def replace(elem): + if '_halt_adapt' in elem._annotations: + return elem + + for adapter in adapters: + e = adapter(elem) + if e: + return e + return replace + + def __replace_orm_element(self, adapters): + def replace(elem): + if '_halt_adapt' in elem._annotations: + return elem + + if "_orm_adapt" in elem._annotations or "parententity" in elem._annotations: + for adapter in adapters: + e = adapter(elem) + if e: + return e + return replace + + def _adapt_all_clauses(self): + self._disable_orm_filtering = True + _adapt_all_clauses = _generative()(_adapt_all_clauses) + + def _adapt_clause(self, clause, as_filter, orm_only): + adapters = [] + if as_filter and self._filter_aliases: + adapters.append(self._filter_aliases.replace) + + if self._polymorphic_adapters: + adapters.append(self.__adapt_polymorphic_element) + + if self._from_obj_alias: + adapters.append(self._from_obj_alias.replace) + + if not adapters: + return clause + + if getattr(self, '_disable_orm_filtering', not orm_only): + return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_element(adapters)) + else: + return visitors.replacement_traverse(clause, {'column_collections':False}, self.__replace_orm_element(adapters)) - if self.table not in self._get_joinable_tables(): - self._aliases_head = self._aliases_tail = mapperutil.AliasedClauses(self._from_obj, equivalents=self.mapper._equivalent_columns) - self._alias_ids.setdefault(self.table, []).append(self._aliases_head) + def _entity_zero(self): + return self._entities[0] + + def _mapper_zero(self): + return self._entity_zero().entity_zero + + def _extension_zero(self): + ent = self._entity_zero() + return getattr(ent, 'extension', ent.mapper.extension) + + def _mapper_entities(self): + for ent in self._entities: + if hasattr(ent, 'primary_entity'): + yield ent + _mapper_entities = property(_mapper_entities) + + def _joinpoint_zero(self): + return self._joinpoint or self._entity_zero().entity_zero + + def _mapper_zero_or_none(self): + if not getattr(self._entities[0], 'primary_entity', False): + return None + return self._entities[0].mapper + + def _only_mapper_zero(self): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError("This operation requires a Query against a single mapper.") + return self._mapper_zero() + + def _only_entity_zero(self): + if len(self._entities) > 1: + raise sa_exc.InvalidRequestError("This operation requires a Query against a single mapper.") + return self._entity_zero() + + def _generate_mapper_zero(self): + if not getattr(self._entities[0], 'primary_entity', False): + raise sa_exc.InvalidRequestError("No primary mapper set up for this Query.") + entity = self._entities[0]._clone() + self._entities = [entity] + self._entities[1:] + return entity + + def __mapper_zero_from_obj(self): + if self._from_obj: + return self._from_obj else: - self._aliases_head = self._aliases_tail = None + return self._entity_zero().selectable - def __set_with_polymorphic(self, cls_or_mappers, selectable=None): - mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable) - self._with_polymorphic = mappers - self.__set_select_from(from_obj) + def __all_equivs(self): + equivs = {} + for ent in self._mapper_entities: + equivs.update(ent.mapper._equivalent_columns) + return equivs - def __no_criterion_condition(self, q, meth): - if q._criterion or q._statement: + def __no_criterion_condition(self, meth): + if self._criterion or self._statement or self._from_obj: util.warn( ("Query.%s() being called on a Query with existing criterion; " - "criterion is being ignored.") % meth) - - q._joinpoint = self.mapper - q._statement = q._criterion = None - q._order_by = q._group_by = q._distinct = False - q._aliases_tail = q._aliases_head - q.table = q._from_obj = q.mapper.mapped_table - if q.mapper.with_polymorphic: - q.__set_with_polymorphic(*q.mapper.with_polymorphic) - - def __no_entities(self, meth): - q = self.__no_statement(meth) - if len(q._entities) > 1 and not isinstance(q._entities[0], _PrimaryMapperEntity): - raise exceptions.InvalidRequestError( - ("Query.%s() being called on a Query with existing " - "additional entities or columns - can't replace columns") % meth) - q._entities = [] - return q + "criterion is being ignored. This usage is deprecated.") % meth) - def __no_statement_condition(self, q, meth): - if q._statement: - raise exceptions.InvalidRequestError( + self._statement = self._criterion = self._from_obj = None + self._order_by = self._group_by = self._distinct = False + self.__joined_tables = {} + + def __no_from_condition(self, meth): + if self._from_obj: + raise sa_exc.InvalidRequestError("Query.%s() being called on a Query which already has a FROM clause established. This usage is deprecated." % meth) + + def __no_statement_condition(self, meth): + if self._statement: + raise sa_exc.InvalidRequestError( ("Query.%s() being called on a Query with an existing full " "statement - can't apply criterion.") % meth) - def __conditional_clone(self, methname=None, conditions=None): - q = self._clone() - if conditions: - for condition in conditions: - condition(q, methname) - return q + def __no_limit_offset(self, meth): + if self._limit or self._offset: + util.warn("Query.%s() being called on a Query which already has LIMIT or OFFSET applied. " + "This usage is deprecated. Apply filtering and joins before LIMIT or OFFSET are applied, " + "or to filter/join to the row-limited results of the query, call from_self() first." + "In release 0.5, from_self() will be called automatically in this scenario." + ) + + def __no_criterion(self): + """generate a Query with no criterion, warn if criterion was present""" + __no_criterion = _generative(__no_criterion_condition)(__no_criterion) def __get_options(self, populate_existing=None, version_check=None, only_load_props=None, refresh_instance=None): if populate_existing: @@ -170,18 +305,27 @@ class Query(object): q.__dict__ = self.__dict__.copy() return q - def session(self): - if self._session is None: - return self.mapper.get_session() - else: - return self._session - session = property(session) - def statement(self): """return the full SELECT statement represented by this Query.""" - return self._compile_context().statement + return self._compile_context(labels=self._with_labels).statement statement = property(statement) + def with_labels(self): + """Apply column labels to the return value of Query.statement. + + Indicates that this Query's `statement` accessor should return a SELECT statement + that applies labels to all columns in the form <tablename>_<columnname>; this + is commonly used to disambiguate columns from multiple tables which have the + same name. + + When the `Query` actually issues SQL to load rows, it always uses + column labeling. + + """ + self._with_labels = True + with_labels = _generative()(with_labels) + + def whereclause(self): """return the WHERE criterion for this Query.""" return self._criterion @@ -189,48 +333,44 @@ class Query(object): def _with_current_path(self, path): """indicate that this query applies to objects loaded within a certain path. - - Used by deferred loaders (see strategies.py) which transfer query + + Used by deferred loaders (see strategies.py) which transfer query options from an originating query to a newly generated query intended for the deferred load. - + """ - q = self._clone() - q._current_path = path - return q + self._current_path = path + _with_current_path = _generative()(_with_current_path) def with_polymorphic(self, cls_or_mappers, selectable=None): """Load columns for descendant mappers of this Query's mapper. - + Using this method will ensure that each descendant mapper's - tables are included in the FROM clause, and will allow filter() - criterion to be used against those tables. The resulting + tables are included in the FROM clause, and will allow filter() + criterion to be used against those tables. The resulting instances will also have those columns already loaded so that no "post fetch" of those columns will be required. - + ``cls_or_mappers`` is a single class or mapper, or list of class/mappers, which inherit from this Query's mapper. Alternatively, it - may also be the string ``'*'``, in which case all descending + may also be the string ``'*'``, in which case all descending mappers will be added to the FROM clause. - - ``selectable`` is a table or select() statement that will + + ``selectable`` is a table or select() statement that will be used in place of the generated FROM clause. This argument - is required if any of the desired mappers use concrete table - inheritance, since SQLAlchemy currently cannot generate UNIONs - among tables automatically. If used, the ``selectable`` - argument must represent the full set of tables and columns mapped + is required if any of the desired mappers use concrete table + inheritance, since SQLAlchemy currently cannot generate UNIONs + among tables automatically. If used, the ``selectable`` + argument must represent the full set of tables and columns mapped by every desired mapper. Otherwise, the unaccounted mapped columns - will result in their table being appended directly to the FROM + will result in their table being appended directly to the FROM clause which will usually lead to incorrect results. """ - q = self.__no_criterion('with_polymorphic') - - q.__set_with_polymorphic(cls_or_mappers, selectable=selectable) + entity = self._generate_mapper_zero() + entity.set_with_polymorphic(self, cls_or_mappers, selectable=selectable) + with_polymorphic = _generative(__no_from_condition, __no_criterion_condition)(with_polymorphic) - return q - - def yield_per(self, count): """Yield only ``count`` rows at a time. @@ -242,30 +382,28 @@ class Query(object): eagerly loaded collections (i.e. any lazy=False) since those collections will be cleared for a new load when encountered in a subsequent result batch. - """ - q = self._clone() - q._yield_per = count - return q + """ + self._yield_per = count + yield_per = _generative()(yield_per) def get(self, ident, **kwargs): """Return an instance of the object based on the given identifier, or None if not found. The `ident` argument is a scalar or tuple of primary key column values in the order of the table def's primary key columns. + """ - ret = self._extension.get(self, ident, **kwargs) + ret = self._extension_zero().get(self, ident, **kwargs) if ret is not mapper.EXT_CONTINUE: return ret # convert composite types to individual args - # TODO: account for the order of columns in the - # ColumnProperty it corresponds to if hasattr(ident, '__composite_values__'): ident = ident.__composite_values__() - key = self.mapper.identity_key_from_primary_key(ident) + key = self._only_mapper_zero().identity_key_from_primary_key(ident) return self._get(key, ident, **kwargs) def load(self, ident, raiseerr=True, **kwargs): @@ -275,15 +413,20 @@ class Query(object): pending changes** to the object already existing in the Session. The `ident` argument is a scalar or tuple of primary key column values in the order of the table def's primary key columns. - """ - ret = self._extension.load(self, ident, **kwargs) + """ + ret = self._extension_zero().load(self, ident, **kwargs) if ret is not mapper.EXT_CONTINUE: return ret - key = self.mapper.identity_key_from_primary_key(ident) + + # convert composite types to individual args + if hasattr(ident, '__composite_values__'): + ident = ident.__composite_values__() + + key = self._only_mapper_zero().identity_key_from_primary_key(ident) instance = self.populate_existing()._get(key, ident, **kwargs) if instance is None and raiseerr: - raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident)) + raise sa_exc.InvalidRequestError("No instance found for identity %s" % repr(ident)) return instance def query_from_parent(cls, instance, property, **kwargs): @@ -303,27 +446,33 @@ class Query(object): \**kwargs all extra keyword arguments are propagated to the constructor of Query. - """ + deprecated. use sqlalchemy.orm.with_parent in conjunction with + filter(). + + """ mapper = object_mapper(instance) prop = mapper.get_property(property, resolve_synonyms=True) target = prop.mapper criterion = prop.compare(operators.eq, instance, value_is_parent=True) return Query(target, **kwargs).filter(criterion) - query_from_parent = classmethod(query_from_parent) + query_from_parent = classmethod(util.deprecated(None, False)(query_from_parent)) + + def correlate(self, *args): + self._correlate = self._correlate.union([_orm_selectable(s) for s in args]) + correlate = _generative()(correlate) def autoflush(self, setting): """Return a Query with a specific 'autoflush' setting. Note that a Session with autoflush=False will - not autoflush, even if this flag is set to True at the + not autoflush, even if this flag is set to True at the Query level. Therefore this flag is usually used only to disable autoflush for a specific Query. - + """ - q = self._clone() - q._autoflush = setting - return q + self._autoflush = setting + autoflush = _generative()(autoflush) def populate_existing(self): """Return a Query that will refresh all instances loaded. @@ -336,11 +485,10 @@ class Query(object): An alternative to populate_existing() is to expire the Session fully using session.expire_all(). - + """ - q = self._clone() - q._populate_existing = True - return q + self._populate_existing = True + populate_existing = _generative()(populate_existing) def with_parent(self, instance, property=None): """add a join criterion corresponding to a relationship to the given parent instance. @@ -361,140 +509,98 @@ class Query(object): mapper = object_mapper(instance) if property is None: for prop in mapper.iterate_properties: - if isinstance(prop, properties.PropertyLoader) and prop.mapper is self.mapper: + if isinstance(prop, properties.PropertyLoader) and prop.mapper is self._mapper_zero(): break else: - raise exceptions.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self.mapper.class_.__name__, instance.__class__.__name__)) + raise sa_exc.InvalidRequestError("Could not locate a property which relates instances of class '%s' to instances of class '%s'" % (self._mapper_zero().class_.__name__, instance.__class__.__name__)) else: prop = mapper.get_property(property, resolve_synonyms=True) return self.filter(prop.compare(operators.eq, instance, value_is_parent=True)) - def add_entity(self, entity, alias=None, id=None): - """add a mapped entity to the list of result columns to be returned. - - This will have the effect of all result-returning methods returning a tuple - of results, the first element being an instance of the primary class for this - Query, and subsequent elements matching columns or entities which were - specified via add_column or add_entity. - - When adding entities to the result, its generally desirable to add - limiting criterion to the query which can associate the primary entity - of this Query along with the additional entities. The Query selects - from all tables with no joining criterion by default. + def add_entity(self, entity, alias=None): + """add a mapped entity to the list of result columns to be returned.""" - entity - a class or mapper which will be added to the results. + if alias: + entity = aliased(entity, alias) - alias - a sqlalchemy.sql.Alias object which will be used to select rows. this - will match the usage of the given Alias in filter(), order_by(), etc. expressions + self._entities = list(self._entities) + m = _MapperEntity(self, entity) + self.__setup_aliasizers([m]) + add_entity = _generative()(add_entity) - id - a string ID matching that given to query.join() or query.outerjoin(); rows will be - selected from the aliased join created via those methods. + def from_self(self, *entities): + """return a Query that selects from this Query's SELECT statement. + \*entities - optional list of entities which will replace + those being selected. """ - q = self._clone() - - if not alias and _is_aliased_class(entity): - alias = entity.alias - if isinstance(entity, type): - entity = mapper.class_mapper(entity) + fromclause = self.compile().correlate(None) + self._statement = self._criterion = None + self._order_by = self._group_by = self._distinct = False + self._limit = self._offset = None + self.__set_select_from(fromclause) + if entities: + self._entities = [] + for ent in entities: + _QueryEntity(self, ent) + self.__setup_aliasizers(self._entities) - if alias is not None: - alias = mapperutil.AliasedClauses(alias) + from_self = _generative()(from_self) + _from_self = from_self - q._entities = q._entities + [_MapperEntity(mapper=entity, alias=alias, id=id)] - return q - - def _from_self(self): - """return a Query that selects from this Query's SELECT statement. - - The API for this method hasn't been decided yet and is subject to change. - - """ - q = self._clone() - q._eager_loaders = util.Set() - fromclause = q.compile().correlate(None) - return Query(self.mapper, self.session).select_from(fromclause) - def values(self, *columns): """Return an iterator yielding result tuples corresponding to the given list of columns""" - - q = self.__no_entities('_values') - q._only_load_props = q._eager_loaders = util.Set() - q._no_filters = True + + if not columns: + return iter(()) + q = self._clone() + q._entities = [] for column in columns: - q._entities.append(self._add_column(column, None, False)) + _ColumnEntity(q, column) + q.__setup_aliasizers(q._entities) if not q._yield_per: - q = q.yield_per(10) + q._yield_per = 10 return iter(q) _values = values - - def add_column(self, column, id=None): - """Add a SQL ColumnElement to the list of result columns to be returned. - This will have the effect of all result-returning methods returning a - tuple of results, the first element being an instance of the primary - class for this Query, and subsequent elements matching columns or - entities which were specified via add_column or add_entity. + def add_column(self, column): + """Add a SQL ColumnElement to the list of result columns to be returned.""" - When adding columns to the result, its generally desirable to add - limiting criterion to the query which can associate the primary entity - of this Query along with the additional columns, if the column is - based on a table or selectable that is not the primary mapped - selectable. The Query selects from all tables with no joining - criterion by default. + self._entities = list(self._entities) + c = _ColumnEntity(self, column) + self.__setup_aliasizers([c]) + add_column = _generative()(add_column) - column - a string column name or sql.ColumnElement to be added to the results. - - """ - q = self._clone() - q._entities = q._entities + [self._add_column(column, id, True)] - return q - - def _add_column(self, column, id, looks_for_aliases): - if isinstance(column, interfaces.PropComparator): - column = column.clause_element() - - elif not isinstance(column, (sql.ColumnElement, basestring)): - raise exceptions.InvalidRequestError("Invalid column expression '%r'" % column) - - return _ColumnEntity(column, id) - def options(self, *args): """Return a new Query object, applying the given list of MapperOptions. """ - return self._options(False, *args) + return self.__options(False, *args) def _conditional_options(self, *args): - return self._options(True, *args) + return self.__options(True, *args) - def _options(self, conditional, *args): - q = self._clone() + def __options(self, conditional, *args): # most MapperOptions write to the '_attributes' dictionary, # so copy that as well - q._attributes = q._attributes.copy() + self._attributes = self._attributes.copy() opts = [o for o in util.flatten_iterator(args)] - q._with_options = q._with_options + opts + self._with_options = self._with_options + opts if conditional: for opt in opts: - opt.process_query_conditionally(q) + opt.process_query_conditionally(self) else: for opt in opts: - opt.process_query(q) - return q + opt.process_query(self) + __options = _generative()(__options) def with_lockmode(self, mode): """Return a new Query object with the specified locking mode.""" - - q = self._clone() - q._lockmode = mode - return q + + self._lockmode = mode + with_lockmode = _generative()(with_lockmode) def params(self, *args, **kwargs): """add values for bind parameters which may have been specified in filter(). @@ -505,14 +611,13 @@ class Query(object): \**kwargs cannot be used. """ - q = self._clone() if len(args) == 1: kwargs.update(args[0]) elif len(args) > 0: - raise exceptions.ArgumentError("params() takes zero or one positional argument, which is a dictionary.") - q._params = q._params.copy() - q._params.update(kwargs) - return q + raise sa_exc.ArgumentError("params() takes zero or one positional argument, which is a dictionary.") + self._params = self._params.copy() + self._params.update(kwargs) + params = _generative()(params) def filter(self, criterion): """apply the given filtering criterion to the query and return the newly resulting ``Query`` @@ -524,22 +629,20 @@ class Query(object): criterion = sql.text(criterion) if criterion is not None and not isinstance(criterion, sql.ClauseElement): - raise exceptions.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string") + raise sa_exc.ArgumentError("filter() argument must be of type sqlalchemy.sql.ClauseElement or string") - if self._aliases_tail: - criterion = self._aliases_tail.adapt_clause(criterion) + criterion = self._adapt_clause(criterion, True, True) - q = self.__no_statement("filter") - if q._criterion is not None: - q._criterion = q._criterion & criterion + if self._criterion is not None: + self._criterion = self._criterion & criterion else: - q._criterion = criterion - return q + self._criterion = criterion + filter = _generative(__no_statement_condition, __no_limit_offset)(filter) def filter_by(self, **kwargs): """apply the given filtering criterion to the query and return the newly resulting ``Query``.""" - clauses = [self._joinpoint.get_property(key, resolve_synonyms=True).compare(operators.eq, value) + clauses = [_entity_descriptor(self._joinpoint_zero(), key)[0] == value for key, value in kwargs.iteritems()] return self.filter(sql.and_(*clauses)) @@ -568,31 +671,27 @@ class Query(object): def order_by(self, *criterion): """apply one or more ORDER BY criterion to the query and return the newly resulting ``Query``""" - q = self.__no_statement("order_by") + criterion = [self._adapt_clause(expression._literal_as_text(o), True, True) for o in criterion] - if self._aliases_tail: - criterion = tuple(self._aliases_tail.adapt_list( - [expression._literal_as_text(o) for o in criterion] - )) - - if q._order_by is False: - q._order_by = criterion + if self._order_by is False: + self._order_by = criterion else: - q._order_by = q._order_by + criterion - return q + self._order_by = self._order_by + criterion order_by = util.array_as_starargs_decorator(order_by) - + order_by = _generative(__no_statement_condition)(order_by) + def group_by(self, *criterion): """apply one or more GROUP BY criterion to the query and return the newly resulting ``Query``""" - q = self.__no_statement("group_by") - if q._group_by is False: - q._group_by = criterion + criterion = list(chain(*[_orm_columns(c) for c in criterion])) + + if self._group_by is False: + self._group_by = criterion else: - q._group_by = q._group_by + criterion - return q + self._group_by = self._group_by + criterion group_by = util.array_as_starargs_decorator(group_by) - + group_by = _generative(__no_statement_condition)(group_by) + def having(self, criterion): """apply a HAVING criterion to the query and return the newly resulting ``Query``.""" @@ -600,190 +699,225 @@ class Query(object): criterion = sql.text(criterion) if criterion is not None and not isinstance(criterion, sql.ClauseElement): - raise exceptions.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string") + raise sa_exc.ArgumentError("having() argument must be of type sqlalchemy.sql.ClauseElement or string") - if self._aliases_tail: - criterion = self._aliases_tail.adapt_clause(criterion) + criterion = self._adapt_clause(criterion, True, True) - q = self.__no_statement("having") - if q._having is not None: - q._having = q._having & criterion + if self._having is not None: + self._having = self._having & criterion else: - q._having = criterion - return q + self._having = criterion + having = _generative(__no_statement_condition)(having) - def join(self, prop, id=None, aliased=False, from_joinpoint=False): + def join(self, *props, **kwargs): """Create a join against this ``Query`` object's criterion - and apply generatively, retunring the newly resulting ``Query``. - - 'prop' may be one of: - * a string property name, i.e. "rooms" - * a class-mapped attribute, i.e. Houses.rooms - * a 2-tuple containing one of the above, combined with a selectable - which derives from the properties' mapped table - * a list (not a tuple) containing a combination of any of the above. + and apply generatively, returning the newly resulting ``Query``. + each element in \*props may be: + + * a string property name, i.e. "rooms". This will join along + the relation of the same name from this Query's "primary" + mapper, if one is present. + + * a class-mapped attribute, i.e. Houses.rooms. This will create a + join from "Houses" table to that of the "rooms" relation. + + * a 2-tuple containing a target class or selectable, and + an "ON" clause. The ON clause can be the property name/ + attribute like above, or a SQL expression. + + e.g.:: + # join along string attribute names session.query(Company).join('employees') - session.query(Company).join(['employees', 'tasks']) - session.query(Houses).join([Colonials.rooms, Room.closets]) - session.query(Company).join([('employees', people.join(engineers)), Engineer.computers]) + session.query(Company).join('employees', 'tasks') + + # join the Person entity to an alias of itself, + # along the "friends" relation + PAlias = aliased(Person) + session.query(Person).join((Palias, Person.friends)) + + # join from Houses to the "rooms" attribute on the + # "Colonials" subclass of Houses, then join to the + # "closets" relation on Room + session.query(Houses).join(Colonials.rooms, Room.closets) + + # join from Company entities to the "employees" collection, + # using "people JOIN engineers" as the target. Then join + # to the "computers" collection on the Engineer entity. + session.query(Company).join((people.join(engineers), 'employees'), Engineer.computers) + + # join from Articles to Keywords, using the "keywords" attribute. + # assume this is a many-to-many relation. + session.query(Article).join(Article.keywords) + + # same thing, but spelled out entirely explicitly + # including the association table. + session.query(Article).join( + (article_keywords, Articles.id==article_keywords.c.article_id), + (Keyword, Keyword.id==article_keywords.c.keyword_id) + ) + + \**kwargs include: + + aliased - when joining, create anonymous aliases of each table. This is + used for self-referential joins or multiple joins to the same table. + Consider usage of the aliased(SomeClass) construct as a more explicit + approach to this. + + from_joinpoint - when joins are specified using string property names, + locate the property from the mapper found in the most recent previous + join() call, instead of from the root entity. """ - return self._join(prop, id=id, outerjoin=False, aliased=aliased, from_joinpoint=from_joinpoint) + aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) + if kwargs: + raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys())) + return self.__join(props, outerjoin=False, create_aliases=aliased, from_joinpoint=from_joinpoint) + join = util.array_as_starargs_decorator(join) - def outerjoin(self, prop, id=None, aliased=False, from_joinpoint=False): + def outerjoin(self, *props, **kwargs): """Create a left outer join against this ``Query`` object's criterion and apply generatively, retunring the newly resulting ``Query``. + + Usage is the same as the ``join()`` method. - 'prop' may be one of: - * a string property name, i.e. "rooms" - * a class-mapped attribute, i.e. Houses.rooms - * a 2-tuple containing one of the above, combined with a selectable - which derives from the properties' mapped table - * a list (not a tuple) containing a combination of any of the above. + """ + aliased, from_joinpoint = kwargs.pop('aliased', False), kwargs.pop('from_joinpoint', False) + if kwargs: + raise TypeError("unknown arguments: %s" % ','.join(kwargs.keys())) + return self.__join(props, outerjoin=True, create_aliases=aliased, from_joinpoint=from_joinpoint) + outerjoin = util.array_as_starargs_decorator(outerjoin) - e.g.:: + def __join(self, keys, outerjoin, create_aliases, from_joinpoint): + self.__currenttables = util.Set(self.__currenttables) + self._polymorphic_adapters = self._polymorphic_adapters.copy() - session.query(Company).outerjoin('employees') - session.query(Company).outerjoin(['employees', 'tasks']) - session.query(Houses).outerjoin([Colonials.rooms, Room.closets]) - session.query(Company).join([('employees', people.join(engineers)), Engineer.computers]) + if not from_joinpoint: + self.__reset_joinpoint() - """ - return self._join(prop, id=id, outerjoin=True, aliased=aliased, from_joinpoint=from_joinpoint) - - def _join(self, prop, id, outerjoin, aliased, from_joinpoint): - (clause, mapper, aliases) = self._join_to(prop, outerjoin=outerjoin, start=from_joinpoint and self._joinpoint or self.mapper, create_aliases=aliased) - # TODO: improve the generative check here to look for primary mapped entity, etc. - q = self.__no_statement("join") - q._from_obj = clause - q._joinpoint = mapper - q._aliases = aliases - q.__generate_alias_ids() - - if aliases: - q._aliases_tail = aliases - - a = aliases - while a is not None: - if isinstance(a, mapperutil.PropertyAliasedClauses): - q._alias_ids.setdefault(a.mapper, []).append(a) - q._alias_ids.setdefault(a.table, []).append(a) - a = a.parentclauses + clause = self._from_obj + right_entity = None + + for arg1 in util.to_list(keys): + prop = None + aliased_entity = False + alias_criterion = False + left_entity = right_entity + right_entity = right_mapper = None + + if isinstance(arg1, tuple): + arg1, arg2 = arg1 else: - break + arg2 = None + + if isinstance(arg2, (interfaces.PropComparator, basestring)): + onclause = arg2 + right_entity = arg1 + elif isinstance(arg1, (interfaces.PropComparator, basestring)): + onclause = arg1 + right_entity = arg2 + else: + onclause = arg2 + right_entity = arg1 - if id: - q._alias_ids[id] = [aliases] - return q + if isinstance(onclause, interfaces.PropComparator): + of_type = getattr(onclause, '_of_type', None) + prop = onclause.property + descriptor = onclause + + if not left_entity: + left_entity = onclause.parententity + + if of_type: + right_mapper = of_type + else: + right_mapper = prop.mapper + + if not right_entity: + right_entity = right_mapper + + elif isinstance(onclause, basestring): + if not left_entity: + left_entity = self._joinpoint_zero() + + descriptor, prop = _entity_descriptor(left_entity, onclause) + right_mapper = prop.mapper + if not right_entity: + right_entity = right_mapper + elif onclause is None: + if not left_entity: + left_entity = self._joinpoint_zero() + else: + if not left_entity: + left_entity = self._joinpoint_zero() + + if not clause: + if isinstance(onclause, interfaces.PropComparator): + clause = onclause.__clause_element__() - def _get_joinable_tables(self): - if not self.__joinable_tables or self.__joinable_tables[0] is not self._from_obj: - currenttables = [self._from_obj] - def visit_join(join): - currenttables.append(join.left) - currenttables.append(join.right) - visitors.traverse(self._from_obj, visit_join=visit_join, traverse_options={'column_collections':False, 'aliased_selectables':False}) - self.__joinable_tables = (self._from_obj, currenttables) - return currenttables - else: - return self.__joinable_tables[1] + for ent in self._mapper_entities: + if ent.corresponds_to(left_entity): + clause = ent.selectable + break - def _join_to(self, keys, outerjoin=False, start=None, create_aliases=True): - if start is None: - start = self._joinpoint + if not clause: + raise exc.InvalidRequestError("Could not find a FROM clause to join from") - clause = self._from_obj + bogus, right_selectable, is_aliased_class = _entity_info(right_entity) - currenttables = self._get_joinable_tables() + if right_mapper and not is_aliased_class: + if right_entity is right_selectable: - # determine if generated joins need to be aliased on the left - # hand side. - if self._aliases_head is self._aliases_tail is not None: - adapt_against = self._aliases_tail.alias - elif start is not self.mapper and self._aliases_tail: - adapt_against = self._aliases_tail.alias - else: - adapt_against = None + if not right_selectable.is_derived_from(right_mapper.mapped_table): + raise sa_exc.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (right_selectable.description, right_mapper.mapped_table.description)) - mapper = start - alias = self._aliases_tail + if not isinstance(right_selectable, expression.Alias): + right_selectable = right_selectable.alias() - if not isinstance(keys, list): - keys = [keys] - - for key in keys: - use_selectable = None - of_type = None - is_aliased_class = False - - if isinstance(key, tuple): - key, use_selectable = key - - if isinstance(key, interfaces.PropComparator): - prop = key.property - if getattr(key, '_of_type', None): - of_type = key._of_type - if not use_selectable: - use_selectable = key._of_type.mapped_table - else: - prop = mapper.get_property(key, resolve_synonyms=True) - - if use_selectable: - if _is_aliased_class(use_selectable): - use_selectable = use_selectable.alias - is_aliased_class = True - if not use_selectable.is_derived_from(prop.mapper.mapped_table): - raise exceptions.InvalidRequestError("Selectable '%s' is not derived from '%s'" % (use_selectable.description, prop.mapper.mapped_table.description)) - if not isinstance(use_selectable, expression.Alias): - use_selectable = use_selectable.alias() - elif prop.mapper.with_polymorphic: - use_selectable = prop.mapper._with_polymorphic_selectable() - if not isinstance(use_selectable, expression.Alias): - use_selectable = use_selectable.alias() - - if prop._is_self_referential() and not create_aliases and not use_selectable: - raise exceptions.InvalidRequestError("Self-referential query on '%s' property requires aliased=True argument." % str(prop)) - - if prop.table not in currenttables or create_aliases or use_selectable: + right_entity = aliased(right_mapper, right_selectable) + alias_criterion = True + + elif right_mapper.with_polymorphic or isinstance(right_mapper.mapped_table, expression.Join): + aliased_entity = True + right_entity = aliased(right_mapper) + alias_criterion = True - if use_selectable or create_aliases: - alias = mapperutil.PropertyAliasedClauses(prop, - prop.primaryjoin, - prop.secondaryjoin, - alias, - alias=use_selectable, - should_adapt=not is_aliased_class - ) - crit = alias.primaryjoin + elif create_aliases: + right_entity = aliased(right_mapper) + alias_criterion = True + + elif prop: + if prop.table in self.__currenttables: + if prop.secondary is not None and prop.secondary not in self.__currenttables: + # TODO: this check is not strong enough for different paths to the same endpoint which + # does not use secondary tables + raise sa_exc.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % descriptor) + + continue + if prop.secondary: - clause = clause.join(alias.secondary, crit, isouter=outerjoin) - clause = clause.join(alias.alias, alias.secondaryjoin, isouter=outerjoin) - else: - clause = clause.join(alias.alias, crit, isouter=outerjoin) - else: - assert not prop.mapper.with_polymorphic - pj, sj, source, dest, target_adapter = prop._create_joins(source_selectable=adapt_against) - if sj: - clause = clause.join(prop.secondary, pj, isouter=outerjoin) - clause = clause.join(prop.table, sj, isouter=outerjoin) - else: - clause = clause.join(prop.table, pj, isouter=outerjoin) - - elif not create_aliases and prop.secondary is not None and prop.secondary not in currenttables: - # TODO: this check is not strong enough for different paths to the same endpoint which - # does not use secondary tables - raise exceptions.InvalidRequestError("Can't join to property '%s'; a path to this table along a different secondary table already exists. Use the `alias=True` argument to `join()`." % prop.key) + self.__currenttables.add(prop.secondary) + self.__currenttables.add(prop.table) - mapper = of_type or prop.mapper + right_entity = prop.mapper - if use_selectable: - adapt_against = use_selectable - - return (clause, mapper, alias) + if prop: + onclause = prop + + clause = orm_join(clause, right_entity, onclause, isouter=outerjoin) + if alias_criterion: + self._filter_aliases = ORMAdapter(right_entity, + equivalents=right_mapper._equivalent_columns, chain_to=self._filter_aliases) + + if aliased_entity: + self.__mapper_loads_polymorphically_with(right_mapper, ORMAdapter(right_entity, equivalents=right_mapper._equivalent_columns)) + + self._from_obj = clause + self._joinpoint = right_entity + __join = _generative(__no_statement_condition, __no_limit_offset)(__join) def reset_joinpoint(self): """return a new Query reset the 'joinpoint' of this Query reset @@ -794,13 +928,8 @@ class Query(object): the root. """ - q = self.__no_statement("reset_joinpoint") - q._joinpoint = q.mapper - if q.table not in q._get_joinable_tables(): - q._aliases_head = q._aliases_tail = mapperutil.AliasedClauses(q._from_obj, equivalents=q.mapper._equivalent_columns) - else: - q._aliases_head = q._aliases_tail = None - return q + self.__reset_joinpoint() + reset_joinpoint = _generative(__no_statement_condition)(reset_joinpoint) def select_from(self, from_obj): """Set the `from_obj` parameter of the query and return the newly @@ -811,14 +940,13 @@ class Query(object): `from_obj` is a single table or selectable. """ - new = self.__no_criterion('select_from') if isinstance(from_obj, (tuple, list)): util.warn_deprecated("select_from() now accepts a single Selectable as its argument, which replaces any existing FROM criterion.") from_obj = from_obj[-1] + + self.__set_select_from(from_obj) + select_from = _generative(__no_from_condition, __no_criterion_condition)(select_from) - new.__set_select_from(from_obj) - return new - def __getitem__(self, item): if isinstance(item, slice): start = item.start @@ -863,9 +991,8 @@ class Query(object): ``Query``. """ - new = self.__no_statement("distinct") - new._distinct = True - return new + self._distinct = True + distinct = _generative(__no_statement_condition)(distinct) def all(self): """Return the results represented by this ``Query`` as a list. @@ -875,7 +1002,6 @@ class Query(object): """ return list(self) - def from_statement(self, statement): """Execute the given SELECT statement and return results. @@ -891,9 +1017,8 @@ class Query(object): """ if isinstance(statement, basestring): statement = sql.text(statement) - q = self.__no_criterion('from_statement') - q._statement = statement - return q + self._statement = statement + from_statement = _generative(__no_criterion_condition)(from_statement) def first(self): """Return the first result of this ``Query`` or None if the result doesn't contain any row. @@ -901,9 +1026,6 @@ class Query(object): This results in an execution of the underlying query. """ - if self._column_aggregate is not None: - return self._col_aggregate(*self._column_aggregate) - ret = list(self[0:1]) if len(ret) > 0: return ret[0] @@ -916,17 +1038,14 @@ class Query(object): This results in an execution of the underlying query. """ - if self._column_aggregate is not None: - return self._col_aggregate(*self._column_aggregate) - ret = list(self[0:2]) if len(ret) == 1: return ret[0] elif len(ret) == 0: - raise exceptions.InvalidRequestError('No rows returned for one()') + raise sa_exc.InvalidRequestError('No rows returned for one()') else: - raise exceptions.InvalidRequestError('Multiple rows returned for one()') + raise sa_exc.InvalidRequestError('Multiple rows returned for one()') def __iter__(self): context = self._compile_context() @@ -936,37 +1055,41 @@ class Query(object): return self._execute_and_instances(context) def _execute_and_instances(self, querycontext): - result = self.session.execute(querycontext.statement, params=self._params, mapper=self.mapper, instance=self._refresh_instance) - return self.iterate_instances(result, querycontext=querycontext) + result = self.session.execute(querycontext.statement, params=self._params, mapper=self._mapper_zero_or_none(), instance=self._refresh_instance) + return self.iterate_instances(result, querycontext) - def instances(self, cursor, *mappers_or_columns, **kwargs): - return list(self.iterate_instances(cursor, *mappers_or_columns, **kwargs)) + def instances(self, cursor, __context=None): + return list(self.iterate_instances(cursor, __context)) - def iterate_instances(self, cursor, *mappers_or_columns, **kwargs): + def iterate_instances(self, cursor, __context=None): session = self.session - context = kwargs.pop('querycontext', None) + context = __context if context is None: context = QueryContext(self) context.runid = _new_runid() - entities = self._entities + [_QueryEntity.legacy_guess_type(mc) for mc in mappers_or_columns] - - if getattr(self, '_no_filters', False): - filter = None - single_entity = custom_rows = False - else: - single_entity = isinstance(entities[0], _PrimaryMapperEntity) and len(entities) == 1 - custom_rows = single_entity and 'append_result' in context.extension.methods - + filtered = bool(list(self._mapper_entities)) + single_entity = filtered and len(self._entities) == 1 + + if filtered: if single_entity: filter = util.OrderedIdentitySet else: filter = util.OrderedSet - - process = [query_entity.row_processor(self, context, single_entity) for query_entity in entities] + else: + filter = None + + custom_rows = single_entity and 'append_result' in self._entities[0].extension.methods + (process, labels) = zip(*[query_entity.row_processor(self, context, custom_rows) for query_entity in self._entities]) + + if not single_entity: + labels = dict([(label, property(util.itemgetter(i))) for i, label in enumerate(labels) if label]) + rowtuple = type.__new__(type, "RowTuple", (tuple,), labels) + rowtuple.keys = labels.keys + while True: context.progress = util.Set() context.partials = {} @@ -974,7 +1097,7 @@ class Query(object): if self._yield_per: fetch = cursor.fetchmany(self._yield_per) if not fetch: - return + break else: fetch = cursor.fetchall() @@ -985,23 +1108,20 @@ class Query(object): elif single_entity: rows = [process[0](context, row) for row in fetch] else: - rows = [tuple([proc(context, row) for proc in process]) for row in fetch] + rows = [rowtuple([proc(context, row) for proc in process]) for row in fetch] if filter: rows = filter(rows) - if context.refresh_instance and context.only_load_props and context.refresh_instance in context.progress: - context.refresh_instance.commit(context.only_load_props) + if context.refresh_instance and self._only_load_props and context.refresh_instance in context.progress: + context.refresh_instance.commit(self._only_load_props) context.progress.remove(context.refresh_instance) - for ii in context.progress: - context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii) - ii.commit_all() - + session._finalize_loaded(context.progress) + for ii, attrs in context.partials.items(): - context.attributes.get(('populating_mapper', ii), _state_mapper(ii))._post_instance(context, ii, only_load_props=attrs) ii.commit(attrs) - + for row in rows: yield row @@ -1010,9 +1130,18 @@ class Query(object): def _get(self, key=None, ident=None, refresh_instance=None, lockmode=None, only_load_props=None): lockmode = lockmode or self._lockmode - if not self._populate_existing and not refresh_instance and not self.mapper.always_refresh and lockmode is None: + if not self._populate_existing and not refresh_instance and not self._mapper_zero().always_refresh and lockmode is None: try: - return self.session.identity_map[key] + instance = self.session.identity_map[key] + state = attributes.instance_state(instance) + if state.expired: + try: + state() + except orm_exc.ObjectDeletedError: + # TODO: should we expunge ? if so, should we expunge here ? or in mapper._load_scalar_attributes ? + self.session.expunge(instance) + return None + return instance except KeyError: pass @@ -1022,27 +1151,29 @@ class Query(object): else: ident = util.to_list(ident) - q = self - - # dont use 'polymorphic' mapper if we are refreshing an instance - if refresh_instance and q.mapper is not q.mapper: - q = q.__reset_all(q.mapper, '_get') + if refresh_instance is None: + q = self.__no_criterion() + else: + q = self._clone() if ident is not None: - q = q.__no_criterion('get') + mapper = q._mapper_zero() params = {} - (_get_clause, _get_params) = q.mapper._get_clause - q = q.filter(_get_clause) - for i, primary_key in enumerate(q.mapper.primary_key): + (_get_clause, _get_params) = mapper._get_clause + + _get_clause = q._adapt_clause(_get_clause, True, False) + q._criterion = _get_clause + + for i, primary_key in enumerate(mapper.primary_key): try: params[_get_params[primary_key].key] = ident[i] except IndexError: - raise exceptions.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in q.mapper.primary_key])) - q = q.params(params) + raise sa_exc.InvalidRequestError("Could not find enough values to formulate primary key for query.get(); primary key columns are %s" % ', '.join(["'%s'" % str(c) for c in q.mapper.primary_key])) + q._params = params if lockmode is not None: - q = q.with_lockmode(lockmode) - q = q.__get_options(populate_existing=bool(refresh_instance), version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance) + q._lockmode = lockmode + q.__get_options(populate_existing=bool(refresh_instance), version_check=(lockmode is not None), only_load_props=only_load_props, refresh_instance=refresh_instance) q._order_by = None try: # call using all() to avoid LIMIT compilation complexity @@ -1053,41 +1184,26 @@ class Query(object): def _select_args(self): return {'limit':self._limit, 'offset':self._offset, 'distinct':self._distinct, 'group_by':self._group_by or None, 'having':self._having or None} _select_args = property(_select_args) - + def _should_nest_selectable(self): kwargs = self._select_args return (kwargs.get('limit') is not None or kwargs.get('offset') is not None or kwargs.get('distinct', False)) _should_nest_selectable = property(_should_nest_selectable) - def count(self, whereclause=None, params=None, **kwargs): - """Apply this query's criterion to a SELECT COUNT statement. - - the whereclause, params and \**kwargs arguments are deprecated. use filter() - and other generative methods to establish modifiers. - - """ - q = self - if whereclause is not None: - q = q.filter(whereclause) - if params is not None: - q = q.params(params) - q = q._legacy_select_kwargs(**kwargs) - return q._count() - - def _count(self): + def count(self): """Apply this query's criterion to a SELECT COUNT statement. this is the purely generative version which will become the public method in version 0.5. """ - return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self.mapper.primary_key)) + return self._col_aggregate(sql.literal_column('1'), sql.func.count, nested_cols=list(self._mapper_zero().primary_key)) def _col_aggregate(self, col, func, nested_cols=None): whereclause = self._criterion - + context = QueryContext(self) - from_obj = self._from_obj + from_obj = self.__mapper_zero_from_obj() if self._should_nest_selectable: if not nested_cols: @@ -1097,113 +1213,97 @@ class Query(object): s = sql.select([func(s.corresponding_column(col) or col)]).select_from(s) else: s = sql.select([func(col)], whereclause, from_obj=from_obj, **self._select_args) - + if self._autoflush and not self._populate_existing: self.session._autoflush() - return self.session.scalar(s, params=self._params, mapper=self.mapper) + return self.session.scalar(s, params=self._params, mapper=self._mapper_zero()) def compile(self): """compiles and returns a SQL statement based on the criterion and conditions within this Query.""" return self._compile_context().statement - def _compile_context(self): - + def _compile_context(self, labels=True): context = QueryContext(self) - if self._statement: - self._statement.use_labels = True - context.statement = self._statement + if context.statement: return context - from_obj = self._from_obj - adapter = self._aliases_head - if self._lockmode: try: - for_update = {'read':'read','update':True,'update_nowait':'nowait',None:False}[self._lockmode] + for_update = {'read': 'read', + 'update': True, + 'update_nowait': 'nowait', + None: False}[self._lockmode] except KeyError: - raise exceptions.ArgumentError("Unknown lockmode '%s'" % self._lockmode) + raise sa_exc.ArgumentError("Unknown lockmode '%s'" % self._lockmode) else: for_update = False - - context.from_clause = from_obj - context.whereclause = self._criterion - context.order_by = self._order_by - + for entity in self._entities: entity.setup_context(self, context) - - if self._eager_loaders and self._should_nest_selectable: - # eager loaders are present, and the SELECT has limiting criterion - # produce a "wrapped" selectable. - + + eager_joins = context.eager_joins.values() + + if context.from_clause: + froms = [context.from_clause] # "load from a single FROM" mode, i.e. when select_from() or join() is used + else: + froms = context.froms # "load from discrete FROMs" mode, i.e. when each _MappedEntity has its own FROM + + if eager_joins and self._should_nest_selectable: + # for eager joins present and LIMIT/OFFSET/DISTINCT, wrap the query inside a select, + # then append eager joins onto that + if context.order_by: - context.order_by = [expression._literal_as_text(o) for o in util.to_list(context.order_by) or []] - if adapter: - context.order_by = adapter.adapt_list(context.order_by) - # locate all embedded Column clauses so they can be added to the - # "inner" select statement where they'll be available to the enclosing - # statement's "order by" - # TODO: this likely doesn't work with very involved ORDER BY expressions, - # such as those including subqueries order_by_col_expr = list(chain(*[sql_util.find_columns(o) for o in context.order_by])) else: context.order_by = None order_by_col_expr = [] - - if adapter: - context.primary_columns = adapter.adapt_list(context.primary_columns) - - inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=context.from_clause, use_labels=True, correlate=False, order_by=context.order_by, **self._select_args).alias() - local_adapter = sql_util.ClauseAdapter(inner) - context.row_adapter = mapperutil.create_row_adapter(inner, equivalent_columns=self.mapper._equivalent_columns) + inner = sql.select(context.primary_columns + order_by_col_expr, context.whereclause, from_obj=froms, use_labels=labels, correlate=False, order_by=context.order_by, **self._select_args) + + if self._correlate: + inner = inner.correlate(*self._correlate) + + inner = inner.alias() - statement = sql.select([inner] + context.secondary_columns, for_update=for_update, use_labels=True) + equivs = self.__all_equivs() - if context.eager_joins: - eager_joins = local_adapter.traverse(context.eager_joins) - statement.append_from(eager_joins) + context.adapter = sql_util.ColumnAdapter(inner, equivs) + + statement = sql.select([inner] + context.secondary_columns, for_update=for_update, use_labels=labels) + + from_clause = inner + for eager_join in eager_joins: + # EagerLoader places a 'stop_on' attribute on the join, + # giving us a marker as to where the "splice point" of the join should be + from_clause = sql_util.splice_joins(from_clause, eager_join, eager_join.stop_on) + + statement.append_from(from_clause) if context.order_by: + local_adapter = sql_util.ClauseAdapter(inner) statement.append_order_by(*local_adapter.copy_and_process(context.order_by)) statement.append_order_by(*context.eager_order_by) else: - if context.order_by: - context.order_by = [expression._literal_as_text(o) for o in util.to_list(context.order_by) or []] - if adapter: - context.order_by = adapter.adapt_list(context.order_by) - else: + if not context.order_by: context.order_by = None - - if adapter: - context.primary_columns = adapter.adapt_list(context.primary_columns) - context.row_adapter = mapperutil.create_row_adapter(adapter.alias, equivalent_columns=self.mapper._equivalent_columns) - + if self._distinct and context.order_by: order_by_col_expr = list(chain(*[sql_util.find_columns(o) for o in context.order_by])) context.primary_columns += order_by_col_expr - statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, order_by=context.order_by, **self._select_args) + froms += context.eager_joins.values() - if context.eager_joins: - if adapter: - context.eager_joins = adapter.adapt_clause(context.eager_joins) - statement.append_from(context.eager_joins) + statement = sql.select(context.primary_columns + context.secondary_columns, context.whereclause, from_obj=froms, use_labels=labels, for_update=for_update, correlate=False, order_by=context.order_by, **self._select_args) + if self._correlate: + statement = statement.correlate(*self._correlate) if context.eager_order_by: - if adapter: - context.eager_order_by = adapter.adapt_list(context.eager_order_by) statement.append_order_by(*context.eager_order_by) - # polymorphic mappers which have concrete tables in their hierarchy usually - # require row aliasing unconditionally. - if not context.row_adapter and self.mapper._requires_row_aliasing: - context.row_adapter = mapperutil.create_row_adapter(self.table, equivalent_columns=self.mapper._equivalent_columns) - - context.statement = statement + context.statement = statement._annotate({'_halt_adapt': True}) return context @@ -1213,462 +1313,257 @@ class Query(object): def __str__(self): return str(self.compile()) - # DEPRECATED LAND ! - - def _generative_col_aggregate(self, col, func): - """apply the given aggregate function to the query and return the newly - resulting ``Query``. (deprecated) - """ - if self._column_aggregate is not None: - raise exceptions.InvalidRequestError("Query already contains an aggregate column or function") - q = self.__no_statement("aggregate") - q._column_aggregate = (col, func) - return q - - def apply_min(self, col): - """apply the SQL ``min()`` function against the given column to the - query and return the newly resulting ``Query``. - - DEPRECATED. - """ - return self._generative_col_aggregate(col, sql.func.min) - - def apply_max(self, col): - """apply the SQL ``max()`` function against the given column to the - query and return the newly resulting ``Query``. - - DEPRECATED. - """ - return self._generative_col_aggregate(col, sql.func.max) - - def apply_sum(self, col): - """apply the SQL ``sum()`` function against the given column to the - query and return the newly resulting ``Query``. - - DEPRECATED. - """ - return self._generative_col_aggregate(col, sql.func.sum) - - def apply_avg(self, col): - """apply the SQL ``avg()`` function against the given column to the - query and return the newly resulting ``Query``. - - DEPRECATED. - """ - return self._generative_col_aggregate(col, sql.func.avg) - - def list(self): #pragma: no cover - """DEPRECATED. use all()""" - - return list(self) - - def scalar(self): #pragma: no cover - """DEPRECATED. use first()""" - - return self.first() - - def _legacy_filter_by(self, *args, **kwargs): #pragma: no cover - return self.filter(self._legacy_join_by(args, kwargs, start=self._joinpoint)) - - def count_by(self, *args, **params): #pragma: no cover - """DEPRECATED. use query.filter_by(\**params).count()""" - - return self.count(self.join_by(*args, **params)) - - def select_whereclause(self, whereclause=None, params=None, **kwargs): #pragma: no cover - """DEPRECATED. use query.filter(whereclause).all()""" - - q = self.filter(whereclause)._legacy_select_kwargs(**kwargs) - if params is not None: - q = q.params(params) - return list(q) +class _QueryEntity(object): + """represent an entity column returned within a Query result.""" - def _legacy_select_from(self, from_obj): - q = self._clone() - if len(from_obj) > 1: - raise exceptions.ArgumentError("Multiple-entry from_obj parameter no longer supported") - q._from_obj = from_obj[0] - return q + def __new__(cls, *args, **kwargs): + if cls is _QueryEntity: + entity = args[1] + if _is_mapped_class(entity): + cls = _MapperEntity + else: + cls = _ColumnEntity + return object.__new__(cls) - def _legacy_select_kwargs(self, **kwargs): #pragma: no cover - q = self - if "order_by" in kwargs and kwargs['order_by']: - q = q.order_by(kwargs['order_by']) - if "group_by" in kwargs: - q = q.group_by(kwargs['group_by']) - if "from_obj" in kwargs: - q = q._legacy_select_from(kwargs['from_obj']) - if "lockmode" in kwargs: - q = q.with_lockmode(kwargs['lockmode']) - if "distinct" in kwargs: - q = q.distinct() - if "limit" in kwargs: - q = q.limit(kwargs['limit']) - if "offset" in kwargs: - q = q.offset(kwargs['offset']) + def _clone(self): + q = self.__class__.__new__(self.__class__) + q.__dict__ = self.__dict__.copy() return q +class _MapperEntity(_QueryEntity): + """mapper/class/AliasedClass entity""" - def get_by(self, *args, **params): #pragma: no cover - """DEPRECATED. use query.filter_by(\**params).first()""" - - ret = self._extension.get_by(self, *args, **params) - if ret is not mapper.EXT_CONTINUE: - return ret - - return self._legacy_filter_by(*args, **params).first() - - def select_by(self, *args, **params): #pragma: no cover - """DEPRECATED. use use query.filter_by(\**params).all().""" - - ret = self._extension.select_by(self, *args, **params) - if ret is not mapper.EXT_CONTINUE: - return ret - - return self._legacy_filter_by(*args, **params).list() - - def join_by(self, *args, **params): #pragma: no cover - """DEPRECATED. use join() to construct joins based on attribute names.""" + def __init__(self, query, entity, entity_name=None): + self.primary_entity = not query._entities + query._entities.append(self) - return self._legacy_join_by(args, params, start=self._joinpoint) + self.entities = [entity] + self.entity_zero = entity + self.entity_name = entity_name - def _build_select(self, arg=None, params=None, **kwargs): #pragma: no cover - if isinstance(arg, sql.FromClause) and arg.supports_execution(): - return self.from_statement(arg) - elif arg is not None: - return self.filter(arg)._legacy_select_kwargs(**kwargs) + def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic): + self.mapper = mapper + self.extension = self.mapper.extension + self.adapter = adapter + self.selectable = from_obj + self._with_polymorphic = with_polymorphic + self.is_aliased_class = is_aliased_class + if is_aliased_class: + self.path_entity = self.entity = self.entity_zero = entity else: - return self._legacy_select_kwargs(**kwargs) - - def selectfirst(self, arg=None, **kwargs): #pragma: no cover - """DEPRECATED. use query.filter(whereclause).first()""" - - return self._build_select(arg, **kwargs).first() - - def selectone(self, arg=None, **kwargs): #pragma: no cover - """DEPRECATED. use query.filter(whereclause).one()""" - - return self._build_select(arg, **kwargs).one() - - def select(self, arg=None, **kwargs): #pragma: no cover - """DEPRECATED. use query.filter(whereclause).all(), or query.from_statement(statement).all()""" + self.path_entity = mapper.base_mapper + self.entity = self.entity_zero = mapper - ret = self._extension.select(self, arg=arg, **kwargs) - if ret is not mapper.EXT_CONTINUE: - return ret - return self._build_select(arg, **kwargs).all() - - def execute(self, clauseelement, params=None, *args, **kwargs): #pragma: no cover - """DEPRECATED. use query.from_statement().all()""" - - return self._select_statement(clauseelement, params, **kwargs) - - def select_statement(self, statement, **params): #pragma: no cover - """DEPRECATED. Use query.from_statement(statement)""" - - return self._select_statement(statement, params) - - def select_text(self, text, **params): #pragma: no cover - """DEPRECATED. Use query.from_statement(statement)""" + def set_with_polymorphic(self, query, cls_or_mappers, selectable): + if cls_or_mappers is None: + query._reset_polymorphic_adapter(self.mapper) + return - return self._select_statement(text, params) - - def _select_statement(self, statement, params=None, **kwargs): #pragma: no cover - q = self.from_statement(statement) - if params is not None: - q = q.params(params) - q.__get_options(**kwargs) - return list(q) - - def join_to(self, key): #pragma: no cover - """DEPRECATED. use join() to create joins based on property names.""" - - [keys, p] = self._locate_prop(key) - return self.join_via(keys) + mappers, from_obj = self.mapper._with_polymorphic_args(cls_or_mappers, selectable) + self._with_polymorphic = mappers - def join_via(self, keys): #pragma: no cover - """DEPRECATED. use join() to create joins based on property names.""" + # TODO: do the wrapped thing here too so that with_polymorphic() can be + # applied to aliases + if not self.is_aliased_class: + self.selectable = from_obj + self.adapter = query._get_polymorphic_adapter(self, from_obj) - mapper = self._joinpoint - clause = None - for key in keys: - prop = mapper.get_property(key, resolve_synonyms=True) - if clause is None: - clause = prop._get_join(mapper) - else: - clause &= prop._get_join(mapper) - mapper = prop.mapper + def corresponds_to(self, entity): + if _is_aliased_class(entity): + return entity is self.path_entity + else: + return entity.base_mapper is self.path_entity - return clause + def _get_entity_clauses(self, query, context): - def _legacy_join_by(self, args, params, start=None): #pragma: no cover - import properties + adapter = None + if not self.is_aliased_class and query._polymorphic_adapters: + for mapper in self.mapper.iterate_to_root(): + adapter = query._polymorphic_adapters.get(mapper.mapped_table, None) + if adapter: + break - clause = None - for arg in args: - if clause is None: - clause = arg - else: - clause &= arg + if not adapter and self.adapter: + adapter = self.adapter - for key, value in params.iteritems(): - (keys, prop) = self._locate_prop(key, start=start) - if isinstance(prop, properties.PropertyLoader): - c = prop.compare(operators.eq, value) & self.join_via(keys[:-1]) + if adapter: + if query._from_obj_alias: + ret = adapter.wrap(query._from_obj_alias) else: - c = prop.compare(operators.eq, value) & self.join_via(keys) - if clause is None: - clause = c - else: - clause &= c - return clause - - def _locate_prop(self, key, start=None): #pragma: no cover - import properties - keys = [] - seen = util.Set() - def search_for_prop(mapper_): - if mapper_ in seen: - return None - seen.add(mapper_) - - prop = mapper_.get_property(key, resolve_synonyms=True, raiseerr=False) - if prop is not None: - if isinstance(prop, properties.PropertyLoader): - keys.insert(0, prop.key) - return prop - else: - for prop in mapper_.iterate_properties: - if not isinstance(prop, properties.PropertyLoader): - continue - x = search_for_prop(prop.mapper) - if x: - keys.insert(0, prop.key) - return x - else: - return None - p = search_for_prop(start or self.mapper) - if p is None: - raise exceptions.InvalidRequestError("Can't locate property named '%s'" % key) - return [keys, p] - - def selectfirst_by(self, *args, **params): #pragma: no cover - """DEPRECATED. Use query.filter_by(\**kwargs).first()""" - - return self._legacy_filter_by(*args, **params).first() - - def selectone_by(self, *args, **params): #pragma: no cover - """DEPRECATED. Use query.filter_by(\**kwargs).one()""" - - return self._legacy_filter_by(*args, **params).one() - - for deprecated_method in ('list', 'scalar', 'count_by', - 'select_whereclause', 'get_by', 'select_by', - 'join_by', 'selectfirst', 'selectone', 'select', - 'execute', 'select_statement', 'select_text', - 'join_to', 'join_via', 'selectfirst_by', - 'selectone_by', 'apply_max', 'apply_min', - 'apply_avg', 'apply_sum'): - locals()[deprecated_method] = \ - util.deprecated(None, False)(locals()[deprecated_method]) - -class _QueryEntity(object): - """represent an entity column returned within a Query result.""" - - def legacy_guess_type(self, e): - if isinstance(e, type): - return _MapperEntity(mapper=mapper.class_mapper(e)) - elif isinstance(e, mapper.Mapper): - return _MapperEntity(mapper=e) + ret = adapter else: - return _ColumnEntity(column=e) - legacy_guess_type=classmethod(legacy_guess_type) + ret = query._from_obj_alias -class _MapperEntity(_QueryEntity): - """entity column corresponding to mapped ORM instances.""" - - def __init__(self, mapper, alias=None, id=None): - self.mapper = mapper - self.alias = alias - self.alias_id = id - - def _get_entity_clauses(self, query): - if self.alias: - return self.alias - elif self.alias_id: - try: - return query._alias_ids[self.alias_id][0] - except KeyError: - raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % self.alias_id) - - l = query._alias_ids.get(self.mapper) - if l: - if len(l) > 1: - raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_entity()" % str(self.mapper)) - return l[0] - else: - return None - - def row_processor(self, query, context, single_entity): - clauses = self._get_entity_clauses(query) - if clauses: - def proc(context, row): - return self.mapper._instance(context, clauses.row_decorator(row), None) - else: - def proc(context, row): - return self.mapper._instance(context, row, None) - - return proc - - def setup_context(self, query, context): - clauses = self._get_entity_clauses(query) - for value in self.mapper.iterate_properties: - context.exec_with_path(self.mapper, value.key, value.setup, context, parentclauses=clauses) + return ret - def __str__(self): - return str(self.mapper) + def row_processor(self, query, context, custom_rows): + adapter = self._get_entity_clauses(query, context) -class _PrimaryMapperEntity(_MapperEntity): - """entity column corresponding to the 'primary' (first) mapped ORM instance.""" + if context.adapter and adapter: + adapter = adapter.wrap(context.adapter) + elif not adapter: + adapter = context.adapter - def row_processor(self, query, context, single_entity): - if single_entity and 'append_result' in context.extension.methods: + # polymorphic mappers which have concrete tables in their hierarchy usually + # require row aliasing unconditionally. + if not adapter and self.mapper._requires_row_aliasing: + adapter = sql_util.ColumnAdapter(self.selectable, self.mapper._equivalent_columns) + + if self.primary_entity: + _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter, + extension=self.extension, only_load_props=query._only_load_props, refresh_instance=context.refresh_instance + ) + else: + _instance = self.mapper._instance_processor(context, (self.path_entity,), adapter) + + if custom_rows: def main(context, row, result): - if context.row_adapter: - row = context.row_adapter(row) - self.mapper._instance(context, row, result, - extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance - ) - elif context.row_adapter: - def main(context, row): - return self.mapper._instance(context, context.row_adapter(row), None, - extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance - ) + _instance(row, result) else: def main(context, row): - return self.mapper._instance(context, row, None, - extension=context.extension, only_load_props=context.only_load_props, refresh_instance=context.refresh_instance - ) + return _instance(row, None) - return main + if self.is_aliased_class: + entname = self.entity._sa_label_name + else: + entname = self.mapper.class_.__name__ + + return main, entname def setup_context(self, query, context): # if single-table inheritance mapper, add "typecol IN (polymorphic)" criterion so # that we only load the appropriate types if self.mapper.single and self.mapper.inherits is not None and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None: context.whereclause = sql.and_(context.whereclause, self.mapper.polymorphic_on.in_([m.polymorphic_identity for m in self.mapper.polymorphic_iterator()])) - - if context.order_by is False: - if self.mapper.order_by: - context.order_by = self.mapper.order_by - elif context.from_clause.default_order_by(): - context.order_by = context.from_clause.default_order_by() - - for value in self.mapper._iterate_polymorphic_properties(query._with_polymorphic, context.from_clause): + + context.froms.append(self.selectable) + + adapter = self._get_entity_clauses(query, context) + + if self.primary_entity: + if context.order_by is False: + # the "default" ORDER BY use case applies only to "mapper zero". the "from clause" default should + # go away in 0.5 (or...maybe 0.6). + if self.mapper.order_by: + context.order_by = self.mapper.order_by + elif context.from_clause: + context.order_by = context.from_clause.default_order_by() + else: + context.order_by = self.selectable.default_order_by() + if context.order_by and adapter: + context.order_by = adapter.adapt_list(util.to_list(context.order_by)) + + for value in self.mapper._iterate_polymorphic_properties(self._with_polymorphic): if query._only_load_props and value.key not in query._only_load_props: continue - context.exec_with_path(self.mapper, value.key, value.setup, context, only_load_props=query._only_load_props) + value.setup(context, self, (self.path_entity,), adapter, only_load_props=query._only_load_props, column_collection=context.primary_columns) + + def __str__(self): + return str(self.mapper) + class _ColumnEntity(_QueryEntity): - """entity column corresponding to Table or selectable columns.""" + """Column/expression based entity.""" + + def __init__(self, query, column, entity_name=None): + if isinstance(column, expression.FromClause) and not isinstance(column, expression.ColumnElement): + for c in column.c: + _ColumnEntity(query, c) + return + + query._entities.append(self) - def __init__(self, column, id): if isinstance(column, basestring): column = sql.literal_column(column) - - if column and isinstance(column, sql.ColumnElement) and not hasattr(column, '_label'): + elif isinstance(column, (attributes.QueryableAttribute, mapper.Mapper._CompileOnAttr)): + column = column.__clause_element__() + elif not isinstance(column, sql.ColumnElement): + raise sa_exc.InvalidRequestError("Invalid column expression '%r'" % column) + + if not hasattr(column, '_label'): column = column.label(None) + self.column = column - self.alias_id = id + self.entity_name = None + self.froms = util.Set() + self.entities = util.OrderedSet([elem._annotations['parententity'] for elem in visitors.iterate(column, {}) if 'parententity' in elem._annotations]) + if self.entities: + self.entity_zero = list(self.entities)[0] + else: + self.entity_zero = None + + def setup_entity(self, entity, mapper, adapter, from_obj, is_aliased_class, with_polymorphic): + self.froms.add(from_obj) def __resolve_expr_against_query_aliases(self, query, expr, context): - if not query._alias_ids: - return expr - - if ('_ColumnEntity', expr) in context.attributes: - return context.attributes[('_ColumnEntity', expr)] - - if self.alias_id: - try: - aliases = query._alias_ids[self.alias_id][0] - except KeyError: - raise exceptions.InvalidRequestError("Query has no alias identified by '%s'" % self.alias_id) + return query._adapt_clause(expr, False, True) - def _locate_aliased(element): - if element in query._alias_ids: - return aliases - else: - def _locate_aliased(element): - if element in query._alias_ids: - aliases = query._alias_ids[element] - if len(aliases) > 1: - raise exceptions.InvalidRequestError("Ambiguous join for entity '%s'; specify id=<someid> to query.join()/query.add_column(), or use the aliased() function to use explicit class aliases." % expr) - return aliases[0] - return None - - class Adapter(visitors.ClauseVisitor): - def before_clone(self, element): - if isinstance(element, expression.FromClause): - alias = _locate_aliased(element) - if alias: - return alias.alias - - if hasattr(element, 'table'): - alias = _locate_aliased(element.table) - if alias: - return alias.aliased_column(element) + def row_processor(self, query, context, custom_rows): + column = self.__resolve_expr_against_query_aliases(query, self.column, context) - return None + if context.adapter: + column = context.adapter.columns[column] - context.attributes[('_ColumnEntity', expr)] = ret = Adapter().traverse(expr, clone=True) - return ret - - def row_processor(self, query, context, single_entity): - column = self.__resolve_expr_against_query_aliases(query, self.column, context) def proc(context, row): return row[column] - return proc - + + return (proc, getattr(column, 'name', None)) + def setup_context(self, query, context): column = self.__resolve_expr_against_query_aliases(query, self.column, context) - context.secondary_columns.append(column) - + context.froms += list(self.froms) + context.primary_columns.append(column) + def __str__(self): return str(self.column) - -Query.logger = logging.class_logger(Query) +Query.logger = log.class_logger(Query) class QueryContext(object): def __init__(self, query): + + if query._statement: + if isinstance(query._statement, expression._SelectBaseMixin) and not query._statement.use_labels: + self.statement = query._statement.apply_labels() + else: + self.statement = query._statement + else: + self.statement = None + self.from_clause = query._from_obj + self.whereclause = query._criterion + self.order_by = query._order_by + if self.order_by: + self.order_by = [expression._literal_as_text(o) for o in util.to_list(self.order_by)] + self.query = query - self.mapper = query.mapper self.session = query.session - self.extension = query._extension - self.statement = None - self.row_adapter = None self.populate_existing = query._populate_existing self.version_check = query._version_check - self.only_load_props = query._only_load_props self.refresh_instance = query._refresh_instance - self.path = () self.primary_columns = [] self.secondary_columns = [] self.eager_order_by = [] - self.eager_joins = None + + self.eager_joins = {} + self.froms = [] + self.adapter = None + self.options = query._with_options self.attributes = query._attributes.copy() - def exec_with_path(self, mapper, propkey, fn, *args, **kwargs): - oldpath = self.path - self.path += (mapper.base_mapper, propkey) - try: - return fn(*args, **kwargs) - finally: - self.path = oldpath +class AliasOption(interfaces.MapperOption): + def __init__(self, alias): + self.alias = alias + def process_query(self, query): + if isinstance(self.alias, basestring): + alias = query._mapper_zero().mapped_table.alias(self.alias) + else: + alias = self.alias + query._from_obj_alias = sql_util.ColumnAdapter(alias) + _runid = 1L _id_lock = util.threading.Lock() diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 479b2f737..c1d3db9f1 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -1,8 +1,17 @@ +# scoping.py +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php + +import inspect +import types + +import sqlalchemy.exceptions as sa_exc from sqlalchemy.util import ScopedRegistry, to_list, get_cls_kwargs -from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, class_mapper +from sqlalchemy.orm import MapperExtension, EXT_CONTINUE, object_session, \ + class_mapper from sqlalchemy.orm.session import Session -from sqlalchemy import exceptions -import types __all__ = ['ScopedSession'] @@ -33,7 +42,7 @@ class ScopedSession(object): scope = kwargs.pop('scope', False) if scope is not None: if self.registry.has(): - raise exceptions.InvalidRequestError("Scoped session is already present; no new arguments may be specified.") + raise sa_exc.InvalidRequestError("Scoped session is already present; no new arguments may be specified.") else: sess = self.session_factory(**kwargs) self.registry.set(sess) @@ -53,7 +62,7 @@ class ScopedSession(object): from sqlalchemy.orm import mapper - extension_args = dict([(arg,kwargs.pop(arg)) + extension_args = dict([(arg, kwargs.pop(arg)) for arg in get_cls_kwargs(_ScopedExt) if arg in kwargs]) @@ -110,10 +119,10 @@ for prop in ('bind', 'dirty', 'deleted', 'new', 'identity_map'): setattr(ScopedSession, prop, makeprop(prop)) def clslevel(name): - def do(cls, *args,**kwargs): + def do(cls, *args, **kwargs): return getattr(Session, name)(*args, **kwargs) return classmethod(do) -for prop in ('close_all','object_session', 'identity_key'): +for prop in ('close_all', 'object_session', 'identity_key'): setattr(ScopedSession, prop, clslevel(prop)) class _ScopedExt(MapperExtension): @@ -121,6 +130,7 @@ class _ScopedExt(MapperExtension): self.context = context self.validate = validate self.save_on_init = save_on_init + self.set_kwargs_on_init = None def validating(self): return _ScopedExt(self.context, validate=True) @@ -128,37 +138,49 @@ class _ScopedExt(MapperExtension): def configure(self, **kwargs): return _ScopedExt(self.context, **kwargs) - def get_session(self): - return self.context.registry() - def instrument_class(self, mapper, class_): class query(object): def __getattr__(s, key): return getattr(self.context.registry().query(class_), key) def __call__(s): return self.context.registry().query(class_) - + def __get__(self, instance, cls): + return self + if not 'query' in class_.__dict__: class_.query = query() - + + if self.set_kwargs_on_init is None: + self.set_kwargs_on_init = class_.__init__ is object.__init__ + if self.set_kwargs_on_init: + def __init__(self, **kwargs): + pass + class_.__init__ = __init__ + def init_instance(self, mapper, class_, oldinit, instance, args, kwargs): if self.save_on_init: entity_name = kwargs.pop('_sa_entity_name', None) session = kwargs.pop('_sa_session', None) - if not isinstance(oldinit, types.MethodType): + + if self.set_kwargs_on_init: for key, value in kwargs.items(): if self.validate: - if not mapper.get_property(key, resolve_synonyms=False, raiseerr=False): - raise exceptions.ArgumentError("Invalid __init__ argument: '%s'" % key) + if not mapper.get_property(key, resolve_synonyms=False, + raiseerr=False): + raise sa_exc.ArgumentError( + "Invalid __init__ argument: '%s'" % key) setattr(instance, key, value) kwargs.clear() + if self.save_on_init: session = session or self.context.registry() - session._save_impl(instance, entity_name=entity_name) + session._save_without_cascade(instance, entity_name=entity_name) return EXT_CONTINUE def init_failed(self, mapper, class_, oldinit, instance, args, kwargs): - object_session(instance).expunge(instance) + sess = object_session(instance) + if sess: + sess.expunge(instance) return EXT_CONTINUE def dispose_class(self, mapper, class_): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 57f23ace2..68a3aed68 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -6,18 +6,24 @@ """Provides the Session class and related utilities.""" - import weakref -from sqlalchemy import util, exceptions, sql, engine -from sqlalchemy.orm import unitofwork, query, attributes, util as mapperutil -from sqlalchemy.orm.mapper import object_mapper as _object_mapper -from sqlalchemy.orm.mapper import class_mapper as _class_mapper -from sqlalchemy.orm.mapper import Mapper +import sqlalchemy.exceptions as sa_exc +import sqlalchemy.orm.attributes +from sqlalchemy import util, sql, engine +from sqlalchemy.sql import util as sql_util, expression +from sqlalchemy.orm import exc, unitofwork, query, attributes, \ + util as mapperutil, SessionExtension +from sqlalchemy.orm.util import object_mapper as _object_mapper +from sqlalchemy.orm.util import class_mapper as _class_mapper +from sqlalchemy.orm.util import _state_mapper, _state_has_identity, _class_to_mapper +from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.orm.unitofwork import UOWTransaction +from sqlalchemy.orm import identity __all__ = ['Session', 'SessionTransaction', 'SessionExtension'] -def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **kwargs): +def sessionmaker(bind=None, class_=None, autoflush=True, autocommit=False, autoexpire=True, **kwargs): """Generate a custom-configured [sqlalchemy.orm.session#Session] class. The returned object is a subclass of ``Session``, which, when instantiated with no @@ -54,20 +60,111 @@ def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **k sess = Session() - The function features a single keyword argument of its own, `class_`, which - may be used to specify an alternate class other than ``sqlalchemy.orm.session.Session`` - which should be used by the returned class. All other keyword arguments sent to - `sessionmaker()` are passed through to the instantiated `Session()` object. - """ + Options: + + autocommit + Defaults to ``False``. When ``True``, the ``Session`` does not keep a + persistent transaction running, and will acquire connections from the engine + on an as-needed basis, returning them immediately after their use. Flushes + will begin and commit (or possibly rollback) their own transaction if no + transaction is present. When using this mode, the `session.begin()` method + may be used to begin a transaction explicitly. + + Leaving it on its default value of ``False`` means that the ``Session`` will + acquire a connection and begin a transaction the first time it is used, which + it will maintain persistently until ``rollback()``, ``commit()``, or + ``close()`` is called. When the transaction is released by any of these + methods, the ``Session`` is ready for the next usage, which will again acquire + and maintain a new connection/transaction. + + autoexpire + When ``True``, all instances will be fully expired after each ``rollback()`` + and after each ``commit()``, so that all attribute/object access subsequent + to a completed transaction will load from the most recent database state. + + autoflush + When ``True``, all query operations will issue a ``flush()`` call to this + ``Session`` before proceeding. This is a convenience feature so that + ``flush()`` need not be called repeatedly in order for database queries to + retrieve results. It's typical that ``autoflush`` is used in conjunction with + ``autocommit=False``. In this scenario, explicit calls to ``flush()`` are rarely + needed; you usually only need to call ``commit()`` (which flushes) to finalize + changes. + + bind + An optional ``Engine`` or ``Connection`` to which this ``Session`` should be + bound. When specified, all SQL operations performed by this session will + execute via this connectable. + + binds + An optional dictionary, which contains more granular "bind" information than + the ``bind`` parameter provides. This dictionary can map individual ``Table`` + instances as well as ``Mapper`` instances to individual ``Engine`` or + ``Connection`` objects. Operations which proceed relative to a particular + ``Mapper`` will consult this dictionary for the direct ``Mapper`` instance as + well as the mapper's ``mapped_table`` attribute in order to locate an + connectable to use. The full resolution is described in the ``get_bind()`` + method of ``Session``. Usage looks like:: + + sess = Session(binds={ + SomeMappedClass : create_engine('postgres://engine1'), + somemapper : create_engine('postgres://engine2'), + some_table : create_engine('postgres://engine3'), + }) + + Also see the ``bind_mapper()`` and ``bind_table()`` methods. + + \class_ + Specify an alternate class other than ``sqlalchemy.orm.session.Session`` + which should be used by the returned class. This is the only argument + that is local to the ``sessionmaker()`` function, and is not sent + directly to the constructor for ``Session``. + echo_uow + When ``True``, configure Python logging to dump all unit-of-work + transactions. This is the equivalent of + ``logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.DEBUG)``. + + extension + An optional [sqlalchemy.orm.session#SessionExtension] instance, which will receive + pre- and post- commit and flush events, as well as a post-rollback event. User- + defined code may be placed within these hooks using a user-defined subclass + of ``SessionExtension``. + + twophase + When ``True``, all transactions will be started using + [sqlalchemy.engine_TwoPhaseTransaction]. During a ``commit()``, after + ``flush()`` has been issued for all attached databases, the ``prepare()`` + method on each database's ``TwoPhaseTransaction`` will be called. This allows + each database to roll back the entire transaction, before each transaction is + committed. + + weak_identity_map + When set to the default value of ``False``, a weak-referencing map is used; + instances which are not externally referenced will be garbage collected + immediately. For dereferenced instances which have pending changes present, + the attribute management system will create a temporary strong-reference to + the object which lasts until the changes are flushed to the database, at which + point it's again dereferenced. Alternatively, when using the value ``True``, + the identity map uses a regular Python dictionary to store instances. The + session will maintain all instances present until they are removed using + expunge(), clear(), or purge(). + + """ + + if 'transactional' in kwargs: + util.warn_deprecated("The 'transactional' argument to sessionmaker() is deprecated; use autocommit=True|False instead.") + autocommit = not kwargs.pop('transactional') + kwargs['bind'] = bind kwargs['autoflush'] = autoflush - kwargs['transactional'] = transactional + kwargs['autocommit'] = autocommit + kwargs['autoexpire'] = autoexpire if class_ is None: class_ = Session - class Sess(class_): + class Sess(object): def __init__(self, **local_kwargs): for k in kwargs: local_kwargs.setdefault(k, kwargs[k]) @@ -83,57 +180,9 @@ def sessionmaker(bind=None, class_=None, autoflush=True, transactional=True, **k kwargs.update(new_kwargs) configure = classmethod(configure) + s = type.__new__(type, "Session", (Sess, class_), {}) + return s - return Sess - -class SessionExtension(object): - """An extension hook object for Sessions. Subclasses may be installed into a Session - (or sessionmaker) using the ``extension`` keyword argument. - """ - - def before_commit(self, session): - """Execute right before commit is called. - - Note that this may not be per-flush if a longer running transaction is ongoing.""" - - def after_commit(self, session): - """Execute after a commit has occured. - - Note that this may not be per-flush if a longer running transaction is ongoing.""" - - def after_rollback(self, session): - """Execute after a rollback has occured. - - Note that this may not be per-flush if a longer running transaction is ongoing.""" - - def before_flush(self, session, flush_context, instances): - """Execute before flush process has started. - - `instances` is an optional list of objects which were passed to the ``flush()`` - method. - """ - - def after_flush(self, session, flush_context): - """Execute after flush has completed, but before commit has been called. - - Note that the session's state is still in pre-flush, i.e. 'new', 'dirty', - and 'deleted' lists still show pre-flush state as well as the history - settings on instance attributes.""" - - def after_flush_postexec(self, session, flush_context): - """Execute after flush has completed, and after the post-exec state occurs. - - This will be when the 'new', 'dirty', and 'deleted' lists are in their final - state. An actual commit() may or may not have occured, depending on whether or not - the flush started its own transaction or participated in a larger transaction. - """ - - def after_begin(self, session, transaction, connection): - """Execute after a transaction is begun on a connection - - `transaction` is the SessionTransaction. This method is called after an - engine level transaction is begun on a connection. - """ class SessionTransaction(object): """Represents a Session-level Transaction. @@ -157,59 +206,100 @@ class SessionTransaction(object): self.nested = nested self._active = True self._prepared = False + if not parent and nested: + raise sa_exc.InvalidRequestError("Can't start a SAVEPOINT transaction when no existing transaction is in progress") + self._take_snapshot() - is_active = property(lambda s: s.session is not None and s._active) + def is_active(self): + return self.session is not None and self._active + is_active = property(is_active) def _assert_is_active(self): self._assert_is_open() if not self._active: - raise exceptions.InvalidRequestError("The transaction is inactive due to a rollback in a subtransaction and should be closed") + raise sa_exc.InvalidRequestError("The transaction is inactive due to a rollback in a subtransaction. Issue rollback() to cancel the transaction.") def _assert_is_open(self): if self.session is None: - raise exceptions.InvalidRequestError("The transaction is closed") - + raise sa_exc.InvalidRequestError("The transaction is closed") + + def _is_transaction_boundary(self): + return self.nested or not self._parent + _is_transaction_boundary = property(_is_transaction_boundary) + def connection(self, bindkey, **kwargs): self._assert_is_active() engine = self.session.get_bind(bindkey, **kwargs) - return self.get_or_add(engine) + return self._connection_for_bind(engine) - def _begin(self, **kwargs): + def _begin(self, autoflush=True, nested=False): self._assert_is_active() - return SessionTransaction(self.session, self, **kwargs) + return SessionTransaction(self.session, self, autoflush=autoflush, nested=nested) def _iterate_parents(self, upto=None): if self._parent is upto: return (self,) else: if self._parent is None: - raise exceptions.InvalidRequestError("Transaction %s is not on the active transaction list" % upto) + raise sa_exc.InvalidRequestError("Transaction %s is not on the active transaction list" % upto) return (self,) + self._parent._iterate_parents(upto) + + def _take_snapshot(self): + if not self._is_transaction_boundary: + self._new = self._parent._new + self._deleted = self._parent._deleted + return + + if self.nested: + self.session.flush() + + if self.autoflush: + # TODO: the "dirty_states" assertion is expensive, + # so consider these assertions as temporary + # during development + assert not self.session._new + assert not self.session._deleted + assert not self.session._dirty_states + + self._new = weakref.WeakKeyDictionary() + self._deleted = weakref.WeakKeyDictionary() + + def _restore_snapshot(self): + assert self._is_transaction_boundary + + for s in util.Set(self._deleted).union(self.session._deleted): + self.session._update_impl(s) + + assert not self.session._deleted + + for s in util.Set(self._new).union(self.session._new): + self.session._expunge_state(s) + + for s in self.session.identity_map.all_states(): + _expire_state(s, None) + + def _remove_snapshot(self): + assert self._is_transaction_boundary - def add(self, bind): - self._assert_is_active() - if self._parent is not None and not self.nested: - return self._parent.add(bind) - - if bind.engine in self._connections: - raise exceptions.InvalidRequestError("Session already has a Connection associated for the given %sEngine" % (isinstance(bind, engine.Connection) and "Connection's " or "")) - return self.get_or_add(bind) - - def get_or_add(self, bind): + if not self.nested and self.session.autoexpire: + for s in self.session.identity_map.all_states(): + _expire_state(s, None) + + def _connection_for_bind(self, bind): self._assert_is_active() if bind in self._connections: return self._connections[bind][0] - if self._parent is not None: - conn = self._parent.get_or_add(bind) + if self._parent: + conn = self._parent._connection_for_bind(bind) if not self.nested: return conn else: if isinstance(bind, engine.Connection): conn = bind if conn.engine in self._connections: - raise exceptions.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine") + raise sa_exc.InvalidRequestError("Session already has a Connection associated for the given Connection's Engine") else: conn = bind.contextual_connect() @@ -227,9 +317,9 @@ class SessionTransaction(object): def prepare(self): if self._parent is not None or not self.session.twophase: - raise exceptions.InvalidRequestError("Only root two phase transactions of can be prepared") + raise sa_exc.InvalidRequestError("Only root two phase transactions of can be prepared") self._prepare_impl() - + def _prepare_impl(self): self._assert_is_active() if self.session.extension is not None and (self._parent is None or self.nested): @@ -264,10 +354,12 @@ class SessionTransaction(object): if self.session.extension is not None: self.session.extension.after_commit(self.session) - + + self._remove_snapshot() + self.close() return self._parent - + def rollback(self): self._assert_is_open() @@ -291,6 +383,8 @@ class SessionTransaction(object): for t in util.Set(self._connections.values()): t[1].rollback() + self._restore_snapshot() + if self.session.extension is not None: self.session.extension.after_rollback(self.session) @@ -308,7 +402,7 @@ class SessionTransaction(object): self._deactivate() self.session = None self._connections = None - + def __enter__(self): return self @@ -356,9 +450,9 @@ class Session(object): * *Transient* - an instance that's not in a session, and is not saved to the database; i.e. it has no database identity. The only relationship such an object has to the ORM - is that its class has a `mapper()` associated with it. + is that its class has a ``mapper()`` associated with it. - * *Pending* - when you `save()` a transient instance, it becomes pending. It still + * *Pending* - when you ``add()`` a transient instance, it becomes pending. It still wasn't actually flushed to the database yet, but it will be when the next flush occurs. @@ -372,108 +466,41 @@ class Session(object): they're detached, **except** they will not be able to issue any SQL in order to load collections or attributes which are not yet loaded, or were marked as "expired". - The session methods which control instance state include ``save()``, ``update()``, - ``save_or_update()``, ``delete()``, ``merge()``, and ``expunge()``. + The session methods which control instance state include ``add()``, ``delete()``, + ``merge()``, and ``expunge()``. - The Session object is **not** threadsafe, particularly during flush operations. A session - which is only read from (i.e. is never flushed) can be used by concurrent threads if it's - acceptable that some object instances may be loaded twice. + The Session object is generally **not** threadsafe. A session which is set to ``autocommit`` + and is only read from may be used by concurrent threads if it's acceptable that some object + instances may be loaded twice. The typical pattern to managing Sessions in a multi-threaded environment is either to use mutexes to limit concurrent access to one thread at a time, or more commonly to establish a unique session for every thread, using a threadlocal variable. SQLAlchemy provides a thread-managed Session adapter, provided by the [sqlalchemy.orm#scoped_session()] function. + """ - - def __init__(self, bind=None, autoflush=True, transactional=False, twophase=False, echo_uow=False, weak_identity_map=True, binds=None, extension=None): + def __init__(self, bind=None, autoflush=True, autoexpire=True, autocommit=False, twophase=False, echo_uow=False, weak_identity_map=True, binds=None, extension=None): """Construct a new Session. - A session is usually constructed using the [sqlalchemy.orm#create_session()] function, - or its more "automated" variant [sqlalchemy.orm#sessionmaker()]. - - autoflush - When ``True``, all query operations will issue a ``flush()`` call to this - ``Session`` before proceeding. This is a convenience feature so that - ``flush()`` need not be called repeatedly in order for database queries to - retrieve results. It's typical that ``autoflush`` is used in conjunction with - ``transactional=True``, so that ``flush()`` is never called; you just call - ``commit()`` when changes are complete to finalize all changes to the - database. - - bind - An optional ``Engine`` or ``Connection`` to which this ``Session`` should be - bound. When specified, all SQL operations performed by this session will - execute via this connectable. - - binds - An optional dictionary, which contains more granular "bind" information than - the ``bind`` parameter provides. This dictionary can map individual ``Table`` - instances as well as ``Mapper`` instances to individual ``Engine`` or - ``Connection`` objects. Operations which proceed relative to a particular - ``Mapper`` will consult this dictionary for the direct ``Mapper`` instance as - well as the mapper's ``mapped_table`` attribute in order to locate an - connectable to use. The full resolution is described in the ``get_bind()`` - method of ``Session``. Usage looks like:: - - sess = Session(binds={ - SomeMappedClass : create_engine('postgres://engine1'), - somemapper : create_engine('postgres://engine2'), - some_table : create_engine('postgres://engine3'), - }) - - Also see the ``bind_mapper()`` and ``bind_table()`` methods. - - echo_uow - When ``True``, configure Python logging to dump all unit-of-work - transactions. This is the equivalent of - ``logging.getLogger('sqlalchemy.orm.unitofwork').setLevel(logging.DEBUG)``. - - extension - An optional [sqlalchemy.orm.session#SessionExtension] instance, which will receive - pre- and post- commit and flush events, as well as a post-rollback event. User- - defined code may be placed within these hooks using a user-defined subclass - of ``SessionExtension``. - - transactional - Set up this ``Session`` to automatically begin transactions. Setting this - flag to ``True`` is the rough equivalent of calling ``begin()`` after each - ``commit()`` operation, after each ``rollback()``, and after each - ``close()``. Basically, this has the effect that all session operations are - performed within the context of a transaction. Note that the ``begin()`` - operation does not immediately utilize any connection resources; only when - connection resources are first required do they get allocated into a - transactional context. - - twophase - When ``True``, all transactions will be started using - [sqlalchemy.engine_TwoPhaseTransaction]. During a ``commit()``, after - ``flush()`` has been issued for all attached databases, the ``prepare()`` - method on each database's ``TwoPhaseTransaction`` will be called. This allows - each database to roll back the entire transaction, before each transaction is - committed. - - weak_identity_map - When set to the default value of ``False``, a weak-referencing map is used; - instances which are not externally referenced will be garbage collected - immediately. For dereferenced instances which have pending changes present, - the attribute management system will create a temporary strong-reference to - the object which lasts until the changes are flushed to the database, at which - point it's again dereferenced. Alternatively, when using the value ``True``, - the identity map uses a regular Python dictionary to store instances. The - session will maintain all instances present until they are removed using - expunge(), clear(), or purge(). + Arguments to ``Session`` are described using the [sqlalchemy.orm#sessionmaker()] function. + """ self.echo_uow = echo_uow - self.weak_identity_map = weak_identity_map - self.uow = unitofwork.UnitOfWork(self) - self.identity_map = self.uow.identity_map + if weak_identity_map: + self._identity_cls = identity.WeakInstanceDict + else: + self._identity_cls = identity.StrongInstanceDict + self.identity_map = self._identity_cls() + self._new = {} # InstanceState->object, strong refs object + self._deleted = {} # same self.bind = bind self.__binds = {} self.transaction = None self.hash_key = id(self) self.autoflush = autoflush - self.transactional = transactional + self.autocommit = autocommit + self.autoexpire = autoexpire self.twophase = twophase self.extension = extension self._query_cls = query.Query @@ -488,28 +515,59 @@ class Session(object): for t in mapperortable._all_tables: self.__binds[t] = value - if self.transactional: + if not self.autocommit: self.begin() _sessions[self.hash_key] = self - def begin(self, **kwargs): - """Begin a transaction on this Session.""" - + def begin(self, subtransactions=False, nested=False, _autoflush=True): + """Begin a transaction on this Session. + + If this Session is already within a transaction, + either a plain transaction or nested transaction, + an error is raised, unless ``subtransactions=True`` + or ``nested=True`` is specified. + + The ``subtransactions=True`` flag indicates that + this ``begin()`` can create a subtransaction if a + transaction is already in progress. A subtransaction + is a non-transactional, delimiting construct that + allows matching begin()/commit() pairs to be nested + together, with only the outermost begin/commit pair + actually affecting transactional state. When a rollback + is issued, the subtransaction will directly roll back + the innermost real transaction, however each subtransaction + still must be explicitly rolled back to maintain proper + stacking of subtransactions. + + If no transaction is in progress, + then a real transaction is begun. + + The ``nested`` flag begins a SAVEPOINT transaction + and is equivalent to calling ``begin_nested()``. + + """ if self.transaction is not None: - self.transaction = self.transaction._begin(**kwargs) + if subtransactions or nested: + self.transaction = self.transaction._begin(nested=nested, autoflush=_autoflush) + else: + raise sa_exc.InvalidRequestError("A transaction is already begun. Use subtransactions=True to allow subtransactions.") else: - self.transaction = SessionTransaction(self, **kwargs) - return self.transaction - - create_transaction = begin + self.transaction = SessionTransaction(self, nested=nested, autoflush=_autoflush) + return self.transaction # needed for __enter__/__exit__ hook def begin_nested(self): """Begin a `nested` transaction on this Session. This utilizes a ``SAVEPOINT`` transaction for databases which support this feature. - """ + The nested transaction is a real transation, unlike + a "subtransaction" which corresponds to multiple + ``begin()`` calls. The next ``rollback()`` or + ``commit()`` call will operate upon this nested + transaction. + + """ return self.begin(nested=True) def rollback(self): @@ -517,42 +575,48 @@ class Session(object): If no transaction is in progress, this method is a pass-thru. + + This method rolls back the current transaction + or nested transaction regardless of subtransactions + being in effect. All subtrasactions up to the + first real transaction are closed. Subtransactions + occur when begin() is called mulitple times. + """ - if self.transaction is None: pass else: self.transaction.rollback() - # TODO: we can rollback attribute values. however - # we would want to expand attributes.py to be able to save *two* rollback points, one to the - # last flush() and the other to when the object first entered the transaction. - # [ticket:705] - #attributes.rollback(*self.identity_map.values()) - if self.transaction is None and self.transactional: + if self.transaction is None and not self.autocommit: self.begin() def commit(self): - """Commit the current transaction in progress. + """Flush any pending changes, and commit the current transaction + in progress, assuming no subtransactions are in effect. If no transaction is in progress, this method raises an InvalidRequestError. + + If a subtransaction is in effect (which occurs when + begin() is called multiple times), the subtransaction + will be closed, and the next call to ``commit()`` + will operate on the enclosing transaction. - If the ``begin()`` method was called on this ``Session`` - additional times subsequent to its first call, - ``commit()`` will not actually commit, and instead - pops an internal SessionTransaction off its internal stack - of transactions. Only when the "root" SessionTransaction - is reached does an actual database-level commit occur. - """ + For a session configured with autocommit=False, a new + transaction will be begun immediately after the commit, + but note that the newly begun transaction does *not* + use any connection resources until the first SQL is + actually emitted. + """ if self.transaction is None: - if self.transactional: + if not self.autocommit: self.begin() else: - raise exceptions.InvalidRequestError("No transaction is begun.") + raise sa_exc.InvalidRequestError("No transaction is begun.") self.transaction.commit() - if self.transaction is None and self.transactional: + if self.transaction is None and not self.autocommit: self.begin() def prepare(self): @@ -565,10 +629,10 @@ class Session(object): not such, an InvalidRequestError is raised. """ if self.transaction is None: - if self.transactional: + if not self.autocommit: self.begin() else: - raise exceptions.InvalidRequestError("No transaction is begun.") + raise sa_exc.InvalidRequestError("No transaction is begun.") self.transaction.prepare() @@ -594,7 +658,7 @@ class Session(object): def __connection(self, engine, **kwargs): if self.transaction is not None: - return self.transaction.get_or_add(engine) + return self.transaction._connection_for_bind(engine) else: return engine.contextual_connect(**kwargs) @@ -620,6 +684,8 @@ class Session(object): the proper bind, in the case of ShardedSession. """ + clause = expression._literal_as_text(clause) + engine = self.get_bind(mapper, clause=clause, instance=instance) return self.__connection(engine, close_with_result=True).execute(clause, params or {}) @@ -646,7 +712,7 @@ class Session(object): if self.transaction is not None: for transaction in self.transaction._iterate_parents(): transaction.close() - if self.transactional: + if not self.autocommit: # note this doesnt use any connection resources self.begin() @@ -657,18 +723,24 @@ class Session(object): sess.close() close_all = classmethod(close_all) - def clear(self): + def expunge_all(self): """Remove all object instances from this ``Session``. This is equivalent to calling ``expunge()`` for all objects in this ``Session``. """ - for instance in self: - self._unattach(instance) - self.uow = unitofwork.UnitOfWork(self) - self.identity_map = self.uow.identity_map + for state in self.identity_map.all_states() + list(self._new): + del state.session_id + self.identity_map = self._identity_cls() + self._new = {} + self._deleted = {} + clear = expunge_all + + # TODO: deprecate + #clear = util.deprecated()(expunge_all) + # TODO: need much more test coverage for bind_mapper() and similar ! def bind_mapper(self, mapper, bind, entity_name=None): @@ -713,79 +785,49 @@ class Session(object): """ if mapper is None and clause is None: - if self.bind is not None: + if self.bind: return self.bind else: - raise exceptions.UnboundExecutionError("This session is unbound to any Engine or Connection; specify a mapper to get_bind()") + raise sa_exc.UnboundExecutionError("This session is not bound to any Engine or Connection; specify a mapper to get_bind()") - elif len(self.__binds): - if mapper is not None: - if isinstance(mapper, type): - mapper = _class_mapper(mapper) + elif self.__binds: + if mapper: + mapper = _class_to_mapper(mapper) if mapper.base_mapper in self.__binds: return self.__binds[mapper.base_mapper] - elif mapper.compile().mapped_table in self.__binds: + elif mapper.mapped_table in self.__binds: return self.__binds[mapper.mapped_table] - if clause is not None: - for t in clause._table_iterator(): + if clause: + for t in sql_util.find_tables(clause): if t in self.__binds: return self.__binds[t] - if self.bind is not None: + if self.bind: return self.bind - elif isinstance(clause, sql.expression.ClauseElement) and clause.bind is not None: + elif isinstance(clause, sql.expression.ClauseElement) and clause.bind: return clause.bind - elif mapper is None: - raise exceptions.UnboundExecutionError("Could not locate any mapper associated with SQL expression") + elif not mapper: + raise sa_exc.UnboundExecutionError("Could not locate any mapper associated with SQL expression") else: - if isinstance(mapper, type): - mapper = _class_mapper(mapper) - else: - mapper = mapper.compile() + mapper = _class_to_mapper(mapper) e = mapper.mapped_table.bind if e is None: - raise exceptions.UnboundExecutionError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper)) + raise sa_exc.UnboundExecutionError("Could not locate any Engine or Connection bound to mapper '%s'" % str(mapper)) return e - def query(self, mapper_or_class, *addtl_entities, **kwargs): - """Return a new ``Query`` object corresponding to this ``Session`` and - the mapper, or the classes' primary mapper. - - """ - entity_name = kwargs.pop('entity_name', None) - - if isinstance(mapper_or_class, type): - q = self._query_cls(_class_mapper(mapper_or_class, entity_name=entity_name), self, **kwargs) - else: - q = self._query_cls(mapper_or_class, self, **kwargs) - - for ent in addtl_entities: - q = q.add_entity(ent) - return q - + def query(self, *entities, **kwargs): + """Return a new ``Query`` object corresponding to this ``Session``.""" + + return self._query_cls(entities, self, **kwargs) def _autoflush(self): if self.autoflush and (self.transaction is None or self.transaction.autoflush): self.flush() + + def _finalize_loaded(self, states): + for state in states: + state.commit_all() - def flush(self, objects=None): - """Flush all the object modifications present in this session - to the database. - - `objects` is a collection or iterator of objects specifically to be - flushed; if ``None``, all new and modified objects are flushed. - - """ - if objects is not None: - try: - if not len(objects): - return - except TypeError: - objects = list(objects) - if not objects: - return - self.uow.flush(self, objects) - - def get(self, class_, ident, **kwargs): + def get(self, class_, ident, entity_name=None): """Return an instance of the object based on the given identifier, or ``None`` if not found. @@ -798,10 +840,9 @@ class Session(object): query. """ - entity_name = kwargs.pop('entity_name', None) - return self.query(class_, entity_name=entity_name).get(ident, **kwargs) + return self.query(class_, entity_name=entity_name).get(ident) - def load(self, class_, ident, **kwargs): + def load(self, class_, ident, entity_name=None): """Return an instance of the object based on the given identifier. @@ -816,8 +857,7 @@ class Session(object): query. """ - entity_name = kwargs.pop('entity_name', None) - return self.query(class_, entity_name=entity_name).load(ident, **kwargs) + return self.query(class_, entity_name=entity_name).load(ident) def refresh(self, instance, attribute_names=None): """Refresh the attributes on the given instance. @@ -838,11 +878,13 @@ class Session(object): refreshed. """ - self._validate_persistent(instance) + state = attributes.instance_state(instance) + self._validate_persistent(state) + if self.query(_object_mapper(instance))._get( + state.key, refresh_instance=state, + only_load_props=attribute_names) is None: + raise sa_exc.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance)) - if self.query(_object_mapper(instance))._get(instance._instance_key, refresh_instance=instance._state, only_load_props=attribute_names) is None: - raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % mapperutil.instance_str(instance)) - def expire_all(self): """Expires all persistent instances within this Session. @@ -862,19 +904,17 @@ class Session(object): of attribute names indicating a subset of attributes to be expired. """ - + state = attributes.instance_state(instance) + self._validate_persistent(state) if attribute_names: - self._validate_persistent(instance) - _expire_state(instance._state, attribute_names=attribute_names) + _expire_state(state, attribute_names=attribute_names) else: # pre-fetch the full cascade since the expire is going to # remove associations - cascaded = list(_cascade_iterator('refresh-expire', instance)) - self._validate_persistent(instance) - _expire_state(instance._state, None) - for (c, m) in cascaded: - self._validate_persistent(c) - _expire_state(c._state, None) + cascaded = list(_cascade_state_iterator('refresh-expire', state)) + _expire_state(state, None) + for (state, m) in cascaded: + _expire_state(state, None) def prune(self): """Remove unreferenced instances cached in the identity map. @@ -887,7 +927,7 @@ class Session(object): Returns the number of objects pruned. """ - return self.uow.prune_identity_map() + return self.identity_map.prune() def expunge(self, instance): """Remove the given `instance` from this ``Session``. @@ -896,11 +936,58 @@ class Session(object): Cascading will be applied according to the *expunge* cascade rule. """ - self._validate_persistent(instance) - for c, m in [(instance, None)] + list(_cascade_iterator('expunge', instance)): - if c in self: - self.uow._remove_deleted(c._state) - self._unattach(c) + + state = attributes.instance_state(instance) + if state.session_id is not self.hash_key: + raise sa_exc.InvalidRequestError("Instance %s is not present in this Session" % mapperutil.state_str(state)) + for s, m in [(state, None)] + list(_cascade_state_iterator('expunge', state)): + self._expunge_state(s) + + def _expunge_state(self, state): + if state in self._new: + self._new.pop(state) + del state.session_id + elif self.identity_map.contains_state(state): + self.identity_map.discard(state) + self._deleted.pop(state, None) + del state.session_id + + def _register_newly_persistent(self, state): + mapper = _state_mapper(state) + instance_key = mapper._identity_key_from_state(state) + + if state.key is None: + state.key = instance_key + elif state.key != instance_key: + # primary key switch + self.identity_map.remove(state) + state.key = instance_key + + if hasattr(state, 'insert_order'): + delattr(state, 'insert_order') + + obj = state.obj() + # prevent against last minute dereferences of the object + # TODO: identify a code path where state.obj() is None + if obj is not None: + if state.key in self.identity_map and not self.identity_map.contains_state(state): + self.identity_map.remove_key(state.key) + self.identity_map.add(state) + state.commit_all() + + # remove from new last, might be the last strong ref + if state in self._new: + if self.transaction: + self.transaction._new[state] = True + self._new.pop(state) + + def _remove_newly_deleted(self, state): + if self.transaction: + self.transaction._deleted[state] = True + + self.identity_map.discard(state) + self._deleted.pop(state, None) + del state.session_id def save(self, instance, entity_name=None): """Add a transient (unsaved) instance to this ``Session``. @@ -911,10 +998,21 @@ class Session(object): The `entity_name` keyword argument will further qualify the specific ``Mapper`` used to handle this instance. + """ - self._save_impl(instance, entity_name=entity_name) - self._cascade_save_or_update(instance) - + state = _state_for_unsaved_instance(instance, entity_name) + self._save_impl(state) + self._cascade_save_or_update(state, entity_name) + + # TODO + #save = util.deprecated("Use the add() method.")(save) + + def _save_without_cascade(self, instance, entity_name=None): + """used by scoping.py to save on init without cascade.""" + + state = _state_for_unsaved_instance(instance, entity_name) + self._save_impl(state) + def update(self, instance, entity_name=None): """Bring the given detached (saved) instance into this ``Session``. @@ -926,24 +1024,42 @@ class Session(object): This operation cascades the `save_or_update` method to associated instances if the relation is mapped with ``cascade="save-update"``. + """ + state = attributes.instance_state(instance) + self._update_impl(state) + self._cascade_save_or_update(state, entity_name) + + # TODO + #update = util.deprecated("Use the add() method.")(update) + + def add(self, instance, entity_name=None): + """Add the given instance into this ``Session``. - self._update_impl(instance, entity_name=entity_name) - self._cascade_save_or_update(instance) - - def save_or_update(self, instance, entity_name=None): - """Save or update the given instance into this ``Session``. + The non-None state `key` on the instance's state determines whether + to ``save()`` or ``update()`` the instance. - The presence of an `_instance_key` attribute on the instance - determines whether to ``save()`` or ``update()`` the instance. """ - - self._save_or_update_impl(instance, entity_name=entity_name) - self._cascade_save_or_update(instance) - - def _cascade_save_or_update(self, instance): - for obj, mapper in _cascade_iterator('save-update', instance, halt_on=lambda c:c in self): - self._save_or_update_impl(obj, mapper.entity_name) + state = _state_for_unknown_persistence_instance(instance, entity_name) + self._save_or_update_state(state, entity_name) + + def add_all(self, instances): + """Add the given collection of instances to this ``Session``.""" + + for instance in instances: + self.add(instance) + + # TODO + # save_or_update = util.deprecated("Use the add() method.")(add) + save_or_update = add + + def _save_or_update_state(self, state, entity_name): + self._save_or_update_impl(state) + self._cascade_save_or_update(state, entity_name) + + def _cascade_save_or_update(self, state, entity_name): + for state, mapper in _cascade_unknown_state_iterator('save-update', state, halt_on=lambda c:c in self): + self._save_or_update_impl(state) def delete(self, instance): """Mark the given instance as deleted. @@ -951,9 +1067,10 @@ class Session(object): The delete operation occurs upon ``flush()``. """ - self._delete_impl(instance) - for c, m in _cascade_iterator('delete', instance): - self._delete_impl(c, ignore_transient=True) + state = attributes.instance_state(instance) + self._delete_impl(state) + for state, m in _cascade_state_iterator('delete', state): + self._delete_impl(state, ignore_transient=True) def merge(self, instance, entity_name=None, dont_load=False, _recursive=None): @@ -980,103 +1097,51 @@ class Session(object): if instance in _recursive: return _recursive[instance] - key = getattr(instance, '_instance_key', None) + new_instance = False + state = attributes.instance_state(instance) + key = state.key if key is None: if dont_load: - raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects. flush() all changes on mapped instances before merging with dont_load=True.") - key = mapper.identity_key_from_instance(instance) + raise sa_exc.InvalidRequestError("merge() with dont_load=True option does not support objects transient (i.e. unpersisted) objects. flush() all changes on mapped instances before merging with dont_load=True.") + key = mapper._identity_key_from_state(state) merged = None if key: if key in self.identity_map: merged = self.identity_map[key] elif dont_load: - if instance._state.modified: - raise exceptions.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True.") - - merged = attributes.new_instance(mapper.class_) - merged._instance_key = key - merged._entity_name = entity_name - self._update_impl(merged, entity_name=mapper.entity_name) + if state.modified: + raise sa_exc.InvalidRequestError("merge() with dont_load=True option does not support objects marked as 'dirty'. flush() all changes on mapped instances before merging with dont_load=True.") + + merged = mapper.class_manager.new_instance() + merged_state = attributes.instance_state(merged) + merged_state.key = key + merged_state.entity_name = entity_name + self._update_impl(merged_state) + new_instance = True else: merged = self.get(mapper.class_, key[1]) - + if merged is None: - merged = attributes.new_instance(mapper.class_) + merged = mapper.class_manager.new_instance() + merged_state = attributes.instance_state(merged) + new_instance = True self.save(merged, entity_name=mapper.entity_name) - + _recursive[instance] = merged - + for prop in mapper.iterate_properties: prop.merge(self, instance, merged, dont_load, _recursive) - + if dont_load: - merged._state.commit_all() # remove any history + attributes.instance_state(merged).commit_all() # remove any history + if new_instance: + merged_state._run_on_load(merged) return merged def identity_key(cls, *args, **kwargs): - """Get an identity key. - - Valid call signatures: - - * ``identity_key(class, ident, entity_name=None)`` - - class - mapped class (must be a positional argument) - - ident - primary key, if the key is composite this is a tuple - - entity_name - optional entity name - - * ``identity_key(instance=instance)`` - - instance - object instance (must be given as a keyword arg) - - * ``identity_key(class, row=row, entity_name=None)`` - - class - mapped class (must be a positional argument) - - row - result proxy row (must be given as a keyword arg) - - entity_name - optional entity name (must be given as a keyword arg) - """ - - if args: - if len(args) == 1: - class_ = args[0] - try: - row = kwargs.pop("row") - except KeyError: - ident = kwargs.pop("ident") - entity_name = kwargs.pop("entity_name", None) - elif len(args) == 2: - class_, ident = args - entity_name = kwargs.pop("entity_name", None) - elif len(args) == 3: - class_, ident, entity_name = args - else: - raise exceptions.ArgumentError("expected up to three " - "positional arguments, got %s" % len(args)) - if kwargs: - raise exceptions.ArgumentError("unknown keyword arguments: %s" - % ", ".join(kwargs.keys())) - mapper = _class_mapper(class_, entity_name=entity_name) - if "ident" in locals(): - return mapper.identity_key_from_primary_key(ident) - return mapper.identity_key_from_row(row) - instance = kwargs.pop("instance") - if kwargs: - raise exceptions.ArgumentError("unknown keyword arguments: %s" - % ", ".join(kwargs.keys())) - mapper = _object_mapper(instance) - return mapper.identity_key_from_instance(instance) + return mapperutil.identity_key(*args, **kwargs) identity_key = classmethod(identity_key) def object_session(cls, instance): @@ -1085,83 +1150,164 @@ class Session(object): return object_session(instance) object_session = classmethod(object_session) - def _save_impl(self, instance, **kwargs): - if hasattr(instance, '_instance_key'): - raise exceptions.InvalidRequestError("Instance '%s' is already persistent" % mapperutil.instance_str(instance)) - else: - # TODO: consolidate the steps here - attributes.manage(instance) - instance._entity_name = kwargs.get('entity_name', None) - self._attach(instance) - self.uow.register_new(instance) - - def _update_impl(self, instance, **kwargs): - if instance in self and instance not in self.deleted: + def _validate_persistent(self, state): + if not self.identity_map.contains_state(state): + raise sa_exc.InvalidRequestError("Instance '%s' is not persistent within this Session" % mapperutil.state_str(state)) + + def _save_impl(self, state): + if state.key is not None: + raise sa_exc.InvalidRequestError( + "Object '%s' already has an identity - it can't be registered " + "as pending" % repr(obj)) + self._attach(state) + if state not in self._new: + self._new[state] = state.obj() + state.insert_order = len(self._new) + + def _update_impl(self, state): + if self.identity_map.contains_state(state) and state not in self._deleted: return - if not hasattr(instance, '_instance_key'): - raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance)) - elif self.identity_map.get(instance._instance_key, instance) is not instance: - raise exceptions.InvalidRequestError("Could not update instance '%s', identity key %s; a different instance with the same identity key already exists in this session." % (mapperutil.instance_str(instance), instance._instance_key)) - self._attach(instance) - - def _save_or_update_impl(self, instance, entity_name=None): - key = getattr(instance, '_instance_key', None) - if key is None: - self._save_impl(instance, entity_name=entity_name) + + if state.key is None: + raise sa_exc.InvalidRequestError( + "Instance '%s' is not persisted" % + mapperutil.state_str(state)) + + if state.key in self.identity_map and not self.identity_map.contains_state(state): + raise sa_exc.InvalidRequestError( + "Could not update instance '%s', identity key %s; a different " + "instance with the same identity key already exists in this " + "session." % (mapperutil.state_str(state), state.key)) + + self._attach(state) + self._deleted.pop(state, None) + self.identity_map.add(state) + + def _save_or_update_impl(self, state): + if state.key is None: + self._save_impl(state) else: - self._update_impl(instance, entity_name=entity_name) + self._update_impl(state) - def _delete_impl(self, instance, ignore_transient=False): - if instance in self and instance in self.deleted: + def _delete_impl(self, state, ignore_transient=False): + if self.identity_map.contains_state(state) and state in self._deleted: return - if not hasattr(instance, '_instance_key'): + + if state.key is None: if ignore_transient: return else: - raise exceptions.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.instance_str(instance)) - if self.identity_map.get(instance._instance_key, instance) is not instance: - raise exceptions.InvalidRequestError("Instance '%s' is with key %s already persisted with a different identity" % (mapperutil.instance_str(instance), instance._instance_key)) - self._attach(instance) - self.uow.register_deleted(instance) - - def _attach(self, instance): - old_id = getattr(instance, '_sa_session_id', None) - if old_id != self.hash_key: - if old_id is not None and old_id in _sessions and instance in _sessions[old_id]: - raise exceptions.InvalidRequestError("Object '%s' is already attached " - "to session '%s' (this is '%s')" % - (mapperutil.instance_str(instance), old_id, id(self))) - - key = getattr(instance, '_instance_key', None) - if key is not None: - self.identity_map[key] = instance - instance._sa_session_id = self.hash_key - - def _unattach(self, instance): - if instance._sa_session_id == self.hash_key: - del instance._sa_session_id - - def _validate_persistent(self, instance): - """Validate that the given instance is persistent within this - ``Session``. - """ - - if instance not in self: - raise exceptions.InvalidRequestError("Instance '%s' is not persistent within this Session" % mapperutil.instance_str(instance)) + raise sa_exc.InvalidRequestError("Instance '%s' is not persisted" % mapperutil.state_str(state)) + if state.key in self.identity_map and not self.identity_map.contains_state(state): + raise sa_exc.InvalidRequestError( + "Instance '%s' is with key %s already persisted with a " + "different identity" % (mapperutil.state_str(state), + state.key)) + + self._deleted[state] = state.obj() + self._attach(state) + + def _attach(self, state): + if state.session_id and state.session_id is not self.hash_key: + raise sa_exc.InvalidRequestError( + "Object '%s' is already attached to session '%s' " + "(this is '%s')" % (mapperutil.state_str(state), + state.session_id, self.hash_key)) + if state.session_id != self.hash_key: + state.session_id = self.hash_key def __contains__(self, instance): """Return True if the given instance is associated with this session. The instance may be pending or persistent within the Session for a result of True. - """ - - return instance._state in self.uow.new or (hasattr(instance, '_instance_key') and self.identity_map.get(instance._instance_key) is instance) + """ + return self._contains_state(attributes.instance_state(instance)) + def __iter__(self): """Return an iterator of all instances which are pending or persistent within this Session.""" - return iter(list(self.uow.new.values()) + self.uow.identity_map.values()) + return iter(list(self._new.values()) + self.identity_map.values()) + + def _contains_state(self, state): + return state in self._new or self.identity_map.contains_state(state) + + + def flush(self, objects=None): + """Flush all the object modifications present in this session + to the database. + + `objects` is a list or tuple of objects specifically to be + flushed; if ``None``, all new and modified objects are flushed. + + """ + if not self.identity_map.check_modified() and not self._deleted and not self._new: + return + + dirty = self._dirty_states + if not dirty and not self._deleted and not self._new: + self.identity_map.modified = False + return + + deleted = util.Set(self._deleted) + new = util.Set(self._new) + + dirty = util.Set(dirty).difference(deleted) + + flush_context = UOWTransaction(self) + + if self.extension is not None: + self.extension.before_flush(self, flush_context, objects) + + # create the set of all objects we want to operate upon + if objects: + # specific list passed in + objset = util.Set([attributes.instance_state(o) for o in objects]) + else: + # or just everything + objset = util.Set(self.identity_map.all_states()).union(new) + + # store objects whose fate has been decided + processed = util.Set() + + # put all saves/updates into the flush context. detect top-level orphans and throw them into deleted. + for state in new.union(dirty).intersection(objset).difference(deleted): + is_orphan = _state_mapper(state)._is_orphan(state) + if is_orphan and not _state_has_identity(state): + raise exc.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" % + ( + mapperutil.state_str(state), + ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in _state_mapper(state).delete_orphans]) + )) + flush_context.register_object(state, isdelete=is_orphan) + processed.add(state) + + # put all remaining deletes into the flush context. + for state in deleted.intersection(objset).difference(processed): + flush_context.register_object(state, isdelete=True) + + if len(flush_context.tasks) == 0: + return + + flush_context.transaction = transaction = self.begin(subtransactions=True, _autoflush=False) + try: + flush_context.execute() + + if self.extension is not None: + self.extension.after_flush(self, flush_context) + transaction.commit() + except: + transaction.rollback() + raise + + flush_context.finalize_flush_changes() + + if not objects: + self.identity_map.modified = False + + if self.extension is not None: + self.extension.after_flush_postexec(self, flush_context) def is_modified(self, instance, include_collections=True, passive=False): """Return True if the given instance has modified attributes. @@ -1180,7 +1326,7 @@ class Session(object): not be loaded in the course of performing this test. """ - for attr in attributes._managed_attributes(instance.__class__): + for attr in attributes.manager_of_class(instance.__class__).attributes: if not include_collections and hasattr(attr.impl, 'get_collection'): continue (added, unchanged, deleted) = attr.get_history(instance) @@ -1188,8 +1334,23 @@ class Session(object): return True return False + def _dirty_states(self): + """Return a set of all persistent states considered dirty. + + This method returns all states that were modified including those that + were possibly deleted. + + """ + return util.IdentitySet( + [state for state in self.identity_map.all_states() if state.check_modified()] + ) + _dirty_states = property(_dirty_states) + def dirty(self): - """Return a ``Set`` of all instances marked as 'dirty' within this ``Session``. + """Return a set of all persistent instances considered dirty. + + Instances are considered dirty when they were modified but not + deleted. Note that the 'dirty' state here is 'optimistic'; most attribute-setting or collection modification operations will mark an instance as 'dirty' and place it in this set, @@ -1200,21 +1361,25 @@ class Session(object): To check if an instance has actionable net changes to its attributes, use the is_modified() method. + """ + + return util.IdentitySet( + [state.obj() for state in self._dirty_states if state not in self._deleted] + ) - return self.uow.locate_dirty() dirty = property(dirty) def deleted(self): "Return a ``Set`` of all instances marked as 'deleted' within this ``Session``" - return util.IdentitySet(self.uow.deleted.values()) + return util.IdentitySet(self._deleted.values()) deleted = property(deleted) def new(self): "Return a ``Set`` of all instances marked as 'new' within this ``Session``." - return util.IdentitySet(self.uow.new.values()) + return util.IdentitySet(self._new.values()) new = property(new) def _expire_state(state, attribute_names): @@ -1233,22 +1398,52 @@ register_attribute = unitofwork.register_attribute _sessions = weakref.WeakValueDictionary() -def _cascade_iterator(cascade, instance, **kwargs): - mapper = _object_mapper(instance) - for (o, m) in mapper.cascade_iterator(cascade, instance._state, **kwargs): - yield o, m +def _cascade_state_iterator(cascade, state, **kwargs): + mapper = _state_mapper(state) + for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs): + yield attributes.instance_state(o), m + +def _cascade_unknown_state_iterator(cascade, state, **kwargs): + mapper = _state_mapper(state) + for (o, m) in mapper.cascade_iterator(cascade, state, **kwargs): + yield _state_for_unknown_persistence_instance(o, m.entity_name), m + +def _state_for_unsaved_instance(instance, entity_name): + manager = attributes.manager_of_class(instance.__class__) + if manager is None: + raise "FIXME unmapped instance" + if manager.has_state(instance): + state = manager.state_of(instance) + if state.key is not None: + raise sa_exc.InvalidRequestError( + "Instance '%s' is already persistent" % + mapperutil.state_str(state)) + else: + state = manager.setup_instance(instance) + state.entity_name = entity_name + return state + +def _state_for_unknown_persistence_instance(instance, entity_name): + state = attributes.instance_state(instance) + state.entity_name = entity_name + return state def object_session(instance): """Return the ``Session`` to which the given instance is bound, or ``None`` if none.""" - hashkey = getattr(instance, '_sa_session_id', None) - if hashkey is not None: - sess = _sessions.get(hashkey) - if sess is not None and instance in sess: - return sess + return _state_session(attributes.instance_state(instance)) + +def _state_session(state): + if state.session_id: + try: + return _sessions[state.session_id] + except KeyError: + pass return None # Lazy initialization to avoid circular imports unitofwork.object_session = object_session +unitofwork._state_session = _state_session from sqlalchemy.orm import mapper mapper._expire_state = _expire_state +mapper._state_session = _state_session diff --git a/lib/sqlalchemy/orm/shard.py b/lib/sqlalchemy/orm/shard.py index 7cf4eb2cc..6850a0bb0 100644 --- a/lib/sqlalchemy/orm/shard.py +++ b/lib/sqlalchemy/orm/shard.py @@ -1,38 +1,49 @@ -"""Defines a rudimental 'horizontal sharding' system which allows a -Session to distribute queries and persistence operations across multiple -databases. +# shard.py +# Copyright (C) the SQLAlchemy authors and contributors +# +# This module is part of SQLAlchemy and is released under +# the MIT License: http://www.opensource.org/licenses/mit-license.php -For a usage example, see the example ``examples/sharding/attribute_shard.py``. +"""Horizontal sharding support. + +Defines a rudimental 'horizontal sharding' system which allows a Session to +distribute queries and persistence operations across multiple databases. + +For a usage example, see the file ``examples/sharding/attribute_shard.py`` +included in the source distrbution. """ + +import sqlalchemy.exceptions as sa_exc +from sqlalchemy import util from sqlalchemy.orm.session import Session from sqlalchemy.orm.query import Query -from sqlalchemy import exceptions, util __all__ = ['ShardedSession', 'ShardedQuery'] + class ShardedSession(Session): def __init__(self, shard_chooser, id_chooser, query_chooser, shards=None, **kwargs): - """construct a ShardedSession. - - shard_chooser - a callable which, passed a Mapper, a mapped instance, and possibly a - SQL clause, returns a shard ID. this id may be based off of the - attributes present within the object, or on some round-robin scheme. If - the scheme is based on a selection, it should set whatever state on the - instance to mark it in the future as participating in that shard. - - id_chooser - a callable, passed a query and a tuple of identity values, - which should return a list of shard ids where the ID might - reside. The databases will be queried in the order of this - listing. - - query_chooser - for a given Query, returns the list of shard_ids where the query - should be issued. Results from all shards returned will be - combined together into a single listing. - + """Construct a ShardedSession. + + shard_chooser + A callable which, passed a Mapper, a mapped instance, and possibly a + SQL clause, returns a shard ID. This id may be based off of the + attributes present within the object, or on some round-robin + scheme. If the scheme is based on a selection, it should set + whatever state on the instance to mark it in the future as + participating in that shard. + + id_chooser + A callable, passed a query and a tuple of identity values, which + should return a list of shard ids where the ID might reside. The + databases will be queried in the order of this listing. + + query_chooser + For a given Query, returns the list of shard_ids where the query + should be issued. Results from all shards returned will be combined + together into a single listing. + """ super(ShardedSession, self).__init__(**kwargs) self.shard_chooser = shard_chooser @@ -87,17 +98,17 @@ class ShardedQuery(Query): def _execute_and_instances(self, context): if self._shard_id is not None: - result = self.session.connection(mapper=self.mapper, shard_id=self._shard_id).execute(context.statement, **self._params) + result = self.session.connection(mapper=self._mapper_zero(), shard_id=self._shard_id).execute(context.statement, **self._params) try: - return iter(self.instances(result, querycontext=context)) + return iter(self.instances(result, context)) finally: result.close() else: partial = [] for shard_id in self.query_chooser(self): - result = self.session.connection(mapper=self.mapper, shard_id=shard_id).execute(context.statement, **self._params) + result = self.session.connection(mapper=self._mapper_zero(), shard_id=shard_id).execute(context.statement, **self._params) try: - partial = partial + list(self.instances(result, querycontext=context)) + partial = partial + list(self.instances(result, context)) finally: result.close() # if some kind of in memory 'sorting' were done, this is where it would happen @@ -124,4 +135,4 @@ class ShardedQuery(Query): if o is not None: return o else: - raise exceptions.InvalidRequestError("No instance found for identity %s" % repr(ident)) + raise sa_exc.InvalidRequestError("No instance found for identity %s" % repr(ident)) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 65a8b019b..8ae3042a6 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -6,11 +6,13 @@ """sqlalchemy.orm.interfaces.LoaderStrategy implementations, and related MapperOptions.""" -from sqlalchemy import sql, util, exceptions, logging +import sqlalchemy.exceptions as sa_exc +from sqlalchemy import sql, util, log from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import visitors, expression, operators from sqlalchemy.orm import mapper, attributes -from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, MapperOption, PropertyOption, serialize_path, deserialize_path +from sqlalchemy.orm.interfaces import LoaderStrategy, StrategizedOption, \ + MapperOption, PropertyOption, serialize_path, deserialize_path from sqlalchemy.orm import session as sessionlib from sqlalchemy.orm import util as mapperutil @@ -21,28 +23,53 @@ class ColumnLoader(LoaderStrategy): def init(self): super(ColumnLoader, self).init() self.columns = self.parent_property.columns - self._should_log_debug = logging.is_debug_enabled(self.logger) + self._should_log_debug = log.is_debug_enabled(self.logger) self.is_composite = hasattr(self.parent_property, 'composite_class') - def setup_query(self, context, parentclauses=None, **kwargs): + def setup_query(self, context, entity, path, adapter, column_collection=None, **kwargs): for c in self.columns: - if parentclauses is not None: - context.secondary_columns.append(parentclauses.aliased_column(c)) - else: - context.primary_columns.append(c) + if adapter: + c = adapter.columns[c] + column_collection.append(c) def init_class_attribute(self): self.is_class_level = True - if self.is_composite: - self._init_composite_attribute() + self.logger.info("%s register managed attribute" % self) + coltype = self.columns[0].type + sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent) + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + key, col = self.key, self.columns[0] + if adapter: + col = adapter.columns[col] + if col in row: + def new_execute(state, row, **flags): + state.dict[key] = row[col] + + if self._should_log_debug: + new_execute = self.debug_callable(new_execute, self.logger, + "%s returning active column fetcher" % self, + lambda state, row, **flags: "%s populating %s" % (self, mapperutil.state_attribute_str(state, key)) + ) + return (new_execute, None) else: - self._init_scalar_attribute() + def new_execute(state, row, isnew, **flags): + if isnew: + state.expire_attributes([key]) + if self._should_log_debug: + self.logger.debug("%s deferring load" % self) + return (new_execute, None) + +ColumnLoader.logger = log.class_logger(ColumnLoader) + +class CompositeColumnLoader(ColumnLoader): + def init_class_attribute(self): + self.is_class_level = True + self.logger.info("%s register managed composite attribute" % self) - def _init_composite_attribute(self): - self.logger.info("register managed composite attribute %s on class %s" % (self.key, self.parent.class_.__name__)) def copy(obj): - return self.parent_property.composite_class( - *obj.__composite_values__()) + return self.parent_property.composite_class(*obj.__composite_values__()) + def compare(a, b): for col, aprop, bprop in zip(self.columns, a.__composite_values__(), @@ -51,63 +78,56 @@ class ColumnLoader(LoaderStrategy): return False else: return True - sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator) - - def _init_scalar_attribute(self): - self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__)) - coltype = self.columns[0].type - sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=coltype.copy_value, compare_function=coltype.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) - - def create_row_processor(self, selectcontext, mapper, row): - if self.is_composite: - for c in self.columns: - if c not in row: - break - else: - def new_execute(instance, row, **flags): - if self._should_log_debug: - self.logger.debug("populating %s with %s/%s..." % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key)) - instance.__dict__[self.key] = self.parent_property.composite_class(*[row[c] for c in self.columns]) - if self._should_log_debug: - self.logger.debug("Returning active composite column fetcher for %s %s" % (mapper, self.key)) - return (new_execute, None, None) - - elif self.columns[0] in row: - def new_execute(instance, row, **flags): + sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, copy_function=copy, compare_function=compare, mutable_scalars=True, comparator=self.parent_property.comparator, parententity=self.parent) + + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + key, columns, composite_class = self.key, self.columns, self.parent_property.composite_class + if adapter: + columns = [adapter.columns[c] for c in columns] + for c in columns: + if c not in row: + def new_execute(state, row, isnew, **flags): + if isnew: + state.expire_attributes([key]) if self._should_log_debug: - self.logger.debug("populating %s with %s/%s" % (mapperutil.attribute_str(instance, self.key), row.__class__.__name__, self.columns[0].key)) - instance.__dict__[self.key] = row[self.columns[0]] - if self._should_log_debug: - self.logger.debug("Returning active column fetcher for %s %s" % (mapper, self.key)) - return (new_execute, None, None) + self.logger.debug("%s deferring load" % self) + return (new_execute, None) else: - def new_execute(instance, row, isnew, **flags): - if isnew: - instance._state.expire_attributes([self.key]) + def new_execute(state, row, **flags): + state.dict[key] = composite_class(*[row[c] for c in columns]) + if self._should_log_debug: - self.logger.debug("Deferring load for %s %s" % (mapper, self.key)) - return (new_execute, None, None) + new_execute = self.debug_callable(new_execute, self.logger, + "%s returning active composite column fetcher" % self, + lambda state, row, **flags: "populating %s" % (mapperutil.state_attribute_str(state, key)) + ) -ColumnLoader.logger = logging.class_logger(ColumnLoader) + return (new_execute, None) +CompositeColumnLoader.logger = log.class_logger(CompositeColumnLoader) + class DeferredColumnLoader(LoaderStrategy): """Deferred column loader, a per-column or per-column-group lazy loader.""" - def create_row_processor(self, selectcontext, mapper, row): - if self.columns[0] in row: - return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, mapper, row) + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + col = self.columns[0] + if adapter: + col = adapter.columns[col] + if col in row: + return self.parent_property._get_strategy(ColumnLoader).create_row_processor(selectcontext, path, mapper, row, adapter) + elif not self.is_class_level or len(selectcontext.options): - def new_execute(instance, row, **flags): - if self._should_log_debug: - self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) - instance._state.set_callable(self.key, self.setup_loader(instance)) - return (new_execute, None, None) + def new_execute(state, row, **flags): + state.set_callable(self.key, self.setup_loader(state)) else: - def new_execute(instance, row, **flags): - if self._should_log_debug: - self.logger.debug("set deferred callable on %s" % mapperutil.attribute_str(instance, self.key)) - instance._state.reset(self.key) - return (new_execute, None, None) + def new_execute(state, row, **flags): + state.reset(self.key) + + if self._should_log_debug: + new_execute = self.debug_callable(new_execute, self.logger, None, + lambda state, row, **flags: "set deferred callable on %s" % mapperutil.state_attribute_str(state, self.key) + ) + return (new_execute, None) def init(self): super(DeferredColumnLoader, self).init() @@ -115,25 +135,25 @@ class DeferredColumnLoader(LoaderStrategy): raise NotImplementedError("Deferred loading for composite types not implemented yet") self.columns = self.parent_property.columns self.group = self.parent_property.group - self._should_log_debug = logging.is_debug_enabled(self.logger) + self._should_log_debug = log.is_debug_enabled(self.logger) def init_class_attribute(self): self.is_class_level = True - self.logger.info("register managed attribute %s on class %s" % (self.key, self.parent.class_.__name__)) - sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator) + self.logger.info("%s register managed attribute" % self) + sessionlib.register_attribute(self.parent.class_, self.key, uselist=False, useobject=False, callable_=self.class_level_loader, copy_function=self.columns[0].type.copy_value, compare_function=self.columns[0].type.compare_values, mutable_scalars=self.columns[0].type.is_mutable(), comparator=self.parent_property.comparator, parententity=self.parent) - def setup_query(self, context, only_load_props=None, **kwargs): + def setup_query(self, context, entity, path, adapter, only_load_props=None, **kwargs): if \ (self.group is not None and context.attributes.get(('undefer', self.group), False)) or \ (only_load_props and self.key in only_load_props): - self.parent_property._get_strategy(ColumnLoader).setup_query(context, **kwargs) + self.parent_property._get_strategy(ColumnLoader).setup_query(context, entity, path, adapter, **kwargs) - def class_level_loader(self, instance, props=None): - if not mapper.has_mapper(instance): + def class_level_loader(self, state, props=None): + if not mapperutil._state_has_mapper(state): return None - localparent = mapper.object_mapper(instance) + localparent = mapper._state_mapper(state) # adjust for the ColumnProperty associated with the instance # not being our own ColumnProperty. This can occur when entity_name @@ -141,38 +161,38 @@ class DeferredColumnLoader(LoaderStrategy): # to the class. prop = localparent.get_property(self.key) if prop is not self.parent_property: - return prop._get_strategy(DeferredColumnLoader).setup_loader(instance) + return prop._get_strategy(DeferredColumnLoader).setup_loader(state) - return LoadDeferredColumns(instance, self.key, props) + return LoadDeferredColumns(state, self.key, props) - def setup_loader(self, instance, props=None, create_statement=None): - return LoadDeferredColumns(instance, self.key, props, optimizing_statement=create_statement) + def setup_loader(self, state, props=None, create_statement=None): + return LoadDeferredColumns(state, self.key, props) -DeferredColumnLoader.logger = logging.class_logger(DeferredColumnLoader) +DeferredColumnLoader.logger = log.class_logger(DeferredColumnLoader) class LoadDeferredColumns(object): - """callable, serializable loader object used by DeferredColumnLoader""" + """serializable loader object used by DeferredColumnLoader""" - def __init__(self, instance, key, keys, optimizing_statement=None): - self.instance = instance + def __init__(self, state, key, keys): + self.state = state self.key = key self.keys = keys - self.optimizing_statement = optimizing_statement def __getstate__(self): - return {'instance':self.instance, 'key':self.key, 'keys':self.keys} + return {'state':self.state, 'key':self.key, 'keys':self.keys} def __setstate__(self, state): - self.instance = state['instance'] + self.state = state['state'] self.key = state['key'] self.keys = state['keys'] - self.optimizing_statement = None def __call__(self): - if not mapper.has_identity(self.instance): + state = self.state + + if not mapper._state_has_identity(state): return None - - localparent = mapper.object_mapper(self.instance, raiseerror=False) + + localparent = mapper._state_mapper(state) prop = localparent.get_property(self.key) strategy = prop._get_strategy(DeferredColumnLoader) @@ -185,22 +205,18 @@ class LoadDeferredColumns(object): toload = [self.key] # narrow the keys down to just those which have no history - group = [k for k in toload if k in self.instance._state.unmodified] + group = [k for k in toload if k in state.unmodified] if strategy._should_log_debug: - strategy.logger.debug("deferred load %s group %s" % (mapperutil.attribute_str(self.instance, self.key), group and ','.join(group) or 'None')) + strategy.logger.debug("deferred load %s group %s" % (mapperutil.state_attribute_str(state, self.key), group and ','.join(group) or 'None')) - session = sessionlib.object_session(self.instance) + session = sessionlib._state_session(state) if session is None: - raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (self.instance.__class__, self.key)) + raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session; deferred load operation of attribute '%s' cannot proceed" % (mapperutil.state_str(state), self.key)) query = session.query(localparent) - if not self.optimizing_statement: - ident = self.instance._instance_key[1] - query._get(None, ident=ident, only_load_props=group, refresh_instance=self.instance._state) - else: - statement, params = self.optimizing_statement(self.instance) - query.from_statement(statement).params(params)._get(None, only_load_props=group, refresh_instance=self.instance._state) + ident = state.key[1] + query._get(None, ident=ident, only_load_props=group, refresh_instance=state) return attributes.ATTR_WAS_SET class DeferredOption(StrategizedOption): @@ -223,55 +239,63 @@ class UndeferGroupOption(MapperOption): class AbstractRelationLoader(LoaderStrategy): def init(self): super(AbstractRelationLoader, self).init() - for attr in ['primaryjoin', 'secondaryjoin', 'secondary', 'foreign_keys', 'mapper', 'target', 'table', 'uselist', 'cascade', 'attributeext', 'order_by', 'remote_side', 'direction']: + for attr in ['mapper', 'target', 'table', 'uselist']: setattr(self, attr, getattr(self.parent_property, attr)) - self._should_log_debug = logging.is_debug_enabled(self.logger) + self._should_log_debug = log.is_debug_enabled(self.logger) - def _init_instance_attribute(self, instance, callable_=None): + def _init_instance_attribute(self, state, callable_=None): if callable_: - instance._state.set_callable(self.key, callable_) + state.set_callable(self.key, callable_) else: - instance._state.initialize(self.key) + state.initialize(self.key) def _register_attribute(self, class_, callable_=None, **kwargs): - self.logger.info("register managed %s attribute %s on class %s" % ((self.uselist and "list-holding" or "scalar"), self.key, self.parent.class_.__name__)) - sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=self.attributeext, cascade=self.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, **kwargs) + self.logger.info("%s register managed %s attribute" % (self, (self.uselist and "collection" or "scalar"))) + + if self.parent_property.backref: + attribute_ext = self.parent_property.backref.extension + else: + attribute_ext = None + + sessionlib.register_attribute(class_, self.key, uselist=self.uselist, useobject=True, extension=attribute_ext, cascade=self.parent_property.cascade, trackparent=True, typecallable=self.parent_property.collection_class, callable_=callable_, comparator=self.parent_property.comparator, parententity=self.parent, **kwargs) class NoLoader(AbstractRelationLoader): def init_class_attribute(self): self.is_class_level = True self._register_attribute(self.parent.class_) - def create_row_processor(self, selectcontext, mapper, row): - def new_execute(instance, row, ispostselect, **flags): - if not ispostselect: - if self._should_log_debug: - self.logger.debug("initializing blank scalar/collection on %s" % mapperutil.attribute_str(instance, self.key)) - self._init_instance_attribute(instance) - return (new_execute, None, None) + def create_row_processor(self, selectcontext, path, mapper, row, adapter): + def new_execute(state, row, **flags): + self._init_instance_attribute(state) + + if self._should_log_debug: + new_execute = self.debug_callable(new_execute, self.logger, None, + lambda state, row, **flags: "initializing blank scalar/collection on %s" % mapperutil.state_attribute_str(state, self.key) + ) + return (new_execute, None) -NoLoader.logger = logging.class_logger(NoLoader) +NoLoader.logger = log.class_logger(NoLoader) class LazyLoader(AbstractRelationLoader): def init(self): super(LazyLoader, self).init() (self.__lazywhere, self.__bind_to_col, self._equated_columns) = self.__create_lazy_clause(self.parent_property) - self.logger.info(str(self.parent_property) + " lazy loading clause " + str(self.__lazywhere)) + self.logger.info("%s lazy loading clause %s" % (self, self.__lazywhere)) # determine if our "lazywhere" clause is the same as the mapper's # get() clause. then we can just use mapper.get() #from sqlalchemy.orm import query self.use_get = not self.uselist and self.mapper._get_clause[0].compare(self.__lazywhere) if self.use_get: - self.logger.info(str(self.parent_property) + " will use query.get() to optimize instance loads") + self.logger.info("%s will use query.get() to optimize instance loads" % self) def init_class_attribute(self): self.is_class_level = True self._register_attribute(self.parent.class_, callable_=self.class_level_loader) - def lazy_clause(self, instance, reverse_direction=False): - if instance is None: + def lazy_clause(self, state, reverse_direction=False): + if state is None: return self._lazy_none_clause(reverse_direction) if not reverse_direction: @@ -285,8 +309,8 @@ class LazyLoader(AbstractRelationLoader): # use the "committed" (database) version to get query column values # also its a deferred value; so that when used by Query, the committed value is used # after an autoflush occurs - bindparam.value = lambda: mapper._get_committed_attr_by_column(instance, bind_to_col[bindparam.key]) - return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam) + bindparam.value = lambda: mapper._get_committed_state_attr_by_column(state, bind_to_col[bindparam.key]) + return visitors.cloned_traverse(criterion, {}, {'bindparam':visit_bindparam}) def _lazy_none_clause(self, reverse_direction=False): if not reverse_direction: @@ -305,13 +329,13 @@ class LazyLoader(AbstractRelationLoader): binary.right = expression.null() binary.operator = operators.is_ - return visitors.traverse(criterion, clone=True, visit_binary=visit_binary) + return visitors.cloned_traverse(criterion, {}, {'binary':visit_binary}) - def class_level_loader(self, instance, options=None, path=None): - if not mapper.has_mapper(instance): + def class_level_loader(self, state, options=None, path=None): + if not mapperutil._state_has_mapper(state): return None - localparent = mapper.object_mapper(instance) + localparent = mapper._state_mapper(state) # adjust for the PropertyLoader associated with the instance # not being our own PropertyLoader. This can occur when entity_name @@ -319,35 +343,41 @@ class LazyLoader(AbstractRelationLoader): # to the class. prop = localparent.get_property(self.key) if prop is not self.parent_property: - return prop._get_strategy(LazyLoader).setup_loader(instance) + return prop._get_strategy(LazyLoader).setup_loader(state) - return LoadLazyAttribute(instance, self.key, options, path) + return LoadLazyAttribute(state, self.key, options, path) - def setup_loader(self, instance, options=None, path=None): - return LoadLazyAttribute(instance, self.key, options, path) + def setup_loader(self, state, options=None, path=None): + return LoadLazyAttribute(state, self.key, options, path) - def create_row_processor(self, selectcontext, mapper, row): + def create_row_processor(self, selectcontext, path, mapper, row, adapter): if not self.is_class_level or len(selectcontext.options): - def new_execute(instance, row, ispostselect, **flags): - if not ispostselect: - if self._should_log_debug: - self.logger.debug("set instance-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) - # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, - # which will override the class-level behavior - - self._init_instance_attribute(instance, callable_=self.setup_loader(instance, selectcontext.options, selectcontext.query._current_path + selectcontext.path)) - return (new_execute, None, None) + path = path + (self.key,) + def new_execute(state, row, **flags): + # we are not the primary manager for this attribute on this class - set up a per-instance lazyloader, + # which will override the class-level behavior + self._init_instance_attribute(state, callable_=self.setup_loader(state, selectcontext.options, selectcontext.query._current_path + path)) + + if self._should_log_debug: + new_execute = self.debug_callable(new_execute, self.logger, None, + lambda state, row, **flags: "set instance-level lazy loader on %s" % mapperutil.state_attribute_str(state, self.key) + ) + + return (new_execute, None) else: - def new_execute(instance, row, ispostselect, **flags): - if not ispostselect: - if self._should_log_debug: - self.logger.debug("set class-level lazy loader on %s" % mapperutil.attribute_str(instance, self.key)) - # we are the primary manager for this attribute on this class - reset its per-instance attribute state, - # so that the class-level lazy loader is executed when next referenced on this instance. - # this usually is not needed unless the constructor of the object referenced the attribute before we got - # to load data into it. - instance._state.reset(self.key) - return (new_execute, None, None) + def new_execute(state, row, **flags): + # we are the primary manager for this attribute on this class - reset its per-instance attribute state, + # so that the class-level lazy loader is executed when next referenced on this instance. + # this usually is not needed unless the constructor of the object referenced the attribute before we got + # to load data into it. + state.reset(self.key) + + if self._should_log_debug: + new_execute = self.debug_callable(new_execute, self.logger, None, + lambda state, row, **flags: "set class-level lazy loader on %s" % mapperutil.state_attribute_str(state, self.key) + ) + + return (new_execute, None) def __create_lazy_clause(cls, prop, reverse_direction=False): binds = {} @@ -374,16 +404,16 @@ class LazyLoader(AbstractRelationLoader): binds[col] = sql.bindparam(None, None, type_=col.type) return binds[col] return None - - lazywhere = prop.primaryjoin + lazywhere = prop.primaryjoin + if not prop.secondaryjoin or not reverse_direction: - lazywhere = visitors.traverse(lazywhere, before_clone=col_to_bind, clone=True) + lazywhere = visitors.replacement_traverse(lazywhere, {}, col_to_bind) if prop.secondaryjoin is not None: secondaryjoin = prop.secondaryjoin if reverse_direction: - secondaryjoin = visitors.traverse(secondaryjoin, before_clone=col_to_bind, clone=True) + secondaryjoin = visitors.replacement_traverse(secondaryjoin, {}, col_to_bind) lazywhere = sql.and_(lazywhere, secondaryjoin) bind_to_col = dict([(binds[col].key, col) for col in binds]) @@ -391,47 +421,44 @@ class LazyLoader(AbstractRelationLoader): return (lazywhere, bind_to_col, equated_columns) __create_lazy_clause = classmethod(__create_lazy_clause) -LazyLoader.logger = logging.class_logger(LazyLoader) +LazyLoader.logger = log.class_logger(LazyLoader) class LoadLazyAttribute(object): - """callable, serializable loader object used by LazyLoader""" + """serializable loader object used by LazyLoader""" - def __init__(self, instance, key, options, path): - self.instance = instance + def __init__(self, state, key, options, path): + self.state = state self.key = key self.options = options self.path = path def __getstate__(self): - return {'instance':self.instance, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)} + return {'state':self.state, 'key':self.key, 'options':self.options, 'path':serialize_path(self.path)} def __setstate__(self, state): - self.instance = state['instance'] + self.state = state['state'] self.key = state['key'] - self.options= state['options'] + self.options = state['options'] self.path = deserialize_path(state['path']) def __call__(self): - instance = self.instance - - if not mapper.has_identity(instance): + state = self.state + if not mapper._state_has_identity(state): return None - instance_mapper = mapper.object_mapper(instance) + instance_mapper = mapper._state_mapper(state) prop = instance_mapper.get_property(self.key) strategy = prop._get_strategy(LazyLoader) if strategy._should_log_debug: - strategy.logger.debug("lazy load attribute %s on instance %s" % (self.key, mapperutil.instance_str(instance))) + strategy.logger.debug("loading %s" % mapperutil.state_attribute_str(state, self.key)) - session = sessionlib.object_session(instance) + session = sessionlib._state_session(state) if session is None: - try: - session = instance_mapper.get_session() - except exceptions.InvalidRequestError: - raise exceptions.UnboundExecutionError("Parent instance %s is not bound to a Session, and no contextual session is established; lazy load operation of attribute '%s' cannot proceed" % (instance.__class__, self.key)) + raise sa_exc.UnboundExecutionError("Parent instance %s is not bound to a Session; lazy load operation of attribute '%s' cannot proceed" % (mapperutil.state_str(state), self.key)) - q = session.query(prop.mapper).autoflush(False) + q = session.query(prop.mapper).autoflush(False)._adapt_all_clauses() + if self.path: q = q._with_current_path(self.path) @@ -441,7 +468,7 @@ class LoadLazyAttribute(object): ident = [] allnulls = True for primary_key in prop.mapper.primary_key: - val = instance_mapper._get_committed_attr_by_column(instance, strategy._equated_columns[primary_key]) + val = instance_mapper._get_committed_state_attr_by_column(state, strategy._equated_columns[primary_key]) allnulls = allnulls and val is None ident.append(val) if allnulls: @@ -450,14 +477,14 @@ class LoadLazyAttribute(object): q = q._conditional_options(*self.options) return q.get(ident) - if strategy.order_by is not False: - q = q.order_by(strategy.order_by) - elif strategy.secondary is not None and strategy.secondary.default_order_by() is not None: - q = q.order_by(strategy.secondary.default_order_by()) + if prop.order_by is not False: + q = q.order_by(prop.order_by) + elif prop.secondary is not None and prop.secondary.default_order_by() is not None: + q = q.order_by(prop.secondary.default_order_by()) if self.options: q = q._conditional_options(*self.options) - q = q.filter(strategy.lazy_clause(instance)) + q = q.filter(strategy.lazy_clause(state)) result = q.all() if strategy.uselist: @@ -478,19 +505,35 @@ class EagerLoader(AbstractRelationLoader): self.join_depth = self.parent_property.join_depth def init_class_attribute(self): - # class-level eager strategy; add the PropertyLoader - # to the parent's list of "eager loaders"; this tells the Query - # that eager loaders will be used in a normal query - self.parent._eager_loaders.add(self.parent_property) - - # initialize a lazy loader on the class level attribute self.parent_property._get_strategy(LazyLoader).init_class_attribute() - def setup_query(self, context, parentclauses=None, parentmapper=None, **kwargs): + def setup_query(self, context, entity, path, adapter, column_collection=None, parentmapper=None, **kwargs): """Add a left outer join to the statement thats being constructed.""" + + path = path + (self.key,) + + # check for user-defined eager alias + if ("eager_row_processor", path) in context.attributes: + clauses = context.attributes[("eager_row_processor", path)] + + adapter = entity._get_entity_clauses(context.query, context) + if adapter and clauses: + context.attributes[("eager_row_processor", path)] = clauses = adapter.wrap(clauses) + elif adapter: + context.attributes[("eager_row_processor", path)] = clauses = adapter + + else: - path = context.path - + clauses = self.__create_eager_join(context, entity, path, adapter, parentmapper) + if not clauses: + return + + context.attributes[("eager_row_processor", path)] = clauses + + for value in self.mapper._iterate_polymorphic_properties(): + value.setup(context, entity, path + (self.mapper.base_mapper,), clauses, parentmapper=self.mapper, column_collection=context.secondary_columns) + + def __create_eager_join(self, context, entity, path, adapter, parentmapper): # check for join_depth or basic recursion, # if the current path was not explicitly stated as # a desired "loaderstrategy" (i.e. via query.options()) @@ -502,159 +545,148 @@ class EagerLoader(AbstractRelationLoader): if self.mapper.base_mapper in path: return - if ("eager_row_processor", path) in context.attributes: - # if user defined eager_row_processor, that's contains_eager(). - # don't render LEFT OUTER JOIN, generate an AliasedClauses from - # the decorator (this is a hack here, cleaned up in 0.5) - cl = context.attributes[("eager_row_processor", path)] - if cl: - row = cl(None) - class ActsLikeAliasedClauses(object): - def aliased_column(self, col): - return row.map[col] - clauses = ActsLikeAliasedClauses() - else: - clauses = None - else: - clauses = self.__create_eager_join(context, path, parentclauses, parentmapper, **kwargs) - if not clauses: - return - - for value in self.mapper._iterate_polymorphic_properties(): - context.exec_with_path(self.mapper, value.key, value.setup, context, parentclauses=clauses, parentmapper=self.mapper) - - def __create_eager_join(self, context, path, parentclauses, parentmapper, **kwargs): if parentmapper is None: - localparent = context.mapper + localparent = entity.mapper else: localparent = parentmapper - - if context.eager_joins: - towrap = context.eager_joins + + # whether or not the Query will wrap the selectable in a subquery, + # and then attach eager load joins to that (i.e., in the case of LIMIT/OFFSET etc.) + should_nest_selectable = context.query._should_nest_selectable + + if entity in context.eager_joins: + entity_key, default_towrap = entity, entity.selectable + elif should_nest_selectable or not context.from_clause or not sql_util.search(context.from_clause, entity.selectable): + # if no from_clause, or a from_clause we can't join to, or a subquery is going to be generated, + # store eager joins per _MappedEntity; Query._compile_context will + # add them as separate selectables to the select(), or splice them together + # after the subquery is generated + entity_key, default_towrap = entity, entity.selectable else: - towrap = context.from_clause - - # create AliasedClauses object to build up the eager query. this is cached after 1st creation. + # otherwise, create a single eager join from the from clause. + # Query._compile_context will adapt as needed and append to the + # FROM clause of the select(). + entity_key, default_towrap = None, context.from_clause + + towrap = context.eager_joins.setdefault(entity_key, default_towrap) + + # create AliasedClauses object to build up the eager query. this is cached after 1st creation. + # this also allows ORMJoin to cache the aliased joins it produces since we pass the same + # args each time in the typical case. + path_key = util.WeakCompositeKey(*path) try: - clauses = self.clauses[path] + clauses = self.clauses[path_key] except KeyError: - clauses = mapperutil.PropertyAliasedClauses(self.parent_property, self.parent_property.primaryjoin, self.parent_property.secondaryjoin, parentclauses) - self.clauses[path] = clauses + self.clauses[path_key] = clauses = mapperutil.ORMAdapter(mapperutil.AliasedClass(self.mapper), + equivalents=self.mapper._equivalent_columns, + chain_to=adapter) - # place the "row_decorator" from the AliasedClauses into the QueryContext, where it will - # be picked up in create_row_processor() when results are fetched - context.attributes[("eager_row_processor", path)] = clauses.row_decorator - - if self.secondaryjoin is not None: - context.eager_joins = sql.outerjoin(towrap, clauses.secondary, clauses.primaryjoin).outerjoin(clauses.alias, clauses.secondaryjoin) - - # TODO: check for "deferred" cols on parent/child tables here ? this would only be - # useful if the primary/secondaryjoin are against non-PK columns on the tables (and therefore might be deferred) - - if self.order_by is False and self.secondary.default_order_by() is not None: - context.eager_order_by += clauses.secondary.default_order_by() + if adapter: + if getattr(adapter, 'aliased_class', None): + onclause = getattr(adapter.aliased_class, self.key, self.parent_property) + else: + onclause = getattr(mapperutil.AliasedClass(self.parent, adapter.selectable), self.key, self.parent_property) else: - context.eager_joins = towrap.outerjoin(clauses.alias, clauses.primaryjoin) - # ensure all the cols on the parent side are actually in the + onclause = self.parent_property + + context.eager_joins[entity_key] = eagerjoin = mapperutil.outerjoin(towrap, clauses.aliased_class, onclause) + + # send a hint to the Query as to where it may "splice" this join + eagerjoin.stop_on = entity.selectable + + if not self.parent_property.secondary and context.query._should_nest_selectable and not parentmapper: + # for parentclause that is the non-eager end of the join, + # ensure all the parent cols in the primaryjoin are actually in the # columns clause (i.e. are not deferred), so that aliasing applied by the Query propagates # those columns outward. This has the effect of "undefering" those columns. - for col in sql_util.find_columns(clauses.primaryjoin): + for col in sql_util.find_columns(self.parent_property.primaryjoin): if localparent.mapped_table.c.contains_column(col): + if adapter: + col = adapter.columns[col] context.primary_columns.append(col) - - if self.order_by is False and clauses.alias.default_order_by() is not None: - context.eager_order_by += clauses.alias.default_order_by() - - if clauses.order_by: - context.eager_order_by += util.to_list(clauses.order_by) + if self.parent_property.order_by is False: + if self.parent_property.secondaryjoin: + default_order_by = eagerjoin.left.right.default_order_by() + else: + default_order_by = eagerjoin.right.default_order_by() + if default_order_by: + context.eager_order_by += default_order_by + elif self.parent_property.order_by: + context.eager_order_by += eagerjoin._target_adapter.copy_and_process(util.to_list(self.parent_property.order_by)) + return clauses - def _create_row_decorator(self, selectcontext, row, path): - """Create a *row decorating* function that will apply eager - aliasing to the row. - - Also check that an identity key can be retrieved from the row, - else return None. - """ - - #print "creating row decorator for path ", "->".join([str(s) for s in path]) - - if ("eager_row_processor", path) in selectcontext.attributes: - decorator = selectcontext.attributes[("eager_row_processor", path)] - if decorator is None: - decorator = lambda row: row + def __create_eager_adapter(self, context, row, adapter, path): + if ("eager_row_processor", path) in context.attributes: + decorator = context.attributes[("eager_row_processor", path)] else: if self._should_log_debug: self.logger.debug("Could not locate aliased clauses for key: " + str(path)) - return None + return False + if adapter and decorator: + decorator = adapter.wrap(decorator) + elif adapter: + decorator = adapter + try: - decorated_row = decorator(row) - # check for identity key - identity_key = self.mapper.identity_key_from_row(decorated_row) - # and its good + identity_key = self.mapper.identity_key_from_row(row, decorator) return decorator except KeyError, k: # no identity key - dont return a row processor, will cause a degrade to lazy if self._should_log_debug: - self.logger.debug("could not locate identity key from row '%s'; missing column '%s'" % (repr(decorated_row), str(k))) - return None - - def create_row_processor(self, selectcontext, mapper, row): + self.logger.debug("could not locate identity key from row; missing column '%s'" % k) + return False - row_decorator = self._create_row_decorator(selectcontext, row, selectcontext.path) - pathstr = ','.join([str(x) for x in selectcontext.path]) - if row_decorator is not None: - def execute(instance, row, isnew, **flags): - decorated_row = row_decorator(row) - - if not self.uselist: - if self._should_log_debug: - self.logger.debug("eagerload scalar instance on %s" % mapperutil.attribute_str(instance, self.key)) + def create_row_processor(self, context, path, mapper, row, adapter): + path = path + (self.key,) + eager_adapter = self.__create_eager_adapter(context, row, adapter, path) + + if eager_adapter is not False: + key = self.key + _instance = self.mapper._instance_processor(context, path + (self.mapper.base_mapper,), eager_adapter) + + if not self.uselist: + def execute(state, row, isnew, **flags): if isnew: # set a scalar object instance directly on the # parent object, bypassing InstrumentedAttribute # event handlers. - # - instance.__dict__[self.key] = self.mapper._instance(selectcontext, decorated_row, None) + state.dict[key] = _instance(row, None) else: # call _instance on the row, even though the object has been created, # so that we further descend into properties - self.mapper._instance(selectcontext, decorated_row, None) - else: - if isnew or self.key not in instance._state.appenders: - # appender_key can be absent from selectcontext.attributes with isnew=False + _instance(row, None) + else: + def execute(state, row, isnew, **flags): + if isnew or (state, key) not in context.attributes: + # appender_key can be absent from context.attributes with isnew=False # when self-referential eager loading is used; the same instance may be present # in two distinct sets of result columns - - if self._should_log_debug: - self.logger.debug("initialize UniqueAppender on %s" % mapperutil.attribute_str(instance, self.key)) - collection = attributes.init_collection(instance, self.key) + collection = attributes.init_collection(state, key) appender = util.UniqueAppender(collection, 'append_without_event') - instance._state.appenders[self.key] = appender - - result_list = instance._state.appenders[self.key] - if self._should_log_debug: - self.logger.debug("eagerload list instance on %s" % mapperutil.attribute_str(instance, self.key)) + context.attributes[(state, key)] = appender + + result_list = context.attributes[(state, key)] - self.mapper._instance(selectcontext, decorated_row, result_list) + _instance(row, result_list) if self._should_log_debug: - self.logger.debug("Returning eager instance loader for %s" % str(self)) + execute = self.debug_callable(execute, self.logger, + "%s returning eager instance loader" % self, + lambda state, row, isnew, **flags: "%s eagerload %s" % (self, self.uselist and "scalar attribute" or "collection") + ) - return (execute, execute, None) + return (execute, execute) else: if self._should_log_debug: - self.logger.debug("eager loader %s degrading to lazy loader" % str(self)) - return self.parent_property._get_strategy(LazyLoader).create_row_processor(selectcontext, mapper, row) + self.logger.debug("%s degrading to lazy loader" % self) + return self.parent_property._get_strategy(LazyLoader).create_row_processor(context, path, mapper, row, adapter) - def __str__(self): - return str(self.parent) + "." + self.key - -EagerLoader.logger = logging.class_logger(EagerLoader) +EagerLoader.logger = log.class_logger(EagerLoader) class EagerLazyOption(StrategizedOption): def __init__(self, key, lazy=True, chained=False, mapper=None): @@ -665,20 +697,6 @@ class EagerLazyOption(StrategizedOption): def is_chained(self): return not self.lazy and self.chained - def process_query_property(self, query, paths): - if self.lazy: - if paths[-1] in query._eager_loaders: - query._eager_loaders = query._eager_loaders.difference(util.Set([paths[-1]])) - else: - if not self.chained: - paths = [paths[-1]] - res = util.Set() - for path in paths: - if len(path) - len(query._current_path) == 2: - res.add(path) - query._eager_loaders = query._eager_loaders.union(res) - super(EagerLazyOption, self).process_query_property(query, paths) - def get_strategy_class(self): if self.lazy: return LazyLoader @@ -687,24 +705,26 @@ class EagerLazyOption(StrategizedOption): elif self.lazy is None: return NoLoader -EagerLazyOption.logger = logging.class_logger(EagerLazyOption) - -class RowDecorateOption(PropertyOption): - def __init__(self, key, decorator=None, alias=None): - super(RowDecorateOption, self).__init__(key) - self.decorator = decorator +class LoadEagerFromAliasOption(PropertyOption): + def __init__(self, key, alias=None): + super(LoadEagerFromAliasOption, self).__init__(key) + if alias: + if not isinstance(alias, basestring): + m, alias, is_aliased_class = mapperutil._entity_info(alias) self.alias = alias def process_query_property(self, query, paths): - if self.alias is not None and self.decorator is None: - (mapper, propname) = paths[-1][-2:] - - prop = mapper.get_property(propname, resolve_synonyms=True) + if self.alias: if isinstance(self.alias, basestring): - self.alias = prop.target.alias(self.alias) + (mapper, propname) = paths[-1][-2:] - self.decorator = mapperutil.create_row_adapter(self.alias) - query._attributes[("eager_row_processor", paths[-1])] = self.decorator + prop = mapper.get_property(propname, resolve_synonyms=True) + self.alias = prop.target.alias(self.alias) + if not isinstance(self.alias, expression.Alias): + import pdb + pdb.set_trace() + query._attributes[("eager_row_processor", paths[-1])] = sql_util.ColumnAdapter(self.alias) + else: + query._attributes[("eager_row_processor", paths[-1])] = None -RowDecorateOption.logger = logging.class_logger(RowDecorateOption) diff --git a/lib/sqlalchemy/orm/sync.py b/lib/sqlalchemy/orm/sync.py index 39a7b5044..eca80df25 100644 --- a/lib/sqlalchemy/orm/sync.py +++ b/lib/sqlalchemy/orm/sync.py @@ -8,31 +8,27 @@ based on join conditions. """ -from sqlalchemy import schema, exceptions, util -from sqlalchemy.sql import visitors, operators, util as sqlutil -from sqlalchemy import logging -from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.interfaces import ONETOMANY, MANYTOONE, MANYTOMANY # legacy +from sqlalchemy.orm import exc, util as mapperutil def populate(source, source_mapper, dest, dest_mapper, synchronize_pairs): for l, r in synchronize_pairs: try: value = source_mapper._get_state_attr_by_column(source, l) - except exceptions.UnmappedColumnError: + except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, dest_mapper, r) try: dest_mapper._set_state_attr_by_column(dest, r, value) - except exceptions.UnmappedColumnError: - self._raise_col_to_prop(True, source_mapper, l, dest_mapper, r) + except exc.UnmappedColumnError: + _raise_col_to_prop(True, source_mapper, l, dest_mapper, r) def clear(dest, dest_mapper, synchronize_pairs): for l, r in synchronize_pairs: if r.primary_key: - raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest))) + raise AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (r, mapperutil.state_str(dest))) try: dest_mapper._set_state_attr_by_column(dest, r, None) - except exceptions.UnmappedColumnError: + except exc.UnmappedColumnError: _raise_col_to_prop(True, None, l, dest_mapper, r) def update(source, source_mapper, dest, old_prefix, synchronize_pairs): @@ -40,8 +36,8 @@ def update(source, source_mapper, dest, old_prefix, synchronize_pairs): try: oldvalue = source_mapper._get_committed_attr_by_column(source.obj(), l) value = source_mapper._get_state_attr_by_column(source, l) - except exceptions.UnmappedColumnError: - self._raise_col_to_prop(False, source_mapper, l, None, r) + except exc.UnmappedColumnError: + _raise_col_to_prop(False, source_mapper, l, None, r) dest[r.key] = value dest[old_prefix + r.key] = oldvalue @@ -49,16 +45,16 @@ def populate_dict(source, source_mapper, dict_, synchronize_pairs): for l, r in synchronize_pairs: try: value = source_mapper._get_state_attr_by_column(source, l) - except exceptions.UnmappedColumnError: + except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) - + dict_[r.key] = value def source_changes(uowcommit, source, source_mapper, synchronize_pairs): for l, r in synchronize_pairs: try: prop = source_mapper._get_col_to_prop(l) - except exceptions.UnmappedColumnError: + except exc.UnmappedColumnError: _raise_col_to_prop(False, source_mapper, l, None, r) (added, unchanged, deleted) = uowcommit.get_attribute_history(source, prop.key, passive=True) if added and deleted: @@ -70,7 +66,7 @@ def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs): for l, r in synchronize_pairs: try: prop = dest_mapper._get_col_to_prop(r) - except exceptions.UnmappedColumnError: + except exc.UnmappedColumnError: _raise_col_to_prop(True, None, l, dest_mapper, r) (added, unchanged, deleted) = uowcommit.get_attribute_history(dest, prop.key, passive=True) if added and deleted: @@ -80,7 +76,6 @@ def dest_changes(uowcommit, dest, dest_mapper, synchronize_pairs): def _raise_col_to_prop(isdest, source_mapper, source_column, dest_mapper, dest_column): if isdest: - raise exceptions.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper)) + raise exc.UnmappedColumnError("Can't execute sync rule for destination column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include this column (or use a viewonly=True relation)." % (dest_column, source_mapper)) else: - raise exceptions.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column)) - + raise exc.UnmappedColumnError("Can't execute sync rule for source column '%s'; mapper '%s' does not map this column. Try using an explicit `foreign_keys` collection which does not include destination column '%s' (or use a viewonly=True relation)." % (source_column, source_mapper, dest_column)) diff --git a/lib/sqlalchemy/orm/unitofwork.py b/lib/sqlalchemy/orm/unitofwork.py index 66b68770d..4edfeefdc 100644 --- a/lib/sqlalchemy/orm/unitofwork.py +++ b/lib/sqlalchemy/orm/unitofwork.py @@ -17,16 +17,19 @@ unique against their primary key identity using an *identity map* pattern. The Unit of Work then maintains lists of objects that are new, dirty, or deleted and provides the capability to flush all those changes at once. + """ -import StringIO, weakref -from sqlalchemy import util, logging, topological, exceptions +import StringIO + +from sqlalchemy import util, log, topological from sqlalchemy.orm import attributes, interfaces from sqlalchemy.orm import util as mapperutil -from sqlalchemy.orm.mapper import object_mapper, _state_mapper, has_identity +from sqlalchemy.orm.mapper import _state_mapper # Load lazily object_session = None +_state_session = None class UOWEventHandler(interfaces.AttributeExtension): """An event handler added to all relation attributes which handles @@ -38,33 +41,33 @@ class UOWEventHandler(interfaces.AttributeExtension): self.class_ = class_ self.cascade = cascade - def _target_mapper(self, obj): - prop = object_mapper(obj).get_property(self.key) + def _target_mapper(self, state): + prop = _state_mapper(state).get_property(self.key) return prop.mapper - def append(self, obj, item, initiator): + def append(self, state, item, initiator): # process "save_update" cascade rules for when an instance is appended to the list of another instance - sess = object_session(obj) + sess = _state_session(state) if sess: if self.cascade.save_update and item not in sess: - sess.save_or_update(item, entity_name=self._target_mapper(obj).entity_name) + sess.save_or_update(item, entity_name=self._target_mapper(state).entity_name) - def remove(self, obj, item, initiator): - sess = object_session(obj) + def remove(self, state, item, initiator): + sess = _state_session(state) if sess: # expunge pending orphans if self.cascade.delete_orphan and item in sess.new: - if self._target_mapper(obj)._is_orphan(item): + if self._target_mapper(state)._is_orphan(attributes.instance_state(item)): sess.expunge(item) - def set(self, obj, newvalue, oldvalue, initiator): + def set(self, state, newvalue, oldvalue, initiator): # process "save_update" cascade rules for when an instance is attached to another instance if oldvalue is newvalue: return - sess = object_session(obj) + sess = _state_session(state) if sess: if newvalue is not None and self.cascade.save_update and newvalue not in sess: - sess.save_or_update(newvalue, entity_name=self._target_mapper(obj).entity_name) + sess.save_or_update(newvalue, entity_name=self._target_mapper(state).entity_name) if self.cascade.delete_orphan and oldvalue in sess.new: sess.expunge(oldvalue) @@ -86,184 +89,6 @@ def register_attribute(class_, key, *args, **kwargs): -class UnitOfWork(object): - """Main UOW object which stores lists of dirty/new/deleted objects. - - Provides top-level *flush* functionality as well as the - default transaction boundaries involved in a write - operation. - """ - - def __init__(self, session): - if session.weak_identity_map: - self.identity_map = attributes.WeakInstanceDict() - else: - self.identity_map = attributes.StrongInstanceDict() - - self.new = {} # InstanceState->object, strong refs object - self.deleted = {} # same - self.logger = logging.instance_logger(self, echoflag=session.echo_uow) - - def _remove_deleted(self, state): - if '_instance_key' in state.dict: - del self.identity_map[state.dict['_instance_key']] - self.deleted.pop(state, None) - self.new.pop(state, None) - - def _is_valid(self, state): - if '_instance_key' in state.dict: - return state.dict['_instance_key'] in self.identity_map - else: - return state in self.new - - def _register_clean(self, state): - """register the given object as 'clean' (i.e. persistent) within this unit of work, after - a save operation has taken place.""" - - mapper = _state_mapper(state) - instance_key = mapper._identity_key_from_state(state) - - if '_instance_key' not in state.dict: - state.dict['_instance_key'] = instance_key - - elif state.dict['_instance_key'] != instance_key: - # primary key switch - del self.identity_map[state.dict['_instance_key']] - state.dict['_instance_key'] = instance_key - - if hasattr(state, 'insert_order'): - delattr(state, 'insert_order') - - o = state.obj() - # prevent against last minute dereferences of the object - # TODO: identify a code path where state.obj() is None - if o is not None: - self.identity_map[state.dict['_instance_key']] = o - state.commit_all() - - # remove from new last, might be the last strong ref - self.new.pop(state, None) - - def register_new(self, obj): - """register the given object as 'new' (i.e. unsaved) within this unit of work.""" - - if hasattr(obj, '_instance_key'): - raise exceptions.InvalidRequestError("Object '%s' already has an identity - it can't be registered as new" % repr(obj)) - if obj._state not in self.new: - self.new[obj._state] = obj - obj._state.insert_order = len(self.new) - - def register_deleted(self, obj): - """register the given persistent object as 'to be deleted' within this unit of work.""" - - self.deleted[obj._state] = obj - - def locate_dirty(self): - """return a set of all persistent instances within this unit of work which - either contain changes or are marked as deleted. - """ - - # a little bit of inlining for speed - return util.IdentitySet([x for x in self.identity_map.values() - if x._state not in self.deleted - and ( - x._state.modified - or (x.__class__._class_state.has_mutable_scalars and x._state.is_modified()) - ) - ]) - - def flush(self, session, objects=None): - """create a dependency tree of all pending SQL operations within this unit of work and execute.""" - - dirty = [x for x in self.identity_map.all_states() - if x.modified - or (x.class_._class_state.has_mutable_scalars and x.is_modified()) - ] - - if not dirty and not self.deleted and not self.new: - return - - deleted = util.Set(self.deleted) - new = util.Set(self.new) - - dirty = util.Set(dirty).difference(deleted) - - flush_context = UOWTransaction(self, session) - - if session.extension is not None: - session.extension.before_flush(session, flush_context, objects) - - # create the set of all objects we want to operate upon - if objects: - # specific list passed in - objset = util.Set([o._state for o in objects]) - else: - # or just everything - objset = util.Set(self.identity_map.all_states()).union(new) - - # store objects whose fate has been decided - processed = util.Set() - - # put all saves/updates into the flush context. detect top-level orphans and throw them into deleted. - for state in new.union(dirty).intersection(objset).difference(deleted): - if state in processed: - continue - - obj = state.obj() - is_orphan = _state_mapper(state)._is_orphan(obj) - if is_orphan and not has_identity(obj): - raise exceptions.FlushError("instance %s is an unsaved, pending instance and is an orphan (is not attached to %s)" % - ( - obj, - ", nor ".join(["any parent '%s' instance via that classes' '%s' attribute" % (klass.__name__, key) for (key,klass) in _state_mapper(state).delete_orphans]) - )) - flush_context.register_object(state, isdelete=is_orphan) - processed.add(state) - - # put all remaining deletes into the flush context. - for state in deleted.intersection(objset).difference(processed): - flush_context.register_object(state, isdelete=True) - - if len(flush_context.tasks) == 0: - return - - session.create_transaction(autoflush=False) - flush_context.transaction = session.transaction - try: - flush_context.execute() - - if session.extension is not None: - session.extension.after_flush(session, flush_context) - session.commit() - except: - session.rollback() - raise - - flush_context.post_exec() - - if session.extension is not None: - session.extension.after_flush_postexec(session, flush_context) - - def prune_identity_map(self): - """Removes unreferenced instances cached in a strong-referencing identity map. - - Note that this method is only meaningful if "weak_identity_map" - on the parent Session is set to False and therefore this UnitOfWork's - identity map is a regular dictionary - - Removes any object in the identity map that is not referenced - in user code or scheduled for a unit of work operation. Returns - the number of objects pruned. - """ - - if isinstance(self.identity_map, attributes.WeakInstanceDict): - return 0 - ref_count = len(self.identity_map) - dirty = self.locate_dirty() - keepers = weakref.WeakValueDictionary(self.identity_map) - self.identity_map.clear() - self.identity_map.update(keepers) - return ref_count - len(self.identity_map) class UOWTransaction(object): """Handles the details of organizing and executing transaction @@ -275,8 +100,7 @@ class UOWTransaction(object): packages. """ - def __init__(self, uow, session): - self.uow = uow + def __init__(self, session): self.session = session self.mapper_flush_opts = session._mapper_flush_opts @@ -291,7 +115,7 @@ class UOWTransaction(object): # information. self.attributes = {} - self.logger = logging.instance_logger(self, echoflag=session.echo_uow) + self.logger = log.instance_logger(self, echoflag=session.echo_uow) def get_attribute_history(self, state, key, passive=True): hashkey = ("history", state, key) @@ -310,19 +134,18 @@ class UOWTransaction(object): (added, unchanged, deleted) = attributes.get_history(state, key, passive=passive) self.attributes[hashkey] = (added, unchanged, deleted, passive) - if added is None: + if added is None or not state.get_impl(key).uses_objects: return (added, unchanged, deleted) else: return ( - [getattr(c, '_state', c) for c in added], - [getattr(c, '_state', c) for c in unchanged], - [getattr(c, '_state', c) for c in deleted], + [c is not None and attributes.instance_state(c) or None for c in added], + [c is not None and attributes.instance_state(c) or None for c in unchanged], + [c is not None and attributes.instance_state(c) or None for c in deleted], ) - - def register_object(self, state, isdelete = False, listonly = False, postupdate=False, post_update_cols=None, **kwargs): + def register_object(self, state, isdelete=False, listonly=False, postupdate=False, post_update_cols=None): # if object is not in the overall session, do nothing - if not self.uow._is_valid(state): + if not self.session._contains_state(state): if self._should_log_debug: self.logger.debug("object %s not part of session, not registering for flush" % (mapperutil.state_str(state))) return @@ -331,12 +154,12 @@ class UOWTransaction(object): self.logger.debug("register object for flush: %s isdelete=%s listonly=%s postupdate=%s" % (mapperutil.state_str(state), isdelete, listonly, postupdate)) mapper = _state_mapper(state) - + task = self.get_task_by_mapper(mapper) if postupdate: task.append_postupdate(state, post_update_cols) else: - task.append(state, listonly, isdelete=isdelete, **kwargs) + task.append(state, listonly=listonly, isdelete=isdelete) def set_row_switch(self, state): """mark a deleted object as a 'row switch'. @@ -451,22 +274,26 @@ class UOWTransaction(object): import uowdumper uowdumper.UOWDumper(tasks, buf) return buf.getvalue() - - def post_exec(self): + + def elements(self): + """return an iterator of all UOWTaskElements within this UOWTransaction.""" + for task in self.tasks.values(): + for elem in task.elements: + yield elem + elements = property(elements) + + def finalize_flush_changes(self): """mark processed objects as clean / deleted after a successful flush(). this method is called within the flush() method after the execute() method has succeeded and the transaction has been committed. """ - for task in self.tasks.values(): - for elem in task.elements: - if elem.state is None: - continue - if elem.isdelete: - self.uow._remove_deleted(elem.state) - else: - self.uow._register_clean(elem.state) + for elem in self.elements: + if elem.isdelete: + self.session._remove_newly_deleted(elem.state) + else: + self.session._register_newly_persistent(elem.state) def _sort_dependencies(self): nodes = topological.sort_with_cycles(self.dependencies, @@ -489,10 +316,9 @@ class UOWTransaction(object): class UOWTask(object): """Represents all of the objects in the UOWTransaction which correspond to - a particular mapper. This is the primary class of three classes used to generate - the elements of the dependency graph. + a particular mapper. + """ - def __init__(self, uowtransaction, mapper, base_task=None): self.uowtransaction = uowtransaction @@ -515,6 +341,7 @@ class UOWTask(object): # mapping of InstanceState -> UOWTaskElement self._objects = {} + self.dependent_tasks = [] self.dependencies = util.Set() self.cyclical_dependencies = util.Set() @@ -564,11 +391,6 @@ class UOWTask(object): rec.update(listonly, isdelete) - def _append_cyclical_childtask(self, task): - if "cyclical" not in self._objects: - self._objects["cyclical"] = UOWTaskElement(None) - self._objects["cyclical"].childtasks.append(task) - def append_postupdate(self, state, post_update_cols): """issue a 'post update' UPDATE statement via this object's mapper immediately. @@ -577,8 +399,8 @@ class UOWTask(object): """ # postupdates are UPDATED immeditely (for now) - # convert post_update_cols list to a Set so that __hashcode__ is used to compare columns - # instead of __eq__ + # convert post_update_cols list to a Set so that __hash__() is used to compare columns + # instead of __eq__() self.mapper._save_obj([state], self.uowtransaction, postupdate=True, post_update_cols=util.Set(post_update_cols)) def __contains__(self, state): @@ -607,26 +429,42 @@ class UOWTask(object): for rec in callable(task): yield rec return property(collection) - - elements = property(lambda self:self._objects.values()) - polymorphic_elements = _polymorphic_collection(lambda task:task.elements) - - polymorphic_tosave_elements = property(lambda self: [rec for rec in self.polymorphic_elements - if not rec.isdelete]) - - polymorphic_todelete_elements = property(lambda self:[rec for rec in self.polymorphic_elements - if rec.isdelete]) + def _elements(self): + return self._objects.values() + elements = property(_elements) + + polymorphic_elements = _polymorphic_collection(_elements) - polymorphic_tosave_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements - if rec.state is not None and not rec.listonly and rec.isdelete is False]) + def polymorphic_tosave_elements(self): + return [rec for rec in self.polymorphic_elements if not rec.isdelete] + polymorphic_tosave_elements = property(polymorphic_tosave_elements) + + def polymorphic_todelete_elements(self): + return [rec for rec in self.polymorphic_elements if rec.isdelete] + polymorphic_todelete_elements = property(polymorphic_todelete_elements) + + def polymorphic_tosave_objects(self): + return [ + rec.state for rec in self.polymorphic_elements + if rec.state is not None and not rec.listonly and rec.isdelete is False + ] + polymorphic_tosave_objects = property(polymorphic_tosave_objects) - polymorphic_todelete_objects = property(lambda self:[rec.state for rec in self.polymorphic_elements - if rec.state is not None and not rec.listonly and rec.isdelete is True]) + def polymorphic_todelete_objects(self): + return [ + rec.state for rec in self.polymorphic_elements + if rec.state is not None and not rec.listonly and rec.isdelete is True + ] + polymorphic_todelete_objects = property(polymorphic_todelete_objects) - polymorphic_dependencies = _polymorphic_collection(lambda task:task.dependencies) + def polymorphic_dependencies(self): + return self.dependencies + polymorphic_dependencies = _polymorphic_collection(polymorphic_dependencies) - polymorphic_cyclical_dependencies = _polymorphic_collection(lambda task:task.cyclical_dependencies) + def polymorphic_cyclical_dependencies(self): + return self.cyclical_dependencies + polymorphic_cyclical_dependencies = _polymorphic_collection(polymorphic_cyclical_dependencies) def _sort_circular_dependencies(self, trans, cycles): """Create a hierarchical tree of *subtasks* @@ -741,7 +579,7 @@ class UOWTask(object): if t is None: t = UOWTask(self.uowtransaction, originating_task.mapper) nexttasks[originating_task] = t - parenttask._append_cyclical_childtask(t) + parenttask.dependent_tasks.append(t) t.append(state, originating_task._objects[state].listonly, isdelete=originating_task._objects[state].isdelete) if state in dependencies: @@ -777,29 +615,17 @@ class UOWTask(object): return ret def __repr__(self): - if self.mapper is not None: - if self.mapper.__class__.__name__ == 'Mapper': - name = self.mapper.class_.__name__ + "/" + self.mapper.local_table.description - else: - name = repr(self.mapper) - else: - name = '(none)' - return ("UOWTask(%s) Mapper: '%s'" % (hex(id(self)), name)) + return ("UOWTask(%s) Mapper: '%r'" % (hex(id(self)), self.mapper)) class UOWTaskElement(object): - """An element within a UOWTask. - - Corresponds to a single object instance to be saved, deleted, or - just part of the transaction as a placeholder for further - dependencies (i.e. 'listonly'). - - may also store additional sub-UOWTasks. + """Corresponds to a single InstanceState to be saved, deleted, + or otherwise marked as having dependencies. A collection of + UOWTaskElements are held by a UOWTask. + """ - def __init__(self, state): self.state = state self.listonly = True - self.childtasks = [] self.isdelete = False self.__preprocessed = {} @@ -835,11 +661,11 @@ class UOWTaskElement(object): class UOWDependencyProcessor(object): """In between the saving and deleting of objects, process - *dependent* data, such as filling in a foreign key on a child item + dependent data, such as filling in a foreign key on a child item from a new primary key, or deleting association rows before a delete. This object acts as a proxy to a DependencyProcessor. + """ - def __init__(self, processor, targettask): self.processor = processor self.targettask = targettask @@ -877,12 +703,12 @@ class UOWDependencyProcessor(object): return elem.state ret = False - elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None and not elem.is_preprocessed(self)] + elements = [getobj(elem) for elem in self.targettask.polymorphic_tosave_elements if not elem.is_preprocessed(self)] if elements: ret = True self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=False) - elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None and not elem.is_preprocessed(self)] + elements = [getobj(elem) for elem in self.targettask.polymorphic_todelete_elements if not elem.is_preprocessed(self)] if elements: ret = True self.processor.preprocess_dependencies(self.targettask, elements, trans, delete=True) @@ -892,9 +718,9 @@ class UOWDependencyProcessor(object): """process all objects contained within this ``UOWDependencyProcessor``s target task.""" if not delete: - self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements if elem.state is not None], trans, delete=False) + self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_tosave_elements], trans, delete=False) else: - self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements if elem.state is not None], trans, delete=True) + self.processor.process_dependencies(self.targettask, [elem.state for elem in self.targettask.polymorphic_todelete_elements], trans, delete=True) def get_object_dependencies(self, state, trans, passive): return trans.get_attribute_history(state, self.processor.key, passive=passive) @@ -907,7 +733,6 @@ class UOWDependencyProcessor(object): when toplogically sorting on a per-instance basis. """ - return self.processor.whose_dependent_on_who(state1, state2) def branch(self, task): @@ -917,7 +742,6 @@ class UOWDependencyProcessor(object): is broken up into many individual ``UOWTask`` objects. """ - return UOWDependencyProcessor(self.processor, task) @@ -944,13 +768,11 @@ class UOWExecutor(object): def execute_save_steps(self, trans, task): self.save_objects(trans, task) self.execute_cyclical_dependencies(trans, task, False) - self.execute_per_element_childtasks(trans, task, False) self.execute_dependencies(trans, task, False) self.execute_dependencies(trans, task, True) - + def execute_delete_steps(self, trans, task): self.execute_cyclical_dependencies(trans, task, True) - self.execute_per_element_childtasks(trans, task, True) self.delete_objects(trans, task) def execute_dependencies(self, trans, task, isdelete=None): @@ -964,12 +786,5 @@ class UOWExecutor(object): def execute_cyclical_dependencies(self, trans, task, isdelete): for dep in task.polymorphic_cyclical_dependencies: self.execute_dependency(trans, dep, isdelete) - - def execute_per_element_childtasks(self, trans, task, isdelete): - for element in task.polymorphic_tosave_elements + task.polymorphic_todelete_elements: - self.execute_element_childtasks(trans, element, isdelete) - - def execute_element_childtasks(self, trans, element, isdelete): - for child in element.childtasks: - self.execute(trans, [child], isdelete) - + for t in task.dependent_tasks: + self.execute(trans, [t], isdelete) diff --git a/lib/sqlalchemy/orm/uowdumper.py b/lib/sqlalchemy/orm/uowdumper.py index 4b3fed70a..09b82167d 100644 --- a/lib/sqlalchemy/orm/uowdumper.py +++ b/lib/sqlalchemy/orm/uowdumper.py @@ -6,17 +6,15 @@ """Dumps out a string representation of a UOWTask structure""" +from sqlalchemy import util from sqlalchemy.orm import unitofwork from sqlalchemy.orm import util as mapperutil -from sqlalchemy import util class UOWDumper(unitofwork.UOWExecutor): - def __init__(self, tasks, buf, verbose=False): - self.verbose = verbose + def __init__(self, tasks, buf): self.indent = 0 self.tasks = tasks self.buf = buf - self.headers = {} self.execute(None, tasks) def execute(self, trans, tasks, isdelete=None): @@ -62,88 +60,23 @@ class UOWDumper(unitofwork.UOWExecutor): for rec in l: if rec.listonly: continue - self.header("Save elements"+ self._inheritance_tag(task)) self.buf.write(self._indent()[:-1] + "+-" + self._repr_task_element(rec) + "\n") - self.closeheader() def delete_objects(self, trans, task): for rec in task.polymorphic_todelete_elements: if rec.listonly: continue - self.header("Delete elements"+ self._inheritance_tag(task)) self.buf.write(self._indent() + "- " + self._repr_task_element(rec) + "\n") - self.closeheader() - - def _inheritance_tag(self, task): - if not self.verbose: - return "" - else: - return (" (inheriting task %s)" % self._repr_task(task)) - - def header(self, text): - """Write a given header just once.""" - - if not self.verbose: - return - try: - self.headers[text] - except KeyError: - self.buf.write(self._indent() + "- " + text + "\n") - self.headers[text] = True - - def closeheader(self): - if not self.verbose: - return - self.buf.write(self._indent() + "- ------\n") def execute_dependency(self, transaction, dep, isdelete): self._dump_processor(dep, isdelete) - def execute_save_steps(self, trans, task): - super(UOWDumper, self).execute_save_steps(trans, task) - - def execute_delete_steps(self, trans, task): - super(UOWDumper, self).execute_delete_steps(trans, task) - - def execute_dependencies(self, trans, task, isdelete=None): - super(UOWDumper, self).execute_dependencies(trans, task, isdelete) - - def execute_cyclical_dependencies(self, trans, task, isdelete): - self.header("Cyclical %s dependencies" % (isdelete and "delete" or "save")) - super(UOWDumper, self).execute_cyclical_dependencies(trans, task, isdelete) - self.closeheader() - - def execute_per_element_childtasks(self, trans, task, isdelete): - super(UOWDumper, self).execute_per_element_childtasks(trans, task, isdelete) - - def execute_element_childtasks(self, trans, element, isdelete): - self.header("%s subelements of UOWTaskElement(%s)" % ((isdelete and "Delete" or "Save"), hex(id(element)))) - super(UOWDumper, self).execute_element_childtasks(trans, element, isdelete) - self.closeheader() - def _dump_processor(self, proc, deletes): if deletes: val = proc.targettask.polymorphic_todelete_elements else: val = proc.targettask.polymorphic_tosave_elements - if self.verbose: - self.buf.write(self._indent() + " +- %s attribute on %s (UOWDependencyProcessor(%d) processing %s)\n" % ( - repr(proc.processor.key), - ("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")), - hex(id(proc)), - self._repr_task(proc.targettask)) - ) - elif False: - self.buf.write(self._indent() + " +- %s attribute on %s\n" % ( - repr(proc.processor.key), - ("%s's to be %s" % (self._repr_task_class(proc.targettask), deletes and "deleted" or "saved")), - ) - ) - - if len(val) == 0: - if self.verbose: - self.buf.write(self._indent() + " +- " + "(no objects)\n") for v in val: self.buf.write(self._indent() + " +- " + self._repr_task_element(v, proc.processor.key, process=True) + "\n") @@ -155,9 +88,7 @@ class UOWDumper(unitofwork.UOWExecutor): objid = "%s.%s" % (mapperutil.state_str(te.state), attribute) else: objid = mapperutil.state_str(te.state) - if self.verbose: - return "%s (UOWTaskElement(%s, %s))" % (objid, hex(id(te)), (te.listonly and 'listonly' or (te.isdelete and 'delete' or 'save'))) - elif process: + if process: return "Process %s" % (objid) else: return "%s %s" % ((te.isdelete and "Delete" or "Save"), objid) diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 19e5e59b9..09b5aa778 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -4,15 +4,19 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from sqlalchemy import sql, util, exceptions -from sqlalchemy.sql import util as sql_util -from sqlalchemy.sql.util import row_adapter as create_row_adapter -from sqlalchemy.sql import visitors -from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator +import new -all_cascades = util.Set(["delete", "delete-orphan", "all", "merge", +import sqlalchemy.exceptions as sa_exc +from sqlalchemy import sql, util +from sqlalchemy.sql import expression, util as sql_util, operators +from sqlalchemy.orm.interfaces import MapperExtension, EXT_CONTINUE, PropComparator, MapperProperty +from sqlalchemy.orm import attributes + +all_cascades = util.FrozenSet(["delete", "delete-orphan", "all", "merge", "expunge", "save-update", "refresh-expire", "none"]) +_INSTRUMENTOR = ('mapper', 'instrumentor') + class CascadeOptions(object): """Keeps track of the options sent to relation().cascade""" @@ -26,7 +30,7 @@ class CascadeOptions(object): self.refresh_expire = "refresh-expire" in values or "all" in values for x in values: if x not in all_cascades: - raise exceptions.ArgumentError("Invalid cascade option '%s'" % x) + raise sa_exc.ArgumentError("Invalid cascade option '%s'" % x) def __contains__(self, item): return getattr(self, item.replace("-", "_"), False) @@ -78,235 +82,277 @@ def polymorphic_union(table_map, typecolname, aliasname='p_union'): result.append(sql.select([col(name, table) for name in colnames], from_obj=[table])) return sql.union_all(*result).alias(aliasname) +def identity_key(*args, **kwargs): + """Get an identity key. -class ExtensionCarrier(object): - """stores a collection of MapperExtension objects. - - allows an extension methods to be called on contained MapperExtensions - in the order they were added to this object. Also includes a 'methods' dictionary - accessor which allows for a quick check if a particular method - is overridden on any contained MapperExtensions. + Valid call signatures: + + * ``identity_key(class, ident, entity_name=None)`` + + class + mapped class (must be a positional argument) + + ident + primary key, if the key is composite this is a tuple + + entity_name + optional entity name + + * ``identity_key(instance=instance)`` + + instance + object instance (must be given as a keyword arg) + + * ``identity_key(class, row=row, entity_name=None)`` + + class + mapped class (must be a positional argument) + + row + result proxy row (must be given as a keyword arg) + + entity_name + optional entity name (must be given as a keyword arg) """ + from sqlalchemy.orm import class_mapper, object_mapper + if args: + if len(args) == 1: + class_ = args[0] + try: + row = kwargs.pop("row") + except KeyError: + ident = kwargs.pop("ident") + entity_name = kwargs.pop("entity_name", None) + elif len(args) == 2: + class_, ident = args + entity_name = kwargs.pop("entity_name", None) + elif len(args) == 3: + class_, ident, entity_name = args + else: + raise sa_exc.ArgumentError("expected up to three " + "positional arguments, got %s" % len(args)) + if kwargs: + raise sa_exc.ArgumentError("unknown keyword arguments: %s" + % ", ".join(kwargs.keys())) + mapper = class_mapper(class_, entity_name=entity_name) + if "ident" in locals(): + return mapper.identity_key_from_primary_key(ident) + return mapper.identity_key_from_row(row) + instance = kwargs.pop("instance") + if kwargs: + raise sa_exc.ArgumentError("unknown keyword arguments: %s" + % ", ".join(kwargs.keys())) + mapper = object_mapper(instance) + return mapper.identity_key_from_instance(instance) - def __init__(self, _elements=None): +class ExtensionCarrier(object): + """Fronts an ordered collection of MapperExtension objects. + + Bundles multiple MapperExtensions into a unified callable unit, + encapsulating ordering, looping and EXT_CONTINUE logic. The + ExtensionCarrier implements the MapperExtension interface, e.g.:: + + carrier.after_insert(...args...) + + Also includes a 'methods' dictionary accessor which allows for a quick + check if a particular method is overridden on any contained + MapperExtensions. + + """ + + interface = util.Set([method for method in dir(MapperExtension) + if not method.startswith('_')]) + + def __init__(self, extensions=None): self.methods = {} - if _elements is not None: - self.__elements = [self.__inspect(e) for e in _elements] - else: - self.__elements = [] - - def copy(self): - return ExtensionCarrier(list(self.__elements)) - - def __iter__(self): - return iter(self.__elements) + self._extensions = [] + for ext in extensions or (): + self.append(ext) - def insert(self, extension): - """Insert a MapperExtension at the beginning of this ExtensionCarrier's list.""" + def copy(self): + return ExtensionCarrier(self._extensions) - self.__elements.insert(0, self.__inspect(extension)) + def push(self, extension): + """Insert a MapperExtension at the beginning of the collection.""" + self._register(extension) + self._extensions.insert(0, extension) def append(self, extension): - """Append a MapperExtension at the end of this ExtensionCarrier's list.""" + """Append a MapperExtension at the end of the collection.""" + self._register(extension) + self._extensions.append(extension) - self.__elements.append(self.__inspect(extension)) + def __iter__(self): + """Iterate over MapperExtensions in the collection.""" + return iter(self._extensions) + + def _register(self, extension): + """Register callable fronts for overridden interface methods.""" + for method in self.interface: + if method in self.methods: + continue + impl = getattr(extension, method, None) + if impl and impl is not getattr(MapperExtension, method): + self.methods[method] = self._create_do(method) + + def _create_do(self, method): + """Return a closure that loops over impls of the named method.""" - def __inspect(self, extension): - for meth in MapperExtension.__dict__.keys(): - if meth not in self.methods and hasattr(extension, meth) and getattr(extension, meth) is not getattr(MapperExtension, meth): - self.methods[meth] = self.__create_do(meth) - return extension - - def __create_do(self, funcname): def _do(*args, **kwargs): - for elem in self.__elements: - ret = getattr(elem, funcname)(*args, **kwargs) + for ext in self._extensions: + ret = getattr(ext, method)(*args, **kwargs) if ret is not EXT_CONTINUE: return ret else: return EXT_CONTINUE - try: - _do.__name__ = funcname + _do.__name__ = method.im_func.func_name except: - # cant set __name__ in py 2.3 pass return _do - - def _pass(self, *args, **kwargs): + + def _pass(*args, **kwargs): return EXT_CONTINUE - + _pass = staticmethod(_pass) + def __getattr__(self, key): + """Delegate MapperExtension methods to bundled fronts.""" + if key not in self.interface: + raise AttributeError(key) return self.methods.get(key, self._pass) -class AliasedClauses(object): - """Creates aliases of a mapped tables for usage in ORM queries, and provides expression adaptation.""" - - def __init__(self, alias, equivalents=None, chain_to=None, should_adapt=True): - self.alias = alias - self.equivalents = equivalents - self.row_decorator = self._create_row_adapter() - self.should_adapt = should_adapt - if should_adapt: - self.adapter = sql_util.ClauseAdapter(self.alias, equivalents=equivalents) +class ORMAdapter(sql_util.ColumnAdapter): + def __init__(self, entity, equivalents=None, chain_to=None): + mapper, selectable, is_aliased_class = _entity_info(entity) + if is_aliased_class: + self.aliased_class = entity else: - self.adapter = visitors.NullVisitor() - - if chain_to: - self.adapter.chain(chain_to.adapter) - - def aliased_column(self, column): - if not self.should_adapt: - return column - - conv = self.alias.corresponding_column(column) - if conv: - return conv - - # process column-level subqueries - aliased_column = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).traverse(column, clone=True) - - # anonymize labels which might have specific names - if isinstance(aliased_column, expression._Label): - aliased_column = aliased_column.label(None) - - # add to row decorator explicitly - self.row_decorator({}).map[column] = aliased_column - return aliased_column - - def adapt_clause(self, clause): - return self.adapter.traverse(clause, clone=True) - - def adapt_list(self, clauses): - return self.adapter.copy_and_process(clauses) - - def _create_row_adapter(self): - return create_row_adapter(self.alias, equivalent_columns=self.equivalents) + self.aliased_class = None + sql_util.ColumnAdapter.__init__(self, selectable, equivalents, chain_to) +class AliasedClass(object): + def __init__(self, cls, alias=None, name=None): + self.__mapper = _class_to_mapper(cls) + self.__target = self.__mapper.class_ + alias = alias or self.__mapper._with_polymorphic_selectable.alias() + self.__adapter = sql_util.ClauseAdapter(alias, equivalents=self.__mapper._equivalent_columns) + self.__alias = alias + self._sa_label_name = name + self.__name__ = 'AliasedClass_' + str(self.__target) + + def __adapt_prop(self, prop): + existing = getattr(self.__target, prop.key) + comparator = AliasedComparator(self, self.__adapter, existing.comparator) + queryattr = attributes.QueryableAttribute( + existing.impl, parententity=self, comparator=comparator) + setattr(self, prop.key, queryattr) + return queryattr -class PropertyAliasedClauses(AliasedClauses): - """extends AliasedClauses to add support for primary/secondary joins on a relation().""" - - def __init__(self, prop, primaryjoin, secondaryjoin, parentclauses=None, alias=None, should_adapt=True): - self.prop = prop - self.mapper = self.prop.mapper - self.table = self.prop.table - self.parentclauses = parentclauses - - if not alias: - from_obj = self.mapper._with_polymorphic_selectable() - alias = from_obj.alias() - - super(PropertyAliasedClauses, self).__init__(alias, equivalents=self.mapper._equivalent_columns, chain_to=parentclauses, should_adapt=should_adapt) - - if prop.secondary: - self.secondary = prop.secondary.alias() - primary_aliasizer = sql_util.ClauseAdapter(self.secondary) - secondary_aliasizer = sql_util.ClauseAdapter(self.alias, equivalents=self.equivalents).chain(sql_util.ClauseAdapter(self.secondary)) - - if parentclauses is not None: - primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, equivalents=parentclauses.equivalents)) - - self.secondaryjoin = secondary_aliasizer.traverse(secondaryjoin, clone=True) - self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True) - else: - primary_aliasizer = sql_util.ClauseAdapter(self.alias, exclude=prop.local_side, equivalents=self.equivalents) - if parentclauses is not None: - primary_aliasizer.chain(sql_util.ClauseAdapter(parentclauses.alias, exclude=prop.remote_side, equivalents=parentclauses.equivalents)) - - self.primaryjoin = primary_aliasizer.traverse(primaryjoin, clone=True) - self.secondary = None - self.secondaryjoin = None - - if prop.order_by: - if prop.secondary: - # usually this is not used but occasionally someone has a sort key in their secondary - # table, even tho SA does not support writing this column directly - self.order_by = secondary_aliasizer.copy_and_process(util.to_list(prop.order_by)) + def __getattr__(self, key): + prop = self.__mapper._get_property(key, raiseerr=False) + if prop: + return self.__adapt_prop(prop) + + for base in self.__target.__mro__: + try: + attr = object.__getattribute__(base, key) + except AttributeError: + continue else: - self.order_by = primary_aliasizer.copy_and_process(util.to_list(prop.order_by)) - + break else: - self.order_by = None + raise AttributeError(key) -class AliasedClass(object): - def __new__(cls, target): - from sqlalchemy.orm import attributes - mapper = _class_to_mapper(target) - alias = mapper.mapped_table.alias() - retcls = type(target.__name__ + "Alias", (cls,), {'alias':alias}) - retcls._class_state = mapper._class_state - for prop in mapper.iterate_properties: - existing = mapper._class_state.attrs[prop.key] - setattr(retcls, prop.key, attributes.InstrumentedAttribute(existing.impl, comparator=AliasedComparator(alias, existing.comparator))) - - return retcls + if hasattr(attr, 'func_code'): + is_method = getattr(self.__target, key, None) + if is_method and is_method.im_self is not None: + return new.instancemethod(attr.im_func, self, self) + else: + return None + elif hasattr(attr, '__get__'): + return attr.__get__(None, self) + else: + return attr - def __init__(self, alias): - self.alias = alias + def __repr__(self): + return '<AliasedClass at 0x%x; %s>' % ( + id(self), self.__target.__name__) class AliasedComparator(PropComparator): - def __init__(self, alias, comparator): - self.alias = alias + def __init__(self, aliasedclass, adapter, comparator): + self.aliasedclass = aliasedclass self.comparator = comparator - self.adapter = sql_util.ClauseAdapter(alias) + self.adapter = adapter + self.__clause_element = self.adapter.traverse(self.comparator.__clause_element__())._annotate({'parententity': aliasedclass}) - def clause_element(self): - return self.adapter.traverse(self.comparator.clause_element(), clone=True) + def __clause_element__(self): + return self.__clause_element def operate(self, op, *other, **kwargs): - return self.adapter.traverse(self.comparator.operate(op, *other, **kwargs), clone=True) + return self.adapter.traverse(self.comparator.operate(op, *other, **kwargs)) def reverse_operate(self, op, other, **kwargs): - return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs), clone=True) - -from sqlalchemy.sql import expression -_selectable = expression._selectable -def _orm_selectable(selectable): - if _is_mapped_class(selectable): - if _is_aliased_class(selectable): - return selectable.alias - else: - return _class_to_mapper(selectable)._with_polymorphic_selectable() - else: - return _selectable(selectable) -expression._selectable = _orm_selectable + return self.adapter.traverse(self.comparator.reverse_operate(op, *other, **kwargs)) + +def _orm_annotate(element, exclude=None): + def clone(elem): + if exclude and elem in exclude: + elem = elem._clone() + elif '_orm_adapt' not in elem._annotations: + elem = elem._annotate({'_orm_adapt':True}) + elem._copy_internals(clone=clone) + return elem + + if element is not None: + element = clone(element) + return element + class _ORMJoin(expression.Join): - """future functionality.""" __visit_name__ = expression.Join.__visit_name__ - + def __init__(self, left, right, onclause=None, isouter=False): - if _is_mapped_class(left) or _is_mapped_class(right): - if hasattr(left, '_orm_mappers'): - left_mapper = left._orm_mappers[1] - adapt_from = left.right + if hasattr(left, '_orm_mappers'): + left_mapper = left._orm_mappers[1] + adapt_from = left.right + + else: + left_mapper, left, left_is_aliased = _entity_info(left) + if left_is_aliased or not left_mapper: + adapt_from = left else: - left_mapper = _class_to_mapper(left) - if _is_aliased_class(left): - adapt_from = left.alias - else: - adapt_from = None + adapt_from = None - right_mapper = _class_to_mapper(right) + right_mapper, right, right_is_aliased = _entity_info(right) + if right_is_aliased: + adapt_to = right + else: + adapt_to = None + + if left_mapper or right_mapper: self._orm_mappers = (left_mapper, right_mapper) - + if isinstance(onclause, basestring): prop = left_mapper.get_property(onclause) + elif isinstance(onclause, attributes.QueryableAttribute): + adapt_from = onclause.__clause_element__() + prop = onclause.property + elif isinstance(onclause, MapperProperty): + prop = onclause + else: + prop = None - if _is_aliased_class(right): - adapt_to = right.alias - else: - adapt_to = None - - pj, sj, source, dest, target_adapter = prop._create_joins(source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, dest_polymorphic=True) + if prop: + pj, sj, source, dest, secondary, target_adapter = prop._create_joins(source_selectable=adapt_from, dest_selectable=adapt_to, source_polymorphic=True, dest_polymorphic=True) if sj: - left = sql.join(left, prop.secondary, onclause=pj) + left = sql.join(left, secondary, pj, isouter) onclause = sj else: onclause = pj + self._target_adapter = target_adapter + expression.Join.__init__(self, left, right, onclause, isouter) def join(self, right, onclause=None, isouter=False): @@ -315,37 +361,81 @@ class _ORMJoin(expression.Join): def outerjoin(self, right, onclause=None): return _ORMJoin(self, right, onclause, True) -def _join(left, right, onclause=None): - """future functionality.""" - - return _ORMJoin(left, right, onclause, False) - -def _outerjoin(left, right, onclause=None): - """future functionality.""" +def join(left, right, onclause=None, isouter=False): + return _ORMJoin(left, right, onclause, isouter) +def outerjoin(left, right, onclause=None): return _ORMJoin(left, right, onclause, True) - -def has_identity(object): - return hasattr(object, '_instance_key') -def _state_has_identity(state): - return '_instance_key' in state.dict +def with_parent(instance, prop): + """Return criterion which selects instances with a given parent. -def _is_mapped_class(cls): - return hasattr(cls, '_class_state') + instance + a parent instance, which should be persistent or detached. + + property + a class-attached descriptor, MapperProperty or string property name + attached to the parent instance. + + \**kwargs + all extra keyword arguments are propagated to the constructor of + Query. -def _is_aliased_class(obj): - return isinstance(obj, type) and issubclass(obj, AliasedClass) - -def has_mapper(object): - """Return True if the given object has had a mapper association - set up, either through loading, or via insertion in a session. """ + if isinstance(prop, basestring): + mapper = object_mapper(instance) + prop = mapper.get_property(prop, resolve_synonyms=True) + elif isinstance(prop, attributes.QueryableAttribute): + prop = prop.property + + return prop.compare(operators.eq, instance, value_is_parent=True) + + +def _entity_info(entity, entity_name=None, compile=True): + if isinstance(entity, AliasedClass): + return entity._AliasedClass__mapper, entity._AliasedClass__alias, True + elif _is_mapped_class(entity): + if isinstance(entity, type): + mapper = class_mapper(entity, entity_name, compile) + else: + if compile: + mapper = entity.compile() + else: + mapper = entity + return mapper, mapper._with_polymorphic_selectable, False + else: + return None, entity, False + +def _entity_descriptor(entity, key): + if isinstance(entity, AliasedClass): + desc = getattr(entity, key) + return desc, desc.property + elif isinstance(entity, type): + desc = attributes.manager_of_class(entity)[key] + return desc, desc.property + else: + desc = entity.class_manager[key] + return desc, desc.property + +def _orm_columns(entity): + mapper, selectable, is_aliased_class = _entity_info(entity) + if isinstance(selectable, expression.Selectable): + return [c for c in selectable.c] + else: + return [selectable] + +def _orm_selectable(entity): + mapper, selectable, is_aliased_class = _entity_info(entity) + return selectable - return hasattr(object, '_entity_name') +def _is_aliased_class(entity): + return isinstance(entity, AliasedClass) def _state_mapper(state, entity_name=None): - return state.class_._class_state.mappers[state.dict.get('_entity_name', entity_name)] + if state.entity_name is not attributes.NO_ENTITY_NAME: + # Override the given entity name if the object is not transient. + entity_name = state.entity_name + return state.manager.mappers[entity_name] def object_mapper(object, entity_name=None, raiseerror=True): """Given an object, return the primary Mapper associated with the object instance. @@ -363,36 +453,40 @@ def object_mapper(object, entity_name=None, raiseerror=True): be located. If False, return None. """ - - try: - mapper = object.__class__._class_state.mappers[getattr(object, '_entity_name', entity_name)] - except (KeyError, AttributeError): - if raiseerror: - raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (object.__class__.__name__, getattr(object, '_entity_name', entity_name))) - else: - return None - return mapper + state = attributes.instance_state(object) + if state.entity_name is not attributes.NO_ENTITY_NAME: + # Override the given entity name if the object is not transient. + entity_name = state.entity_name + return class_mapper( + type(object), entity_name=entity_name, + compile=False, raiseerror=raiseerror) def class_mapper(class_, entity_name=None, compile=True, raiseerror=True): - """Given a class and optional entity_name, return the primary Mapper associated with the key. + """Given a class (or an object) and optional entity_name, return the primary Mapper associated with the key. If no mapper can be located, raises ``InvalidRequestError``. - """ + """ + + if not isinstance(class_, type): + class_ = type(class_) try: - mapper = class_._class_state.mappers[entity_name] + class_manager = attributes.manager_of_class(class_) + mapper = class_manager.mappers[entity_name] except (KeyError, AttributeError): - if raiseerror: - raise exceptions.InvalidRequestError("Class '%s' entity name '%s' has no mapper associated with it" % (class_.__name__, entity_name)) - else: - return None + if not raiseerror: + return + raise sa_exc.InvalidRequestError( + "Class '%s' entity name '%s' has no mapper associated with it" % + (class_.__name__, entity_name)) if compile: - return mapper.compile() - else: - return mapper + mapper = mapper.compile() + return mapper def _class_to_mapper(class_or_mapper, entity_name=None, compile=True): - if isinstance(class_or_mapper, type): + if _is_aliased_class(class_or_mapper): + return class_or_mapper._AliasedClass__mapper + elif isinstance(class_or_mapper, type): return class_mapper(class_or_mapper, entity_name=entity_name, compile=compile) else: if compile: @@ -400,10 +494,32 @@ def _class_to_mapper(class_or_mapper, entity_name=None, compile=True): else: return class_or_mapper +def has_identity(object): + state = attributes.instance_state(object) + return _state_has_identity(state) + +def _state_has_identity(state): + return bool(state.key) + +def has_mapper(object): + state = attributes.instance_state(object) + return _state_has_mapper(state) + +def _state_has_mapper(state): + return state.entity_name is not attributes.NO_ENTITY_NAME + +def _is_mapped_class(cls): + from sqlalchemy.orm import mapperlib as mapper + if isinstance(cls, (AliasedClass, mapper.Mapper)): + return True + + manager = attributes.manager_of_class(cls) + return manager and _INSTRUMENTOR in manager.info + def instance_str(instance): """Return a string describing an instance.""" - return instance.__class__.__name__ + "@" + hex(id(instance)) + return state_str(attributes.instance_state(instance)) def state_str(state): """Return a string describing an instance.""" @@ -415,12 +531,24 @@ def state_str(state): def attribute_str(instance, attribute): return instance_str(instance) + "." + attribute +def state_attribute_str(state, attribute): + return state_str(state) + "." + attribute + def identity_equal(a, b): if a is b: return True - id_a = getattr(a, '_instance_key', None) - id_b = getattr(b, '_instance_key', None) - if id_a is None or id_b is None: + if a is None or b is None: + return False + try: + state_a = attributes.instance_state(a) + state_b = attributes.instance_state(b) + except (KeyError, AttributeError): + return False + if state_a.key is None or state_b.key is None: return False - return id_a == id_b + return state_a.key == state_b.key +# TODO: Avoid circular import. +attributes.identity_equal = identity_equal +attributes._is_aliased_class = _is_aliased_class +attributes._entity_info = _entity_info diff --git a/lib/sqlalchemy/pool.py b/lib/sqlalchemy/pool.py index 31adf77d1..c1b29a1d0 100644 --- a/lib/sqlalchemy/pool.py +++ b/lib/sqlalchemy/pool.py @@ -18,7 +18,7 @@ SQLAlchemy connection pool. import weakref, time -from sqlalchemy import exceptions, logging +from sqlalchemy import exc, log from sqlalchemy import queue as Queue from sqlalchemy.util import thread, threading, pickle, as_interface @@ -118,7 +118,7 @@ class Pool(object): """ def __init__(self, creator, recycle=-1, echo=None, use_threadlocal=True, reset_on_return=True, listeners=None): - self.logger = logging.instance_logger(self, echoflag=echo) + self.logger = log.instance_logger(self, echoflag=echo) # the WeakValueDictionary works more nicely than a regular dict of # weakrefs. the latter can pile up dead reference objects which don't # get cleaned out. WVD adds from 1-6 method calls to a checkout @@ -342,7 +342,7 @@ class _ConnectionFairy(object): return self._connection_record.info except AttributeError: if self.connection is None: - raise exceptions.InvalidRequestError("This connection is closed") + raise exc.InvalidRequestError("This connection is closed") try: return self._detached_info except AttributeError: @@ -359,7 +359,7 @@ class _ConnectionFairy(object): """ if self.connection is None: - raise exceptions.InvalidRequestError("This connection is closed") + raise exc.InvalidRequestError("This connection is closed") if self._connection_record is not None: self._connection_record.invalidate(e=e) self.connection = None @@ -378,8 +378,8 @@ class _ConnectionFairy(object): def checkout(self): if self.connection is None: - raise exceptions.InvalidRequestError("This connection is closed") - self.__counter +=1 + raise exc.InvalidRequestError("This connection is closed") + self.__counter += 1 if not self._pool._on_checkout or self.__counter != 1: return self @@ -391,7 +391,7 @@ class _ConnectionFairy(object): for l in self._pool._on_checkout: l.checkout(self.connection, self._connection_record, self) return self - except exceptions.DisconnectionError, e: + except exc.DisconnectionError, e: if self._pool._should_log_info: self._pool.log( "Disconnection detected on checkout: %s" % e) @@ -402,7 +402,7 @@ class _ConnectionFairy(object): if self._pool._should_log_info: self._pool.log("Reconnection attempts exhausted on checkout") self.invalidate() - raise exceptions.InvalidRequestError("This connection is closed") + raise exc.InvalidRequestError("This connection is closed") def detach(self): """Separate this connection from its Pool. @@ -426,7 +426,7 @@ class _ConnectionFairy(object): self._connection_record = None def close(self): - self.__counter -=1 + self.__counter -= 1 if self.__counter == 0: self._close() @@ -601,7 +601,7 @@ class QueuePool(Pool): if not wait: return self.do_get() else: - raise exceptions.TimeoutError("QueuePool limit of size %d overflow %d reached, connection timed out, timeout %d" % (self.size(), self.overflow(), self._timeout)) + raise exc.TimeoutError("QueuePool limit of size %d overflow %d reached, connection timed out, timeout %d" % (self.size(), self.overflow(), self._timeout)) if self._overflow_lock is not None: self._overflow_lock.acquire() @@ -658,10 +658,10 @@ class NullPool(Pool): return "NullPool" def do_return_conn(self, conn): - conn.close() + conn.close() def do_return_invalid(self, conn): - pass + pass def do_get(self): return self.create_connection() diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 9a4bf4109..1f0b52ace 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -27,7 +27,7 @@ are part of the SQL expression language, they are usable as components in SQL expressions. """ import re, inspect -from sqlalchemy import types, exceptions, util, databases +from sqlalchemy import types, exc, util, databases from sqlalchemy.sql import expression, visitors URL = None @@ -42,10 +42,11 @@ class SchemaItem(object): """Base class for items that define a database schema.""" __metaclass__ = expression._FigureVisitName - + quote = None + def _init_items(self, *args): """Initialize the list of child items for this SchemaItem.""" - + for item in args: if item is not None: item._set_parent(self) @@ -95,7 +96,7 @@ class _TableSingleton(expression._FigureVisitName): try: table = metadata.tables[key] if not useexisting and table._cant_override(*args, **kwargs): - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "Table '%s' is already defined for this MetaData instance. " "Specify 'useexisting=True' to redefine options and " "columns on an existing Table object." % key) @@ -104,7 +105,7 @@ class _TableSingleton(expression._FigureVisitName): return table except KeyError: if mustexist: - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( "Table '%s' not defined" % (key)) try: return type.__call__(self, name, metadata, *args, **kwargs) @@ -182,17 +183,19 @@ class Table(SchemaItem, expression.TableClause): Deprecated; this is an oracle-only argument - "schema" should be used in its place. - quote - When True, indicates that the Table identifier must be quoted. - This flag does *not* disable quoting; for case-insensitive names, - use an all lower case identifier. + quote + Force quoting of the identifier on or off, based on `True` or + `False`. Defaults to `None`. This flag is rarely needed, + as quoting is normally applied + automatically for known reserved words, as well as for + "case sensitive" identifiers. An identifier is "case sensitive" + if it contains non-lowercase letters, otherwise it's + considered to be "case insensitive". quote_schema - When True, indicates that the schema identifier must be quoted. - This flag does *not* disable quoting; for case-insensitive names, - use an all lower case identifier. + same as 'quote' but applies to the schema identifier. + """ - super(Table, self).__init__(name) self.metadata = metadata self.schema = kwargs.pop('schema', kwargs.pop('owner', None)) @@ -214,7 +217,7 @@ class Table(SchemaItem, expression.TableClause): self._set_parent(metadata) - self.__extra_kwargs(**kwargs) + self.__extra_kwargs(**kwargs) # load column definitions from the database if 'autoload' is defined # we do it after the table is in the singleton dictionary to support @@ -234,7 +237,7 @@ class Table(SchemaItem, expression.TableClause): autoload_with = kwargs.pop('autoload_with', None) schema = kwargs.pop('schema', None) if schema and schema != self.schema: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Can't change schema of existing table from '%s' to '%s'", (self.schema, schema)) @@ -258,8 +261,8 @@ class Table(SchemaItem, expression.TableClause): ['autoload', 'autoload_with', 'schema', 'owner'])) def __extra_kwargs(self, **kwargs): - self.quote = kwargs.pop('quote', False) - self.quote_schema = kwargs.pop('quote_schema', False) + self.quote = kwargs.pop('quote', None) + self.quote_schema = kwargs.pop('quote_schema', None) if kwargs.get('info'): self._info = kwargs.pop('info') @@ -488,9 +491,13 @@ class Column(SchemaItem, expression._ColumnClause): or subtype of Integer. quote - When True, indicates that the Column identifier must be quoted. - This flag does *not* disable quoting; for case-insensitive names, - use an all lower case identifier. + Force quoting of the identifier on or off, based on `True` or + `False`. Defaults to `None`. This flag is rarely needed, + as quoting is normally applied + automatically for known reserved words, as well as for + "case sensitive" identifiers. An identifier is "case sensitive" + if it contains non-lowercase letters, otherwise it's + considered to be "case insensitive". """ name = kwargs.pop('name', None) @@ -499,7 +506,7 @@ class Column(SchemaItem, expression._ColumnClause): args = list(args) if isinstance(args[0], basestring): if name is not None: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "May not pass name positionally and as a keyword.") name = args.pop(0) if args: @@ -507,7 +514,7 @@ class Column(SchemaItem, expression._ColumnClause): (isinstance(args[0], type) and issubclass(args[0], types.AbstractType))): if type_ is not None: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "May not pass type_ positionally and as a keyword.") type_ = args.pop(0) @@ -520,15 +527,17 @@ class Column(SchemaItem, expression._ColumnClause): self.default = kwargs.pop('default', None) self.index = kwargs.pop('index', None) self.unique = kwargs.pop('unique', None) - self.quote = kwargs.pop('quote', False) + self.quote = kwargs.pop('quote', None) self.onupdate = kwargs.pop('onupdate', None) self.autoincrement = kwargs.pop('autoincrement', True) self.constraints = util.Set() self.foreign_keys = util.OrderedSet() + util.set_creation_order(self) + if kwargs.get('info'): self._info = kwargs.pop('info') if kwargs: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Unknown arguments passed to Column: " + repr(kwargs.keys())) def __str__(self): @@ -545,7 +554,7 @@ class Column(SchemaItem, expression._ColumnClause): bind = property(bind) def references(self, column): - """Return True if this references the given column via a foreign key.""" + """Return True if this Column references the given column via foreign key.""" for fk in self.foreign_keys: if fk.references(column.table): return True @@ -576,14 +585,14 @@ class Column(SchemaItem, expression._ColumnClause): def _set_parent(self, table): if self.name is None: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Column must be constructed with a name or assign .name " "before adding to a Table.") if self.key is None: self.key = self.name self.metadata = table.metadata if getattr(self, 'table', None) is not None: - raise exceptions.ArgumentError("this Column already has a table!") + raise exc.ArgumentError("this Column already has a table!") if not self._is_oid: self._pre_existing_column = table._columns.get(self.key) @@ -594,7 +603,7 @@ class Column(SchemaItem, expression._ColumnClause): if self.primary_key: table.primary_key.replace(self) elif self.key in table.primary_key: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Trying to redefine primary-key column '%s' as a " "non-primary-key column on table '%s'" % ( self.key, table.fullname)) @@ -604,14 +613,14 @@ class Column(SchemaItem, expression._ColumnClause): if self.index: if isinstance(self.index, basestring): - raise exceptions.ArgumentError( + raise exc.ArgumentError( "The 'index' keyword argument on Column is boolean only. " "To create indexes with a specific name, create an " "explicit Index object external to the Table.") Index('ix_%s' % self._label, self, unique=self.unique) elif self.unique: if isinstance(self.unique, basestring): - raise exceptions.ArgumentError( + raise exc.ArgumentError( "The 'unique' keyword argument on Column is boolean only. " "To create unique constraints or indexes with a specific " "name, append an explicit UniqueConstraint to the Table's " @@ -631,17 +640,17 @@ class Column(SchemaItem, expression._ColumnClause): """Create a copy of this ``Column``, unitialized. This is used in ``Table.tometadata``. - """ + """ return Column(self.name, self.type, self.default, key = self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, index=self.index, autoincrement=self.autoincrement, *[c.copy() for c in self.constraints]) - - def _make_proxy(self, selectable, name = None): + + def _make_proxy(self, selectable, name=None): """Create a *proxy* for this column. This is a copy of this ``Column`` referenced by a different parent (such as an alias or select statement). - """ + """ fk = [ForeignKey(f._colspec) for f in self.foreign_keys] c = Column(name or self.name, self.type, self.default, key = name or self.key, primary_key = self.primary_key, nullable = self.nullable, _is_oid = self._is_oid, quote=self.quote, *fk) c.table = selectable @@ -654,7 +663,6 @@ class Column(SchemaItem, expression._ColumnClause): [c._init_items(f) for f in fk] return c - def get_children(self, schema_visitor=False, **kwargs): if schema_visitor: return [x for x in (self.default, self.onupdate) if x is not None] + \ @@ -670,8 +678,8 @@ class ForeignKey(SchemaItem): For a composite (multiple column) FOREIGN KEY, use a ForeignKeyConstraint within the Table definition. - """ + """ def __init__(self, column, constraint=None, use_alter=False, name=None, onupdate=None, ondelete=None, deferrable=None, initially=None): """Construct a column-level FOREIGN KEY. @@ -742,14 +750,15 @@ class ForeignKey(SchemaItem): def references(self, table): """Return True if the given table is referenced by this ForeignKey.""" - return table.corresponding_column(self.column) is not None def get_referent(self, table): """Return the column in the given table referenced by this ForeignKey. Returns None if this ``ForeignKey`` does not reference the given table. + """ + return table.corresponding_column(self.column) def column(self): @@ -766,22 +775,22 @@ class ForeignKey(SchemaItem): parenttable = c.table break else: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Parent column '%s' does not descend from a " "table-attached Column" % str(self.parent)) m = re.match(r"^(.+?)(?:\.(.+?))?(?:\.(.+?))?$", self._colspec, re.UNICODE) if m is None: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Invalid foreign key column specification: %s" % self._colspec) if m.group(3) is None: (tname, colname) = m.group(1, 2) schema = None else: - (schema,tname,colname) = m.group(1,2,3) + (schema, tname, colname) = m.group(1, 2, 3) if _get_table_key(tname, schema) not in parenttable.metadata: - raise exceptions.NoReferencedTableError( + raise exc.NoReferencedTableError( "Could not find table '%s' with which to generate a " "foreign key" % tname) table = Table(tname, parenttable.metadata, @@ -797,13 +806,13 @@ class ForeignKey(SchemaItem): else: self._column = table.c[colname] except KeyError, e: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Could not create ForeignKey '%s' on table '%s': " "table '%s' has no column named '%s'" % ( self._colspec, parenttable.name, table.name, str(e))) - - elif isinstance(self._colspec, expression.Operators): - self._column = self._colspec.clause_element() + + elif hasattr(self._colspec, '__clause_element__'): + self._column = self._colspec.__clause_element__() else: self._column = self._colspec @@ -906,12 +915,11 @@ class ColumnDefault(DefaultGenerator): defaulted = argspec[3] is not None and len(argspec[3]) or 0 if positionals - defaulted > 1: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "ColumnDefault Python function takes zero or one " "positional arguments") return fn - def _visit_name(self): if self.for_update: return "column_onupdate" @@ -926,12 +934,12 @@ class Sequence(DefaultGenerator): """Represents a named database sequence.""" def __init__(self, name, start=None, increment=None, schema=None, - optional=False, quote=False, **kwargs): + optional=False, quote=None, **kwargs): super(Sequence, self).__init__(**kwargs) self.name = name self.start = start self.increment = increment - self.optional=optional + self.optional = optional self.quote = quote self.schema = schema self.kwargs = kwargs @@ -960,7 +968,6 @@ class Sequence(DefaultGenerator): bind = _bind_or_error(self) bind.drop(self, checkfirst=checkfirst) - class Constraint(SchemaItem): """A table-level SQL constraint, such as a KEY. @@ -989,8 +996,11 @@ class Constraint(SchemaItem): self.initially = initially def __contains__(self, x): - return self.columns.contains_column(x) - + return x in self.columns + + def contains_column(self, col): + return self.columns.contains_column(col) + def keys(self): return self.columns.keys() @@ -1105,7 +1115,7 @@ class ForeignKeyConstraint(Constraint): self.onupdate = onupdate self.ondelete = ondelete if self.name is None and use_alter: - raise exceptions.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name") + raise exc.ArgumentError("Alterable ForeignKey/ForeignKeyConstraint requires a name") self.use_alter = use_alter def _set_parent(self, table): @@ -1113,7 +1123,7 @@ class ForeignKeyConstraint(Constraint): if self not in table.constraints: table.constraints.add(self) for (c, r) in zip(self.__colnames, self.__refcolnames): - self.append_element(c,r) + self.append_element(c, r) def append_element(self, col, refcol): fk = ForeignKey(refcol, constraint=self, name=self.name, onupdate=self.onupdate, ondelete=self.ondelete, use_alter=self.use_alter) @@ -1159,7 +1169,7 @@ class PrimaryKeyConstraint(Constraint): deferrable=kwargs.pop('deferrable', None), initially=kwargs.pop('initially', None)) if kwargs: - raise exceptions.ArgumentError( + raise exc.ArgumentError( 'Unknown PrimaryKeyConstraint argument(s): %s' % ', '.join([repr(x) for x in kwargs.keys()])) @@ -1174,14 +1184,14 @@ class PrimaryKeyConstraint(Constraint): def add(self, col): self.columns.add(col) - col.primary_key=True + col.primary_key = True append_column = add def replace(self, col): self.columns.replace(col) def remove(self, col): - col.primary_key=False + col.primary_key = False del self.columns[col.key] def copy(self): @@ -1222,7 +1232,7 @@ class UniqueConstraint(Constraint): deferrable=kwargs.pop('deferrable', None), initially=kwargs.pop('initially', None)) if kwargs: - raise exceptions.ArgumentError( + raise exc.ArgumentError( 'Unknown UniqueConstraint argument(s): %s' % ', '.join([repr(x) for x in kwargs.keys()])) @@ -1295,11 +1305,11 @@ class Index(SchemaItem): self._set_parent(column.table) elif column.table != self.table: # all columns muse be from same table - raise exceptions.ArgumentError( + raise exc.ArgumentError( "All index columns must be from same table. " "%s is from %s not %s" % (column, column.table, self.table)) elif column.name in [ c.name for c in self.columns ]: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "A column may not appear twice in the " "same index (%s already has column %s)" % (self.name, column)) self.columns.append(column) @@ -1370,7 +1380,7 @@ class MetaData(SchemaItem): self.ddl_listeners = util.defaultdict(list) if reflect: if not bind: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "A bind must be supplied in conjunction with reflect=True") self.reflect() @@ -1508,7 +1518,7 @@ class MetaData(SchemaItem): missing = [name for name in only if name not in available] if missing: s = schema and (" schema '%s'" % schema) or '' - raise exceptions.InvalidRequestError( + raise exc.InvalidRequestError( 'Could not reflect: requested table(s) not available ' 'in %s%s: (%s)' % (bind.engine.url, s, ', '.join(missing))) load = [name for name in only if name not in current] @@ -1777,12 +1787,12 @@ class DDL(object): """ if not isinstance(statement, basestring): - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Expected a string or unicode SQL statement, got '%r'" % statement) if (on is not None and (not isinstance(on, basestring) and not callable(on))): - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Expected the name of a database dialect or a callable for " "'on' criteria, got type '%s'." % type(on).__name__) @@ -1858,10 +1868,10 @@ class DDL(object): """ if not hasattr(schema_item, 'ddl_listeners'): - raise exceptions.ArgumentError( + raise exc.ArgumentError( "%s does not support DDL events" % type(schema_item).__name__) if event not in schema_item.ddl_events: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Unknown event, expected one of (%s), got '%r'" % (', '.join(schema_item.ddl_events), event)) schema_item.ddl_listeners[event].append(self) @@ -1955,5 +1965,5 @@ def _bind_or_error(schemaitem): 'Execution can not proceed without a database to execute ' 'against. Either execute with an explicit connection or ' 'assign %s to enable implicit execution.') % (item, bindable) - raise exceptions.UnboundExecutionError(msg) + raise exc.UnboundExecutionError(msg) return bind diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index c966f396a..5ea9eb1e6 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -1,2 +1,2 @@ from sqlalchemy.sql.expression import * -from sqlalchemy.sql.visitors import ClauseVisitor, NoColumnVisitor +from sqlalchemy.sql.visitors import ClauseVisitor diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1fe9ef062..78bb4e31c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -19,7 +19,7 @@ is otherwise internal to SQLAlchemy. """ import string, re, itertools -from sqlalchemy import schema, engine, util, exceptions +from sqlalchemy import schema, engine, util, exc from sqlalchemy.sql import operators, functions from sqlalchemy.sql import expression as sql @@ -115,8 +115,6 @@ class DefaultCompiler(engine.Compiled): paradigm as visitors.ClauseVisitor but implements its own traversal. """ - __traverse_options__ = {'column_collections':False, 'entry':True} - operators = OPERATORS functions = FUNCTIONS @@ -162,17 +160,12 @@ class DefaultCompiler(engine.Compiled): # for aliases self.generated_ids = {} - # paramstyle from the dialect (comes from DB-API) - self.paramstyle = self.dialect.paramstyle - # true if the paramstyle is positional self.positional = self.dialect.positional + if self.positional: + self.positiontup = [] - self.bindtemplate = BIND_TEMPLATES[self.paramstyle] - - # a list of the compiled's bind parameter names, used to help - # formulate a positional argument list - self.positiontup = [] + self.bindtemplate = BIND_TEMPLATES[self.dialect.paramstyle] # an IdentifierPreparer that formats the quoting of identifiers self.preparer = self.dialect.identifier_preparer @@ -230,15 +223,18 @@ class DefaultCompiler(engine.Compiled): return "" def visit_grouping(self, grouping, **kwargs): - return "(" + self.process(grouping.elem) + ")" + return "(" + self.process(grouping.element) + ")" - def visit_label(self, label, result_map=None): + def visit_label(self, label, result_map=None, render_labels=False): + if not render_labels: + return self.process(label.element) + labelname = self._truncated_identifier("colident", label.name) if result_map is not None: - result_map[labelname.lower()] = (label.name, (label, label.obj, labelname), label.obj.type) + result_map[labelname.lower()] = (label.name, (label, label.element, labelname), label.element.type) - return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) + return " ".join([self.process(label.element), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) def visit_column(self, column, result_map=None, **kwargs): @@ -261,16 +257,16 @@ class DefaultCompiler(engine.Compiled): if getattr(column, "is_literal", False): name = self.escape_literal_column(name) else: - name = self.preparer.quote(column, name) + name = self.preparer.quote(name, column.quote) if column.table is None or not column.table.named_with_column: return name else: if getattr(column.table, 'schema', None): - schema_prefix = self.preparer.quote(column.table, column.table.schema) + '.' + schema_prefix = self.preparer.quote(column.table.schema, column.table.quote_schema) + '.' else: schema_prefix = '' - return schema_prefix + self.preparer.quote(column.table, ANONYMOUS_LABEL.sub(self._process_anon, column.table.name)) + "." + name + return schema_prefix + self.preparer.quote(ANONYMOUS_LABEL.sub(self._process_anon, column.table.name), column.table.quote) + "." + name def escape_literal_column(self, text): """provide escaping for the literal_column() construct.""" @@ -387,7 +383,7 @@ class DefaultCompiler(engine.Compiled): if name in self.binds: existing = self.binds[name] if existing is not bindparam and (existing.unique or bindparam.unique): - raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) + raise exc.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) self.binds[bindparam.key] = self.binds[name] = bindparam return self.bindparam_string(name) @@ -418,7 +414,7 @@ class DefaultCompiler(engine.Compiled): return truncname def _process_anon(self, match): - (ident, derived) = match.group(1,2) + (ident, derived) = match.group(1, 2) key = ('anonymous', ident) if key in self.generated_ids: @@ -436,8 +432,9 @@ class DefaultCompiler(engine.Compiled): def bindparam_string(self, name): if self.positional: self.positiontup.append(name) - - return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} + return self.bindtemplate % {'name':name, 'position':len(self.positiontup)} + else: + return self.bindtemplate % {'name':name} def visit_alias(self, alias, asfrom=False, **kwargs): if asfrom: @@ -490,7 +487,7 @@ class DefaultCompiler(engine.Compiled): froms = select._get_display_froms(existingfroms) - correlate_froms = util.Set(itertools.chain(*([froms] + [f._get_from_objects() for f in froms]))) + correlate_froms = util.Set(sql._from_objects(*froms)) # TODO: might want to propigate existing froms for select(select(select)) # where innermost select should correlate to outermost @@ -504,6 +501,7 @@ class DefaultCompiler(engine.Compiled): [c for c in [ self.process( self.label_select_column(select, co, asfrom=asfrom), + render_labels=True, **column_clause_args) for co in select.inner_columns ] @@ -580,9 +578,9 @@ class DefaultCompiler(engine.Compiled): def visit_table(self, table, asfrom=False, **kwargs): if asfrom: if getattr(table, "schema", None): - return self.preparer.quote(table, table.schema) + "." + self.preparer.quote(table, table.name) + return self.preparer.quote(table.schema, table.quote_schema) + "." + self.preparer.quote(table.name, table.quote) else: - return self.preparer.quote(table, table.name) + return self.preparer.quote(table.name, table.quote) else: return "" @@ -603,7 +601,7 @@ class DefaultCompiler(engine.Compiled): return (insert + " INTO %s (%s) VALUES (%s)" % (preparer.format_table(insert_stmt.table), - ', '.join([preparer.quote(c[0], c[0].name) + ', '.join([preparer.quote(c[0].name, c[0].quote) for c in colparams]), ', '.join([c[1] for c in colparams]))) @@ -613,7 +611,7 @@ class DefaultCompiler(engine.Compiled): self.isupdate = True colparams = self._get_colparams(update_stmt) - text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0], c[0].name), c[1]) for c in colparams], ', ') + text = "UPDATE " + self.preparer.format_table(update_stmt.table) + " SET " + string.join(["%s=%s" % (self.preparer.quote(c[0].name, c[0].quote), c[1]) for c in colparams], ', ') if update_stmt._whereclause: text += " WHERE " + self.process(update_stmt._whereclause) @@ -837,7 +835,7 @@ class SchemaGenerator(DDLBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) self.append("PRIMARY KEY ") - self.append("(%s)" % ', '.join([self.preparer.quote(c, c.name) for c in constraint])) + self.append("(%s)" % ', '.join([self.preparer.quote(c.name, c.quote) for c in constraint])) self.define_constraint_deferrability(constraint) def visit_foreign_key_constraint(self, constraint): @@ -858,9 +856,9 @@ class SchemaGenerator(DDLBase): preparer.format_constraint(constraint)) table = list(constraint.elements)[0].column.table self.append("FOREIGN KEY(%s) REFERENCES %s (%s)" % ( - ', '.join([preparer.quote(f.parent, f.parent.name) for f in constraint.elements]), + ', '.join([preparer.quote(f.parent.name, f.parent.quote) for f in constraint.elements]), preparer.format_table(table), - ', '.join([preparer.quote(f.column, f.column.name) for f in constraint.elements]) + ', '.join([preparer.quote(f.column.name, f.column.quote) for f in constraint.elements]) )) if constraint.ondelete is not None: self.append(" ON DELETE %s" % constraint.ondelete) @@ -873,7 +871,7 @@ class SchemaGenerator(DDLBase): if constraint.name is not None: self.append("CONSTRAINT %s " % self.preparer.format_constraint(constraint)) - self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c, c.name) for c in constraint]))) + self.append(" UNIQUE (%s)" % (', '.join([self.preparer.quote(c.name, c.quote) for c in constraint]))) self.define_constraint_deferrability(constraint) def define_constraint_deferrability(self, constraint): @@ -896,7 +894,7 @@ class SchemaGenerator(DDLBase): self.append("INDEX %s ON %s (%s)" \ % (preparer.format_index(index), preparer.format_table(index.table), - string.join([preparer.quote(c, c.name) for c in index.columns], ', '))) + string.join([preparer.quote(c.name, c.quote) for c in index.columns], ', '))) self.execute() @@ -1005,9 +1003,12 @@ class IdentifierPreparer(object): or not self.legal_characters.match(unicode(value)) or (lc_value != value)) - def quote(self, obj, ident): - if getattr(obj, 'quote', False): + def quote(self, ident, force): + if force: return self.quote_identifier(ident) + elif force is False: + return ident + if ident in self.__strings: return self.__strings[ident] else: @@ -1017,53 +1018,47 @@ class IdentifierPreparer(object): self.__strings[ident] = ident return self.__strings[ident] - def should_quote(self, object): - return object.quote or self._requires_quotes(object.name) - def format_sequence(self, sequence, use_schema=True): - name = self.quote(sequence, sequence.name) + name = self.quote(sequence.name, sequence.quote) if not self.omit_schema and use_schema and sequence.schema is not None: - name = self.quote(sequence, sequence.schema) + "." + name + name = self.quote(sequence.schema, sequence.quote) + "." + name return name def format_label(self, label, name=None): - return self.quote(label, name or label.name) + return self.quote(name or label.name, label.quote) def format_alias(self, alias, name=None): - return self.quote(alias, name or alias.name) + return self.quote(name or alias.name, alias.quote) def format_savepoint(self, savepoint, name=None): - return self.quote(savepoint, name or savepoint.ident) + return self.quote(name or savepoint.ident, savepoint.quote) def format_constraint(self, constraint): - return self.quote(constraint, constraint.name) + return self.quote(constraint.name, constraint.quote) def format_index(self, index): - return self.quote(index, index.name) + return self.quote(index.name, index.quote) def format_table(self, table, use_schema=True, name=None): """Prepare a quoted table and schema name.""" if name is None: name = table.name - result = self.quote(table, name) + result = self.quote(name, table.quote) if not self.omit_schema and use_schema and getattr(table, "schema", None): - result = self.quote(table, table.schema) + "." + result + result = self.quote(table.schema, table.quote_schema) + "." + result return result def format_column(self, column, use_table=False, name=None, table_name=None): - """Prepare a quoted column name. - - deprecated. use preparer.quote(col, column.name) or combine with format_table() - """ + """Prepare a quoted column name.""" if name is None: name = column.name if not getattr(column, 'is_literal', False): if use_table: - return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(column, name) + return self.format_table(column.table, use_schema=False, name=table_name) + "." + self.quote(name, column.quote) else: - return self.quote(column, name) + return self.quote(name, column.quote) else: # literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted if use_table: @@ -1079,7 +1074,7 @@ class IdentifierPreparer(object): # a longer sequence. if not self.omit_schema and use_schema and getattr(table, 'schema', None): - return (self.quote_identifier(table.schema), + return (self.quote(table.schema, table.quote_schema), self.format_table(table, use_schema=False)) else: return (self.format_table(table, use_schema=False), ) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 867fdd69c..7ce637701 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -26,12 +26,12 @@ to stay the same in future releases. """ import itertools, re -from sqlalchemy import util, exceptions +from sqlalchemy import util, exc from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes functions, schema, sql_util = None, None, None -DefaultDialect, ClauseAdapter = None, None +DefaultDialect, ClauseAdapter, Annotated = None, None, None __all__ = [ 'Alias', 'ClauseElement', @@ -503,15 +503,21 @@ def collate(expression, collation): def exists(*args, **kwargs): """Return an ``EXISTS`` clause as applied to a [sqlalchemy.sql.expression#Select] object. + + Calling styles are of the following forms:: + + # use on an existing select() + s = select([<columns>]).where(<criterion>) + s = exists(s) + + # construct a select() at once + exists(['*'], **select_arguments).where(<criterion>) + + # columns argument is optional, generates "EXISTS (SELECT *)" + # by default. + exists().where(<criterion>) - The resulting [sqlalchemy.sql.expression#_Exists] object can be executed by - itself or used as a subquery within an enclosing select. - - \*args, \**kwargs - all arguments are sent directly to the [sqlalchemy.sql.expression#select()] - function to produce a ``SELECT`` statement. """ - return _Exists(*args, **kwargs) def union(*selects, **kwargs): @@ -872,27 +878,36 @@ def _compound_select(keyword, *selects, **kwargs): return CompoundSelect(keyword, *selects, **kwargs) def _is_literal(element): - return not isinstance(element, ClauseElement) + return not isinstance(element, (ClauseElement, Operators)) + +def _from_objects(*elements, **kwargs): + return itertools.chain(*[element._get_from_objects(**kwargs) for element in elements]) +def _labeled(element): + if not hasattr(element, 'name'): + return element.label(None) + else: + return element + def _literal_as_text(element): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): return _TextClause(unicode(element)) else: return element def _literal_as_column(element): - if isinstance(element, Operators): - return element.clause_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): return literal_column(str(element)) else: return element def _literal_as_binds(element, name=None, type_=None): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): if element is None: return null() @@ -902,17 +917,17 @@ def _literal_as_binds(element, name=None, type_=None): return element def _no_literals(element): - if isinstance(element, Operators): - return element.expression_element() + if hasattr(element, '__clause_element__'): + return element.__clause_element__() elif _is_literal(element): - raise exceptions.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) + raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' function to indicate a SQL expression literal, or 'literal()' to indicate a bound value." % element) else: return element def _corresponding_column_or_error(fromclause, column, require_embedded=False): c = fromclause.corresponding_column(column, require_embedded=require_embedded) if not c: - raise exceptions.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) + raise exc.InvalidRequestError("Given column '%s', attached to table '%s', failed to locate a corresponding column from table '%s'" % (str(column), str(getattr(column, 'table', None)), fromclause.description)) return c def _selectable(element): @@ -921,9 +936,8 @@ def _selectable(element): elif isinstance(element, Selectable): return element else: - raise exceptions.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) + raise exc.ArgumentError("Object '%s' is not a Selectable and does not implement `__selectable__()`" % repr(element)) - def is_column(col): """True if ``col`` is an instance of ``ColumnElement``.""" return isinstance(col, ColumnElement) @@ -941,7 +955,9 @@ class _FigureVisitName(type): class ClauseElement(object): """Base class for elements of a programmatically constructed SQL expression.""" __metaclass__ = _FigureVisitName - + _annotations = {} + supports_execution = False + def _clone(self): """Create a shallow copy of this ClauseElement. @@ -976,6 +992,14 @@ class ClauseElement(object): """ raise NotImplementedError(repr(self)) + + def _annotate(self, values): + """return a copy of this ClauseElement with the given annotations dictionary.""" + + global Annotated + if Annotated is None: + from sqlalchemy.sql.util import Annotated + return Annotated(self, values) def unique_params(self, *optionaldict, **kwargs): """Return a copy with ``bindparam()`` elments replaced. @@ -1006,14 +1030,14 @@ class ClauseElement(object): if len(optionaldict) == 1: kwargs.update(optionaldict[0]) elif len(optionaldict) > 1: - raise exceptions.ArgumentError("params() takes zero or one positional dictionary argument") + raise exc.ArgumentError("params() takes zero or one positional dictionary argument") def visit_bindparam(bind): if bind.key in kwargs: bind.value = kwargs[bind.key] if unique: bind._convert_to_unique() - return visitors.traverse(self, visit_bindparam=visit_bindparam, clone=True) + return visitors.cloned_traverse(self, {}, {'bindparam':visit_bindparam}) def compare(self, other): """Compare this ClauseElement to the given ClauseElement. @@ -1049,11 +1073,6 @@ class ClauseElement(object): def self_group(self, against=None): return self - def supports_execution(self): - """Return True if this clause element represents a complete executable statement.""" - - return False - def bind(self): """Returns the Engine or Connection to which this ClauseElement is bound, or None if none found.""" @@ -1062,7 +1081,7 @@ class ClauseElement(object): return self._bind except AttributeError: pass - for f in self._get_from_objects(): + for f in _from_objects(self): if f is self: continue engine = f.bind @@ -1083,7 +1102,7 @@ class ClauseElement(object): 'Engine for execution. Or, assign a bind to the statement ' 'or the Metadata of its underlying tables to enable ' 'implicit execution via this method.' % label) - raise exceptions.UnboundExecutionError(msg) + raise exc.UnboundExecutionError(msg) return e.execute_clauseelement(self, multiparams, params) def scalar(self, *multiparams, **params): @@ -1159,6 +1178,12 @@ class ClauseElement(object): self.__module__, self.__class__.__name__, id(self), friendly) +class _Immutable(object): + """mark a ClauseElement as 'immutable' when expressions are cloned.""" + + def _clone(self): + return self + class Operators(object): def __and__(self, other): return self.operate(operators.and_, other) @@ -1174,9 +1199,6 @@ class Operators(object): return self.operate(operators.op, opstring, b) return op - def clause_element(self): - raise NotImplementedError() - def operate(self, op, *other, **kwargs): raise NotImplementedError() @@ -1216,7 +1238,7 @@ class ColumnOperators(Operators): def ilike(self, other, escape=None): return self.operate(operators.ilike_op, other, escape=escape) - def in_(self, *other): + def in_(self, other): return self.operate(operators.in_op, other) def startswith(self, other, **kwargs): @@ -1279,18 +1301,18 @@ class _CompareMixin(ColumnOperators): def __compare(self, op, obj, negate=None, reverse=False, **kwargs): if obj is None or isinstance(obj, _Null): if op == operators.eq: - return _BinaryExpression(self.expression_element(), null(), operators.is_, negate=operators.isnot) + return _BinaryExpression(self, null(), operators.is_, negate=operators.isnot) elif op == operators.ne: - return _BinaryExpression(self.expression_element(), null(), operators.isnot, negate=operators.is_) + return _BinaryExpression(self, null(), operators.isnot, negate=operators.is_) else: - raise exceptions.ArgumentError("Only '='/'!=' operators can be used with NULL") + raise exc.ArgumentError("Only '='/'!=' operators can be used with NULL") else: obj = self._check_literal(obj) if reverse: - return _BinaryExpression(obj, self.expression_element(), op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(obj, self, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) else: - return _BinaryExpression(self.expression_element(), obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) + return _BinaryExpression(self, obj, op, type_=sqltypes.Boolean, negate=negate, modifiers=kwargs) def __operate(self, op, obj, reverse=False): obj = self._check_literal(obj) @@ -1298,9 +1320,9 @@ class _CompareMixin(ColumnOperators): type_ = self._compare_type(obj) if reverse: - return _BinaryExpression(obj, self.expression_element(), type_.adapt_operator(op), type_=type_) + return _BinaryExpression(obj, self, type_.adapt_operator(op), type_=type_) else: - return _BinaryExpression(self.expression_element(), obj, type_.adapt_operator(op), type_=type_) + return _BinaryExpression(self, obj, type_.adapt_operator(op), type_=type_) # a mapping of operators with the method they use, along with their negated # operator for comparison operators @@ -1329,17 +1351,10 @@ class _CompareMixin(ColumnOperators): o = _CompareMixin.operators[op] return o[0](self, op, other, reverse=True, *o[1:], **kwargs) - def in_(self, *other): - return self._in_impl(operators.in_op, operators.notin_op, *other) - - def _in_impl(self, op, negate_op, *other): - # Handle old style *args argument passing - if len(other) != 1 or not isinstance(other[0], Selectable) and (not hasattr(other[0], '__iter__') or isinstance(other[0], basestring)): - util.warn_deprecated('passing in_ arguments as varargs is deprecated, in_ takes a single argument that is a sequence or a selectable') - seq_or_selectable = other - else: - seq_or_selectable = other[0] + def in_(self, other): + return self._in_impl(operators.in_op, operators.notin_op, other) + def _in_impl(self, op, negate_op, seq_or_selectable): if isinstance(seq_or_selectable, Selectable): return self.__compare( op, seq_or_selectable, negate=negate_op) @@ -1348,7 +1363,7 @@ class _CompareMixin(ColumnOperators): for o in seq_or_selectable: if not _is_literal(o): if not isinstance( o, _CompareMixin): - raise exceptions.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) ) + raise exc.InvalidRequestError( "in() function accepts either a list of non-selectable values, or a selectable: "+repr(o) ) else: o = self._bind_param(o) args.append(o) @@ -1433,22 +1448,13 @@ class _CompareMixin(ColumnOperators): if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): other.type = self.type return other - elif isinstance(other, Operators): - return other.expression_element() + elif hasattr(other, '__clause_element__'): + return other.__clause_element__() elif _is_literal(other): return self._bind_param(other) else: return other - def clause_element(self): - """Allow ``_CompareMixins`` to return the underlying ``ClauseElement``, for non-``ClauseElement`` ``_CompareMixins``.""" - return self - - def expression_element(self): - """Allow ``_CompareMixins`` to return the appropriate object to be used in expressions.""" - - return self - def _compare_type(self, obj): """Allow subclasses to override the type used in constructing ``_BinaryExpression`` objects. @@ -1480,23 +1486,22 @@ class ColumnElement(ClauseElement, _CompareMixin): primary_key = False foreign_keys = [] - + quote = None + def base_columns(self): - if hasattr(self, '_base_columns'): - return self._base_columns - self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) + if not hasattr(self, '_base_columns'): + self._base_columns = util.Set([c for c in self.proxy_set if not hasattr(c, 'proxies')]) return self._base_columns base_columns = property(base_columns) def proxy_set(self): - if hasattr(self, '_proxy_set'): - return self._proxy_set - s = util.Set([self]) - if hasattr(self, 'proxies'): - for c in self.proxies: - s = s.union(c.proxy_set) - self._proxy_set = s - return s + if not hasattr(self, '_proxy_set'): + s = util.Set([self]) + if hasattr(self, 'proxies'): + for c in self.proxies: + s.update(c.proxy_set) + self._proxy_set = s + return self._proxy_set proxy_set = property(proxy_set) def shares_lineage(self, othercolumn): @@ -1518,7 +1523,7 @@ class ColumnElement(ClauseElement, _CompareMixin): co = _ColumnClause(self.anon_label, selectable, type_=getattr(self, 'type', None)) co.proxies = [self] - selectable.columns[name]= co + selectable.columns[name] = co return co def anon_label(self): @@ -1613,7 +1618,7 @@ class ColumnCollection(util.OrderedProperties): def __contains__(self, other): if not isinstance(other, basestring): - raise exceptions.ArgumentError("__contains__ requires a string argument") + raise exc.ArgumentError("__contains__ requires a string argument") return util.OrderedProperties.__contains__(self, other) def contains_column(self, col): @@ -1641,6 +1646,9 @@ class ColumnSet(util.OrderedSet): l.append(c==local) return and_(*l) + def __hash__(self): + return hash(tuple(self._list)) + class Selectable(ClauseElement): """mark a class as being selectable""" @@ -1648,8 +1656,9 @@ class FromClause(Selectable): """Represent an element that can be used within the ``FROM`` clause of a ``SELECT`` statement.""" __visit_name__ = 'fromclause' - named_with_column=False + named_with_column = False _hide_froms = [] + quote = None def _get_from_objects(self, **modifiers): return [] @@ -1694,12 +1703,12 @@ class FromClause(Selectable): return fromclause in util.Set(self._cloned_set) def replace_selectable(self, old, alias): - """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" + """replace all occurences of FromClause 'old' with the given Alias object, returning a copy of this ``FromClause``.""" - global ClauseAdapter - if ClauseAdapter is None: - from sqlalchemy.sql.util import ClauseAdapter - return ClauseAdapter(alias).traverse(self, clone=True) + global ClauseAdapter + if ClauseAdapter is None: + from sqlalchemy.sql.util import ClauseAdapter + return ClauseAdapter(alias).traverse(self) def correspond_on_equivalents(self, column, equivalents): col = self.corresponding_column(column, require_embedded=True) @@ -1859,7 +1868,7 @@ class _BindParamClause(ClauseElement, _CompareMixin): def _convert_to_unique(self): if not self.unique: - self.unique=True + self.unique = True self.key = "{ANON %d %s}" % (id(self), self._orig_key or 'param') def _get_from_objects(self, **modifiers): @@ -1910,6 +1919,7 @@ class _TextClause(ClauseElement): __visit_name__ = 'textclause' _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE) + supports_execution = True _hide_froms = [] oid_column = None @@ -1950,12 +1960,6 @@ class _TextClause(ClauseElement): def _get_from_objects(self, **modifiers): return [] - def supports_execution(self): - return True - - def _table_iterator(self): - return iter([]) - class _Null(ColumnElement): """Represent the NULL keyword in a SQL statement. @@ -2042,6 +2046,7 @@ class _CalculatedClause(ColumnElement): __visit_name__ = 'calculatedclause' def __init__(self, name, *clauses, **kwargs): + ColumnElement.__init__(self) self.name = name self.type = sqltypes.to_instance(kwargs.get('type_', None)) self._bind = kwargs.get('bind', None) @@ -2061,7 +2066,7 @@ class _CalculatedClause(ColumnElement): def clauses(self): if isinstance(self.clause_expr, _Grouping): - return self.clause_expr.elem + return self.clause_expr.element else: return self.clause_expr clauses = property(clauses) @@ -2239,8 +2244,13 @@ class _Exists(_UnaryExpression): __visit_name__ = _UnaryExpression.__visit_name__ def __init__(self, *args, **kwargs): - kwargs['correlate'] = True - s = select(*args, **kwargs).as_scalar().self_group() + if args and isinstance(args[0], _SelectBaseMixin): + s = args[0] + else: + if not args: + args = ([literal_column('*')],) + s = select(*args, **kwargs).as_scalar().self_group() + _UnaryExpression.__init__(self, s, operator=operators.exists) def select(self, whereclause=None, **params): @@ -2272,7 +2282,7 @@ class Join(FromClause): self.right = _selectable(right).self_group() if onclause is None: - self.onclause = self.__match_primaries(self.left, self.right) + self.onclause = self._match_primaries(self.left, self.right) else: self.onclause = onclause @@ -2310,7 +2320,7 @@ class Join(FromClause): def get_children(self, **kwargs): return self.left, self.right, self.onclause - def __match_primaries(self, primary, secondary): + def _match_primaries(self, primary, secondary): global sql_util if not sql_util: from sqlalchemy.sql import util as sql_util @@ -2359,7 +2369,7 @@ class Join(FromClause): return self.select(use_labels=True, correlate=False).alias(name) def _hide_froms(self): - return itertools.chain(*[x.left._get_from_objects() + x.right._get_from_objects() for x in self._cloned_set]) + return itertools.chain(*[_from_objects(x.left, x.right) for x in self._cloned_set]) _hide_froms = property(_hide_froms) def _get_from_objects(self, **modifiers): @@ -2382,9 +2392,10 @@ class Alias(FromClause): def __init__(self, selectable, alias=None): baseselectable = selectable while isinstance(baseselectable, Alias): - baseselectable = baseselectable.selectable + baseselectable = baseselectable.element self.original = baseselectable - self.selectable = selectable + self.supports_execution = baseselectable.supports_execution + self.element = selectable if alias is None: if self.original.named_with_column: alias = getattr(self.original, 'name', None) @@ -2398,112 +2409,100 @@ class Alias(FromClause): def is_derived_from(self, fromclause): if fromclause in util.Set(self._cloned_set): return True - return self.selectable.is_derived_from(fromclause) - - def supports_execution(self): - return self.original.supports_execution() - - def _table_iterator(self): - return self.original._table_iterator() + return self.element.is_derived_from(fromclause) def _populate_column_collection(self): - for col in self.selectable.columns: + for col in self.element.columns: col._make_proxy(self) - if self.selectable.oid_column is not None: - self._oid_column = self.selectable.oid_column._make_proxy(self) + if self.element.oid_column is not None: + self._oid_column = self.element.oid_column._make_proxy(self) def _copy_internals(self, clone=_clone): - self._reset_exported() - self.selectable = _clone(self.selectable) - baseselectable = self.selectable - while isinstance(baseselectable, Alias): - baseselectable = baseselectable.selectable - self.original = baseselectable + self._reset_exported() + self.element = _clone(self.element) + baseselectable = self.element + while isinstance(baseselectable, Alias): + baseselectable = baseselectable.selectable + self.original = baseselectable def get_children(self, column_collections=True, aliased_selectables=True, **kwargs): if column_collections: for c in self.c: yield c if aliased_selectables: - yield self.selectable + yield self.element def _get_from_objects(self, **modifiers): return [self] def bind(self): - return self.selectable.bind + return self.element.bind bind = property(bind) -class _ColumnElementAdapter(ColumnElement): - """Adapts a ClauseElement which may or may not be a - ColumnElement subclass itself into an object which - acts like a ColumnElement. - """ +class _Grouping(ColumnElement): + """Represent a grouping within a column expression""" - def __init__(self, elem): - self.elem = elem - self.type = getattr(elem, 'type', None) + def __init__(self, element): + ColumnElement.__init__(self) + self.element = element + self.type = getattr(element, 'type', None) def key(self): - return self.elem.key + return self.element.key key = property(key) def _label(self): try: - return self.elem._label + return self.element._label except AttributeError: return self.anon_label _label = property(_label) def _copy_internals(self, clone=_clone): - self.elem = clone(self.elem) + self.element = clone(self.element) def get_children(self, **kwargs): - return self.elem, + return self.element, def _get_from_objects(self, **modifiers): - return self.elem._get_from_objects(**modifiers) + return self.element._get_from_objects(**modifiers) def __getattr__(self, attr): - return getattr(self.elem, attr) + return getattr(self.element, attr) def __getstate__(self): - return {'elem':self.elem, 'type':self.type} + return {'element':self.element, 'type':self.type} def __setstate__(self, state): - self.elem = state['elem'] + self.element = state['element'] self.type = state['type'] -class _Grouping(_ColumnElementAdapter): - """Represent a grouping within a column expression""" - pass - class _FromGrouping(FromClause): """Represent a grouping of a FROM clause""" __visit_name__ = 'grouping' - def __init__(self, elem): - self.elem = elem + def __init__(self, element): + self.element = element def columns(self): - return self.elem.columns + return self.element.columns columns = c = property(columns) def _hide_froms(self): - return self.elem._hide_froms + return self.element._hide_froms _hide_froms = property(_hide_froms) def get_children(self, **kwargs): - return self.elem, + return self.element, def _copy_internals(self, clone=_clone): - self.elem = clone(self.elem) + self.element = clone(self.element) def _get_from_objects(self, **modifiers): - return self.elem._get_from_objects(**modifiers) + return self.element._get_from_objects(**modifiers) def __getattr__(self, attr): - return getattr(self.elem, attr) + return getattr(self.element, attr) class _Label(ColumnElement): """Represents a column label (AS). @@ -2516,12 +2515,12 @@ class _Label(ColumnElement): ``ColumnElement`` subclasses. """ - def __init__(self, name, obj, type_=None): - while isinstance(obj, _Label): - obj = obj.obj - self.name = name or "{ANON %d %s}" % (id(self), getattr(obj, 'name', 'anon')) - self.obj = obj.self_group(against=operators.as_) - self.type = sqltypes.to_instance(type_ or getattr(obj, 'type', None)) + def __init__(self, name, element, type_=None): + while isinstance(element, _Label): + element = element.element + self.name = name or "{ANON %d %s}" % (id(self), getattr(element, 'name', 'anon')) + self.element = element.self_group(against=operators.as_) + self.type = sqltypes.to_instance(type_ or getattr(element, 'type', None)) def key(self): return self.name @@ -2532,8 +2531,9 @@ class _Label(ColumnElement): _label = property(_label) def _proxy_attr(name): + get = util.attrgetter(name) def attr(self): - return getattr(self.obj, name) + return get(self.element) return property(attr) proxies = _proxy_attr('proxies') @@ -2542,27 +2542,24 @@ class _Label(ColumnElement): primary_key = _proxy_attr('primary_key') foreign_keys = _proxy_attr('foreign_keys') - def expression_element(self): - return self.obj - def get_children(self, **kwargs): - return self.obj, + return self.element, def _copy_internals(self, clone=_clone): - self.obj = clone(self.obj) + self.element = clone(self.element) def _get_from_objects(self, **modifiers): - return self.obj._get_from_objects(**modifiers) + return self.element._get_from_objects(**modifiers) def _make_proxy(self, selectable, name = None): - if isinstance(self.obj, (Selectable, ColumnElement)): - e = self.obj._make_proxy(selectable, name=self.name) + if isinstance(self.element, (Selectable, ColumnElement)): + e = self.element._make_proxy(selectable, name=self.name) else: e = column(self.name)._make_proxy(selectable=selectable) e.proxies.append(self) return e -class _ColumnClause(ColumnElement): +class _ColumnClause(_Immutable, ColumnElement): """Represents a generic column expression from any textual string. This includes columns associated with tables, aliases and select @@ -2602,16 +2599,7 @@ class _ColumnClause(ColumnElement): return self.name.encode('ascii', 'backslashreplace') description = property(description) - def _clone(self): - # ColumnClause is immutable - return self - def _label(self): - """Generate a 'label' string for this column. - """ - - # for a "literal" column, we've no idea what the text is - # therefore no 'label' can be automatically generated if self.is_literal: return None if not self.__label: @@ -2626,24 +2614,21 @@ class _ColumnClause(ColumnElement): counter = 1 while label in self.table.c: label = self.__label + "_" + str(counter) - counter +=1 + counter += 1 self.__label = label else: self.__label = self.name return self.__label - _label = property(_label) def label(self, name): - # if going off the "__label" property and its None, we have - # no label; return self if name is None: return self else: return super(_ColumnClause, self).label(name) def _get_from_objects(self, **modifiers): - if self.table is not None: + if self.table: return [self.table] else: return [] @@ -2651,20 +2636,20 @@ class _ColumnClause(ColumnElement): def _bind_param(self, obj): return _BindParamClause(self.name, obj, type_=self.type, unique=True) - def _make_proxy(self, selectable, name = None): + def _make_proxy(self, selectable, name=None, attach=True): # propigate the "is_literal" flag only if we are keeping our name, # otherwise its considered to be a label is_literal = self.is_literal and (name is None or name == self.name) c = _ColumnClause(name or self.name, selectable=selectable, _is_oid=self._is_oid, type_=self.type, is_literal=is_literal) c.proxies = [self] - if not self._is_oid: + if attach and not self._is_oid: selectable.columns[c.name] = c return c def _compare_type(self, obj): return self.type -class TableClause(FromClause): +class TableClause(_Immutable, FromClause): """Represents a "table" construct. Note that this represents tables only as another syntactical @@ -2691,10 +2676,6 @@ class TableClause(FromClause): return self.name.encode('ascii', 'backslashreplace') description = property(description) - def _clone(self): - # TableClause is immutable - return self - def append_column(self, c): self._columns[c.name] = c c.table = self @@ -2724,10 +2705,11 @@ class TableClause(FromClause): def _get_from_objects(self, **modifiers): return [self] - class _SelectBaseMixin(object): """Base class for ``Select`` and ``CompoundSelects``.""" + supports_execution = True + def __init__(self, use_labels=False, for_update=False, limit=None, offset=None, order_by=None, group_by=None, bind=None, autocommit=False): self.use_labels = use_labels self.for_update = for_update @@ -2773,11 +2755,6 @@ class _SelectBaseMixin(object): """ return self.as_scalar().label(name) - def supports_execution(self): - """part of the ClauseElement contract; returns ``True`` in all cases for this class.""" - - return True - def autocommit(self): """return a new selectable with the 'autocommit' flag set to True.""" @@ -2860,15 +2837,15 @@ class _SelectBaseMixin(object): class _ScalarSelect(_Grouping): __visit_name__ = 'grouping' - def __init__(self, elem): - self.elem = elem - cols = list(elem.inner_columns) + def __init__(self, element): + self.element = element + cols = list(element.inner_columns) if len(cols) != 1: - raise exceptions.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.") + raise exc.InvalidRequestError("Scalar select can only be created from a Select object that has exactly one column expression.") self.type = cols[0].type def columns(self): - raise exceptions.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.") + raise exc.InvalidRequestError("Scalar Select expression has no columns; use this object directly within a column-level expression.") columns = c = property(columns) def self_group(self, **kwargs): @@ -2893,7 +2870,7 @@ class CompoundSelect(_SelectBaseMixin, FromClause): if not numcols: numcols = len(s.c) elif len(s.c) != numcols: - raise exceptions.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" % + raise exc.ArgumentError("All selectables passed to CompoundSelect must have identical numbers of columns; select #%d has %d columns, select #%d has %d" % (1, len(self.selects[0].c), n+1, len(s.c)) ) if s._order_by_clause: @@ -2936,11 +2913,6 @@ class CompoundSelect(_SelectBaseMixin, FromClause): return (column_collections and list(self.c) or []) + \ [self._order_by_clause, self._group_by_clause] + list(self.selects) - def _table_iterator(self): - for s in self.selects: - for t in s._table_iterator(): - yield t - def bind(self): if self._bind: return self._bind @@ -2976,6 +2948,7 @@ class Select(_SelectBaseMixin, FromClause): self._distinct = distinct self._correlate = util.Set() + self._froms = util.OrderedSet() if columns: self._raw_columns = [ @@ -2983,22 +2956,23 @@ class Select(_SelectBaseMixin, FromClause): for c in [_literal_as_column(c) for c in columns] ] + + self._froms.update(_from_objects(*self._raw_columns)) else: self._raw_columns = [] - - if from_obj: - self._froms = util.Set([ - _is_literal(f) and _TextClause(f) or f - for f in util.to_list(from_obj) - ]) - else: - self._froms = util.Set() - + if whereclause: self._whereclause = _literal_as_text(whereclause) + self._froms.update(_from_objects(self._whereclause, is_where=True)) else: self._whereclause = None + if from_obj: + self._froms.update([ + _is_literal(f) and _TextClause(f) or f + for f in util.to_list(from_obj) + ]) + if having: self._having = _literal_as_text(having) else: @@ -3020,36 +2994,28 @@ class Select(_SelectBaseMixin, FromClause): correlating. """ - froms = util.OrderedSet() - - for col in self._raw_columns: - froms.update(col._get_from_objects()) - - if self._whereclause is not None: - froms.update(self._whereclause._get_from_objects(is_where=True)) - - if self._froms: - froms.update(self._froms) + froms = self._froms toremove = itertools.chain(*[f._hide_froms for f in froms]) - froms.difference_update(toremove) + if toremove: + froms = froms.difference(toremove) if len(froms) > 1 or self._correlate: if self._correlate: - froms.difference_update(_cloned_intersection(froms, self._correlate)) + froms = froms.difference(_cloned_intersection(froms, self._correlate)) if self._should_correlate and existing_froms: - froms.difference_update(_cloned_intersection(froms, existing_froms)) + froms = froms.difference(_cloned_intersection(froms, existing_froms)) if not len(froms): - raise exceptions.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self) + raise exc.InvalidRequestError("Select statement '%s' returned no FROM clauses due to auto-correlation; specify correlate(<tables>) to control correlation manually." % self) return froms froms = property(_get_display_froms, doc="""Return a list of all FromClause elements which will be applied to the FROM clause of the resulting statement.""") def type(self): - raise exceptions.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.") + raise exc.InvalidRequestError("Select objects don't have a type. Call as_scalar() on this Select object to return a 'scalar' version of this Select.") type = property(type) def locate_all_froms(self): @@ -3059,22 +3025,10 @@ class Select(_SelectBaseMixin, FromClause): is specifically for those FromClause elements that would actually be rendered. """ - if hasattr(self, '_all_froms'): - return self._all_froms - - froms = util.Set( - itertools.chain(* - [self._froms] + - [f._get_from_objects() for f in self._froms] + - [col._get_from_objects() for col in self._raw_columns] - ) - ) + if not hasattr(self, '_all_froms'): + self._all_froms = self._froms.union(_from_objects(*list(self._froms))) - if self._whereclause: - froms.update(self._whereclause._get_from_objects(is_where=True)) - - self._all_froms = froms - return froms + return self._all_froms def inner_columns(self): """an iteratorof all ColumnElement expressions which would @@ -3092,7 +3046,7 @@ class Select(_SelectBaseMixin, FromClause): def is_derived_from(self, fromclause): if self in util.Set(fromclause._cloned_set): return True - + for f in self.locate_all_froms(): if f.is_derived_from(fromclause): return True @@ -3112,7 +3066,7 @@ class Select(_SelectBaseMixin, FromClause): """return child elements as per the ClauseElement specification.""" return (column_collections and list(self.columns) or []) + \ - list(self.locate_all_froms()) + \ + list(self._froms) + \ [x for x in (self._whereclause, self._having, self._order_by_clause, self._group_by_clause) if x is not None] def column(self, column): @@ -3125,6 +3079,7 @@ class Select(_SelectBaseMixin, FromClause): column = column.self_group(against=operators.comma_op) s._raw_columns = s._raw_columns + [column] + s._froms = s._froms.union(_from_objects(column)) return s def where(self, whereclause): @@ -3185,7 +3140,7 @@ class Select(_SelectBaseMixin, FromClause): """ s = self._generate() - s._should_correlate=False + s._should_correlate = False if fromclauses == (None,): s._correlate = util.Set() else: @@ -3195,7 +3150,7 @@ class Select(_SelectBaseMixin, FromClause): def append_correlation(self, fromclause): """append the given correlation expression to this select() construct.""" - self._should_correlate=False + self._should_correlate = False self._correlate = self._correlate.union([fromclause]) def append_column(self, column): @@ -3207,6 +3162,7 @@ class Select(_SelectBaseMixin, FromClause): column = column.self_group(against=operators.comma_op) self._raw_columns = self._raw_columns + [column] + self._froms = self._froms.union(_from_objects(column)) self._reset_exported() def append_prefix(self, clause): @@ -3221,10 +3177,13 @@ class Select(_SelectBaseMixin, FromClause): The expression will be joined to existing WHERE criterion via AND. """ + whereclause = _literal_as_text(whereclause) + self._froms = self._froms.union(_from_objects(whereclause, is_where=True)) + if self._whereclause is not None: - self._whereclause = and_(self._whereclause, _literal_as_text(whereclause)) + self._whereclause = and_(self._whereclause, whereclause) else: - self._whereclause = _literal_as_text(whereclause) + self._whereclause = whereclause def append_having(self, having): """append the given expression to this select() construct's HAVING criterion. @@ -3311,31 +3270,23 @@ class Select(_SelectBaseMixin, FromClause): return intersect_all(self, other, **kwargs) - def _table_iterator(self): - for t in visitors.NoColumnVisitor().iterate(self): - if isinstance(t, TableClause): - yield t - def bind(self): if self._bind: return self._bind - for f in self._froms: - if f is self: - continue - e = f.bind - if e: - self._bind = e - return e - # look through the columns (largely synomous with looking - # through the FROMs except in the case of _CalculatedClause/_Function) - for c in self._raw_columns: - if getattr(c, 'table', None) is self: - continue - e = c.bind + if not self._froms: + for c in self._raw_columns: + e = c.bind + if e: + self._bind = e + return e + else: + e = list(self._froms)[0].bind if e: self._bind = e return e + return None + def _set_bind(self, bind): self._bind = bind bind = property(bind, _set_bind) @@ -3343,11 +3294,7 @@ class Select(_SelectBaseMixin, FromClause): class _UpdateBase(ClauseElement): """Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.""" - def supports_execution(self): - return True - - def _table_iterator(self): - return iter([self.table]) + supports_execution = True def _generate(self): s = self.__class__.__new__(self.__class__) @@ -3407,7 +3354,7 @@ class Insert(_ValuesBase): self._bind = bind self.table = table self.select = None - self.inline=inline + self.inline = inline if prefixes: self._prefixes = [_literal_as_text(p) for p in prefixes] else: @@ -3502,10 +3449,11 @@ class Delete(_UpdateBase): self._whereclause = clone(self._whereclause) class _IdentifiedClause(ClauseElement): + supports_execution = True + quote = None + def __init__(self, ident): self.ident = ident - def supports_execution(self): - return True class SavepointClause(_IdentifiedClause): pass diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index dfd638ecb..46dcaba66 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -44,7 +44,7 @@ def between_op(a, b, c): return a.between(b, c) def in_op(a, b): - return a.in_(*b) + return a.in_(b) def notin_op(a, b): raise NotImplementedError() diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index d299982cf..944a68def 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1,4 +1,4 @@ -from sqlalchemy import exceptions, schema, topological, util, sql +from sqlalchemy import exc, schema, topological, util, sql from sqlalchemy.sql import expression, operators, visitors from itertools import chain @@ -8,43 +8,57 @@ def sort_tables(tables, reverse=False): """sort a collection of Table objects in order of their foreign-key dependency.""" tuples = [] - class TVisitor(schema.SchemaVisitor): - def visit_foreign_key(_self, fkey): - if fkey.use_alter: - return - parent_table = fkey.column.table - if parent_table in tables: - child_table = fkey.parent.table - tuples.append( ( parent_table, child_table ) ) - vis = TVisitor() + def visit_foreign_key(fkey): + if fkey.use_alter: + return + parent_table = fkey.column.table + if parent_table in tables: + child_table = fkey.parent.table + tuples.append( ( parent_table, child_table ) ) + for table in tables: - vis.traverse(table) + visitors.traverse(table, {'schema_visitor':True}, {'foreign_key':visit_foreign_key}) sequence = topological.sort(tuples, tables) if reverse: return util.reversed(sequence) else: return sequence -def find_tables(clause, check_columns=False, include_aliases=False): +def search(clause, target): + if not clause: + return False + for elem in visitors.iterate(clause, {'column_collections':False}): + if elem is target: + return True + else: + return False + +def find_tables(clause, check_columns=False, include_aliases=False, include_joins=False, include_selects=False): """locate Table objects within the given expression.""" tables = [] - kwargs = {} + _visitors = {} + + def visit_something(elem): + tables.append(elem) + + if include_selects: + _visitors['select'] = _visitors['compound_select'] = visit_something + + if include_joins: + _visitors['join'] = visit_something + if include_aliases: - def visit_alias(alias): - tables.append(alias) - kwargs['visit_alias'] = visit_alias + _visitors['alias'] = visit_something if check_columns: def visit_column(column): tables.append(column.table) - kwargs['visit_column'] = visit_column + _visitors['column'] = visit_column - def visit_table(table): - tables.append(table) - kwargs['visit_table'] = visit_table + _visitors['table'] = visit_something - visitors.traverse(clause, traverse_options= {'column_collections':False}, **kwargs) + visitors.traverse(clause, {'column_collections':False}, _visitors) return tables def find_columns(clause): @@ -53,7 +67,7 @@ def find_columns(clause): cols = util.Set() def visit_column(col): cols.add(col) - visitors.traverse(clause, visit_column=visit_column) + visitors.traverse(clause, {}, {'column':visit_column}) return cols def join_condition(a, b, ignore_nonexistent_tables=False): @@ -72,7 +86,7 @@ def join_condition(a, b, ignore_nonexistent_tables=False): for fk in b.foreign_keys: try: col = fk.get_referent(a) - except exceptions.NoReferencedTableError: + except exc.NoReferencedTableError: if ignore_nonexistent_tables: continue else: @@ -81,27 +95,26 @@ def join_condition(a, b, ignore_nonexistent_tables=False): if col: crit.append(col == fk.parent) constraints.add(fk.constraint) - if a is not b: for fk in a.foreign_keys: try: col = fk.get_referent(b) - except exceptions.NoReferencedTableError: + except exc.NoReferencedTableError: if ignore_nonexistent_tables: continue else: raise - + if col: crit.append(col == fk.parent) constraints.add(fk.constraint) if len(crit) == 0: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Can't find any foreign key relationships " "between '%s' and '%s'" % (a.description, b.description)) elif len(constraints) > 1: - raise exceptions.ArgumentError( + raise exc.ArgumentError( "Can't determine join between '%s' and '%s'; " "tables have more than one foreign key " "constraint relationship between them. " @@ -111,7 +124,70 @@ def join_condition(a, b, ignore_nonexistent_tables=False): return (crit[0]) else: return sql.and_(*crit) + +class Annotated(object): + """clones a ClauseElement and applies an 'annotations' dictionary. + + Unlike regular clones, this clone also mimics __hash__() and + __cmp__() of the original element so that it takes its place + in hashed collections. + A reference to the original element is maintained, for the important + reason of keeping its hash value current. When GC'ed, the + hash value may be reused, causing conflicts. + + """ + def __new__(cls, *args): + if not args: + return object.__new__(cls) + else: + element, values = args + return object.__new__( + type.__new__(type, "Annotated%s" % element.__class__.__name__, (Annotated, element.__class__), {}) + ) + + def __init__(self, element, values): + self.__dict__ = element.__dict__.copy() + self.__element = element + self._annotations = values + + def _annotate(self, values): + _values = self._annotations.copy() + _values.update(values) + clone = self.__class__.__new__(self.__class__) + clone.__dict__ = self.__dict__.copy() + clone._annotations = _values + return clone + + def __hash__(self): + return hash(self.__element) + + def __cmp__(self, other): + return cmp(hash(self.__element), hash(other)) + +def splice_joins(left, right, stop_on=None): + if left is None: + return right + + stack = [(right, None)] + + adapter = ClauseAdapter(left) + ret = None + while stack: + (right, prevright) = stack.pop() + if isinstance(right, expression.Join) and right is not stop_on: + right = right._clone() + right._reset_exported() + right.onclause = adapter.traverse(right.onclause) + stack.append((right.left, right)) + else: + right = adapter.traverse(right) + if prevright: + prevright.left = right + if not ret: + ret = right + + return ret def reduce_columns(columns, *clauses): """given a list of columns, return a 'reduced' set based on natural equivalents. @@ -151,7 +227,7 @@ def reduce_columns(columns, *clauses): omit.add(c) break for clause in clauses: - visitors.traverse(clause, visit_binary=visit_binary) + visitors.traverse(clause, {}, {'binary':visit_binary}) return expression.ColumnSet(columns.difference(omit)) @@ -159,7 +235,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re """traverse an expression and locate binary criterion pairs.""" if consider_as_foreign_keys and consider_as_referenced_keys: - raise exceptions.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'") + raise exc.ArgumentError("Can only specify one of 'consider_as_foreign_keys' or 'consider_as_referenced_keys'") def visit_binary(binary): if not any_operator and binary.operator != operators.eq: @@ -184,7 +260,7 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None, consider_as_re elif binary.right.references(binary.left): pairs.append((binary.left, binary.right)) pairs = [] - visitors.traverse(expression, visit_binary=visit_binary) + visitors.traverse(expression, {}, {'binary':visit_binary}) return pairs def folded_equivalents(join, equivs=None): @@ -195,15 +271,15 @@ def folded_equivalents(join, equivs=None): This function is used by Join.select(fold_equivalents=True). TODO: deprecate ? - """ + """ if equivs is None: equivs = util.Set() def visit_binary(binary): if binary.operator == operators.eq and binary.left.name == binary.right.name: equivs.add(binary.right) equivs.add(binary.left) - visitors.traverse(join.onclause, visit_binary=visit_binary) + visitors.traverse(join.onclause, {}, {'binary':visit_binary}) collist = [] if isinstance(join.left, expression.Join): left = folded_equivalents(join.left, equivs) @@ -246,43 +322,8 @@ class AliasedRow(object): def keys(self): return self.row.keys() -def row_adapter(from_, equivalent_columns=None): - """create a row adapter callable against a selectable.""" - - if equivalent_columns is None: - equivalent_columns = {} - - def locate_col(col): - c = from_.corresponding_column(col) - if c: - return c - elif col in equivalent_columns: - for c2 in equivalent_columns[col]: - corr = from_.corresponding_column(c2) - if corr: - return corr - return col - - map = util.PopulateDict(locate_col) - - def adapt(row): - return AliasedRow(row, map) - return adapt - -class ColumnsInClause(visitors.ClauseVisitor): - """Given a selectable, visit clauses and determine if any columns - from the clause are in the selectable. - """ - - def __init__(self, selectable): - self.selectable = selectable - self.result = False - - def visit_column(self, column): - if self.selectable.c.get(column.key) is column: - self.result = True -class ClauseAdapter(visitors.ClauseVisitor): +class ClauseAdapter(visitors.ReplacingCloningVisitor): """Given a clause (like as in a WHERE criterion), locate columns which are embedded within a given selectable, and changes those columns to be that of the selectable. @@ -308,58 +349,76 @@ class ClauseAdapter(visitors.ClauseVisitor): condition to read:: s.c.col1 == table2.c.col1 - """ - - __traverse_options__ = {'column_collections':False} - def __init__(self, selectable, include=None, exclude=None, equivalents=None): - self.__traverse_options__ = self.__traverse_options__.copy() - self.__traverse_options__['stop_on'] = [selectable] + """ + def __init__(self, selectable, equivalents=None, include=None, exclude=None): + self.__traverse_options__ = {'column_collections':False, 'stop_on':[selectable]} self.selectable = selectable self.include = include self.exclude = exclude - self.equivalents = equivalents - - def traverse(self, obj, clone=True): - if not clone: - raise exceptions.ArgumentError("ClauseAdapter 'clone' argument must be True") - return visitors.ClauseVisitor.traverse(self, obj, clone=True) - - def copy_and_chain(self, adapter): - """create a copy of this adapter and chain to the given adapter. - - currently this adapter must be unchained to start, raises - an exception if it's already chained. - - Does not modify the given adapter. - """ + self.equivalents = equivalents or {} - if adapter is None: - return self + def _corresponding_column(self, col, require_embedded): + newcol = self.selectable.corresponding_column(col, require_embedded=require_embedded) - if hasattr(self, '_next'): - raise NotImplementedError("Can't chain_to on an already chained ClauseAdapter (yet)") - - ca = ClauseAdapter(self.selectable, self.include, self.exclude, self.equivalents) - ca._next = adapter - return ca + if not newcol and col in self.equivalents: + for equiv in self.equivalents[col]: + newcol = self.selectable.corresponding_column(equiv, require_embedded=require_embedded) + if newcol: + return newcol + return newcol - def before_clone(self, col): + def replace(self, col): if isinstance(col, expression.FromClause): if self.selectable.is_derived_from(col): return self.selectable + if not isinstance(col, expression.ColumnElement): return None - if self.include is not None: - if col not in self.include: - return None - if self.exclude is not None: - if col in self.exclude: - return None - newcol = self.selectable.corresponding_column(col, require_embedded=True) - if newcol is None and self.equivalents is not None and col in self.equivalents: - for equiv in self.equivalents[col]: - newcol = self.selectable.corresponding_column(equiv, require_embedded=True) - if newcol: - return newcol - return newcol + + if self.include and col not in self.include: + return None + elif self.exclude and col in self.exclude: + return None + + return self._corresponding_column(col, True) + +class ColumnAdapter(ClauseAdapter): + + def __init__(self, selectable, equivalents=None, chain_to=None, include=None, exclude=None): + ClauseAdapter.__init__(self, selectable, equivalents, include, exclude) + if chain_to: + self.chain(chain_to) + self.columns = util.PopulateDict(self._locate_col) + + def wrap(self, adapter): + ac = self.__class__.__new__(self.__class__) + ac.__dict__ = self.__dict__.copy() + ac._locate_col = ac._wrap(ac._locate_col, adapter._locate_col) + ac.adapt_clause = ac._wrap(ac.adapt_clause, adapter.adapt_clause) + ac.adapt_list = ac._wrap(ac.adapt_list, adapter.adapt_list) + ac.columns = util.PopulateDict(ac._locate_col) + return ac + + adapt_clause = ClauseAdapter.traverse + adapt_list = ClauseAdapter.copy_and_process + + def _wrap(self, local, wrapped): + def locate(col): + col = local(col) + return wrapped(col) + return locate + + def _locate_col(self, col): + c = self._corresponding_column(col, False) + if not c: + c = self.adapt_clause(col) + + # anonymize labels in case they have a hardcoded name + if isinstance(c, expression._Label): + c = c.label(None) + return c + + def adapted_row(self, row): + return AliasedRow(row, self.columns) + diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 9888a228a..738dae9c7 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -1,138 +1,29 @@ from sqlalchemy import util class ClauseVisitor(object): - """Traverses and visits ``ClauseElement`` structures. - - Calls visit_XXX() methods for each particular - ``ClauseElement`` subclass encountered. Traversal of a - hierarchy of ``ClauseElements`` is achieved via the - ``traverse()`` method, which is passed the lead - ``ClauseElement``. - - By default, ``ClauseVisitor`` traverses all elements - fully. Options can be specified at the class level via the - ``__traverse_options__`` dictionary which will be passed - to the ``get_children()`` method of each ``ClauseElement``; - these options can indicate modifications to the set of - elements returned, such as to not return column collections - (column_collections=False) or to return Schema-level items - (schema_visitor=True). - - ``ClauseVisitor`` also supports a simultaneous copy-and-traverse - operation, which will produce a copy of a given ``ClauseElement`` - structure while at the same time allowing ``ClauseVisitor`` subclasses - to modify the new structure in-place. - - """ __traverse_options__ = {} - def traverse_single(self, obj, **kwargs): - """visit a single element, without traversing its child elements.""" - + def traverse_single(self, obj): for v in self._iterate_visitors: meth = getattr(v, "visit_%s" % obj.__visit_name__, None) if meth: - return meth(obj, **kwargs) + return meth(obj) - traverse_chained = traverse_single - def iterate(self, obj): """traverse the given expression structure, returning an iterator of all elements.""" - - stack = [obj] - traversal = util.deque() - while stack: - t = stack.pop() - traversal.appendleft(t) - for c in t.get_children(**self.__traverse_options__): - stack.append(c) - return iter(traversal) - - def traverse(self, obj, clone=False): - """traverse and visit the given expression structure. - - Returns the structure given, or a copy of the structure if - clone=True. - - When the copy operation takes place, the before_clone() method - will receive each element before it is copied. If the method - returns a non-None value, the return value is taken as the - "copied" element and traversal will not descend further. - - The visit_XXX() methods receive the element *after* it's been - copied. To compare an element to another regardless of - one element being a cloned copy of the original, the - '_cloned_set' attribute of ClauseElement can be used for the compare, - i.e.:: - - original in copied._cloned_set - - - """ - if clone: - return self._cloned_traversal(obj) - else: - return self._non_cloned_traversal(obj) - - def copy_and_process(self, list_): - """Apply cloned traversal to the given list of elements, and return the new list.""" - - return [self._cloned_traversal(x) for x in list_] - def before_clone(self, elem): - """receive pre-copied elements during a cloning traversal. - - If the method returns a new element, the element is used - instead of creating a simple copy of the element. Traversal - will halt on the newly returned element if it is re-encountered. - """ - return None - - def _clone_element(self, elem, stop_on, cloned): - for v in self._iterate_visitors: - newelem = v.before_clone(elem) - if newelem: - stop_on.add(newelem) - return newelem - - if elem not in cloned: - # the full traversal will only make a clone of a particular element - # once. - cloned[elem] = elem._clone() - return cloned[elem] - - def _cloned_traversal(self, obj): - """a recursive traversal which creates copies of elements, returning the new structure.""" - - stop_on = self.__traverse_options__.get('stop_on', []) - return self._cloned_traversal_impl(obj, util.Set(stop_on), {}, _clone_toplevel=True) - - def _cloned_traversal_impl(self, elem, stop_on, cloned, _clone_toplevel=False): - if elem in stop_on: - return elem - - if _clone_toplevel: - elem = self._clone_element(elem, stop_on, cloned) - if elem in stop_on: - return elem - - def clone(element): - return self._clone_element(element, stop_on, cloned) - elem._copy_internals(clone=clone) + return iterate(obj, self.__traverse_options__) - self.traverse_single(elem) + def traverse(self, obj): + """traverse and visit the given expression structure.""" - for e in elem.get_children(**self.__traverse_options__): - if e not in stop_on: - self._cloned_traversal_impl(e, stop_on, cloned) - return elem + visitors = {} - def _non_cloned_traversal(self, obj): - """a non-recursive, non-cloning traversal.""" - - for target in self.iterate(obj): - self.traverse_single(target) - return obj + for name in dir(self): + if name.startswith('visit_'): + visitors[name[6:]] = getattr(self, name) + + return traverse(obj, self.__traverse_options__, visitors) def _iterate_visitors(self): """iterate through this visitor and each 'chained' visitor.""" @@ -152,31 +43,136 @@ class ClauseVisitor(object): tail._next = visitor return self -class NoColumnVisitor(ClauseVisitor): - """ClauseVisitor with 'column_collections' set to False; will not - traverse the front-facing Column collections on Table, Alias, Select, - and CompoundSelect objects. +class CloningVisitor(ClauseVisitor): + def copy_and_process(self, list_): + """Apply cloned traversal to the given list of elements, and return the new list.""" + + return [self.traverse(x) for x in list_] + + def traverse(self, obj): + """traverse and visit the given expression structure.""" + + visitors = {} + + for name in dir(self): + if name.startswith('visit_'): + visitors[name[6:]] = getattr(self, name) + + return cloned_traverse(obj, self.__traverse_options__, visitors) + +class ReplacingCloningVisitor(CloningVisitor): + def replace(self, elem): + """receive pre-copied elements during a cloning traversal. + + If the method returns a new element, the element is used + instead of creating a simple copy of the element. Traversal + will halt on the newly returned element if it is re-encountered. + """ + return None + + def traverse(self, obj): + """traverse and visit the given expression structure.""" + + def replace(elem): + for v in self._iterate_visitors: + e = v.replace(elem) + if e: + return e + return replacement_traverse(obj, self.__traverse_options__, replace) + +def iterate(obj, opts): + """traverse the given expression structure, returning an iterator. + + traversal is configured to be breadth-first. """ + stack = util.deque([obj]) + while stack: + t = stack.popleft() + yield t + for c in t.get_children(**opts): + stack.append(c) + +def iterate_depthfirst(obj, opts): + """traverse the given expression structure, returning an iterator. - __traverse_options__ = {'column_collections':False} - -class NullVisitor(ClauseVisitor): - def traverse(self, obj, clone=False): - next = getattr(self, '_next', None) - if next: - return next.traverse(obj, clone=clone) - else: - return obj - -def traverse(clause, **kwargs): - """traverse the given clause, applying visit functions passed in as keyword arguments.""" + traversal is configured to be depth-first. + + """ + stack = util.deque([obj]) + traversal = util.deque() + while stack: + t = stack.pop() + traversal.appendleft(t) + for c in t.get_children(**opts): + stack.append(c) + return iter(traversal) + +def traverse_using(iterator, obj, visitors): + """visit the given expression structure using the given iterator of objects.""" + + for target in iterator: + meth = visitors.get(target.__visit_name__, None) + if meth: + meth(target) + return obj - clone = kwargs.pop('clone', False) - class Vis(ClauseVisitor): - __traverse_options__ = kwargs.pop('traverse_options', {}) - vis = Vis() - for key in kwargs: - setattr(vis, key, kwargs[key]) - return vis.traverse(clause, clone=clone) +def traverse(obj, opts, visitors): + """traverse and visit the given expression structure using the default iterator.""" + + return traverse_using(iterate(obj, opts), obj, visitors) + +def traverse_depthfirst(obj, opts, visitors): + """traverse and visit the given expression structure using the depth-first iterator.""" + + return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) + +def cloned_traverse(obj, opts, visitors): + cloned = {} + + def clone(element): + if element not in cloned: + cloned[element] = element._clone() + return cloned[element] + + obj = clone(obj) + stack = [obj] + + while stack: + t = stack.pop() + if t in cloned: + continue + t._copy_internals(clone=clone) + + meth = visitors.get(t.__visit_name__, None) + if meth: + meth(t) + + for c in t.get_children(**opts): + stack.append(c) + return obj + +def replacement_traverse(obj, opts, replace): + cloned = {} + stop_on = util.Set(opts.get('stop_on', [])) + + def clone(element): + newelem = replace(element) + if newelem: + stop_on.add(newelem) + return newelem + + if element not in cloned: + cloned[element] = element._clone() + return cloned[element] + obj = clone(obj) + stack = [obj] + while stack: + t = stack.pop() + if t in stop_on: + continue + t._copy_internals(clone=clone) + for c in t.get_children(**opts): + stack.append(c) + return obj diff --git a/lib/sqlalchemy/topological.py b/lib/sqlalchemy/topological.py index 996123979..9ef3dfaf4 100644 --- a/lib/sqlalchemy/topological.py +++ b/lib/sqlalchemy/topological.py @@ -19,7 +19,7 @@ conditions. """ from sqlalchemy import util -from sqlalchemy.exceptions import CircularDependencyError +from sqlalchemy.exc import CircularDependencyError __all__ = ['sort', 'sort_with_cycles', 'sort_as_tree'] @@ -207,9 +207,9 @@ def _sort(tuples, allitems, allow_cycles=False, ignore_self_cycles=False): for n in lead.cycles: if n is not lead: n._cyclical = True - for (n,k) in list(edges.edges_by_parent(n)): + for (n, k) in list(edges.edges_by_parent(n)): edges.add((lead, k)) - edges.remove((n,k)) + edges.remove((n, k)) continue else: # long cycles not allowed @@ -248,7 +248,7 @@ def _organize_as_tree(nodes): nodealldeps = node.all_deps() if nodealldeps: # iterate over independent node indexes in reverse order so we can efficiently remove them - for index in xrange(len(independents)-1,-1,-1): + for index in xrange(len(independents) - 1, -1, -1): child, childsubtree, childcycles = independents[index] # if there is a dependency between this node and an independent node if (childsubtree.intersection(nodealldeps) or childcycles.intersection(node.dependencies)): @@ -261,7 +261,7 @@ def _organize_as_tree(nodes): # remove the child from list of independent subtrees independents[index:index+1] = [] # add node as a new independent subtree - independents.append((node,subtree,cycles)) + independents.append((node, subtree, cycles)) # choose an arbitrary node from list of all independent subtrees head = independents.pop()[0] # add all other independent subtrees as a child of the chosen root diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index e06ec9a5a..bae079e64 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -24,7 +24,7 @@ __all__ = [ 'TypeEngine', 'TypeDecorator', 'AbstractType', import inspect import datetime as dt -from sqlalchemy import exceptions +from sqlalchemy import exc from sqlalchemy.util import pickle, Decimal as _python_Decimal import sqlalchemy.util as util NoneType = type(None) @@ -173,7 +173,6 @@ class TypeEngine(AbstractType): def get_col_spec(self): raise NotImplementedError() - def bind_processor(self, dialect): return None @@ -214,7 +213,7 @@ class TypeDecorator(AbstractType): def __init__(self, *args, **kwargs): if not hasattr(self.__class__, 'impl'): - raise exceptions.AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated") + raise AssertionError("TypeDecorator implementations require a class-level variable 'impl' which refers to the class of type being decorated") self.impl = self.__class__.impl(*args, **kwargs) def dialect_impl(self, dialect, **kwargs): @@ -231,7 +230,7 @@ class TypeDecorator(AbstractType): typedesc = self.load_dialect_impl(dialect) tt = self.copy() if not isinstance(tt, self.__class__): - raise exceptions.AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__)) + raise AssertionError("Type object %s does not properly implement the copy() method, it must return an object of type %s" % (self, self.__class__)) tt.impl = typedesc self._impl_dict[dialect] = tt return tt @@ -299,7 +298,7 @@ class TypeDecorator(AbstractType): return self.impl.copy_value(value) def compare_values(self, x, y): - return self.impl.compare_values(x,y) + return self.impl.compare_values(x, y) def is_mutable(self): return self.impl.is_mutable() @@ -363,12 +362,14 @@ class Concatenable(object): class String(Concatenable, TypeEngine): """A sized string type. - Usually corresponds to VARCHAR. Can also take Python unicode objects + In SQL, corresponds to VARCHAR. Can also take Python unicode objects and encode to the database's encoding in bind params (and the reverse for result sets.) - a String with no length will adapt itself automatically to a Text - object at the dialect level (this behavior is deprecated in 0.4). + The `length` field is usually required when the `String` type is used within a + CREATE TABLE statement, since VARCHAR requires a length on most databases. + Currently SQLite is an exception to this. + """ def __init__(self, length=None, convert_unicode=False, assert_unicode=None): self.length = length @@ -393,7 +394,7 @@ class String(Concatenable, TypeEngine): "param value %r" % value) return value else: - raise exceptions.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) + raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) else: return value return process @@ -411,26 +412,6 @@ class String(Concatenable, TypeEngine): else: return None - def dialect_impl(self, dialect, **kwargs): - _for_ddl = kwargs.pop('_for_ddl', False) - if _for_ddl and self.length is None: - label = util.to_ascii(_for_ddl is True and - '' or (' for column "%s"' % str(_for_ddl))) - util.warn_deprecated( - "Using String type with no length for CREATE TABLE " - "is deprecated; use the Text type explicitly" + label) - return TypeEngine.dialect_impl(self, dialect, **kwargs) - - def get_search_list(self): - l = super(String, self).get_search_list() - # if we are String or Unicode with no length, - # return Text as the highest-priority type - # to be adapted by the dialect - if self.length is None and l[0] in (String, Unicode): - return (Text,) + l - else: - return l - def get_dbapi_type(self, dbapi): return dbapi.STRING @@ -632,7 +613,7 @@ class Interval(TypeDecorator): if value is None: return None return dt.datetime.utcfromtimestamp(0) + value - + def process_result_value(self, value, dialect): if dialect.__class__ in self.__supported: return value @@ -641,23 +622,68 @@ class Interval(TypeDecorator): return None return value - dt.datetime.utcfromtimestamp(0) -class FLOAT(Float): pass -TEXT = Text -class NUMERIC(Numeric): pass -class DECIMAL(Numeric): pass -class INT(Integer): pass +class FLOAT(Float): + """The SQL FLOAT type.""" + + +class NUMERIC(Numeric): + """The SQL NUMERIC type.""" + + +class DECIMAL(Numeric): + """The SQL DECIMAL type.""" + + +class INT(Integer): + """The SQL INT or INTEGER type.""" + + INTEGER = INT -class SMALLINT(Smallinteger): pass -class TIMESTAMP(DateTime): pass -class DATETIME(DateTime): pass -class DATE(Date): pass -class TIME(Time): pass -class CLOB(Text): pass -class VARCHAR(String): pass -class CHAR(String): pass -class NCHAR(Unicode): pass -class BLOB(Binary): pass -class BOOLEAN(Boolean): pass + +class SMALLINT(Smallinteger): + """The SQL SMALLINT type.""" + + +class TIMESTAMP(DateTime): + """The SQL TIMESTAMP type.""" + + +class DATETIME(DateTime): + """The SQL DATETIME type.""" + + +class DATE(Date): + """The SQL DATE type.""" + + +class TIME(Time): + """The SQL TIME type.""" + + +TEXT = Text + +class CLOB(Text): + """The SQL CLOB type.""" + + +class VARCHAR(String): + """The SQL VARCHAR type.""" + + +class CHAR(String): + """The SQL CHAR type.""" + + +class NCHAR(Unicode): + """The SQL NCHAR type.""" + + +class BLOB(Binary): + """The SQL BLOB type.""" + + +class BOOLEAN(Boolean): + """The SQL BOOLEAN type.""" NULLTYPE = NullType() diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index e88c4b3b9..ff1108c3b 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -8,7 +8,7 @@ import inspect, itertools, new, operator, sets, sys, warnings, weakref import __builtin__ types = __import__('types') -from sqlalchemy import exceptions +from sqlalchemy import exc try: import thread, threading @@ -18,14 +18,16 @@ except ImportError: try: Set = set + FrozenSet = frozenset set_types = set, sets.Set except NameError: set_types = sets.Set, - # layer some of __builtin__.set's binop behavior onto sets.Set - class Set(sets.Set): + + def py24_style_ops(): + """Layer some of __builtin__.set's binop behavior onto sets.Set.""" + def _binary_sanity_check(self, other): pass - def issubset(self, iterable): other = type(self)(iterable) return sets.Set.issubset(self, other) @@ -38,7 +40,6 @@ except NameError: def __ge__(self, other): sets.Set._binary_sanity_check(self, other) return sets.Set.__ge__(self, other) - # lt and gt still require a BaseSet def __lt__(self, other): sets.Set._binary_sanity_check(self, other) @@ -63,6 +64,14 @@ except NameError: if not isinstance(other, sets.BaseSet): return NotImplemented return sets.Set.__isub__(self, other) + return locals() + + py24_style_ops = py24_style_ops() + Set = type('Set', (sets.Set,), py24_style_ops) + FrozenSet = type('FrozenSet', (sets.ImmutableSet,), py24_style_ops) + del py24_style_ops + +EMPTY_SET = FrozenSet() try: import cPickle as pickle @@ -96,10 +105,16 @@ except ImportError: try: from operator import attrgetter -except: +except ImportError: def attrgetter(attribute): return lambda value: getattr(value, attribute) +try: + from operator import itemgetter +except ImportError: + def itemgetter(attribute): + return lambda value: value[attribute] + if sys.version_info >= (2, 5): class PopulateDict(dict): """a dict which populates missing values via a creation function. @@ -169,17 +184,17 @@ except ImportError: class deque(list): def appendleft(self, x): self.insert(0, x) - + def extendleft(self, iterable): self[0:0] = list(iterable) def popleft(self): return self.pop(0) - + def rotate(self, n): for i in xrange(n): self.appendleft(self.pop()) - + def to_list(x, default=None): if x is None: return default @@ -188,18 +203,34 @@ def to_list(x, default=None): else: return x -def array_as_starargs_decorator(func): +def array_as_starargs_decorator(fn): """Interpret a single positional array argument as *args for the decorated method. - + """ + def starargs_as_list(self, *args, **kwargs): - if len(args) == 1: - return func(self, *to_list(args[0], []), **kwargs) + if isinstance(args, basestring) or (len(args) == 1 and not isinstance(args[0], tuple)): + return fn(self, *to_list(args[0], []), **kwargs) else: - return func(self, *args, **kwargs) - return starargs_as_list - + return fn(self, *args, **kwargs) + starargs_as_list.__doc__ = fn.__doc__ + return function_named(starargs_as_list, fn.__name__) + +def array_as_starargs_fn_decorator(fn): + """Interpret a single positional array argument as + *args for the decorated function. + + """ + + def starargs_as_list(*args, **kwargs): + if isinstance(args, basestring) or (len(args) == 1 and not isinstance(args[0], tuple)): + return fn(*to_list(args[0], []), **kwargs) + else: + return fn(*args, **kwargs) + starargs_as_list.__doc__ = fn.__doc__ + return function_named(starargs_as_list, fn.__name__) + def to_set(x): if x is None: return Set() @@ -281,14 +312,121 @@ def get_func_kwargs(func): """Return the full set of legal kwargs for the given `func`.""" return inspect.getargspec(func)[0] +def format_argspec_plus(fn, grouped=True): + """Returns a dictionary of formatted, introspected function arguments. + + A enhanced variant of inspect.formatargspec to support code generation. + + fn + An inspectable callable + grouped + Defaults to True; include (parens, around, argument) lists + + Returns: + + args + Full inspect.formatargspec for fn + self_arg + The name of the first positional argument, or None + apply_pos + args, re-written in calling rather than receiving syntax. Arguments are + passed positionally. + apply_kw + Like apply_pos, except keyword-ish args are passed as keywords. + + Example:: + + >>> format_argspec_plus(lambda self, a, b, c=3, **d: 123) + {'args': '(self, a, b, c=3, **d)', + 'self_arg': 'self', + 'apply_kw': '(self, a, b, c=c, **d)', + 'apply_pos': '(self, a, b, c, **d)'} + + """ + spec = inspect.getargspec(fn) + args = inspect.formatargspec(*spec) + self_arg = spec[0] and spec[0][0] or None + apply_pos = inspect.formatargspec(spec[0], spec[1], spec[2]) + defaulted_vals = spec[3] is not None and spec[0][0-len(spec[3]):] or () + apply_kw = inspect.formatargspec(spec[0], spec[1], spec[2], defaulted_vals, + formatvalue=lambda x: '=' + x) + if grouped: + return dict(args=args, self_arg=self_arg, + apply_pos=apply_pos, apply_kw=apply_kw) + else: + return dict(args=args[1:-1], self_arg=self_arg, + apply_pos=apply_pos[1:-1], apply_kw=apply_kw[1:-1]) + +def format_argspec_init(method, grouped=True): + """format_argspec_plus with considerations for typical __init__ methods + + Wraps format_argspec_plus with error handling strategies for typical + __init__ cases:: + + object.__init__ -> (self) + other unreflectable (usually C) -> (self, *args, **kwargs) + + """ + try: + return format_argspec_plus(method, grouped=grouped) + except TypeError: + self_arg = 'self' + if method is object.__init__: + args = grouped and '(self)' or 'self' + else: + args = (grouped and '(self, *args, **kwargs)' + or 'self, *args, **kwargs') + return dict(self_arg='self', args=args, apply_pos=args, apply_kw=args) + +def getargspec_init(method): + """inspect.getargspec with considerations for typical __init__ methods + + Wraps inspect.getargspec with error handling for typical __init__ cases:: + + object.__init__ -> (self) + other unreflectable (usually C) -> (self, *args, **kwargs) + + """ + try: + return inspect.getargspec(method) + except TypeError: + if method is object.__init__: + return (['self'], None, None, None) + else: + return (['self'], 'args', 'kwargs', None) + def unbound_method_to_callable(func_or_cls): """Adjust the incoming callable such that a 'self' argument is not required.""" - + if isinstance(func_or_cls, types.MethodType) and not func_or_cls.im_self: return func_or_cls.im_func else: return func_or_cls +def class_hierarchy(cls): + """Return an unordered sequence of all classes related to cls. + + Traverses diamond hierarchies. + + Fibs slightly: subclasses of builtin types are not returned. Thus + class_hierarchy(class A(object)) returns (A, object), not A plus every + class systemwide that derives from object. + + """ + hier = Set([cls]) + process = list(cls.__mro__) + while process: + c = process.pop() + for b in [_ for _ in c.__bases__ if _ not in hier]: + process.append(b) + hier.add(b) + if c.__module__ == '__builtin__': + continue + for s in [_ for _ in c.__subclasses__() if _ not in hier]: + process.append(s) + hier.add(s) + return list(hier) + # from paste.deploy.converters def asbool(obj): if isinstance(obj, (str, unicode)): @@ -328,9 +466,12 @@ def duck_type_collection(specimen, default=None): return specimen.__emulates__ isa = isinstance(specimen, type) and issubclass or isinstance - if isa(specimen, list): return list - if isa(specimen, set_types): return Set - if isa(specimen, dict): return dict + if isa(specimen, list): + return list + elif isa(specimen, set_types): + return Set + elif isa(specimen, dict): + return dict if hasattr(specimen, 'append'): return list @@ -370,10 +511,23 @@ def assert_arg_type(arg, argtype, name): return arg else: if isinstance(argtype, tuple): - raise exceptions.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg)))) + raise exc.ArgumentError("Argument '%s' is expected to be one of type %s, got '%s'" % (name, ' or '.join(["'%s'" % str(a) for a in argtype]), str(type(arg)))) else: - raise exceptions.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg)))) + raise exc.ArgumentError("Argument '%s' is expected to be of type '%s', got '%s'" % (name, str(argtype), str(type(arg)))) +_creation_order = 1 +def set_creation_order(instance): + """assign a '_creation_order' sequence to the given instance. + + This allows multiple instances to be sorted in order of + creation (typically within a single thread; the counter is + not particularly threadsafe). + + """ + global _creation_order + instance._creation_order = _creation_order + _creation_order +=1 + def warn_exception(func, *args, **kwargs): """executes the given function, catches all exceptions and converts to a warning.""" try: @@ -430,22 +584,22 @@ class SimpleProperty(object): class NotImplProperty(object): - """a property that raises ``NotImplementedError``.""" + """a property that raises ``NotImplementedError``.""" - def __init__(self, doc): - self.__doc__ = doc + def __init__(self, doc): + self.__doc__ = doc - def __set__(self, obj, value): - raise NotImplementedError() + def __set__(self, obj, value): + raise NotImplementedError() - def __delete__(self, obj): - raise NotImplementedError() + def __delete__(self, obj): + raise NotImplementedError() - def __get__(self, obj, owner): - if obj is None: - return self - else: - raise NotImplementedError() + def __get__(self, obj, owner): + if obj is None: + return self + else: + raise NotImplementedError() class OrderedProperties(object): """An object that maintains the order in which attributes are set upon it. @@ -496,10 +650,10 @@ class OrderedProperties(object): def __contains__(self, key): return key in self._data - + def update(self, value): self._data.update(value) - + def get(self, key, default=None): if key in self: return self[key] @@ -529,7 +683,10 @@ class OrderedDict(dict): def clear(self): self._list = [] dict.clear(self) - + + def sort(self, fn=None): + self._list.sort(fn) + def update(self, ____sequence=None, **kwargs): if ____sequence is not None: if hasattr(____sequence, 'keys'): @@ -622,22 +779,24 @@ class OrderedSet(Set): if d is not None: self.update(d) - def add(self, key): - if key not in self: - self._list.append(key) - Set.add(self, key) + def add(self, element): + if element not in self: + self._list.append(element) + Set.add(self, element) def remove(self, element): Set.remove(self, element) self._list.remove(element) + def insert(self, pos, element): + if element not in self: + self._list.insert(pos, element) + Set.add(self, element) + def discard(self, element): - try: - Set.remove(self, element) - except KeyError: - pass - else: + if element in self: self._list.remove(element) + Set.remove(self, element) def clear(self): Set.clear(self) @@ -650,22 +809,22 @@ class OrderedSet(Set): return iter(self._list) def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self._list) + return '%s(%r)' % (self.__class__.__name__, self._list) __str__ = __repr__ def update(self, iterable): - add = self.add - for i in iterable: - add(i) - return self + add = self.add + for i in iterable: + add(i) + return self __ior__ = update def union(self, other): - result = self.__class__(self) - result.update(other) - return result + result = self.__class__(self) + result.update(other) + return result __or__ = union @@ -698,10 +857,10 @@ class OrderedSet(Set): __iand__ = intersection_update def symmetric_difference_update(self, other): - Set.symmetric_difference_update(self, other) - self._list = [ a for a in self._list if a in self] - self._list += [ a for a in other._list if a in self] - return self + Set.symmetric_difference_update(self, other) + self._list = [ a for a in self._list if a in self] + self._list += [ a for a in other._list if a in self] + return self __ixor__ = symmetric_difference_update @@ -1021,6 +1180,35 @@ class ScopedRegistry(object): def _get_key(self): return self.scopefunc() +class WeakCompositeKey(object): + """an weak-referencable, hashable collection which is strongly referenced + until any one of its members is garbage collected. + + """ + keys = Set() + + def __init__(self, *args): + self.args = [self.__ref(arg) for arg in args] + WeakCompositeKey.keys.add(self) + + def __ref(self, arg): + if isinstance(arg, type): + return weakref.ref(arg, self.__remover) + else: + return lambda: arg + + def __remover(self, wr): + WeakCompositeKey.keys.discard(self) + + def __hash__(self): + return hash(tuple(self)) + + def __cmp__(self, other): + return cmp(tuple(self), tuple(other)) + + def __iter__(self): + return iter([arg() for arg in self.args]) + class _symbol(object): def __init__(self, name): """Construct a new named symbol.""" @@ -1059,7 +1247,6 @@ class symbol(object): finally: symbol._lock.release() - def as_interface(obj, cls=None, methods=None, required=None): """Ensure basic interface compliance for an instance or dict of callables. @@ -1155,21 +1342,12 @@ def function_named(fn, name): fn.func_defaults, fn.func_closure) return fn -def conditional_cache_decorator(func): - """apply conditional caching to the return value of a function.""" - - return cache_decorator(func, conditional=True) - -def cache_decorator(func, conditional=False): +def cache_decorator(func): """apply caching to the return value of a function.""" name = '_cached_' + func.__name__ - + def do_with_cache(self, *args, **kwargs): - if conditional: - cache = kwargs.pop('cache', False) - if not cache: - return func(self, *args, **kwargs) try: return getattr(self, name) except AttributeError: @@ -1177,21 +1355,109 @@ def cache_decorator(func, conditional=False): setattr(self, name, value) return value return do_with_cache - + def reset_cached(instance, name): try: delattr(instance, '_cached_' + name) except AttributeError: pass +class WeakIdentityMapping(weakref.WeakKeyDictionary): + """A WeakKeyDictionary with an object identity index. + + Adds a .by_id dictionary to a regular WeakKeyDictionary. Trades + performance during mutation operations for accelerated lookups by id(). + + The usual cautions about weak dictionaries and iteration also apply to + this subclass. + + """ + _none = symbol('none') + + def __init__(self): + weakref.WeakKeyDictionary.__init__(self) + self.by_id = {} + self._weakrefs = {} + + def __setitem__(self, object, value): + oid = id(object) + self.by_id[oid] = value + if oid not in self._weakrefs: + self._weakrefs[oid] = self._ref(object) + weakref.WeakKeyDictionary.__setitem__(self, object, value) + + def __delitem__(self, object): + del self._weakrefs[id(object)] + del self.by_id[id(object)] + weakref.WeakKeyDictionary.__delitem__(self, object) + + def setdefault(self, object, default=None): + value = weakref.WeakKeyDictionary.setdefault(self, object, default) + oid = id(object) + if value is default: + self.by_id[oid] = default + if oid not in self._weakrefs: + self._weakrefs[oid] = self._ref(object) + return value + + def pop(self, object, default=_none): + if default is self._none: + value = weakref.WeakKeyDictionary.pop(self, object) + else: + value = weakref.WeakKeyDictionary.pop(self, object, default) + if id(object) in self.by_id: + del self._weakrefs[id(object)] + del self.by_id[id(object)] + return value + + def popitem(self): + item = weakref.WeakKeyDictionary.popitem(self) + oid = id(item[0]) + del self._weakrefs[oid] + del self.by_id[oid] + return item + + def clear(self): + self._weakrefs.clear() + self.by_id.clear() + weakref.WeakKeyDictionary.clear(self) + + def update(self, *a, **kw): + raise NotImplementedError + + def _cleanup(self, wr, key=None): + if key is None: + key = wr.key + try: + del self._weakrefs[key] + except (KeyError, AttributeError): # pragma: no cover + pass # pragma: no cover + try: + del self.by_id[key] + except (KeyError, AttributeError): # pragma: no cover + pass # pragma: no cover + if sys.version_info < (2, 4): # pragma: no cover + def _ref(self, object): + oid = id(object) + return weakref.ref(object, lambda wr: self._cleanup(wr, oid)) + else: + class _keyed_weakref(weakref.ref): + def __init__(self, object, callback): + weakref.ref.__init__(self, object, callback) + self.key = id(object) + + def _ref(self, object): + return self._keyed_weakref(object, self._cleanup) + + def warn(msg): if isinstance(msg, basestring): - warnings.warn(msg, exceptions.SAWarning, stacklevel=3) + warnings.warn(msg, exc.SAWarning, stacklevel=3) else: warnings.warn(msg, stacklevel=3) def warn_deprecated(msg): - warnings.warn(msg, exceptions.SADeprecationWarning, stacklevel=3) + warnings.warn(msg, exc.SADeprecationWarning, stacklevel=3) def deprecated(message=None, add_deprecation_to_docstring=True): """Decorates a function and issues a deprecation warning on use. @@ -1216,7 +1482,7 @@ def deprecated(message=None, add_deprecation_to_docstring=True): def decorate(fn): return _decorate_with_warning( - fn, exceptions.SADeprecationWarning, + fn, exc.SADeprecationWarning, message % dict(func=fn.__name__), header) return decorate @@ -1248,7 +1514,7 @@ def pending_deprecation(version, message=None, def decorate(fn): return _decorate_with_warning( - fn, exceptions.SAPendingDeprecationWarning, + fn, exc.SAPendingDeprecationWarning, message % dict(func=fn.__name__), header) return decorate |