summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/util.py')
-rw-r--r--lib/sqlalchemy/sql/util.py36
1 files changed, 24 insertions, 12 deletions
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 0989cb43e..93998c9a9 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -52,30 +52,42 @@ def find_columns(clause):
def reduce_columns(columns, *clauses):
- raise NotImplementedError()
+ """given a list of columns, return a 'reduced' set based on natural equivalents.
+
+ the set is reduced to the smallest list of columns which have no natural
+ equivalent present in the list. A "natural equivalent" means that two columns
+ will ultimately represent the same value because they are related by a foreign key.
+
+ \*clauses is an optional list of join clauses which will be traversed
+ to further identify columns that are "equivalent".
- # TODO !!!
- all_proxied_cols = util.Set(chain(*[c.proxy_set for c in columns]))
+ This function is primarily used to determine the most minimal "primary key"
+ from a selectable, by reducing the set of primary key columns present
+ in the the selectable to just those that are not repeated.
+
+ """
columns = util.Set(columns)
- equivs = {}
+ omit = util.Set()
for col in columns:
for fk in col.foreign_keys:
- if fk.column in all_proxied_cols:
- for c in columns:
- if col.references(c):
- equivs[col] = c
+ for c in columns:
+ if c is col:
+ continue
+ if fk.column.shares_lineage(c):
+ omit.add(col)
+ break
if clauses:
def visit_binary(binary):
- if binary.operator == operators.eq and binary.left in columns and binary.right in columns:
- equivs[binary.left] = binary.right
+ cols = columns.difference(omit)
+ if binary.operator == operators.eq and binary.left in cols and binary.right in cols:
+ omit.add(binary.right)
for clause in clauses:
visitors.traverse(clause, visit_binary=visit_binary)
- result = util.Set([c for c in columns if c not in equivs])
- return expression.ColumnSet(result)
+ return expression.ColumnSet(columns.difference(omit))
class ColumnsInClause(visitors.ClauseVisitor):
"""Given a selectable, visit clauses and determine if any columns