summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/base.py
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2012-04-24 16:44:53 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2012-04-24 16:44:53 -0400
commitcfe56e3735d2ba34923c36e9f015253e535ed1bd (patch)
treeb2aa6a91a25dcae56bbe833ea04d50ac94ac20a7 /lib/sqlalchemy/dialects/postgresql/base.py
parenta55d6c5f35769ea61ea5240aff9f763229d3007e (diff)
downloadsqlalchemy-cfe56e3735d2ba34923c36e9f015253e535ed1bd.tar.gz
- [feature] postgresql.ARRAY features an optional
"dimension" argument, will assign a specific number of dimensions to the array which will render in DDL as ARRAY[][]..., also improves performance of bind/result processing. [ticket:2441]
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql/base.py')
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py89
1 files changed, 50 insertions, 39 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index d47b9e757..c3ff73fa1 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -329,7 +329,7 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine):
"""
__visit_name__ = 'ARRAY'
- def __init__(self, item_type, as_tuple=False):
+ def __init__(self, item_type, as_tuple=False, dimensions=None):
"""Construct an ARRAY.
E.g.::
@@ -349,6 +349,14 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine):
as psycopg2 return lists by default. When tuples are
returned, the results are hashable.
+ :param dimensions: if non-None, the ARRAY will assume a fixed
+ number of dimensions. This will cause the DDL emitted for this
+ ARRAY to include the exact number of bracket clauses ``[]``,
+ and will also optimize the performance of the type overall.
+ Note that PG arrays are always implicitly "non-dimensioned",
+ meaning they can store any number of dimensions no matter how
+ they were declared.
+
"""
if isinstance(item_type, ARRAY):
raise ValueError("Do not nest ARRAY types; ARRAY(basetype) "
@@ -357,58 +365,59 @@ class ARRAY(sqltypes.Concatenable, sqltypes.TypeEngine):
item_type = item_type()
self.item_type = item_type
self.as_tuple = as_tuple
+ self.dimensions = dimensions
def compare_values(self, x, y):
return x == y
- def bind_processor(self, dialect):
- item_proc = self.item_type.dialect_impl(dialect).bind_processor(dialect)
- if item_proc:
- def convert_item(item):
- if isinstance(item, (list, tuple)):
- return [convert_item(child) for child in item]
- else:
- return item_proc(item)
+ def _proc_array(self, arr, itemproc, dim, collection):
+ if dim == 1 or (
+ dim is None and
+ (not arr or not isinstance(arr[0], (list, tuple)))
+ ):
+ if itemproc:
+ return collection(itemproc(x) for x in arr)
+ else:
+ return collection(arr)
else:
- def convert_item(item):
- if isinstance(item, (list, tuple)):
- return [convert_item(child) for child in item]
- else:
- return item
+ return collection(
+ self._proc_array(
+ x, itemproc,
+ dim - 1 if dim is not None else None,
+ collection)
+ for x in arr
+ )
+
+ def bind_processor(self, dialect):
+ item_proc = self.item_type.\
+ dialect_impl(dialect).\
+ bind_processor(dialect)
def process(value):
if value is None:
return value
- return [convert_item(item) for item in value]
+ else:
+ return self._proc_array(
+ value,
+ item_proc,
+ self.dimensions,
+ list)
return process
def result_processor(self, dialect, coltype):
- item_proc = self.item_type.dialect_impl(dialect).result_processor(dialect, coltype)
- if item_proc:
- def convert_item(item):
- if isinstance(item, list):
- r = [convert_item(child) for child in item]
- if self.as_tuple:
- r = tuple(r)
- return r
- else:
- return item_proc(item)
- else:
- def convert_item(item):
- if isinstance(item, list):
- r = [convert_item(child) for child in item]
- if self.as_tuple:
- r = tuple(r)
- return r
- else:
- return item
+ item_proc = self.item_type.\
+ dialect_impl(dialect).\
+ result_processor(dialect, coltype)
def process(value):
if value is None:
return value
- r = [convert_item(item) for item in value]
- if self.as_tuple:
- r = tuple(r)
- return r
+ else:
+ return self._proc_array(
+ value,
+ item_proc,
+ self.dimensions,
+ tuple if self.as_tuple else list)
return process
+
PGArray = ARRAY
class ENUM(sqltypes.Enum):
@@ -841,7 +850,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
return "BYTEA"
def visit_ARRAY(self, type_):
- return self.process(type_.item_type) + '[]'
+ return self.process(type_.item_type) + ('[]' * (type_.dimensions
+ if type_.dimensions
+ is not None else 1))
class PGIdentifierPreparer(compiler.IdentifierPreparer):