summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTakeshi KOMIYA <i.tkomiya@gmail.com>2020-03-14 13:51:49 +0900
committerTakeshi KOMIYA <i.tkomiya@gmail.com>2020-03-14 13:51:49 +0900
commitf6d58bbefc4f57710d6b2875bdd6b1f0533d1f39 (patch)
tree61a553241898a73a114cd1dcca5682ee1aa6b350
parentf85b870ad59f39c8637160a4cd4d865ce1e1628e (diff)
downloadsphinx-git-f6d58bbefc4f57710d6b2875bdd6b1f0533d1f39.tar.gz
Fix #7304: pycode: Support operators (BinOp, BoolOp and UnaryOp)
-rw-r--r--sphinx/pycode/ast.py34
-rw-r--r--tests/test_pycode_ast.py22
2 files changed, 55 insertions, 1 deletions
diff --git a/sphinx/pycode/ast.py b/sphinx/pycode/ast.py
index 52617e3bc..4d8aa8955 100644
--- a/sphinx/pycode/ast.py
+++ b/sphinx/pycode/ast.py
@@ -9,7 +9,7 @@
"""
import sys
-from typing import List
+from typing import Dict, List, Type
if sys.version_info > (3, 8):
import ast
@@ -21,6 +21,29 @@ else:
import ast # type: ignore
+OPERATORS = {
+ ast.Add: "+",
+ ast.And: "and",
+ ast.BitAnd: "&",
+ ast.BitOr: "|",
+ ast.BitXor: "^",
+ ast.Div: "/",
+ ast.FloorDiv: "//",
+ ast.Invert: "~",
+ ast.LShift: "<<",
+ ast.MatMult: "@",
+ ast.Mult: "*",
+ ast.Mod: "%",
+ ast.Not: "not",
+ ast.Pow: "**",
+ ast.Or: "or",
+ ast.RShift: ">>",
+ ast.Sub: "-",
+ ast.UAdd: "+",
+ ast.USub: "-",
+} # type: Dict[Type[ast.AST], str]
+
+
def parse(code: str, mode: str = 'exec') -> "ast.AST":
"""Parse the *code* using built-in ast or typed_ast.
@@ -41,6 +64,8 @@ def unparse(node: ast.AST) -> str:
return None
elif isinstance(node, str):
return node
+ elif node.__class__ in OPERATORS:
+ return OPERATORS[node.__class__]
elif isinstance(node, ast.arg):
if node.annotation:
return "%s: %s" % (node.arg, unparse(node.annotation))
@@ -50,6 +75,11 @@ def unparse(node: ast.AST) -> str:
return unparse_arguments(node)
elif isinstance(node, ast.Attribute):
return "%s.%s" % (unparse(node.value), node.attr)
+ elif isinstance(node, ast.BinOp):
+ return " ".join(unparse(e) for e in [node.left, node.op, node.right])
+ elif isinstance(node, ast.BoolOp):
+ op = " %s " % unparse(node.op)
+ return op.join(unparse(e) for e in node.values)
elif isinstance(node, ast.Bytes):
return repr(node.s)
elif isinstance(node, ast.Call):
@@ -81,6 +111,8 @@ def unparse(node: ast.AST) -> str:
return repr(node.s)
elif isinstance(node, ast.Subscript):
return "%s[%s]" % (unparse(node.value), unparse(node.slice))
+ elif isinstance(node, ast.UnaryOp):
+ return "%s %s" % (unparse(node.op), unparse(node.operand))
elif isinstance(node, ast.Tuple):
return ", ".join(unparse(e) for e in node.elts)
elif sys.version_info > (3, 6) and isinstance(node, ast.Constant):
diff --git a/tests/test_pycode_ast.py b/tests/test_pycode_ast.py
index d195e5c6f..117feb8f7 100644
--- a/tests/test_pycode_ast.py
+++ b/tests/test_pycode_ast.py
@@ -16,21 +16,43 @@ from sphinx.pycode import ast
@pytest.mark.parametrize('source,expected', [
+ ("a + b", "a + b"), # Add
+ ("a and b", "a and b"), # And
("os.path", "os.path"), # Attribute
+ ("1 * 2", "1 * 2"), # BinOp
+ ("a & b", "a & b"), # BitAnd
+ ("a | b", "a | b"), # BitOr
+ ("a ^ b", "a ^ b"), # BitXor
+ ("a and b and c", "a and b and c"), # BoolOp
("b'bytes'", "b'bytes'"), # Bytes
("object()", "object()"), # Call
("1234", "1234"), # Constant
("{'key1': 'value1', 'key2': 'value2'}",
"{'key1': 'value1', 'key2': 'value2'}"), # Dict
+ ("a / b", "a / b"), # Div
("...", "..."), # Ellipsis
+ ("a // b", "a // b"), # FloorDiv
("Tuple[int, int]", "Tuple[int, int]"), # Index, Subscript
+ ("~ 1", "~ 1"), # Invert
("lambda x, y: x + y",
"lambda x, y: ..."), # Lambda
("[1, 2, 3]", "[1, 2, 3]"), # List
+ ("a << b", "a << b"), # LShift
+ ("a @ b", "a @ b"), # MatMult
+ ("a % b", "a % b"), # Mod
+ ("a * b", "a * b"), # Mult
("sys", "sys"), # Name, NameConstant
("1234", "1234"), # Num
+ ("not a", "not a"), # Not
+ ("a or b", "a or b"), # Or
+ ("a ** b", "a ** b"), # Pow
+ ("a >> b", "a >> b"), # RShift
("{1, 2, 3}", "{1, 2, 3}"), # Set
+ ("a - b", "a - b"), # Sub
("'str'", "'str'"), # Str
+ ("+ a", "+ a"), # UAdd
+ ("- 1", "- 1"), # UnaryOp
+ ("- a", "- a"), # USub
("(1, 2, 3)", "1, 2, 3"), # Tuple
])
def test_unparse(source, expected):