summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorToshio Kuratomi <toshio@fedoraproject.org>2014-11-25 00:45:59 -0800
committerToshio Kuratomi <toshio@fedoraproject.org>2014-11-25 00:45:59 -0800
commit0287e9a23d29e253054ad6a110c7f5ba6a939595 (patch)
tree9969f9c75570ef88c4c7a0df995c4ed170e3c2a4
parent19606afe5f47f044c6d49935e2bd37f3c66b81e3 (diff)
downloadansible-0287e9a23d29e253054ad6a110c7f5ba6a939595.tar.gz
Normalize the identifier quoting so we can reuse the functions for mysql
-rw-r--r--lib/ansible/module_utils/database.py39
1 files changed, 23 insertions, 16 deletions
diff --git a/lib/ansible/module_utils/database.py b/lib/ansible/module_utils/database.py
index cb6c7c46b1..68b294a436 100644
--- a/lib/ansible/module_utils/database.py
+++ b/lib/ansible/module_utils/database.py
@@ -35,13 +35,14 @@ class UnclosedQuoteError(SQLParseError):
# maps a type of identifier to the maximum number of dot levels that are
# allowed to specifiy that identifier. For example, a database column can be
# specified by up to 4 levels: database.schema.table.column
-_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, schema=2, table=3, column=4, role=1)
+_PG_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, schema=2, table=3, column=4, role=1)
+_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1)
-def _find_end_quote(identifier):
+def _find_end_quote(identifier, quote_char='"'):
accumulate = 0
while True:
try:
- quote = identifier.index('"')
+ quote = identifier.index(quote_char)
except ValueError:
raise UnclosedQuoteError
accumulate = accumulate + quote
@@ -49,7 +50,7 @@ def _find_end_quote(identifier):
next_char = identifier[quote+1]
except IndexError:
return accumulate
- if next_char == '"':
+ if next_char == quote_char:
try:
identifier = identifier[quote+2:]
accumulate = accumulate + 2
@@ -59,15 +60,15 @@ def _find_end_quote(identifier):
return accumulate
-def _identifier_parse(identifier):
+def _identifier_parse(identifier, quote_char='"'):
if not identifier:
raise SQLParseError('Identifier name unspecified or unquoted trailing dot')
already_quoted = False
- if identifier.startswith('"'):
+ if identifier.startswith(quote_char):
already_quoted = True
try:
- end_quote = _find_end_quote(identifier[1:]) + 1
+ end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1
except UnclosedQuoteError:
already_quoted = False
else:
@@ -87,27 +88,33 @@ def _identifier_parse(identifier):
try:
dot = identifier.index('.')
except ValueError:
- identifier = identifier.replace('"', '""')
- identifier = ''.join(('"', identifier, '"'))
+ identifier = identifier.replace(quote_char, quote_char*2)
+ identifier = ''.join((quote_char, identifier, quote_char))
further_identifiers = [identifier]
else:
if dot == 0 or dot >= len(identifier) - 1:
- identifier = identifier.replace('"', '""')
- identifier = ''.join(('"', identifier, '"'))
+ identifier = identifier.replace(quote_char, quote_char*2)
+ identifier = ''.join((quote_char, identifier, quote_char))
further_identifiers = [identifier]
else:
first_identifier = identifier[:dot]
next_identifier = identifier[dot+1:]
further_identifiers = _identifier_parse(next_identifier)
- first_identifier = first_identifier.replace('"', '""')
- first_identifier = ''.join(('"', first_identifier, '"'))
+ first_identifier = first_identifier.replace(quote_char, quote_char*2)
+ first_identifier = ''.join((quote_char, first_identifier, quote_char))
further_identifiers.insert(0, first_identifier)
return further_identifiers
def pg_quote_identifier(identifier, id_type):
- identifier_fragments = _identifier_parse(identifier)
- if len(identifier_fragments) > _IDENTIFIER_TO_DOT_LEVEL[id_type]:
- raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _IDENTIFIER_TO_DOT_LEVEL[id_type]))
+ identifier_fragments = _identifier_parse(identifier, quote_char='"')
+ if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]:
+ raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]))
+ return '.'.join(identifier_fragments)
+
+def mysql_quote_identifier(identifier, id_type):
+ identifier_fragments = _identifier_parse(identifier, quote_char='`')
+ if len(identifier_fragments) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]:
+ raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _IDENTIFIER_TO_DOT_LEVEL[id_type]))
return '.'.join(identifier_fragments)