diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-03-31 04:27:05 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-03-31 04:27:05 +0000 |
commit | 90ffb177ed88cac43d4c3cbdc568d0d0a93fd579 (patch) | |
tree | 50366ca0dd25b1acaadde48d0045260173a11ad8 /lib/sqlalchemy/mapping/mapper.py | |
parent | 17e341e36714fa4d87b1dc6e95618538ce038161 (diff) | |
download | sqlalchemy-90ffb177ed88cac43d4c3cbdc568d0d0a93fd579.tar.gz |
starting to refactor mapper slightly, adding entity_name, version_id_col, allowing keywords in mapper.options()
Diffstat (limited to 'lib/sqlalchemy/mapping/mapper.py')
-rw-r--r-- | lib/sqlalchemy/mapping/mapper.py | 51 |
1 files changed, 23 insertions, 28 deletions
diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index f8faea855..7e12459c5 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -40,7 +40,9 @@ class Mapper(object): extension = None, order_by = False, allow_column_override = False, + entity_name = None, always_refresh = False, + version_id_col = None, **kwargs): if primarytable is not None: @@ -55,6 +57,8 @@ class Mapper(object): self.order_by = order_by self._options = {} self.always_refresh = always_refresh + self.entity_name = entity_name + self.version_id_col = version_id_col if not issubclass(class_, object): raise ArgumentError("Class '%s' is not a new-style class" % class_.__name__) @@ -85,7 +89,7 @@ class Mapper(object): # stricter set of tables to create "sync rules" by,based on the immediate # inherited table, rather than all inherited tables self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY) - self._synchronizer.compile(self.table.onclause, util.HashSet([inherits.noninherited_table]), TableFinder(table)) + self._synchronizer.compile(self.table.onclause, util.HashSet([inherits.noninherited_table]), mapperutil.TableFinder(table)) # the old rule #self._synchronizer.compile(self.table.onclause, inherits.tables, TableFinder(table)) else: @@ -100,7 +104,7 @@ class Mapper(object): # locate all tables contained within the "table" passed in, which # may be a join or other construct - self.tables = TableFinder(self.table) + self.tables = mapperutil.TableFinder(self.table) # determine primary key columns, either passed in, or get them from our set of tables self.pks_by_table = {} @@ -350,9 +354,10 @@ class Mapper(object): compiling or executing it""" return self._compile(whereclause, **options) - def copy(self): + def copy(self, **kwargs): mapper = Mapper.__new__(Mapper) mapper.__dict__.update(self.__dict__) + mapper.__dict__.update(kwargs) mapper.props = self.props.copy() return mapper @@ -374,7 +379,7 @@ class Mapper(object): return callit return Proxy() - def options(self, *options): + def options(self, *options, **kwargs): """uses this mapper as a prototype for a new mapper with different behavior. *options is a list of options directives, which include eagerload(), lazyload(), and noload()""" @@ -382,7 +387,7 @@ class Mapper(object): try: return self._options[optkey] except KeyError: - mapper = self.copy() + mapper = self.copy(**kwargs) for option in options: option.process(mapper) self._options[optkey] = mapper @@ -610,7 +615,13 @@ class Mapper(object): self.extension.before_update(self, obj) hasdata = False for col in table.columns: - if self.pks_by_table[table].contains(col): + if col is self.version_id_col: + if not isinsert: + params[col._label] = self._getattrbycolumn(obj, col) + params[col.key] = params[col._label] + 1 + else: + params[col.key] = 1 + elif self.pks_by_table[table].contains(col): # column is a primary key ? if not isinsert: # doing an UPDATE? put primary key values as "WHERE" parameters @@ -664,6 +675,8 @@ class Mapper(object): clause = sql.and_() for col in self.pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col._label)) + if self.version_id_col is not None: + clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col._label)) statement = table.update(clause) rows = 0 for rec in update: @@ -729,11 +742,15 @@ class Mapper(object): delete.append(params) for col in self.pks_by_table[table]: params[col.key] = self._getattrbycolumn(obj, col) + if self.version_id_col is not None: + params[self.version_id_col.key] = self._getattrbycolumn(obj, self.version_id_col) self.extension.before_delete(self, obj) if len(delete): clause = sql.and_() for col in self.pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col.key)) + if self.version_id_col is not None: + clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col.key)) statement = table.delete(clause) c = statement.execute(*delete) if table.engine.supports_sane_rowcount() and c.rowcount != len(delete): @@ -1036,28 +1053,6 @@ class MapperExtension(object): if self.next is not None: self.next.before_delete(mapper, instance) -class TableFinder(sql.ClauseVisitor): - """given a Clause, locates all the Tables within it into a list.""" - def __init__(self, table, check_columns=False): - self.tables = [] - self.check_columns = check_columns - if table is not None: - table.accept_visitor(self) - def visit_table(self, table): - self.tables.append(table) - def __len__(self): - return len(self.tables) - def __getitem__(self, i): - return self.tables[i] - def __iter__(self): - return iter(self.tables) - def __contains__(self, obj): - return obj in self.tables - def __add__(self, obj): - return self.tables + list(obj) - def visit_column(self, column): - if self.check_columns: - column.table.accept_visitor(self) def hash_key(obj): if obj is None: |