summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/sync.py
blob: cf48202b0f9bcf0d3f2748b1c0bed7f12a16d2c6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# mapper/sync.py
# Copyright (C) 2005, 2006, 2007 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

"""Contains the ClauseSynchronizer class, which is used to map
attributes between two objects in a manner corresponding to a SQL
clause that compares column values.
"""

from sqlalchemy import sql, schema, exceptions
from sqlalchemy import logging
from sqlalchemy.orm import util as mapperutil
import operator

ONETOMANY = 0
MANYTOONE = 1
MANYTOMANY = 2

class ClauseSynchronizer(object):
    """Given a SQL clause, usually a series of one or more binary
    expressions between columns, and a set of 'source' and
    'destination' mappers, compiles a set of SyncRules corresponding
    to that information.

    The ClauseSynchronizer can then be executed given a set of
    parent/child objects or destination dictionary, which will iterate
    through each of its SyncRules and execute them.  Each SyncRule
    will copy the value of a single attribute from the parent to the
    child, corresponding to the pair of columns in a particular binary
    expression, using the source and destination mappers to map those
    two columns to object attributes within parent and child.
    """

    def __init__(self, parent_mapper, child_mapper, direction):
        self.parent_mapper = parent_mapper
        self.child_mapper = child_mapper
        self.direction = direction
        self.syncrules = []

    def compile(self, sqlclause, foreign_keys=None, issecondary=None):
        def compile_binary(binary):
            """Assemble a SyncRule given a single binary condition."""

            if binary.operator != operator.eq or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                return

            source_column = None
            dest_column = None

            if foreign_keys is None:
                if binary.left.table == binary.right.table:
                    raise exceptions.ArgumentError("need foreign_keys argument for self-referential sync")

                if binary.left in [f.column for f in binary.right.foreign_keys]:
                    dest_column = binary.right
                    source_column = binary.left
                elif binary.right in [f.column for f in binary.left.foreign_keys]:
                    dest_column = binary.left
                    source_column = binary.right
            else:
                if binary.left in foreign_keys:
                    source_column=binary.right
                    dest_column = binary.left
                elif binary.right in foreign_keys:
                    source_column = binary.left
                    dest_column = binary.right

            if source_column and dest_column:
                if self.direction == ONETOMANY:
                    self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper))
                elif self.direction == MANYTOONE:
                    self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper))
                else:
                    if not issecondary:
                        self.syncrules.append(SyncRule(self.parent_mapper, source_column, dest_column, dest_mapper=self.child_mapper, issecondary=issecondary))
                    else:
                        self.syncrules.append(SyncRule(self.child_mapper, source_column, dest_column, dest_mapper=self.parent_mapper, issecondary=issecondary))

        rules_added = len(self.syncrules)
        BinaryVisitor(compile_binary).traverse(sqlclause)
        if len(self.syncrules) == rules_added:
            raise exceptions.ArgumentError("No syncrules generated for join criterion " + str(sqlclause))

    def dest_columns(self):
        return [r.dest_column for r in self.syncrules if r.dest_column is not None]

    def execute(self, source, dest, obj=None, child=None, clearkeys=None):
        for rule in self.syncrules:
            rule.execute(source, dest, obj, child, clearkeys)

class SyncRule(object):
    """An instruction indicating how to populate the objects on each
    side of a relationship.

    In other words, if table1 column A is joined against table2 column
    B, and we are a one-to-many from table1 to table2, a syncrule
    would say *take the A attribute from object1 and assign it to the
    B attribute on object2*.

    A rule contains the source mapper, the source column, destination
    column, destination mapper in the case of a one/many relationship,
    and the integer direction of this mapper relative to the
    association in the case of a many to many relationship.
    """

    def __init__(self, source_mapper, source_column, dest_column, dest_mapper=None, issecondary=None):
        self.source_mapper = source_mapper
        self.source_column = source_column
        self.issecondary = issecondary
        self.dest_mapper = dest_mapper
        self.dest_column = dest_column

        #print "SyncRule", source_mapper, source_column, dest_column, dest_mapper

    def dest_primary_key(self):
        try:
            return self._dest_primary_key
        except AttributeError:
            self._dest_primary_key = self.dest_mapper is not None and self.dest_column in self.dest_mapper.pks_by_table[self.dest_column.table]
            return self._dest_primary_key

    def execute(self, source, dest, obj, child, clearkeys):
        if source is None:
            if self.issecondary is False:
                source = obj
            elif self.issecondary is True:
                source = child
        if clearkeys or source is None:
            value = None
            clearkeys = True
        else:
            value = self.source_mapper.get_attr_by_column(source, self.source_column)
        if isinstance(dest, dict):
            dest[self.dest_column.key] = value
        else:
            if clearkeys and self.dest_primary_key():
                raise exceptions.AssertionError("Dependency rule tried to blank-out primary key column '%s' on instance '%s'" % (str(self.dest_column), mapperutil.instance_str(dest)))

            if logging.is_debug_enabled(self.logger):
                self.logger.debug("execute() instances: %s(%s)->%s(%s) ('%s')" % (mapperutil.instance_str(source), str(self.source_column), mapperutil.instance_str(dest), str(self.dest_column), value))
            self.dest_mapper.set_attr_by_column(dest, self.dest_column, value)

SyncRule.logger = logging.class_logger(SyncRule)

class BinaryVisitor(sql.ClauseVisitor):
    def __init__(self, func):
        self.func = func

    def visit_binary(self, binary):
        self.func(binary)