diff options
author | michele.simionato <devnull@localhost> | 2008-12-09 07:15:54 +0000 |
---|---|---|
committer | michele.simionato <devnull@localhost> | 2008-12-09 07:15:54 +0000 |
commit | de607ad6f4c7c85ee1299c9058b11379a3354277 (patch) | |
tree | e73a51a6c06f83bb7f6efc2827ff13db689bee70 | |
parent | 208b408ed6d1a69544324472d59e1b0e0b5044e2 (diff) | |
download | micheles-de607ad6f4c7c85ee1299c9058b11379a3354277.tar.gz |
Many improvements
-rwxr-xr-x | decorator/decorator.py | 35 | ||||
-rw-r--r-- | decorator/documentation.py | 126 | ||||
-rw-r--r-- | sqlplain/Makefile | 2 | ||||
-rw-r--r-- | sqlplain/automatize.py | 13 | ||||
-rw-r--r-- | sqlplain/connection.py | 16 | ||||
-rw-r--r-- | sqlplain/memoize.py | 2 | ||||
-rw-r--r-- | sqlplain/postgres_support.py | 1 | ||||
-rw-r--r-- | sqlplain/postgres_util.py | 12 | ||||
-rw-r--r-- | sqlplain/python.py | 98 | ||||
-rw-r--r-- | sqlplain/sql_support.py | 2 | ||||
-rw-r--r-- | sqlplain/sqlite_util.py | 1 | ||||
-rw-r--r-- | sqlplain/uri.py | 107 | ||||
-rw-r--r-- | sqlplain/util.py | 79 |
13 files changed, 208 insertions, 286 deletions
diff --git a/decorator/decorator.py b/decorator/decorator.py index e9a35e0..39debfa 100755 --- a/decorator/decorator.py +++ b/decorator/decorator.py @@ -93,25 +93,28 @@ def makefn(src, funcdata, save_source=True, **evaldict): func = evaldict[name] return funcdata.update(func, __source__=src) -def decorator(caller, func=None): +def decorator_apply(caller, func): + "decorator.apply(caller, func) is akin to decorator(caller)(func)" + fd = FuncData(func) + name = fd.name + signature = fd.signature + for arg in signature.split(','): + argname = arg.strip(' *') + assert not argname in('_func_', '_call_'), ( + '%s is a reserved argument name!' % argname) + src = """def %(name)s(%(signature)s): + return _call_(_func_, %(signature)s)""" % locals() + return makefn(src, fd, save_source=False, _func_=func, _call_=caller) + +def decorator(caller): """ - decorator(caller) converts a caller function into a decorator; - decorator(caller, func) is akin to decorator(caller)(func). + decorator(caller) converts a caller function into a decorator. """ - if func: - fd = FuncData(func) - name = fd.name - signature = fd.signature - for arg in signature.split(','): - argname = arg.strip(' *') - assert not argname in('_func_', '_call_'), ( - '%s is a reserved argument name!' % argname) - src = """def %(name)s(%(signature)s): - return _call_(_func_, %(signature)s)""" % locals() - return makefn(src, fd, save_source=False, _func_=func, _call_=caller) - src = 'def %s(func): return decorator(caller, func)' % caller.__name__ + src = 'def %s(func): return appl(caller, func)' % caller.__name__ return makefn(src, FuncData(caller), save_source=False, - caller=caller, decorator=decorator) + caller=caller, appl=decorator_apply) + +decorator.apply = decorator_apply @decorator def deprecated(func, *args, **kw): diff --git a/decorator/documentation.py b/decorator/documentation.py index 3d5f88a..dadb231 100644 --- a/decorator/documentation.py +++ b/decorator/documentation.py @@ -14,12 +14,13 @@ The ``decorator`` module Introduction ------------------------------------------------ -Python 2.4 decorators are an interesting example of why syntactic sugar -matters: in principle, their introduction changed nothing, since they do -not provide any new functionality which was not already present in the -language; in practice, their introduction has significantly changed the way -we structure our programs in Python. I believe the change is for the best, -and that decorators are a great idea since: +Python decorators are an interesting example of why syntactic sugar +matters. In principle, their introduction in Python 2.4 changed +nothing, since they do not provide any new functionality which was not +already present in the language; in practice, their introduction has +significantly changed the way we structure our programs in Python. I +believe the change is for the best, and that decorators are a great +idea since: * decorators help reducing boilerplate code; * decorators help separation of concerns; @@ -145,12 +146,15 @@ First of all, you must import ``decorator``: >>> from decorator import decorator -Then you must define an helper function with signature ``(f, *args, **kw)`` +Then you must define a helper function with signature ``(f, *args, **kw)`` which calls the original function ``f`` with arguments ``args`` and ``kw`` and implements the tracing capability: $$_trace +At this point you can define your decorator in terms of the helper function +via ``decorator.apply``: + $$trace Therefore, you can write the following: @@ -192,7 +196,8 @@ That includes even functions with exotic signatures like the following: calling exotic_signature with args ((1, 2),), {} 3 -Notice that exotic signatures have been disabled in Python 3.0. +Notice that the support for exotic signatures has been deprecated +in Python 2.6 and removed in Python 3.0. ``decorator`` is a decorator --------------------------------------------- @@ -274,38 +279,6 @@ tuple(kwargs.iteritems()))`` as key for the memoize dictionary. Notice that in general it is impossible to memoize correctly something that depends on mutable arguments. -``locked`` ---------------------------------------------------------------- - -There are good use cases for decorators is in multithreaded programming. -For instance, a ``locked`` decorator can remove the boilerplate -for acquiring/releasing locks [#]_. - -.. [#] In Python 2.5, the preferred way to manage locking is via - the ``with`` statement: http://docs.python.org/lib/with-locks.html - -To show an example of usage, suppose one wants to write some data to -an external resource which can be accessed by a single user at once -(for instance a printer). Then the access to the writing function must -be locked: - -.. code-block:: python - - import time - - datalist = [] # for simplicity the written data are stored into a list. - - @locked - def write(data): - "Writing to a sigle-access resource" - time.sleep(1) - datalist.append(data) - - -Since the writing function is locked, we are guaranteed that at any given time -there is at most one writer. An example multithreaded program that invokes -``write`` and prints the datalist is shown in the next section. - ``delayed`` and ``threaded`` -------------------------------------------- @@ -350,8 +323,28 @@ to deserve a name: threaded = delayed(0) Threaded procedures will be executed in a separated thread as soon -as they are called. Here is an example using the ``write`` -routine defined before: +as they are called. Here is an example. + +Suppose one wants to write some data to +an external resource which can be accessed by a single user at once +(for instance a printer). Then the access to the writing function must +be locked: + +.. code-block:: python + + import time + + datalist = [] # for simplicity the written data are stored into a list. + + def write(data): + "Writing to a sigle-access resource" + with threading.Lock(): + time.sleep(1) + datalist.append(data) + + +Since the writing function is locked, we are guaranteed that at any given time +there is at most one writer. Here is an example. >>> @threaded ... def writedata(data): @@ -406,35 +399,6 @@ Please wait ... >>> print read_data() some data -``redirecting_stdout`` -------------------------------------------- - -Decorators help in removing the boilerplate associated to ``try .. finally`` -blocks. We saw the case of ``locked``; here is another example: - -$$redirecting_stdout - -Here is an example of usage: - ->>> from StringIO import StringIO - ->>> out = StringIO() - ->>> @redirecting_stdout(out) -... def helloworld(): -... print "hello, world!" - ->>> helloworld() - ->>> out.getvalue() -'hello, world!\n' - -Similar tricks can be used to remove the boilerplate associate with -transactional databases. I think you got the idea, so I will leave -the transactional example as an exercise for the reader. Of course -in Python 2.5 these use cases can also be addressed with the ``with`` -statement. - Class decorators and decorator factories -------------------------------------------------------------------- @@ -529,9 +493,8 @@ of the original function. If not, you will get an error at calling time, not at decoration time. With ``new_wrapper`` at your disposal, it is a breeze to define an utility -to upgrade old-style decorators to signature-preserving decorators: +to upgrade old-style decorators to signature-preserving decorators. -$$upgrade_dec ``tail_recursive`` ------------------------------------------------------------ @@ -760,7 +723,7 @@ def _trace(f, *args, **kw): print "calling %s with args %s, %s" % (f.func_name, args, kw) return f(*args, **kw) def trace(f): - return decorator(_trace, f) + return decorator.apply(_trace, f) def delayed(nsec): def call(proc, *args, **kw): @@ -793,7 +756,7 @@ def _memoize(func, *args, **kw): def memoize(f): f.cache = {} - return decorator(_memoize, f) + return decorator.apply(_memoize, f) @decorator def locked(func, *args, **kw): @@ -821,17 +784,6 @@ def blocking(not_avail="Not Available"): return f.result return decorator(call) -def redirecting_stdout(new_stdout): - def call(func, *args, **kw): - save_stdout = sys.stdout - sys.stdout = new_stdout - try: - result = func(*args, **kw) - finally: - sys.stdout = save_stdout - return result - return decorator(call) - class User(object): "Will just be able to see a page" @@ -869,7 +821,7 @@ class Restricted(object): '%s does not have the permission to run %s!' % (userclass.__name__, func.__name__)) def __call__(self, func): - return decorator(self.call, func) + return decorator.apply(self.call, func) class Action(object): diff --git a/sqlplain/Makefile b/sqlplain/Makefile index b6d758b..8b7816d 100644 --- a/sqlplain/Makefile +++ b/sqlplain/Makefile @@ -1,4 +1,4 @@ count: wc -l __init__.py automatize.py configurator.py connection.py \ mssql_support.py postgres_support.py sqlite_support.py \ - sql_support.py uri.py util.py runtransac.py insert.py memoize.py + sql_support.py uri.py util.py runtransac.py memoize.py diff --git a/sqlplain/automatize.py b/sqlplain/automatize.py index 357e165..8b1d267 100644 --- a/sqlplain/automatize.py +++ b/sqlplain/automatize.py @@ -1,12 +1,23 @@ import os, sys, subprocess, re from sqlplain.configurator import configurator -from sqlplain.connection import LazyConn from sqlplain.util import create_db, create_schema from sqlplain.namedtuple import namedtuple VERSION = re.compile(r'(\d[\d\.-]+)') Chunk = namedtuple('Chunk', 'version fname code') +try: + CalledProcessError = subprocess.CalledProcessError +except AttributeError: + class CalledProcessError(Exception): pass + +def getoutput(commandlist): + po = subprocess.Popen(commandlist, stdout=subprocess.PIPE) + out, err = po.communicate() + if po.returncode or err: + raise CalledProcessError('%s [return code %d]' % (err, po.returncode)) + return out + def collect(directory, exts): ''' Read the files with a given set of extensions from a directory diff --git a/sqlplain/connection.py b/sqlplain/connection.py index 6cd5c0e..2bb4bc0 100644 --- a/sqlplain/connection.py +++ b/sqlplain/connection.py @@ -1,4 +1,4 @@ -import sys, threading, itertools +import sys, threading, itertools, string from operator import attrgetter try: from collections import namedtuple @@ -120,8 +120,6 @@ class LazyConnection(object): try: if args: cursor.execute(templ, args) - #elif self.dbtype == 'sqlite': - # cursor.executescript(templ) else: cursor.execute(templ) except Exception, e: @@ -178,6 +176,18 @@ class LazyConnection(object): res.header = Ntuple(*header) return res + def executescript(self, sql, *dicts, **kw): + "A driver-independent method to execute sql templates" + d = {} + for dct in dicts + (kw,): + d.update(dct) + if d: + sql = string.Template(sql).substitute(d) + if self.dbtype == 'sqlite': + self._curs.executescript(sql) + else: # psycopg and pymssql are already able to execute chunks + self.execute(sql) + def close(self): """The next time you will call an active method, a fresh new connection will be instantiated""" diff --git a/sqlplain/memoize.py b/sqlplain/memoize.py index 5f12f7c..9dad114 100644 --- a/sqlplain/memoize.py +++ b/sqlplain/memoize.py @@ -1,4 +1,4 @@ -from sqlplain.python import decorator +from decorator import decorator class Memoize(object): diff --git a/sqlplain/postgres_support.py b/sqlplain/postgres_support.py index c29a22e..e91a456 100644 --- a/sqlplain/postgres_support.py +++ b/sqlplain/postgres_support.py @@ -9,6 +9,7 @@ ISOLATION_LEVELS = None, 0, 1, 2 def connect(params, isolation_level=None, **kw): user, pwd, host, port, db = params port = port or 5432 + #import pdb; pdb.set_trace() conn = dbapi2.connect( database=db, host=host, port=port, user=user, password=pwd, **kw) if isolation_level is None: diff --git a/sqlplain/postgres_util.py b/sqlplain/postgres_util.py index c0df642..9105077 100644 --- a/sqlplain/postgres_util.py +++ b/sqlplain/postgres_util.py @@ -1,4 +1,5 @@ from sqlplain.util import openclose +from sqlplain.automatize import getoutput def create_db_postgres(uri): openclose(uri.copy(database='template1'), @@ -15,6 +16,9 @@ def exists_table_postgres(conn, tname): def bulk_insert_postgres(conn, file, table, sep='\t', null='\N', columns=None): conn._curs.copy_from(file, table, sep, null, columns) +def dump_postgres(conn, file, table, sep='\t', null='\N', columns=None): + conn._curs.copy_to(file, table, sep, null, columns) + def exists_db_postgres(uri): dbname = uri['database'] for row in openclose( @@ -22,3 +26,11 @@ def exists_db_postgres(uri): if row[0] == dbname: return True return False + +def get_schema_postgres(uri, objectname): + cmd = ['pg_dump', '-s', + '-t', objectname, + '-h', uri['host'], + '-U', uri['user'], + '-d', uri['database']] + return getoutput(cmd) diff --git a/sqlplain/python.py b/sqlplain/python.py deleted file mode 100644 index 9fdc691..0000000 --- a/sqlplain/python.py +++ /dev/null @@ -1,98 +0,0 @@ -import os, sys, re, inspect, warnings -from tempfile import mkstemp - -PYTHON3 = sys.version >= '3' - -DEF = re.compile('\s*def\s*([_\w][_\w\d]*)\s*\(') - -def _callermodule(level=2): - return sys._getframe(level).f_globals.get('__name__', '?') - -def getsignature(func): - "Return the signature of a function as a string" - argspec = inspect.getargspec(func) - return inspect.formatargspec(formatvalue=lambda val: "", *argspec)[1:-1] - -class FuncData(object): - def __init__(self, func=None, name=None, signature=None, - defaults=None, doc=None, module=None, funcdict=None): - if func: - self.name = func.__name__ - self.signature = getsignature(func) - self.defaults = func.func_defaults - self.doc = func.__doc__ - self.module = func.__module__ - self.dict = func.__dict__ - if name: - self.name = name - if signature: - self.signature = signature - if defaults: - self.defaults = defaults - if doc: - self.doc = doc - if module: - self.module = module - if funcdict: - self.dict = funcdict - - def update(self, func, **kw): - func.__name__ = getattr(self, 'name', 'noname') - func.__doc__ = getattr(self, 'doc', None) - func.__dict__ = getattr(self, 'dict', {}) - func.func_defaults = getattr(self, 'defaults', None) - func.__module__ = getattr(self, 'module', _callermodule()) - func.__dict__.update(kw) - return func - - def __getitem__(self, name): - return getattr(self, name) - -def makefn(src, funcdata, save_source=True, **evaldict): - src += os.linesep # add a newline just for safety - name = DEF.match(src).group(1) # extract the function name from the source - if save_source: - fhandle, fname = mkstemp() - os.write(fhandle, src) - os.close(fhandle) - else: - fname = '?' - code = compile(src, fname, 'single') - exec code in evaldict - func = evaldict[name] - return funcdata.update(func, __source__=src) - -def decorator(caller, func=None): - """ - decorator(caller) converts a caller function into a decorator; - decorator(caller, func) is akin to decorator(caller)(func). - """ - if func: - fd = FuncData(func) - name = fd.name - signature = fd.signature - for arg in signature.split(','): - argname = arg.strip(' *') - assert not argname in('_func_', '_call_'), ( - '%s is a reserved argument name!' % argname) - src = """def %(name)s(%(signature)s): - return _call_(_func_, %(signature)s)""" % locals() - return makefn(src, fd, save_source=False, _func_=func, _call_=caller) - src = 'def %s(func): return decorator(caller, func)' % caller.__name__ - return makefn(src, FuncData(caller), save_source=False, - caller=caller, decorator=decorator) - -@decorator -def deprecated(func, *args, **kw): - "A decorator for deprecated functions" - warnings.warn('Calling the deprecated function %r' % func.__name__, - DeprecationWarning, stacklevel=3) - return func(*args, **kw) - -def upgrade_dec(dec): - def new_dec(func): - fd = FuncData(func) - src = '''def %(name)s(%(signature)s): - return decorated(%(signature)s)''' % fd - return makefn(src, fd, save_source=False, decorated=dec(func)) - return FuncData(dec).update(new_dec) diff --git a/sqlplain/sql_support.py b/sqlplain/sql_support.py index 7c7c2c2..bef6c72 100644 --- a/sqlplain/sql_support.py +++ b/sqlplain/sql_support.py @@ -1,5 +1,5 @@ import re, sys -from sqlplain.python import makefn, FuncData +from decorator import makefn, FuncData from sqlplain.memoize import Memoize STRING_OR_COMMENT = re.compile(r"('[^']*'|--.*\n)") diff --git a/sqlplain/sqlite_util.py b/sqlplain/sqlite_util.py index a6fba5f..b7bba6e 100644 --- a/sqlplain/sqlite_util.py +++ b/sqlplain/sqlite_util.py @@ -1,3 +1,4 @@ +import os from sqlplain.util import openclose def exists_table_sqlite(conn, tname): diff --git a/sqlplain/uri.py b/sqlplain/uri.py index 9b5c1c3..93c4580 100644 --- a/sqlplain/uri.py +++ b/sqlplain/uri.py @@ -8,42 +8,54 @@ from sqlplain.configurator import configurator SUPPORTED_DBTYPES = 'mssql', 'postgres', 'sqlite' -class URI(dict): - def __init__(self, uri): - """ - Extract the connection parameters from a SQLAlchemy-like uri string. - Return a dictionary with keys +def imp(mod): + return __import__(mod, globals(), locals(), ['']) - - uri - - dbtype - - server # means host:port - - database - - host - - port +class URI(object): + """ + Extract: the connection parameters from a SQLAlchemy-like uri string. + Has attributes - In the case of mssql, the host may contain an instance name. - """ + - dbtype + - server # means host:port + - database + - host + - port + - scriptdir + + In the case of mssql, the host may contain an instance name. + """ + def __init__(self, uri): if isinstance(uri, URI): # copy data from uri - self.update(uri) + vars(self).update(vars(uri)) + return assert uri and isinstance(uri, str), '%r is not a valid string!' % uri + self.scriptdir = None if not '://' in uri: # assume it is an alias try: + section = configurator.scriptdir + except AttributeError: # missing [scripdir] section in conf + pass + else: + scriptdir = section.get(uri) + if scriptdir: + self.scriptdir = os.path.expanduser(scriptdir) + try: uri = configurator.uri[uri] except KeyError: raise NameError( - '%s is not a valid URI, not a recognized alias' % uri) - #else: - # uri.dir = configurator.dir.get(uri) + '%s is not a valid URI, not a recognized alias in %s' % + (uri, configurator._conf_file)) if not uri.startswith(SUPPORTED_DBTYPES): raise NameError('Invalid URI %s' % uri) dbtype, partial_uri = uri.split('://') if dbtype == 'sqlite': # strip a leading slash, since according to # SQLAlchemy conventions full_uri starts with three slashes or more - self['dbtype'] = dbtype - self['user'] = '' - self['password'] = '' - self['database'] = partial_uri[1:] - self['host'] = 'localhost' + self.dbtype = dbtype + self.user = '' + self.password = '' + self.database = partial_uri[1:] + self.host = 'localhost' return elif not ('@' in partial_uri and '/' in partial_uri and \ ':' in partial_uri): @@ -51,21 +63,34 @@ class URI(dict): 'Wrong uri %s: should be dbtype://user:passwd@host:port/db' % partial_uri) user_pwd, host_db = partial_uri.split('@') - self['dbtype'] = dbtype - self['server'], self['database'] = host_db.split('/') - self['user'], self['password'] = user_pwd.split(':') - self['user'] = self['user'] or os.environ.get('USER') - if not self['user']: + self.dbtype = dbtype + self.server, self.database = host_db.split('/') + self.user, self.password = user_pwd.split(':') + self.user = self.user or os.environ.get('USER') + if not self.user: raise ValueError('Empty username and $USER!') - if ':' in self['server']: # look if an explicit port is passed - self['host'], self['port'] = self['server'].split(':') + if ':' in self.server: # look if an explicit port is passed + self.host, self.port = self.server.split(':') else: - self['host'], self['port'] = self['server'], None + self.host, self.port = self.server, None def copy(self, **kw): + "Returns a copy of the URI object with different attributes" new = self.__class__(self) - new.update(kw) + vars(new).update(kw) return new + + def import_driver(self): + "Import the right driver and populate the util module" + from sqlplain import util + dbtype = self.dbtype + driver = imp('sqlplain.%s_support' % dbtype) + driver_util = imp('sqlplain.%s_util' % dbtype) + # dynamically populate thw 'util' module with the driver-specific func + for name, value in vars(driver_util).iteritems(): + if name.endswith(dbtype): + setattr(util, name, value) + return driver def get_driver_connect_params(self): """ @@ -73,17 +98,19 @@ class URI(dict): the URI and returns the right connection factory, as well as its arguments user, pwd, host, port, db. """ - module = 'sqlplain.%(dbtype)s_support' % self - driver = __import__(module, globals(), locals(), ['']) - if self['dbtype'] == 'sqlite': - params = self['database'] + driver = self.import_driver() + if self.dbtype == 'sqlite': + params = self.database else: - params = (self['user'], self['password'], self['host'], - self['port'], self['database']) + params = (self.user, self.password, self.host, + self.port, self.database) return driver.dbapi2, driver.connect, params - + + def __getitem__(self, name): + return self.__dict__[name] + def __str__(self): - if self['dbtype'] == 'sqlite': - return 'sqlite:///' + self['database'] + if self.dbtype == 'sqlite': + return 'sqlite:///' + self.database t = '%(dbtype)s://%(user)s:xxxxx@%(server)s/%(database)s' return t % self diff --git a/sqlplain/util.py b/sqlplain/util.py index ca0dc2c..44d0a3d 100644 --- a/sqlplain/util.py +++ b/sqlplain/util.py @@ -2,7 +2,7 @@ Notice: create_db and drop_db are not transactional. """ -import os, sys +import os, sys, re from sqlplain.uri import URI from sqlplain import lazyconnect, transact, do from sqlplain.namedtuple import namedtuple @@ -10,36 +10,13 @@ from sqlplain.namedtuple import namedtuple VERSION = re.compile(r'(\d[\d\.-]+)') Chunk = namedtuple('Chunk', 'version fname code') -def collect(directory, exts): - ''' - Read the files with a given set of extensions from a directory - and returns them ordered by version number. - ''' - sql = [] - for fname in os.listdir(directory): - if fname.endswith(exts) and not fname.startswith('_'): - version = VERSION.search(fname) - if version: - code = file(os.path.join(directory, fname)).read() - sql.append(Chunk(version, fname, code)) - return sorted(sql) - -# dispatch on the database type - -def _call_with_uri(procname, uri, *args): - "Call a procedure by name, passing to it an URI string" - proc = globals().get(procname + '_' + uri['dbtype']) - if proc is None: - raise NameError('Missing procedure %s, database not supported' % - proc.__name__) - return proc(uri, *args) - -def _call_with_conn(procname, conn, *args): - proc = globals().get(procname + '_' + conn.dbtype) +def _call(procname, uri_or_conn, *args, **kw): + "Call a procedure by name, by dispatching on the database type" + dbtype = uri_or_conn.dbtype + proc = globals().get(procname + '_' + dbtype) if proc is None: - raise NameError('Missing procedure %s, database not supported' % - proc.__name__) - return proc(conn, *args) + raise NameError('Missing procedure %s for %s' % (procname, dbtype)) + return proc(uri_or_conn, *args, **kw) # exported utilities @@ -60,11 +37,26 @@ def openclose(uri, templ, *args, **kw): def exists_db(uri): "Check is a database exists" - return _call_with_uri('exists_db', URI(uri)) + return _call('exists_db', URI(uri)) def drop_db(uri): "Drop an existing database" - _call_with_uri('drop_db', URI(uri)) + _call('drop_db', URI(uri)) + +# helper for createdb +def _collect(directory, exts): + ''' + Read the files with a given set of extensions from a directory + and returns them ordered by version number. + ''' + sql = [] + for fname in os.listdir(directory): + if fname.endswith(exts) and not fname.startswith('_'): + version = VERSION.search(fname) + if version: + code = file(os.path.join(directory, fname)).read() + sql.append(Chunk(version, fname, code)) + return sorted(sql) def create_db(uri, force=False, scriptdir=None, **kw): """ @@ -73,29 +65,40 @@ def create_db(uri, force=False, scriptdir=None, **kw): is dropped and recreated. """ uri = URI(uri) + uri.import_driver() # import the driver if exists_db(uri): if force: - _call_with_uri('drop_db', uri) + _call('drop_db', uri) else: raise RuntimeError( 'There is already a database %s!' % uri) - _call_with_uri('create_db', uri) + _call('create_db', uri) db = lazyconnect(uri, **kw) + scriptdir = uri.scriptdir or scriptdir if scriptdir: - chunks = collect(dir, ('.sql', '.py')) + chunks = _collect(scriptdir, ('.sql', '.py')) for chunk in chunks: if chunk.fname.endswith('.sql'): - db.execute(chunk.code) + db.executescript(chunk.code) elif chunk.fname.endswith('.py'): exec chunk.code in {} return db def bulk_insert(conn, file, table, sep='\t'): - return _call_with_conn('bulk_insert', conn, file, table, sep) + return _call('bulk_insert', conn, file, table, sep) def exists_table(conn, tname): "Check if a table exists" - return _call_with_conn(conn, tname) + return _call('exists_table', conn, tname) + +def drop_table(conn, tname, force=False): + """ + Drop a table. If the table does not exist, raise an error, unless + force is True. + """ + if not exists_table(tname) and force: + return # do not raise an error + return conn.execute('DROP TABLE %s' % tname) ########################## schema management ########################### |