diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 19 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/visitors.py | 29 |
2 files changed, 22 insertions, 26 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d906bf5d4..1ab0ba405 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -579,8 +579,9 @@ class SQLCompiler(engine.Compiled): else: return fn(" " + operator + " ") - def visit_bindparam(self, bindparam, within_columns_clause=False, + def visit_bindparam(self, bindparam, within_columns_clause=False, literal_binds=False, **kwargs): + if literal_binds or \ (within_columns_clause and \ self.ansi_bind_rules): @@ -591,6 +592,7 @@ class SQLCompiler(engine.Compiled): within_columns_clause=True, **kwargs) name = self._truncate_bindparam(bindparam) + if name in self.binds: existing = self.binds[name] if existing is not bindparam: @@ -600,7 +602,8 @@ class SQLCompiler(engine.Compiled): "unique bind parameter of the same name" % bindparam.key ) - elif getattr(existing, '_is_crud', False): + elif getattr(existing, '_is_crud', False) or \ + getattr(bindparam, '_is_crud', False): raise exc.CompileError( "bindparam() name '%s' is reserved " "for automatic usage in the VALUES or SET " @@ -992,18 +995,8 @@ class SQLCompiler(engine.Compiled): bindparam = sql.bindparam(col.key, value, type_=col.type, required=required) bindparam._is_crud = True - if col.key in self.binds: - raise exc.CompileError( - "bindparam() name '%s' is reserved " - "for automatic usage in the VALUES or SET clause of this " - "insert/update statement. Please use a " - "name other than column name when using bindparam() " - "with insert() or update() (for example, 'b_%s')." - % (col.key, col.key) - ) + return bindparam._compiler_dispatch(self) - self.binds[col.key] = bindparam - return self.bindparam_string(self._truncate_bindparam(bindparam)) def _get_colparams(self, stmt): """create a set of tuples representing column/string pairs for use diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 8011aa109..0c6be97d7 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -44,22 +44,25 @@ class VisitableType(type): super(VisitableType, cls).__init__(clsname, bases, clsdict) return - # set up an optimized visit dispatch function - # for use by the compiler - if '__visit_name__' in cls.__dict__: - visit_name = cls.__visit_name__ - if isinstance(visit_name, str): - getter = operator.attrgetter("visit_%s" % visit_name) - def _compiler_dispatch(self, visitor, **kw): - return getter(visitor)(self, **kw) - else: - def _compiler_dispatch(self, visitor, **kw): - return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) - - cls._compiler_dispatch = _compiler_dispatch + _generate_dispatch(cls) super(VisitableType, cls).__init__(clsname, bases, clsdict) +def _generate_dispatch(cls): + # set up an optimized visit dispatch function + # for use by the compiler + if '__visit_name__' in cls.__dict__: + visit_name = cls.__visit_name__ + if isinstance(visit_name, str): + getter = operator.attrgetter("visit_%s" % visit_name) + def _compiler_dispatch(self, visitor, **kw): + return getter(visitor)(self, **kw) + else: + def _compiler_dispatch(self, visitor, **kw): + return getattr(visitor, 'visit_%s' % self.__visit_name__)(self, **kw) + + cls._compiler_dispatch = _compiler_dispatch + class Visitable(object): """Base class for visitable objects, applies the ``VisitableType`` metaclass. |