summaryrefslogtreecommitdiff
path: root/test/dialect/postgresql/test_compiler.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2022-12-13 20:07:14 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2022-12-15 10:36:48 -0500
commit7b84c850606c7b093b4260c08ff4636ff1bdbfef (patch)
treea2500d653134f05981ea3b6618ff63dcefa91716 /test/dialect/postgresql/test_compiler.py
parente0eea374c2df82f879d69b99ba2230c743bbae27 (diff)
downloadsqlalchemy-7b84c850606c7b093b4260c08ff4636ff1bdbfef.tar.gz
add explicit REGCONFIG, pg full text functions
Added support for explicit use of PG full text functions with asyncpg and psycopg (SQLAlchemy 2.0 only), with regards to the ``REGCONFIG`` type cast for the first argument, which previously would be incorrectly cast to a VARCHAR, causing failures on these dialects that rely upon explicit type casts. This includes support for :class:`_postgresql.to_tsvector`, :class:`_postgresql.to_tsquery`, :class:`_postgresql.plainto_tsquery`, :class:`_postgresql.phraseto_tsquery`, :class:`_postgresql.websearch_to_tsquery`, :class:`_postgresql.ts_headline`, each of which will determine based on number of arguments passed if the first string argument should be interpreted as a PostgreSQL "REGCONFIG" value; if so, the argument is typed using a newly added type object :class:`_postgresql.REGCONFIG` which is then explicitly cast in the SQL expression. Fixes: #8977 Change-Id: Ib36698a984fd4194bd6e0eb663105f790f3db7d3
Diffstat (limited to 'test/dialect/postgresql/test_compiler.py')
-rw-r--r--test/dialect/postgresql/test_compiler.py179
1 files changed, 179 insertions, 0 deletions
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index cf5f1c826..57b147c90 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -1,6 +1,7 @@
from sqlalchemy import and_
from sqlalchemy import BigInteger
from sqlalchemy import bindparam
+from sqlalchemy import case
from sqlalchemy import cast
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
@@ -28,6 +29,7 @@ from sqlalchemy import Table
from sqlalchemy import testing
from sqlalchemy import Text
from sqlalchemy import text
+from sqlalchemy import true
from sqlalchemy import tuple_
from sqlalchemy import types as sqltypes
from sqlalchemy import UniqueConstraint
@@ -43,6 +45,8 @@ from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import JSONPATH
from sqlalchemy.dialects.postgresql import Range
+from sqlalchemy.dialects.postgresql import REGCONFIG
+from sqlalchemy.dialects.postgresql import TSQUERY
from sqlalchemy.dialects.postgresql import TSRANGE
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
@@ -54,6 +58,8 @@ from sqlalchemy.sql import literal_column
from sqlalchemy.sql import operators
from sqlalchemy.sql import table
from sqlalchemy.sql import util as sql_util
+from sqlalchemy.sql.functions import GenericFunction
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_raises_message
@@ -3183,6 +3189,12 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
column("title", String(128)),
column("body", String(128)),
)
+ self.matchtable = Table(
+ "matchtable",
+ MetaData(),
+ Column("id", Integer, primary_key=True),
+ Column("title", String(200)),
+ )
def _raise_query(self, q):
"""
@@ -3287,6 +3299,173 @@ class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL):
"""plainto_tsquery('english', %(to_tsvector_2)s)""",
)
+ @testing.combinations(
+ ("to_tsvector",),
+ ("to_tsquery",),
+ ("plainto_tsquery",),
+ ("phraseto_tsquery",),
+ ("websearch_to_tsquery",),
+ ("ts_headline",),
+ argnames="to_ts_name",
+ )
+ def test_dont_compile_non_imported(self, to_ts_name):
+ new_func = type(
+ to_ts_name,
+ (GenericFunction,),
+ {
+ "_register": False,
+ "inherit_cache": True,
+ },
+ )
+
+ with expect_raises_message(
+ exc.CompileError,
+ rf"Can't compile \"{to_ts_name}\(\)\" full text search "
+ f"function construct that does not originate from the "
+ f'"sqlalchemy.dialects.postgresql" package. '
+ f'Please ensure "import sqlalchemy.dialects.postgresql" is '
+ f"called before constructing "
+ rf"\"sqlalchemy.func.{to_ts_name}\(\)\" to ensure "
+ f"registration of the correct "
+ f"argument and return types.",
+ ):
+ select(new_func("x", "y")).compile(dialect=postgresql.dialect())
+
+ @testing.combinations(
+ (func.to_tsvector,),
+ (func.to_tsquery,),
+ (func.plainto_tsquery,),
+ (func.phraseto_tsquery,),
+ (func.websearch_to_tsquery,),
+ argnames="to_ts_func",
+ )
+ @testing.variation("use_regconfig", [True, False, "literal"])
+ def test_to_regconfig_fns(self, to_ts_func, use_regconfig):
+ """test #8977"""
+ matchtable = self.matchtable
+
+ fn_name = to_ts_func().name
+
+ if use_regconfig.literal:
+ regconfig = literal("english", REGCONFIG)
+ elif use_regconfig:
+ regconfig = "english"
+ else:
+ regconfig = None
+
+ if regconfig is None:
+ if fn_name == "to_tsvector":
+ fn = to_ts_func(matchtable.c.title).match("python")
+ expected = (
+ "to_tsvector(matchtable.title) @@ "
+ "plainto_tsquery($1::VARCHAR)"
+ )
+ else:
+ fn = func.to_tsvector(matchtable.c.title).op("@@")(
+ to_ts_func("python")
+ )
+ expected = (
+ f"to_tsvector(matchtable.title) @@ {fn_name}($1::VARCHAR)"
+ )
+ else:
+ if fn_name == "to_tsvector":
+ fn = to_ts_func(regconfig, matchtable.c.title).match("python")
+ expected = (
+ "to_tsvector($1::REGCONFIG, matchtable.title) @@ "
+ "plainto_tsquery($2::VARCHAR)"
+ )
+ else:
+ fn = func.to_tsvector(matchtable.c.title).op("@@")(
+ to_ts_func(regconfig, "python")
+ )
+ expected = (
+ f"to_tsvector(matchtable.title) @@ "
+ f"{fn_name}($1::REGCONFIG, $2::VARCHAR)"
+ )
+
+ stmt = matchtable.select().where(fn)
+
+ self.assert_compile(
+ stmt,
+ "SELECT matchtable.id, matchtable.title "
+ f"FROM matchtable WHERE {expected}",
+ dialect="postgresql+asyncpg",
+ )
+
+ @testing.variation("use_regconfig", [True, False, "literal"])
+ @testing.variation("include_options", [True, False])
+ @testing.variation("tsquery_in_expr", [True, False])
+ def test_ts_headline(
+ self, connection, use_regconfig, include_options, tsquery_in_expr
+ ):
+ """test #8977"""
+ if use_regconfig.literal:
+ regconfig = literal("english", REGCONFIG)
+ elif use_regconfig:
+ regconfig = "english"
+ else:
+ regconfig = None
+
+ text = (
+ "The most common type of search is to find all documents "
+ "containing given query terms and return them in order of "
+ "their similarity to the query."
+ )
+ tsquery = func.to_tsquery("english", "query & similarity")
+
+ if regconfig is None:
+ tsquery_str = "to_tsquery($2::REGCONFIG, $3::VARCHAR)"
+ else:
+ tsquery_str = "to_tsquery($3::REGCONFIG, $4::VARCHAR)"
+
+ if tsquery_in_expr:
+ tsquery = case((true(), tsquery), else_=null())
+ tsquery_str = f"CASE WHEN true THEN {tsquery_str} ELSE NULL END"
+
+ is_(tsquery.type._type_affinity, TSQUERY)
+
+ args = [text, tsquery]
+ if regconfig is not None:
+ args.insert(0, regconfig)
+ if include_options:
+ args.append(
+ "MaxFragments=10, MaxWords=7, "
+ "MinWords=3, StartSel=<<, StopSel=>>"
+ )
+
+ fn = func.ts_headline(*args)
+ stmt = select(fn)
+
+ if regconfig is None and not include_options:
+ self.assert_compile(
+ stmt,
+ f"SELECT ts_headline($1::VARCHAR, "
+ f"{tsquery_str}) AS ts_headline_1",
+ dialect="postgresql+asyncpg",
+ )
+ elif regconfig is None and include_options:
+ self.assert_compile(
+ stmt,
+ f"SELECT ts_headline($1::VARCHAR, "
+ f"{tsquery_str}, $4::VARCHAR) AS ts_headline_1",
+ dialect="postgresql+asyncpg",
+ )
+ elif regconfig is not None and not include_options:
+ self.assert_compile(
+ stmt,
+ f"SELECT ts_headline($1::REGCONFIG, $2::VARCHAR, "
+ f"{tsquery_str}) AS ts_headline_1",
+ dialect="postgresql+asyncpg",
+ )
+ else:
+ self.assert_compile(
+ stmt,
+ f"SELECT ts_headline($1::REGCONFIG, $2::VARCHAR, "
+ f"{tsquery_str}, $5::VARCHAR) "
+ "AS ts_headline_1",
+ dialect="postgresql+asyncpg",
+ )
+
class RegexpTest(fixtures.TestBase, testing.AssertsCompiledSQL):
__dialect__ = "postgresql"