summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/mapping/mapper.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-03-31 04:27:05 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-03-31 04:27:05 +0000
commit90ffb177ed88cac43d4c3cbdc568d0d0a93fd579 (patch)
tree50366ca0dd25b1acaadde48d0045260173a11ad8 /lib/sqlalchemy/mapping/mapper.py
parent17e341e36714fa4d87b1dc6e95618538ce038161 (diff)
downloadsqlalchemy-90ffb177ed88cac43d4c3cbdc568d0d0a93fd579.tar.gz
starting to refactor mapper slightly, adding entity_name, version_id_col, allowing keywords in mapper.options()
Diffstat (limited to 'lib/sqlalchemy/mapping/mapper.py')
-rw-r--r--lib/sqlalchemy/mapping/mapper.py51
1 files changed, 23 insertions, 28 deletions
diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py
index f8faea855..7e12459c5 100644
--- a/lib/sqlalchemy/mapping/mapper.py
+++ b/lib/sqlalchemy/mapping/mapper.py
@@ -40,7 +40,9 @@ class Mapper(object):
extension = None,
order_by = False,
allow_column_override = False,
+ entity_name = None,
always_refresh = False,
+ version_id_col = None,
**kwargs):
if primarytable is not None:
@@ -55,6 +57,8 @@ class Mapper(object):
self.order_by = order_by
self._options = {}
self.always_refresh = always_refresh
+ self.entity_name = entity_name
+ self.version_id_col = version_id_col
if not issubclass(class_, object):
raise ArgumentError("Class '%s' is not a new-style class" % class_.__name__)
@@ -85,7 +89,7 @@ class Mapper(object):
# stricter set of tables to create "sync rules" by,based on the immediate
# inherited table, rather than all inherited tables
self._synchronizer = sync.ClauseSynchronizer(self, self, sync.ONETOMANY)
- self._synchronizer.compile(self.table.onclause, util.HashSet([inherits.noninherited_table]), TableFinder(table))
+ self._synchronizer.compile(self.table.onclause, util.HashSet([inherits.noninherited_table]), mapperutil.TableFinder(table))
# the old rule
#self._synchronizer.compile(self.table.onclause, inherits.tables, TableFinder(table))
else:
@@ -100,7 +104,7 @@ class Mapper(object):
# locate all tables contained within the "table" passed in, which
# may be a join or other construct
- self.tables = TableFinder(self.table)
+ self.tables = mapperutil.TableFinder(self.table)
# determine primary key columns, either passed in, or get them from our set of tables
self.pks_by_table = {}
@@ -350,9 +354,10 @@ class Mapper(object):
compiling or executing it"""
return self._compile(whereclause, **options)
- def copy(self):
+ def copy(self, **kwargs):
mapper = Mapper.__new__(Mapper)
mapper.__dict__.update(self.__dict__)
+ mapper.__dict__.update(kwargs)
mapper.props = self.props.copy()
return mapper
@@ -374,7 +379,7 @@ class Mapper(object):
return callit
return Proxy()
- def options(self, *options):
+ def options(self, *options, **kwargs):
"""uses this mapper as a prototype for a new mapper with different behavior.
*options is a list of options directives, which include eagerload(), lazyload(), and noload()"""
@@ -382,7 +387,7 @@ class Mapper(object):
try:
return self._options[optkey]
except KeyError:
- mapper = self.copy()
+ mapper = self.copy(**kwargs)
for option in options:
option.process(mapper)
self._options[optkey] = mapper
@@ -610,7 +615,13 @@ class Mapper(object):
self.extension.before_update(self, obj)
hasdata = False
for col in table.columns:
- if self.pks_by_table[table].contains(col):
+ if col is self.version_id_col:
+ if not isinsert:
+ params[col._label] = self._getattrbycolumn(obj, col)
+ params[col.key] = params[col._label] + 1
+ else:
+ params[col.key] = 1
+ elif self.pks_by_table[table].contains(col):
# column is a primary key ?
if not isinsert:
# doing an UPDATE? put primary key values as "WHERE" parameters
@@ -664,6 +675,8 @@ class Mapper(object):
clause = sql.and_()
for col in self.pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col._label))
+ if self.version_id_col is not None:
+ clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col._label))
statement = table.update(clause)
rows = 0
for rec in update:
@@ -729,11 +742,15 @@ class Mapper(object):
delete.append(params)
for col in self.pks_by_table[table]:
params[col.key] = self._getattrbycolumn(obj, col)
+ if self.version_id_col is not None:
+ params[self.version_id_col.key] = self._getattrbycolumn(obj, self.version_id_col)
self.extension.before_delete(self, obj)
if len(delete):
clause = sql.and_()
for col in self.pks_by_table[table]:
clause.clauses.append(col == sql.bindparam(col.key))
+ if self.version_id_col is not None:
+ clause.clauses.append(self.version_id_col == sql.bindparam(self.version_id_col.key))
statement = table.delete(clause)
c = statement.execute(*delete)
if table.engine.supports_sane_rowcount() and c.rowcount != len(delete):
@@ -1036,28 +1053,6 @@ class MapperExtension(object):
if self.next is not None:
self.next.before_delete(mapper, instance)
-class TableFinder(sql.ClauseVisitor):
- """given a Clause, locates all the Tables within it into a list."""
- def __init__(self, table, check_columns=False):
- self.tables = []
- self.check_columns = check_columns
- if table is not None:
- table.accept_visitor(self)
- def visit_table(self, table):
- self.tables.append(table)
- def __len__(self):
- return len(self.tables)
- def __getitem__(self, i):
- return self.tables[i]
- def __iter__(self):
- return iter(self.tables)
- def __contains__(self, obj):
- return obj in self.tables
- def __add__(self, obj):
- return self.tables + list(obj)
- def visit_column(self, column):
- if self.check_columns:
- column.table.accept_visitor(self)
def hash_key(obj):
if obj is None: