summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/asyncpg.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-11-22 14:28:26 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2021-11-23 16:52:55 -0500
commit939de240d31a5441ad7380738d410a976d4ecc3a (patch)
treee5261a905636fa473760b1e81894453112bbaa66 /lib/sqlalchemy/dialects/postgresql/asyncpg.py
parentd3a4e96196cd47858de072ae589c6554088edc24 (diff)
downloadsqlalchemy-939de240d31a5441ad7380738d410a976d4ecc3a.tar.gz
propose emulated setinputsizes embedded in the compiler
Add a new system so that PostgreSQL and other dialects have a reliable way to add casts to bound parameters in SQL statements, replacing previous use of setinputsizes() for PG dialects. rationale: 1. psycopg3 will be using the same SQLAlchemy-side "setinputsizes" as asyncpg, so we will be seeing a lot more of this 2. the full rendering that SQLAlchemy's compilation is performing is in the engine log as well as error messages. Without this, we introduce three levels of SQL rendering, the compiler, the hidden "setinputsizes" in SQLAlchemy, and then whatever the DBAPI driver does. With this new approach, users reporting bugs etc. will be less confused that there are as many as two separate layers of "hidden rendering"; SQLAlchemy's rendering is again fully transparent 3. calling upon a setinputsizes() method for every statement execution is expensive. this way, the work is done behind the caching layer 4. for "fast insertmany()", I also want there to be a fast approach towards setinputsizes. As it was, we were going to be taking a SQL INSERT with thousands of bound parameter placeholders and running a whole second pass on it to apply typecasts. this way, we will at least be able to build the SQL string once without a huge second pass over the whole string 5. psycopg2 can use this same system for its ARRAY casts 6. the general need for PostgreSQL to have lots of type casts is now mostly in the base PostgreSQL dialect and works independently of a DBAPI being present. dependence on DBAPI symbols that aren't complete / consistent / hashable is removed I was originally going to try to build this into bind_expression(), but it was revealed this worked poorly with custom bind_expression() as well as empty sets. the current impl also doesn't need to run a second expression pass over the POSTCOMPILE sections, which came out better than I originally thought it would. Change-Id: I363e6d593d059add7bcc6d1f6c3f91dd2e683c0c
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/asyncpg.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py152
1 files changed, 38 insertions, 114 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index fe1f9fd5a..4ac0971e5 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -134,32 +134,28 @@ except ImportError:
_python_UUID = None
+class AsyncpgString(sqltypes.String):
+ render_bind_cast = True
+
+
class AsyncpgTime(sqltypes.Time):
- def get_dbapi_type(self, dbapi):
- return dbapi.TIME
+ render_bind_cast = True
class AsyncpgDate(sqltypes.Date):
- def get_dbapi_type(self, dbapi):
- return dbapi.DATE
+ render_bind_cast = True
class AsyncpgDateTime(sqltypes.DateTime):
- def get_dbapi_type(self, dbapi):
- if self.timezone:
- return dbapi.TIMESTAMP_W_TZ
- else:
- return dbapi.TIMESTAMP
+ render_bind_cast = True
class AsyncpgBoolean(sqltypes.Boolean):
- def get_dbapi_type(self, dbapi):
- return dbapi.BOOLEAN
+ render_bind_cast = True
class AsyncPgInterval(INTERVAL):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTERVAL
+ render_bind_cast = True
@classmethod
def adapt_emulated_to_native(cls, interval, **kw):
@@ -168,49 +164,45 @@ class AsyncPgInterval(INTERVAL):
class AsyncPgEnum(ENUM):
- def get_dbapi_type(self, dbapi):
- return dbapi.ENUM
+ render_bind_cast = True
class AsyncpgInteger(sqltypes.Integer):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTEGER
+ render_bind_cast = True
class AsyncpgBigInteger(sqltypes.BigInteger):
- def get_dbapi_type(self, dbapi):
- return dbapi.BIGINTEGER
+ render_bind_cast = True
class AsyncpgJSON(json.JSON):
- def get_dbapi_type(self, dbapi):
- return dbapi.JSON
+ render_bind_cast = True
def result_processor(self, dialect, coltype):
return None
class AsyncpgJSONB(json.JSONB):
- def get_dbapi_type(self, dbapi):
- return dbapi.JSONB
+ render_bind_cast = True
def result_processor(self, dialect, coltype):
return None
class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
- def get_dbapi_type(self, dbapi):
- raise NotImplementedError("should not be here")
+ pass
class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTEGER
+ __visit_name__ = "json_int_index"
+
+ render_bind_cast = True
class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
- def get_dbapi_type(self, dbapi):
- return dbapi.STRING
+ __visit_name__ = "json_str_index"
+
+ render_bind_cast = True
class AsyncpgJSONPathType(json.JSONPathType):
@@ -224,8 +216,7 @@ class AsyncpgJSONPathType(json.JSONPathType):
class AsyncpgUUID(UUID):
- def get_dbapi_type(self, dbapi):
- return dbapi.UUID
+ render_bind_cast = True
def bind_processor(self, dialect):
if not self.as_uuid and dialect.use_native_uuid:
@@ -249,8 +240,7 @@ class AsyncpgUUID(UUID):
class AsyncpgNumeric(sqltypes.Numeric):
- def get_dbapi_type(self, dbapi):
- return dbapi.NUMBER
+ render_bind_cast = True
def bind_processor(self, dialect):
return None
@@ -281,18 +271,16 @@ class AsyncpgNumeric(sqltypes.Numeric):
class AsyncpgFloat(AsyncpgNumeric):
- def get_dbapi_type(self, dbapi):
- return dbapi.FLOAT
+ __visit_name__ = "float"
+ render_bind_cast = True
class AsyncpgREGCLASS(REGCLASS):
- def get_dbapi_type(self, dbapi):
- return dbapi.STRING
+ render_bind_cast = True
class AsyncpgOID(OID):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTEGER
+ render_bind_cast = True
class PGExecutionContext_asyncpg(PGExecutionContext):
@@ -317,11 +305,6 @@ class PGExecutionContext_asyncpg(PGExecutionContext):
if not self.compiled:
return
- # we have to exclude ENUM because "enum" not really a "type"
- # we can cast to, it has to be the name of the type itself.
- # for now we just omit it from casting
- self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM}
-
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(server_side=True)
@@ -367,15 +350,7 @@ class AsyncAdapt_asyncpg_cursor:
self._adapt_connection._handle_exception(error)
def _parameter_placeholders(self, params):
- if not self._inputsizes:
- return tuple("$%d" % idx for idx, _ in enumerate(params, 1))
- else:
- return tuple(
- "$%d::%s" % (idx, typ) if typ else "$%d" % idx
- for idx, typ in enumerate(
- (_pg_types.get(typ) for typ in self._inputsizes), 1
- )
- )
+ return tuple(f"${idx:d}" for idx, _ in enumerate(params, 1))
async def _prepare_and_execute(self, operation, parameters):
adapt_connection = self._adapt_connection
@@ -464,7 +439,7 @@ class AsyncAdapt_asyncpg_cursor:
)
def setinputsizes(self, *inputsizes):
- self._inputsizes = inputsizes
+ raise NotImplementedError()
def __iter__(self):
while self._rows:
@@ -798,6 +773,12 @@ class AsyncAdapt_asyncpg_dbapi:
"all prepared caches in response to this exception)",
)
+ # pep-249 datatype placeholders. As of SQLAlchemy 2.0 these aren't
+ # used, however the test suite looks for these in a few cases.
+ STRING = util.symbol("STRING")
+ NUMBER = util.symbol("NUMBER")
+ DATETIME = util.symbol("DATETIME")
+
@util.memoized_property
def _asyncpg_error_translate(self):
import asyncpg
@@ -814,50 +795,6 @@ class AsyncAdapt_asyncpg_dbapi:
def Binary(self, value):
return value
- STRING = util.symbol("STRING")
- TIMESTAMP = util.symbol("TIMESTAMP")
- TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ")
- TIME = util.symbol("TIME")
- DATE = util.symbol("DATE")
- INTERVAL = util.symbol("INTERVAL")
- NUMBER = util.symbol("NUMBER")
- FLOAT = util.symbol("FLOAT")
- BOOLEAN = util.symbol("BOOLEAN")
- INTEGER = util.symbol("INTEGER")
- BIGINTEGER = util.symbol("BIGINTEGER")
- BYTES = util.symbol("BYTES")
- DECIMAL = util.symbol("DECIMAL")
- JSON = util.symbol("JSON")
- JSONB = util.symbol("JSONB")
- ENUM = util.symbol("ENUM")
- UUID = util.symbol("UUID")
- BYTEA = util.symbol("BYTEA")
-
- DATETIME = TIMESTAMP
- BINARY = BYTEA
-
-
-_pg_types = {
- AsyncAdapt_asyncpg_dbapi.STRING: "varchar",
- AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp",
- AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone",
- AsyncAdapt_asyncpg_dbapi.DATE: "date",
- AsyncAdapt_asyncpg_dbapi.TIME: "time",
- AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval",
- AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric",
- AsyncAdapt_asyncpg_dbapi.FLOAT: "float",
- AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool",
- AsyncAdapt_asyncpg_dbapi.INTEGER: "integer",
- AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint",
- AsyncAdapt_asyncpg_dbapi.BYTES: "bytes",
- AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal",
- AsyncAdapt_asyncpg_dbapi.JSON: "json",
- AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb",
- AsyncAdapt_asyncpg_dbapi.ENUM: "enum",
- AsyncAdapt_asyncpg_dbapi.UUID: "uuid",
- AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea",
-}
-
class PGDialect_asyncpg(PGDialect):
driver = "asyncpg"
@@ -865,19 +802,20 @@ class PGDialect_asyncpg(PGDialect):
supports_server_side_cursors = True
+ render_bind_cast = True
+
default_paramstyle = "format"
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_asyncpg
statement_compiler = PGCompiler_asyncpg
preparer = PGIdentifierPreparer_asyncpg
- use_setinputsizes = True
-
use_native_uuid = True
colspecs = util.update_copy(
PGDialect.colspecs,
{
+ sqltypes.String: AsyncpgString,
sqltypes.Time: AsyncpgTime,
sqltypes.Date: AsyncpgDate,
sqltypes.DateTime: AsyncpgDateTime,
@@ -977,20 +915,6 @@ class PGDialect_asyncpg(PGDialect):
e, self.dbapi.InterfaceError
) and "connection is closed" in str(e)
- def do_set_input_sizes(self, cursor, list_of_tuples, context):
- if self.positional:
- cursor.setinputsizes(
- *[dbtype for key, dbtype, sqltype in list_of_tuples]
- )
- else:
- cursor.setinputsizes(
- **{
- key: dbtype
- for key, dbtype, sqltype in list_of_tuples
- if dbtype
- }
- )
-
async def setup_asyncpg_json_codec(self, conn):
"""set up JSON codec for asyncpg.