diff options
author | mike bayer <mike_mp@zzzcomputing.com> | 2021-11-25 18:22:59 +0000 |
---|---|---|
committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2021-11-25 18:22:59 +0000 |
commit | 8ddb3ef165d0c2d6d7167bb861bb349e68b5e8df (patch) | |
tree | 1f61463f9888eedbd156b35858af266135f7d6e7 /lib/sqlalchemy/sql/compiler.py | |
parent | 3619f084bfb5208ae45686a0993d620b2711adf2 (diff) | |
parent | 939de240d31a5441ad7380738d410a976d4ecc3a (diff) | |
download | sqlalchemy-8ddb3ef165d0c2d6d7167bb861bb349e68b5e8df.tar.gz |
Merge "propose emulated setinputsizes embedded in the compiler" into main
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 115 |
1 files changed, 62 insertions, 53 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 0dd61d675..28c1bf069 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -228,6 +228,7 @@ FUNCTIONS = { functions.grouping_sets: "GROUPING SETS", } + EXTRACT_MAP = { "month": "month", "day": "day", @@ -1037,57 +1038,28 @@ class SQLCompiler(Compiled): return pd @util.memoized_instancemethod - def _get_set_input_sizes_lookup( - self, include_types=None, exclude_types=None - ): - if not hasattr(self, "bind_names"): - return None - + def _get_set_input_sizes_lookup(self): dialect = self.dialect - dbapi = self.dialect.dbapi - # _unwrapped_dialect_impl() is necessary so that we get the - # correct dialect type for a custom TypeDecorator, or a Variant, - # which is also a TypeDecorator. Special types like Interval, - # that use TypeDecorator but also might be mapped directly - # for a dialect impl, also subclass Emulated first which overrides - # this behavior in those cases to behave like the default. + include_types = dialect.include_set_input_sizes + exclude_types = dialect.exclude_set_input_sizes - if include_types is None and exclude_types is None: + dbapi = dialect.dbapi - def _lookup_type(typ): - dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi) - return dbtype + def lookup_type(typ): + dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi) - else: - - def _lookup_type(typ): - # note we get dbtype from the possibly TypeDecorator-wrapped - # dialect_impl, but the dialect_impl itself that we use for - # include/exclude is the unwrapped version. - - dialect_impl = typ._unwrapped_dialect_impl(dialect) - - dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi) - - if ( - dbtype is not None - and ( - exclude_types is None - or dbtype not in exclude_types - and type(dialect_impl) not in exclude_types - ) - and ( - include_types is None - or dbtype in include_types - or type(dialect_impl) in include_types - ) - ): - return dbtype - else: - return None + if ( + dbtype is not None + and (exclude_types is None or dbtype not in exclude_types) + and (include_types is None or dbtype in include_types) + ): + return dbtype + else: + return None inputsizes = {} + literal_execute_params = self.literal_execute_params for bindparam in self.bind_names: @@ -1096,10 +1068,10 @@ class SQLCompiler(Compiled): if bindparam.type._is_tuple_type: inputsizes[bindparam] = [ - _lookup_type(typ) for typ in bindparam.type.types + lookup_type(typ) for typ in bindparam.type.types ] else: - inputsizes[bindparam] = _lookup_type(bindparam.type) + inputsizes[bindparam] = lookup_type(bindparam.type) return inputsizes @@ -2060,7 +2032,25 @@ class SQLCompiler(Compiled): parameter, values ) - typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect) + dialect = self.dialect + typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect) + + if ( + self.dialect._bind_typing_render_casts + and typ_dialect_impl.render_bind_cast + ): + + def _render_bindtemplate(name): + return self.render_bind_cast( + parameter.type, + typ_dialect_impl, + self.bindtemplate % {"name": name}, + ) + + else: + + def _render_bindtemplate(name): + return self.bindtemplate % {"name": name} if not values: to_update = [] @@ -2085,14 +2075,16 @@ class SQLCompiler(Compiled): for i, tuple_element in enumerate(values, 1) for j, value in enumerate(tuple_element, 1) ] + replacement_expression = ( - "VALUES " if self.dialect.tuple_in_values else "" + "VALUES " if dialect.tuple_in_values else "" ) + ", ".join( "(%s)" % ( ", ".join( - self.bindtemplate - % {"name": to_update[i * len(tuple_element) + j][0]} + _render_bindtemplate( + to_update[i * len(tuple_element) + j][0] + ) for j, value in enumerate(tuple_element) ) ) @@ -2104,7 +2096,7 @@ class SQLCompiler(Compiled): for i, value in enumerate(values, 1) ] replacement_expression = ", ".join( - self.bindtemplate % {"name": key} for key, value in to_update + _render_bindtemplate(key) for key, value in to_update ) return to_update, replacement_expression @@ -2373,6 +2365,7 @@ class SQLCompiler(Compiled): m = re.match( r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped ) + assert m, "unexpected format for expanding parameter" wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % ( m.group(2), m.group(1), @@ -2460,13 +2453,18 @@ class SQLCompiler(Compiled): name, post_compile=post_compile, expanding=bindparam.expanding, + bindparam_type=bindparam.type, **kwargs ) if bindparam.expanding: ret = "(%s)" % ret + return ret + def render_bind_cast(self, type_, dbapi_type, sqltext): + raise NotImplementedError() + def render_literal_bindparam( self, bindparam, render_literal_value=NO_ARG, **kw ): @@ -2553,6 +2551,7 @@ class SQLCompiler(Compiled): post_compile=False, expanding=False, escaped_from=None, + bindparam_type=None, **kw ): @@ -2580,8 +2579,18 @@ class SQLCompiler(Compiled): self.escaped_bind_names[escaped_from] = name if post_compile: return "__[POSTCOMPILE_%s]" % name - else: - return self.bindtemplate % {"name": name} + + ret = self.bindtemplate % {"name": name} + + if ( + bindparam_type is not None + and self.dialect._bind_typing_render_casts + ): + type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect) + if type_impl.render_bind_cast: + ret = self.render_bind_cast(bindparam_type, type_impl, ret) + + return ret def visit_cte( self, |