diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-03-13 00:24:54 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-03-13 00:24:54 +0000 |
commit | c5e1abc7f7adce841775ea92b72bcf95207027af (patch) | |
tree | b406fd4e6ede57ed8805a40a909c3c69207d3414 /lib/sqlalchemy | |
parent | 2ce45d70c7e499fd6c239d963f50cd839b28629b (diff) | |
download | sqlalchemy-c5e1abc7f7adce841775ea92b72bcf95207027af.tar.gz |
refactor to Compiled.get_params() to return new ClauseParameters object, a more intelligent bind parameter dictionary that does type conversions late and preserves the unconverted value; used to fix mappers not comparing correct value in post-fetch [ticket:110]
removed pre_exec assertion from oracle/firebird regarding "check for sequence/primary key value"
fix to Unicode type to check for null, fixes [ticket:109]
create_engine() now uses genericized parameters; host/hostname, db/dbname/database, password/passwd, etc. for all engine connections
fix to select([func(column)]) so that it creates a FROM clause to the column's table, fixes [ticket:111]
doc updates for column defaults, indexes, connection pooling, engine params
unit tests for the above bugfixes
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 22 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/firebird.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 15 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/postgres.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/engine.py | 20 | ||||
-rw-r--r-- | lib/sqlalchemy/mapping/mapper.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/schema.py | 2 | ||||
-rw-r--r-- | lib/sqlalchemy/sql.py | 23 | ||||
-rw-r--r-- | lib/sqlalchemy/types.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/util.py | 3 |
11 files changed, 62 insertions, 50 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 7f95cd392..b039b346b 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -118,7 +118,8 @@ class ANSICompiler(sql.Compiled): objects compiled within this object. The output is dependent on the paramstyle of the DBAPI being used; if a named style, the return result will be a dictionary with keynames matching the compiled statement. If a positional style, the output - will be a list corresponding to the bind positions in the compiled statement. + will be a list, with an iterator that will return parameter + values in an order corresponding to the bind positions in the compiled statement. for an executemany style of call, this method should be called for each element in the list of parameter groups that will ultimately be executed. @@ -129,32 +130,23 @@ class ANSICompiler(sql.Compiled): bindparams = {} bindparams.update(params) + d = sql.ClauseParameters(self.engine) if self.positional: - d = OrderedDict() for k in self.positiontup: b = self.binds[k] - if self.engine is not None: - d[k] = b.typeprocess(b.value, self.engine) - else: - d[k] = b.value + d.set_parameter(k, b.value, b) else: - d = {} for b in self.binds.values(): - if self.engine is not None: - d[b.key] = b.typeprocess(b.value, self.engine) - else: - d[b.key] = b.value + d.set_parameter(b.key, b.value, b) for key, value in bindparams.iteritems(): try: b = self.binds[key] except KeyError: continue - if self.engine is not None: - d[b.key] = b.typeprocess(value, self.engine) - else: - d[b.key] = value + d.set_parameter(b.key, value, b) + #print "FROM", params, "TO", d return d def get_named_params(self, parameters): diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 4dd4aa2a6..7d5cfed11 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -176,19 +176,8 @@ class FBSQLEngine(ansisql.ANSISQLEngine): return self.context.last_inserted_ids def pre_exec(self, proxy, compiled, parameters, **kwargs): - # this is just an assertion that all the primary key columns in an insert statement - # have a value set up, or have a default generator ready to go - if getattr(compiled, "isinsert", False): - if isinstance(parameters, list): - plist = parameters - else: - plist = [parameters] - for param in plist: - for primary_key in compiled.statement.table.primary_key: - if not param.has_key(primary_key.key) or param[primary_key.key] is None: - if primary_key.default is None: - raise "Column '%s.%s': Firebird primary key columns require a default value or a schema.Sequence to create ids" % (primary_key.table.name, primary_key.name) - + pass + def _executemany(self, c, statement, parameters): rowcount = 0 for param in parameters: diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index 8b262877c..c55da97cb 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -134,7 +134,7 @@ class MySQLEngine(ansisql.ANSISQLEngine): def __init__(self, opts, module = None, **params): if module is None: self.module = mysql - self.opts = opts or {} + self.opts = self._translate_connect_args(('host', 'db', 'user', 'passwd'), opts) ansisql.ANSISQLEngine.__init__(self, **params) def connect_args(self): diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 8f8058680..21b478001 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -90,7 +90,7 @@ def descriptor(): class OracleSQLEngine(ansisql.ANSISQLEngine): def __init__(self, opts, use_ansi = True, module = None, **params): self._use_ansi = use_ansi - self.opts = opts or {} + self.opts = self._translate_connect_args((None, 'dsn', 'user', 'password'), opts) if module is None: self.module = cx_Oracle else: @@ -181,18 +181,7 @@ order by UCC.CONSTRAINT_NAME""",{'table_name' : table.name.upper()}) return self.context.last_inserted_ids def pre_exec(self, proxy, compiled, parameters, **kwargs): - # this is just an assertion that all the primary key columns in an insert statement - # have a value set up, or have a default generator ready to go - if getattr(compiled, "isinsert", False): - if isinstance(parameters, list): - plist = parameters - else: - plist = [parameters] - for param in plist: - for primary_key in compiled.statement.table.primary_key: - if not param.has_key(primary_key.key) or param[primary_key.key] is None: - if primary_key.default is None: - raise "Column '%s.%s': Oracle primary key columns require a default value or a schema.Sequence to create ids" % (primary_key.table.name, primary_key.name) + pass def _executemany(self, c, statement, parameters): rowcount = 0 diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index db20b636c..72d426012 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -181,7 +181,7 @@ class PGSQLEngine(ansisql.ANSISQLEngine): self.version = 1 except: self.version = 1 - self.opts = opts or {} + self.opts = self._translate_connect_args(('host', 'database', 'user', 'password'), opts) if self.opts.has_key('port'): if self.version == 2: self.opts['port'] = int(self.opts['port']) diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py index 269402f81..e44e0a950 100644 --- a/lib/sqlalchemy/engine.py +++ b/lib/sqlalchemy/engine.py @@ -203,6 +203,25 @@ class SQLEngine(schema.SchemaEngine): self._figure_paramstyle() self.logger = logger or util.Logger(origin='engine') + def _translate_connect_args(self, names, args): + """translates a dictionary of connection arguments to those used by a specific dbapi. + the names parameter is a tuple of argument names in the form ('host', 'database', 'user', 'password') + where the given strings match the corresponding argument names for the dbapi. Will return a dictionary + with the dbapi-specific parameters, the generic ones removed, and any additional parameters still remaining, + from the dictionary represented by args. Will return a blank dictionary if args is null.""" + if args is None: + return {} + a = args.copy() + standard_names = [('host','hostname'), ('database', 'dbname'), ('user', 'username'), ('password', 'passwd', 'pw')] + for n in names: + sname = standard_names.pop(0) + if n is None: + continue + for sn in sname: + if sn != n and a.has_key(sn): + a[n] = a[sn] + del a[sn] + return a def _get_ischema(self): # We use a property for ischema so that the accessor # creation only happens as needed, since otherwise we @@ -563,7 +582,6 @@ class SQLEngine(schema.SchemaEngine): parameters = [compiled.get_params(**m) for m in parameters] else: parameters = compiled.get_params(**parameters) - def proxy(statement=None, parameters=None): if statement is None: return cursor diff --git a/lib/sqlalchemy/mapping/mapper.py b/lib/sqlalchemy/mapping/mapper.py index 554b2d5b4..a77c2db12 100644 --- a/lib/sqlalchemy/mapping/mapper.py +++ b/lib/sqlalchemy/mapping/mapper.py @@ -651,8 +651,8 @@ class Mapper(object): for c in table.c: if c.primary_key or not params.has_key(c.name): continue - if self._getattrbycolumn(obj, c) != params[c.name]: - self._setattrbycolumn(obj, c, params[c.name]) + if self._getattrbycolumn(obj, c) != params.get_original(c.name): + self._setattrbycolumn(obj, c, params.get_original(c.name)) def delete_obj(self, objects, uow): """called by a UnitOfWork object to delete objects, which involves a diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 5cb9f2043..756c03b6e 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -267,7 +267,7 @@ class Column(sql.ColumnClause, SchemaItem): name will all be included in the index, in the order of their creation. - unique=None : True or undex name. Indicates that this column is + unique=None : True or index name. Indicates that this column is indexed in a unique index . Pass true to autogenerate the index name. Pass a string to specify the index name. Multiple columns that specify the same index name will all be included in the index, in the diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 89b4b5585..4eaf33e00 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -232,6 +232,27 @@ def _is_literal(element): def is_column(col): return isinstance(col, ColumnElement) +class ClauseParameters(util.OrderedDict): + """represents a dictionary/iterator of bind parameter key names/values. Includes parameters compiled with a Compiled object as well as additional arguments passed to the Compiled object's get_params() method. Parameter values will be converted as per the TypeEngine objects present in the bind parameter objects. The non-converted value can be retrieved via the get_original method. For Compiled objects that compile positional parameters, the values() iteration of the object will return the parameter values in the correct order.""" + def __init__(self, engine=None): + super(ClauseParameters, self).__init__(self) + self.engine = engine + self.binds = {} + def set_parameter(self, key, value, bindparam): + self[key] = value + self.binds[key] = bindparam + def get_original(self, key): + return super(ClauseParameters, self).__getitem__(key) + def __getitem__(self, key): + v = super(ClauseParameters, self).__getitem__(key) + if self.engine is not None and self.binds.has_key(key): + v = self.binds[key].typeprocess(v, self.engine) + return v + def values(self): + return [self[key] for key in self] + def get_original_dict(self): + return self.copy() + class ClauseVisitor(object): """Defines the visiting of ClauseElements.""" def visit_column(self, column):pass @@ -779,6 +800,8 @@ class Function(ClauseList, ColumnElement): clause = BindParamClause(self.name, clause, shortname=self.name, type=None) self.clauses.append(clause) def _process_from_dict(self, data, asfrom): + super(Function, self)._process_from_dict(data, asfrom) + # this helps a Select object get the engine from us data.setdefault(self, self) def copy_container(self): clauses = [clause.copy_container() for clause in self.clauses] diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 96e6f1edb..89fd3bd2c 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -91,12 +91,12 @@ class String(TypeEngine): class Unicode(String): def convert_bind_param(self, value, engine): - if isinstance(value, unicode): + if value is not None and isinstance(value, unicode): return value.encode(engine.encoding) else: return value def convert_result_value(self, value, engine): - if not isinstance(value, unicode): + if value is not None and not isinstance(value, unicode): return value.decode(engine.encoding) else: return value diff --git a/lib/sqlalchemy/util.py b/lib/sqlalchemy/util.py index 2b522d571..303cf5683 100644 --- a/lib/sqlalchemy/util.py +++ b/lib/sqlalchemy/util.py @@ -191,7 +191,8 @@ class OrderedDict(dict): def itervalues(self): return iter([self[key] for key in self._list]) - def iterkeys(self): return self.__iter__() + def iterkeys(self): + return self.__iter__() def iteritems(self): return iter([(key, self[key]) for key in self.keys()]) |