diff options
author | staticdev <staticdev-support@protonmail.com> | 2022-12-10 12:25:07 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-10 12:25:07 +0000 |
commit | fc906c0ec4d1a5aebb530c2d1a7c3be1051d284c (patch) | |
tree | 48a74dc2dab277732700b291b5f03c11d43fa2ff | |
parent | 7ea25f07c271637e01d5a54f7aba79b1d5db61e1 (diff) | |
parent | 1398e40df157101c5e8308dda45ff6d7be348c65 (diff) | |
download | isort-fc906c0ec4d1a5aebb530c2d1a7c3be1051d284c.tar.gz |
Merge pull request #2017 from XuehaiPan/black-pyi
Fix `black` compatibility for `.pyi` type stub files
-rw-r--r-- | isort/output.py | 12 | ||||
-rw-r--r-- | tests/unit/profiles/test_black.py | 86 |
2 files changed, 90 insertions, 8 deletions
diff --git a/isort/output.py b/isort/output.py index c59be936..3cb3c08b 100644 --- a/isort/output.py +++ b/isort/output.py @@ -209,16 +209,20 @@ def sorted_imports( break if config.lines_after_imports != -1: - formatted_output[imports_tail:0] = [ - "" for line in range(config.lines_after_imports) - ] + lines_after_imports = config.lines_after_imports + if config.profile == "black" and extension == "pyi": # special case for black + lines_after_imports = 1 + formatted_output[imports_tail:0] = ["" for line in range(lines_after_imports)] elif extension != "pyi" and next_construct.startswith(STATEMENT_DECLARATIONS): formatted_output[imports_tail:0] = ["", ""] else: formatted_output[imports_tail:0] = [""] if config.lines_before_imports != -1: - formatted_output[:0] = ["" for line in range(config.lines_before_imports)] + lines_before_imports = config.lines_before_imports + if config.profile == "black" and extension == "pyi": # special case for black + lines_before_imports = 1 + formatted_output[:0] = ["" for line in range(lines_before_imports)] if parsed.place_imports: new_out_lines = [] diff --git a/tests/unit/profiles/test_black.py b/tests/unit/profiles/test_black.py index 05444b8c..a4bb2d52 100644 --- a/tests/unit/profiles/test_black.py +++ b/tests/unit/profiles/test_black.py @@ -19,20 +19,25 @@ def black_format(code: str, is_pyi: bool = False, line_length: int = 88) -> str: return code -def black_test(code: str, expected_output: str = ""): +def black_test(code: str, expected_output: str = "", *, is_pyi: bool = False, **config_kwargs): """Tests that the given code: - Behaves the same when formatted multiple times with isort. - Agrees with black formatting. - Matches the desired output or itself if none is provided. """ expected_output = expected_output or code + config_kwargs = { + "extension": "pyi" if is_pyi else None, + "profile": "black", + **config_kwargs, + } # output should stay consistent over multiple runs - output = isort.code(code, profile="black") - assert output == isort.code(code, profile="black") + output = isort.code(code, **config_kwargs) + assert output == isort.code(code, **config_kwargs) # output should agree with black - black_output = black_format(output) + black_output = black_format(output, is_pyi=is_pyi) assert output == black_output # output should match expected output @@ -369,3 +374,76 @@ if TYPE_CHECKING: DEFAULT_LINE_LENGTH = 88 """, ) + + +def test_black_pyi_file(): + """Test consistent code formatting between isort and black for `.pyi` files. + + black only allows no more than two consecutive blank lines in a `.pyi` file. + """ + + black_test( + """# comment + + +import math + +from typing import Sequence +import numpy as np + + +def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: + ... + + +def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: + ... +""", + """# comment + + +import math +from typing import Sequence + +import numpy as np + + +def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: + ... + + +def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: + ... +""", + is_pyi=False, + lines_before_imports=2, + lines_after_imports=2, + ) + + black_test( + """# comment + + +import math + +from typing import Sequence +import numpy as np + + +def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: ... +def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: ... +""", + """# comment + +import math +from typing import Sequence + +import numpy as np + +def add(a: np.ndarray, b: np.ndarray) -> np.ndarray: ... +def sub(a: np.ndarray, b: np.ndarray) -> np.ndarray: ... +""", + is_pyi=True, + lines_before_imports=2, # will be ignored + lines_after_imports=2, # will be ignored + ) |