summaryrefslogtreecommitdiff
path: root/yoyo/connections.py
blob: 777692d8188b3a1498698b2d876f60032c01f000 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# Copyright 2015 Oliver Cope
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

from collections import namedtuple

from .migrations import default_migration_table
from .backends import (PostgresqlBackend,
                       SQLiteBackend,
                       ODBCBackend,
                       MySQLBackend,
                       MySQLdbBackend)
from .compat import urlsplit, urlunsplit, parse_qsl, urlencode, quote, unquote

BACKENDS = {
    'odbc': ODBCBackend,
    'postgresql': PostgresqlBackend,
    'postgres': PostgresqlBackend,
    'psql': PostgresqlBackend,
    'mysql': MySQLBackend,
    'mysql+mysqldb': MySQLdbBackend,
    'sqlite': SQLiteBackend,
}


_DatabaseURI = namedtuple('_DatabaseURI',
                          'scheme username password hostname port database '
                          'args')


class DatabaseURI(_DatabaseURI):

    @property
    def netloc(self):
        hostname = self.hostname or ''
        if self.port:
            hostpart = '{}:{}'.format(hostname, self.port)
        else:
            hostpart = hostname

        if self.username:
            return '{}:{}@{}'.format(quote(self.username),
                                     quote(self.password or ''),
                                     hostpart)
        else:
            return hostpart

    def __str__(self):
        return urlunsplit((self.scheme,
                           self.netloc,
                           self.database,
                           urlencode(self.args),
                           ''))

    @property
    def uri(self):
        return str(self)


class BadConnectionURI(Exception):
    """
    An invalid connection URI
    """


def get_backend(uri, migration_table=default_migration_table):
    """
    Connect to the given DB uri in the format
    ``driver://user:pass@host:port/database_name?param=value``,
    returning a :class:`DatabaseBackend` object
    """
    parsed = parse_uri(uri)
    try:
        backend_class = BACKENDS[parsed.scheme.lower()]
    except KeyError:
        raise BadConnectionURI('Unrecognised database connection scheme %r' %
                               parsed.scheme)
    return backend_class(parsed, migration_table)


def parse_uri(s):
    """
    Examples::

        >>> parse_uri('postgres://fred:bassett@server:5432/fredsdatabase')
        ('postgres', 'fred', 'bassett', 'server', 5432, 'fredsdatabase', None)
        >>> parse_uri('mysql:///jimsdatabase')
        ('mysql', None, None, None, None, 'jimsdatabase', None, None)
        >>> parse_uri('odbc://user:password@server/database?DSN=dsn')
        ('odbc', 'user', 'password', 'server', None, 'database', {'DSN':'dsn'})
    """
    result = urlsplit(s)

    if not result.scheme:
        raise BadConnectionURI("No scheme specified in connection URI %r" % s)

    return DatabaseURI(scheme=result.scheme,
                       username=(unquote(result.username)
                                 if result.username is not None
                                 else None),
                       password=(unquote(result.password)
                                 if result.password is not None
                                 else None),
                       hostname=result.hostname,
                       port=result.port,
                       database=result.path[1:] if result.path else None,
                       args=dict(parse_qsl(result.query)))