summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/mapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/mapper.py')
-rw-r--r--lib/sqlalchemy/mapper.py57
1 files changed, 41 insertions, 16 deletions
diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py
index 855c7e289..ae8f70859 100644
--- a/lib/sqlalchemy/mapper.py
+++ b/lib/sqlalchemy/mapper.py
@@ -22,7 +22,7 @@ import sqlalchemy.util as util
import sqlalchemy.objectstore as objectstore
import random, copy, types
-__ALL__ = ['eagermapper', 'eagerloader', 'lazymapper', 'lazyloader', 'eagerload', 'lazyload', 'assignmapper', 'mapper', 'lazyloader', 'lazymapper', 'clear_mappers', 'objectstore', 'sql']
+__ALL__ = ['eagermapper', 'eagerloader', 'lazymapper', 'lazyloader', 'eagerload', 'lazyload', 'assignmapper', 'mapper', 'lazyloader', 'lazymapper', 'clear_mappers', 'objectstore', 'sql', 'MapperExtension']
def relation(*args, **params):
if isinstance(args[0], type) and len(args) == 1:
@@ -32,14 +32,16 @@ def relation(*args, **params):
else:
return relation_mapper(*args, **params)
-def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin = None, lazy = True, **options):
+def relation_loader(mapper, secondary = None, primaryjoin = None, secondaryjoin = None, lazy = True, **kwargs):
if lazy:
- return LazyLoader(mapper, secondary, primaryjoin, secondaryjoin, **options)
+ return LazyLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs)
+ elif lazy is None:
+ return PropertyLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs)
else:
- return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **options)
+ return EagerLoader(mapper, secondary, primaryjoin, secondaryjoin, **kwargs)
-def relation_mapper(class_, table = None, secondary = None, primaryjoin = None, secondaryjoin = None, primarytable = None, properties = None, lazy = True, foreignkey = None, primary_keys = None, thiscol = None, **options):
- return relation_loader(mapper(class_, table, primarytable=primarytable, properties=properties, primary_keys=primary_keys, **options), secondary, primaryjoin, secondaryjoin, lazy = lazy, foreignkey = foreignkey, thiscol = thiscol, **options)
+def relation_mapper(class_, table=None, secondary=None, primaryjoin=None, secondaryjoin=None, **kwargs):
+ return relation_loader(mapper(class_, table, **kwargs), secondary, primaryjoin, secondaryjoin, **kwargs)
class assignmapper(object):
def __init__(self, table, **kwargs):
@@ -110,8 +112,13 @@ class Mapper(object):
is_primary = False,
inherits = None,
inherit_condition = None,
+ extension = None,
**kwargs):
-
+
+ if extension is None:
+ self.extension = MapperExtension()
+ else:
+ self.extension = extension
self.hashkey = hashkey
self.class_ = class_
self.scope = scope
@@ -174,6 +181,7 @@ class Mapper(object):
if properties is not None:
for key, prop in properties.iteritems():
if isinstance(prop, schema.Column):
+ self.columns[key] = prop
prop = ColumnProperty(prop)
self.props[key] = prop
if isinstance(prop, ColumnProperty):
@@ -264,8 +272,11 @@ class Mapper(object):
except IndexError:
return None
- def identity_key(self, instance):
- return objectstore.get_id_key(tuple([self._getattrbycolumn(instance, column) for column in self.primary_keys[self.table]]), self.class_, self.primarytable)
+ def identity_key(self, *primary_keys):
+ return objectstore.get_id_key(tuple(primary_keys), self.class_, self.primarytable)
+
+ def instance_key(self, instance):
+ return self.identity_key(**[self._getattrbycolumn(instance, column) for column in self.primary_keys[self.table]])
def compile(self, whereclause = None, **options):
"""works like select, except returns the SQL statement object without
@@ -418,8 +429,6 @@ class Mapper(object):
identitykey = self._identity_key(row)
if objectstore.uow().has_key(identitykey):
instance = objectstore.uow()._get(identitykey)
- if result is not None:
- result.append_nohistory(instance)
if populate_existing:
isnew = not imap.has_key(identitykey)
@@ -428,6 +437,10 @@ class Mapper(object):
for prop in self.props.values():
prop.execute(instance, row, identitykey, imap, isnew)
+ if self.extension.append_result(self, row, imap, result, instance, populate_existing=populate_existing):
+ if result is not None:
+ result.append_nohistory(instance)
+
return instance
# look in result-local identitymap for it.
@@ -439,7 +452,8 @@ class Mapper(object):
if row[col.label] is None:
return None
# plugin point
- instance = self.class_()
+ if self.extension.create_instance(self, row, imap, self.class_) is None:
+ instance = self.class_()
instance._mapper = self.hashkey
instance._instance_key = identitykey
@@ -449,8 +463,6 @@ class Mapper(object):
instance = imap[identitykey]
isnew = False
- if result is not None:
- result.append_nohistory(instance)
# plugin point
@@ -458,6 +470,10 @@ class Mapper(object):
# instances from the row and possibly populate this item.
for prop in self.props.values():
prop.execute(instance, row, identitykey, imap, isnew)
+
+ if self.extension.append_result(self, row, imap, result, instance, populate_existing=populate_existing):
+ if result is not None:
+ result.append_nohistory(instance)
return instance
@@ -533,7 +549,7 @@ class PropertyLoader(MapperProperty):
"""describes an object property that holds a single item or list of items that correspond to a related
database table."""
- def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False, thiscol = None):
+ def __init__(self, argument, secondary, primaryjoin, secondaryjoin, foreignkey = None, uselist = None, private = False, thiscol = None, **kwargs):
self.uselist = uselist
self.argument = argument
self.secondary = secondary
@@ -806,7 +822,10 @@ class PropertyLoader(MapperProperty):
associationrow[colmap[self.secondary].key] = self.parent._getattrbycolumn(obj, colmap[self.parent.primarytable])
elif colmap.has_key(self.target) and colmap.has_key(self.secondary):
associationrow[colmap[self.secondary].key] = self.mapper._getattrbycolumn(child, colmap[self.target])
-
+
+ def execute(self, instance, row, identitykey, imap, isnew):
+ pass
+
class LazyLoader(PropertyLoader):
def execute(self, instance, row, identitykey, imap, isnew):
if isnew:
@@ -962,6 +981,12 @@ class BinaryVisitor(sql.ClauseVisitor):
def visit_binary(self, binary):
self.func(binary)
+
+class MapperExtension(object):
+ def create_instance(self, mapper, row, imap, class_):
+ return None
+ def append_result(self, mapper, row, imap, result, instance, populate_existing=False):
+ return True
def hash_key(obj):
if obj is None: