summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGES6
-rw-r--r--examples/postgis/postgis.py161
-rw-r--r--lib/sqlalchemy/engine/base.py2
-rw-r--r--lib/sqlalchemy/sql/compiler.py2
-rw-r--r--lib/sqlalchemy/sql/expression.py8
-rw-r--r--lib/sqlalchemy/sql/functions.py4
6 files changed, 132 insertions, 51 deletions
diff --git a/CHANGES b/CHANGES
index cdfe4617a..e098cbdbc 100644
--- a/CHANGES
+++ b/CHANGES
@@ -186,6 +186,12 @@ CHANGES
- sql
- Columns can again contain percent signs within their
names. [ticket:1256]
+
+ - sqlalchemy.sql.expression.Function is now a public
+ class. It can be subclassed to provide user-defined
+ SQL functions in an imperative style, including
+ with pre-established behaviors. The postgis.py
+ example illustrates one usage of this.
- PickleType now favors == comparison by default,
if the incoming object (such as a dict) implements
diff --git a/examples/postgis/postgis.py b/examples/postgis/postgis.py
index c463cca26..802aa0ea9 100644
--- a/examples/postgis/postgis.py
+++ b/examples/postgis/postgis.py
@@ -1,7 +1,7 @@
"""A naive example illustrating techniques to help
embed PostGIS functionality.
-The techniques here could be used a capable developer
+The techniques here could be used by a capable developer
as the basis for a comprehensive PostGIS SQLAlchemy extension.
Please note this is an entirely incomplete proof of concept
only, and PostGIS support is *not* a supported feature
@@ -40,23 +40,79 @@ from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.types import TypeEngine
from sqlalchemy.sql import expression
+# Python datatypes
+
+class GisElement(object):
+ """Represents a geometry value."""
+
+ @property
+ def wkt(self):
+ return func.AsText(literal(self, Geometry))
+
+ @property
+ def wkb(self):
+ return func.AsBinary(literal(self, Geometry))
+
+ def __str__(self):
+ return self.desc
+
+ def __repr__(self):
+ return "<%s at 0x%x; %r>" % (self.__class__.__name__, id(self), self.desc)
+
+class PersistentGisElement(GisElement):
+ """Represents a Geometry value as loaded from the database."""
+
+ def __init__(self, desc):
+ self.desc = desc
+
+class TextualGisElement(GisElement, expression.Function):
+ """Represents a Geometry value as expressed within application code; i.e. in wkt format.
+
+ Extends expression.Function so that the value is interpreted as
+ GeomFromText(value) in a SQL expression context.
+
+ """
+
+ def __init__(self, desc, srid=-1):
+ assert isinstance(desc, basestring)
+ self.desc = desc
+ expression.Function.__init__(self, "GeomFromText", desc, srid)
+
+
+# SQL datatypes.
+
class Geometry(TypeEngine):
- """Base PostGIS Geometry column type"""
+ """Base PostGIS Geometry column type.
+
+ Converts bind/result values to/from a PersistentGisElement.
+
+ """
name = 'GEOMETRY'
- def __init__(self, dimension, srid=-1):
+ def __init__(self, dimension=None, srid=-1):
self.dimension = dimension
self.srid = srid
-
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is not None:
+ return value.desc
+ else:
+ return value
+ return process
+
def result_processor(self, dialect):
def process(value):
if value is not None:
- return gis_element(value)
+ return PersistentGisElement(value)
else:
return value
return process
+# other datatypes can be added as needed, which
+# currently only affect DDL statements.
+
class Point(Geometry):
name = 'POINT'
@@ -66,10 +122,25 @@ class Curve(Geometry):
class LineString(Curve):
name = 'LINESTRING'
-# ... add other types as needed
+# ... etc.
+
+# DDL integration
class GISDDL(object):
+ """A DDL extension which integrates SQLAlchemy table create/drop
+ methods with PostGis' AddGeometryColumn/DropGeometryColumn functions.
+
+ Usage::
+
+ sometable = Table('sometable', metadata, ...)
+
+ GISDDL(sometable)
+
+ sometable.create()
+
+ """
+
def __init__(self, table):
for event in ('before-create', 'after-create', 'before-drop', 'after-drop'):
table.ddl_listeners[event].append(self)
@@ -95,23 +166,25 @@ class GISDDL(object):
elif event == 'after-drop':
table._columns = self._stack.pop()
+# ORM integration
+
def _to_postgis(value):
+ """Interpret a value as a GIS-compatible construct."""
+
if hasattr(value, '__clause_element__'):
return value.__clause_element__()
- elif isinstance(value, expression.ClauseElement):
+ elif isinstance(value, (expression.ClauseElement, GisElement)):
return value
elif isinstance(value, basestring):
- return func.GeomFromText(value, -1)
- elif isinstance(value, gis_element):
- return value.desc
+ return TextualGisElement(value)
elif value is None:
return None
else:
raise Exception("Invalid type")
-
+
class GisAttribute(AttributeExtension):
- """Intercepts 'set' events on a mapped instance and
+ """Intercepts 'set' events on a mapped instance attribute and
converts the incoming value to a GIS expression.
"""
@@ -123,44 +196,36 @@ class GisComparator(ColumnProperty.ColumnComparator):
"""Intercepts standard Column operators on mapped class attributes
and overrides their behavior.
-
"""
+ # override the __eq__() operator
def __eq__(self, other):
return self.__clause_element__().op('~=')(_to_postgis(other))
+ # add a custom operator
def intersects(self, other):
return self.__clause_element__().op('&&')(_to_postgis(other))
-
-class gis_element(object):
- """Represents a geometry value.
-
- This is just the raw string returned by PostGIS,
- plus some helper functions.
-
- """
-
- def __init__(self, desc):
- self.desc = desc
-
- @property
- def wkt(self):
- return func.AsText(self.desc)
-
- @property
- def wkb(self):
- return func.AsBinary(self.desc)
-
+ # any number of GIS operators can be overridden/added here
+ # using the techniques above.
+
+
def GISColumn(*args, **kw):
- """Define a declarative column property with GIS behavior."""
+ """Define a declarative column property with GIS behavior.
+ This just produces orm.column_property() with the appropriate
+ extension and comparator_factory arguments. The given arguments
+ are passed through to Column. The declarative module extracts
+ the Column for inclusion in the mapped table.
+
+ """
return column_property(
Column(*args, **kw),
extension=GisAttribute(),
comparator_factory=GisComparator
)
-
+
+# illustrate usage
if __name__ == '__main__':
from sqlalchemy import *
from sqlalchemy.orm import *
@@ -187,8 +252,7 @@ if __name__ == '__main__':
session = sessionmaker(bind=engine)()
- # Add objects using strings for the geometry objects; the attribute extension
- # converts them to GeomFromText
+ # Add objects. We can use strings...
session.add_all([
Road(road_name='Jeff Rd', road_geom='LINESTRING(191232 243118,191108 243242)'),
Road(road_name='Geordie Rd', road_geom='LINESTRING(189141 244158,189265 244817)'),
@@ -197,18 +261,29 @@ if __name__ == '__main__':
Road(road_name='Phil Tce', road_geom='LINESTRING(190131 224148,190871 228134)'),
])
- # GeomFromText can be called directly here as well.
- session.add(
- Road(road_name='Dave Cres', road_geom=func.GeomFromText('LINESTRING(198231 263418,198213 268322)', -1)),
- )
+ # or use an explicit TextualGisElement (similar to saying func.GeomFromText())
+ r = Road(road_name='Dave Cres', road_geom=TextualGisElement('LINESTRING(198231 263418,198213 268322)', -1))
+ session.add(r)
+
+ # pre flush, the TextualGisElement represents the string we sent.
+ assert str(r.road_geom) == 'LINESTRING(198231 263418,198213 268322)'
+ assert session.scalar(r.road_geom.wkt) == 'LINESTRING(198231 263418,198213 268322)'
session.commit()
+
+ # after flush and/or commit, all the TextualGisElements become PersistentGisElements.
+ assert str(r.road_geom) == "01020000000200000000000000B832084100000000E813104100000000283208410000000088601041"
r1 = session.query(Road).filter(Road.road_name=='Graeme Ave').one()
-
- # illustrate the overridden __eq__() operator
+
+ # illustrate the overridden __eq__() operator.
+
+ # strings come in as TextualGisElements
r2 = session.query(Road).filter(Road.road_geom == 'LINESTRING(189412 252431,189631 259122)').one()
+
+ # PersistentGisElements work directly
r3 = session.query(Road).filter(Road.road_geom == r1.road_geom).one()
+
assert r1 is r2 is r3
# illustrate the "intersects" operator
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index c547c0e54..de6346b8b 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -962,7 +962,7 @@ class Connection(Connectable):
# poor man's multimethod/generic function thingy
executors = {
- expression._Function: _execute_function,
+ expression.Function: _execute_function,
expression.ClauseElement: _execute_clauseelement,
Compiled: _execute_compiled,
schema.SchemaItem: _execute_default,
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 31fc9ae1e..0430f053b 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -463,7 +463,7 @@ class DefaultCompiler(engine.Compiled):
not isinstance(column.table, sql.Select):
return _CompileLabel(column, sql._generated_label(column.name))
elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
- and (not hasattr(column, 'name') or isinstance(column, sql._Function)):
+ and (not hasattr(column, 'name') or isinstance(column, sql.Function)):
return _CompileLabel(column, column.anon_label)
else:
return column
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 07df207dd..7204e2956 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -820,12 +820,12 @@ def text(text, bind=None, *args, **kwargs):
return _TextClause(text, bind=bind, *args, **kwargs)
def null():
- """Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement."""
+ """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql statement."""
return _Null()
class _FunctionGenerator(object):
- """Generate ``_Function`` objects based on getattr calls."""
+ """Generate :class:`Function` objects based on getattr calls."""
def __init__(self, **opts):
self.__names = []
@@ -856,7 +856,7 @@ class _FunctionGenerator(object):
if func is not None:
return func(*c, **o)
- return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
+ return Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
# "func" global - i.e. func.count()
func = _FunctionGenerator()
@@ -2228,7 +2228,7 @@ class _CalculatedClause(ColumnElement):
def _compare_type(self, obj):
return self.type
-class _Function(_CalculatedClause, FromClause):
+class Function(_CalculatedClause, FromClause):
"""Describe a SQL function.
Extends ``_CalculatedClause``, turn the *clauselist* into function
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index b57b242f5..1bcc6d864 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -1,6 +1,6 @@
from sqlalchemy import types as sqltypes
from sqlalchemy.sql.expression import (
- ClauseList, _Function, _literal_as_binds, text
+ ClauseList, Function, _literal_as_binds, text
)
from sqlalchemy.sql import operators
from sqlalchemy.sql.visitors import VisitableType
@@ -10,7 +10,7 @@ class _GenericMeta(VisitableType):
args = [_literal_as_binds(c) for c in args]
return type.__call__(self, *args, **kwargs)
-class GenericFunction(_Function):
+class GenericFunction(Function):
__metaclass__ = _GenericMeta
def __init__(self, type_=None, group=True, args=(), **kwargs):