summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-03-28 07:19:14 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-03-28 07:19:14 +0000
commite0b638a704f7f0abf88d1a80a95cf052954e048c (patch)
tree6b8717e491820322c48b5b271ec6b192fa8a86cd /lib/sqlalchemy
parentccbcbda43e74a1d09d50aa2f8212b3cb9adafd23 (diff)
downloadsqlalchemy-e0b638a704f7f0abf88d1a80a95cf052954e048c.tar.gz
- column label and bind param "truncation" also generate
deterministic names now, based on their ordering within the full statement being compiled. this means the same statement will produce the same string across application restarts and allowing DB query plan caching to work better. - cleanup to sql.ClauseParameters since it was just falling apart, API made more explicit - many unit test tweaks to adjust for bind params not being "pre" truncated, changes to ClauseParameters
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/ansisql.py77
-rw-r--r--lib/sqlalchemy/engine/base.py2
-rw-r--r--lib/sqlalchemy/engine/default.py12
-rw-r--r--lib/sqlalchemy/ext/sqlsoup.py4
-rw-r--r--lib/sqlalchemy/orm/mapper.py6
-rw-r--r--lib/sqlalchemy/schema.py1
-rw-r--r--lib/sqlalchemy/sql.py50
7 files changed, 100 insertions, 52 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index 050e605eb..a75263d91 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -132,6 +132,11 @@ class ANSICompiler(sql.Compiled):
# a dictionary of select columns labels mapped to their "generated" label
self.column_labels = {}
+ # a dictionary of ClauseElement subclasses to counters, which are used to
+ # generate truncated identifier names or "anonymous" identifiers such as
+ # for aliases
+ self.generated_ids = {}
+
# True if this compiled represents an INSERT
self.isinsert = False
@@ -242,24 +247,27 @@ class ANSICompiler(sql.Compiled):
return ""
def visit_label(self, label):
- labelname = label.name
- if len(labelname) >= self.dialect.max_identifier_length():
- labelname = labelname[0:self.dialect.max_identifier_length() - 6] + "_" + hex(random.randint(0, 65535))[2:]
+ labelname = self._truncated_identifier("colident", label.name)
if len(self.select_stack):
self.typemap.setdefault(labelname.lower(), label.obj.type)
if isinstance(label.obj, sql._ColumnClause):
- self.column_labels[label.obj._label] = labelname.lower()
+ self.column_labels[label.obj._label] = labelname
self.strings[label] = self.strings[label.obj] + " AS " + self.preparer.format_label(label, labelname)
def visit_column(self, column):
- if len(self.select_stack):
- # if we are within a visit to a Select, set up the "typemap"
- # for this column which is used to translate result set values
- self.typemap.setdefault(column.name.lower(), column.type)
- self.column_labels.setdefault(column._label, column.name.lower())
+ # there is actually somewhat of a ruleset when you would *not* necessarily
+ # want to truncate a column identifier, if its mapped to the name of a
+ # physical column. but thats very hard to identify at this point, and
+ # the identifier length should be greater than the id lengths of any physical
+ # columns so should not matter.
+ if not column.is_literal:
+ name = self._truncated_identifier("colident", column.name)
+ else:
+ name = column.name
+
if column.table is None or not column.table.named_with_column():
- self.strings[column] = self.preparer.format_column(column)
+ self.strings[column] = self.preparer.format_column(column, name=name)
else:
if column.table.oid_column is column:
n = self.dialect.oid_column_name(column)
@@ -270,7 +278,13 @@ class ANSICompiler(sql.Compiled):
else:
self.strings[column] = None
else:
- self.strings[column] = self.preparer.format_column_with_table(column)
+ self.strings[column] = self.preparer.format_column_with_table(column, column_name=name)
+
+ if len(self.select_stack):
+ # if we are within a visit to a Select, set up the "typemap"
+ # for this column which is used to translate result set values
+ self.typemap.setdefault(name.lower(), column.type)
+ self.column_labels.setdefault(column._label, name.lower())
def visit_fromclause(self, fromclause):
self.froms[fromclause] = fromclause.name
@@ -394,11 +408,23 @@ class ANSICompiler(sql.Compiled):
bind_name = bindparam.key
if len(bind_name) >= self.dialect.max_identifier_length():
- bind_name = bind_name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(random.randint(0, 65535))[2:]
+ bind_name = self._truncated_identifier("bindparam", bind_name)
# add to bind_names for translation
self.bind_names[bindparam] = bind_name
return bind_name
-
+
+ def _truncated_identifier(self, ident_class, name):
+ if (ident_class, name) in self.generated_ids:
+ return self.generated_ids[(ident_class, name)]
+ if len(name) >= self.dialect.max_identifier_length():
+ counter = self.generated_ids.get(ident_class, 1)
+ truncname = name[0:self.dialect.max_identifier_length() - 6] + "_" + hex(counter)[2:]
+ self.generated_ids[ident_class] = counter + 1
+ else:
+ truncname = name
+ self.generated_ids[(ident_class, name)] = truncname
+ return truncname
+
def bindparam_string(self, name):
return self.bindtemplate % name
@@ -1043,30 +1069,33 @@ class ANSIIdentifierPreparer(object):
def format_alias(self, alias):
return self.__generic_obj_format(alias, alias.name)
- def format_table(self, table, use_schema=True):
+ def format_table(self, table, use_schema=True, name=None):
"""Prepare a quoted table and schema name."""
- result = self.__generic_obj_format(table, table.name)
+ if name is None:
+ name = table.name
+ result = self.__generic_obj_format(table, name)
if use_schema and getattr(table, "schema", None):
result = self.__generic_obj_format(table, table.schema) + "." + result
return result
- def format_column(self, column, use_table=False):
+ def format_column(self, column, use_table=False, name=None):
"""Prepare a quoted column name."""
-
+ if name is None:
+ name = column.name
if not getattr(column, 'is_literal', False):
if use_table:
- return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, column.name)
+ return self.format_table(column.table, use_schema=False) + "." + self.__generic_obj_format(column, name)
else:
- return self.__generic_obj_format(column, column.name)
+ return self.__generic_obj_format(column, name)
else:
# literal textual elements get stuck into ColumnClause alot, which shouldnt get quoted
if use_table:
- return self.format_table(column.table, use_schema=False) + "." + column.name
+ return self.format_table(column.table, use_schema=False) + "." + name
else:
- return column.name
+ return name
- def format_column_with_table(self, column):
+ def format_column_with_table(self, column, column_name=None):
"""Prepare a quoted column name with table name."""
-
- return self.format_column(column, use_table=True)
+
+ return self.format_column(column, use_table=True, name=column_name)
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 84ad6478f..c2ae272f5 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -883,7 +883,7 @@ class ResultProxy(object):
elif isinstance(key, basestring) and key.lower() in self.props:
rec = self.props[key.lower()]
elif isinstance(key, sql.ColumnElement):
- label = self.column_labels.get(key._label, key.name)
+ label = self.column_labels.get(key._label, key.name).lower()
if label in self.props:
rec = self.props[label]
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 798d02d32..e9ea6c149 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -255,20 +255,20 @@ class DefaultExecutionContext(base.ExecutionContext):
# its a pk, add the value to our last_inserted_ids list,
# or, if its a SQL-side default, dont do any of that, but we'll need
# the SQL-generated value after execution.
- elif not param.has_key(c.key) or param[c.key] is None:
+ elif not c.key in param or param.get_original(c.key) is None:
if isinstance(c.default, schema.PassiveDefault):
self._lastrow_has_defaults = True
newid = drunner.get_column_default(c)
if newid is not None:
- param[c.key] = newid
+ param.set_value(c.key, newid)
if c.primary_key:
- last_inserted_ids.append(param[c.key])
+ last_inserted_ids.append(param.get_processed(c.key))
elif c.primary_key:
need_lastrowid = True
# its an explicitly passed pk value - add it to
# our last_inserted_ids list.
elif c.primary_key:
- last_inserted_ids.append(param[c.key])
+ last_inserted_ids.append(param.get_processed(c.key))
if need_lastrowid:
self._last_inserted_ids = None
else:
@@ -290,8 +290,8 @@ class DefaultExecutionContext(base.ExecutionContext):
pass
# its not in the bind parameters, and theres an "onupdate" defined for the column;
# execute it and add to bind params
- elif c.onupdate is not None and (not param.has_key(c.key) or param[c.key] is None):
+ elif c.onupdate is not None and (not c.key in param or param.get_original(c.key) is None):
value = drunner.get_column_onupdate(c)
if value is not None:
- param[c.key] = value
+ param.set_value(c.key, value)
self._last_updated_params = param
diff --git a/lib/sqlalchemy/ext/sqlsoup.py b/lib/sqlalchemy/ext/sqlsoup.py
index b899c043d..21c1fac51 100644
--- a/lib/sqlalchemy/ext/sqlsoup.py
+++ b/lib/sqlalchemy/ext/sqlsoup.py
@@ -187,13 +187,13 @@ If you join tables that have an identical column name, wrap your join
with `with_labels`, to disambiguate columns with their table name::
>>> db.with_labels(join1).c.keys()
- ['users_name', 'users_email', 'users_password', 'users_classname', 'users_admin', 'loans_book_id', 'loans_user_name', 'loans_loan_date']
+ [u'users_name', u'users_email', u'users_password', u'users_classname', u'users_admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date']
You can also join directly to a labeled object::
>>> labeled_loans = db.with_labels(db.loans)
>>> db.join(db.users, labeled_loans, isouter=True).c.keys()
- ['name', 'email', 'password', 'classname', 'admin', 'loans_book_id', 'loans_user_name', 'loans_loan_date']
+ [u'name', u'email', u'password', u'classname', u'admin', u'loans_book_id', u'loans_user_name', u'loans_loan_date']
Advanced Use
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py
index 3d7ddb5d6..0279cca53 100644
--- a/lib/sqlalchemy/orm/mapper.py
+++ b/lib/sqlalchemy/orm/mapper.py
@@ -1237,13 +1237,13 @@ class Mapper(object):
self.set_attr_by_column(obj, c, row[c])
else:
for c in table.c:
- if c.primary_key or not params.has_key(c.name):
+ if c.primary_key or not c.key in params:
continue
v = self.get_attr_by_column(obj, c, False)
if v is NO_ATTRIBUTE:
continue
- elif v != params.get_original(c.name):
- self.set_attr_by_column(obj, c, params.get_original(c.name))
+ elif v != params.get_original(c.key):
+ self.set_attr_by_column(obj, c, params.get_original(c.key))
def delete_obj(self, objects, uowtransaction):
"""Issue ``DELETE`` statements for a list of objects.
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index 5ed95fabb..ff835cec9 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -613,6 +613,7 @@ class Column(SchemaItem, sql._ColumnClause):
[c._init_items(f) for f in fk]
return c
+
def _case_sens(self):
"""Redirect the `case_sensitive` accessor to use the ultimate
parent column which created this one."""
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index be43bb21b..78d07bec8 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -442,7 +442,7 @@ class AbstractDialect(object):
Used by ``Compiled`` objects."""
pass
-class ClauseParameters(dict):
+class ClauseParameters(object):
"""Represent a dictionary/iterator of bind parameter key names/values.
Tracks the original ``BindParam`` objects as well as the
@@ -453,39 +453,54 @@ class ClauseParameters(dict):
def __init__(self, dialect, positional=None):
super(ClauseParameters, self).__init__(self)
- self.dialect=dialect
+ self.dialect = dialect
self.binds = {}
+ self.binds_to_names = {}
+ self.binds_to_values = {}
self.positional = positional or []
def set_parameter(self, bindparam, value, name):
- self[name] = value
+ self.binds[bindparam.key] = bindparam
self.binds[name] = bindparam
-
+ self.binds_to_names[bindparam] = name
+ self.binds_to_values[bindparam] = value
+
def get_original(self, key):
"""Return the given parameter as it was originally placed in
this ``ClauseParameters`` object, without any ``Type``
conversion."""
+ return self.binds_to_values[self.binds[key]]
- return super(ClauseParameters, self).__getitem__(key)
-
+ def get_processed(self, key):
+ bind = self.binds[key]
+ value = self.binds_to_values[bind]
+ return bind.typeprocess(value, self.dialect)
+
def __getitem__(self, key):
- v = super(ClauseParameters, self).__getitem__(key)
- if self.binds.has_key(key):
- v = self.binds[key].typeprocess(v, self.dialect)
- return v
-
+ return self.get_processed(key)
+
+ def __contains__(self, key):
+ return key in self.binds
+
+ def set_value(self, key, value):
+ bind = self.binds[key]
+ self.binds_to_values[bind] = value
+
def get_original_dict(self):
- return self.copy()
+ return dict([(self.binds_to_names[b], self.binds_to_values[b]) for b in self.binds_to_names.keys()])
def get_raw_list(self):
- return [self[key] for key in self.positional]
+ return [self.get_processed(key) for key in self.positional]
def get_raw_dict(self):
d = {}
- for k in self:
- d[k] = self[k]
+ for k in self.binds_to_names.values():
+ d[k] = self.get_processed(k)
return d
+ def __repr__(self):
+ return repr(self.get_original_dict())
+
class ClauseVisitor(object):
"""A class that knows how to traverse and visit
``ClauseElements``.
@@ -1012,6 +1027,7 @@ class ColumnElement(Selectable, _CompareMixin):
with Selectable objects.
""")
+
def _one_fkey(self):
if len(self._foreign_keys):
return list(self._foreign_keys)[0]
@@ -1037,7 +1053,7 @@ class ColumnElement(Selectable, _CompareMixin):
for a column proxied from a Union (i.e. CompoundSelect), this
set will be just one element.
""")
-
+
def shares_lineage(self, othercolumn):
"""Return True if the given ``ColumnElement`` has a common ancestor to this ``ColumnElement``."""
@@ -1929,6 +1945,8 @@ class _ColumnClause(ColumnElement):
self.__label = "".join([x for x in self.__label if x in legal_characters])
return self.__label
+ is_labeled = property(lambda self:self.name != list(self.orig_set)[0].name)
+
_label = property(_get_label)
def label(self, name):