summaryrefslogtreecommitdiff
path: root/test/testlib/config.py
blob: ac9f397177b22aaad52291f43e75c204f37fcfeb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
import optparse, os, sys, re, ConfigParser, StringIO, time, warnings
logging, require = None, None


__all__ = 'parser', 'configure', 'options',

db = None
db_label, db_url, db_opts = None, None, {}

options = None
file_config = None

base_config = """
[db]
sqlite=sqlite:///:memory:
sqlite_file=sqlite:///querytest.db
postgres=postgres://scott:tiger@127.0.0.1:5432/test
mysql=mysql://scott:tiger@127.0.0.1:3306/test
oracle=oracle://scott:tiger@127.0.0.1:1521
oracle8=oracle://scott:tiger@127.0.0.1:1521/?use_ansi=0
mssql=mssql://scott:tiger@SQUAWK\\SQLEXPRESS/test
firebird=firebird://sysdba:masterkey@localhost//tmp/test.fdb
maxdb=maxdb://MONA:RED@/maxdb1
"""

parser = optparse.OptionParser(usage = "usage: %prog [options] [tests...]")

def configure():
    global options, config
    global getopts_options, file_config

    file_config = ConfigParser.ConfigParser()
    file_config.readfp(StringIO.StringIO(base_config))
    file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])

    # Opt parsing can fire immediate actions, like logging and coverage
    (options, args) = parser.parse_args()
    sys.argv[1:] = args

    # Lazy setup of other options (post coverage)
    for fn in post_configure:
        fn(options, file_config)

    return options, file_config

def configure_defaults():
    global options, config
    global getopts_options, file_config
    global db

    file_config = ConfigParser.ConfigParser()
    file_config.readfp(StringIO.StringIO(base_config))
    file_config.read(['test.cfg', os.path.expanduser('~/.satest.cfg')])
    (options, args) = parser.parse_args([])

    # make error messages raised by decorators that depend on a default
    # database clearer.
    class _engine_bomb(object):
        def __getattr__(self, key):
            raise RuntimeError('No default engine available, testlib '
                               'was configured with defaults only.')

    db = _engine_bomb()
    import testlib.testing
    testlib.testing.db = db

    return options, file_config

def _log(option, opt_str, value, parser):
    global logging
    if not logging:
        import logging
        logging.basicConfig()

    if opt_str.endswith('-info'):
        logging.getLogger(value).setLevel(logging.INFO)
    elif opt_str.endswith('-debug'):
        logging.getLogger(value).setLevel(logging.DEBUG)

def _start_coverage(option, opt_str, value, parser):
    import sys, atexit, coverage
    true_out = sys.stdout

    def _iter_covered_files():
        import sqlalchemy
        for rec in os.walk(os.path.dirname(sqlalchemy.__file__)):
            for x in rec[2]:
                if x.endswith('.py'):
                    yield os.path.join(rec[0], x)
    def _stop():
        coverage.stop()
        true_out.write("\nPreparing coverage report...\n")
        coverage.report(list(_iter_covered_files()),
                        show_missing=False, ignore_errors=False,
                        file=true_out)
    atexit.register(_stop)
    coverage.erase()
    coverage.start()

def _list_dbs(*args):
    print "Available --db options (use --dburi to override)"
    for macro in sorted(file_config.options('db')):
        print "%20s\t%s" % (macro, file_config.get('db', macro))
    sys.exit(0)

def _server_side_cursors(options, opt_str, value, parser):
    db_opts['server_side_cursors'] = True

def _engine_strategy(options, opt_str, value, parser):
    if value:
        db_opts['strategy'] = value

opt = parser.add_option
opt("--verbose", action="store_true", dest="verbose",
    help="enable stdout echoing/printing")
opt("--quiet", action="store_true", dest="quiet", help="suppress output")
opt("--log-info", action="callback", type="string", callback=_log,
    help="turn on info logging for <LOG> (multiple OK)")
opt("--log-debug", action="callback", type="string", callback=_log,
    help="turn on debug logging for <LOG> (multiple OK)")
opt("--require", action="append", dest="require", default=[],
    help="require a particular driver or module version (multiple OK)")
opt("--db", action="store", dest="db", default="sqlite",
    help="Use prefab database uri")
opt('--dbs', action='callback', callback=_list_dbs,
    help="List available prefab dbs")
opt("--dburi", action="store", dest="dburi",
    help="Database uri (overrides --db)")
opt("--dropfirst", action="store_true", dest="dropfirst",
    help="Drop all tables in the target database first (use with caution on Oracle, MS-SQL)")
opt("--mockpool", action="store_true", dest="mockpool",
    help="Use mock pool (asserts only one connection used)")
opt("--enginestrategy", action="callback", type="string",
    callback=_engine_strategy,
    help="Engine strategy (plain or threadlocal, defaults to plain)")
opt("--reversetop", action="store_true", dest="reversetop", default=False,
    help="Reverse the collection ordering for topological sorts (helps "
          "reveal dependency issues)")
opt("--unhashable", action="store_true", dest="unhashable", default=False,
    help="Disallow SQLAlchemy from performing a hash() on mapped test objects.")
opt("--noncomparable", action="store_true", dest="noncomparable", default=False,
    help="Disallow SQLAlchemy from performing == on mapped test objects.")
opt("--truthless", action="store_true", dest="truthless", default=False,
    help="Disallow SQLAlchemy from truth-evaluating mapped test objects.")
opt("--serverside", action="callback", callback=_server_side_cursors,
    help="Turn on server side cursors for PG")
opt("--mysql-engine", action="store", dest="mysql_engine", default=None,
    help="Use the specified MySQL storage engine for all tables, default is "
         "a db-default/InnoDB combo.")
opt("--table-option", action="append", dest="tableopts", default=[],
    help="Add a dialect-specific table option, key=value")
opt("--coverage", action="callback", callback=_start_coverage,
    help="Dump a full coverage report after running tests")
opt("--profile", action="append", dest="profile_targets", default=[],
    help="Enable a named profile target (multiple OK.)")
opt("--profile-sort", action="store", dest="profile_sort", default=None,
    help="Sort profile stats with this comma-separated sort order")
opt("--profile-limit", type="int", action="store", dest="profile_limit",
    default=None,
    help="Limit function count in profile stats")

class _ordered_map(object):
    def __init__(self):
        self._keys = list()
        self._data = dict()

    def __setitem__(self, key, value):
        if key not in self._keys:
            self._keys.append(key)
        self._data[key] = value

    def __iter__(self):
        for key in self._keys:
            yield self._data[key]

# at one point in refactoring, modules were injecting into the config
# process.  this could probably just become a list now.
post_configure = _ordered_map()

def _engine_uri(options, file_config):
    global db_label, db_url
    db_label = 'sqlite'
    if options.dburi:
        db_url = options.dburi
        db_label = db_url[:db_url.index(':')]
    elif options.db:
        db_label = options.db
        db_url = None

    if db_url is None:
        if db_label not in file_config.options('db'):
            raise RuntimeError(
                "Unknown engine.  Specify --dbs for known engines.")
        db_url = file_config.get('db', db_label)
post_configure['engine_uri'] = _engine_uri

def _require(options, file_config):
    if not(options.require or
           (file_config.has_section('require') and
            file_config.items('require'))):
        return

    try:
        import pkg_resources
    except ImportError:
        raise RuntimeError("setuptools is required for version requirements")

    cmdline = []
    for requirement in options.require:
        pkg_resources.require(requirement)
        cmdline.append(re.split('\s*(<!>=)', requirement, 1)[0])

    if file_config.has_section('require'):
        for label, requirement in file_config.items('require'):
            if not label == db_label or label.startswith('%s.' % db_label):
                continue
            seen = [c for c in cmdline if requirement.startswith(c)]
            if seen:
                continue
            pkg_resources.require(requirement)
post_configure['require'] = _require

def _engine_pool(options, file_config):
    if options.mockpool:
        from sqlalchemy import pool
        db_opts['poolclass'] = pool.AssertionPool
post_configure['engine_pool'] = _engine_pool

def _create_testing_engine(options, file_config):
    from testlib import engines, testing
    global db
    db = engines.testing_engine(db_url, db_opts)
    testing.db = db
post_configure['create_engine'] = _create_testing_engine

def _prep_testing_database(options, file_config):
    from testlib import engines
    from sqlalchemy import schema

    try:
        # also create alt schemas etc. here?
        if options.dropfirst:
            e = engines.utf8_engine()
            existing = e.table_names()
            if existing:
                if not options.quiet:
                    print "Dropping existing tables in database: " + db_url
                    try:
                        print "Tables: %s" % ', '.join(existing)
                    except:
                        pass
                    print "Abort within 5 seconds..."
                    time.sleep(5)
                md = schema.MetaData(e, reflect=True)
                md.drop_all()
            e.dispose()
    except (KeyboardInterrupt, SystemExit):
        raise
    except Exception, e:
        if not options.quiet:
            warnings.warn(RuntimeWarning(
                "Error checking for existing tables in testing "
                "database: %s" % e))
post_configure['prep_db'] = _prep_testing_database

def _set_table_options(options, file_config):
    import testlib.schema

    table_options = testlib.schema.table_options
    for spec in options.tableopts:
        key, value = spec.split('=')
        table_options[key] = value

    if options.mysql_engine:
        table_options['mysql_engine'] = options.mysql_engine
post_configure['table_options'] = _set_table_options

def _reverse_topological(options, file_config):
    if options.reversetop:
        from sqlalchemy.orm import unitofwork
        from sqlalchemy import topological
        class RevQueueDepSort(topological.QueueDependencySorter):
            def __init__(self, tuples, allitems):
                self.tuples = list(tuples)
                self.allitems = list(allitems)
                self.tuples.reverse()
                self.allitems.reverse()
        topological.QueueDependencySorter = RevQueueDepSort
        unitofwork.DependencySorter = RevQueueDepSort
post_configure['topological'] = _reverse_topological

def _set_profile_targets(options, file_config):
    from testlib import profiling

    profile_config = profiling.profile_config

    for target in options.profile_targets:
        profile_config['targets'].add(target)

    if options.profile_sort:
        profile_config['sort'] = options.profile_sort.split(',')

    if options.profile_limit:
        profile_config['limit'] = options.profile_limit

    if options.quiet:
        profile_config['report'] = False

    # magic "all" target
    if 'all' in profiling.all_targets:
        targets = profile_config['targets']
        if 'all' in targets and len(targets) != 1:
            targets.clear()
            targets.add('all')
post_configure['profile_targets'] = _set_profile_targets