summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTimothy Edmund Crosley <timothy.crosley@gmail.com>2015-07-15 00:40:12 -0700
committerTimothy Edmund Crosley <timothy.crosley@gmail.com>2015-07-15 00:40:12 -0700
commit064566cc59ff7bdb8b2d71f69fd2a542606b386c (patch)
treee1be5327f37aaea9530a9db5e2677f8e5c11ea4c
parent48718ce5cc6303d16a642677306bf196001ed196 (diff)
parentf157b06c61e6b77cc2c0d38710ab974d044cec02 (diff)
downloadisort-064566cc59ff7bdb8b2d71f69fd2a542606b386c.tar.gz
Merge pull request #290 from timothycrosley/feature/fix-issue-284
Feature/fix issue 284
-rw-r--r--isort/isort.py31
-rw-r--r--test_isort.py9
2 files changed, 32 insertions, 8 deletions
diff --git a/isort/isort.py b/isort/isort.py
index 1fd5f8be..0acfd815 100644
--- a/isort/isort.py
+++ b/isort/isort.py
@@ -458,7 +458,9 @@ class SortImports(object):
section_title = self.config.get('import_heading_' + str(section_name).lower(), '')
if section_title:
- section_output.insert(0, "# " + section_title)
+ section_comment = "# {0}".format(section_title)
+ if not section_comment in self.out_lines[0:1]:
+ section_output.insert(0, section_comment)
output += section_output + ['']
while [character.strip() for character in output[-1:]] == [""]:
@@ -603,7 +605,13 @@ class SortImports(object):
def _skip_line(self, line):
skip_line = self._in_quote
- if '"' in line or "'" in line:
+ if self.index == 1 and line.startswith("#"):
+ self._in_top_comment = True
+ elif self._in_top_comment:
+ if not line.startswith("#"):
+ self._in_top_comment = False
+ self._first_comment_index_end = self.index
+ elif '"' in line or "'" in line:
index = 0
if self._first_comment_index_start == -1:
self._first_comment_index_start = self.index
@@ -613,7 +621,7 @@ class SortImports(object):
elif self._in_quote:
if line[index:index + len(self._in_quote)] == self._in_quote:
self._in_quote = False
- if self._first_comment_index_end == -1:
+ if self._first_comment_index_end < self._first_comment_index_start:
self._first_comment_index_end = self.index
elif line[index] in ("'", '"'):
long_quote = line[index:index + 3]
@@ -626,7 +634,7 @@ class SortImports(object):
break
index += 1
- return skip_line or self._in_quote
+ return skip_line or self._in_quote or self._in_top_comment
def _strip_syntax(self, import_string):
import_string = import_string.replace("_import", "[[i]]")
@@ -643,6 +651,7 @@ class SortImports(object):
def _parse(self):
"""Parses a python file taking out and categorizing imports."""
self._in_quote = False
+ self._in_top_comment = False
while not self._at_end():
line = self._get_line()
skip_line = self._skip_line(line)
@@ -733,11 +742,14 @@ class SortImports(object):
if comments:
self.comments['from'].setdefault(import_from, []).extend(comments)
- if len(self.out_lines) > self.import_index:
+ if len(self.out_lines) > max(self.import_index, self._first_comment_index_end, 1) - 1:
last = self.out_lines and self.out_lines[-1].rstrip() or ""
while last.startswith("#") and not last.endswith('"""') and not last.endswith("'''"):
self.comments['above']['from'].setdefault(import_from, []).insert(0, self.out_lines.pop(-1))
- last = self.out_lines and self.out_lines[-1].rstrip() or ""
+ if len(self.out_lines) > max(self.import_index - 1, self._first_comment_index_end, 1) - 1:
+ last = self.out_lines[-1].rstrip()
+ else:
+ last = ""
if root.get(import_from, False):
root[import_from].update(imports)
@@ -749,11 +761,14 @@ class SortImports(object):
self.comments['straight'][module] = comments
comments = None
- if len(self.out_lines) > self.import_index:
+ if len(self.out_lines) > max(self.import_index, self._first_comment_index_end, 1) - 1:
last = self.out_lines and self.out_lines[-1].rstrip() or ""
while last.startswith("#") and not last.endswith('"""') and not last.endswith("'''"):
self.comments['above']['from'].setdefault(module, []).insert(0, self.out_lines.pop(-1))
- last = self.out_lines and self.out_lines[-1].rstrip() or ""
+ if len(self.out_lines) > max(self.import_index - 1, self._first_comment_index_end, 1) - 1:
+ last = self.out_lines[-1].rstrip()
+ else:
+ last = ""
self.imports[self.place_module(module)][import_type].add(module)
diff --git a/test_isort.py b/test_isort.py
index 756b1932..dd1c3f63 100644
--- a/test_isort.py
+++ b/test_isort.py
@@ -1441,3 +1441,12 @@ def test_other_file_encodings():
assert SortImports(file_path=tmp_fname).output == file_contents
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
+
+
+def test_comment_at_top_of_file():
+ """Test to ensure isort correctly handles top of file comments"""
+ test_input = ("# Comment one\n"
+ "from django import forms\n"
+ "# Comment two\n"
+ "from django.contrib.gis.geos import GEOSException\n")
+ assert SortImports(file_contents=test_input).output == test_input