summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql
diff options
context:
space:
mode:
authormike bayer <mike_mp@zzzcomputing.com>2022-12-16 15:08:00 +0000
committerGerrit Code Review <gerrit@ci3.zzzcomputing.com>2022-12-16 15:08:00 +0000
commita41d0dc4bdcc698643b6a4d76f265f5aa4765bee (patch)
tree91ce64509ed3e87b6456c8c69513771eba6c0403 /lib/sqlalchemy/dialects/postgresql
parent95a6b50923660a5d69c5aa54d195582112c4358d (diff)
parent7b84c850606c7b093b4260c08ff4636ff1bdbfef (diff)
downloadsqlalchemy-a41d0dc4bdcc698643b6a4d76f265f5aa4765bee.tar.gz
Merge "add explicit REGCONFIG, pg full text functions" into main
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/__init__.py11
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py6
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py63
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py208
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg.py6
-rw-r--r--lib/sqlalchemy/dialects/postgresql/types.py25
6 files changed, 317 insertions, 2 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py
index 7890541ff..d2e213bbc 100644
--- a/lib/sqlalchemy/dialects/postgresql/__init__.py
+++ b/lib/sqlalchemy/dialects/postgresql/__init__.py
@@ -37,6 +37,12 @@ from .dml import insert
from .ext import aggregate_order_by
from .ext import array_agg
from .ext import ExcludeConstraint
+from .ext import phraseto_tsquery
+from .ext import plainto_tsquery
+from .ext import to_tsquery
+from .ext import to_tsvector
+from .ext import ts_headline
+from .ext import websearch_to_tsquery
from .hstore import HSTORE
from .hstore import hstore
from .json import JSON
@@ -72,8 +78,10 @@ from .types import MACADDR
from .types import MONEY
from .types import OID
from .types import REGCLASS
+from .types import REGCONFIG
from .types import TIME
from .types import TIMESTAMP
+from .types import TSQUERY
from .types import TSVECTOR
# Alias psycopg also as psycopg_async
@@ -102,6 +110,9 @@ __all__ = (
"MONEY",
"OID",
"REGCLASS",
+ "REGCONFIG",
+ "TSQUERY",
+ "TSVECTOR",
"DOUBLE_PRECISION",
"TIMESTAMP",
"TIME",
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index b8f614eba..3c1eaf918 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -142,6 +142,7 @@ from .base import PGDialect
from .base import PGExecutionContext
from .base import PGIdentifierPreparer
from .base import REGCLASS
+from .base import REGCONFIG
from ... import exc
from ... import pool
from ... import util
@@ -160,6 +161,10 @@ class AsyncpgString(sqltypes.String):
render_bind_cast = True
+class AsyncpgREGCONFIG(REGCONFIG):
+ render_bind_cast = True
+
+
class AsyncpgTime(sqltypes.Time):
render_bind_cast = True
@@ -899,6 +904,7 @@ class PGDialect_asyncpg(PGDialect):
PGDialect.colspecs,
{
sqltypes.String: AsyncpgString,
+ REGCONFIG: AsyncpgREGCONFIG,
sqltypes.Time: AsyncpgTime,
sqltypes.Date: AsyncpgDate,
sqltypes.DateTime: AsyncpgDateTime,
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index f9108094f..8287e828a 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -827,6 +827,8 @@ For example, the query::
would generate:
+.. sourcecode:: sql
+
SELECT to_tsquery('cat') @> to_tsquery('cat & rat')
@@ -840,6 +842,20 @@ produces a statement equivalent to::
SELECT CAST('some text' AS TSVECTOR) AS anon_1
+The ``func`` namespace is augmented by the PostgreSQL dialect to set up
+correct argument and return types for most full text search functions.
+These functions are used automatically by the :attr:`_sql.func` namespace
+assuming the ``sqlalchemy.dialects.postgresql`` package has been imported,
+or :func:`_sa.create_engine` has been invoked using a ``postgresql``
+dialect. These functions are documented at:
+
+* :class:`_postgresql.to_tsvector`
+* :class:`_postgresql.to_tsquery`
+* :class:`_postgresql.plainto_tsquery`
+* :class:`_postgresql.phraseto_tsquery`
+* :class:`_postgresql.websearch_to_tsquery`
+* :class:`_postgresql.ts_headline`
+
Specifying the "regconfig" with ``match()`` or custom operators
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -1402,6 +1418,7 @@ from . import hstore as _hstore
from . import json as _json
from . import pg_catalog
from . import ranges as _ranges
+from .ext import _regconfig_fn
from .ext import aggregate_order_by
from .named_types import CreateDomainType as CreateDomainType # noqa: F401
from .named_types import CreateEnumType as CreateEnumType # noqa: F401
@@ -1428,6 +1445,7 @@ from .types import PGInterval as PGInterval # noqa: F401
from .types import PGMacAddr as PGMacAddr # noqa: F401
from .types import PGUuid as PGUuid
from .types import REGCLASS as REGCLASS
+from .types import REGCONFIG as REGCONFIG # noqa: F401
from .types import TIME as TIME
from .types import TIMESTAMP as TIMESTAMP
from .types import TSVECTOR as TSVECTOR
@@ -1636,6 +1654,45 @@ ischema_names = {
class PGCompiler(compiler.SQLCompiler):
+ def visit_to_tsvector_func(self, element, **kw):
+ return self._assert_pg_ts_ext(element, **kw)
+
+ def visit_to_tsquery_func(self, element, **kw):
+ return self._assert_pg_ts_ext(element, **kw)
+
+ def visit_plainto_tsquery_func(self, element, **kw):
+ return self._assert_pg_ts_ext(element, **kw)
+
+ def visit_phraseto_tsquery_func(self, element, **kw):
+ return self._assert_pg_ts_ext(element, **kw)
+
+ def visit_websearch_to_tsquery_func(self, element, **kw):
+ return self._assert_pg_ts_ext(element, **kw)
+
+ def visit_ts_headline_func(self, element, **kw):
+ return self._assert_pg_ts_ext(element, **kw)
+
+ def _assert_pg_ts_ext(self, element, **kw):
+ if not isinstance(element, _regconfig_fn):
+ # other options here include trying to rewrite the function
+ # with the correct types. however, that means we have to
+ # "un-SQL-ize" the first argument, which can't work in a
+ # generalized way. Also, parent compiler class has already added
+ # the incorrect return type to the result map. So let's just
+ # make sure the function we want is used up front.
+
+ raise exc.CompileError(
+ f'Can\'t compile "{element.name}()" full text search '
+ f"function construct that does not originate from the "
+ f'"sqlalchemy.dialects.postgresql" package. '
+ f'Please ensure "import sqlalchemy.dialects.postgresql" is '
+ f"called before constructing "
+ f'"sqlalchemy.func.{element.name}()" to ensure registration '
+ f"of the correct argument and return types."
+ )
+
+ return f"{element.name}{self.function_argspec(element, **kw)}"
+
def render_bind_cast(self, type_, dbapi_type, sqltext):
return f"""{sqltext}::{
self.dialect.type_compiler_instance.process(
@@ -2381,6 +2438,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
def visit_TSVECTOR(self, type_, **kw):
return "TSVECTOR"
+ def visit_TSQUERY(self, type_, **kw):
+ return "TSQUERY"
+
def visit_INET(self, type_, **kw):
return "INET"
@@ -2396,6 +2456,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
def visit_OID(self, type_, **kw):
return "OID"
+ def visit_REGCONFIG(self, type_, **kw):
+ return "REGCONFIG"
+
def visit_REGCLASS(self, type_, **kw):
return "REGCLASS"
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
index b0d8ef345..31fbf203b 100644
--- a/lib/sqlalchemy/dialects/postgresql/ext.py
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -8,8 +8,11 @@
from __future__ import annotations
from itertools import zip_longest
+from typing import Any
from typing import TYPE_CHECKING
+from typing import TypeVar
+from . import types
from .array import ARRAY
from ...sql import coercions
from ...sql import elements
@@ -18,8 +21,11 @@ from ...sql import functions
from ...sql import roles
from ...sql import schema
from ...sql.schema import ColumnCollectionConstraint
+from ...sql.sqltypes import TEXT
from ...sql.visitors import InternalTraversal
+_T = TypeVar("_T", bound=Any)
+
if TYPE_CHECKING:
from ...sql.visitors import _TraverseInternalsType
@@ -287,3 +293,205 @@ def array_agg(*arg, **kw):
"""
kw["_default_array_type"] = ARRAY
return functions.func.array_agg(*arg, **kw)
+
+
+class _regconfig_fn(functions.GenericFunction[_T]):
+ inherit_cache = True
+
+ def __init__(self, *args, **kwargs):
+ args = list(args)
+ if len(args) > 1:
+
+ initial_arg = coercions.expect(
+ roles.ExpressionElementRole,
+ args.pop(0),
+ name=getattr(self, "name", None),
+ apply_propagate_attrs=self,
+ type_=types.REGCONFIG,
+ )
+ initial_arg = [initial_arg]
+ else:
+ initial_arg = []
+
+ addtl_args = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ name=getattr(self, "name", None),
+ apply_propagate_attrs=self,
+ )
+ for c in args
+ ]
+ super().__init__(*(initial_arg + addtl_args), **kwargs)
+
+
+class to_tsvector(_regconfig_fn):
+ """The PostgreSQL ``to_tsvector`` SQL function.
+
+ This function applies automatic casting of the REGCONFIG argument
+ to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+ and applies a return type of :class:`_postgresql.TSVECTOR`.
+
+ Assuming the PostgreSQL dialect has been imported, either by invoking
+ ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+ engine using ``create_engine("postgresql...")``,
+ :class:`_postgresql.to_tsvector` will be used automatically when invoking
+ ``sqlalchemy.func.to_tsvector()``, ensuring the correct argument and return
+ type handlers are used at compile and execution time.
+
+ .. versionadded:: 2.0.0b5
+
+ """
+
+ inherit_cache = True
+ type = types.TSVECTOR
+
+
+class to_tsquery(_regconfig_fn):
+ """The PostgreSQL ``to_tsquery`` SQL function.
+
+ This function applies automatic casting of the REGCONFIG argument
+ to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+ and applies a return type of :class:`_postgresql.TSQUERY`.
+
+ Assuming the PostgreSQL dialect has been imported, either by invoking
+ ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+ engine using ``create_engine("postgresql...")``,
+ :class:`_postgresql.to_tsquery` will be used automatically when invoking
+ ``sqlalchemy.func.to_tsquery()``, ensuring the correct argument and return
+ type handlers are used at compile and execution time.
+
+ .. versionadded:: 2.0.0b5
+
+ """
+
+ inherit_cache = True
+ type = types.TSQUERY
+
+
+class plainto_tsquery(_regconfig_fn):
+ """The PostgreSQL ``plainto_tsquery`` SQL function.
+
+ This function applies automatic casting of the REGCONFIG argument
+ to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+ and applies a return type of :class:`_postgresql.TSQUERY`.
+
+ Assuming the PostgreSQL dialect has been imported, either by invoking
+ ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+ engine using ``create_engine("postgresql...")``,
+ :class:`_postgresql.plainto_tsquery` will be used automatically when
+ invoking ``sqlalchemy.func.plainto_tsquery()``, ensuring the correct
+ argument and return type handlers are used at compile and execution time.
+
+ .. versionadded:: 2.0.0b5
+
+ """
+
+ inherit_cache = True
+ type = types.TSQUERY
+
+
+class phraseto_tsquery(_regconfig_fn):
+ """The PostgreSQL ``phraseto_tsquery`` SQL function.
+
+ This function applies automatic casting of the REGCONFIG argument
+ to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+ and applies a return type of :class:`_postgresql.TSQUERY`.
+
+ Assuming the PostgreSQL dialect has been imported, either by invoking
+ ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+ engine using ``create_engine("postgresql...")``,
+ :class:`_postgresql.phraseto_tsquery` will be used automatically when
+ invoking ``sqlalchemy.func.phraseto_tsquery()``, ensuring the correct
+ argument and return type handlers are used at compile and execution time.
+
+ .. versionadded:: 2.0.0b5
+
+ """
+
+ inherit_cache = True
+ type = types.TSQUERY
+
+
+class websearch_to_tsquery(_regconfig_fn):
+ """The PostgreSQL ``websearch_to_tsquery`` SQL function.
+
+ This function applies automatic casting of the REGCONFIG argument
+ to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+ and applies a return type of :class:`_postgresql.TSQUERY`.
+
+ Assuming the PostgreSQL dialect has been imported, either by invoking
+ ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+ engine using ``create_engine("postgresql...")``,
+ :class:`_postgresql.websearch_to_tsquery` will be used automatically when
+ invoking ``sqlalchemy.func.websearch_to_tsquery()``, ensuring the correct
+ argument and return type handlers are used at compile and execution time.
+
+ .. versionadded:: 2.0.0b5
+
+ """
+
+ inherit_cache = True
+ type = types.TSQUERY
+
+
+class ts_headline(_regconfig_fn):
+ """The PostgreSQL ``ts_headline`` SQL function.
+
+ This function applies automatic casting of the REGCONFIG argument
+ to use the :class:`_postgresql.REGCONFIG` datatype automatically,
+ and applies a return type of :class:`_types.TEXT`.
+
+ Assuming the PostgreSQL dialect has been imported, either by invoking
+ ``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
+ engine using ``create_engine("postgresql...")``,
+ :class:`_postgresql.ts_headline` will be used automatically when invoking
+ ``sqlalchemy.func.ts_headline()``, ensuring the correct argument and return
+ type handlers are used at compile and execution time.
+
+ .. versionadded:: 2.0.0b5
+
+ """
+
+ inherit_cache = True
+ type = TEXT
+
+ def __init__(self, *args, **kwargs):
+ args = list(args)
+
+ # parse types according to
+ # https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-HEADLINE
+ if len(args) < 2:
+ # invalid args; don't do anything
+ has_regconfig = False
+ elif (
+ isinstance(args[1], elements.ColumnElement)
+ and args[1].type._type_affinity is types.TSQUERY
+ ):
+ # tsquery is second argument, no regconfig argument
+ has_regconfig = False
+ else:
+ has_regconfig = True
+
+ if has_regconfig:
+ initial_arg = coercions.expect(
+ roles.ExpressionElementRole,
+ args.pop(0),
+ apply_propagate_attrs=self,
+ name=getattr(self, "name", None),
+ type_=types.REGCONFIG,
+ )
+ initial_arg = [initial_arg]
+ else:
+ initial_arg = []
+
+ addtl_args = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ name=getattr(self, "name", None),
+ apply_propagate_attrs=self,
+ )
+ for c in args
+ ]
+ super().__init__(*(initial_arg + addtl_args), **kwargs)
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg.py b/lib/sqlalchemy/dialects/postgresql/psycopg.py
index 400c3186e..67d1370f5 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg.py
@@ -70,6 +70,7 @@ from ._psycopg_common import _PGExecutionContext_common_psycopg
from .base import INTERVAL
from .base import PGCompiler
from .base import PGIdentifierPreparer
+from .base import REGCONFIG
from .json import JSON
from .json import JSONB
from .json import JSONPathType
@@ -90,6 +91,10 @@ class _PGString(sqltypes.String):
render_bind_cast = True
+class _PGREGCONFIG(REGCONFIG):
+ render_bind_cast = True
+
+
class _PGJSON(JSON):
render_bind_cast = True
@@ -270,6 +275,7 @@ class PGDialect_psycopg(_PGDialect_common_psycopg):
_PGDialect_common_psycopg.colspecs,
{
sqltypes.String: _PGString,
+ REGCONFIG: _PGREGCONFIG,
JSON: _PGJSON,
sqltypes.JSON: _PGJSON,
JSONB: _PGJSONB,
diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py
index 72703ff81..49fc70ba3 100644
--- a/lib/sqlalchemy/dialects/postgresql/types.py
+++ b/lib/sqlalchemy/dialects/postgresql/types.py
@@ -6,7 +6,6 @@
# mypy: ignore-errors
import datetime as dt
-from typing import Any
from ...sql import sqltypes
@@ -102,6 +101,28 @@ class OID(sqltypes.TypeEngine[int]):
__visit_name__ = "OID"
+class REGCONFIG(sqltypes.TypeEngine[str]):
+
+ """Provide the PostgreSQL REGCONFIG type.
+
+ .. versionadded:: 2.0.0b5
+
+ """
+
+ __visit_name__ = "REGCONFIG"
+
+
+class TSQUERY(sqltypes.TypeEngine[str]):
+
+ """Provide the PostgreSQL TSQUERY type.
+
+ .. versionadded:: 2.0.0b5
+
+ """
+
+ __visit_name__ = "TSQUERY"
+
+
class REGCLASS(sqltypes.TypeEngine[str]):
"""Provide the PostgreSQL REGCLASS type.
@@ -207,7 +228,7 @@ class BIT(sqltypes.TypeEngine[int]):
PGBit = BIT
-class TSVECTOR(sqltypes.TypeEngine[Any]):
+class TSVECTOR(sqltypes.TypeEngine[str]):
"""The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
text search type TSVECTOR.