summaryrefslogtreecommitdiff
path: root/alembic/autogenerate/generate.py
blob: c68615660be1b2822d5f813c5fee31321850b38d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from .. import util
from . import api
from . import compose
from . import compare
from . import render
from ..operations import ops


class RevisionContext(object):
    def __init__(self, config, script_directory, command_args):
        self.config = config
        self.script_directory = script_directory
        self.command_args = command_args
        self.template_args = {
            'config': config  # Let templates use config for
                              # e.g. multiple databases
        }
        self.generated_revisions = [
            self._default_revision()
        ]

    def _to_script(self, migration_script):
        template_args = {}
        for k, v in self.template_args.items():
            template_args.setdefault(k, v)

        if migration_script._autogen_context is not None:
            render._render_migration_script(
                migration_script._autogen_context, migration_script,
                template_args
            )

        return self.script_directory.generate_revision(
            migration_script.rev_id,
            migration_script.message,
            refresh=True,
            head=migration_script.head,
            splice=migration_script.splice,
            branch_labels=migration_script.branch_label,
            version_path=migration_script.version_path,
            **template_args)

    def run_autogenerate(self, rev, context):
        if self.command_args['sql']:
            raise util.CommandError(
                "Using --sql with --autogenerate does not make any sense")
        if set(self.script_directory.get_revisions(rev)) != \
                set(self.script_directory.get_revisions("heads")):
            raise util.CommandError("Target database is not up to date.")

        autogen_context = api._autogen_context(context)

        diffs = []
        compare._produce_net_changes(autogen_context, diffs)

        migration_script = self.generated_revisions[0]

        compose._to_migration_script(autogen_context, migration_script, diffs)

        hook = context.opts.get('process_revision_directives', None)
        if hook:
            hook(context, rev, self.generated_revisions)

        for migration_script in self.generated_revisions:
            migration_script._autogen_context = autogen_context

    def run_no_autogenerate(self, rev, context):
        hook = context.opts.get('process_revision_directives', None)
        if hook:
            hook(context, rev, self.generated_revisions)

        for migration_script in self.generated_revisions:
            migration_script._autogen_context = None

    def _default_revision(self):
        op = ops.MigrationScript(
            rev_id=self.command_args['rev_id'] or util.rev_id(),
            message=self.command_args['message'],
            imports=set(),
            upgrade_ops=ops.UpgradeOps([]),
            downgrade_ops=ops.DowngradeOps([]),
            head=self.command_args['head'],
            splice=self.command_args['splice'],
            branch_label=self.command_args['branch_label'],
            version_path=self.command_args['version_path']
        )
        op._autogen_context = None
        return op

    def generate_scripts(self):
        for generated_revision in self.generated_revisions:
            yield self._to_script(generated_revision)