diff options
author | olly <olly@ollycope.com> | 2015-04-15 09:10:29 +0000 |
---|---|---|
committer | olly <olly@ollycope.com> | 2015-04-15 09:10:29 +0000 |
commit | 3167df66154a38c09df91b9b1aba0953be1d87fa (patch) | |
tree | fc464a9214690a80fd1cf89842b1c619f39c8d10 | |
parent | da18d6a57b153d5b06220ad7e065e012b21fd137 (diff) | |
download | yoyo-3167df66154a38c09df91b9b1aba0953be1d87fa.tar.gz |
Refactor connection code
-rwxr-xr-x | yoyo/connections.py | 58 |
1 files changed, 36 insertions, 22 deletions
diff --git a/yoyo/connections.py b/yoyo/connections.py index 8ed338e..608fc6b 100755 --- a/yoyo/connections.py +++ b/yoyo/connections.py @@ -1,5 +1,16 @@ +from functools import wraps + _schemes = {} +drivers = { + 'odbc': 'pyodbc', + 'postgresql': 'psycopg2', + 'postgres': 'psycopg2', + 'psql': 'psycopg2', + 'mysql': 'MySQLdb', + 'sqlite': 'sqlite3', +} + class BadConnectionURI(Exception): """ @@ -15,15 +26,20 @@ def connection_for(scheme): """ def decorate(func): - _schemes[scheme] = func + + @wraps(func) + def with_driver(*args, **kwargs): + driver = __import__(drivers[scheme], globals(), locals()) + return func(driver, *args, **kwargs) + _schemes[scheme] = with_driver + return func return decorate @connection_for('odbc') -def connect_odbc(username, password, host, port, database, db_params): - import pyodbc +def connect_odbc(driver, username, password, host, port, database, db_params): kwargs = db_params if username is not None: @@ -39,12 +55,11 @@ def connect_odbc(username, password, host, port, database, db_params): connection_string = '' for k, v in kwargs: connection_string += k + '=' + v + ';' - return pyodbc.connect(connection_string), pyodbc.paramstyle + return driver.connect(connection_string), driver.paramstyle @connection_for('mysql') -def connect_mysql(username, password, host, port, database, db_params): - import MySQLdb +def connect_mysql(driver, username, password, host, port, database, db_params): kwargs = {} if username is not None: @@ -57,22 +72,20 @@ def connect_mysql(username, password, host, port, database, db_params): kwargs['port'] = port kwargs['db'] = database - return MySQLdb.connect(**kwargs), MySQLdb.paramstyle + return driver.connect(**kwargs), driver.paramstyle @connection_for('sqlite') -def connect_sqlite(username, password, host, port, database, db_params): - import sqlite3 - - return sqlite3.connect(database), sqlite3.paramstyle +def connect_sqlite( + driver, username, password, host, port, database, db_params): + return driver.connect(database), driver.paramstyle @connection_for('postgres') @connection_for('postgresql') @connection_for('psql') -def connect_postgres(username, password, host, port, database, db_params): - import psycopg2 - +def connect_postgres( + driver, username, password, host, port, database, db_params): connargs = [] if username is not None: connargs.append('user=%s' % username) @@ -83,35 +96,36 @@ def connect_postgres(username, password, host, port, database, db_params): if host is not None: connargs.append('host=%s' % host) connargs.append('dbname=%s' % database) - return psycopg2.connect(' '.join(connargs)), psycopg2.paramstyle + return driver.connect(' '.join(connargs)), driver.paramstyle def connect(uri): """ Connect to the given DB uri in the format - ``driver://user:pass@host:port/database_name?param=value``, returning a DB-API connection + ``driver://user:pass@host:port/database_name?param=value``, + returning a DB-API connection object and the paramstyle used by the DB-API module. """ - scheme, username, password, host, port, database, db_params = parse_uri(uri) + scheme, username, password, host, port, database, params = parse_uri(uri) try: connection_func = _schemes[scheme.lower()] except KeyError: raise BadConnectionURI('Unrecognised database connection scheme %r' % scheme) - return connection_func(username, password, host, port, database, db_params) + return connection_func(username, password, host, port, database, params) def parse_uri(uri): """ Examples:: - >>> parse_uri('postgres://fred:bassett@dbserver:5432/fredsdatabase') - ('postgres', 'fred', 'bassett', 'dbserver', 5432, 'fredsdatabase', None) + >>> parse_uri('postgres://fred:bassett@server:5432/fredsdatabase') + ('postgres', 'fred', 'bassett', 'server', 5432, 'fredsdatabase', None) >>> parse_uri('mysql:///jimsdatabase') ('mysql', None, None, None, None, 'jimsdatabase', None, None) - >>> parse_uri('odbc://user:password@dbserver/database?DSN=dsn') - ('odbc', 'user', 'password', 'dbserver', None, 'database', {'DSN':'dsn'}) + >>> parse_uri('odbc://user:password@server/database?DSN=dsn') + ('odbc', 'user', 'password', 'server', None, 'database', {'DSN':'dsn'}) """ scheme = username = password = host = port = database = None |