summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-04-06 01:15:46 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-04-06 01:15:46 +0000
commit680c27607328a8f89e446601f7bc7ed56394dc27 (patch)
tree4f5fdc632d648cb723373c06a82eba3332c27807 /lib/sqlalchemy/sql.py
parent753b7c2d3ebe8753d70ff8ed33dfbcdddb5e5d29 (diff)
downloadsqlalchemy-680c27607328a8f89e446601f7bc7ed56394dc27.tar.gz
moves the binding of a TypeEngine object from "schema/statement creation" time into "compilation" time
Diffstat (limited to 'lib/sqlalchemy/sql.py')
-rw-r--r--lib/sqlalchemy/sql.py50
1 files changed, 22 insertions, 28 deletions
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index f0171571d..f6e2d03c9 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -139,17 +139,11 @@ def cast(clause, totype, **kwargs):
or
cast(table.c.timestamp, DATE)
"""
- engine = kwargs.get('engine', None)
- if engine is None:
- engine = getattr(clause, 'engine', None)
- if engine is not None:
- totype_desc = engine.type_descriptor(totype)
- # handle non-column clauses (e.g. cast(1234, TEXT)
- if not hasattr(clause, 'label'):
- clause = literal(clause)
- return Function('CAST', clause.label(totype_desc.get_col_spec()), type=totype, **kwargs)
- else:
- raise InvalidRequestError("No engine available, cannot generate cast for " + str(clause) + " to type " + str(totype))
+ # handle non-column clauses (e.g. cast(1234, TEXT)
+ if not hasattr(clause, 'label'):
+ clause = literal(clause)
+ totype = sqltypes.to_instance(totype)
+ return Function('CAST', CompoundClause("AS", clause, TypeClause(totype)), type=totype, **kwargs)
def exists(*args, **params):
params['correlate'] = True
@@ -295,7 +289,8 @@ class ClauseVisitor(object):
def visit_clauselist(self, list):pass
def visit_function(self, func):pass
def visit_label(self, label):pass
-
+ def visit_typeclause(self, typeclause):pass
+
class Compiled(ClauseVisitor):
"""represents a compiled SQL expression. the __str__ method of the Compiled object
should produce the actual text of the statement. Compiled objects are specific to the
@@ -671,13 +666,7 @@ class BindParamClause(ClauseElement, CompareMixin):
self.key = key
self.value = value
self.shortname = shortname
- self.type = type or sqltypes.NULLTYPE
- def _get_convert_type(self, engine):
- try:
- return self._converted_type
- except AttributeError:
- self._converted_type = engine.type_descriptor(self.type)
- return self._converted_type
+ self.type = sqltypes.to_instance(type)
def accept_visitor(self, visitor):
visitor.visit_bindparam(self)
def _get_from_objects(self):
@@ -685,7 +674,7 @@ class BindParamClause(ClauseElement, CompareMixin):
def copy_container(self):
return BindParamClause(self.key, self.value, self.shortname, self.type)
def typeprocess(self, value, engine):
- return self._get_convert_type(engine).convert_bind_param(value, engine)
+ return self.type.engine_impl(engine).convert_bind_param(value, engine)
def compare(self, other):
"""compares this BindParamClause to the given clause.
@@ -695,7 +684,14 @@ class BindParamClause(ClauseElement, CompareMixin):
def _make_proxy(self, selectable, name = None):
return self
# return self.obj._make_proxy(selectable, name=self.name)
-
+
+class TypeClause(ClauseElement):
+ """handles a type keyword in a SQL statement"""
+ def __init__(self, type):
+ self.type = type
+ def accept_visitor(self, visitor):
+ visitor.visit_typeclause(self)
+
class TextClause(ClauseElement):
"""represents literal a SQL text fragment. public constructor is the
text() function.
@@ -714,7 +710,7 @@ class TextClause(ClauseElement):
self.typemap = typemap
if typemap is not None:
for key in typemap.keys():
- typemap[key] = engine.type_descriptor(typemap[key])
+ typemap[key] = sqltypes.to_instance(typemap[key])
def repl(m):
self.bindparams[m.group(1)] = bindparam(m.group(1))
return ":%s" % m.group(1)
@@ -820,11 +816,9 @@ class Function(ClauseList, ColumnElement):
"""describes a SQL function. extends ClauseList to provide comparison operators."""
def __init__(self, name, *clauses, **kwargs):
self.name = name
- self.type = kwargs.get('type', sqltypes.NULLTYPE)
+ self.type = sqltypes.to_instance(kwargs.get('type', None))
self.packagenames = kwargs.get('packagenames', None) or []
self._engine = kwargs.get('engine', None)
- if self._engine is not None:
- self.type = self._engine.type_descriptor(self.type)
ClauseList.__init__(self, parens=True, *clauses)
key = property(lambda self:self.name)
def append(self, clause):
@@ -873,7 +867,7 @@ class BinaryClause(ClauseElement):
self.left = left
self.right = right
self.operator = operator
- self.type = type
+ self.type = sqltypes.to_instance(type)
self.parens = False
if isinstance(self.left, BinaryClause):
self.left.parens = True
@@ -1028,7 +1022,7 @@ class Label(ColumnElement):
while isinstance(obj, Label):
obj = obj.obj
self.obj = obj
- self.type = type or sqltypes.NullTypeEngine()
+ self.type = sqltypes.to_instance(type)
obj.parens=True
key = property(lambda s: s.name)
@@ -1049,7 +1043,7 @@ class ColumnClause(ColumnElement):
def __init__(self, text, selectable=None, type=None):
self.key = self.name = self.text = text
self.table = selectable
- self.type = type or sqltypes.NullTypeEngine()
+ self.type = sqltypes.to_instance(type)
self.__label = None
def _get_label(self):
if self.__label is None: