diff options
-rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 93 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/default.py | 13 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 20 |
3 files changed, 111 insertions, 15 deletions
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index fc35df2bb..df45f69bb 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -659,6 +659,16 @@ class OracleCompiler(compiler.DefaultCompiler): """Need to determine how to get ``LIMIT``/``OFFSET`` into a ``UNION`` for Oracle.""" pass + def create_insert_update_bind(self, col, value): + key = col.key + # TODO: make this check more specific to reserved words + if len(key) < 30: + key += '_' + bindparam = sql.bindparam(key, value, shortname=col.key, type_=col.type) + self.binds[col.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) + + def visit_select(self, select, **kwargs): """Look for ``LIMIT`` and OFFSET in a select statement, and if so tries to wrap it in a subquery with ``row_number()`` criterion. @@ -735,3 +745,86 @@ dialect.schemagenerator = OracleSchemaGenerator dialect.schemadropper = OracleSchemaDropper dialect.preparer = OracleIdentifierPreparer dialect.defaultrunner = OracleDefaultRunner + + +RESERVED_WORDS = util.Set(''' +SHARE +RAW +DROP +BETWEEN +FROM +DESC +OPTION +PRIOR +LONG +THEN +DEFAULT +ALTER +IS +INTO +MINUS +INTEGER +NUMBER +GRANT +IDENTIFIED +ALL +TO +ORDER +ON +FLOAT +DATE +HAVING +CLUSTER +NOWAIT +RESOURCE +ANY +TABLE +INDEX +FOR +UPDATE +WHERE +CHECK +SMALLINT +WITH +DELETE +BY +ASC +REVOKE +LIKE +SIZE +RENAME +NOCOMPRESS +NULL +GROUP +VALUES +AS +IN +VIEW +EXCLUSIVE +COMPRESS +SYNONYM +SELECT +INSERT +EXISTS +NOT +TRIGGER +ELSE +CREATE +INTERSECT +PCTFREE +DISTINCT +CONNECT +SET +MODE +OF +UNIQUE +VARCHAR2 +VARCHAR +LOCK +OR +CHAR +DECIMAL +UNION +PUBLIC +AND +START'''.splitlines())
\ No newline at end of file diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index dfeefa337..f716d06f5 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -371,7 +371,7 @@ class DefaultExecutionContext(base.ExecutionContext): else: val = drunner.get_column_onupdate(c) if val is not None: - param[c.key] = val + param[self.compiled.binds[c.key].key] = val self.compiled_parameters = params else: @@ -385,12 +385,15 @@ class DefaultExecutionContext(base.ExecutionContext): val = drunner.get_column_onupdate(c) if val is not None: - compiled_parameters[c.key] = val + compiled_parameters[self.compiled.binds[c.key].key] = val if self.isinsert: - self._last_inserted_ids = [compiled_parameters.get(c.key, None) for c in self.compiled.statement.table.primary_key] - self._last_inserted_params = compiled_parameters + self._last_inserted_ids = [ + k and compiled_parameters.get(k.key, None) or None for k in + [self.compiled.binds.get(c.key, None) for c in self.compiled.statement.table.primary_key] + ] + self._last_inserted_params = dict([(key, compiled_parameters[self.compiled.bind_names[b]]) for key, b in self.compiled.binds.iteritems()]) else: - self._last_updated_params = compiled_parameters + self._last_updated_params = dict([(key, compiled_parameters[self.compiled.bind_names[b]]) for key, b in self.compiled.binds.iteritems()]) self.postfetch_cols = self.compiled.postfetch diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6a048a780..4a45d6c15 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -649,6 +649,11 @@ class DefaultCompiler(engine.Compiled): self.stack.pop(-1) return text + + def create_insert_update_bind(self, col, value): + bindparam = sql.bindparam(col.key, value, type_=col.type) + self.binds[col.key] = bindparam + return self.bindparam_string(self._truncate_bindparam(bindparam)) def _get_colparams(self, stmt): """create a set of tuples representing column/string pairs for use @@ -656,18 +661,13 @@ class DefaultCompiler(engine.Compiled): """ - def create_bind_param(col, value): - bindparam = sql.bindparam(col.key, value, type_=col.type) - self.binds[col.key] = bindparam - return self.bindparam_string(self._truncate_bindparam(bindparam)) - self.postfetch = [] self.prefetch = [] # no parameters in the statement, no parameters in the # compiled params - return binds for all columns if self.column_keys is None and stmt.parameters is None: - return [(c, create_bind_param(c, None)) for c in stmt.table.columns] + return [(c, self.create_insert_update_bind(c, None)) for c in stmt.table.columns] # if we have statement parameters - set defaults in the # compiled params @@ -686,7 +686,7 @@ class DefaultCompiler(engine.Compiled): if c.key in parameters: value = parameters[c.key] if sql._is_literal(value): - value = create_bind_param(c, value) + value = self.create_insert_update_bind(c, value) else: self.postfetch.append(c) value = self.process(value.self_group()) @@ -699,7 +699,7 @@ class DefaultCompiler(engine.Compiled): not self.dialect.supports_pk_autoincrement) or (c.default is not None and not isinstance(c.default, schema.Sequence))): - values.append((c, create_bind_param(c, None))) + values.append((c, self.create_insert_update_bind(c, None))) self.prefetch.append(c) elif isinstance(c.default, schema.ColumnDefault): if isinstance(c.default.arg, sql.ClauseElement): @@ -708,7 +708,7 @@ class DefaultCompiler(engine.Compiled): # dont add primary key column to postfetch self.postfetch.append(c) else: - values.append((c, create_bind_param(c, None))) + values.append((c, self.create_insert_update_bind(c, None))) self.prefetch.append(c) elif isinstance(c.default, schema.PassiveDefault): if not c.primary_key: @@ -725,7 +725,7 @@ class DefaultCompiler(engine.Compiled): values.append((c, self.process(c.onupdate.arg.self_group()))) self.postfetch.append(c) else: - values.append((c, create_bind_param(c, None))) + values.append((c, self.create_insert_update_bind(c, None))) self.prefetch.append(c) elif isinstance(c.onupdate, schema.PassiveDefault): self.postfetch.append(c) |