summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/dialects/postgresql/pg8000.py
blob: 1c45b50f278cb9bbad0a7de50c00ccf5351048a7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Support for the PostgreSQL database via the pg8000 driver.

Connecting
----------

URLs are of the form
`postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]`.

Unicode
-------

pg8000 requires that the postgresql client encoding be configured in the postgresql.conf file
in order to use encodings other than ascii.  Set this value to the same value as 
the "encoding" parameter on create_engine(), usually "utf-8".

Interval
--------

Passing data from/to the Interval type is not supported as of yet.

"""
from sqlalchemy.engine import default
import decimal
from sqlalchemy import util
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql.base import PGDialect, PGCompiler, PGIdentifierPreparer

class _PGNumeric(sqltypes.Numeric):
    def bind_processor(self, dialect):
        return None

    def result_processor(self, dialect):
        if self.asdecimal:
            return None
        else:
            def process(value):
                if isinstance(value, decimal.Decimal):
                    return float(value)
                else:
                    return value
            return process


class PostgreSQL_pg8000ExecutionContext(default.DefaultExecutionContext):
    pass


class PostgreSQL_pg8000Compiler(PGCompiler):
    def visit_mod(self, binary, **kw):
        return self.process(binary.left) + " %% " + self.process(binary.right)

    def post_process_text(self, text):
        if '%%' in text:
            util.warn("The SQLAlchemy postgresql dialect now automatically escapes '%' in text() "
                      "expressions to '%%'.")
        return text.replace('%', '%%')


class PostgreSQL_pg8000IdentifierPreparer(PGIdentifierPreparer):
    def _escape_identifier(self, value):
        value = value.replace(self.escape_quote, self.escape_to_quote)
        return value.replace('%', '%%')

    
class PostgreSQL_pg8000(PGDialect):
    driver = 'pg8000'

    supports_unicode_statements = True
    
    supports_unicode_binds = True
    
    default_paramstyle = 'format'
    supports_sane_multi_rowcount = False
    execution_ctx_cls = PostgreSQL_pg8000ExecutionContext
    statement_compiler = PostgreSQL_pg8000Compiler
    preparer = PostgreSQL_pg8000IdentifierPreparer
    
    colspecs = util.update_copy(
        PGDialect.colspecs,
        {
            sqltypes.Numeric : _PGNumeric,
            sqltypes.Float: sqltypes.Float,  # prevents _PGNumeric from being used
        }
    )
    
    @classmethod
    def dbapi(cls):
        return __import__('pg8000').dbapi

    def create_connect_args(self, url):
        opts = url.translate_connect_args(username='user')
        if 'port' in opts:
            opts['port'] = int(opts['port'])
        opts.update(url.query)
        return ([], opts)

    def is_disconnect(self, e):
        return "connection is closed" in str(e)

dialect = PostgreSQL_pg8000