diff options
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 1 | ||||
-rw-r--r-- | lib/sqlalchemy/engine.py | 49 | ||||
-rw-r--r-- | lib/sqlalchemy/mapper.py | 30 | ||||
-rw-r--r-- | lib/sqlalchemy/objectstore.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 7 | ||||
-rw-r--r-- | test/tables.py | 2 |
7 files changed, 55 insertions, 38 deletions
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index bb1f3c6ae..0ef64d480 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -40,7 +40,7 @@ class OracleDateTime(sqltypes.DateTime): return "DATE" class OracleText(sqltypes.TEXT): def get_col_spec(self): - return "TEXT" + return "CLOB" class OracleString(sqltypes.String): def get_col_spec(self): return "VARCHAR(%(length)s)" % {'length' : self.length} diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 86116f83f..f3648e69c 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -117,7 +117,6 @@ class PGSQLEngine(ansisql.ANSISQLEngine): return None def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): - # if a sequence was explicitly defined we do it here if compiled is None: return if getattr(compiled, "isinsert", False): if isinstance(parameters, list): diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 93569a05c..c910916d1 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -124,7 +124,7 @@ class SQLEngine(schema.SchemaEngine): connection.commit() def proxy(self): - return lambda s, p = None: self.execute(s, p, commit=True) + return lambda s, p = None: self.execute(s, p) def connection(self): return self._pool.connect() @@ -172,8 +172,6 @@ class SQLEngine(schema.SchemaEngine): self.do_rollback(self.context.transaction) self.context.transaction = None self.context.tcount = None - else: - self.do_rollback(self.connection()) def commit(self): if self.context.transaction is not None: @@ -183,8 +181,6 @@ class SQLEngine(schema.SchemaEngine): self.do_commit(self.context.transaction) self.context.transaction = None self.context.tcount = None - else: - self.do_commit(self.connection()) def pre_exec(self, connection, cursor, statement, parameters, many = False, echo = None, **kwargs): pass @@ -202,19 +198,23 @@ class SQLEngine(schema.SchemaEngine): else: c = connection.cursor() - self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs) - - if echo is True or self.echo: - self.log(statement) - self.log(repr(parameters)) - - if isinstance(parameters, list): - self._executemany(c, statement, parameters) - else: - self._execute(c, statement, parameters) - self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs) - if commit: - connection.commit() + try: + self.pre_exec(connection, c, statement, parameters, echo = echo, **kwargs) + + if echo is True or self.echo: + self.log(statement) + self.log(repr(parameters)) + if isinstance(parameters, list): + self._executemany(c, statement, parameters) + else: + self._execute(c, statement, parameters) + self.post_exec(connection, c, statement, parameters, echo = echo, **kwargs) + if commit or self.context.transaction is None: + self.do_commit(connection) + except: + self.do_rollback(connection) + # TODO: wrap DB exceptions ? + raise return ResultProxy(c, self, typemap = typemap) def _execute(self, c, statement, parameters): @@ -247,7 +247,18 @@ class ResultProxy: i+=1 def _get_col(self, row, key): - rec = self.props[key.lower()] + if isinstance(key, schema.Column): + try: + rec = self.props[key.label.lower()] + except KeyError: + try: + rec = self.props[key.key.lower()] + except KeyError: + rec = self.props[key.name.lower()] + elif isinstance(key, str): + rec = self.props[key.lower()] + else: + rec = self.props[key] return rec[0].convert_result_value(row[rec[1]]) def fetchall(self): diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index c90b50769..f541182aa 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -312,7 +312,7 @@ class Mapper(object): objectstore.uow().register_clean(value) if len(mappers): - return result + otherresults + return [result] + otherresults else: return result @@ -375,9 +375,21 @@ class Mapper(object): in this case, the developer must insure that an adequate set of columns exists in the rowset with which to build new object instances.""" if arg is not None and isinstance(arg, sql.Select): - return self._select_statement(arg, **params) + return self.select_statement(arg, **params) else: - return self._select_whereclause(arg, **params) + return self.select_whereclause(arg, **params) + + def select_whereclause(self, whereclause = None, order_by = None, **params): + statement = self._compile(whereclause, order_by = order_by) + return self.select_statement(statement, **params) + + def select_statement(self, statement, **params): + statement.use_labels = True + return self.instances(statement.execute(**params)) + + def select_text(self, text, **params): + t = sql.text(text, engine=self.primarytable.engine) + return self.instances(t.execute(**params)) def _getattrbycolumn(self, obj, column): try: @@ -494,13 +506,6 @@ class Mapper(object): statement.use_labels = True return statement - def _select_whereclause(self, whereclause = None, order_by = None, **params): - statement = self._compile(whereclause, order_by = order_by) - return self._select_statement(statement, **params) - - def _select_statement(self, statement, **params): - statement.use_labels = True - return self.instances(statement.execute(**params)) def _identity_key(self, row): return objectstore.get_row_key(row, self.class_, self.primarytable, self.primary_keys[self.table]) @@ -539,7 +544,7 @@ class Mapper(object): # check if primary keys in the result are None - this indicates # an instance of the object is not present in the row for col in self.primary_keys[self.table]: - if row[col.label] is None: + if row[col] is None: return None # plugin point instance = self.extension.create_instance(self, row, imap, self.class_) @@ -622,8 +627,7 @@ class ColumnProperty(MapperProperty): def execute(self, instance, row, identitykey, imap, isnew): if isnew: - instance.__dict__[self.key] = row[self.columns[0].label] - #setattr(instance, self.key, row[self.columns[0].label]) + instance.__dict__[self.key] = row[self.columns[0]] class PropertyLoader(MapperProperty): diff --git a/lib/sqlalchemy/objectstore.py b/lib/sqlalchemy/objectstore.py index 9b414f181..6081a7150 100644 --- a/lib/sqlalchemy/objectstore.py +++ b/lib/sqlalchemy/objectstore.py @@ -52,7 +52,7 @@ def get_row_key(row, class_, table, primary_keys): may be synonymous with the table argument or can be a larger construct containing that table. return value: a tuple object which is used as an identity key. """ - return (class_, table, tuple([row[column.label] for column in primary_keys])) + return (class_, table, tuple([row[column] for column in primary_keys])) def begin(): """begins a new UnitOfWork transaction. the next commit will affect only diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index b3a4ad293..a5a97a9e8 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -121,8 +121,8 @@ def bindparam(key, value = None, type=None): else: return BindParamClause(key, value, type=type) -def text(text): - return TextClause(text) +def text(text, engine=None): + return TextClause(text, engine=engine) def null(): return Null() @@ -383,9 +383,10 @@ class BindParamClause(ClauseElement): class TextClause(ClauseElement): """represents any plain text WHERE clause or full SQL statement""" - def __init__(self, text = ""): + def __init__(self, text = "", engine=None): self.text = text self.parens = False + self.engine = engine def accept_visitor(self, visitor): visitor.visit_textclause(self) def hash_key(self): diff --git a/test/tables.py b/test/tables.py index aceed904b..8bddce1d0 100644 --- a/test/tables.py +++ b/test/tables.py @@ -159,6 +159,8 @@ class Address(object): return "Address: " + repr(getattr(self, 'address_id', None)) + " " + repr(getattr(self, 'user_id', None)) + " " + repr(self.email_address) class Order(object): + def __init__(self): + self.isopen=0 def __repr__(self): return "Order: " + repr(self.description) + " " + repr(self.isopen) + " " + repr(getattr(self, 'items', None)) |