diff options
-rw-r--r-- | CHANGES | 6 | ||||
-rw-r--r-- | examples/postgis/postgis.py | 161 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 8 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/functions.py | 4 |
6 files changed, 132 insertions, 51 deletions
@@ -186,6 +186,12 @@ CHANGES - sql - Columns can again contain percent signs within their names. [ticket:1256] + + - sqlalchemy.sql.expression.Function is now a public + class. It can be subclassed to provide user-defined + SQL functions in an imperative style, including + with pre-established behaviors. The postgis.py + example illustrates one usage of this. - PickleType now favors == comparison by default, if the incoming object (such as a dict) implements diff --git a/examples/postgis/postgis.py b/examples/postgis/postgis.py index c463cca26..802aa0ea9 100644 --- a/examples/postgis/postgis.py +++ b/examples/postgis/postgis.py @@ -1,7 +1,7 @@ """A naive example illustrating techniques to help embed PostGIS functionality. -The techniques here could be used a capable developer +The techniques here could be used by a capable developer as the basis for a comprehensive PostGIS SQLAlchemy extension. Please note this is an entirely incomplete proof of concept only, and PostGIS support is *not* a supported feature @@ -40,23 +40,79 @@ from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.types import TypeEngine from sqlalchemy.sql import expression +# Python datatypes + +class GisElement(object): + """Represents a geometry value.""" + + @property + def wkt(self): + return func.AsText(literal(self, Geometry)) + + @property + def wkb(self): + return func.AsBinary(literal(self, Geometry)) + + def __str__(self): + return self.desc + + def __repr__(self): + return "<%s at 0x%x; %r>" % (self.__class__.__name__, id(self), self.desc) + +class PersistentGisElement(GisElement): + """Represents a Geometry value as loaded from the database.""" + + def __init__(self, desc): + self.desc = desc + +class TextualGisElement(GisElement, expression.Function): + """Represents a Geometry value as expressed within application code; i.e. in wkt format. + + Extends expression.Function so that the value is interpreted as + GeomFromText(value) in a SQL expression context. + + """ + + def __init__(self, desc, srid=-1): + assert isinstance(desc, basestring) + self.desc = desc + expression.Function.__init__(self, "GeomFromText", desc, srid) + + +# SQL datatypes. + class Geometry(TypeEngine): - """Base PostGIS Geometry column type""" + """Base PostGIS Geometry column type. + + Converts bind/result values to/from a PersistentGisElement. + + """ name = 'GEOMETRY' - def __init__(self, dimension, srid=-1): + def __init__(self, dimension=None, srid=-1): self.dimension = dimension self.srid = srid - + + def bind_processor(self, dialect): + def process(value): + if value is not None: + return value.desc + else: + return value + return process + def result_processor(self, dialect): def process(value): if value is not None: - return gis_element(value) + return PersistentGisElement(value) else: return value return process +# other datatypes can be added as needed, which +# currently only affect DDL statements. + class Point(Geometry): name = 'POINT' @@ -66,10 +122,25 @@ class Curve(Geometry): class LineString(Curve): name = 'LINESTRING' -# ... add other types as needed +# ... etc. + +# DDL integration class GISDDL(object): + """A DDL extension which integrates SQLAlchemy table create/drop + methods with PostGis' AddGeometryColumn/DropGeometryColumn functions. + + Usage:: + + sometable = Table('sometable', metadata, ...) + + GISDDL(sometable) + + sometable.create() + + """ + def __init__(self, table): for event in ('before-create', 'after-create', 'before-drop', 'after-drop'): table.ddl_listeners[event].append(self) @@ -95,23 +166,25 @@ class GISDDL(object): elif event == 'after-drop': table._columns = self._stack.pop() +# ORM integration + def _to_postgis(value): + """Interpret a value as a GIS-compatible construct.""" + if hasattr(value, '__clause_element__'): return value.__clause_element__() - elif isinstance(value, expression.ClauseElement): + elif isinstance(value, (expression.ClauseElement, GisElement)): return value elif isinstance(value, basestring): - return func.GeomFromText(value, -1) - elif isinstance(value, gis_element): - return value.desc + return TextualGisElement(value) elif value is None: return None else: raise Exception("Invalid type") - + class GisAttribute(AttributeExtension): - """Intercepts 'set' events on a mapped instance and + """Intercepts 'set' events on a mapped instance attribute and converts the incoming value to a GIS expression. """ @@ -123,44 +196,36 @@ class GisComparator(ColumnProperty.ColumnComparator): """Intercepts standard Column operators on mapped class attributes and overrides their behavior. - """ + # override the __eq__() operator def __eq__(self, other): return self.__clause_element__().op('~=')(_to_postgis(other)) + # add a custom operator def intersects(self, other): return self.__clause_element__().op('&&')(_to_postgis(other)) - -class gis_element(object): - """Represents a geometry value. - - This is just the raw string returned by PostGIS, - plus some helper functions. - - """ - - def __init__(self, desc): - self.desc = desc - - @property - def wkt(self): - return func.AsText(self.desc) - - @property - def wkb(self): - return func.AsBinary(self.desc) - + # any number of GIS operators can be overridden/added here + # using the techniques above. + + def GISColumn(*args, **kw): - """Define a declarative column property with GIS behavior.""" + """Define a declarative column property with GIS behavior. + This just produces orm.column_property() with the appropriate + extension and comparator_factory arguments. The given arguments + are passed through to Column. The declarative module extracts + the Column for inclusion in the mapped table. + + """ return column_property( Column(*args, **kw), extension=GisAttribute(), comparator_factory=GisComparator ) - + +# illustrate usage if __name__ == '__main__': from sqlalchemy import * from sqlalchemy.orm import * @@ -187,8 +252,7 @@ if __name__ == '__main__': session = sessionmaker(bind=engine)() - # Add objects using strings for the geometry objects; the attribute extension - # converts them to GeomFromText + # Add objects. We can use strings... session.add_all([ Road(road_name='Jeff Rd', road_geom='LINESTRING(191232 243118,191108 243242)'), Road(road_name='Geordie Rd', road_geom='LINESTRING(189141 244158,189265 244817)'), @@ -197,18 +261,29 @@ if __name__ == '__main__': Road(road_name='Phil Tce', road_geom='LINESTRING(190131 224148,190871 228134)'), ]) - # GeomFromText can be called directly here as well. - session.add( - Road(road_name='Dave Cres', road_geom=func.GeomFromText('LINESTRING(198231 263418,198213 268322)', -1)), - ) + # or use an explicit TextualGisElement (similar to saying func.GeomFromText()) + r = Road(road_name='Dave Cres', road_geom=TextualGisElement('LINESTRING(198231 263418,198213 268322)', -1)) + session.add(r) + + # pre flush, the TextualGisElement represents the string we sent. + assert str(r.road_geom) == 'LINESTRING(198231 263418,198213 268322)' + assert session.scalar(r.road_geom.wkt) == 'LINESTRING(198231 263418,198213 268322)' session.commit() + + # after flush and/or commit, all the TextualGisElements become PersistentGisElements. + assert str(r.road_geom) == "01020000000200000000000000B832084100000000E813104100000000283208410000000088601041" r1 = session.query(Road).filter(Road.road_name=='Graeme Ave').one() - - # illustrate the overridden __eq__() operator + + # illustrate the overridden __eq__() operator. + + # strings come in as TextualGisElements r2 = session.query(Road).filter(Road.road_geom == 'LINESTRING(189412 252431,189631 259122)').one() + + # PersistentGisElements work directly r3 = session.query(Road).filter(Road.road_geom == r1.road_geom).one() + assert r1 is r2 is r3 # illustrate the "intersects" operator diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index c547c0e54..de6346b8b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -962,7 +962,7 @@ class Connection(Connectable): # poor man's multimethod/generic function thingy executors = { - expression._Function: _execute_function, + expression.Function: _execute_function, expression.ClauseElement: _execute_clauseelement, Compiled: _execute_compiled, schema.SchemaItem: _execute_default, diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 31fc9ae1e..0430f053b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -463,7 +463,7 @@ class DefaultCompiler(engine.Compiled): not isinstance(column.table, sql.Select): return _CompileLabel(column, sql._generated_label(column.name)) elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \ - and (not hasattr(column, 'name') or isinstance(column, sql._Function)): + and (not hasattr(column, 'name') or isinstance(column, sql.Function)): return _CompileLabel(column, column.anon_label) else: return column diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 07df207dd..7204e2956 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -820,12 +820,12 @@ def text(text, bind=None, *args, **kwargs): return _TextClause(text, bind=bind, *args, **kwargs) def null(): - """Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement.""" + """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql statement.""" return _Null() class _FunctionGenerator(object): - """Generate ``_Function`` objects based on getattr calls.""" + """Generate :class:`Function` objects based on getattr calls.""" def __init__(self, **opts): self.__names = [] @@ -856,7 +856,7 @@ class _FunctionGenerator(object): if func is not None: return func(*c, **o) - return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) + return Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o) # "func" global - i.e. func.count() func = _FunctionGenerator() @@ -2228,7 +2228,7 @@ class _CalculatedClause(ColumnElement): def _compare_type(self, obj): return self.type -class _Function(_CalculatedClause, FromClause): +class Function(_CalculatedClause, FromClause): """Describe a SQL function. Extends ``_CalculatedClause``, turn the *clauselist* into function diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index b57b242f5..1bcc6d864 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -1,6 +1,6 @@ from sqlalchemy import types as sqltypes from sqlalchemy.sql.expression import ( - ClauseList, _Function, _literal_as_binds, text + ClauseList, Function, _literal_as_binds, text ) from sqlalchemy.sql import operators from sqlalchemy.sql.visitors import VisitableType @@ -10,7 +10,7 @@ class _GenericMeta(VisitableType): args = [_literal_as_binds(c) for c in args] return type.__call__(self, *args, **kwargs) -class GenericFunction(_Function): +class GenericFunction(Function): __metaclass__ = _GenericMeta def __init__(self, type_=None, group=True, args=(), **kwargs): |