summaryrefslogtreecommitdiff
path: root/alembic/autogenerate/generate.py
blob: fcd3533d124729be711ad932a498744ebb65d23c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from .. import util
from .api import _produce_migration_diffs


class GeneratedRevision(object):
    def __init__(self, revision_context):
        self.revision_context = revision_context
        self.template_args = {}
        self.imports = set()
        self.rev_id = revision_context.command_args['rev_id'] or util.rev_id()

        self.head = self.revision_context.command_args['head']
        self.splice = self.revision_context.command_args['splice']
        self.branch_labels = \
            self.revision_context.command_args['branch_labels']
        self.version_path = self.revision_context.command_args['version_path']

    def to_script(self):
        return self.revision_context.script_directory.generate_revision(
            self.rev_id,
            self.revision_context.command_args['message'],
            refresh=True,
            head=self.head,
            splice=self.splice,
            branch_labels=self.branch_label,
            version_path=self.version_path,
            **self.template_args)


class RevisionContext(object):
    def __init__(self, script_directory, command_args):
        self.script_directory = script_directory
        self.command_args = command_args
        self.generated_revisions = [
            GeneratedRevision(self)
        ]

    def run_autogenerate(self, rev, context):
        if self.command_args['sql']:
            raise util.CommandError(
                "Using --sql with --autogenerate does not make any sense")
        if set(self.script_directory.get_revisions(rev)) != \
                set(self.script_directory.get_revisions("heads")):
            raise util.CommandError("Target database is not up to date.")
        for generated_revision in self.generated_revisions:
            _produce_migration_diffs(
                context,
                generated_revision.template_args, generated_revision.imports)

    def run_no_autogenerate(self, rev, context):
        pass

    def generate_scripts(self):
        for generated_revision in self.generated_revisions:
            yield generated_revision.to_script()