summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2007-12-05 03:07:21 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2007-12-05 03:07:21 +0000
commit238c2c8dbe3ca5b92d298b39e96f81eb416d1413 (patch)
treeb123393efbbb06a1e0ebc84385f5964efa98f0b1 /lib/sqlalchemy/sql
parentc6bda7dcc89ae5f7842f0e900d3917024a74eb29 (diff)
downloadsqlalchemy-238c2c8dbe3ca5b92d298b39e96f81eb416d1413.tar.gz
- basic framework for generic functions, [ticket:615]
- changed the various "literal" generation functions to use an anonymous bind parameter. not much changes here except their labels now look like ":param_1", ":param_2" instead of ":literal" - from_obj keyword argument to select() can be a scalar or a list.
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py52
-rw-r--r--lib/sqlalchemy/sql/expression.py37
-rw-r--r--lib/sqlalchemy/sql/functions.py78
3 files changed, 135 insertions, 32 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index d94d3fef6..ea7c1b734 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -8,15 +8,10 @@
import string, re
from sqlalchemy import schema, engine, util, exceptions
-from sqlalchemy.sql import operators, visitors
+from sqlalchemy.sql import operators, visitors, functions
from sqlalchemy.sql import util as sql_util
from sqlalchemy.sql import expression as sql
-ANSI_FUNCS = util.Set([
- 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
- 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
- 'SESSION_USER', 'USER'])
-
RESERVED_WORDS = util.Set([
'all', 'analyse', 'analyze', 'and', 'any', 'array',
'as', 'asc', 'asymmetric', 'authorization', 'between',
@@ -87,6 +82,18 @@ OPERATORS = {
operators.isnot : 'IS NOT'
}
+FUNCTIONS = {
+ functions.coalesce : 'coalesce%(expr)s',
+ functions.current_date: 'CURRENT_DATE',
+ functions.current_time: 'CURRENT_TIME',
+ functions.current_timestamp: 'CURRENT_TIMESTAMP',
+ functions.current_user: 'CURRENT_USER',
+ functions.localtime: 'LOCALTIME',
+ functions.localtimestamp: 'LOCALTIMESTAMP',
+ functions.session_user :'SESSION_USER',
+ functions.user: 'USER'
+}
+
class DefaultCompiler(engine.Compiled):
"""Default implementation of Compiled.
@@ -97,6 +104,7 @@ class DefaultCompiler(engine.Compiled):
__traverse_options__ = {'column_collections':False, 'entry':True}
operators = OPERATORS
+ functions = FUNCTIONS
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
"""Construct a new ``DefaultCompiler`` object.
@@ -215,7 +223,7 @@ class DefaultCompiler(engine.Compiled):
labelname = self._truncated_identifier("colident", label.name)
if result_map is not None:
- result_map[labelname] = (label.name, (label, label.obj), label.obj.type)
+ result_map[labelname.lower()] = (label.name, (label, label.obj), label.obj.type)
return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)])
@@ -231,7 +239,7 @@ class DefaultCompiler(engine.Compiled):
name = column.name
if result_map is not None:
- result_map[name] = (name, (column, ), column.type)
+ result_map[name.lower()] = (name, (column, ), column.type)
if column._is_oid:
n = self.dialect.oid_column_name(column)
@@ -270,7 +278,7 @@ class DefaultCompiler(engine.Compiled):
def visit_textclause(self, textclause, **kwargs):
if textclause.typemap is not None:
for colname, type_ in textclause.typemap.iteritems():
- self.result_map[colname] = (colname, None, type_)
+ self.result_map[colname.lower()] = (colname, None, type_)
def do_bindparam(m):
name = m.group(1)
@@ -297,9 +305,6 @@ class DefaultCompiler(engine.Compiled):
sep = " " + self.operator_string(clauselist.operator) + " "
return sep.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None])
- def apply_function_parens(self, func):
- return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0
-
def visit_calculatedclause(self, clause, **kwargs):
return self.process(clause.clause_expr)
@@ -308,12 +313,20 @@ class DefaultCompiler(engine.Compiled):
def visit_function(self, func, result_map=None, **kwargs):
if result_map is not None:
- result_map[func.name] = (func.name, None, func.type)
+ result_map[func.name.lower()] = (func.name, None, func.type)
+
+ name = self.function_string(func)
- if not self.apply_function_parens(func):
- return ".".join(func.packagenames + [func.name])
+ if callable(name):
+ return name(*[self.process(x) for x in func.clause_expr])
else:
- return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr)
+ return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)}
+
+ def function_argspec(self, func):
+ return self.process(func.clause_expr)
+
+ def function_string(self, func):
+ return self.functions.get(func.__class__, func.name + "%(expr)s")
def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs):
stack_entry = {'select':cs}
@@ -358,10 +371,11 @@ class DefaultCompiler(engine.Compiled):
def operator_string(self, operator):
return self.operators.get(operator, str(operator))
-
+
def visit_bindparam(self, bindparam, **kwargs):
- # apply truncation to the ultimate generated name
-
+ # TODO: remove this whole "unique" thing, just use regular
+ # anonymous params to implement. params used for inserts/updates
+ # etc. should no longer be "unique".
if bindparam.unique:
count = 1
key = bindparam.key
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index c21b102c7..6ad29218f 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -26,10 +26,11 @@ to stay the same in future releases.
"""
import re
+import datetime
from sqlalchemy import util, exceptions
from sqlalchemy.sql import operators, visitors
from sqlalchemy import types as sqltypes
-
+functions = None
__all__ = [
'Alias', 'ClauseElement',
@@ -596,7 +597,7 @@ def literal(value, type_=None):
bind-parameter translation for this literal.
"""
- return _BindParamClause('literal', value, type_=type_, unique=True)
+ return _BindParamClause(None, value, type_=type_, unique=True)
def label(name, obj):
"""Return a [sqlalchemy.sql.expression#_Label] object for the given [sqlalchemy.sql.expression#ColumnElement].
@@ -766,6 +767,14 @@ class _FunctionGenerator(object):
def __call__(self, *c, **kwargs):
o = self.opts.copy()
o.update(kwargs)
+ if len(self.__names) == 1:
+ global functions
+ if functions is None:
+ from sqlalchemy.sql import functions
+ func = getattr(functions, self.__names[-1].lower(), None)
+ if func is not None:
+ return func(*c, **o)
+
return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
func = _FunctionGenerator()
@@ -798,7 +807,7 @@ def _literal_as_column(element):
else:
return element
-def _literal_as_binds(element, name='literal', type_=None):
+def _literal_as_binds(element, name=None, type_=None):
if isinstance(element, Operators):
return element.expression_element()
elif _is_literal(element):
@@ -1344,7 +1353,7 @@ class _CompareMixin(ColumnOperators):
return lambda other: self.__operate(operator, other)
def _bind_param(self, obj):
- return _BindParamClause('literal', obj, type_=self.type, unique=True)
+ return _BindParamClause(None, obj, type_=self.type, unique=True)
def _check_literal(self, other):
if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType):
@@ -1762,6 +1771,10 @@ class _BindParamClause(ClauseElement, _CompareMixin):
unicode : sqltypes.NCHAR,
int : sqltypes.Integer,
float : sqltypes.Numeric,
+ datetime.date : sqltypes.Date,
+ datetime.datetime : sqltypes.DateTime,
+ datetime.time : sqltypes.Time,
+ datetime.timedelta : sqltypes.Interval,
type(None):sqltypes.NullType
}
@@ -1997,11 +2010,12 @@ class _Function(_CalculatedClause, FromClause):
def __init__(self, name, *clauses, **kwargs):
self.packagenames = kwargs.get('packagenames', None) or []
self.oid_column = None
- kwargs['operator'] = operators.comma_op
- _CalculatedClause.__init__(self, name, **kwargs)
- for c in clauses:
- self.append(c)
-
+ self.name = name
+ self._bind = kwargs.get('bind', None)
+ args = [_literal_as_binds(c, self.name) for c in clauses]
+ self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args).self_group()
+ self.type = sqltypes.to_instance(kwargs.get('type_', None))
+
key = property(lambda self:self.name)
columns = property(lambda self:[self])
@@ -2012,9 +2026,6 @@ class _Function(_CalculatedClause, FromClause):
def get_children(self, **kwargs):
return _CalculatedClause.get_children(self, **kwargs)
- def append(self, clause):
- self.clauses.append(_literal_as_binds(clause, self.name))
-
class _Cast(ColumnElement):
@@ -2984,7 +2995,7 @@ class Select(_SelectBaseMixin, FromClause):
if from_obj:
self._froms = util.Set([
_is_literal(f) and _TextFromClause(f) or f
- for f in from_obj
+ for f in util.to_list(from_obj)
])
else:
self._froms = util.Set()
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
new file mode 100644
index 000000000..d39032b91
--- /dev/null
+++ b/lib/sqlalchemy/sql/functions.py
@@ -0,0 +1,78 @@
+from sqlalchemy import types as sqltypes
+from sqlalchemy.sql.expression import _Function, _literal_as_binds, ClauseList, _FigureVisitName
+from sqlalchemy.sql import operators
+
+class _GenericMeta(_FigureVisitName):
+ def __init__(cls, clsname, bases, dict):
+ cls.__visit_name__ = 'function'
+ type.__init__(cls, clsname, bases, dict)
+
+ def __call__(self, *args, **kwargs):
+ args = [_literal_as_binds(c) for c in args]
+ return type.__call__(self, *args, **kwargs)
+
+class GenericFunction(_Function):
+ __metaclass__ = _GenericMeta
+
+ def __init__(self, type_=None, group=True, args=(), **kwargs):
+ self.packagenames = []
+ self.oid_column = None
+ self.name = self.__class__.__name__
+ self._bind = kwargs.get('bind', None)
+ if group:
+ self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args).self_group()
+ else:
+ self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args)
+ self.type = sqltypes.to_instance(type_ or getattr(self, '__return_type__', None))
+
+class AnsiFunction(GenericFunction):
+ def __init__(self, **kwargs):
+ GenericFunction.__init__(self, **kwargs)
+
+
+class coalesce(GenericFunction):
+ def __init__(self, *args, **kwargs):
+ kwargs.setdefault('type_', _type_from_args(args))
+ GenericFunction.__init__(self, args=args, **kwargs)
+
+class concat(GenericFunction):
+ __return_type__ = sqltypes.String
+ def __init__(self, *args, **kwargs):
+ GenericFunction.__init__(self, args=args, **kwargs)
+
+class char_length(GenericFunction):
+ __return_type__ = sqltypes.Integer
+
+ def __init__(self, arg, **kwargs):
+ GenericFunction.__init__(self, args=[arg], **kwargs)
+
+class current_date(AnsiFunction):
+ __return_type__ = sqltypes.Date
+
+class current_time(AnsiFunction):
+ __return_type__ = sqltypes.Time
+
+class current_timestamp(AnsiFunction):
+ __return_type__ = sqltypes.DateTime
+
+class current_user(AnsiFunction):
+ __return_type__ = sqltypes.String
+
+class localtime(AnsiFunction):
+ __return_type__ = sqltypes.DateTime
+
+class localtimestamp(AnsiFunction):
+ __return_type__ = sqltypes.DateTime
+
+class session_user(AnsiFunction):
+ __return_type__ = sqltypes.String
+
+class user(AnsiFunction):
+ __return_type__ = sqltypes.String
+
+def _type_from_args(args):
+ for a in args:
+ if not isinstance(a.type, sqltypes.NullType):
+ return a.type
+ else:
+ return sqltypes.NullType \ No newline at end of file