diff options
Diffstat (limited to 'lib/sqlalchemy/ansisql.py')
-rw-r--r-- | lib/sqlalchemy/ansisql.py | 30 |
1 files changed, 24 insertions, 6 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index ac0b52200..d5b31c8aa 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -23,6 +23,11 @@ ANSI_FUNCS = sets.ImmutableSet([ ]) +RESERVED_WORDS = util.Set(['all', 'analyse', 'analyze', 'and', 'any', 'array', 'as', 'asc', 'asymmetric', 'authorization', 'between', 'binary', 'both', 'case', 'cast', 'check', 'collate', 'column', 'constraint', 'create', 'cross', 'current_date', 'current_role', 'current_time', 'current_timestamp', 'current_user', 'default', 'deferrable', 'desc', 'distinct', 'do', 'else', 'end', 'except', 'false', 'for', 'foreign', 'freeze', 'from', 'full', 'grant', 'group', 'having', 'ilike', 'in', 'initially', 'inner', 'intersect', 'into', 'is', 'isnull', 'join', 'leading', 'left', 'like', 'limit', 'localtime', 'localtimestamp', 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset', 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps', 'placing', 'primary', 'references', 'right', 'select', 'session_user', 'similar', 'some', 'symmetric', 'table', 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user', 'using', 'verbose', 'when', 'where']) + +LEGAL_CHARACTERS = util.Set(string.ascii_lowercase + string.ascii_uppercase + string.digits + '_$') +ILLEGAL_INITIAL_CHARACTERS = util.Set(string.digits + '$') + def create_engine(): return engine.ComposedSQLEngine(None, ANSIDialect()) @@ -749,21 +754,34 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): # some tests would need to be rewritten if this is done. #return value.upper() - def _requires_quotes(self, value, natural_case): + def _reserved_words(self): + return RESERVED_WORDS + + def _legal_characters(self): + return LEGAL_CHARACTERS + + def _illegal_initial_characters(self): + return ILLEGAL_INITIAL_CHARACTERS + + def _requires_quotes(self, value, case_sensitive): """return true if the given identifier requires quoting.""" - return False + return \ + value in self._reserved_words() \ + or (value[0] in self._illegal_initial_characters()) \ + or bool(len([x for x in str(value) if x not in self._legal_characters()])) \ + or (case_sensitive and value.lower() != value) def visit_table(self, table): if table in self.__visited: return - if table.quote or self._requires_quotes(table.name, table.natural_case): + if table.quote or self._requires_quotes(table.name, table.case_sensitive): tablestring = self._quote_identifier(table.name) else: tablestring = table.name if table.schema: - if table.quote_schema or self._requires_quotes(table.schema, table.natural_case_schema): + if table.quote_schema or self._requires_quotes(table.schema, table.case_sensitive_schema): schemastring = self._quote_identifier(table.schema) else: schemastring = table.schema @@ -775,7 +793,7 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): def visit_column(self, column): if column in self.__visited: return - if column.quote or self._requires_quotes(column.name, column.natural_case): + if column.quote or self._requires_quotes(column.name, column.case_sensitive): self.__strings[column] = self._quote_identifier(column.name) else: self.__strings[column] = column.name @@ -783,7 +801,7 @@ class ANSIIdentifierPreparer(schema.SchemaVisitor): def visit_sequence(self, sequence): if sequence in self.__visited: return - if sequence.quote or self._requires_quotes(sequence.name, sequence.natural_case): + if sequence.quote or self._requires_quotes(sequence.name, sequence.case_sensitive): self.__strings[sequence] = self._quote_identifier(sequence.name) else: self.__strings[sequence] = sequence.name |