diff options
author | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-09-04 00:08:57 +0000 |
---|---|---|
committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2007-09-04 00:08:57 +0000 |
commit | 3126d464e7124cde24b18ba7efc318913d2ac40d (patch) | |
tree | 33dd8fdda1ea8a3aae5c75cfbbdded60a1b01997 | |
parent | c9924a4a145f06eac427fe60c54d4c58b894167f (diff) | |
download | sqlalchemy-3126d464e7124cde24b18ba7efc318913d2ac40d.tar.gz |
- removed "parameters" argument from clauseelement.compile(), replaced with
"column_keys". the parameters sent to execute() only interact with the
insert/update statement compilation process in terms of the column names
present but not the values for those columns.
produces more consistent execute/executemany behavior, simplifies things a
bit internally.
-rw-r--r-- | CHANGES | 7 | ||||
-rw-r--r-- | lib/sqlalchemy/databases/mssql.py | 4 | ||||
-rw-r--r-- | lib/sqlalchemy/engine/base.py | 66 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 26 | ||||
-rw-r--r-- | lib/sqlalchemy/sql/expression.py | 28 | ||||
-rw-r--r-- | test/sql/query.py | 20 | ||||
-rw-r--r-- | test/sql/select.py | 18 | ||||
-rw-r--r-- | test/testlib/testing.py | 9 |
8 files changed, 97 insertions, 81 deletions
@@ -13,6 +13,13 @@ CHANGES so mappers within inheritance relationships need to be constructed in inheritance order (which should be the normal case anyway). +- removed "parameters" argument from clauseelement.compile(), replaced with + "column_keys". the parameters sent to execute() only interact with the + insert/update statement compilation process in terms of the column names + present but not the values for those columns. + produces more consistent execute/executemany behavior, simplifies things a + bit internally. + 0.4.0beta5 ---------- diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 3e8949e14..7b0defc2a 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -847,8 +847,8 @@ dialect_mapping = { class MSSQLCompiler(compiler.DefaultCompiler): - def __init__(self, dialect, statement, parameters, **kwargs): - super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs) + def __init__(self, *args, **kwargs): + super(MSSQLCompiler, self).__init__(*args, **kwargs) self.tablealiases = {} def get_select_precolumns(self, select): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index bd6f5b97c..32bf7b780 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -421,7 +421,7 @@ class Compiled(object): defaults. """ - def __init__(self, dialect, statement, parameters, bind=None): + def __init__(self, dialect, statement, column_keys=None, bind=None): """Construct a new ``Compiled`` object. dialect @@ -430,26 +430,16 @@ class Compiled(object): statement ``ClauseElement`` to be compiled. - parameters - Optional dictionary indicating a set of bind parameters - specified with this ``Compiled`` object. These parameters - are the *default* values corresponding to the - ``ClauseElement``'s ``_BindParamClauses`` when the - ``Compiled`` is executed. In the case of an ``INSERT`` or - ``UPDATE`` statement, these parameters will also result in - the creation of new ``_BindParamClause`` objects for each - key and will also affect the generated column list in an - ``INSERT`` statement and the ``SET`` clauses of an - ``UPDATE`` statement. The keys of the parameter dictionary - can either be the string names of columns or - ``_ColumnClause`` objects. + column_keys + a list of column names to be compiled into an INSERT or UPDATE + statement. bind Optional Engine or Connection to compile this statement against. """ self.dialect = dialect self.statement = statement - self.parameters = parameters + self.column_keys = column_keys self.bind = bind self.can_execute = statement.supports_execution() @@ -778,8 +768,8 @@ class Connection(Connectable): return self.execute(object, *multiparams, **params).scalar() - def statement_compiler(self, statement, parameters, **kwargs): - return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs) + def statement_compiler(self, statement, **kwargs): + return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs) def execute(self, object, *multiparams, **params): """Executes and returns a ResultProxy.""" @@ -808,25 +798,43 @@ class Connection(Connectable): parameters = list(multiparams) return parameters + def __distill_params_and_keys(self, multiparams, params): + if multiparams is None or len(multiparams) == 0: + if params: + parameters = params + keys = params.keys() + else: + parameters = None + keys = [] + executemany = False + elif len(multiparams) == 1 and isinstance(multiparams[0], (list, tuple, dict)): + parameters = multiparams[0] + if isinstance(parameters, dict): + keys = parameters.keys() + else: + keys = parameters[0].keys() + executemany = False + else: + parameters = list(multiparams) + keys = parameters[0].keys() + executemany = True + return (parameters, keys, executemany) + def _execute_function(self, func, multiparams, params): return self._execute_clauseelement(func.select(), multiparams, params) def _execute_clauseelement(self, elem, multiparams=None, params=None): - if multiparams: - param = multiparams[0] - executemany = len(multiparams) > 1 - else: - param = params - executemany = False - return self._execute_compiled(elem.compile(dialect=self.dialect, parameters=param, inline=executemany), multiparams, params) + (params, keys, executemany) = self.__distill_params_and_keys(multiparams, params) + return self._execute_compiled(elem.compile(dialect=self.dialect, column_keys=keys, inline=executemany), distilled_params=params) - def _execute_compiled(self, compiled, multiparams=None, params=None): + def _execute_compiled(self, compiled, multiparams=None, params=None, distilled_params=None): """Execute a sql.Compiled object.""" if not compiled.can_execute: raise exceptions.ArgumentError("Not an executeable clause: %s" % (str(compiled))) - params = self.__distill_params(multiparams, params) - context = self.__create_execution_context(compiled=compiled, parameters=params) + if distilled_params is None: + distilled_params = self.__distill_params(multiparams, params) + context = self.__create_execution_context(compiled=compiled, parameters=distilled_params) context.pre_execution() self.__execute_raw(context) @@ -1119,8 +1127,8 @@ class Engine(Connectable): connection = self.contextual_connect(close_with_result=True) return connection._execute_compiled(compiled, multiparams, params) - def statement_compiler(self, statement, parameters, **kwargs): - return self.dialect.statement_compiler(self.dialect, statement, parameters, bind=self, **kwargs) + def statement_compiler(self, statement, **kwargs): + return self.dialect.statement_compiler(self.dialect, statement, bind=self, **kwargs) def connect(self, **kwargs): """Return a newly allocated Connection object.""" diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 1cfebdc27..eb416803a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -90,7 +90,7 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): operators = OPERATORS - def __init__(self, dialect, statement, parameters=None, inline=False, **kwargs): + def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs): """Construct a new ``DefaultCompiler`` object. dialect @@ -99,16 +99,12 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): statement ClauseElement to be compiled - parameters - optional dictionary indicating a set of bind parameters - specified with this Compiled object. These parameters are - the *default* key/value pairs when the Compiled is executed, - and also may affect the actual compilation, as in the case - of an INSERT where the actual columns inserted will - correspond to the keys present in the parameters. + column_keys + a list of column names to be compiled into an INSERT or UPDATE + statement. """ - super(DefaultCompiler, self).__init__(dialect, statement, parameters, **kwargs) + super(DefaultCompiler, self).__init__(dialect, statement, column_keys, **kwargs) # if we are insert/update. set to true when we visit an INSERT or UPDATE self.isinsert = self.isupdate = False @@ -217,12 +213,10 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): to produce a ClauseParameters structure, representing the bind arguments for a single statement execution, or one element of an executemany execution. """ - + d = sql_util.ClauseParameters(self.dialect, self.positiontup) - pd = self.parameters or {} - if params is not None: - pd.update(params) + pd = params or {} bind_names = self.bind_names for key, bind in self.binds.iteritems(): @@ -658,15 +652,15 @@ class DefaultCompiler(engine.Compiled, visitors.ClauseVisitor): # no parameters in the statement, no parameters in the # compiled params - return binds for all columns - if self.parameters is None and stmt.parameters is None: + if self.column_keys is None and stmt.parameters is None: return [(c, create_bind_param(c, None)) for c in stmt.table.columns] # if we have statement parameters - set defaults in the # compiled params - if self.parameters is None: + if self.column_keys is None: parameters = {} else: - parameters = dict([(getattr(k, 'key', k), v) for k, v in self.parameters.iteritems()]) + parameters = dict([(getattr(key, 'key', key), None) for key in self.column_keys]) if stmt.parameters is not None: for k, v in stmt.parameters.iteritems(): diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index ac56289e8..f88c418eb 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -963,19 +963,24 @@ class ClauseElement(object): def execute(self, *multiparams, **params): """Compile and execute this ``ClauseElement``.""" - - if multiparams: - compile_params = multiparams[0] + + if len(multiparams) == 0: + keys = params.keys() + elif isinstance(multiparams[0], dict): + keys = multiparams[0].keys() + elif isinstance(multiparams[0], (list, tuple)): + keys = multiparams[0][0].keys() else: - compile_params = params - return self.compile(bind=self.bind, parameters=compile_params, inline=(len(multiparams) > 1)).execute(*multiparams, **params) + keys = None + + return self.compile(bind=self.bind, column_keys=keys, inline=(len(multiparams) > 1)).execute(*multiparams, **params) def scalar(self, *multiparams, **params): """Compile and execute this ``ClauseElement``, returning the result's scalar representation.""" return self.execute(*multiparams, **params).scalar() - def compile(self, bind=None, parameters=None, compiler=None, dialect=None, inline=False): + def compile(self, bind=None, column_keys=None, compiler=None, dialect=None, inline=False): """Compile this SQL expression. Uses the given ``Compiler``, or the given ``AbstractDialect`` @@ -999,21 +1004,18 @@ class ClauseElement(object): ``SET`` and ``VALUES`` clause of those statements. """ - if isinstance(parameters, (list, tuple)): - parameters = parameters[0] - if compiler is None: if dialect is not None: - compiler = dialect.statement_compiler(dialect, self, parameters, inline=inline) + compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline) elif bind is not None: - compiler = bind.statement_compiler(self, parameters, inline=inline) + compiler = bind.statement_compiler(self, column_keys=column_keys, inline=inline) elif self.bind is not None: - compiler = self.bind.statement_compiler(self, parameters, inline=inline) + compiler = self.bind.statement_compiler(self, column_keys=column_keys, inline=inline) if compiler is None: from sqlalchemy.engine.default import DefaultDialect dialect = DefaultDialect() - compiler = dialect.statement_compiler(dialect, self, parameters=parameters, inline=inline) + compiler = dialect.statement_compiler(dialect, self, column_keys=column_keys, inline=inline) compiler.compile() return compiler diff --git a/test/sql/query.py b/test/sql/query.py index 4e68fb980..a519dd974 100644 --- a/test/sql/query.py +++ b/test/sql/query.py @@ -32,6 +32,14 @@ class QueryTest(PersistTest): users.insert().execute(user_id = 7, user_name = 'jack') assert users.count().scalar() == 1 + def test_insert_heterogeneous_params(self): + users.insert().execute( + {'user_id':7, 'user_name':'jack'}, + {'user_id':8, 'user_name':'ed'}, + {'user_id':9} + ) + assert users.select().execute().fetchall() == [(7, 'jack'), (8, 'ed'), (9, None)] + def testupdate(self): users.insert().execute(user_id = 7, user_name = 'jack') @@ -353,9 +361,9 @@ class QueryTest(PersistTest): ) meta.create_all() try: - t.insert().execute(value=func.length("one")) + t.insert(values=dict(value=func.length("one"))).execute() assert t.select().execute().fetchone()['value'] == 3 - t.update().execute(value=func.length("asfda")) + t.update(values=dict(value=func.length("asfda"))).execute() assert t.select().execute().fetchone()['value'] == 5 r = t.insert(values=dict(value=func.length("sfsaafsda"))).execute() @@ -363,14 +371,14 @@ class QueryTest(PersistTest): assert t.select(t.c.id==id).execute().fetchone()['value'] == 9 t.update(values={t.c.value:func.length("asdf")}).execute() assert t.select().execute().fetchone()['value'] == 4 - + print "--------------------------" t2.insert().execute() - t2.insert().execute(value=func.length("one")) - t2.insert().execute(value=func.length("asfda") + -19, stuff="hi") + t2.insert(values=dict(value=func.length("one"))).execute() + t2.insert(values=dict(value=func.length("asfda") + -19)).execute(stuff="hi") assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(7,None), (3,None), (-14,"hi")] - t2.update().execute(value=func.length("asdsafasd"), stuff="some stuff") + t2.update(values=dict(value=func.length("asdsafasd"))).execute(stuff="some stuff") assert select([t2.c.value, t2.c.stuff]).execute().fetchall() == [(9,"some stuff"), (9,"some stuff"), (9,"some stuff")] t2.delete().execute() diff --git a/test/sql/select.py b/test/sql/select.py index 5eaea7480..edca33bc0 100644 --- a/test/sql/select.py +++ b/test/sql/select.py @@ -567,10 +567,9 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = def testtextbinds(self): self.assert_compile( - text("select * from foo where lala=:bar and hoho=:whee"), + text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar', 4), bindparam('whee', 7)]), "select * from foo where lala=:bar and hoho=:whee", checkparams={'bar':4, 'whee': 7}, - params={'bar':4, 'whee': 7, 'hoho':10}, ) self.assert_compile( @@ -582,10 +581,9 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = dialect = postgres.dialect() self.assert_compile( - text("select * from foo where lala=:bar and hoho=:whee"), + text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar',4), bindparam('whee',7)]), "select * from foo where lala=%(bar)s and hoho=%(whee)s", checkparams={'bar':4, 'whee': 7}, - params={'bar':4, 'whee': 7, 'hoho':10}, dialect=dialect ) self.assert_compile( @@ -598,10 +596,9 @@ WHERE mytable.myid = myothertable.otherid) AS t2view WHERE t2view.mytable_myid = dialect = sqlite.dialect() self.assert_compile( - text("select * from foo where lala=:bar and hoho=:whee"), + text("select * from foo where lala=:bar and hoho=:whee", bindparams=[bindparam('bar',4), bindparam('whee',7)]), "select * from foo where lala=? and hoho=?", checkparams=[4, 7], - params={'bar':4, 'whee': 7, 'hoho':10}, dialect=dialect ) @@ -936,11 +933,6 @@ EXISTS (select yay from foo where boo = lar)", except exceptions.CompileError, err: assert str(err) == "Bind parameter 'mytable_myid_1' conflicts with unique bind parameter of the same name" - # check that the bind params sent along with a compile() call - # get preserved when the params are retreived later - s = select([table1], table1.c.myid == bindparam('test')) - c = s.compile(parameters = {'test' : 7}) - self.assert_(c.get_params().get_original_dict() == {'test' : 7}) def testbindascol(self): t = table('foo', column('id')) @@ -1134,7 +1126,7 @@ class CRUDTest(SQLCompileTest): self.assert_compile(table.insert(inline=True), "INSERT INTO sometable (foo) VALUES (foobar())", params={}) def testinsertexpression(self): - self.assert_compile(insert(table1), "INSERT INTO mytable (myid) VALUES (lala())", params=dict(myid=func.lala())) + self.assert_compile(insert(table1, values=dict(myid=func.lala())), "INSERT INTO mytable (myid) VALUES (lala())") def testupdate(self): self.assert_compile(update(table1, table1.c.myid == 7), "UPDATE mytable SET name=:name WHERE mytable.myid = :mytable_myid", params = {table1.c.name:'fred'}) @@ -1144,7 +1136,7 @@ class CRUDTest(SQLCompileTest): self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}), "UPDATE mytable SET name=mytable.myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'}) self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.myid : 9}), "UPDATE mytable SET myid=:myid, description=:description WHERE mytable.myid = :mytable_myid", params = {'mytable_myid': 12, 'myid': 9, 'description': 'test'}) s = table1.update(table1.c.myid == 12, values = {table1.c.name : 'lala'}) - c = s.compile(parameters = {'mytable_id':9,'name':'h0h0'}) + c = s.compile(column_keys=['mytable_id', 'name']) self.assert_compile(update(table1, table1.c.myid == 12, values = {table1.c.name : table1.c.myid}).values({table1.c.name:table1.c.name + 'foo'}), "UPDATE mytable SET name=(mytable.name || :mytable_name), description=:description WHERE mytable.myid = :mytable_myid", params = {'description':'test'}) self.assert_(str(s) == str(c)) diff --git a/test/testlib/testing.py b/test/testlib/testing.py index 0038fddfe..26873b25f 100644 --- a/test/testlib/testing.py +++ b/test/testlib/testing.py @@ -211,8 +211,13 @@ class SQLCompileTest(PersistTest): def assert_compile(self, clause, result, params=None, checkparams=None, dialect=None): if dialect is None: dialect = getattr(self, '__dialect__', None) - - c = clause.compile(parameters=params, dialect=dialect) + + if params is None: + keys = None + else: + keys = params.keys() + + c = clause.compile(column_keys=keys, dialect=dialect) print "\nSQL String:\n" + str(c) + repr(c.get_params()) |