summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-04-06 01:15:46 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-04-06 01:15:46 +0000
commit680c27607328a8f89e446601f7bc7ed56394dc27 (patch)
tree4f5fdc632d648cb723373c06a82eba3332c27807 /lib/sqlalchemy
parent753b7c2d3ebe8753d70ff8ed33dfbcdddb5e5d29 (diff)
downloadsqlalchemy-680c27607328a8f89e446601f7bc7ed56394dc27.tar.gz
moves the binding of a TypeEngine object from "schema/statement creation" time into "compilation" time
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/ansisql.py5
-rw-r--r--lib/sqlalchemy/databases/firebird.py2
-rw-r--r--lib/sqlalchemy/databases/mssql.py2
-rw-r--r--lib/sqlalchemy/databases/mysql.py2
-rw-r--r--lib/sqlalchemy/databases/oracle.py2
-rw-r--r--lib/sqlalchemy/databases/postgres.py2
-rw-r--r--lib/sqlalchemy/databases/sqlite.py2
-rw-r--r--lib/sqlalchemy/engine.py4
-rw-r--r--lib/sqlalchemy/ext/proxy.py39
-rw-r--r--lib/sqlalchemy/schema.py1
-rw-r--r--lib/sqlalchemy/sql.py50
-rw-r--r--lib/sqlalchemy/types.py24
12 files changed, 55 insertions, 80 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index dfc15a383..40e946651 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -189,7 +189,10 @@ class ANSICompiler(sql.Compiled):
def visit_index(self, index):
self.strings[index] = index.name
-
+
+ def visit_typeclause(self, typeclause):
+ self.strings[typeclause] = typeclause.type.engine_impl(self.engine).get_col_spec()
+
def visit_textclause(self, textclause):
if textclause.parens and len(textclause.text):
self.strings[textclause] = "(" + textclause.text + ")"
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py
index 7d5cfed11..7dc48a54a 100644
--- a/lib/sqlalchemy/databases/firebird.py
+++ b/lib/sqlalchemy/databases/firebird.py
@@ -238,7 +238,7 @@ class FBCompiler(ansisql.ANSICompiler):
class FBSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
- colspec += " " + column.type.get_col_spec()
+ colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py
index 582ed9002..6a7ef91b3 100644
--- a/lib/sqlalchemy/databases/mssql.py
+++ b/lib/sqlalchemy/databases/mssql.py
@@ -460,7 +460,7 @@ class MSSQLCompiler(ansisql.ANSICompiler):
class MSSQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
- colspec = column.name + " " + column.type.get_col_spec()
+ colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
# install a IDENTITY Sequence if we have an implicit IDENTITY column
if column.primary_key and isinstance(column.type, types.Integer):
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index c55da97cb..a25a21e9b 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -263,7 +263,7 @@ class MySQLCompiler(ansisql.ANSICompiler):
class MySQLSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, first_pk=False):
- colspec = column.name + " " + column.type.get_col_spec()
+ colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index c673cb961..a475d29b7 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -306,7 +306,7 @@ class OracleCompiler(ansisql.ANSICompiler):
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
colspec = column.name
- colspec += " " + column.type.get_col_spec()
+ colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py
index 72d426012..a7285b4b5 100644
--- a/lib/sqlalchemy/databases/postgres.py
+++ b/lib/sqlalchemy/databases/postgres.py
@@ -305,7 +305,7 @@ class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
if column.primary_key and isinstance(column.type, types.Integer) and (column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional)):
colspec += " SERIAL"
else:
- colspec += " " + column.type.get_col_spec()
+ colspec += " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
index 0e208854e..a7536ee4e 100644
--- a/lib/sqlalchemy/databases/sqlite.py
+++ b/lib/sqlalchemy/databases/sqlite.py
@@ -241,7 +241,7 @@ class SQLiteCompiler(ansisql.ANSICompiler):
class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, override_pk=False, **kwargs):
- colspec = column.name + " " + column.type.get_col_spec()
+ colspec = column.name + " " + column.type.engine_impl(self.engine).get_col_spec()
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py
index 727ee30ad..97c710762 100644
--- a/lib/sqlalchemy/engine.py
+++ b/lib/sqlalchemy/engine.py
@@ -319,7 +319,7 @@ class SQLEngine(schema.SchemaEngine):
self.positional = True
else:
raise DBAPIError("Unsupported paramstyle '%s'" % self._paramstyle)
-
+
def type_descriptor(self, typeobj):
"""provides a database-specific TypeEngine object, given the generic object
which comes from the types module. Subclasses will usually use the adapt_type()
@@ -808,7 +808,7 @@ class ResultProxy:
rec = self.props[key.lower()]
else:
rec = self.props[key]
- return rec[0].convert_result_value(row[rec[1]], self.engine)
+ return rec[0].engine_impl(self.engine).convert_result_value(row[rec[1]], self.engine)
def __iter__(self):
while True:
diff --git a/lib/sqlalchemy/ext/proxy.py b/lib/sqlalchemy/ext/proxy.py
index 2ca3116c1..38325bea3 100644
--- a/lib/sqlalchemy/ext/proxy.py
+++ b/lib/sqlalchemy/ext/proxy.py
@@ -36,11 +36,6 @@ class BaseProxyEngine(schema.SchemaEngine):
return None
return e.oid_column_name()
- def type_descriptor(self, typeobj):
- """Proxy point: return a ProxyTypeEngine
- """
- return ProxyTypeEngine(self, typeobj)
-
def __getattr__(self, attr):
# call get_engine() to give subclasses a chance to change
# connection establishment behavior
@@ -116,37 +111,3 @@ class ProxyEngine(BaseProxyEngine):
self.storage.engine = engine
-class ProxyType(object):
- """ProxyType base class; used by ProxyTypeEngine to construct proxying
- types
- """
- def __init__(self, engine, typeobj):
- self._engine = engine
- self.typeobj = typeobj
-
- def __getattribute__(self, attr):
- if attr.startswith('__') and attr.endswith('__'):
- return object.__getattribute__(self, attr)
-
- engine = object.__getattribute__(self, '_engine').engine
- typeobj = object.__getattribute__(self, 'typeobj')
- return getattr(engine.type_descriptor(typeobj), attr)
-
- def __repr__(self):
- return '<Proxy %s>' % (object.__getattribute__(self, 'typeobj'))
-
-class ProxyTypeEngine(object):
- """Proxy type engine; creates dynamic proxy type subclass that is instance
- of actual type, but proxies engine-dependant operations through the proxy
- engine.
- """
- def __new__(cls, engine, typeobj):
- """Create a new subclass of ProxyType and typeobj
- so that internal isinstance() calls will get the expected result.
- """
- if isinstance(typeobj, type):
- typeclass = typeobj
- else:
- typeclass = typeobj.__class__
- typed = type('ProxyTypeHelper', (ProxyType, typeclass), {})
- return typed(engine, typeobj)
diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py
index eabfee9bb..24392b3d9 100644
--- a/lib/sqlalchemy/schema.py
+++ b/lib/sqlalchemy/schema.py
@@ -163,7 +163,6 @@ class Table(sql.TableClause, SchemaItem):
if column.primary_key:
self.primary_key.append(column)
column.table = self
- column.type = self.engine.type_descriptor(column.type)
def append_index(self, index):
self.indexes[index.name] = index
diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py
index f0171571d..f6e2d03c9 100644
--- a/lib/sqlalchemy/sql.py
+++ b/lib/sqlalchemy/sql.py
@@ -139,17 +139,11 @@ def cast(clause, totype, **kwargs):
or
cast(table.c.timestamp, DATE)
"""
- engine = kwargs.get('engine', None)
- if engine is None:
- engine = getattr(clause, 'engine', None)
- if engine is not None:
- totype_desc = engine.type_descriptor(totype)
- # handle non-column clauses (e.g. cast(1234, TEXT)
- if not hasattr(clause, 'label'):
- clause = literal(clause)
- return Function('CAST', clause.label(totype_desc.get_col_spec()), type=totype, **kwargs)
- else:
- raise InvalidRequestError("No engine available, cannot generate cast for " + str(clause) + " to type " + str(totype))
+ # handle non-column clauses (e.g. cast(1234, TEXT)
+ if not hasattr(clause, 'label'):
+ clause = literal(clause)
+ totype = sqltypes.to_instance(totype)
+ return Function('CAST', CompoundClause("AS", clause, TypeClause(totype)), type=totype, **kwargs)
def exists(*args, **params):
params['correlate'] = True
@@ -295,7 +289,8 @@ class ClauseVisitor(object):
def visit_clauselist(self, list):pass
def visit_function(self, func):pass
def visit_label(self, label):pass
-
+ def visit_typeclause(self, typeclause):pass
+
class Compiled(ClauseVisitor):
"""represents a compiled SQL expression. the __str__ method of the Compiled object
should produce the actual text of the statement. Compiled objects are specific to the
@@ -671,13 +666,7 @@ class BindParamClause(ClauseElement, CompareMixin):
self.key = key
self.value = value
self.shortname = shortname
- self.type = type or sqltypes.NULLTYPE
- def _get_convert_type(self, engine):
- try:
- return self._converted_type
- except AttributeError:
- self._converted_type = engine.type_descriptor(self.type)
- return self._converted_type
+ self.type = sqltypes.to_instance(type)
def accept_visitor(self, visitor):
visitor.visit_bindparam(self)
def _get_from_objects(self):
@@ -685,7 +674,7 @@ class BindParamClause(ClauseElement, CompareMixin):
def copy_container(self):
return BindParamClause(self.key, self.value, self.shortname, self.type)
def typeprocess(self, value, engine):
- return self._get_convert_type(engine).convert_bind_param(value, engine)
+ return self.type.engine_impl(engine).convert_bind_param(value, engine)
def compare(self, other):
"""compares this BindParamClause to the given clause.
@@ -695,7 +684,14 @@ class BindParamClause(ClauseElement, CompareMixin):
def _make_proxy(self, selectable, name = None):
return self
# return self.obj._make_proxy(selectable, name=self.name)
-
+
+class TypeClause(ClauseElement):
+ """handles a type keyword in a SQL statement"""
+ def __init__(self, type):
+ self.type = type
+ def accept_visitor(self, visitor):
+ visitor.visit_typeclause(self)
+
class TextClause(ClauseElement):
"""represents literal a SQL text fragment. public constructor is the
text() function.
@@ -714,7 +710,7 @@ class TextClause(ClauseElement):
self.typemap = typemap
if typemap is not None:
for key in typemap.keys():
- typemap[key] = engine.type_descriptor(typemap[key])
+ typemap[key] = sqltypes.to_instance(typemap[key])
def repl(m):
self.bindparams[m.group(1)] = bindparam(m.group(1))
return ":%s" % m.group(1)
@@ -820,11 +816,9 @@ class Function(ClauseList, ColumnElement):
"""describes a SQL function. extends ClauseList to provide comparison operators."""
def __init__(self, name, *clauses, **kwargs):
self.name = name
- self.type = kwargs.get('type', sqltypes.NULLTYPE)
+ self.type = sqltypes.to_instance(kwargs.get('type', None))
self.packagenames = kwargs.get('packagenames', None) or []
self._engine = kwargs.get('engine', None)
- if self._engine is not None:
- self.type = self._engine.type_descriptor(self.type)
ClauseList.__init__(self, parens=True, *clauses)
key = property(lambda self:self.name)
def append(self, clause):
@@ -873,7 +867,7 @@ class BinaryClause(ClauseElement):
self.left = left
self.right = right
self.operator = operator
- self.type = type
+ self.type = sqltypes.to_instance(type)
self.parens = False
if isinstance(self.left, BinaryClause):
self.left.parens = True
@@ -1028,7 +1022,7 @@ class Label(ColumnElement):
while isinstance(obj, Label):
obj = obj.obj
self.obj = obj
- self.type = type or sqltypes.NullTypeEngine()
+ self.type = sqltypes.to_instance(type)
obj.parens=True
key = property(lambda s: s.name)
@@ -1049,7 +1043,7 @@ class ColumnClause(ColumnElement):
def __init__(self, text, selectable=None, type=None):
self.key = self.name = self.text = text
self.table = selectable
- self.type = type or sqltypes.NullTypeEngine()
+ self.type = sqltypes.to_instance(type)
self.__label = None
def _get_label(self):
if self.__label is None:
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index ecf791a37..7a3822a65 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -16,11 +16,22 @@ try:
import cPickle as pickle
except:
import pickle
-
+
class TypeEngine(object):
- basetypes = []
def __init__(self, *args, **kwargs):
pass
+ def _get_impl_dict(self):
+ try:
+ return self._impl_dict
+ except AttributeError:
+ self._impl_dict = {}
+ return self._impl_dict
+ impl_dict = property(_get_impl_dict)
+ def engine_impl(self, engine):
+ try:
+ return self.impl_dict[engine]
+ except:
+ return self.impl_dict.setdefault(engine, engine.type_descriptor(self))
def _get_impl(self):
if hasattr(self, '_impl'):
return self._impl
@@ -41,7 +52,14 @@ class TypeEngine(object):
return {}
def adapt_args(self):
return self
-
+
+def to_instance(typeobj):
+ if typeobj is None:
+ return NULLTYPE
+ elif isinstance(typeobj, type):
+ return typeobj()
+ else:
+ return typeobj
def adapt_type(typeobj, colspecs):
if isinstance(typeobj, type):
typeobj = typeobj()