summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/engine.py
blob: 811fa943313f8da12fd74d13ea832f8d67e21e11 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""builds upon the schema and sql packages to provide a central object for tying schema objects
and sql constructs to database-specific query compilation and execution"""

import sqlalchemy.schema as schema
import sqlalchemy.pool
import sqlalchemy.util as util
import sqlalchemy.sql as sql
import StringIO

class SchemaIterator(schema.SchemaVisitor):
    """a visitor that can gather text into a buffer and execute the contents of the buffer."""
    
    def __init__(self, sqlproxy, **params):
        self.sqlproxy = sqlproxy
        self.buffer = StringIO.StringIO()

    def run(self):
        raise NotImplementedError()

    def append(self, s):
        self.buffer.write(s)
        
    def execute(self):
        try:
            return self.sqlproxy(self.buffer.getvalue())
        finally:
            self.buffer.truncate(0)

class SQLEngine(schema.SchemaEngine):
    """base class for a series of database-specific engines.  serves as an abstract factory for
    implementation objects as well as database connections, transactions, SQL generators, etc."""
    
    def __init__(self, pool = None, echo = False, **params):
        # get a handle on the connection pool via the connect arguments
        # this insures the SQLEngine instance integrates with the pool referenced
        # by direct usage of pool.manager(<module>).connect(*args, **params)
        (cargs, cparams) = self.connect_args()
        self._pool = sqlalchemy.pool.manage(self.dbapi()).get_pool(*cargs, **cparams)
        self._echo = echo
        self.context = util.ThreadLocal()
        
    def schemagenerator(self, proxy, **params):
        raise NotImplementedError()

    def schemadropper(self, proxy, **params):
        raise NotImplementedError()
        
    def columnimpl(self, column):
        return sql.ColumnSelectable(column)

    def connect_args(self):
        raise NotImplementedError()
        
    def dbapi(self):
        raise NotImplementedError()

    def compile(self, statement):
        raise NotImplementedError()

    def proxy(self):
        return lambda s, p = None: self.execute(s, p)
        
    def connection(self):
        return self._pool.connect()

    def transaction(self, func):
        self.begin()
        try:
            func()
        except:
            self.rollback()
            raise
        self.commit()
            
    def begin(self):
        if getattr(self.context, 'transaction', None) is None:
            conn = self.connection()
            self.context.transaction = conn
            self.context.tcount = 1
        else:
            self.context.tcount += 1
            
    def rollback(self):
        if self.context.transaction is not None:
            self.context.transaction.rollback()
            self.context.transaction = None
            self.context.tcount = None
            
    def commit(self):
        if self.context.transaction is not None:
            count = self.context.tcount - 1
            self.context.tcount = count
            if count == 0:
                self.context.transaction.commit()
                self.context.transaction = None
                self.context.tcount = None
                
    def execute(self, statement, parameters, connection = None, **params):
        if parameters is None:
            parameters = {}
        
        if self._echo:
            self.log(statement)
            self.log(repr(parameters))
            
        if connection is None:
            poolconn = self.connection()
            c = poolconn.cursor()
            c.execute(statement, parameters)
            return c
        else:
            c = connection.cursor()
            c.execute(statement, parameters)
            return c

    def log(self, msg):
        print msg