summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2021-11-25 18:22:59 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2021-11-25 18:22:59 +0000
commit8ddb3ef165d0c2d6d7167bb861bb349e68b5e8df (patch)
tree1f61463f9888eedbd156b35858af266135f7d6e7 /lib/sqlalchemy/sql/compiler.py
parent3619f084bfb5208ae45686a0993d620b2711adf2 (diff)
parent939de240d31a5441ad7380738d410a976d4ecc3a (diff)
downloadsqlalchemy-8ddb3ef165d0c2d6d7167bb861bb349e68b5e8df.tar.gz
Merge "propose emulated setinputsizes embedded in the compiler" into main
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py115
1 files changed, 62 insertions, 53 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 0dd61d675..28c1bf069 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -228,6 +228,7 @@ FUNCTIONS = {
functions.grouping_sets: "GROUPING SETS",
}
+
EXTRACT_MAP = {
"month": "month",
"day": "day",
@@ -1037,57 +1038,28 @@ class SQLCompiler(Compiled):
return pd
@util.memoized_instancemethod
- def _get_set_input_sizes_lookup(
- self, include_types=None, exclude_types=None
- ):
- if not hasattr(self, "bind_names"):
- return None
-
+ def _get_set_input_sizes_lookup(self):
dialect = self.dialect
- dbapi = self.dialect.dbapi
- # _unwrapped_dialect_impl() is necessary so that we get the
- # correct dialect type for a custom TypeDecorator, or a Variant,
- # which is also a TypeDecorator. Special types like Interval,
- # that use TypeDecorator but also might be mapped directly
- # for a dialect impl, also subclass Emulated first which overrides
- # this behavior in those cases to behave like the default.
+ include_types = dialect.include_set_input_sizes
+ exclude_types = dialect.exclude_set_input_sizes
- if include_types is None and exclude_types is None:
+ dbapi = dialect.dbapi
- def _lookup_type(typ):
- dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
- return dbtype
+ def lookup_type(typ):
+ dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi)
- else:
-
- def _lookup_type(typ):
- # note we get dbtype from the possibly TypeDecorator-wrapped
- # dialect_impl, but the dialect_impl itself that we use for
- # include/exclude is the unwrapped version.
-
- dialect_impl = typ._unwrapped_dialect_impl(dialect)
-
- dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
-
- if (
- dbtype is not None
- and (
- exclude_types is None
- or dbtype not in exclude_types
- and type(dialect_impl) not in exclude_types
- )
- and (
- include_types is None
- or dbtype in include_types
- or type(dialect_impl) in include_types
- )
- ):
- return dbtype
- else:
- return None
+ if (
+ dbtype is not None
+ and (exclude_types is None or dbtype not in exclude_types)
+ and (include_types is None or dbtype in include_types)
+ ):
+ return dbtype
+ else:
+ return None
inputsizes = {}
+
literal_execute_params = self.literal_execute_params
for bindparam in self.bind_names:
@@ -1096,10 +1068,10 @@ class SQLCompiler(Compiled):
if bindparam.type._is_tuple_type:
inputsizes[bindparam] = [
- _lookup_type(typ) for typ in bindparam.type.types
+ lookup_type(typ) for typ in bindparam.type.types
]
else:
- inputsizes[bindparam] = _lookup_type(bindparam.type)
+ inputsizes[bindparam] = lookup_type(bindparam.type)
return inputsizes
@@ -2060,7 +2032,25 @@ class SQLCompiler(Compiled):
parameter, values
)
- typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+ dialect = self.dialect
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
+
+ if (
+ self.dialect._bind_typing_render_casts
+ and typ_dialect_impl.render_bind_cast
+ ):
+
+ def _render_bindtemplate(name):
+ return self.render_bind_cast(
+ parameter.type,
+ typ_dialect_impl,
+ self.bindtemplate % {"name": name},
+ )
+
+ else:
+
+ def _render_bindtemplate(name):
+ return self.bindtemplate % {"name": name}
if not values:
to_update = []
@@ -2085,14 +2075,16 @@ class SQLCompiler(Compiled):
for i, tuple_element in enumerate(values, 1)
for j, value in enumerate(tuple_element, 1)
]
+
replacement_expression = (
- "VALUES " if self.dialect.tuple_in_values else ""
+ "VALUES " if dialect.tuple_in_values else ""
) + ", ".join(
"(%s)"
% (
", ".join(
- self.bindtemplate
- % {"name": to_update[i * len(tuple_element) + j][0]}
+ _render_bindtemplate(
+ to_update[i * len(tuple_element) + j][0]
+ )
for j, value in enumerate(tuple_element)
)
)
@@ -2104,7 +2096,7 @@ class SQLCompiler(Compiled):
for i, value in enumerate(values, 1)
]
replacement_expression = ", ".join(
- self.bindtemplate % {"name": key} for key, value in to_update
+ _render_bindtemplate(key) for key, value in to_update
)
return to_update, replacement_expression
@@ -2373,6 +2365,7 @@ class SQLCompiler(Compiled):
m = re.match(
r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
)
+ assert m, "unexpected format for expanding parameter"
wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
m.group(2),
m.group(1),
@@ -2460,13 +2453,18 @@ class SQLCompiler(Compiled):
name,
post_compile=post_compile,
expanding=bindparam.expanding,
+ bindparam_type=bindparam.type,
**kwargs
)
if bindparam.expanding:
ret = "(%s)" % ret
+
return ret
+ def render_bind_cast(self, type_, dbapi_type, sqltext):
+ raise NotImplementedError()
+
def render_literal_bindparam(
self, bindparam, render_literal_value=NO_ARG, **kw
):
@@ -2553,6 +2551,7 @@ class SQLCompiler(Compiled):
post_compile=False,
expanding=False,
escaped_from=None,
+ bindparam_type=None,
**kw
):
@@ -2580,8 +2579,18 @@ class SQLCompiler(Compiled):
self.escaped_bind_names[escaped_from] = name
if post_compile:
return "__[POSTCOMPILE_%s]" % name
- else:
- return self.bindtemplate % {"name": name}
+
+ ret = self.bindtemplate % {"name": name}
+
+ if (
+ bindparam_type is not None
+ and self.dialect._bind_typing_render_casts
+ ):
+ type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect)
+ if type_impl.render_bind_cast:
+ ret = self.render_bind_cast(bindparam_type, type_impl, ret)
+
+ return ret
def visit_cte(
self,