summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorstaticdev <staticdev-support@protonmail.com>2022-12-10 12:25:07 +0000
committerGitHub <noreply@github.com>2022-12-10 12:25:07 +0000
commitfc906c0ec4d1a5aebb530c2d1a7c3be1051d284c (patch)
tree48a74dc2dab277732700b291b5f03c11d43fa2ff
parent7ea25f07c271637e01d5a54f7aba79b1d5db61e1 (diff)
parent1398e40df157101c5e8308dda45ff6d7be348c65 (diff)
downloadisort-fc906c0ec4d1a5aebb530c2d1a7c3be1051d284c.tar.gz
Merge pull request #2017 from XuehaiPan/black-pyi
Fix `black` compatibility for `.pyi` type stub files
-rw-r--r--isort/output.py12
-rw-r--r--tests/unit/profiles/test_black.py86
2 files changed, 90 insertions, 8 deletions
diff --git a/isort/output.py b/isort/output.py
index c59be936..3cb3c08b 100644
--- a/isort/output.py
+++ b/isort/output.py
@@ -209,16 +209,20 @@ def sorted_imports(
break
if config.lines_after_imports != -1:
- formatted_output[imports_tail:0] = [
- "" for line in range(config.lines_after_imports)
- ]
+ lines_after_imports = config.lines_after_imports
+ if config.profile == "black" and extension == "pyi": # special case for black
+ lines_after_imports = 1
+ formatted_output[imports_tail:0] = ["" for line in range(lines_after_imports)]
elif extension != "pyi" and next_construct.startswith(STATEMENT_DECLARATIONS):
formatted_output[imports_tail:0] = ["", ""]
else:
formatted_output[imports_tail:0] = [""]
if config.lines_before_imports != -1:
- formatted_output[:0] = ["" for line in range(config.lines_before_imports)]
+ lines_before_imports = config.lines_before_imports
+ if config.profile == "black" and extension == "pyi": # special case for black
+ lines_before_imports = 1
+ formatted_output[:0] = ["" for line in range(lines_before_imports)]
if parsed.place_imports:
new_out_lines = []
diff --git a/tests/unit/profiles/test_black.py b/tests/unit/profiles/test_black.py
index 05444b8c..a4bb2d52 100644
--- a/tests/unit/profiles/test_black.py
+++ b/tests/unit/profiles/test_black.py
@@ -19,20 +19,25 @@ def black_format(code: str, is_pyi: bool = False, line_length: int = 88) -> str:
return code
-def black_test(code: str, expected_output: str = ""):
+def black_test(code: str, expected_output: str = "", *, is_pyi: bool = False, **config_kwargs):
"""Tests that the given code:
- Behaves the same when formatted multiple times with isort.
- Agrees with black formatting.
- Matches the desired output or itself if none is provided.
"""
expected_output = expected_output or code
+ config_kwargs = {
+ "extension": "pyi" if is_pyi else None,
+ "profile": "black",
+ **config_kwargs,
+ }
# output should stay consistent over multiple runs
- output = isort.code(code, profile="black")
- assert output == isort.code(code, profile="black")
+ output = isort.code(code, **config_kwargs)
+ assert output == isort.code(code, **config_kwargs)
# output should agree with black
- black_output = black_format(output)
+ black_output = black_format(output, is_pyi=is_pyi)
assert output == black_output
# output should match expected output
@@ -369,3 +374,76 @@ if TYPE_CHECKING:
DEFAULT_LINE_LENGTH = 88
""",
)
+
+
+def test_black_pyi_file():
+ """Test consistent code formatting between isort and black for `.pyi` files.
+
+ black only allows no more than two consecutive blank lines in a `.pyi` file.
+ """
+
+ black_test(
+ """# comment
+
+
+import math
+
+from typing import Sequence
+import numpy as np
+
+
+def add(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ ...
+
+
+def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ ...
+""",
+ """# comment
+
+
+import math
+from typing import Sequence
+
+import numpy as np
+
+
+def add(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ ...
+
+
+def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray:
+ ...
+""",
+ is_pyi=False,
+ lines_before_imports=2,
+ lines_after_imports=2,
+ )
+
+ black_test(
+ """# comment
+
+
+import math
+
+from typing import Sequence
+import numpy as np
+
+
+def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: ...
+def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: ...
+""",
+ """# comment
+
+import math
+from typing import Sequence
+
+import numpy as np
+
+def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: ...
+def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: ...
+""",
+ is_pyi=True,
+ lines_before_imports=2, # will be ignored
+ lines_after_imports=2, # will be ignored
+ )