summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorolly <olly@ollycope.com>2015-04-15 09:10:29 +0000
committerolly <olly@ollycope.com>2015-04-15 09:10:29 +0000
commit3167df66154a38c09df91b9b1aba0953be1d87fa (patch)
treefc464a9214690a80fd1cf89842b1c619f39c8d10
parentda18d6a57b153d5b06220ad7e065e012b21fd137 (diff)
downloadyoyo-3167df66154a38c09df91b9b1aba0953be1d87fa.tar.gz
Refactor connection code
-rwxr-xr-xyoyo/connections.py58
1 files changed, 36 insertions, 22 deletions
diff --git a/yoyo/connections.py b/yoyo/connections.py
index 8ed338e..608fc6b 100755
--- a/yoyo/connections.py
+++ b/yoyo/connections.py
@@ -1,5 +1,16 @@
+from functools import wraps
+
_schemes = {}
+drivers = {
+ 'odbc': 'pyodbc',
+ 'postgresql': 'psycopg2',
+ 'postgres': 'psycopg2',
+ 'psql': 'psycopg2',
+ 'mysql': 'MySQLdb',
+ 'sqlite': 'sqlite3',
+}
+
class BadConnectionURI(Exception):
"""
@@ -15,15 +26,20 @@ def connection_for(scheme):
"""
def decorate(func):
- _schemes[scheme] = func
+
+ @wraps(func)
+ def with_driver(*args, **kwargs):
+ driver = __import__(drivers[scheme], globals(), locals())
+ return func(driver, *args, **kwargs)
+ _schemes[scheme] = with_driver
+
return func
return decorate
@connection_for('odbc')
-def connect_odbc(username, password, host, port, database, db_params):
- import pyodbc
+def connect_odbc(driver, username, password, host, port, database, db_params):
kwargs = db_params
if username is not None:
@@ -39,12 +55,11 @@ def connect_odbc(username, password, host, port, database, db_params):
connection_string = ''
for k, v in kwargs:
connection_string += k + '=' + v + ';'
- return pyodbc.connect(connection_string), pyodbc.paramstyle
+ return driver.connect(connection_string), driver.paramstyle
@connection_for('mysql')
-def connect_mysql(username, password, host, port, database, db_params):
- import MySQLdb
+def connect_mysql(driver, username, password, host, port, database, db_params):
kwargs = {}
if username is not None:
@@ -57,22 +72,20 @@ def connect_mysql(username, password, host, port, database, db_params):
kwargs['port'] = port
kwargs['db'] = database
- return MySQLdb.connect(**kwargs), MySQLdb.paramstyle
+ return driver.connect(**kwargs), driver.paramstyle
@connection_for('sqlite')
-def connect_sqlite(username, password, host, port, database, db_params):
- import sqlite3
-
- return sqlite3.connect(database), sqlite3.paramstyle
+def connect_sqlite(
+ driver, username, password, host, port, database, db_params):
+ return driver.connect(database), driver.paramstyle
@connection_for('postgres')
@connection_for('postgresql')
@connection_for('psql')
-def connect_postgres(username, password, host, port, database, db_params):
- import psycopg2
-
+def connect_postgres(
+ driver, username, password, host, port, database, db_params):
connargs = []
if username is not None:
connargs.append('user=%s' % username)
@@ -83,35 +96,36 @@ def connect_postgres(username, password, host, port, database, db_params):
if host is not None:
connargs.append('host=%s' % host)
connargs.append('dbname=%s' % database)
- return psycopg2.connect(' '.join(connargs)), psycopg2.paramstyle
+ return driver.connect(' '.join(connargs)), driver.paramstyle
def connect(uri):
"""
Connect to the given DB uri in the format
- ``driver://user:pass@host:port/database_name?param=value``, returning a DB-API connection
+ ``driver://user:pass@host:port/database_name?param=value``,
+ returning a DB-API connection
object and the paramstyle used by the DB-API module.
"""
- scheme, username, password, host, port, database, db_params = parse_uri(uri)
+ scheme, username, password, host, port, database, params = parse_uri(uri)
try:
connection_func = _schemes[scheme.lower()]
except KeyError:
raise BadConnectionURI('Unrecognised database connection scheme %r' %
scheme)
- return connection_func(username, password, host, port, database, db_params)
+ return connection_func(username, password, host, port, database, params)
def parse_uri(uri):
"""
Examples::
- >>> parse_uri('postgres://fred:bassett@dbserver:5432/fredsdatabase')
- ('postgres', 'fred', 'bassett', 'dbserver', 5432, 'fredsdatabase', None)
+ >>> parse_uri('postgres://fred:bassett@server:5432/fredsdatabase')
+ ('postgres', 'fred', 'bassett', 'server', 5432, 'fredsdatabase', None)
>>> parse_uri('mysql:///jimsdatabase')
('mysql', None, None, None, None, 'jimsdatabase', None, None)
- >>> parse_uri('odbc://user:password@dbserver/database?DSN=dsn')
- ('odbc', 'user', 'password', 'dbserver', None, 'database', {'DSN':'dsn'})
+ >>> parse_uri('odbc://user:password@server/database?DSN=dsn')
+ ('odbc', 'user', 'password', 'server', None, 'database', {'DSN':'dsn'})
"""
scheme = username = password = host = port = database = None