summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/compiler.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/compiler.py')
-rw-r--r--lib/sqlalchemy/sql/compiler.py74
1 files changed, 62 insertions, 12 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b703c59f2..15ddd7d6f 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -139,8 +139,16 @@ RESERVED_WORDS = set(
)
LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
+LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
+FK_ON_DELETE = re.compile(
+ r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
+)
+FK_ON_UPDATE = re.compile(
+ r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
+)
+FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
@@ -758,12 +766,11 @@ class SQLCompiler(Compiled):
else:
col = with_cols[element.element]
except KeyError:
- # treat it like text()
- util.warn_limited(
- "Can't resolve label reference %r; converting to text()",
- util.ellipses_string(element.element),
+ elements._no_text_coercion(
+ element.element,
+ exc.CompileError,
+ "Can't resolve label reference for ORDER BY / GROUP BY.",
)
- return self.process(element._text_clause)
else:
kwargs["render_label_as_label"] = col
return self.process(
@@ -1076,10 +1083,24 @@ class SQLCompiler(Compiled):
if func._has_args:
name += "%(expr)s"
else:
- name = func.name + "%(expr)s"
- return ".".join(list(func.packagenames) + [name]) % {
- "expr": self.function_argspec(func, **kwargs)
- }
+ name = func.name
+ name = (
+ self.preparer.quote(name)
+ if self.preparer._requires_quotes_illegal_chars(name)
+ else name
+ )
+ name = name + "%(expr)s"
+ return ".".join(
+ [
+ (
+ self.preparer.quote(tok)
+ if self.preparer._requires_quotes_illegal_chars(tok)
+ else tok
+ )
+ for tok in func.packagenames
+ ]
+ + [name]
+ ) % {"expr": self.function_argspec(func, **kwargs)}
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
@@ -3153,9 +3174,13 @@ class DDLCompiler(Compiled):
def define_constraint_cascades(self, constraint):
text = ""
if constraint.ondelete is not None:
- text += " ON DELETE %s" % constraint.ondelete
+ text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
+ constraint.ondelete, FK_ON_DELETE
+ )
if constraint.onupdate is not None:
- text += " ON UPDATE %s" % constraint.onupdate
+ text += " ON UPDATE %s" % self.preparer.validate_sql_phrase(
+ constraint.onupdate, FK_ON_UPDATE
+ )
return text
def define_constraint_deferrability(self, constraint):
@@ -3166,7 +3191,9 @@ class DDLCompiler(Compiled):
else:
text += " NOT DEFERRABLE"
if constraint.initially is not None:
- text += " INITIALLY %s" % constraint.initially
+ text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
+ constraint.initially, FK_INITIALLY
+ )
return text
def define_constraint_match(self, constraint):
@@ -3416,6 +3443,24 @@ class IdentifierPreparer(object):
return value.replace(self.escape_to_quote, self.escape_quote)
+ def validate_sql_phrase(self, element, reg):
+ """keyword sequence filter.
+
+ a filter for elements that are intended to represent keyword sequences,
+ such as "INITIALLY", "INTIALLY DEFERRED", etc. no special characters
+ should be present.
+
+ .. versionadded:: 1.3
+
+ """
+
+ if element is not None and not reg.match(element):
+ raise exc.CompileError(
+ "Unexpected SQL phrase: %r (matching against %r)"
+ % (element, reg.pattern)
+ )
+ return element
+
def quote_identifier(self, value):
"""Quote an identifier.
@@ -3439,6 +3484,11 @@ class IdentifierPreparer(object):
or (lc_value != value)
)
+ def _requires_quotes_illegal_chars(self, value):
+ """Return True if the given identifier requires quoting, but
+ not taking case convention into account."""
+ return not self.legal_characters.match(util.text_type(value))
+
def quote_schema(self, schema, force=None):
"""Conditionally quote a schema name.