summaryrefslogtreecommitdiff
path: root/test/sql/test_functions.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/sql/test_functions.py')
-rw-r--r--test/sql/test_functions.py34
1 files changed, 32 insertions, 2 deletions
diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py
index e460a90cb..96e0a9129 100644
--- a/test/sql/test_functions.py
+++ b/test/sql/test_functions.py
@@ -31,9 +31,11 @@ from sqlalchemy.dialects import mysql
from sqlalchemy.dialects import oracle
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects import sqlite
+from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import column
from sqlalchemy.sql import functions
from sqlalchemy.sql import LABEL_STYLE_TABLENAME_PLUS_COL
+from sqlalchemy.sql import operators
from sqlalchemy.sql import quoted_name
from sqlalchemy.sql import table
from sqlalchemy.sql.compiler import BIND_TEMPLATES
@@ -99,6 +101,36 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
functions._registry["_default"].pop("fake_func")
+ @testing.combinations(
+ (operators.in_op, [1, 2, 3], "myfunc() IN (1, 2, 3)"),
+ (operators.add, 5, "myfunc() + 5"),
+ (operators.eq, column("q"), "myfunc() = q"),
+ argnames="op,other,expected",
+ )
+ @testing.combinations((True,), (False,), argnames="use_custom")
+ def test_operators_custom(self, op, other, expected, use_custom):
+ if use_custom:
+
+ class MyFunc(FunctionElement):
+ name = "myfunc"
+ type = Integer()
+
+ @compiles(MyFunc)
+ def visit_myfunc(element, compiler, **kw):
+ return "myfunc(%s)" % compiler.process(element.clauses, **kw)
+
+ expr = op(MyFunc(), other)
+ else:
+ expr = op(func.myfunc(type_=Integer), other)
+
+ self.assert_compile(
+ select(1).where(expr),
+ "SELECT 1 WHERE %s" % (expected,),
+ literal_binds=True,
+ render_postcompile=True,
+ dialect="default_enhanced",
+ )
+
def test_use_labels(self):
self.assert_compile(
select(func.foo()).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL),
@@ -106,8 +138,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
)
def test_use_labels_function_element(self):
- from sqlalchemy.ext.compiler import compiles
-
class max_(FunctionElement):
name = "max"