diff options
Diffstat (limited to 'lib/sqlalchemy/sql/dml.py')
-rw-r--r-- | lib/sqlalchemy/sql/dml.py | 101 |
1 files changed, 47 insertions, 54 deletions
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 3dc4e917c..467a764d6 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -39,54 +39,8 @@ class DMLState(CompileState): isdelete = False isinsert = False - @classmethod - def _create_insert(cls, statement, compiler, **kw): - return DMLState(statement, compiler, isinsert=True, **kw) - - @classmethod - def _create_update(cls, statement, compiler, **kw): - return DMLState(statement, compiler, isupdate=True, **kw) - - @classmethod - def _create_delete(cls, statement, compiler, **kw): - return DMLState(statement, compiler, isdelete=True, **kw) - - def __init__( - self, - statement, - compiler, - isinsert=False, - isupdate=False, - isdelete=False, - **kw - ): - self.statement = statement - - if isupdate: - self.isupdate = True - self._preserve_parameter_order = ( - statement._preserve_parameter_order - ) - if statement._ordered_values is not None: - self._process_ordered_values(statement) - elif statement._values is not None: - self._process_values(statement) - elif statement._multi_values: - self._process_multi_values(statement) - self._extra_froms = self._make_extra_froms(statement) - elif isinsert: - self.isinsert = True - if statement._select_names: - self._process_select_values(statement) - if statement._values is not None: - self._process_values(statement) - if statement._multi_values: - self._process_multi_values(statement) - elif isdelete: - self.isdelete = True - self._extra_froms = self._make_extra_froms(statement) - else: - assert False, "one of isinsert, isupdate, or isdelete must be set" + def __init__(self, statement, compiler, **kw): + raise NotImplementedError() def _make_extra_froms(self, statement): froms = [] @@ -174,6 +128,51 @@ class DMLState(CompileState): ) +@CompileState.plugin_for("default", "insert") +class InsertDMLState(DMLState): + isinsert = True + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isinsert = True + if statement._select_names: + self._process_select_values(statement) + if statement._values is not None: + self._process_values(statement) + if statement._multi_values: + self._process_multi_values(statement) + + +@CompileState.plugin_for("default", "update") +class UpdateDMLState(DMLState): + isupdate = True + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isupdate = True + self._preserve_parameter_order = statement._preserve_parameter_order + if statement._ordered_values is not None: + self._process_ordered_values(statement) + elif statement._values is not None: + self._process_values(statement) + elif statement._multi_values: + self._process_multi_values(statement) + self._extra_froms = self._make_extra_froms(statement) + + +@CompileState.plugin_for("default", "delete") +class DeleteDMLState(DMLState): + isdelete = True + + def __init__(self, statement, compiler, **kw): + self.statement = statement + + self.isdelete = True + self._extra_froms = self._make_extra_froms(statement) + + class UpdateBase( roles.DMLRole, HasCTE, @@ -754,8 +753,6 @@ class Insert(ValuesBase): _supports_multi_parameters = True - _compile_state_factory = DMLState._create_insert - select = None include_insert_from_select_defaults = False @@ -964,8 +961,6 @@ class Update(DMLWhereBase, ValuesBase): __visit_name__ = "update" - _compile_state_factory = DMLState._create_update - _traverse_internals = ( [ ("table", InternalTraversal.dp_clauseelement), @@ -1210,8 +1205,6 @@ class Delete(DMLWhereBase, UpdateBase): __visit_name__ = "delete" - _compile_state_factory = DMLState._create_delete - _traverse_internals = ( [ ("table", InternalTraversal.dp_clauseelement), |