diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-12-05 03:07:21 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-12-05 03:07:21 +0000 |
commit | 238c2c8dbe3ca5b92d298b39e96f81eb416d1413 (patch) | |
tree | b123393efbbb06a1e0ebc84385f5964efa98f0b1 /lib/sqlalchemy | |
parent | c6bda7dcc89ae5f7842f0e900d3917024a74eb29 (diff) | |
download | sqlalchemy-238c2c8dbe3ca5b92d298b39e96f81eb416d1413.tar.gz |
- basic framework for generic functions, [ticket:615]
- changed the various "literal" generation functions to use an anonymous
bind parameter. not much changes here except their labels now look
like ":param_1", ":param_2" instead of ":literal"
- from_obj keyword argument to select() can be a scalar or a list.
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 6 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 52 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 37 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 78 |
5 files changed, 139 insertions, 36 deletions
diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 7580a750a..3a13c4b69 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -354,11 +354,11 @@ class FBCompiler(compiler.DefaultCompiler): else: return self.process(alias.original, **kwargs) - def apply_function_parens(self, func): + def function_argspec(self, func): if func.clauses: - return super(FBCompiler, self).apply_function_parens(func) + return self.process(func.clause_expr) else: - return False + return "" def default_from(self): return " FROM rdb$database" diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4e6247810..45e4c036f 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1353,7 +1353,7 @@ class ResultProxy(object): if self.context.result_map: try: - (name, obj, type_) = self.context.result_map[colname] + (name, obj, type_) = self.context.result_map[colname.lower()] except KeyError: (name, obj, type_) = (colname, None, typemap.get(item[1], types.NULLTYPE)) else: diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index d94d3fef6..ea7c1b734 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -8,15 +8,10 @@ import string, re from sqlalchemy import schema, engine, util, exceptions -from sqlalchemy.sql import operators, visitors +from sqlalchemy.sql import operators, visitors, functions from sqlalchemy.sql import util as sql_util from sqlalchemy.sql import expression as sql -ANSI_FUNCS = util.Set([ - 'CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP', - 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP', - 'SESSION_USER', 'USER']) - RESERVED_WORDS = util.Set([ 'all', 'analyse', 'analyze', 'and', 'any', 'array', 'as', 'asc', 'asymmetric', 'authorization', 'between', @@ -87,6 +82,18 @@ OPERATORS = { operators.isnot : 'IS NOT' } +FUNCTIONS = { + functions.coalesce : 'coalesce%(expr)s', + functions.current_date: 'CURRENT_DATE', + functions.current_time: 'CURRENT_TIME', + functions.current_timestamp: 'CURRENT_TIMESTAMP', + functions.current_user: 'CURRENT_USER', + functions.localtime: 'LOCALTIME', + functions.localtimestamp: 'LOCALTIMESTAMP', + functions.session_user :'SESSION_USER', + functions.user: 'USER' +} + class DefaultCompiler(engine.Compiled): """Default implementation of Compiled. @@ -97,6 +104,7 @@ class DefaultCompiler(engine.Compiled): __traverse_options__ = {'column_collections':False, 'entry':True} operators = OPERATORS + functions = FUNCTIONS def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. @@ -215,7 +223,7 @@ class DefaultCompiler(engine.Compiled): labelname = self._truncated_identifier("colident", label.name) if result_map is not None: - result_map[labelname] = (label.name, (label, label.obj), label.obj.type) + result_map[labelname.lower()] = (label.name, (label, label.obj), label.obj.type) return " ".join([self.process(label.obj), self.operator_string(operators.as_), self.preparer.format_label(label, labelname)]) @@ -231,7 +239,7 @@ class DefaultCompiler(engine.Compiled): name = column.name if result_map is not None: - result_map[name] = (name, (column, ), column.type) + result_map[name.lower()] = (name, (column, ), column.type) if column._is_oid: n = self.dialect.oid_column_name(column) @@ -270,7 +278,7 @@ class DefaultCompiler(engine.Compiled): def visit_textclause(self, textclause, **kwargs): if textclause.typemap is not None: for colname, type_ in textclause.typemap.iteritems(): - self.result_map[colname] = (colname, None, type_) + self.result_map[colname.lower()] = (colname, None, type_) def do_bindparam(m): name = m.group(1) @@ -297,9 +305,6 @@ class DefaultCompiler(engine.Compiled): sep = " " + self.operator_string(clauselist.operator) + " " return sep.join([s for s in [self.process(c) for c in clauselist.clauses] if s is not None]) - def apply_function_parens(self, func): - return func.name.upper() not in ANSI_FUNCS or len(func.clauses) > 0 - def visit_calculatedclause(self, clause, **kwargs): return self.process(clause.clause_expr) @@ -308,12 +313,20 @@ class DefaultCompiler(engine.Compiled): def visit_function(self, func, result_map=None, **kwargs): if result_map is not None: - result_map[func.name] = (func.name, None, func.type) + result_map[func.name.lower()] = (func.name, None, func.type) + + name = self.function_string(func) - if not self.apply_function_parens(func): - return ".".join(func.packagenames + [func.name]) + if callable(name): + return name(*[self.process(x) for x in func.clause_expr]) else: - return ".".join(func.packagenames + [func.name]) + (not func.group and " " or "") + self.process(func.clause_expr) + return ".".join(func.packagenames + [name]) % {'expr':self.function_argspec(func)} + + def function_argspec(self, func): + return self.process(func.clause_expr) + + def function_string(self, func): + return self.functions.get(func.__class__, func.name + "%(expr)s") def visit_compound_select(self, cs, asfrom=False, parens=True, **kwargs): stack_entry = {'select':cs} @@ -358,10 +371,11 @@ class DefaultCompiler(engine.Compiled): def operator_string(self, operator): return self.operators.get(operator, str(operator)) - + def visit_bindparam(self, bindparam, **kwargs): - # apply truncation to the ultimate generated name - + # TODO: remove this whole "unique" thing, just use regular + # anonymous params to implement. params used for inserts/updates + # etc. should no longer be "unique". if bindparam.unique: count = 1 key = bindparam.key diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index c21b102c7..6ad29218f 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -26,10 +26,11 @@ to stay the same in future releases. """ import re +import datetime from sqlalchemy import util, exceptions from sqlalchemy.sql import operators, visitors from sqlalchemy import types as sqltypes - +functions = None __all__ = [ 'Alias', 'ClauseElement', @@ -596,7 +597,7 @@ def literal(value, type_=None): bind-parameter translation for this literal. """ - return _BindParamClause('literal', value, type_=type_, unique=True) + return _BindParamClause(None, value, type_=type_, unique=True) def label(name, obj): """Return a [sqlalchemy.sql.expression#_Label] object for the given [sqlalchemy.sql.expression#ColumnElement]. @@ -766,6 +767,14 @@ class _FunctionGenerator(object): def __call__(self, *c, **kwargs): o = self.opts.copy() o.update(kwargs) + if len(self.__names) == 1: + global functions + if functions is None: + from sqlalchemy.sql import functions + func = getattr(functions, self.__names[-1].lower(), None) + if func is not None: + return func(*c, **o) + return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) func = _FunctionGenerator() @@ -798,7 +807,7 @@ def _literal_as_column(element): else: return element -def _literal_as_binds(element, name='literal', type_=None): +def _literal_as_binds(element, name=None, type_=None): if isinstance(element, Operators): return element.expression_element() elif _is_literal(element): @@ -1344,7 +1353,7 @@ class _CompareMixin(ColumnOperators): return lambda other: self.__operate(operator, other) def _bind_param(self, obj): - return _BindParamClause('literal', obj, type_=self.type, unique=True) + return _BindParamClause(None, obj, type_=self.type, unique=True) def _check_literal(self, other): if isinstance(other, _BindParamClause) and isinstance(other.type, sqltypes.NullType): @@ -1762,6 +1771,10 @@ class _BindParamClause(ClauseElement, _CompareMixin): unicode : sqltypes.NCHAR, int : sqltypes.Integer, float : sqltypes.Numeric, + datetime.date : sqltypes.Date, + datetime.datetime : sqltypes.DateTime, + datetime.time : sqltypes.Time, + datetime.timedelta : sqltypes.Interval, type(None):sqltypes.NullType } @@ -1997,11 +2010,12 @@ class _Function(_CalculatedClause, FromClause): def __init__(self, name, *clauses, **kwargs): self.packagenames = kwargs.get('packagenames', None) or [] self.oid_column = None - kwargs['operator'] = operators.comma_op - _CalculatedClause.__init__(self, name, **kwargs) - for c in clauses: - self.append(c) - + self.name = name + self._bind = kwargs.get('bind', None) + args = [_literal_as_binds(c, self.name) for c in clauses] + self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args).self_group() + self.type = sqltypes.to_instance(kwargs.get('type_', None)) + key = property(lambda self:self.name) columns = property(lambda self:[self]) @@ -2012,9 +2026,6 @@ class _Function(_CalculatedClause, FromClause): def get_children(self, **kwargs): return _CalculatedClause.get_children(self, **kwargs) - def append(self, clause): - self.clauses.append(_literal_as_binds(clause, self.name)) - class _Cast(ColumnElement): @@ -2984,7 +2995,7 @@ class Select(_SelectBaseMixin, FromClause): if from_obj: self._froms = util.Set([ _is_literal(f) and _TextFromClause(f) or f - for f in from_obj + for f in util.to_list(from_obj) ]) else: self._froms = util.Set() diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py new file mode 100644 index 000000000..d39032b91 --- /dev/null +++ b/lib/sqlalchemy/sql/functions.py @@ -0,0 +1,78 @@ +from sqlalchemy import types as sqltypes +from sqlalchemy.sql.expression import _Function, _literal_as_binds, ClauseList, _FigureVisitName +from sqlalchemy.sql import operators + +class _GenericMeta(_FigureVisitName): + def __init__(cls, clsname, bases, dict): + cls.__visit_name__ = 'function' + type.__init__(cls, clsname, bases, dict) + + def __call__(self, *args, **kwargs): + args = [_literal_as_binds(c) for c in args] + return type.__call__(self, *args, **kwargs) + +class GenericFunction(_Function): + __metaclass__ = _GenericMeta + + def __init__(self, type_=None, group=True, args=(), **kwargs): + self.packagenames = [] + self.oid_column = None + self.name = self.__class__.__name__ + self._bind = kwargs.get('bind', None) + if group: + self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args).self_group() + else: + self.clause_expr = ClauseList(operator=operators.comma_op, group_contents=True, *args) + self.type = sqltypes.to_instance(type_ or getattr(self, '__return_type__', None)) + +class AnsiFunction(GenericFunction): + def __init__(self, **kwargs): + GenericFunction.__init__(self, **kwargs) + + +class coalesce(GenericFunction): + def __init__(self, *args, **kwargs): + kwargs.setdefault('type_', _type_from_args(args)) + GenericFunction.__init__(self, args=args, **kwargs) + +class concat(GenericFunction): + __return_type__ = sqltypes.String + def __init__(self, *args, **kwargs): + GenericFunction.__init__(self, args=args, **kwargs) + +class char_length(GenericFunction): + __return_type__ = sqltypes.Integer + + def __init__(self, arg, **kwargs): + GenericFunction.__init__(self, args=[arg], **kwargs) + +class current_date(AnsiFunction): + __return_type__ = sqltypes.Date + +class current_time(AnsiFunction): + __return_type__ = sqltypes.Time + +class current_timestamp(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class current_user(AnsiFunction): + __return_type__ = sqltypes.String + +class localtime(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class localtimestamp(AnsiFunction): + __return_type__ = sqltypes.DateTime + +class session_user(AnsiFunction): + __return_type__ = sqltypes.String + +class user(AnsiFunction): + __return_type__ = sqltypes.String + +def _type_from_args(args): + for a in args: + if not isinstance(a.type, sqltypes.NullType): + return a.type + else: + return sqltypes.NullType
\ No newline at end of file |