summaryrefslogtreecommitdiff
path: root/Tools
diff options
context:
space:
mode:
Diffstat (limited to 'Tools')
-rw-r--r--Tools/BUILD.bazel1
-rw-r--r--Tools/ci-run.sh76
-rw-r--r--Tools/cython-mode.el303
-rw-r--r--Tools/dataclass_test_data/test_dataclasses.py4266
-rw-r--r--Tools/dump_github_issues.py142
-rw-r--r--Tools/gen_tests_for_posix_pxds.py41
-rw-r--r--Tools/make_dataclass_tests.py443
-rw-r--r--Tools/rules.bzl4
8 files changed, 4956 insertions, 320 deletions
diff --git a/Tools/BUILD.bazel b/Tools/BUILD.bazel
index e69de29bb..8b1378917 100644
--- a/Tools/BUILD.bazel
+++ b/Tools/BUILD.bazel
@@ -0,0 +1 @@
+
diff --git a/Tools/ci-run.sh b/Tools/ci-run.sh
index 09e9ae318..d42365fd6 100644
--- a/Tools/ci-run.sh
+++ b/Tools/ci-run.sh
@@ -1,24 +1,24 @@
#!/usr/bin/bash
+set -x
+
GCC_VERSION=${GCC_VERSION:=8}
# Set up compilers
if [[ $TEST_CODE_STYLE == "1" ]]; then
- echo "Skipping compiler setup"
+ echo "Skipping compiler setup: Code style run"
elif [[ $OSTYPE == "linux-gnu"* ]]; then
echo "Setting up linux compiler"
echo "Installing requirements [apt]"
sudo apt-add-repository -y "ppa:ubuntu-toolchain-r/test"
sudo apt update -y -q
- sudo apt install -y -q ccache gdb python-dbg python3-dbg gcc-$GCC_VERSION || exit 1
+ sudo apt install -y -q gdb python3-dbg gcc-$GCC_VERSION || exit 1
ALTERNATIVE_ARGS=""
if [[ $BACKEND == *"cpp"* ]]; then
sudo apt install -y -q g++-$GCC_VERSION || exit 1
ALTERNATIVE_ARGS="--slave /usr/bin/g++ g++ /usr/bin/g++-$GCC_VERSION"
fi
- sudo /usr/sbin/update-ccache-symlinks
- echo "/usr/lib/ccache" >> $GITHUB_PATH # export ccache to path
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-$GCC_VERSION 60 $ALTERNATIVE_ARGS
@@ -32,7 +32,27 @@ elif [[ $OSTYPE == "darwin"* ]]; then
export CC="clang -Wno-deprecated-declarations"
export CXX="clang++ -stdlib=libc++ -Wno-deprecated-declarations"
else
- echo "No setup specified for $OSTYPE"
+ echo "Skipping compiler setup: No setup specified for $OSTYPE"
+fi
+
+if [[ $COVERAGE == "1" ]]; then
+ echo "Skip setting up compilation caches"
+elif [[ $OSTYPE == "msys" ]]; then
+ echo "Set up sccache"
+ echo "TODO: Make a soft symlink to sccache"
+else
+ echo "Set up ccache"
+
+ echo "/usr/lib/ccache" >> $GITHUB_PATH # export ccache to path
+
+ echo "Make a soft symlinks to ccache"
+ cp ccache /usr/local/bin/
+ ln -s ccache /usr/local/bin/gcc
+ ln -s ccache /usr/local/bin/g++
+ ln -s ccache /usr/local/bin/cc
+ ln -s ccache /usr/local/bin/c++
+ ln -s ccache /usr/local/bin/clang
+ ln -s ccache /usr/local/bin/clang++
fi
# Set up miniconda
@@ -50,14 +70,17 @@ echo "===================="
echo "|VERSIONS INSTALLED|"
echo "===================="
echo "Python $PYTHON_SYS_VERSION"
+
if [[ $CC ]]; then
which ${CC%% *}
${CC%% *} --version
fi
+
if [[ $CXX ]]; then
which ${CXX%% *}
${CXX%% *} --version
fi
+
echo "===================="
# Install python requirements
@@ -68,12 +91,17 @@ if [[ $PYTHON_VERSION == "2.7"* ]]; then
elif [[ $PYTHON_VERSION == "3."[45]* ]]; then
python -m pip install wheel || exit 1
python -m pip install -r test-requirements-34.txt || exit 1
+elif [[ $PYTHON_VERSION == "pypy-2.7" ]]; then
+ pip install wheel || exit 1
+ pip install -r test-requirements-pypy27.txt || exit 1
+elif [[ $PYTHON_VERSION == "3.1"[2-9]* ]]; then
+ python -m pip install wheel || exit 1
+ python -m pip install -r test-requirements-312.txt || exit 1
else
- python -m pip install -U pip setuptools wheel || exit 1
+ python -m pip install -U pip "setuptools<60" wheel || exit 1
if [[ $PYTHON_VERSION != *"-dev" || $COVERAGE == "1" ]]; then
python -m pip install -r test-requirements.txt || exit 1
-
if [[ $PYTHON_VERSION != "pypy"* && $PYTHON_VERSION != "3."[1]* ]]; then
python -m pip install -r test-requirements-cpython.txt || exit 1
fi
@@ -108,7 +136,16 @@ export PATH="/usr/lib/ccache:$PATH"
# Most modern compilers allow the last conflicting option
# to override the previous ones, so '-O0 -O3' == '-O3'
# This is true for the latest msvc, gcc and clang
-CFLAGS="-O0 -ggdb -Wall -Wextra"
+if [[ $OSTYPE == "msys" ]]; then # for MSVC cl
+ # /wd disables warnings
+ # 4711 warns that function `x` was selected for automatic inline expansion
+ # 4127 warns that a conditional expression is constant, should be fixed here https://github.com/cython/cython/pull/4317
+ # (off by default) 5045 warns that the compiler will insert Spectre mitigations for memory load if the /Qspectre switch is specified
+ # (off by default) 4820 warns about the code in Python\3.9.6\x64\include ...
+ CFLAGS="-Od /Z7 /MP /W4 /wd4711 /wd4127 /wd5045 /wd4820"
+else
+ CFLAGS="-O0 -ggdb -Wall -Wextra"
+fi
# Trying to cover debug assertions in the CI without adding
# extra jobs. Therefore, odd-numbered minor versions of Python
# running C++ jobs get NDEBUG undefined, and even-numbered
@@ -123,6 +160,9 @@ fi
if [[ $NO_CYTHON_COMPILE != "1" && $PYTHON_VERSION != "pypy"* ]]; then
BUILD_CFLAGS="$CFLAGS -O2"
+ if [[ $CYTHON_COMPILE_ALL == "1" && $OSTYPE != "msys" ]]; then
+ BUILD_CFLAGS="$CFLAGS -O3 -g0 -mtune=generic" # make wheel sizes comparable to standard wheel build
+ fi
if [[ $PYTHON_SYS_VERSION == "2"* ]]; then
BUILD_CFLAGS="$BUILD_CFLAGS -fno-strict-aliasing"
fi
@@ -134,24 +174,30 @@ if [[ $NO_CYTHON_COMPILE != "1" && $PYTHON_VERSION != "pypy"* ]]; then
if [[ $CYTHON_COMPILE_ALL == "1" ]]; then
SETUP_ARGS="$SETUP_ARGS --cython-compile-all"
fi
- #SETUP_ARGS="$SETUP_ARGS
- # $(python -c 'import sys; print("-j5" if sys.version_info >= (3,5) else "")')"
+ # It looks like parallel build may be causing occasional link failures on Windows
+ # "with exit code 1158". DW isn't completely sure of this, but has disabled it in
+ # the hope it helps
+ SETUP_ARGS="$SETUP_ARGS
+ $(python -c 'import sys; print("-j5" if sys.version_info >= (3,5) and not sys.platform.startswith("win") else "")')"
CFLAGS=$BUILD_CFLAGS \
python setup.py build_ext -i $SETUP_ARGS || exit 1
# COVERAGE can be either "" (empty or not set) or "1" (when we set it)
# STACKLESS can be either "" (empty or not set) or "true" (when we set it)
- # CYTHON_COMPILE_ALL can be either "" (empty or not set) or "1" (when we set it)
if [[ $COVERAGE != "1" && $STACKLESS != "true" && $BACKEND != *"cpp"* &&
- $CYTHON_COMPILE_ALL != "1" && $LIMITED_API == "" && $EXTRA_CFLAGS == "" ]]; then
+ $LIMITED_API == "" && $EXTRA_CFLAGS == "" ]]; then
python setup.py bdist_wheel || exit 1
+ ls -l dist/ || true
fi
+
+ echo "Extension modules created during the build:"
+ find Cython -name "*.so" -ls | sort -k11
fi
if [[ $TEST_CODE_STYLE == "1" ]]; then
- make -C docs html || echo "FIXME: docs build failed!"
-elif [[ $PYTHON_VERSION != "pypy"* ]]; then
+ make -C docs html || exit 1
+elif [[ $PYTHON_VERSION != "pypy"* && $OSTYPE != "msys" ]]; then
# Run the debugger tests in python-dbg if available
# (but don't fail, because they currently do fail)
PYTHON_DBG=$(python -c 'import sys; print("%d.%d" % sys.version_info[:2])')
@@ -181,6 +227,6 @@ python runtests.py \
EXIT_CODE=$?
-ccache -s 2>/dev/null || true
+ccache -s -v -v 2>/dev/null || true
exit $EXIT_CODE
diff --git a/Tools/cython-mode.el b/Tools/cython-mode.el
deleted file mode 100644
index e4be99f5b..000000000
--- a/Tools/cython-mode.el
+++ /dev/null
@@ -1,303 +0,0 @@
-;;; cython-mode.el --- Major mode for editing Cython files
-
-;; License: Apache-2.0
-
-;;; Commentary:
-
-;; This should work with python-mode.el as well as either the new
-;; python.el or the old.
-
-;;; Code:
-
-;; Load python-mode if available, otherwise use builtin emacs python package
-(when (not (require 'python-mode nil t))
- (require 'python))
-(eval-when-compile (require 'rx))
-
-;;;###autoload
-(add-to-list 'auto-mode-alist '("\\.pyx\\'" . cython-mode))
-;;;###autoload
-(add-to-list 'auto-mode-alist '("\\.pxd\\'" . cython-mode))
-;;;###autoload
-(add-to-list 'auto-mode-alist '("\\.pxi\\'" . cython-mode))
-
-
-(defvar cython-buffer nil
- "Variable pointing to the cython buffer which was compiled.")
-
-(defun cython-compile ()
- "Compile the file via Cython."
- (interactive)
- (let ((cy-buffer (current-buffer)))
- (with-current-buffer
- (compile compile-command)
- (set (make-local-variable 'cython-buffer) cy-buffer)
- (add-to-list (make-local-variable 'compilation-finish-functions)
- 'cython-compilation-finish))))
-
-(defun cython-compilation-finish (buffer how)
- "Called when Cython compilation finishes."
- ;; XXX could annotate source here
- )
-
-(defvar cython-mode-map
- (let ((map (make-sparse-keymap)))
- ;; Will inherit from `python-mode-map' thanks to define-derived-mode.
- (define-key map "\C-c\C-c" 'cython-compile)
- map)
- "Keymap used in `cython-mode'.")
-
-(defvar cython-font-lock-keywords
- `(;; ctypedef statement: "ctypedef (...type... alias)?"
- (,(rx
- ;; keyword itself
- symbol-start (group "ctypedef")
- ;; type specifier: at least 1 non-identifier symbol + 1 identifier
- ;; symbol and anything but a comment-starter after that.
- (opt (regexp "[^a-zA-Z0-9_\n]+[a-zA-Z0-9_][^#\n]*")
- ;; type alias: an identifier
- symbol-start (group (regexp "[a-zA-Z_]+[a-zA-Z0-9_]*"))
- ;; space-or-comments till the end of the line
- (* space) (opt "#" (* nonl)) line-end))
- (1 font-lock-keyword-face)
- (2 font-lock-type-face nil 'noerror))
- ;; new keywords in Cython language
- (,(rx symbol-start
- (or "by" "cdef" "cimport" "cpdef"
- "extern" "gil" "include" "nogil" "property" "public"
- "readonly" "DEF" "IF" "ELIF" "ELSE"
- "new" "del" "cppclass" "namespace" "const"
- "__stdcall" "__cdecl" "__fastcall" "inline" "api")
- symbol-end)
- . font-lock-keyword-face)
- ;; Question mark won't match at a symbol-end, so 'except?' must be
- ;; special-cased. It's simpler to handle it separately than weaving it
- ;; into the lengthy list of other keywords.
- (,(rx symbol-start "except?") . font-lock-keyword-face)
- ;; C and Python types (highlight as builtins)
- (,(rx symbol-start
- (or
- "object" "dict" "list"
- ;; basic c type names
- "void" "char" "int" "float" "double" "bint"
- ;; longness/signed/constness
- "signed" "unsigned" "long" "short"
- ;; special basic c types
- "size_t" "Py_ssize_t" "Py_UNICODE" "Py_UCS4" "ssize_t" "ptrdiff_t")
- symbol-end)
- . font-lock-builtin-face)
- (,(rx symbol-start "NULL" symbol-end)
- . font-lock-constant-face)
- ;; cdef is used for more than functions, so simply highlighting the next
- ;; word is problematic. struct, enum and property work though.
- (,(rx symbol-start
- (group (or "struct" "enum" "union"
- (seq "ctypedef" (+ space "fused"))))
- (+ space) (group (regexp "[a-zA-Z_]+[a-zA-Z0-9_]*")))
- (1 font-lock-keyword-face prepend) (2 font-lock-type-face))
- ("\\_<property[ \t]+\\([a-zA-Z_]+[a-zA-Z0-9_]*\\)"
- 1 font-lock-function-name-face))
- "Additional font lock keywords for Cython mode.")
-
-;;;###autoload
-(defgroup cython nil "Major mode for editing and compiling Cython files"
- :group 'languages
- :prefix "cython-"
- :link '(url-link :tag "Homepage" "http://cython.org"))
-
-;;;###autoload
-(defcustom cython-default-compile-format "cython -a %s"
- "Format for the default command to compile a Cython file.
-It will be passed to `format' with `buffer-file-name' as the only other argument."
- :group 'cython
- :type 'string)
-
-;; Some functions defined differently in the different python modes
-(defun cython-comment-line-p ()
- "Return non-nil if current line is a comment."
- (save-excursion
- (back-to-indentation)
- (eq ?# (char-after (point)))))
-
-(defun cython-in-string/comment ()
- "Return non-nil if point is in a comment or string."
- (nth 8 (syntax-ppss)))
-
-(defalias 'cython-beginning-of-statement
- (cond
- ;; python-mode.el
- ((fboundp 'py-beginning-of-statement)
- 'py-beginning-of-statement)
- ;; old python.el
- ((fboundp 'python-beginning-of-statement)
- 'python-beginning-of-statement)
- ;; new python.el
- ((fboundp 'python-nav-beginning-of-statement)
- 'python-nav-beginning-of-statement)
- (t (error "Couldn't find implementation for `cython-beginning-of-statement'"))))
-
-(defalias 'cython-beginning-of-block
- (cond
- ;; python-mode.el
- ((fboundp 'py-beginning-of-block)
- 'py-beginning-of-block)
- ;; old python.el
- ((fboundp 'python-beginning-of-block)
- 'python-beginning-of-block)
- ;; new python.el
- ((fboundp 'python-nav-beginning-of-block)
- 'python-nav-beginning-of-block)
- (t (error "Couldn't find implementation for `cython-beginning-of-block'"))))
-
-(defalias 'cython-end-of-statement
- (cond
- ;; python-mode.el
- ((fboundp 'py-end-of-statement)
- 'py-end-of-statement)
- ;; old python.el
- ((fboundp 'python-end-of-statement)
- 'python-end-of-statement)
- ;; new python.el
- ((fboundp 'python-nav-end-of-statement)
- 'python-nav-end-of-statement)
- (t (error "Couldn't find implementation for `cython-end-of-statement'"))))
-
-(defun cython-open-block-statement-p (&optional bos)
- "Return non-nil if statement at point opens a Cython block.
-BOS non-nil means point is known to be at beginning of statement."
- (save-excursion
- (unless bos (cython-beginning-of-statement))
- (looking-at (rx (and (or "if" "else" "elif" "while" "for" "def" "cdef" "cpdef"
- "class" "try" "except" "finally" "with"
- "EXAMPLES:" "TESTS:" "INPUT:" "OUTPUT:")
- symbol-end)))))
-
-(defun cython-beginning-of-defun ()
- "`beginning-of-defun-function' for Cython.
-Finds beginning of innermost nested class or method definition.
-Returns the name of the definition found at the end, or nil if
-reached start of buffer."
- (let ((ci (current-indentation))
- (def-re (rx line-start (0+ space) (or "def" "cdef" "cpdef" "class") (1+ space)
- (group (1+ (or word (syntax symbol))))))
- found lep) ;; def-line
- (if (cython-comment-line-p)
- (setq ci most-positive-fixnum))
- (while (and (not (bobp)) (not found))
- ;; Treat bol at beginning of function as outside function so
- ;; that successive C-M-a makes progress backwards.
- ;;(setq def-line (looking-at def-re))
- (unless (bolp) (end-of-line))
- (setq lep (line-end-position))
- (if (and (re-search-backward def-re nil 'move)
- ;; Must be less indented or matching top level, or
- ;; equally indented if we started on a definition line.
- (let ((in (current-indentation)))
- (or (and (zerop ci) (zerop in))
- (= lep (line-end-position)) ; on initial line
- ;; Not sure why it was like this -- fails in case of
- ;; last internal function followed by first
- ;; non-def statement of the main body.
- ;;(and def-line (= in ci))
- (= in ci)
- (< in ci)))
- (not (cython-in-string/comment)))
- (setq found t)))))
-
-(defun cython-end-of-defun ()
- "`end-of-defun-function' for Cython.
-Finds end of innermost nested class or method definition."
- (let ((orig (point))
- (pattern (rx line-start (0+ space) (or "def" "cdef" "cpdef" "class") space)))
- ;; Go to start of current block and check whether it's at top
- ;; level. If it is, and not a block start, look forward for
- ;; definition statement.
- (when (cython-comment-line-p)
- (end-of-line)
- (forward-comment most-positive-fixnum))
- (when (not (cython-open-block-statement-p))
- (cython-beginning-of-block))
- (if (zerop (current-indentation))
- (unless (cython-open-block-statement-p)
- (while (and (re-search-forward pattern nil 'move)
- (cython-in-string/comment))) ; just loop
- (unless (eobp)
- (beginning-of-line)))
- ;; Don't move before top-level statement that would end defun.
- (end-of-line)
- (beginning-of-defun))
- ;; If we got to the start of buffer, look forward for
- ;; definition statement.
- (when (and (bobp) (not (looking-at (rx (or "def" "cdef" "cpdef" "class")))))
- (while (and (not (eobp))
- (re-search-forward pattern nil 'move)
- (cython-in-string/comment)))) ; just loop
- ;; We're at a definition statement (or end-of-buffer).
- ;; This is where we should have started when called from end-of-defun
- (unless (eobp)
- (let ((block-indentation (current-indentation)))
- (python-nav-end-of-statement)
- (while (and (forward-line 1)
- (not (eobp))
- (or (and (> (current-indentation) block-indentation)
- (or (cython-end-of-statement) t))
- ;; comment or empty line
- (looking-at (rx (0+ space) (or eol "#"))))))
- (forward-comment -1))
- ;; Count trailing space in defun (but not trailing comments).
- (skip-syntax-forward " >")
- (unless (eobp) ; e.g. missing final newline
- (beginning-of-line)))
- ;; Catch pathological cases like this, where the beginning-of-defun
- ;; skips to a definition we're not in:
- ;; if ...:
- ;; ...
- ;; else:
- ;; ... # point here
- ;; ...
- ;; def ...
- (if (< (point) orig)
- (goto-char (point-max)))))
-
-(defun cython-current-defun ()
- "`add-log-current-defun-function' for Cython."
- (save-excursion
- ;; Move up the tree of nested `class' and `def' blocks until we
- ;; get to zero indentation, accumulating the defined names.
- (let ((start t)
- accum)
- (while (or start (> (current-indentation) 0))
- (setq start nil)
- (cython-beginning-of-block)
- (end-of-line)
- (beginning-of-defun)
- (if (looking-at (rx (0+ space) (or "def" "cdef" "cpdef" "class") (1+ space)
- (group (1+ (or word (syntax symbol))))))
- (push (match-string 1) accum)))
- (if accum (mapconcat 'identity accum ".")))))
-
-;;;###autoload
-(define-derived-mode cython-mode python-mode "Cython"
- "Major mode for Cython development, derived from Python mode.
-
-\\{cython-mode-map}"
- (font-lock-add-keywords nil cython-font-lock-keywords)
- (set (make-local-variable 'outline-regexp)
- (rx (* space) (or "class" "def" "cdef" "cpdef" "elif" "else" "except" "finally"
- "for" "if" "try" "while" "with")
- symbol-end))
- (set (make-local-variable 'beginning-of-defun-function)
- #'cython-beginning-of-defun)
- (set (make-local-variable 'end-of-defun-function)
- #'cython-end-of-defun)
- (set (make-local-variable 'compile-command)
- (format cython-default-compile-format (shell-quote-argument (or buffer-file-name ""))))
- (set (make-local-variable 'add-log-current-defun-function)
- #'cython-current-defun)
- (add-hook 'which-func-functions #'cython-current-defun nil t)
- (add-to-list (make-local-variable 'compilation-finish-functions)
- 'cython-compilation-finish))
-
-(provide 'cython-mode)
-
-;;; cython-mode.el ends here
diff --git a/Tools/dataclass_test_data/test_dataclasses.py b/Tools/dataclass_test_data/test_dataclasses.py
new file mode 100644
index 000000000..e2eab6957
--- /dev/null
+++ b/Tools/dataclass_test_data/test_dataclasses.py
@@ -0,0 +1,4266 @@
+# Deliberately use "from dataclasses import *". Every name in __all__
+# is tested, so they all must be present. This is a way to catch
+# missing ones.
+
+from dataclasses import *
+
+import abc
+import pickle
+import inspect
+import builtins
+import types
+import weakref
+import unittest
+from unittest.mock import Mock
+from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol
+from typing import get_type_hints
+from collections import deque, OrderedDict, namedtuple
+from functools import total_ordering
+
+import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
+import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
+
+# Just any custom exception we can catch.
+class CustomError(Exception): pass
+
+class TestCase(unittest.TestCase):
+ def test_no_fields(self):
+ @dataclass
+ class C:
+ pass
+
+ o = C()
+ self.assertEqual(len(fields(C)), 0)
+
+ def test_no_fields_but_member_variable(self):
+ @dataclass
+ class C:
+ i = 0
+
+ o = C()
+ self.assertEqual(len(fields(C)), 0)
+
+ def test_one_field_no_default(self):
+ @dataclass
+ class C:
+ x: int
+
+ o = C(42)
+ self.assertEqual(o.x, 42)
+
+ def test_field_default_default_factory_error(self):
+ msg = "cannot specify both default and default_factory"
+ with self.assertRaisesRegex(ValueError, msg):
+ @dataclass
+ class C:
+ x: int = field(default=1, default_factory=int)
+
+ def test_field_repr(self):
+ int_field = field(default=1, init=True, repr=False)
+ int_field.name = "id"
+ repr_output = repr(int_field)
+ expected_output = "Field(name='id',type=None," \
+ f"default=1,default_factory={MISSING!r}," \
+ "init=True,repr=False,hash=None," \
+ "compare=True,metadata=mappingproxy({})," \
+ f"kw_only={MISSING!r}," \
+ "_field_type=None)"
+
+ self.assertEqual(repr_output, expected_output)
+
+ def test_named_init_params(self):
+ @dataclass
+ class C:
+ x: int
+
+ o = C(x=32)
+ self.assertEqual(o.x, 32)
+
+ def test_two_fields_one_default(self):
+ @dataclass
+ class C:
+ x: int
+ y: int = 0
+
+ o = C(3)
+ self.assertEqual((o.x, o.y), (3, 0))
+
+ # Non-defaults following defaults.
+ with self.assertRaisesRegex(TypeError,
+ "non-default argument 'y' follows "
+ "default argument"):
+ @dataclass
+ class C:
+ x: int = 0
+ y: int
+
+ # A derived class adds a non-default field after a default one.
+ with self.assertRaisesRegex(TypeError,
+ "non-default argument 'y' follows "
+ "default argument"):
+ @dataclass
+ class B:
+ x: int = 0
+
+ @dataclass
+ class C(B):
+ y: int
+
+ # Override a base class field and add a default to
+ # a field which didn't use to have a default.
+ with self.assertRaisesRegex(TypeError,
+ "non-default argument 'y' follows "
+ "default argument"):
+ @dataclass
+ class B:
+ x: int
+ y: int
+
+ @dataclass
+ class C(B):
+ x: int = 0
+
+ def test_overwrite_hash(self):
+ # Test that declaring this class isn't an error. It should
+ # use the user-provided __hash__.
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __hash__(self):
+ return 301
+ self.assertEqual(hash(C(100)), 301)
+
+ # Test that declaring this class isn't an error. It should
+ # use the generated __hash__.
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return False
+ self.assertEqual(hash(C(100)), hash((100,)))
+
+ # But this one should generate an exception, because with
+ # unsafe_hash=True, it's an error to have a __hash__ defined.
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __hash__'):
+ @dataclass(unsafe_hash=True)
+ class C:
+ def __hash__(self):
+ pass
+
+ # Creating this class should not generate an exception,
+ # because even though __hash__ exists before @dataclass is
+ # called, (due to __eq__ being defined), since it's None
+ # that's okay.
+ @dataclass(unsafe_hash=True)
+ class C:
+ x: int
+ def __eq__(self):
+ pass
+ # The generated hash function works as we'd expect.
+ self.assertEqual(hash(C(10)), hash((10,)))
+
+ # Creating this class should generate an exception, because
+ # __hash__ exists and is not None, which it would be if it
+ # had been auto-generated due to __eq__ being defined.
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __hash__'):
+ @dataclass(unsafe_hash=True)
+ class C:
+ x: int
+ def __eq__(self):
+ pass
+ def __hash__(self):
+ pass
+
+ def test_overwrite_fields_in_derived_class(self):
+ # Note that x from C1 replaces x in Base, but the order remains
+ # the same as defined in Base.
+ @dataclass
+ class Base:
+ x: Any = 15.0
+ y: int = 0
+
+ @dataclass
+ class C1(Base):
+ z: int = 10
+ x: int = 15
+
+ o = Base()
+ self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
+
+ o = C1()
+ self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
+
+ o = C1(x=5)
+ self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
+
+ def test_field_named_self(self):
+ @dataclass
+ class C:
+ self: str
+ c=C('foo')
+ self.assertEqual(c.self, 'foo')
+
+ # Make sure the first parameter is not named 'self'.
+ sig = inspect.signature(C.__init__)
+ first = next(iter(sig.parameters))
+ self.assertNotEqual('self', first)
+
+ # But we do use 'self' if no field named self.
+ @dataclass
+ class C:
+ selfx: str
+
+ # Make sure the first parameter is named 'self'.
+ sig = inspect.signature(C.__init__)
+ first = next(iter(sig.parameters))
+ self.assertEqual('self', first)
+
+ def test_field_named_object(self):
+ @dataclass
+ class C:
+ object: str
+ c = C('foo')
+ self.assertEqual(c.object, 'foo')
+
+ def test_field_named_object_frozen(self):
+ @dataclass(frozen=True)
+ class C:
+ object: str
+ c = C('foo')
+ self.assertEqual(c.object, 'foo')
+
+ def test_field_named_like_builtin(self):
+ # Attribute names can shadow built-in names
+ # since code generation is used.
+ # Ensure that this is not happening.
+ exclusions = {'None', 'True', 'False'}
+ builtins_names = sorted(
+ b for b in builtins.__dict__.keys()
+ if not b.startswith('__') and b not in exclusions
+ )
+ attributes = [(name, str) for name in builtins_names]
+ C = make_dataclass('C', attributes)
+
+ c = C(*[name for name in builtins_names])
+
+ for name in builtins_names:
+ self.assertEqual(getattr(c, name), name)
+
+ def test_field_named_like_builtin_frozen(self):
+ # Attribute names can shadow built-in names
+ # since code generation is used.
+ # Ensure that this is not happening
+ # for frozen data classes.
+ exclusions = {'None', 'True', 'False'}
+ builtins_names = sorted(
+ b for b in builtins.__dict__.keys()
+ if not b.startswith('__') and b not in exclusions
+ )
+ attributes = [(name, str) for name in builtins_names]
+ C = make_dataclass('C', attributes, frozen=True)
+
+ c = C(*[name for name in builtins_names])
+
+ for name in builtins_names:
+ self.assertEqual(getattr(c, name), name)
+
+ def test_0_field_compare(self):
+ # Ensure that order=False is the default.
+ @dataclass
+ class C0:
+ pass
+
+ @dataclass(order=False)
+ class C1:
+ pass
+
+ for cls in [C0, C1]:
+ with self.subTest(cls=cls):
+ self.assertEqual(cls(), cls())
+ for idx, fn in enumerate([lambda a, b: a < b,
+ lambda a, b: a <= b,
+ lambda a, b: a > b,
+ lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ with self.assertRaisesRegex(TypeError,
+ f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
+ fn(cls(), cls())
+
+ @dataclass(order=True)
+ class C:
+ pass
+ self.assertLessEqual(C(), C())
+ self.assertGreaterEqual(C(), C())
+
+ def test_1_field_compare(self):
+ # Ensure that order=False is the default.
+ @dataclass
+ class C0:
+ x: int
+
+ @dataclass(order=False)
+ class C1:
+ x: int
+
+ for cls in [C0, C1]:
+ with self.subTest(cls=cls):
+ self.assertEqual(cls(1), cls(1))
+ self.assertNotEqual(cls(0), cls(1))
+ for idx, fn in enumerate([lambda a, b: a < b,
+ lambda a, b: a <= b,
+ lambda a, b: a > b,
+ lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ with self.assertRaisesRegex(TypeError,
+ f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
+ fn(cls(0), cls(0))
+
+ @dataclass(order=True)
+ class C:
+ x: int
+ self.assertLess(C(0), C(1))
+ self.assertLessEqual(C(0), C(1))
+ self.assertLessEqual(C(1), C(1))
+ self.assertGreater(C(1), C(0))
+ self.assertGreaterEqual(C(1), C(0))
+ self.assertGreaterEqual(C(1), C(1))
+
+ def test_simple_compare(self):
+ # Ensure that order=False is the default.
+ @dataclass
+ class C0:
+ x: int
+ y: int
+
+ @dataclass(order=False)
+ class C1:
+ x: int
+ y: int
+
+ for cls in [C0, C1]:
+ with self.subTest(cls=cls):
+ self.assertEqual(cls(0, 0), cls(0, 0))
+ self.assertEqual(cls(1, 2), cls(1, 2))
+ self.assertNotEqual(cls(1, 0), cls(0, 0))
+ self.assertNotEqual(cls(1, 0), cls(1, 1))
+ for idx, fn in enumerate([lambda a, b: a < b,
+ lambda a, b: a <= b,
+ lambda a, b: a > b,
+ lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ with self.assertRaisesRegex(TypeError,
+ f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
+ fn(cls(0, 0), cls(0, 0))
+
+ @dataclass(order=True)
+ class C:
+ x: int
+ y: int
+
+ for idx, fn in enumerate([lambda a, b: a == b,
+ lambda a, b: a <= b,
+ lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ self.assertTrue(fn(C(0, 0), C(0, 0)))
+
+ for idx, fn in enumerate([lambda a, b: a < b,
+ lambda a, b: a <= b,
+ lambda a, b: a != b]):
+ with self.subTest(idx=idx):
+ self.assertTrue(fn(C(0, 0), C(0, 1)))
+ self.assertTrue(fn(C(0, 1), C(1, 0)))
+ self.assertTrue(fn(C(1, 0), C(1, 1)))
+
+ for idx, fn in enumerate([lambda a, b: a > b,
+ lambda a, b: a >= b,
+ lambda a, b: a != b]):
+ with self.subTest(idx=idx):
+ self.assertTrue(fn(C(0, 1), C(0, 0)))
+ self.assertTrue(fn(C(1, 0), C(0, 1)))
+ self.assertTrue(fn(C(1, 1), C(1, 0)))
+
+ def test_compare_subclasses(self):
+ # Comparisons fail for subclasses, even if no fields
+ # are added.
+ @dataclass
+ class B:
+ i: int
+
+ @dataclass
+ class C(B):
+ pass
+
+ for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
+ (lambda a, b: a != b, True)]):
+ with self.subTest(idx=idx):
+ self.assertEqual(fn(B(0), C(0)), expected)
+
+ for idx, fn in enumerate([lambda a, b: a < b,
+ lambda a, b: a <= b,
+ lambda a, b: a > b,
+ lambda a, b: a >= b]):
+ with self.subTest(idx=idx):
+ with self.assertRaisesRegex(TypeError,
+ "not supported between instances of 'B' and 'C'"):
+ fn(B(0), C(0))
+
+ def test_eq_order(self):
+ # Test combining eq and order.
+ for (eq, order, result ) in [
+ (False, False, 'neither'),
+ (False, True, 'exception'),
+ (True, False, 'eq_only'),
+ (True, True, 'both'),
+ ]:
+ with self.subTest(eq=eq, order=order):
+ if result == 'exception':
+ with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
+ @dataclass(eq=eq, order=order)
+ class C:
+ pass
+ else:
+ @dataclass(eq=eq, order=order)
+ class C:
+ pass
+
+ if result == 'neither':
+ self.assertNotIn('__eq__', C.__dict__)
+ self.assertNotIn('__lt__', C.__dict__)
+ self.assertNotIn('__le__', C.__dict__)
+ self.assertNotIn('__gt__', C.__dict__)
+ self.assertNotIn('__ge__', C.__dict__)
+ elif result == 'both':
+ self.assertIn('__eq__', C.__dict__)
+ self.assertIn('__lt__', C.__dict__)
+ self.assertIn('__le__', C.__dict__)
+ self.assertIn('__gt__', C.__dict__)
+ self.assertIn('__ge__', C.__dict__)
+ elif result == 'eq_only':
+ self.assertIn('__eq__', C.__dict__)
+ self.assertNotIn('__lt__', C.__dict__)
+ self.assertNotIn('__le__', C.__dict__)
+ self.assertNotIn('__gt__', C.__dict__)
+ self.assertNotIn('__ge__', C.__dict__)
+ else:
+ assert False, f'unknown result {result!r}'
+
+ def test_field_no_default(self):
+ @dataclass
+ class C:
+ x: int = field()
+
+ self.assertEqual(C(5).x, 5)
+
+ with self.assertRaisesRegex(TypeError,
+ r"__init__\(\) missing 1 required "
+ "positional argument: 'x'"):
+ C()
+
+ def test_field_default(self):
+ default = object()
+ @dataclass
+ class C:
+ x: object = field(default=default)
+
+ self.assertIs(C.x, default)
+ c = C(10)
+ self.assertEqual(c.x, 10)
+
+ # If we delete the instance attribute, we should then see the
+ # class attribute.
+ del c.x
+ self.assertIs(c.x, default)
+
+ self.assertIs(C().x, default)
+
+ def test_not_in_repr(self):
+ @dataclass
+ class C:
+ x: int = field(repr=False)
+ with self.assertRaises(TypeError):
+ C()
+ c = C(10)
+ self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
+
+ @dataclass
+ class C:
+ x: int = field(repr=False)
+ y: int
+ c = C(10, 20)
+ self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
+
+ def test_not_in_compare(self):
+ @dataclass
+ class C:
+ x: int = 0
+ y: int = field(compare=False, default=4)
+
+ self.assertEqual(C(), C(0, 20))
+ self.assertEqual(C(1, 10), C(1, 20))
+ self.assertNotEqual(C(3), C(4, 10))
+ self.assertNotEqual(C(3, 10), C(4, 10))
+
+ def test_no_unhashable_default(self):
+ # See bpo-44674.
+ class Unhashable:
+ __hash__ = None
+
+ unhashable_re = 'mutable default .* for field a is not allowed'
+ with self.assertRaisesRegex(ValueError, unhashable_re):
+ @dataclass
+ class A:
+ a: dict = {}
+
+ with self.assertRaisesRegex(ValueError, unhashable_re):
+ @dataclass
+ class A:
+ a: Any = Unhashable()
+
+ # Make sure that the machinery looking for hashability is using the
+ # class's __hash__, not the instance's __hash__.
+ with self.assertRaisesRegex(ValueError, unhashable_re):
+ unhashable = Unhashable()
+ # This shouldn't make the variable hashable.
+ unhashable.__hash__ = lambda: 0
+ @dataclass
+ class A:
+ a: Any = unhashable
+
+ def test_hash_field_rules(self):
+ # Test all 6 cases of:
+ # hash=True/False/None
+ # compare=True/False
+ for (hash_, compare, result ) in [
+ (True, False, 'field' ),
+ (True, True, 'field' ),
+ (False, False, 'absent'),
+ (False, True, 'absent'),
+ (None, False, 'absent'),
+ (None, True, 'field' ),
+ ]:
+ with self.subTest(hash=hash_, compare=compare):
+ @dataclass(unsafe_hash=True)
+ class C:
+ x: int = field(compare=compare, hash=hash_, default=5)
+
+ if result == 'field':
+ # __hash__ contains the field.
+ self.assertEqual(hash(C(5)), hash((5,)))
+ elif result == 'absent':
+ # The field is not present in the hash.
+ self.assertEqual(hash(C(5)), hash(()))
+ else:
+ assert False, f'unknown result {result!r}'
+
+ def test_init_false_no_default(self):
+ # If init=False and no default value, then the field won't be
+ # present in the instance.
+ @dataclass
+ class C:
+ x: int = field(init=False)
+
+ self.assertNotIn('x', C().__dict__)
+
+ @dataclass
+ class C:
+ x: int
+ y: int = 0
+ z: int = field(init=False)
+ t: int = 10
+
+ self.assertNotIn('z', C(0).__dict__)
+ self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
+
+ def test_class_marker(self):
+ @dataclass
+ class C:
+ x: int
+ y: str = field(init=False, default=None)
+ z: str = field(repr=False)
+
+ the_fields = fields(C)
+ # the_fields is a tuple of 3 items, each value
+ # is in __annotations__.
+ self.assertIsInstance(the_fields, tuple)
+ for f in the_fields:
+ self.assertIs(type(f), Field)
+ self.assertIn(f.name, C.__annotations__)
+
+ self.assertEqual(len(the_fields), 3)
+
+ self.assertEqual(the_fields[0].name, 'x')
+ self.assertEqual(the_fields[0].type, int)
+ self.assertFalse(hasattr(C, 'x'))
+ self.assertTrue (the_fields[0].init)
+ self.assertTrue (the_fields[0].repr)
+ self.assertEqual(the_fields[1].name, 'y')
+ self.assertEqual(the_fields[1].type, str)
+ self.assertIsNone(getattr(C, 'y'))
+ self.assertFalse(the_fields[1].init)
+ self.assertTrue (the_fields[1].repr)
+ self.assertEqual(the_fields[2].name, 'z')
+ self.assertEqual(the_fields[2].type, str)
+ self.assertFalse(hasattr(C, 'z'))
+ self.assertTrue (the_fields[2].init)
+ self.assertFalse(the_fields[2].repr)
+
+ def test_field_order(self):
+ @dataclass
+ class B:
+ a: str = 'B:a'
+ b: str = 'B:b'
+ c: str = 'B:c'
+
+ @dataclass
+ class C(B):
+ b: str = 'C:b'
+
+ self.assertEqual([(f.name, f.default) for f in fields(C)],
+ [('a', 'B:a'),
+ ('b', 'C:b'),
+ ('c', 'B:c')])
+
+ @dataclass
+ class D(B):
+ c: str = 'D:c'
+
+ self.assertEqual([(f.name, f.default) for f in fields(D)],
+ [('a', 'B:a'),
+ ('b', 'B:b'),
+ ('c', 'D:c')])
+
+ @dataclass
+ class E(D):
+ a: str = 'E:a'
+ d: str = 'E:d'
+
+ self.assertEqual([(f.name, f.default) for f in fields(E)],
+ [('a', 'E:a'),
+ ('b', 'B:b'),
+ ('c', 'D:c'),
+ ('d', 'E:d')])
+
+ def test_class_attrs(self):
+ # We only have a class attribute if a default value is
+ # specified, either directly or via a field with a default.
+ default = object()
+ @dataclass
+ class C:
+ x: int
+ y: int = field(repr=False)
+ z: object = default
+ t: int = field(default=100)
+
+ self.assertFalse(hasattr(C, 'x'))
+ self.assertFalse(hasattr(C, 'y'))
+ self.assertIs (C.z, default)
+ self.assertEqual(C.t, 100)
+
+ def test_disallowed_mutable_defaults(self):
+ # For the known types, don't allow mutable default values.
+ for typ, empty, non_empty in [(list, [], [1]),
+ (dict, {}, {0:1}),
+ (set, set(), set([1])),
+ ]:
+ with self.subTest(typ=typ):
+ # Can't use a zero-length value.
+ with self.assertRaisesRegex(ValueError,
+ f'mutable default {typ} for field '
+ 'x is not allowed'):
+ @dataclass
+ class Point:
+ x: typ = empty
+
+
+ # Nor a non-zero-length value
+ with self.assertRaisesRegex(ValueError,
+ f'mutable default {typ} for field '
+ 'y is not allowed'):
+ @dataclass
+ class Point:
+ y: typ = non_empty
+
+ # Check subtypes also fail.
+ class Subclass(typ): pass
+
+ with self.assertRaisesRegex(ValueError,
+ f"mutable default .*Subclass'>"
+ ' for field z is not allowed'
+ ):
+ @dataclass
+ class Point:
+ z: typ = Subclass()
+
+ # Because this is a ClassVar, it can be mutable.
+ @dataclass
+ class C:
+ z: ClassVar[typ] = typ()
+
+ # Because this is a ClassVar, it can be mutable.
+ @dataclass
+ class C:
+ x: ClassVar[typ] = Subclass()
+
+ def test_deliberately_mutable_defaults(self):
+ # If a mutable default isn't in the known list of
+ # (list, dict, set), then it's okay.
+ class Mutable:
+ def __init__(self):
+ self.l = []
+
+ @dataclass
+ class C:
+ x: Mutable
+
+ # These 2 instances will share this value of x.
+ lst = Mutable()
+ o1 = C(lst)
+ o2 = C(lst)
+ self.assertEqual(o1, o2)
+ o1.x.l.extend([1, 2])
+ self.assertEqual(o1, o2)
+ self.assertEqual(o1.x.l, [1, 2])
+ self.assertIs(o1.x, o2.x)
+
+ def test_no_options(self):
+ # Call with dataclass().
+ @dataclass()
+ class C:
+ x: int
+
+ self.assertEqual(C(42).x, 42)
+
+ def test_not_tuple(self):
+ # Make sure we can't be compared to a tuple.
+ @dataclass
+ class Point:
+ x: int
+ y: int
+ self.assertNotEqual(Point(1, 2), (1, 2))
+
+ # And that we can't compare to another unrelated dataclass.
+ @dataclass
+ class C:
+ x: int
+ y: int
+ self.assertNotEqual(Point(1, 3), C(1, 3))
+
+ def test_not_other_dataclass(self):
+ # Test that some of the problems with namedtuple don't happen
+ # here.
+ @dataclass
+ class Point3D:
+ x: int
+ y: int
+ z: int
+
+ @dataclass
+ class Date:
+ year: int
+ month: int
+ day: int
+
+ self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
+ self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
+
+ # Make sure we can't unpack.
+ with self.assertRaisesRegex(TypeError, 'unpack'):
+ x, y, z = Point3D(4, 5, 6)
+
+ # Make sure another class with the same field names isn't
+ # equal.
+ @dataclass
+ class Point3Dv1:
+ x: int = 0
+ y: int = 0
+ z: int = 0
+ self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
+
+ def test_function_annotations(self):
+ # Some dummy class and instance to use as a default.
+ class F:
+ pass
+ f = F()
+
+ def validate_class(cls):
+ # First, check __annotations__, even though they're not
+ # function annotations.
+ self.assertEqual(cls.__annotations__['i'], int)
+ self.assertEqual(cls.__annotations__['j'], str)
+ self.assertEqual(cls.__annotations__['k'], F)
+ self.assertEqual(cls.__annotations__['l'], float)
+ self.assertEqual(cls.__annotations__['z'], complex)
+
+ # Verify __init__.
+
+ signature = inspect.signature(cls.__init__)
+ # Check the return type, should be None.
+ self.assertIs(signature.return_annotation, None)
+
+ # Check each parameter.
+ params = iter(signature.parameters.values())
+ param = next(params)
+ # This is testing an internal name, and probably shouldn't be tested.
+ self.assertEqual(param.name, 'self')
+ param = next(params)
+ self.assertEqual(param.name, 'i')
+ self.assertIs (param.annotation, int)
+ self.assertEqual(param.default, inspect.Parameter.empty)
+ self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ param = next(params)
+ self.assertEqual(param.name, 'j')
+ self.assertIs (param.annotation, str)
+ self.assertEqual(param.default, inspect.Parameter.empty)
+ self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ param = next(params)
+ self.assertEqual(param.name, 'k')
+ self.assertIs (param.annotation, F)
+ # Don't test for the default, since it's set to MISSING.
+ self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ param = next(params)
+ self.assertEqual(param.name, 'l')
+ self.assertIs (param.annotation, float)
+ # Don't test for the default, since it's set to MISSING.
+ self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
+ self.assertRaises(StopIteration, next, params)
+
+
+ @dataclass
+ class C:
+ i: int
+ j: str
+ k: F = f
+ l: float=field(default=None)
+ z: complex=field(default=3+4j, init=False)
+
+ validate_class(C)
+
+ # Now repeat with __hash__.
+ @dataclass(frozen=True, unsafe_hash=True)
+ class C:
+ i: int
+ j: str
+ k: F = f
+ l: float=field(default=None)
+ z: complex=field(default=3+4j, init=False)
+
+ validate_class(C)
+
+ def test_missing_default(self):
+ # Test that MISSING works the same as a default not being
+ # specified.
+ @dataclass
+ class C:
+ x: int=field(default=MISSING)
+ with self.assertRaisesRegex(TypeError,
+ r'__init__\(\) missing 1 required '
+ 'positional argument'):
+ C()
+ self.assertNotIn('x', C.__dict__)
+
+ @dataclass
+ class D:
+ x: int
+ with self.assertRaisesRegex(TypeError,
+ r'__init__\(\) missing 1 required '
+ 'positional argument'):
+ D()
+ self.assertNotIn('x', D.__dict__)
+
+ def test_missing_default_factory(self):
+ # Test that MISSING works the same as a default factory not
+ # being specified (which is really the same as a default not
+ # being specified, too).
+ @dataclass
+ class C:
+ x: int=field(default_factory=MISSING)
+ with self.assertRaisesRegex(TypeError,
+ r'__init__\(\) missing 1 required '
+ 'positional argument'):
+ C()
+ self.assertNotIn('x', C.__dict__)
+
+ @dataclass
+ class D:
+ x: int=field(default=MISSING, default_factory=MISSING)
+ with self.assertRaisesRegex(TypeError,
+ r'__init__\(\) missing 1 required '
+ 'positional argument'):
+ D()
+ self.assertNotIn('x', D.__dict__)
+
+ def test_missing_repr(self):
+ self.assertIn('MISSING_TYPE object', repr(MISSING))
+
+ def test_dont_include_other_annotations(self):
+ @dataclass
+ class C:
+ i: int
+ def foo(self) -> int:
+ return 4
+ @property
+ def bar(self) -> int:
+ return 5
+ self.assertEqual(list(C.__annotations__), ['i'])
+ self.assertEqual(C(10).foo(), 4)
+ self.assertEqual(C(10).bar, 5)
+ self.assertEqual(C(10).i, 10)
+
+ def test_post_init(self):
+ # Just make sure it gets called
+ @dataclass
+ class C:
+ def __post_init__(self):
+ raise CustomError()
+ with self.assertRaises(CustomError):
+ C()
+
+ @dataclass
+ class C:
+ i: int = 10
+ def __post_init__(self):
+ if self.i == 10:
+ raise CustomError()
+ with self.assertRaises(CustomError):
+ C()
+ # post-init gets called, but doesn't raise. This is just
+ # checking that self is used correctly.
+ C(5)
+
+ # If there's not an __init__, then post-init won't get called.
+ @dataclass(init=False)
+ class C:
+ def __post_init__(self):
+ raise CustomError()
+ # Creating the class won't raise
+ C()
+
+ @dataclass
+ class C:
+ x: int = 0
+ def __post_init__(self):
+ self.x *= 2
+ self.assertEqual(C().x, 0)
+ self.assertEqual(C(2).x, 4)
+
+ # Make sure that if we're frozen, post-init can't set
+ # attributes.
+ @dataclass(frozen=True)
+ class C:
+ x: int = 0
+ def __post_init__(self):
+ self.x *= 2
+ with self.assertRaises(FrozenInstanceError):
+ C()
+
+ def test_post_init_super(self):
+ # Make sure super() post-init isn't called by default.
+ class B:
+ def __post_init__(self):
+ raise CustomError()
+
+ @dataclass
+ class C(B):
+ def __post_init__(self):
+ self.x = 5
+
+ self.assertEqual(C().x, 5)
+
+ # Now call super(), and it will raise.
+ @dataclass
+ class C(B):
+ def __post_init__(self):
+ super().__post_init__()
+
+ with self.assertRaises(CustomError):
+ C()
+
+ # Make sure post-init is called, even if not defined in our
+ # class.
+ @dataclass
+ class C(B):
+ pass
+
+ with self.assertRaises(CustomError):
+ C()
+
+ def test_post_init_staticmethod(self):
+ flag = False
+ @dataclass
+ class C:
+ x: int
+ y: int
+ @staticmethod
+ def __post_init__():
+ nonlocal flag
+ flag = True
+
+ self.assertFalse(flag)
+ c = C(3, 4)
+ self.assertEqual((c.x, c.y), (3, 4))
+ self.assertTrue(flag)
+
+ def test_post_init_classmethod(self):
+ @dataclass
+ class C:
+ flag = False
+ x: int
+ y: int
+ @classmethod
+ def __post_init__(cls):
+ cls.flag = True
+
+ self.assertFalse(C.flag)
+ c = C(3, 4)
+ self.assertEqual((c.x, c.y), (3, 4))
+ self.assertTrue(C.flag)
+
+ def test_post_init_not_auto_added(self):
+ # See bpo-46757, which had proposed always adding __post_init__. As
+ # Raymond Hettinger pointed out, that would be a breaking change. So,
+ # add a test to make sure that the current behavior doesn't change.
+
+ @dataclass
+ class A0:
+ pass
+
+ @dataclass
+ class B0:
+ b_called: bool = False
+ def __post_init__(self):
+ self.b_called = True
+
+ @dataclass
+ class C0(A0, B0):
+ c_called: bool = False
+ def __post_init__(self):
+ super().__post_init__()
+ self.c_called = True
+
+ # Since A0 has no __post_init__, and one wasn't automatically added
+ # (because that's the rule: it's never added by @dataclass, it's only
+ # the class author that can add it), then B0.__post_init__ is called.
+ # Verify that.
+ c = C0()
+ self.assertTrue(c.b_called)
+ self.assertTrue(c.c_called)
+
+ ######################################
+ # Now, the same thing, except A1 defines __post_init__.
+ @dataclass
+ class A1:
+ def __post_init__(self):
+ pass
+
+ @dataclass
+ class B1:
+ b_called: bool = False
+ def __post_init__(self):
+ self.b_called = True
+
+ @dataclass
+ class C1(A1, B1):
+ c_called: bool = False
+ def __post_init__(self):
+ super().__post_init__()
+ self.c_called = True
+
+ # This time, B1.__post_init__ isn't being called. This mimics what
+ # would happen if A1.__post_init__ had been automatically added,
+ # instead of manually added as we see here. This test isn't really
+ # needed, but I'm including it just to demonstrate the changed
+ # behavior when A1 does define __post_init__.
+ c = C1()
+ self.assertFalse(c.b_called)
+ self.assertTrue(c.c_called)
+
+ def test_class_var(self):
+ # Make sure ClassVars are ignored in __init__, __repr__, etc.
+ @dataclass
+ class C:
+ x: int
+ y: int = 10
+ z: ClassVar[int] = 1000
+ w: ClassVar[int] = 2000
+ t: ClassVar[int] = 3000
+ s: ClassVar = 4000
+
+ c = C(5)
+ self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
+ self.assertEqual(len(fields(C)), 2) # We have 2 fields.
+ self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
+ self.assertEqual(c.z, 1000)
+ self.assertEqual(c.w, 2000)
+ self.assertEqual(c.t, 3000)
+ self.assertEqual(c.s, 4000)
+ C.z += 1
+ self.assertEqual(c.z, 1001)
+ c = C(20)
+ self.assertEqual((c.x, c.y), (20, 10))
+ self.assertEqual(c.z, 1001)
+ self.assertEqual(c.w, 2000)
+ self.assertEqual(c.t, 3000)
+ self.assertEqual(c.s, 4000)
+
+ def test_class_var_no_default(self):
+ # If a ClassVar has no default value, it should not be set on the class.
+ @dataclass
+ class C:
+ x: ClassVar[int]
+
+ self.assertNotIn('x', C.__dict__)
+
+ def test_class_var_default_factory(self):
+ # It makes no sense for a ClassVar to have a default factory. When
+ # would it be called? Call it yourself, since it's class-wide.
+ with self.assertRaisesRegex(TypeError,
+ 'cannot have a default factory'):
+ @dataclass
+ class C:
+ x: ClassVar[int] = field(default_factory=int)
+
+ self.assertNotIn('x', C.__dict__)
+
+ def test_class_var_with_default(self):
+ # If a ClassVar has a default value, it should be set on the class.
+ @dataclass
+ class C:
+ x: ClassVar[int] = 10
+ self.assertEqual(C.x, 10)
+
+ @dataclass
+ class C:
+ x: ClassVar[int] = field(default=10)
+ self.assertEqual(C.x, 10)
+
+ def test_class_var_frozen(self):
+ # Make sure ClassVars work even if we're frozen.
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int = 10
+ z: ClassVar[int] = 1000
+ w: ClassVar[int] = 2000
+ t: ClassVar[int] = 3000
+
+ c = C(5)
+ self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
+ self.assertEqual(len(fields(C)), 2) # We have 2 fields
+ self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
+ self.assertEqual(c.z, 1000)
+ self.assertEqual(c.w, 2000)
+ self.assertEqual(c.t, 3000)
+ # We can still modify the ClassVar, it's only instances that are
+ # frozen.
+ C.z += 1
+ self.assertEqual(c.z, 1001)
+ c = C(20)
+ self.assertEqual((c.x, c.y), (20, 10))
+ self.assertEqual(c.z, 1001)
+ self.assertEqual(c.w, 2000)
+ self.assertEqual(c.t, 3000)
+
+ def test_init_var_no_default(self):
+ # If an InitVar has no default value, it should not be set on the class.
+ @dataclass
+ class C:
+ x: InitVar[int]
+
+ self.assertNotIn('x', C.__dict__)
+
+ def test_init_var_default_factory(self):
+ # It makes no sense for an InitVar to have a default factory. When
+ # would it be called? Call it yourself, since it's class-wide.
+ with self.assertRaisesRegex(TypeError,
+ 'cannot have a default factory'):
+ @dataclass
+ class C:
+ x: InitVar[int] = field(default_factory=int)
+
+ self.assertNotIn('x', C.__dict__)
+
+ def test_init_var_with_default(self):
+ # If an InitVar has a default value, it should be set on the class.
+ @dataclass
+ class C:
+ x: InitVar[int] = 10
+ self.assertEqual(C.x, 10)
+
+ @dataclass
+ class C:
+ x: InitVar[int] = field(default=10)
+ self.assertEqual(C.x, 10)
+
+ def test_init_var(self):
+ @dataclass
+ class C:
+ x: int = None
+ init_param: InitVar[int] = None
+
+ def __post_init__(self, init_param):
+ if self.x is None:
+ self.x = init_param*2
+
+ c = C(init_param=10)
+ self.assertEqual(c.x, 20)
+
+ def test_init_var_preserve_type(self):
+ self.assertEqual(InitVar[int].type, int)
+
+ # Make sure the repr is correct.
+ self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
+ self.assertEqual(repr(InitVar[List[int]]),
+ 'dataclasses.InitVar[typing.List[int]]')
+ self.assertEqual(repr(InitVar[list[int]]),
+ 'dataclasses.InitVar[list[int]]')
+ self.assertEqual(repr(InitVar[int|str]),
+ 'dataclasses.InitVar[int | str]')
+
+ def test_init_var_inheritance(self):
+ # Note that this deliberately tests that a dataclass need not
+ # have a __post_init__ function if it has an InitVar field.
+ # It could just be used in a derived class, as shown here.
+ @dataclass
+ class Base:
+ x: int
+ init_base: InitVar[int]
+
+ # We can instantiate by passing the InitVar, even though
+ # it's not used.
+ b = Base(0, 10)
+ self.assertEqual(vars(b), {'x': 0})
+
+ @dataclass
+ class C(Base):
+ y: int
+ init_derived: InitVar[int]
+
+ def __post_init__(self, init_base, init_derived):
+ self.x = self.x + init_base
+ self.y = self.y + init_derived
+
+ c = C(10, 11, 50, 51)
+ self.assertEqual(vars(c), {'x': 21, 'y': 101})
+
+ def test_default_factory(self):
+ # Test a factory that returns a new list.
+ @dataclass
+ class C:
+ x: int
+ y: list = field(default_factory=list)
+
+ c0 = C(3)
+ c1 = C(3)
+ self.assertEqual(c0.x, 3)
+ self.assertEqual(c0.y, [])
+ self.assertEqual(c0, c1)
+ self.assertIsNot(c0.y, c1.y)
+ self.assertEqual(astuple(C(5, [1])), (5, [1]))
+
+ # Test a factory that returns a shared list.
+ l = []
+ @dataclass
+ class C:
+ x: int
+ y: list = field(default_factory=lambda: l)
+
+ c0 = C(3)
+ c1 = C(3)
+ self.assertEqual(c0.x, 3)
+ self.assertEqual(c0.y, [])
+ self.assertEqual(c0, c1)
+ self.assertIs(c0.y, c1.y)
+ self.assertEqual(astuple(C(5, [1])), (5, [1]))
+
+ # Test various other field flags.
+ # repr
+ @dataclass
+ class C:
+ x: list = field(default_factory=list, repr=False)
+ self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
+ self.assertEqual(C().x, [])
+
+ # hash
+ @dataclass(unsafe_hash=True)
+ class C:
+ x: list = field(default_factory=list, hash=False)
+ self.assertEqual(astuple(C()), ([],))
+ self.assertEqual(hash(C()), hash(()))
+
+ # init (see also test_default_factory_with_no_init)
+ @dataclass
+ class C:
+ x: list = field(default_factory=list, init=False)
+ self.assertEqual(astuple(C()), ([],))
+
+ # compare
+ @dataclass
+ class C:
+ x: list = field(default_factory=list, compare=False)
+ self.assertEqual(C(), C([1]))
+
+ def test_default_factory_with_no_init(self):
+ # We need a factory with a side effect.
+ factory = Mock()
+
+ @dataclass
+ class C:
+ x: list = field(default_factory=factory, init=False)
+
+ # Make sure the default factory is called for each new instance.
+ C().x
+ self.assertEqual(factory.call_count, 1)
+ C().x
+ self.assertEqual(factory.call_count, 2)
+
+ def test_default_factory_not_called_if_value_given(self):
+ # We need a factory that we can test if it's been called.
+ factory = Mock()
+
+ @dataclass
+ class C:
+ x: int = field(default_factory=factory)
+
+ # Make sure that if a field has a default factory function,
+ # it's not called if a value is specified.
+ C().x
+ self.assertEqual(factory.call_count, 1)
+ self.assertEqual(C(10).x, 10)
+ self.assertEqual(factory.call_count, 1)
+ C().x
+ self.assertEqual(factory.call_count, 2)
+
+ def test_default_factory_derived(self):
+ # See bpo-32896.
+ @dataclass
+ class Foo:
+ x: dict = field(default_factory=dict)
+
+ @dataclass
+ class Bar(Foo):
+ y: int = 1
+
+ self.assertEqual(Foo().x, {})
+ self.assertEqual(Bar().x, {})
+ self.assertEqual(Bar().y, 1)
+
+ @dataclass
+ class Baz(Foo):
+ pass
+ self.assertEqual(Baz().x, {})
+
+ def test_intermediate_non_dataclass(self):
+ # Test that an intermediate class that defines
+ # annotations does not define fields.
+
+ @dataclass
+ class A:
+ x: int
+
+ class B(A):
+ y: int
+
+ @dataclass
+ class C(B):
+ z: int
+
+ c = C(1, 3)
+ self.assertEqual((c.x, c.z), (1, 3))
+
+ # .y was not initialized.
+ with self.assertRaisesRegex(AttributeError,
+ 'object has no attribute'):
+ c.y
+
+ # And if we again derive a non-dataclass, no fields are added.
+ class D(C):
+ t: int
+ d = D(4, 5)
+ self.assertEqual((d.x, d.z), (4, 5))
+
+ def test_classvar_default_factory(self):
+ # It's an error for a ClassVar to have a factory function.
+ with self.assertRaisesRegex(TypeError,
+ 'cannot have a default factory'):
+ @dataclass
+ class C:
+ x: ClassVar[int] = field(default_factory=int)
+
+ def test_is_dataclass(self):
+ class NotDataClass:
+ pass
+
+ self.assertFalse(is_dataclass(0))
+ self.assertFalse(is_dataclass(int))
+ self.assertFalse(is_dataclass(NotDataClass))
+ self.assertFalse(is_dataclass(NotDataClass()))
+
+ @dataclass
+ class C:
+ x: int
+
+ @dataclass
+ class D:
+ d: C
+ e: int
+
+ c = C(10)
+ d = D(c, 4)
+
+ self.assertTrue(is_dataclass(C))
+ self.assertTrue(is_dataclass(c))
+ self.assertFalse(is_dataclass(c.x))
+ self.assertTrue(is_dataclass(d.d))
+ self.assertFalse(is_dataclass(d.e))
+
+ def test_is_dataclass_when_getattr_always_returns(self):
+ # See bpo-37868.
+ class A:
+ def __getattr__(self, key):
+ return 0
+ self.assertFalse(is_dataclass(A))
+ a = A()
+
+ # Also test for an instance attribute.
+ class B:
+ pass
+ b = B()
+ b.__dataclass_fields__ = []
+
+ for obj in a, b:
+ with self.subTest(obj=obj):
+ self.assertFalse(is_dataclass(obj))
+
+ # Indirect tests for _is_dataclass_instance().
+ with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
+ asdict(obj)
+ with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
+ astuple(obj)
+ with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
+ replace(obj, x=0)
+
+ def test_is_dataclass_genericalias(self):
+ @dataclass
+ class A(types.GenericAlias):
+ origin: type
+ args: type
+ self.assertTrue(is_dataclass(A))
+ a = A(list, int)
+ self.assertTrue(is_dataclass(type(a)))
+ self.assertTrue(is_dataclass(a))
+
+
+ def test_helper_fields_with_class_instance(self):
+ # Check that we can call fields() on either a class or instance,
+ # and get back the same thing.
+ @dataclass
+ class C:
+ x: int
+ y: float
+
+ self.assertEqual(fields(C), fields(C(0, 0.0)))
+
+ def test_helper_fields_exception(self):
+ # Check that TypeError is raised if not passed a dataclass or
+ # instance.
+ with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
+ fields(0)
+
+ class C: pass
+ with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
+ fields(C)
+ with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
+ fields(C())
+
+ def test_helper_asdict(self):
+ # Basic tests for asdict(), it should return a new dictionary.
+ @dataclass
+ class C:
+ x: int
+ y: int
+ c = C(1, 2)
+
+ self.assertEqual(asdict(c), {'x': 1, 'y': 2})
+ self.assertEqual(asdict(c), asdict(c))
+ self.assertIsNot(asdict(c), asdict(c))
+ c.x = 42
+ self.assertEqual(asdict(c), {'x': 42, 'y': 2})
+ self.assertIs(type(asdict(c)), dict)
+
+ def test_helper_asdict_raises_on_classes(self):
+ # asdict() should raise on a class object.
+ @dataclass
+ class C:
+ x: int
+ y: int
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ asdict(C)
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ asdict(int)
+
+ def test_helper_asdict_copy_values(self):
+ @dataclass
+ class C:
+ x: int
+ y: List[int] = field(default_factory=list)
+ initial = []
+ c = C(1, initial)
+ d = asdict(c)
+ self.assertEqual(d['y'], initial)
+ self.assertIsNot(d['y'], initial)
+ c = C(1)
+ d = asdict(c)
+ d['y'].append(1)
+ self.assertEqual(c.y, [])
+
+ def test_helper_asdict_nested(self):
+ @dataclass
+ class UserId:
+ token: int
+ group: int
+ @dataclass
+ class User:
+ name: str
+ id: UserId
+ u = User('Joe', UserId(123, 1))
+ d = asdict(u)
+ self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
+ self.assertIsNot(asdict(u), asdict(u))
+ u.id.group = 2
+ self.assertEqual(asdict(u), {'name': 'Joe',
+ 'id': {'token': 123, 'group': 2}})
+
+ def test_helper_asdict_builtin_containers(self):
+ @dataclass
+ class User:
+ name: str
+ id: int
+ @dataclass
+ class GroupList:
+ id: int
+ users: List[User]
+ @dataclass
+ class GroupTuple:
+ id: int
+ users: Tuple[User, ...]
+ @dataclass
+ class GroupDict:
+ id: int
+ users: Dict[str, User]
+ a = User('Alice', 1)
+ b = User('Bob', 2)
+ gl = GroupList(0, [a, b])
+ gt = GroupTuple(0, (a, b))
+ gd = GroupDict(0, {'first': a, 'second': b})
+ self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
+ {'name': 'Bob', 'id': 2}]})
+ self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
+ {'name': 'Bob', 'id': 2})})
+ self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
+ 'second': {'name': 'Bob', 'id': 2}}})
+
+ def test_helper_asdict_builtin_object_containers(self):
+ @dataclass
+ class Child:
+ d: object
+
+ @dataclass
+ class Parent:
+ child: Child
+
+ self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
+ self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
+
+ def test_helper_asdict_factory(self):
+ @dataclass
+ class C:
+ x: int
+ y: int
+ c = C(1, 2)
+ d = asdict(c, dict_factory=OrderedDict)
+ self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
+ self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
+ c.x = 42
+ d = asdict(c, dict_factory=OrderedDict)
+ self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
+ self.assertIs(type(d), OrderedDict)
+
+ def test_helper_asdict_namedtuple(self):
+ T = namedtuple('T', 'a b c')
+ @dataclass
+ class C:
+ x: str
+ y: T
+ c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
+
+ d = asdict(c)
+ self.assertEqual(d, {'x': 'outer',
+ 'y': T(1,
+ {'x': 'inner',
+ 'y': T(11, 12, 13)},
+ 2),
+ }
+ )
+
+ # Now with a dict_factory. OrderedDict is convenient, but
+ # since it compares to dicts, we also need to have separate
+ # assertIs tests.
+ d = asdict(c, dict_factory=OrderedDict)
+ self.assertEqual(d, {'x': 'outer',
+ 'y': T(1,
+ {'x': 'inner',
+ 'y': T(11, 12, 13)},
+ 2),
+ }
+ )
+
+ # Make sure that the returned dicts are actually OrderedDicts.
+ self.assertIs(type(d), OrderedDict)
+ self.assertIs(type(d['y'][1]), OrderedDict)
+
+ def test_helper_asdict_namedtuple_key(self):
+ # Ensure that a field that contains a dict which has a
+ # namedtuple as a key works with asdict().
+
+ @dataclass
+ class C:
+ f: dict
+ T = namedtuple('T', 'a')
+
+ c = C({T('an a'): 0})
+
+ self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
+
+ def test_helper_asdict_namedtuple_derived(self):
+ class T(namedtuple('Tbase', 'a')):
+ def my_a(self):
+ return self.a
+
+ @dataclass
+ class C:
+ f: T
+
+ t = T(6)
+ c = C(t)
+
+ d = asdict(c)
+ self.assertEqual(d, {'f': T(a=6)})
+ # Make sure that t has been copied, not used directly.
+ self.assertIsNot(d['f'], t)
+ self.assertEqual(d['f'].my_a(), 6)
+
+ def test_helper_astuple(self):
+ # Basic tests for astuple(), it should return a new tuple.
+ @dataclass
+ class C:
+ x: int
+ y: int = 0
+ c = C(1)
+
+ self.assertEqual(astuple(c), (1, 0))
+ self.assertEqual(astuple(c), astuple(c))
+ self.assertIsNot(astuple(c), astuple(c))
+ c.y = 42
+ self.assertEqual(astuple(c), (1, 42))
+ self.assertIs(type(astuple(c)), tuple)
+
+ def test_helper_astuple_raises_on_classes(self):
+ # astuple() should raise on a class object.
+ @dataclass
+ class C:
+ x: int
+ y: int
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ astuple(C)
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ astuple(int)
+
+ def test_helper_astuple_copy_values(self):
+ @dataclass
+ class C:
+ x: int
+ y: List[int] = field(default_factory=list)
+ initial = []
+ c = C(1, initial)
+ t = astuple(c)
+ self.assertEqual(t[1], initial)
+ self.assertIsNot(t[1], initial)
+ c = C(1)
+ t = astuple(c)
+ t[1].append(1)
+ self.assertEqual(c.y, [])
+
+ def test_helper_astuple_nested(self):
+ @dataclass
+ class UserId:
+ token: int
+ group: int
+ @dataclass
+ class User:
+ name: str
+ id: UserId
+ u = User('Joe', UserId(123, 1))
+ t = astuple(u)
+ self.assertEqual(t, ('Joe', (123, 1)))
+ self.assertIsNot(astuple(u), astuple(u))
+ u.id.group = 2
+ self.assertEqual(astuple(u), ('Joe', (123, 2)))
+
+ def test_helper_astuple_builtin_containers(self):
+ @dataclass
+ class User:
+ name: str
+ id: int
+ @dataclass
+ class GroupList:
+ id: int
+ users: List[User]
+ @dataclass
+ class GroupTuple:
+ id: int
+ users: Tuple[User, ...]
+ @dataclass
+ class GroupDict:
+ id: int
+ users: Dict[str, User]
+ a = User('Alice', 1)
+ b = User('Bob', 2)
+ gl = GroupList(0, [a, b])
+ gt = GroupTuple(0, (a, b))
+ gd = GroupDict(0, {'first': a, 'second': b})
+ self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
+ self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
+ self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
+
+ def test_helper_astuple_builtin_object_containers(self):
+ @dataclass
+ class Child:
+ d: object
+
+ @dataclass
+ class Parent:
+ child: Child
+
+ self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
+ self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
+
+ def test_helper_astuple_factory(self):
+ @dataclass
+ class C:
+ x: int
+ y: int
+ NT = namedtuple('NT', 'x y')
+ def nt(lst):
+ return NT(*lst)
+ c = C(1, 2)
+ t = astuple(c, tuple_factory=nt)
+ self.assertEqual(t, NT(1, 2))
+ self.assertIsNot(t, astuple(c, tuple_factory=nt))
+ c.x = 42
+ t = astuple(c, tuple_factory=nt)
+ self.assertEqual(t, NT(42, 2))
+ self.assertIs(type(t), NT)
+
+ def test_helper_astuple_namedtuple(self):
+ T = namedtuple('T', 'a b c')
+ @dataclass
+ class C:
+ x: str
+ y: T
+ c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
+
+ t = astuple(c)
+ self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
+
+ # Now, using a tuple_factory. list is convenient here.
+ t = astuple(c, tuple_factory=list)
+ self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
+
+ def test_dynamic_class_creation(self):
+ cls_dict = {'__annotations__': {'x': int, 'y': int},
+ }
+
+ # Create the class.
+ cls = type('C', (), cls_dict)
+
+ # Make it a dataclass.
+ cls1 = dataclass(cls)
+
+ self.assertEqual(cls1, cls)
+ self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
+
+ def test_dynamic_class_creation_using_field(self):
+ cls_dict = {'__annotations__': {'x': int, 'y': int},
+ 'y': field(default=5),
+ }
+
+ # Create the class.
+ cls = type('C', (), cls_dict)
+
+ # Make it a dataclass.
+ cls1 = dataclass(cls)
+
+ self.assertEqual(cls1, cls)
+ self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
+
+ def test_init_in_order(self):
+ @dataclass
+ class C:
+ a: int
+ b: int = field()
+ c: list = field(default_factory=list, init=False)
+ d: list = field(default_factory=list)
+ e: int = field(default=4, init=False)
+ f: int = 4
+
+ calls = []
+ def setattr(self, name, value):
+ calls.append((name, value))
+
+ C.__setattr__ = setattr
+ c = C(0, 1)
+ self.assertEqual(('a', 0), calls[0])
+ self.assertEqual(('b', 1), calls[1])
+ self.assertEqual(('c', []), calls[2])
+ self.assertEqual(('d', []), calls[3])
+ self.assertNotIn(('e', 4), calls)
+ self.assertEqual(('f', 4), calls[4])
+
+ def test_items_in_dicts(self):
+ @dataclass
+ class C:
+ a: int
+ b: list = field(default_factory=list, init=False)
+ c: list = field(default_factory=list)
+ d: int = field(default=4, init=False)
+ e: int = 0
+
+ c = C(0)
+ # Class dict
+ self.assertNotIn('a', C.__dict__)
+ self.assertNotIn('b', C.__dict__)
+ self.assertNotIn('c', C.__dict__)
+ self.assertIn('d', C.__dict__)
+ self.assertEqual(C.d, 4)
+ self.assertIn('e', C.__dict__)
+ self.assertEqual(C.e, 0)
+ # Instance dict
+ self.assertIn('a', c.__dict__)
+ self.assertEqual(c.a, 0)
+ self.assertIn('b', c.__dict__)
+ self.assertEqual(c.b, [])
+ self.assertIn('c', c.__dict__)
+ self.assertEqual(c.c, [])
+ self.assertNotIn('d', c.__dict__)
+ self.assertIn('e', c.__dict__)
+ self.assertEqual(c.e, 0)
+
+ def test_alternate_classmethod_constructor(self):
+ # Since __post_init__ can't take params, use a classmethod
+ # alternate constructor. This is mostly an example to show
+ # how to use this technique.
+ @dataclass
+ class C:
+ x: int
+ @classmethod
+ def from_file(cls, filename):
+ # In a real example, create a new instance
+ # and populate 'x' from contents of a file.
+ value_in_file = 20
+ return cls(value_in_file)
+
+ self.assertEqual(C.from_file('filename').x, 20)
+
+ def test_field_metadata_default(self):
+ # Make sure the default metadata is read-only and of
+ # zero length.
+ @dataclass
+ class C:
+ i: int
+
+ self.assertFalse(fields(C)[0].metadata)
+ self.assertEqual(len(fields(C)[0].metadata), 0)
+ with self.assertRaisesRegex(TypeError,
+ 'does not support item assignment'):
+ fields(C)[0].metadata['test'] = 3
+
+ def test_field_metadata_mapping(self):
+ # Make sure only a mapping can be passed as metadata
+ # zero length.
+ with self.assertRaises(TypeError):
+ @dataclass
+ class C:
+ i: int = field(metadata=0)
+
+ # Make sure an empty dict works.
+ d = {}
+ @dataclass
+ class C:
+ i: int = field(metadata=d)
+ self.assertFalse(fields(C)[0].metadata)
+ self.assertEqual(len(fields(C)[0].metadata), 0)
+ # Update should work (see bpo-35960).
+ d['foo'] = 1
+ self.assertEqual(len(fields(C)[0].metadata), 1)
+ self.assertEqual(fields(C)[0].metadata['foo'], 1)
+ with self.assertRaisesRegex(TypeError,
+ 'does not support item assignment'):
+ fields(C)[0].metadata['test'] = 3
+
+ # Make sure a non-empty dict works.
+ d = {'test': 10, 'bar': '42', 3: 'three'}
+ @dataclass
+ class C:
+ i: int = field(metadata=d)
+ self.assertEqual(len(fields(C)[0].metadata), 3)
+ self.assertEqual(fields(C)[0].metadata['test'], 10)
+ self.assertEqual(fields(C)[0].metadata['bar'], '42')
+ self.assertEqual(fields(C)[0].metadata[3], 'three')
+ # Update should work.
+ d['foo'] = 1
+ self.assertEqual(len(fields(C)[0].metadata), 4)
+ self.assertEqual(fields(C)[0].metadata['foo'], 1)
+ with self.assertRaises(KeyError):
+ # Non-existent key.
+ fields(C)[0].metadata['baz']
+ with self.assertRaisesRegex(TypeError,
+ 'does not support item assignment'):
+ fields(C)[0].metadata['test'] = 3
+
+ def test_field_metadata_custom_mapping(self):
+ # Try a custom mapping.
+ class SimpleNameSpace:
+ def __init__(self, **kw):
+ self.__dict__.update(kw)
+
+ def __getitem__(self, item):
+ if item == 'xyzzy':
+ return 'plugh'
+ return getattr(self, item)
+
+ def __len__(self):
+ return self.__dict__.__len__()
+
+ @dataclass
+ class C:
+ i: int = field(metadata=SimpleNameSpace(a=10))
+
+ self.assertEqual(len(fields(C)[0].metadata), 1)
+ self.assertEqual(fields(C)[0].metadata['a'], 10)
+ with self.assertRaises(AttributeError):
+ fields(C)[0].metadata['b']
+ # Make sure we're still talking to our custom mapping.
+ self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
+
+ def test_generic_dataclasses(self):
+ T = TypeVar('T')
+
+ @dataclass
+ class LabeledBox(Generic[T]):
+ content: T
+ label: str = '<unknown>'
+
+ box = LabeledBox(42)
+ self.assertEqual(box.content, 42)
+ self.assertEqual(box.label, '<unknown>')
+
+ # Subscripting the resulting class should work, etc.
+ Alias = List[LabeledBox[int]]
+
+ def test_generic_extending(self):
+ S = TypeVar('S')
+ T = TypeVar('T')
+
+ @dataclass
+ class Base(Generic[T, S]):
+ x: T
+ y: S
+
+ @dataclass
+ class DataDerived(Base[int, T]):
+ new_field: str
+ Alias = DataDerived[str]
+ c = Alias(0, 'test1', 'test2')
+ self.assertEqual(astuple(c), (0, 'test1', 'test2'))
+
+ class NonDataDerived(Base[int, T]):
+ def new_method(self):
+ return self.y
+ Alias = NonDataDerived[float]
+ c = Alias(10, 1.0)
+ self.assertEqual(c.new_method(), 1.0)
+
+ def test_generic_dynamic(self):
+ T = TypeVar('T')
+
+ @dataclass
+ class Parent(Generic[T]):
+ x: T
+ Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
+ bases=(Parent[int], Generic[T]), namespace={'other': 42})
+ self.assertIs(Child[int](1, 2).z, None)
+ self.assertEqual(Child[int](1, 2, 3).z, 3)
+ self.assertEqual(Child[int](1, 2, 3).other, 42)
+ # Check that type aliases work correctly.
+ Alias = Child[T]
+ self.assertEqual(Alias[int](1, 2).x, 1)
+ # Check MRO resolution.
+ self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
+
+ def test_dataclasses_pickleable(self):
+ global P, Q, R
+ @dataclass
+ class P:
+ x: int
+ y: int = 0
+ @dataclass
+ class Q:
+ x: int
+ y: int = field(default=0, init=False)
+ @dataclass
+ class R:
+ x: int
+ y: List[int] = field(default_factory=list)
+ q = Q(1)
+ q.y = 2
+ samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
+ for sample in samples:
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(sample=sample, proto=proto):
+ new_sample = pickle.loads(pickle.dumps(sample, proto))
+ self.assertEqual(sample.x, new_sample.x)
+ self.assertEqual(sample.y, new_sample.y)
+ self.assertIsNot(sample, new_sample)
+ new_sample.x = 42
+ another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
+ self.assertEqual(new_sample.x, another_new_sample.x)
+ self.assertEqual(sample.y, another_new_sample.y)
+
+ def test_dataclasses_qualnames(self):
+ @dataclass(order=True, unsafe_hash=True, frozen=True)
+ class A:
+ x: int
+ y: int
+
+ self.assertEqual(A.__init__.__name__, "__init__")
+ for function in (
+ '__eq__',
+ '__lt__',
+ '__le__',
+ '__gt__',
+ '__ge__',
+ '__hash__',
+ '__init__',
+ '__repr__',
+ '__setattr__',
+ '__delattr__',
+ ):
+ self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}")
+
+ with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"):
+ A()
+
+
+class TestFieldNoAnnotation(unittest.TestCase):
+ def test_field_without_annotation(self):
+ with self.assertRaisesRegex(TypeError,
+ "'f' is a field but has no type annotation"):
+ @dataclass
+ class C:
+ f = field()
+
+ def test_field_without_annotation_but_annotation_in_base(self):
+ @dataclass
+ class B:
+ f: int
+
+ with self.assertRaisesRegex(TypeError,
+ "'f' is a field but has no type annotation"):
+ # This is still an error: make sure we don't pick up the
+ # type annotation in the base class.
+ @dataclass
+ class C(B):
+ f = field()
+
+ def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
+ # Same test, but with the base class not a dataclass.
+ class B:
+ f: int
+
+ with self.assertRaisesRegex(TypeError,
+ "'f' is a field but has no type annotation"):
+ # This is still an error: make sure we don't pick up the
+ # type annotation in the base class.
+ @dataclass
+ class C(B):
+ f = field()
+
+
+class TestDocString(unittest.TestCase):
+ def assertDocStrEqual(self, a, b):
+ # Because 3.6 and 3.7 differ in how inspect.signature work
+ # (see bpo #32108), for the time being just compare them with
+ # whitespace stripped.
+ self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
+
+ def test_existing_docstring_not_overridden(self):
+ @dataclass
+ class C:
+ """Lorem ipsum"""
+ x: int
+
+ self.assertEqual(C.__doc__, "Lorem ipsum")
+
+ def test_docstring_no_fields(self):
+ @dataclass
+ class C:
+ pass
+
+ self.assertDocStrEqual(C.__doc__, "C()")
+
+ def test_docstring_one_field(self):
+ @dataclass
+ class C:
+ x: int
+
+ self.assertDocStrEqual(C.__doc__, "C(x:int)")
+
+ def test_docstring_two_fields(self):
+ @dataclass
+ class C:
+ x: int
+ y: int
+
+ self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
+
+ def test_docstring_three_fields(self):
+ @dataclass
+ class C:
+ x: int
+ y: int
+ z: str
+
+ self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
+
+ def test_docstring_one_field_with_default(self):
+ @dataclass
+ class C:
+ x: int = 3
+
+ self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
+
+ def test_docstring_one_field_with_default_none(self):
+ @dataclass
+ class C:
+ x: Union[int, type(None)] = None
+
+ self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
+
+ def test_docstring_list_field(self):
+ @dataclass
+ class C:
+ x: List[int]
+
+ self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
+
+ def test_docstring_list_field_with_default_factory(self):
+ @dataclass
+ class C:
+ x: List[int] = field(default_factory=list)
+
+ self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
+
+ def test_docstring_deque_field(self):
+ @dataclass
+ class C:
+ x: deque
+
+ self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
+
+ def test_docstring_deque_field_with_default_factory(self):
+ @dataclass
+ class C:
+ x: deque = field(default_factory=deque)
+
+ self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
+
+
+class TestInit(unittest.TestCase):
+ def test_base_has_init(self):
+ class B:
+ def __init__(self):
+ self.z = 100
+ pass
+
+ # Make sure that declaring this class doesn't raise an error.
+ # The issue is that we can't override __init__ in our class,
+ # but it should be okay to add __init__ to us if our base has
+ # an __init__.
+ @dataclass
+ class C(B):
+ x: int = 0
+ c = C(10)
+ self.assertEqual(c.x, 10)
+ self.assertNotIn('z', vars(c))
+
+ # Make sure that if we don't add an init, the base __init__
+ # gets called.
+ @dataclass(init=False)
+ class C(B):
+ x: int = 10
+ c = C()
+ self.assertEqual(c.x, 10)
+ self.assertEqual(c.z, 100)
+
+ def test_no_init(self):
+ @dataclass(init=False)
+ class C:
+ i: int = 0
+ self.assertEqual(C().i, 0)
+
+ @dataclass(init=False)
+ class C:
+ i: int = 2
+ def __init__(self):
+ self.i = 3
+ self.assertEqual(C().i, 3)
+
+ def test_overwriting_init(self):
+ # If the class has __init__, use it no matter the value of
+ # init=.
+
+ @dataclass
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+ self.assertEqual(C(3).x, 6)
+
+ @dataclass(init=True)
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+ self.assertEqual(C(4).x, 8)
+
+ @dataclass(init=False)
+ class C:
+ x: int
+ def __init__(self, x):
+ self.x = 2 * x
+ self.assertEqual(C(5).x, 10)
+
+ def test_inherit_from_protocol(self):
+ # Dataclasses inheriting from protocol should preserve their own `__init__`.
+ # See bpo-45081.
+
+ class P(Protocol):
+ a: int
+
+ @dataclass
+ class C(P):
+ a: int
+
+ self.assertEqual(C(5).a, 5)
+
+ @dataclass
+ class D(P):
+ def __init__(self, a):
+ self.a = a * 2
+
+ self.assertEqual(D(5).a, 10)
+
+
+class TestRepr(unittest.TestCase):
+ def test_repr(self):
+ @dataclass
+ class B:
+ x: int
+
+ @dataclass
+ class C(B):
+ y: int = 10
+
+ o = C(4)
+ self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
+
+ @dataclass
+ class D(C):
+ x: int = 20
+ self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
+
+ @dataclass
+ class C:
+ @dataclass
+ class D:
+ i: int
+ @dataclass
+ class E:
+ pass
+ self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
+ self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
+
+ def test_no_repr(self):
+ # Test a class with no __repr__ and repr=False.
+ @dataclass(repr=False)
+ class C:
+ x: int
+ self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
+ repr(C(3)))
+
+ # Test a class with a __repr__ and repr=False.
+ @dataclass(repr=False)
+ class C:
+ x: int
+ def __repr__(self):
+ return 'C-class'
+ self.assertEqual(repr(C(3)), 'C-class')
+
+ def test_overwriting_repr(self):
+ # If the class has __repr__, use it no matter the value of
+ # repr=.
+
+ @dataclass
+ class C:
+ x: int
+ def __repr__(self):
+ return 'x'
+ self.assertEqual(repr(C(0)), 'x')
+
+ @dataclass(repr=True)
+ class C:
+ x: int
+ def __repr__(self):
+ return 'x'
+ self.assertEqual(repr(C(0)), 'x')
+
+ @dataclass(repr=False)
+ class C:
+ x: int
+ def __repr__(self):
+ return 'x'
+ self.assertEqual(repr(C(0)), 'x')
+
+
+class TestEq(unittest.TestCase):
+ def test_no_eq(self):
+ # Test a class with no __eq__ and eq=False.
+ @dataclass(eq=False)
+ class C:
+ x: int
+ self.assertNotEqual(C(0), C(0))
+ c = C(3)
+ self.assertEqual(c, c)
+
+ # Test a class with an __eq__ and eq=False.
+ @dataclass(eq=False)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 10
+ self.assertEqual(C(3), 10)
+
+ def test_overwriting_eq(self):
+ # If the class has __eq__, use it no matter the value of
+ # eq=.
+
+ @dataclass
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 3
+ self.assertEqual(C(1), 3)
+ self.assertNotEqual(C(1), 1)
+
+ @dataclass(eq=True)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 4
+ self.assertEqual(C(1), 4)
+ self.assertNotEqual(C(1), 1)
+
+ @dataclass(eq=False)
+ class C:
+ x: int
+ def __eq__(self, other):
+ return other == 5
+ self.assertEqual(C(1), 5)
+ self.assertNotEqual(C(1), 1)
+
+
+class TestOrdering(unittest.TestCase):
+ def test_functools_total_ordering(self):
+ # Test that functools.total_ordering works with this class.
+ @total_ordering
+ @dataclass
+ class C:
+ x: int
+ def __lt__(self, other):
+ # Perform the test "backward", just to make
+ # sure this is being called.
+ return self.x >= other
+
+ self.assertLess(C(0), -1)
+ self.assertLessEqual(C(0), -1)
+ self.assertGreater(C(0), 1)
+ self.assertGreaterEqual(C(0), 1)
+
+ def test_no_order(self):
+ # Test that no ordering functions are added by default.
+ @dataclass(order=False)
+ class C:
+ x: int
+ # Make sure no order methods are added.
+ self.assertNotIn('__le__', C.__dict__)
+ self.assertNotIn('__lt__', C.__dict__)
+ self.assertNotIn('__ge__', C.__dict__)
+ self.assertNotIn('__gt__', C.__dict__)
+
+ # Test that __lt__ is still called
+ @dataclass(order=False)
+ class C:
+ x: int
+ def __lt__(self, other):
+ return False
+ # Make sure other methods aren't added.
+ self.assertNotIn('__le__', C.__dict__)
+ self.assertNotIn('__ge__', C.__dict__)
+ self.assertNotIn('__gt__', C.__dict__)
+
+ def test_overwriting_order(self):
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __lt__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __lt__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __le__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __le__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __gt__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __gt__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __ge__'
+ '.*using functools.total_ordering'):
+ @dataclass(order=True)
+ class C:
+ x: int
+ def __ge__(self):
+ pass
+
+class TestHash(unittest.TestCase):
+ def test_unsafe_hash(self):
+ @dataclass(unsafe_hash=True)
+ class C:
+ x: int
+ y: str
+ self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
+
+ def test_hash_rules(self):
+ def non_bool(value):
+ # Map to something else that's True, but not a bool.
+ if value is None:
+ return None
+ if value:
+ return (3,)
+ return 0
+
+ def test(case, unsafe_hash, eq, frozen, with_hash, result):
+ with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
+ frozen=frozen):
+ if result != 'exception':
+ if with_hash:
+ @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
+ class C:
+ def __hash__(self):
+ return 0
+ else:
+ @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
+ class C:
+ pass
+
+ # See if the result matches what's expected.
+ if result == 'fn':
+ # __hash__ contains the function we generated.
+ self.assertIn('__hash__', C.__dict__)
+ self.assertIsNotNone(C.__dict__['__hash__'])
+
+ elif result == '':
+ # __hash__ is not present in our class.
+ if not with_hash:
+ self.assertNotIn('__hash__', C.__dict__)
+
+ elif result == 'none':
+ # __hash__ is set to None.
+ self.assertIn('__hash__', C.__dict__)
+ self.assertIsNone(C.__dict__['__hash__'])
+
+ elif result == 'exception':
+ # Creating the class should cause an exception.
+ # This only happens with with_hash==True.
+ assert(with_hash)
+ with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
+ @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
+ class C:
+ def __hash__(self):
+ return 0
+
+ else:
+ assert False, f'unknown result {result!r}'
+
+ # There are 8 cases of:
+ # unsafe_hash=True/False
+ # eq=True/False
+ # frozen=True/False
+ # And for each of these, a different result if
+ # __hash__ is defined or not.
+ for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
+ (False, False, False, '', ''),
+ (False, False, True, '', ''),
+ (False, True, False, 'none', ''),
+ (False, True, True, 'fn', ''),
+ (True, False, False, 'fn', 'exception'),
+ (True, False, True, 'fn', 'exception'),
+ (True, True, False, 'fn', 'exception'),
+ (True, True, True, 'fn', 'exception'),
+ ], 1):
+ test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
+ test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
+
+ # Test non-bool truth values, too. This is just to
+ # make sure the data-driven table in the decorator
+ # handles non-bool values.
+ test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
+ test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
+
+
+ def test_eq_only(self):
+ # If a class defines __eq__, __hash__ is automatically added
+ # and set to None. This is normal Python behavior, not
+ # related to dataclasses. Make sure we don't interfere with
+ # that (see bpo=32546).
+
+ @dataclass
+ class C:
+ i: int
+ def __eq__(self, other):
+ return self.i == other.i
+ self.assertEqual(C(1), C(1))
+ self.assertNotEqual(C(1), C(4))
+
+ # And make sure things work in this case if we specify
+ # unsafe_hash=True.
+ @dataclass(unsafe_hash=True)
+ class C:
+ i: int
+ def __eq__(self, other):
+ return self.i == other.i
+ self.assertEqual(C(1), C(1.0))
+ self.assertEqual(hash(C(1)), hash(C(1.0)))
+
+ # And check that the classes __eq__ is being used, despite
+ # specifying eq=True.
+ @dataclass(unsafe_hash=True, eq=True)
+ class C:
+ i: int
+ def __eq__(self, other):
+ return self.i == 3 and self.i == other.i
+ self.assertEqual(C(3), C(3))
+ self.assertNotEqual(C(1), C(1))
+ self.assertEqual(hash(C(1)), hash(C(1.0)))
+
+ def test_0_field_hash(self):
+ @dataclass(frozen=True)
+ class C:
+ pass
+ self.assertEqual(hash(C()), hash(()))
+
+ @dataclass(unsafe_hash=True)
+ class C:
+ pass
+ self.assertEqual(hash(C()), hash(()))
+
+ def test_1_field_hash(self):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ self.assertEqual(hash(C(4)), hash((4,)))
+ self.assertEqual(hash(C(42)), hash((42,)))
+
+ @dataclass(unsafe_hash=True)
+ class C:
+ x: int
+ self.assertEqual(hash(C(4)), hash((4,)))
+ self.assertEqual(hash(C(42)), hash((42,)))
+
+ def test_hash_no_args(self):
+ # Test dataclasses with no hash= argument. This exists to
+ # make sure that if the @dataclass parameter name is changed
+ # or the non-default hashing behavior changes, the default
+ # hashability keeps working the same way.
+
+ class Base:
+ def __hash__(self):
+ return 301
+
+ # If frozen or eq is None, then use the default value (do not
+ # specify any value in the decorator).
+ for frozen, eq, base, expected in [
+ (None, None, object, 'unhashable'),
+ (None, None, Base, 'unhashable'),
+ (None, False, object, 'object'),
+ (None, False, Base, 'base'),
+ (None, True, object, 'unhashable'),
+ (None, True, Base, 'unhashable'),
+ (False, None, object, 'unhashable'),
+ (False, None, Base, 'unhashable'),
+ (False, False, object, 'object'),
+ (False, False, Base, 'base'),
+ (False, True, object, 'unhashable'),
+ (False, True, Base, 'unhashable'),
+ (True, None, object, 'tuple'),
+ (True, None, Base, 'tuple'),
+ (True, False, object, 'object'),
+ (True, False, Base, 'base'),
+ (True, True, object, 'tuple'),
+ (True, True, Base, 'tuple'),
+ ]:
+
+ with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
+ # First, create the class.
+ if frozen is None and eq is None:
+ @dataclass
+ class C(base):
+ i: int
+ elif frozen is None:
+ @dataclass(eq=eq)
+ class C(base):
+ i: int
+ elif eq is None:
+ @dataclass(frozen=frozen)
+ class C(base):
+ i: int
+ else:
+ @dataclass(frozen=frozen, eq=eq)
+ class C(base):
+ i: int
+
+ # Now, make sure it hashes as expected.
+ if expected == 'unhashable':
+ c = C(10)
+ with self.assertRaisesRegex(TypeError, 'unhashable type'):
+ hash(c)
+
+ elif expected == 'base':
+ self.assertEqual(hash(C(10)), 301)
+
+ elif expected == 'object':
+ # I'm not sure what test to use here. object's
+ # hash isn't based on id(), so calling hash()
+ # won't tell us much. So, just check the
+ # function used is object's.
+ self.assertIs(C.__hash__, object.__hash__)
+
+ elif expected == 'tuple':
+ self.assertEqual(hash(C(42)), hash((42,)))
+
+ else:
+ assert False, f'unknown value for expected={expected!r}'
+
+
+class TestFrozen(unittest.TestCase):
+ def test_frozen(self):
+ @dataclass(frozen=True)
+ class C:
+ i: int
+
+ c = C(10)
+ self.assertEqual(c.i, 10)
+ with self.assertRaises(FrozenInstanceError):
+ c.i = 5
+ self.assertEqual(c.i, 10)
+
+ def test_inherit(self):
+ @dataclass(frozen=True)
+ class C:
+ i: int
+
+ @dataclass(frozen=True)
+ class D(C):
+ j: int
+
+ d = D(0, 10)
+ with self.assertRaises(FrozenInstanceError):
+ d.i = 5
+ with self.assertRaises(FrozenInstanceError):
+ d.j = 6
+ self.assertEqual(d.i, 0)
+ self.assertEqual(d.j, 10)
+
+ def test_inherit_nonfrozen_from_empty_frozen(self):
+ @dataclass(frozen=True)
+ class C:
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'cannot inherit non-frozen dataclass from a frozen one'):
+ @dataclass
+ class D(C):
+ j: int
+
+ def test_inherit_nonfrozen_from_empty(self):
+ @dataclass
+ class C:
+ pass
+
+ @dataclass
+ class D(C):
+ j: int
+
+ d = D(3)
+ self.assertEqual(d.j, 3)
+ self.assertIsInstance(d, C)
+
+ # Test both ways: with an intermediate normal (non-dataclass)
+ # class and without an intermediate class.
+ def test_inherit_nonfrozen_from_frozen(self):
+ for intermediate_class in [True, False]:
+ with self.subTest(intermediate_class=intermediate_class):
+ @dataclass(frozen=True)
+ class C:
+ i: int
+
+ if intermediate_class:
+ class I(C): pass
+ else:
+ I = C
+
+ with self.assertRaisesRegex(TypeError,
+ 'cannot inherit non-frozen dataclass from a frozen one'):
+ @dataclass
+ class D(I):
+ pass
+
+ def test_inherit_frozen_from_nonfrozen(self):
+ for intermediate_class in [True, False]:
+ with self.subTest(intermediate_class=intermediate_class):
+ @dataclass
+ class C:
+ i: int
+
+ if intermediate_class:
+ class I(C): pass
+ else:
+ I = C
+
+ with self.assertRaisesRegex(TypeError,
+ 'cannot inherit frozen dataclass from a non-frozen one'):
+ @dataclass(frozen=True)
+ class D(I):
+ pass
+
+ def test_inherit_from_normal_class(self):
+ for intermediate_class in [True, False]:
+ with self.subTest(intermediate_class=intermediate_class):
+ class C:
+ pass
+
+ if intermediate_class:
+ class I(C): pass
+ else:
+ I = C
+
+ @dataclass(frozen=True)
+ class D(I):
+ i: int
+
+ d = D(10)
+ with self.assertRaises(FrozenInstanceError):
+ d.i = 5
+
+ def test_non_frozen_normal_derived(self):
+ # See bpo-32953.
+
+ @dataclass(frozen=True)
+ class D:
+ x: int
+ y: int = 10
+
+ class S(D):
+ pass
+
+ s = S(3)
+ self.assertEqual(s.x, 3)
+ self.assertEqual(s.y, 10)
+ s.cached = True
+
+ # But can't change the frozen attributes.
+ with self.assertRaises(FrozenInstanceError):
+ s.x = 5
+ with self.assertRaises(FrozenInstanceError):
+ s.y = 5
+ self.assertEqual(s.x, 3)
+ self.assertEqual(s.y, 10)
+ self.assertEqual(s.cached, True)
+
+ def test_overwriting_frozen(self):
+ # frozen uses __setattr__ and __delattr__.
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __setattr__'):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __setattr__(self):
+ pass
+
+ with self.assertRaisesRegex(TypeError,
+ 'Cannot overwrite attribute __delattr__'):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ def __delattr__(self):
+ pass
+
+ @dataclass(frozen=False)
+ class C:
+ x: int
+ def __setattr__(self, name, value):
+ self.__dict__['x'] = value * 2
+ self.assertEqual(C(10).x, 20)
+
+ def test_frozen_hash(self):
+ @dataclass(frozen=True)
+ class C:
+ x: Any
+
+ # If x is immutable, we can compute the hash. No exception is
+ # raised.
+ hash(C(3))
+
+ # If x is mutable, computing the hash is an error.
+ with self.assertRaisesRegex(TypeError, 'unhashable type'):
+ hash(C({}))
+
+
+class TestSlots(unittest.TestCase):
+ def test_simple(self):
+ @dataclass
+ class C:
+ __slots__ = ('x',)
+ x: Any
+
+ # There was a bug where a variable in a slot was assumed to
+ # also have a default value (of type
+ # types.MemberDescriptorType).
+ with self.assertRaisesRegex(TypeError,
+ r"__init__\(\) missing 1 required positional argument: 'x'"):
+ C()
+
+ # We can create an instance, and assign to x.
+ c = C(10)
+ self.assertEqual(c.x, 10)
+ c.x = 5
+ self.assertEqual(c.x, 5)
+
+ # We can't assign to anything else.
+ with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
+ c.y = 5
+
+ def test_derived_added_field(self):
+ # See bpo-33100.
+ @dataclass
+ class Base:
+ __slots__ = ('x',)
+ x: Any
+
+ @dataclass
+ class Derived(Base):
+ x: int
+ y: int
+
+ d = Derived(1, 2)
+ self.assertEqual((d.x, d.y), (1, 2))
+
+ # We can add a new field to the derived instance.
+ d.z = 10
+
+ def test_generated_slots(self):
+ @dataclass(slots=True)
+ class C:
+ x: int
+ y: int
+
+ c = C(1, 2)
+ self.assertEqual((c.x, c.y), (1, 2))
+
+ c.x = 3
+ c.y = 4
+ self.assertEqual((c.x, c.y), (3, 4))
+
+ with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"):
+ c.z = 5
+
+ def test_add_slots_when_slots_exists(self):
+ with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'):
+ @dataclass(slots=True)
+ class C:
+ __slots__ = ('x',)
+ x: int
+
+ def test_generated_slots_value(self):
+
+ class Root:
+ __slots__ = {'x'}
+
+ class Root2(Root):
+ __slots__ = {'k': '...', 'j': ''}
+
+ class Root3(Root2):
+ __slots__ = ['h']
+
+ class Root4(Root3):
+ __slots__ = 'aa'
+
+ @dataclass(slots=True)
+ class Base(Root4):
+ y: int
+ j: str
+ h: str
+
+ self.assertEqual(Base.__slots__, ('y', ))
+
+ @dataclass(slots=True)
+ class Derived(Base):
+ aa: float
+ x: str
+ z: int
+ k: str
+ h: str
+
+ self.assertEqual(Derived.__slots__, ('z', ))
+
+ @dataclass
+ class AnotherDerived(Base):
+ z: int
+
+ self.assertNotIn('__slots__', AnotherDerived.__dict__)
+
+ def test_cant_inherit_from_iterator_slots(self):
+
+ class Root:
+ __slots__ = iter(['a'])
+
+ class Root2(Root):
+ __slots__ = ('b', )
+
+ with self.assertRaisesRegex(
+ TypeError,
+ "^Slots of 'Root' cannot be determined"
+ ):
+ @dataclass(slots=True)
+ class C(Root2):
+ x: int
+
+ def test_returns_new_class(self):
+ class A:
+ x: int
+
+ B = dataclass(A, slots=True)
+ self.assertIsNot(A, B)
+
+ self.assertFalse(hasattr(A, "__slots__"))
+ self.assertTrue(hasattr(B, "__slots__"))
+
+ # Can't be local to test_frozen_pickle.
+ @dataclass(frozen=True, slots=True)
+ class FrozenSlotsClass:
+ foo: str
+ bar: int
+
+ @dataclass(frozen=True)
+ class FrozenWithoutSlotsClass:
+ foo: str
+ bar: int
+
+ def test_frozen_pickle(self):
+ # bpo-43999
+
+ self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar"))
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
+ with self.subTest(proto=proto):
+ obj = self.FrozenSlotsClass("a", 1)
+ p = pickle.loads(pickle.dumps(obj, protocol=proto))
+ self.assertIsNot(obj, p)
+ self.assertEqual(obj, p)
+
+ obj = self.FrozenWithoutSlotsClass("a", 1)
+ p = pickle.loads(pickle.dumps(obj, protocol=proto))
+ self.assertIsNot(obj, p)
+ self.assertEqual(obj, p)
+
+ def test_slots_with_default_no_init(self):
+ # Originally reported in bpo-44649.
+ @dataclass(slots=True)
+ class A:
+ a: str
+ b: str = field(default='b', init=False)
+
+ obj = A("a")
+ self.assertEqual(obj.a, 'a')
+ self.assertEqual(obj.b, 'b')
+
+ def test_slots_with_default_factory_no_init(self):
+ # Originally reported in bpo-44649.
+ @dataclass(slots=True)
+ class A:
+ a: str
+ b: str = field(default_factory=lambda:'b', init=False)
+
+ obj = A("a")
+ self.assertEqual(obj.a, 'a')
+ self.assertEqual(obj.b, 'b')
+
+ def test_slots_no_weakref(self):
+ @dataclass(slots=True)
+ class A:
+ # No weakref.
+ pass
+
+ self.assertNotIn("__weakref__", A.__slots__)
+ a = A()
+ with self.assertRaisesRegex(TypeError,
+ "cannot create weak reference"):
+ weakref.ref(a)
+
+ def test_slots_weakref(self):
+ @dataclass(slots=True, weakref_slot=True)
+ class A:
+ a: int
+
+ self.assertIn("__weakref__", A.__slots__)
+ a = A(1)
+ weakref.ref(a)
+
+ def test_slots_weakref_base_str(self):
+ class Base:
+ __slots__ = '__weakref__'
+
+ @dataclass(slots=True)
+ class A(Base):
+ a: int
+
+ # __weakref__ is in the base class, not A. But an A is still weakref-able.
+ self.assertIn("__weakref__", Base.__slots__)
+ self.assertNotIn("__weakref__", A.__slots__)
+ a = A(1)
+ weakref.ref(a)
+
+ def test_slots_weakref_base_tuple(self):
+ # Same as test_slots_weakref_base, but use a tuple instead of a string
+ # in the base class.
+ class Base:
+ __slots__ = ('__weakref__',)
+
+ @dataclass(slots=True)
+ class A(Base):
+ a: int
+
+ # __weakref__ is in the base class, not A. But an A is still
+ # weakref-able.
+ self.assertIn("__weakref__", Base.__slots__)
+ self.assertNotIn("__weakref__", A.__slots__)
+ a = A(1)
+ weakref.ref(a)
+
+ def test_weakref_slot_without_slot(self):
+ with self.assertRaisesRegex(TypeError,
+ "weakref_slot is True but slots is False"):
+ @dataclass(weakref_slot=True)
+ class A:
+ a: int
+
+ def test_weakref_slot_make_dataclass(self):
+ A = make_dataclass('A', [('a', int),], slots=True, weakref_slot=True)
+ self.assertIn("__weakref__", A.__slots__)
+ a = A(1)
+ weakref.ref(a)
+
+ # And make sure if raises if slots=True is not given.
+ with self.assertRaisesRegex(TypeError,
+ "weakref_slot is True but slots is False"):
+ B = make_dataclass('B', [('a', int),], weakref_slot=True)
+
+ def test_weakref_slot_subclass_weakref_slot(self):
+ @dataclass(slots=True, weakref_slot=True)
+ class Base:
+ field: int
+
+ # A *can* also specify weakref_slot=True if it wants to (gh-93521)
+ @dataclass(slots=True, weakref_slot=True)
+ class A(Base):
+ ...
+
+ # __weakref__ is in the base class, not A. But an instance of A
+ # is still weakref-able.
+ self.assertIn("__weakref__", Base.__slots__)
+ self.assertNotIn("__weakref__", A.__slots__)
+ a = A(1)
+ weakref.ref(a)
+
+ def test_weakref_slot_subclass_no_weakref_slot(self):
+ @dataclass(slots=True, weakref_slot=True)
+ class Base:
+ field: int
+
+ @dataclass(slots=True)
+ class A(Base):
+ ...
+
+ # __weakref__ is in the base class, not A. Even though A doesn't
+ # specify weakref_slot, it should still be weakref-able.
+ self.assertIn("__weakref__", Base.__slots__)
+ self.assertNotIn("__weakref__", A.__slots__)
+ a = A(1)
+ weakref.ref(a)
+
+ def test_weakref_slot_normal_base_weakref_slot(self):
+ class Base:
+ __slots__ = ('__weakref__',)
+
+ @dataclass(slots=True, weakref_slot=True)
+ class A(Base):
+ field: int
+
+ # __weakref__ is in the base class, not A. But an instance of
+ # A is still weakref-able.
+ self.assertIn("__weakref__", Base.__slots__)
+ self.assertNotIn("__weakref__", A.__slots__)
+ a = A(1)
+ weakref.ref(a)
+
+
+class TestDescriptors(unittest.TestCase):
+ def test_set_name(self):
+ # See bpo-33141.
+
+ # Create a descriptor.
+ class D:
+ def __set_name__(self, owner, name):
+ self.name = name + 'x'
+ def __get__(self, instance, owner):
+ if instance is not None:
+ return 1
+ return self
+
+ # This is the case of just normal descriptor behavior, no
+ # dataclass code is involved in initializing the descriptor.
+ @dataclass
+ class C:
+ c: int=D()
+ self.assertEqual(C.c.name, 'cx')
+
+ # Now test with a default value and init=False, which is the
+ # only time this is really meaningful. If not using
+ # init=False, then the descriptor will be overwritten, anyway.
+ @dataclass
+ class C:
+ c: int=field(default=D(), init=False)
+ self.assertEqual(C.c.name, 'cx')
+ self.assertEqual(C().c, 1)
+
+ def test_non_descriptor(self):
+ # PEP 487 says __set_name__ should work on non-descriptors.
+ # Create a descriptor.
+
+ class D:
+ def __set_name__(self, owner, name):
+ self.name = name + 'x'
+
+ @dataclass
+ class C:
+ c: int=field(default=D(), init=False)
+ self.assertEqual(C.c.name, 'cx')
+
+ def test_lookup_on_instance(self):
+ # See bpo-33175.
+ class D:
+ pass
+
+ d = D()
+ # Create an attribute on the instance, not type.
+ d.__set_name__ = Mock()
+
+ # Make sure d.__set_name__ is not called.
+ @dataclass
+ class C:
+ i: int=field(default=d, init=False)
+
+ self.assertEqual(d.__set_name__.call_count, 0)
+
+ def test_lookup_on_class(self):
+ # See bpo-33175.
+ class D:
+ pass
+ D.__set_name__ = Mock()
+
+ # Make sure D.__set_name__ is called.
+ @dataclass
+ class C:
+ i: int=field(default=D(), init=False)
+
+ self.assertEqual(D.__set_name__.call_count, 1)
+
+ def test_init_calls_set(self):
+ class D:
+ pass
+
+ D.__set__ = Mock()
+
+ @dataclass
+ class C:
+ i: D = D()
+
+ # Make sure D.__set__ is called.
+ D.__set__.reset_mock()
+ c = C(5)
+ self.assertEqual(D.__set__.call_count, 1)
+
+ def test_getting_field_calls_get(self):
+ class D:
+ pass
+
+ D.__set__ = Mock()
+ D.__get__ = Mock()
+
+ @dataclass
+ class C:
+ i: D = D()
+
+ c = C(5)
+
+ # Make sure D.__get__ is called.
+ D.__get__.reset_mock()
+ value = c.i
+ self.assertEqual(D.__get__.call_count, 1)
+
+ def test_setting_field_calls_set(self):
+ class D:
+ pass
+
+ D.__set__ = Mock()
+
+ @dataclass
+ class C:
+ i: D = D()
+
+ c = C(5)
+
+ # Make sure D.__set__ is called.
+ D.__set__.reset_mock()
+ c.i = 10
+ self.assertEqual(D.__set__.call_count, 1)
+
+ def test_setting_uninitialized_descriptor_field(self):
+ class D:
+ pass
+
+ D.__set__ = Mock()
+
+ @dataclass
+ class C:
+ i: D
+
+ # D.__set__ is not called because there's no D instance to call it on
+ D.__set__.reset_mock()
+ c = C(5)
+ self.assertEqual(D.__set__.call_count, 0)
+
+ # D.__set__ still isn't called after setting i to an instance of D
+ # because descriptors don't behave like that when stored as instance vars
+ c.i = D()
+ c.i = 5
+ self.assertEqual(D.__set__.call_count, 0)
+
+ def test_default_value(self):
+ class D:
+ def __get__(self, instance: Any, owner: object) -> int:
+ if instance is None:
+ return 100
+
+ return instance._x
+
+ def __set__(self, instance: Any, value: int) -> None:
+ instance._x = value
+
+ @dataclass
+ class C:
+ i: D = D()
+
+ c = C()
+ self.assertEqual(c.i, 100)
+
+ c = C(5)
+ self.assertEqual(c.i, 5)
+
+ def test_no_default_value(self):
+ class D:
+ def __get__(self, instance: Any, owner: object) -> int:
+ if instance is None:
+ raise AttributeError()
+
+ return instance._x
+
+ def __set__(self, instance: Any, value: int) -> None:
+ instance._x = value
+
+ @dataclass
+ class C:
+ i: D = D()
+
+ with self.assertRaisesRegex(TypeError, 'missing 1 required positional argument'):
+ c = C()
+
+class TestStringAnnotations(unittest.TestCase):
+ def test_classvar(self):
+ # Some expressions recognized as ClassVar really aren't. But
+ # if you're using string annotations, it's not an exact
+ # science.
+ # These tests assume that both "import typing" and "from
+ # typing import *" have been run in this file.
+ for typestr in ('ClassVar[int]',
+ 'ClassVar [int]',
+ ' ClassVar [int]',
+ 'ClassVar',
+ ' ClassVar ',
+ 'typing.ClassVar[int]',
+ 'typing.ClassVar[str]',
+ ' typing.ClassVar[str]',
+ 'typing .ClassVar[str]',
+ 'typing. ClassVar[str]',
+ 'typing.ClassVar [str]',
+ 'typing.ClassVar [ str]',
+
+ # Not syntactically valid, but these will
+ # be treated as ClassVars.
+ 'typing.ClassVar.[int]',
+ 'typing.ClassVar+',
+ ):
+ with self.subTest(typestr=typestr):
+ @dataclass
+ class C:
+ x: typestr
+
+ # x is a ClassVar, so C() takes no args.
+ C()
+
+ # And it won't appear in the class's dict because it doesn't
+ # have a default.
+ self.assertNotIn('x', C.__dict__)
+
+ def test_isnt_classvar(self):
+ for typestr in ('CV',
+ 't.ClassVar',
+ 't.ClassVar[int]',
+ 'typing..ClassVar[int]',
+ 'Classvar',
+ 'Classvar[int]',
+ 'typing.ClassVarx[int]',
+ 'typong.ClassVar[int]',
+ 'dataclasses.ClassVar[int]',
+ 'typingxClassVar[str]',
+ ):
+ with self.subTest(typestr=typestr):
+ @dataclass
+ class C:
+ x: typestr
+
+ # x is not a ClassVar, so C() takes one arg.
+ self.assertEqual(C(10).x, 10)
+
+ def test_initvar(self):
+ # These tests assume that both "import dataclasses" and "from
+ # dataclasses import *" have been run in this file.
+ for typestr in ('InitVar[int]',
+ 'InitVar [int]'
+ ' InitVar [int]',
+ 'InitVar',
+ ' InitVar ',
+ 'dataclasses.InitVar[int]',
+ 'dataclasses.InitVar[str]',
+ ' dataclasses.InitVar[str]',
+ 'dataclasses .InitVar[str]',
+ 'dataclasses. InitVar[str]',
+ 'dataclasses.InitVar [str]',
+ 'dataclasses.InitVar [ str]',
+
+ # Not syntactically valid, but these will
+ # be treated as InitVars.
+ 'dataclasses.InitVar.[int]',
+ 'dataclasses.InitVar+',
+ ):
+ with self.subTest(typestr=typestr):
+ @dataclass
+ class C:
+ x: typestr
+
+ # x is an InitVar, so doesn't create a member.
+ with self.assertRaisesRegex(AttributeError,
+ "object has no attribute 'x'"):
+ C(1).x
+
+ def test_isnt_initvar(self):
+ for typestr in ('IV',
+ 'dc.InitVar',
+ 'xdataclasses.xInitVar',
+ 'typing.xInitVar[int]',
+ ):
+ with self.subTest(typestr=typestr):
+ @dataclass
+ class C:
+ x: typestr
+
+ # x is not an InitVar, so there will be a member x.
+ self.assertEqual(C(10).x, 10)
+
+ def test_classvar_module_level_import(self):
+ from test import dataclass_module_1
+ from test import dataclass_module_1_str
+ from test import dataclass_module_2
+ from test import dataclass_module_2_str
+
+ for m in (dataclass_module_1, dataclass_module_1_str,
+ dataclass_module_2, dataclass_module_2_str,
+ ):
+ with self.subTest(m=m):
+ # There's a difference in how the ClassVars are
+ # interpreted when using string annotations or
+ # not. See the imported modules for details.
+ if m.USING_STRINGS:
+ c = m.CV(10)
+ else:
+ c = m.CV()
+ self.assertEqual(c.cv0, 20)
+
+
+ # There's a difference in how the InitVars are
+ # interpreted when using string annotations or
+ # not. See the imported modules for details.
+ c = m.IV(0, 1, 2, 3, 4)
+
+ for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
+ with self.subTest(field_name=field_name):
+ with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
+ # Since field_name is an InitVar, it's
+ # not an instance field.
+ getattr(c, field_name)
+
+ if m.USING_STRINGS:
+ # iv4 is interpreted as a normal field.
+ self.assertIn('not_iv4', c.__dict__)
+ self.assertEqual(c.not_iv4, 4)
+ else:
+ # iv4 is interpreted as an InitVar, so it
+ # won't exist on the instance.
+ self.assertNotIn('not_iv4', c.__dict__)
+
+ def test_text_annotations(self):
+ from test import dataclass_textanno
+
+ self.assertEqual(
+ get_type_hints(dataclass_textanno.Bar),
+ {'foo': dataclass_textanno.Foo})
+ self.assertEqual(
+ get_type_hints(dataclass_textanno.Bar.__init__),
+ {'foo': dataclass_textanno.Foo,
+ 'return': type(None)})
+
+
+class TestMakeDataclass(unittest.TestCase):
+ def test_simple(self):
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', int, field(default=5))],
+ namespace={'add_one': lambda self: self.x + 1})
+ c = C(10)
+ self.assertEqual((c.x, c.y), (10, 5))
+ self.assertEqual(c.add_one(), 11)
+
+
+ def test_no_mutate_namespace(self):
+ # Make sure a provided namespace isn't mutated.
+ ns = {}
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', int, field(default=5))],
+ namespace=ns)
+ self.assertEqual(ns, {})
+
+ def test_base(self):
+ class Base1:
+ pass
+ class Base2:
+ pass
+ C = make_dataclass('C',
+ [('x', int)],
+ bases=(Base1, Base2))
+ c = C(2)
+ self.assertIsInstance(c, C)
+ self.assertIsInstance(c, Base1)
+ self.assertIsInstance(c, Base2)
+
+ def test_base_dataclass(self):
+ @dataclass
+ class Base1:
+ x: int
+ class Base2:
+ pass
+ C = make_dataclass('C',
+ [('y', int)],
+ bases=(Base1, Base2))
+ with self.assertRaisesRegex(TypeError, 'required positional'):
+ c = C(2)
+ c = C(1, 2)
+ self.assertIsInstance(c, C)
+ self.assertIsInstance(c, Base1)
+ self.assertIsInstance(c, Base2)
+
+ self.assertEqual((c.x, c.y), (1, 2))
+
+ def test_init_var(self):
+ def post_init(self, y):
+ self.x *= y
+
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', InitVar[int]),
+ ],
+ namespace={'__post_init__': post_init},
+ )
+ c = C(2, 3)
+ self.assertEqual(vars(c), {'x': 6})
+ self.assertEqual(len(fields(c)), 1)
+
+ def test_class_var(self):
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', ClassVar[int], 10),
+ ('z', ClassVar[int], field(default=20)),
+ ])
+ c = C(1)
+ self.assertEqual(vars(c), {'x': 1})
+ self.assertEqual(len(fields(c)), 1)
+ self.assertEqual(C.y, 10)
+ self.assertEqual(C.z, 20)
+
+ def test_other_params(self):
+ C = make_dataclass('C',
+ [('x', int),
+ ('y', ClassVar[int], 10),
+ ('z', ClassVar[int], field(default=20)),
+ ],
+ init=False)
+ # Make sure we have a repr, but no init.
+ self.assertNotIn('__init__', vars(C))
+ self.assertIn('__repr__', vars(C))
+
+ # Make sure random other params don't work.
+ with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
+ C = make_dataclass('C',
+ [],
+ xxinit=False)
+
+ def test_no_types(self):
+ C = make_dataclass('Point', ['x', 'y', 'z'])
+ c = C(1, 2, 3)
+ self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
+ self.assertEqual(C.__annotations__, {'x': 'typing.Any',
+ 'y': 'typing.Any',
+ 'z': 'typing.Any'})
+
+ C = make_dataclass('Point', ['x', ('y', int), 'z'])
+ c = C(1, 2, 3)
+ self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
+ self.assertEqual(C.__annotations__, {'x': 'typing.Any',
+ 'y': int,
+ 'z': 'typing.Any'})
+
+ def test_invalid_type_specification(self):
+ for bad_field in [(),
+ (1, 2, 3, 4),
+ ]:
+ with self.subTest(bad_field=bad_field):
+ with self.assertRaisesRegex(TypeError, r'Invalid field: '):
+ make_dataclass('C', ['a', bad_field])
+
+ # And test for things with no len().
+ for bad_field in [float,
+ lambda x:x,
+ ]:
+ with self.subTest(bad_field=bad_field):
+ with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
+ make_dataclass('C', ['a', bad_field])
+
+ def test_duplicate_field_names(self):
+ for field in ['a', 'ab']:
+ with self.subTest(field=field):
+ with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
+ make_dataclass('C', [field, 'a', field])
+
+ def test_keyword_field_names(self):
+ for field in ['for', 'async', 'await', 'as']:
+ with self.subTest(field=field):
+ with self.assertRaisesRegex(TypeError, 'must not be keywords'):
+ make_dataclass('C', ['a', field])
+ with self.assertRaisesRegex(TypeError, 'must not be keywords'):
+ make_dataclass('C', [field])
+ with self.assertRaisesRegex(TypeError, 'must not be keywords'):
+ make_dataclass('C', [field, 'a'])
+
+ def test_non_identifier_field_names(self):
+ for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
+ with self.subTest(field=field):
+ with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
+ make_dataclass('C', ['a', field])
+ with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
+ make_dataclass('C', [field])
+ with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
+ make_dataclass('C', [field, 'a'])
+
+ def test_underscore_field_names(self):
+ # Unlike namedtuple, it's okay if dataclass field names have
+ # an underscore.
+ make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
+
+ def test_funny_class_names_names(self):
+ # No reason to prevent weird class names, since
+ # types.new_class allows them.
+ for classname in ['()', 'x,y', '*', '2@3', '']:
+ with self.subTest(classname=classname):
+ C = make_dataclass(classname, ['a', 'b'])
+ self.assertEqual(C.__name__, classname)
+
+class TestReplace(unittest.TestCase):
+ def test(self):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int
+
+ c = C(1, 2)
+ c1 = replace(c, x=3)
+ self.assertEqual(c1.x, 3)
+ self.assertEqual(c1.y, 2)
+
+ def test_frozen(self):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int
+ z: int = field(init=False, default=10)
+ t: int = field(init=False, default=100)
+
+ c = C(1, 2)
+ c1 = replace(c, x=3)
+ self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
+ self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
+
+
+ with self.assertRaisesRegex(ValueError, 'init=False'):
+ replace(c, x=3, z=20, t=50)
+ with self.assertRaisesRegex(ValueError, 'init=False'):
+ replace(c, z=20)
+ replace(c, x=3, z=20, t=50)
+
+ # Make sure the result is still frozen.
+ with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
+ c1.x = 3
+
+ # Make sure we can't replace an attribute that doesn't exist,
+ # if we're also replacing one that does exist. Test this
+ # here, because setting attributes on frozen instances is
+ # handled slightly differently from non-frozen ones.
+ with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
+ "keyword argument 'a'"):
+ c1 = replace(c, x=20, a=5)
+
+ def test_invalid_field_name(self):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int
+
+ c = C(1, 2)
+ with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
+ "keyword argument 'z'"):
+ c1 = replace(c, z=3)
+
+ def test_invalid_object(self):
+ @dataclass(frozen=True)
+ class C:
+ x: int
+ y: int
+
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ replace(C, x=3)
+
+ with self.assertRaisesRegex(TypeError, 'dataclass instance'):
+ replace(0, x=3)
+
+ def test_no_init(self):
+ @dataclass
+ class C:
+ x: int
+ y: int = field(init=False, default=10)
+
+ c = C(1)
+ c.y = 20
+
+ # Make sure y gets the default value.
+ c1 = replace(c, x=5)
+ self.assertEqual((c1.x, c1.y), (5, 10))
+
+ # Trying to replace y is an error.
+ with self.assertRaisesRegex(ValueError, 'init=False'):
+ replace(c, x=2, y=30)
+
+ with self.assertRaisesRegex(ValueError, 'init=False'):
+ replace(c, y=30)
+
+ def test_classvar(self):
+ @dataclass
+ class C:
+ x: int
+ y: ClassVar[int] = 1000
+
+ c = C(1)
+ d = C(2)
+
+ self.assertIs(c.y, d.y)
+ self.assertEqual(c.y, 1000)
+
+ # Trying to replace y is an error: can't replace ClassVars.
+ with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
+ "unexpected keyword argument 'y'"):
+ replace(c, y=30)
+
+ replace(c, x=5)
+
+ def test_initvar_is_specified(self):
+ @dataclass
+ class C:
+ x: int
+ y: InitVar[int]
+
+ def __post_init__(self, y):
+ self.x *= y
+
+ c = C(1, 10)
+ self.assertEqual(c.x, 10)
+ with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
+ "specified with replace()"):
+ replace(c, x=3)
+ c = replace(c, x=3, y=5)
+ self.assertEqual(c.x, 15)
+
+ def test_initvar_with_default_value(self):
+ @dataclass
+ class C:
+ x: int
+ y: InitVar[int] = None
+ z: InitVar[int] = 42
+
+ def __post_init__(self, y, z):
+ if y is not None:
+ self.x += y
+ if z is not None:
+ self.x += z
+
+ c = C(x=1, y=10, z=1)
+ self.assertEqual(replace(c), C(x=12))
+ self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42))
+ self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1))
+
+ def test_recursive_repr(self):
+ @dataclass
+ class C:
+ f: "C"
+
+ c = C(None)
+ c.f = c
+ self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
+
+ def test_recursive_repr_two_attrs(self):
+ @dataclass
+ class C:
+ f: "C"
+ g: "C"
+
+ c = C(None, None)
+ c.f = c
+ c.g = c
+ self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
+ ".<locals>.C(f=..., g=...)")
+
+ def test_recursive_repr_indirection(self):
+ @dataclass
+ class C:
+ f: "D"
+
+ @dataclass
+ class D:
+ f: "C"
+
+ c = C(None)
+ d = D(None)
+ c.f = d
+ d.f = c
+ self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
+ ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
+ ".<locals>.D(f=...))")
+
+ def test_recursive_repr_indirection_two(self):
+ @dataclass
+ class C:
+ f: "D"
+
+ @dataclass
+ class D:
+ f: "E"
+
+ @dataclass
+ class E:
+ f: "C"
+
+ c = C(None)
+ d = D(None)
+ e = E(None)
+ c.f = d
+ d.f = e
+ e.f = c
+ self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
+ ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
+ ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
+ ".<locals>.E(f=...)))")
+
+ def test_recursive_repr_misc_attrs(self):
+ @dataclass
+ class C:
+ f: "C"
+ g: int
+
+ c = C(None, 1)
+ c.f = c
+ self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
+ ".<locals>.C(f=..., g=1)")
+
+ ## def test_initvar(self):
+ ## @dataclass
+ ## class C:
+ ## x: int
+ ## y: InitVar[int]
+
+ ## c = C(1, 10)
+ ## d = C(2, 20)
+
+ ## # In our case, replacing an InitVar is a no-op
+ ## self.assertEqual(c, replace(c, y=5))
+
+ ## replace(c, x=5)
+
+class TestAbstract(unittest.TestCase):
+ def test_abc_implementation(self):
+ class Ordered(abc.ABC):
+ @abc.abstractmethod
+ def __lt__(self, other):
+ pass
+
+ @abc.abstractmethod
+ def __le__(self, other):
+ pass
+
+ @dataclass(order=True)
+ class Date(Ordered):
+ year: int
+ month: 'Month'
+ day: 'int'
+
+ self.assertFalse(inspect.isabstract(Date))
+ self.assertGreater(Date(2020,12,25), Date(2020,8,31))
+
+ def test_maintain_abc(self):
+ class A(abc.ABC):
+ @abc.abstractmethod
+ def foo(self):
+ pass
+
+ @dataclass
+ class Date(A):
+ year: int
+ month: 'Month'
+ day: 'int'
+
+ self.assertTrue(inspect.isabstract(Date))
+ msg = 'class Date without an implementation for abstract method foo'
+ self.assertRaisesRegex(TypeError, msg, Date)
+
+
+class TestMatchArgs(unittest.TestCase):
+ def test_match_args(self):
+ @dataclass
+ class C:
+ a: int
+ self.assertEqual(C(42).__match_args__, ('a',))
+
+ def test_explicit_match_args(self):
+ ma = ()
+ @dataclass
+ class C:
+ a: int
+ __match_args__ = ma
+ self.assertIs(C(42).__match_args__, ma)
+
+ def test_bpo_43764(self):
+ @dataclass(repr=False, eq=False, init=False)
+ class X:
+ a: int
+ b: int
+ c: int
+ self.assertEqual(X.__match_args__, ("a", "b", "c"))
+
+ def test_match_args_argument(self):
+ @dataclass(match_args=False)
+ class X:
+ a: int
+ self.assertNotIn('__match_args__', X.__dict__)
+
+ @dataclass(match_args=False)
+ class Y:
+ a: int
+ __match_args__ = ('b',)
+ self.assertEqual(Y.__match_args__, ('b',))
+
+ @dataclass(match_args=False)
+ class Z(Y):
+ z: int
+ self.assertEqual(Z.__match_args__, ('b',))
+
+ # Ensure parent dataclass __match_args__ is seen, if child class
+ # specifies match_args=False.
+ @dataclass
+ class A:
+ a: int
+ z: int
+ @dataclass(match_args=False)
+ class B(A):
+ b: int
+ self.assertEqual(B.__match_args__, ('a', 'z'))
+
+ def test_make_dataclasses(self):
+ C = make_dataclass('C', [('x', int), ('y', int)])
+ self.assertEqual(C.__match_args__, ('x', 'y'))
+
+ C = make_dataclass('C', [('x', int), ('y', int)], match_args=True)
+ self.assertEqual(C.__match_args__, ('x', 'y'))
+
+ C = make_dataclass('C', [('x', int), ('y', int)], match_args=False)
+ self.assertNotIn('__match__args__', C.__dict__)
+
+ C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)})
+ self.assertEqual(C.__match_args__, ('z',))
+
+
+class TestKeywordArgs(unittest.TestCase):
+ def test_no_classvar_kwarg(self):
+ msg = 'field a is a ClassVar but specifies kw_only'
+ with self.assertRaisesRegex(TypeError, msg):
+ @dataclass
+ class A:
+ a: ClassVar[int] = field(kw_only=True)
+
+ with self.assertRaisesRegex(TypeError, msg):
+ @dataclass
+ class A:
+ a: ClassVar[int] = field(kw_only=False)
+
+ with self.assertRaisesRegex(TypeError, msg):
+ @dataclass(kw_only=True)
+ class A:
+ a: ClassVar[int] = field(kw_only=False)
+
+ def test_field_marked_as_kwonly(self):
+ #######################
+ # Using dataclass(kw_only=True)
+ @dataclass(kw_only=True)
+ class A:
+ a: int
+ self.assertTrue(fields(A)[0].kw_only)
+
+ @dataclass(kw_only=True)
+ class A:
+ a: int = field(kw_only=True)
+ self.assertTrue(fields(A)[0].kw_only)
+
+ @dataclass(kw_only=True)
+ class A:
+ a: int = field(kw_only=False)
+ self.assertFalse(fields(A)[0].kw_only)
+
+ #######################
+ # Using dataclass(kw_only=False)
+ @dataclass(kw_only=False)
+ class A:
+ a: int
+ self.assertFalse(fields(A)[0].kw_only)
+
+ @dataclass(kw_only=False)
+ class A:
+ a: int = field(kw_only=True)
+ self.assertTrue(fields(A)[0].kw_only)
+
+ @dataclass(kw_only=False)
+ class A:
+ a: int = field(kw_only=False)
+ self.assertFalse(fields(A)[0].kw_only)
+
+ #######################
+ # Not specifying dataclass(kw_only)
+ @dataclass
+ class A:
+ a: int
+ self.assertFalse(fields(A)[0].kw_only)
+
+ @dataclass
+ class A:
+ a: int = field(kw_only=True)
+ self.assertTrue(fields(A)[0].kw_only)
+
+ @dataclass
+ class A:
+ a: int = field(kw_only=False)
+ self.assertFalse(fields(A)[0].kw_only)
+
+ def test_match_args(self):
+ # kw fields don't show up in __match_args__.
+ @dataclass(kw_only=True)
+ class C:
+ a: int
+ self.assertEqual(C(a=42).__match_args__, ())
+
+ @dataclass
+ class C:
+ a: int
+ b: int = field(kw_only=True)
+ self.assertEqual(C(42, b=10).__match_args__, ('a',))
+
+ def test_KW_ONLY(self):
+ @dataclass
+ class A:
+ a: int
+ _: KW_ONLY
+ b: int
+ c: int
+ A(3, c=5, b=4)
+ msg = "takes 2 positional arguments but 4 were given"
+ with self.assertRaisesRegex(TypeError, msg):
+ A(3, 4, 5)
+
+
+ @dataclass(kw_only=True)
+ class B:
+ a: int
+ _: KW_ONLY
+ b: int
+ c: int
+ B(a=3, b=4, c=5)
+ msg = "takes 1 positional argument but 4 were given"
+ with self.assertRaisesRegex(TypeError, msg):
+ B(3, 4, 5)
+
+ # Explicitly make a field that follows KW_ONLY be non-keyword-only.
+ @dataclass
+ class C:
+ a: int
+ _: KW_ONLY
+ b: int
+ c: int = field(kw_only=False)
+ c = C(1, 2, b=3)
+ self.assertEqual(c.a, 1)
+ self.assertEqual(c.b, 3)
+ self.assertEqual(c.c, 2)
+ c = C(1, b=3, c=2)
+ self.assertEqual(c.a, 1)
+ self.assertEqual(c.b, 3)
+ self.assertEqual(c.c, 2)
+ c = C(1, b=3, c=2)
+ self.assertEqual(c.a, 1)
+ self.assertEqual(c.b, 3)
+ self.assertEqual(c.c, 2)
+ c = C(c=2, b=3, a=1)
+ self.assertEqual(c.a, 1)
+ self.assertEqual(c.b, 3)
+ self.assertEqual(c.c, 2)
+
+ def test_KW_ONLY_as_string(self):
+ @dataclass
+ class A:
+ a: int
+ _: 'dataclasses.KW_ONLY'
+ b: int
+ c: int
+ A(3, c=5, b=4)
+ msg = "takes 2 positional arguments but 4 were given"
+ with self.assertRaisesRegex(TypeError, msg):
+ A(3, 4, 5)
+
+ def test_KW_ONLY_twice(self):
+ msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified"
+
+ with self.assertRaisesRegex(TypeError, msg):
+ @dataclass
+ class A:
+ a: int
+ X: KW_ONLY
+ Y: KW_ONLY
+ b: int
+ c: int
+
+ with self.assertRaisesRegex(TypeError, msg):
+ @dataclass
+ class A:
+ a: int
+ X: KW_ONLY
+ b: int
+ Y: KW_ONLY
+ c: int
+
+ with self.assertRaisesRegex(TypeError, msg):
+ @dataclass
+ class A:
+ a: int
+ X: KW_ONLY
+ b: int
+ c: int
+ Y: KW_ONLY
+
+ # But this usage is okay, since it's not using KW_ONLY.
+ @dataclass
+ class A:
+ a: int
+ _: KW_ONLY
+ b: int
+ c: int = field(kw_only=True)
+
+ # And if inheriting, it's okay.
+ @dataclass
+ class A:
+ a: int
+ _: KW_ONLY
+ b: int
+ c: int
+ @dataclass
+ class B(A):
+ _: KW_ONLY
+ d: int
+
+ # Make sure the error is raised in a derived class.
+ with self.assertRaisesRegex(TypeError, msg):
+ @dataclass
+ class A:
+ a: int
+ _: KW_ONLY
+ b: int
+ c: int
+ @dataclass
+ class B(A):
+ X: KW_ONLY
+ d: int
+ Y: KW_ONLY
+
+
+ def test_post_init(self):
+ @dataclass
+ class A:
+ a: int
+ _: KW_ONLY
+ b: InitVar[int]
+ c: int
+ d: InitVar[int]
+ def __post_init__(self, b, d):
+ raise CustomError(f'{b=} {d=}')
+ with self.assertRaisesRegex(CustomError, 'b=3 d=4'):
+ A(1, c=2, b=3, d=4)
+
+ @dataclass
+ class B:
+ a: int
+ _: KW_ONLY
+ b: InitVar[int]
+ c: int
+ d: InitVar[int]
+ def __post_init__(self, b, d):
+ self.a = b
+ self.c = d
+ b = B(1, c=2, b=3, d=4)
+ self.assertEqual(asdict(b), {'a': 3, 'c': 4})
+
+ def test_defaults(self):
+ # For kwargs, make sure we can have defaults after non-defaults.
+ @dataclass
+ class A:
+ a: int = 0
+ _: KW_ONLY
+ b: int
+ c: int = 1
+ d: int
+
+ a = A(d=4, b=3)
+ self.assertEqual(a.a, 0)
+ self.assertEqual(a.b, 3)
+ self.assertEqual(a.c, 1)
+ self.assertEqual(a.d, 4)
+
+ # Make sure we still check for non-kwarg non-defaults not following
+ # defaults.
+ err_regex = "non-default argument 'z' follows default argument"
+ with self.assertRaisesRegex(TypeError, err_regex):
+ @dataclass
+ class A:
+ a: int = 0
+ z: int
+ _: KW_ONLY
+ b: int
+ c: int = 1
+ d: int
+
+ def test_make_dataclass(self):
+ A = make_dataclass("A", ['a'], kw_only=True)
+ self.assertTrue(fields(A)[0].kw_only)
+
+ B = make_dataclass("B",
+ ['a', ('b', int, field(kw_only=False))],
+ kw_only=True)
+ self.assertTrue(fields(B)[0].kw_only)
+ self.assertFalse(fields(B)[1].kw_only)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/Tools/dump_github_issues.py b/Tools/dump_github_issues.py
new file mode 100644
index 000000000..daec51c50
--- /dev/null
+++ b/Tools/dump_github_issues.py
@@ -0,0 +1,142 @@
+"""
+Dump the GitHub issues of the current project to a file (.json.gz).
+
+Usage: python3 Tools/dump_github_issues.py
+"""
+
+import configparser
+import gzip
+import json
+import os.path
+
+from datetime import datetime
+from urllib.request import urlopen
+
+GIT_CONFIG_FILE = ".git/config"
+
+
+class RateLimitReached(Exception):
+ pass
+
+
+def gen_urls(repo):
+ i = 0
+ while True:
+ yield f"https://api.github.com/repos/{repo}/issues?state=all&per_page=100&page={i}"
+ i += 1
+
+
+def read_rate_limit():
+ with urlopen("https://api.github.com/rate_limit") as p:
+ return json.load(p)
+
+
+def parse_rate_limit(limits):
+ limits = limits['resources']['core']
+ return limits['limit'], limits['remaining'], datetime.fromtimestamp(limits['reset'])
+
+
+def load_url(url):
+ with urlopen(url) as p:
+ data = json.load(p)
+ if isinstance(data, dict) and 'rate limit' in data.get('message', ''):
+ raise RateLimitReached()
+
+ assert isinstance(data, list), type(data)
+ return data or None # None indicates empty last page
+
+
+def join_list_data(lists):
+ result = []
+ for data in lists:
+ if not data:
+ break
+ result.extend(data)
+ return result
+
+
+def output_filename(repo):
+ timestamp = datetime.now()
+ return f"github_issues_{repo.replace('/', '_')}_{timestamp.strftime('%Y%m%d_%H%M%S')}.json.gz"
+
+
+def write_gzjson(file_name, data, indent=2):
+ with gzip.open(file_name, "wt", encoding='utf-8') as gz:
+ json.dump(data, gz, indent=indent)
+
+
+def find_origin_url(git_config=GIT_CONFIG_FILE):
+ assert os.path.exists(git_config)
+ parser = configparser.ConfigParser()
+ parser.read(git_config)
+ return parser.get('remote "origin"', 'url')
+
+
+def parse_repo_name(git_url):
+ if git_url.endswith('.git'):
+ git_url = git_url[:-4]
+ return '/'.join(git_url.split('/')[-2:])
+
+
+def dump_issues(repo):
+ """Main entry point."""
+ print(f"Reading issues from repo '{repo}'")
+ urls = gen_urls(repo)
+ try:
+ paged_data = map(load_url, urls)
+ issues = join_list_data(paged_data)
+ except RateLimitReached:
+ limit, remaining, reset_time = parse_rate_limit(read_rate_limit())
+ print(f"FAILURE: Rate limits ({limit}) reached, remaining: {remaining}, reset at {reset_time}")
+ return
+
+ filename = output_filename(repo)
+ print(f"Writing {len(issues)} to {filename}")
+ write_gzjson(filename, issues)
+
+
+### TESTS
+
+def test_join_list_data():
+ assert join_list_data([]) == []
+ assert join_list_data([[1,2]]) == [1,2]
+ assert join_list_data([[1,2], [3]]) == [1,2,3]
+ assert join_list_data([[0], [1,2], [3]]) == [0,1,2,3]
+ assert join_list_data([[0], [1,2], [[[]],[]]]) == [0,1,2,[[]],[]]
+
+
+def test_output_filename():
+ filename = output_filename("re/po")
+ import re
+ assert re.match(r"github_issues_re_po_[0-9]{8}_[0-9]{6}\.json", filename)
+
+
+def test_find_origin_url():
+ assert find_origin_url()
+
+
+def test_parse_repo_name():
+ assert parse_repo_name("https://github.com/cython/cython") == "cython/cython"
+ assert parse_repo_name("git+ssh://git@github.com/cython/cython.git") == "cython/cython"
+ assert parse_repo_name("git+ssh://git@github.com/fork/cython.git") == "fork/cython"
+
+
+def test_write_gzjson():
+ import tempfile
+ with tempfile.NamedTemporaryFile() as tmp:
+ write_gzjson(tmp.name, [{}])
+
+ # test JSON format
+ with gzip.open(tmp.name) as f:
+ assert json.load(f) == [{}]
+
+ # test indentation
+ with gzip.open(tmp.name) as f:
+ assert f.read() == b'[\n {}\n]'
+
+
+### MAIN
+
+if __name__ == '__main__':
+ repo_name = parse_repo_name(find_origin_url())
+ dump_issues(repo_name)
diff --git a/Tools/gen_tests_for_posix_pxds.py b/Tools/gen_tests_for_posix_pxds.py
new file mode 100644
index 000000000..c92c49a6a
--- /dev/null
+++ b/Tools/gen_tests_for_posix_pxds.py
@@ -0,0 +1,41 @@
+#!/usr/bin/python3
+
+from pathlib import Path
+
+PROJECT_ROOT = Path(__file__) / "../.."
+POSIX_PXDS_DIR = PROJECT_ROOT / "Cython/Includes/posix"
+TEST_PATH = PROJECT_ROOT / "tests/compile/posix_pxds.pyx"
+
+def main():
+ datas = [
+ "# tag: posix\n"
+ "# mode: compile\n"
+ "\n"
+ "# This file is generated by `Tools/gen_tests_for_posix_pxds.py`.\n"
+ "\n"
+ "cimport posix\n"
+ ]
+
+ filenames = sorted(map(lambda path: path.name, POSIX_PXDS_DIR.iterdir()))
+
+ for name in filenames:
+ if name == "__init__.pxd":
+ continue
+ if name.endswith(".pxd"):
+ name = name[:-4]
+ else:
+ continue
+
+ s = (
+ "cimport posix.{name}\n"
+ "from posix cimport {name}\n"
+ "from posix.{name} cimport *\n"
+ ).format(name=name)
+
+ datas.append(s)
+
+ with open(TEST_PATH, "w", encoding="utf-8", newline="\n") as f:
+ f.write("\n".join(datas))
+
+if __name__ == "__main__":
+ main()
diff --git a/Tools/make_dataclass_tests.py b/Tools/make_dataclass_tests.py
new file mode 100644
index 000000000..dc38eee70
--- /dev/null
+++ b/Tools/make_dataclass_tests.py
@@ -0,0 +1,443 @@
+# Used to generate tests/run/test_dataclasses.pyx but translating the CPython test suite
+# dataclass file. Initially run using Python 3.10 - this file is not designed to be
+# backwards compatible since it will be run manually and infrequently.
+
+import ast
+import os.path
+import sys
+
+unavailable_functions = frozenset(
+ {
+ "dataclass_textanno", # part of CPython test module
+ "dataclass_module_1", # part of CPython test module
+ "make_dataclass", # not implemented in Cython dataclasses (probably won't be implemented)
+ }
+)
+
+skip_tests = frozenset(
+ {
+ # needs Cython compile
+ # ====================
+ ("TestCase", "test_field_default_default_factory_error"),
+ ("TestCase", "test_two_fields_one_default"),
+ ("TestCase", "test_overwrite_hash"),
+ ("TestCase", "test_eq_order"),
+ ("TestCase", "test_no_unhashable_default"),
+ ("TestCase", "test_disallowed_mutable_defaults"),
+ ("TestCase", "test_classvar_default_factory"),
+ ("TestCase", "test_field_metadata_mapping"),
+ ("TestFieldNoAnnotation", "test_field_without_annotation"),
+ (
+ "TestFieldNoAnnotation",
+ "test_field_without_annotation_but_annotation_in_base",
+ ),
+ (
+ "TestFieldNoAnnotation",
+ "test_field_without_annotation_but_annotation_in_base_not_dataclass",
+ ),
+ ("TestOrdering", "test_overwriting_order"),
+ ("TestHash", "test_hash_rules"),
+ ("TestHash", "test_hash_no_args"),
+ ("TestFrozen", "test_inherit_nonfrozen_from_empty_frozen"),
+ ("TestFrozen", "test_inherit_nonfrozen_from_frozen"),
+ ("TestFrozen", "test_inherit_frozen_from_nonfrozen"),
+ ("TestFrozen", "test_overwriting_frozen"),
+ ("TestSlots", "test_add_slots_when_slots_exists"),
+ ("TestSlots", "test_cant_inherit_from_iterator_slots"),
+ ("TestSlots", "test_weakref_slot_without_slot"),
+ ("TestKeywordArgs", "test_no_classvar_kwarg"),
+ ("TestKeywordArgs", "test_KW_ONLY_twice"),
+ ("TestKeywordArgs", "test_defaults"),
+ # uses local variable in class definition
+ ("TestCase", "test_default_factory"),
+ ("TestCase", "test_default_factory_with_no_init"),
+ ("TestCase", "test_field_default"),
+ ("TestCase", "test_function_annotations"),
+ ("TestDescriptors", "test_lookup_on_instance"),
+ ("TestCase", "test_default_factory_not_called_if_value_given"),
+ ("TestCase", "test_class_attrs"),
+ ("TestCase", "test_hash_field_rules"),
+ ("TestStringAnnotations",), # almost all the texts here use local variables
+ # Currently unsupported
+ # =====================
+ (
+ "TestOrdering",
+ "test_functools_total_ordering",
+ ), # combination of cython dataclass and total_ordering
+ ("TestCase", "test_missing_default_factory"), # we're MISSING MISSING
+ ("TestCase", "test_missing_default"), # MISSING
+ ("TestCase", "test_missing_repr"), # MISSING
+ ("TestSlots",), # __slots__ isn't understood
+ ("TestMatchArgs",),
+ ("TestKeywordArgs", "test_field_marked_as_kwonly"),
+ ("TestKeywordArgs", "test_match_args"),
+ ("TestKeywordArgs", "test_KW_ONLY"),
+ ("TestKeywordArgs", "test_KW_ONLY_as_string"),
+ ("TestKeywordArgs", "test_post_init"),
+ (
+ "TestCase",
+ "test_class_var_frozen",
+ ), # __annotations__ not present on cdef classes https://github.com/cython/cython/issues/4519
+ ("TestCase", "test_dont_include_other_annotations"), # __annotations__
+ ("TestDocString",), # don't think cython dataclasses currently set __doc__
+ # either cython.dataclasses.field or cython.dataclasses.dataclass called directly as functions
+ # (will probably never be supported)
+ ("TestCase", "test_field_repr"),
+ ("TestCase", "test_dynamic_class_creation"),
+ ("TestCase", "test_dynamic_class_creation_using_field"),
+ # Requires inheritance from non-cdef class
+ ("TestCase", "test_is_dataclass_genericalias"),
+ ("TestCase", "test_generic_extending"),
+ ("TestCase", "test_generic_dataclasses"),
+ ("TestCase", "test_generic_dynamic"),
+ ("TestInit", "test_inherit_from_protocol"),
+ ("TestAbstract", "test_abc_implementation"),
+ ("TestAbstract", "test_maintain_abc"),
+ # Requires multiple inheritance from extension types
+ ("TestCase", "test_post_init_not_auto_added"),
+ # Refers to nonlocal from enclosing function
+ (
+ "TestCase",
+ "test_post_init_staticmethod",
+ ), # TODO replicate the gist of the test elsewhere
+ # PEP487 isn't support in Cython
+ ("TestDescriptors", "test_non_descriptor"),
+ ("TestDescriptors", "test_set_name"),
+ ("TestDescriptors", "test_setting_field_calls_set"),
+ ("TestDescriptors", "test_setting_uninitialized_descriptor_field"),
+ # Looks up __dict__, which cdef classes don't typically have
+ ("TestCase", "test_init_false_no_default"),
+ ("TestCase", "test_init_var_inheritance"), # __dict__ again
+ ("TestCase", "test_base_has_init"),
+ ("TestInit", "test_base_has_init"), # needs __dict__ for vars
+ # Requires arbitrary attributes to be writeable
+ ("TestCase", "test_post_init_super"),
+ ('TestCase', 'test_init_in_order'),
+ # Cython being strict about argument types - expected difference
+ ("TestDescriptors", "test_getting_field_calls_get"),
+ ("TestDescriptors", "test_init_calls_set"),
+ ("TestHash", "test_eq_only"),
+ # I think an expected difference with cdef classes - the property will be in the dict
+ ("TestCase", "test_items_in_dicts"),
+ # These tests are probably fine, but the string substitution in this file doesn't get it right
+ ("TestRepr", "test_repr"),
+ ("TestCase", "test_not_in_repr"),
+ ('TestRepr', 'test_no_repr'),
+ # class variable doesn't exist in Cython so uninitialized variable appears differently - for now this is deliberate
+ ('TestInit', 'test_no_init'),
+ # I believe the test works but the ordering functions do appear in the class dict (and default slot wrappers which
+ # just raise NotImplementedError
+ ('TestOrdering', 'test_no_order'),
+ # not possible to add attributes on extension types
+ ("TestCase", "test_post_init_classmethod"),
+ # Cannot redefine the same field in a base dataclass (tested in dataclass_e6)
+ ("TestCase", "test_field_order"),
+ (
+ "TestCase",
+ "test_overwrite_fields_in_derived_class",
+ ),
+ # Bugs
+ #======
+ # not specifically a dataclass issue - a C int crashes classvar
+ ("TestCase", "test_class_var"),
+ (
+ "TestFrozen",
+ ), # raises AttributeError, not FrozenInstanceError (may be hard to fix)
+ ('TestCase', 'test_post_init'), # Works except for AttributeError instead of FrozenInstanceError
+ ("TestReplace", "test_frozen"), # AttributeError not FrozenInstanceError
+ (
+ "TestCase",
+ "test_dataclasses_qualnames",
+ ), # doesn't define __setattr__ and just relies on Cython to enforce readonly properties
+ ("TestCase", "test_compare_subclasses"), # wrong comparison
+ ("TestCase", "test_simple_compare"), # wrong comparison
+ (
+ "TestCase",
+ "test_field_named_self",
+ ), # I think just an error in inspecting the signature
+ (
+ "TestCase",
+ "test_init_var_default_factory",
+ ), # should be raising a compile error
+ ("TestCase", "test_init_var_no_default"), # should be raising a compile error
+ ("TestCase", "test_init_var_with_default"), # not sure...
+ ("TestReplace", "test_initvar_with_default_value"), # needs investigating
+ # Maybe bugs?
+ # ==========
+ # non-default argument 'z' follows default argument in dataclass __init__ - this message looks right to me!
+ ("TestCase", "test_class_marker"),
+ # cython.dataclasses.field parameter 'metadata' must be a literal value - possibly not something we can support?
+ ("TestCase", "test_field_metadata_custom_mapping"),
+ (
+ "TestCase",
+ "test_class_var_default_factory",
+ ), # possibly to do with ClassVar being assigned a field
+ (
+ "TestCase",
+ "test_class_var_with_default",
+ ), # possibly to do with ClassVar being assigned a field
+ (
+ "TestDescriptors",
+ ), # mostly don't work - I think this may be a limitation of cdef classes but needs investigating
+ }
+)
+
+version_specific_skips = {
+ # The version numbers are the first version that the test should be run on
+ ("TestCase", "test_init_var_preserve_type"): (
+ 3,
+ 10,
+ ), # needs language support for | operator on types
+}
+
+class DataclassInDecorators(ast.NodeVisitor):
+ found = False
+
+ def visit_Name(self, node):
+ if node.id == "dataclass":
+ self.found = True
+ return self.generic_visit(node)
+
+ def generic_visit(self, node):
+ if self.found:
+ return # skip
+ return super().generic_visit(node)
+
+
+def dataclass_in_decorators(decorator_list):
+ finder = DataclassInDecorators()
+ for dec in decorator_list:
+ finder.visit(dec)
+ if finder.found:
+ return True
+ return False
+
+
+class SubstituteNameString(ast.NodeTransformer):
+ def __init__(self, substitutions):
+ super().__init__()
+ self.substitutions = substitutions
+
+ def visit_Constant(self, node):
+ # attempt to handle some difference in class names
+ # (note: requires Python>=3.8)
+ if isinstance(node.value, str):
+ if node.value.find("<locals>") != -1:
+ import re
+
+ new_value = new_value2 = re.sub("[\w.]*<locals>", "", node.value)
+ for key, value in self.substitutions.items():
+ new_value2 = re.sub(f"(?<![\w])[.]{key}(?![\w])", value, new_value2)
+ if new_value != new_value2:
+ node.value = new_value2
+ return node
+
+
+class SubstituteName(SubstituteNameString):
+ def visit_Name(self, node):
+ if isinstance(node.ctx, ast.Store): # don't reassign lhs
+ return node
+ replacement = self.substitutions.get(node.id, None)
+ if replacement is not None:
+ return ast.Name(id=replacement, ctx=node.ctx)
+ else:
+ return node
+
+
+class IdentifyCdefClasses(ast.NodeVisitor):
+ def __init__(self):
+ super().__init__()
+ self.top_level_class = True
+ self.classes = {}
+ self.cdef_classes = set()
+
+ def visit_ClassDef(self, node):
+ top_level_class, self.top_level_class = self.top_level_class, False
+ try:
+ if not top_level_class:
+ self.classes[node.name] = node
+ if dataclass_in_decorators(node.decorator_list):
+ self.handle_cdef_class(node)
+ self.generic_visit(node) # any nested classes in it?
+ else:
+ self.generic_visit(node)
+ finally:
+ self.top_level_class = top_level_class
+
+ def visit_FunctionDef(self, node):
+ classes, self.classes = self.classes, {}
+ self.generic_visit(node)
+ self.classes = classes
+
+ def handle_cdef_class(self, cls_node):
+ if cls_node not in self.cdef_classes:
+ self.cdef_classes.add(cls_node)
+ # go back through previous classes we've seen and pick out any first bases
+ if cls_node.bases and isinstance(cls_node.bases[0], ast.Name):
+ base0_node = self.classes.get(cls_node.bases[0].id)
+ if base0_node:
+ self.handle_cdef_class(base0_node)
+
+
+class ExtractDataclassesToTopLevel(ast.NodeTransformer):
+ def __init__(self, cdef_classes_set):
+ super().__init__()
+ self.nested_name = []
+ self.current_function_global_classes = []
+ self.global_classes = []
+ self.cdef_classes_set = cdef_classes_set
+ self.used_names = set()
+ self.collected_substitutions = {}
+ self.uses_unavailable_name = False
+ self.top_level_class = True
+
+ def visit_ClassDef(self, node):
+ if not self.top_level_class:
+ # Include any non-toplevel class in this to be able
+ # to test inheritance.
+
+ self.generic_visit(node) # any nested classes in it?
+ if not node.body:
+ node.body.append(ast.Pass)
+
+ # First, make it a C class.
+ if node in self.cdef_classes_set:
+ node.decorator_list.append(ast.Name(id="cclass", ctx=ast.Load()))
+ # otherwise move it to the global scope, but don't make it cdef
+ # change the name
+ old_name = node.name
+ new_name = "_".join([node.name] + self.nested_name)
+ while new_name in self.used_names:
+ new_name = new_name + "_"
+ node.name = new_name
+ self.current_function_global_classes.append(node)
+ self.used_names.add(new_name)
+ # hmmmm... possibly there's a few cases where there's more than one name?
+ self.collected_substitutions[old_name] = node.name
+
+ return ast.Assign(
+ targets=[ast.Name(id=old_name, ctx=ast.Store())],
+ value=ast.Name(id=new_name, ctx=ast.Load()),
+ lineno=-1,
+ )
+ else:
+ top_level_class, self.top_level_class = self.top_level_class, False
+ self.nested_name.append(node.name)
+ if tuple(self.nested_name) in skip_tests:
+ self.top_level_class = top_level_class
+ self.nested_name.pop()
+ return None
+ self.generic_visit(node)
+ self.nested_name.pop()
+ if not node.body:
+ node.body.append(ast.Pass())
+ self.top_level_class = top_level_class
+ return node
+
+ def visit_FunctionDef(self, node):
+ self.nested_name.append(node.name)
+ if tuple(self.nested_name) in skip_tests:
+ self.nested_name.pop()
+ return None
+ if tuple(self.nested_name) in version_specific_skips:
+ version = version_specific_skips[tuple(self.nested_name)]
+ decorator = ast.parse(
+ f"skip_on_versions_below({version})", mode="eval"
+ ).body
+ node.decorator_list.append(decorator)
+ collected_subs, self.collected_substitutions = self.collected_substitutions, {}
+ uses_unavailable_name, self.uses_unavailable_name = (
+ self.uses_unavailable_name,
+ False,
+ )
+ current_func_globs, self.current_function_global_classes = (
+ self.current_function_global_classes,
+ [],
+ )
+
+ # visit once to work out what the substitutions should be
+ self.generic_visit(node)
+ if self.collected_substitutions:
+ # replace strings in this function
+ node = SubstituteNameString(self.collected_substitutions).visit(node)
+ replacer = SubstituteName(self.collected_substitutions)
+ # replace any base classes
+ for global_class in self.current_function_global_classes:
+ global_class = replacer.visit(global_class)
+ self.global_classes.append(self.current_function_global_classes)
+
+ self.nested_name.pop()
+ self.collected_substitutions = collected_subs
+ if self.uses_unavailable_name:
+ node = None
+ self.uses_unavailable_name = uses_unavailable_name
+ self.current_function_global_classes = current_func_globs
+ return node
+
+ def visit_Name(self, node):
+ if node.id in unavailable_functions:
+ self.uses_unavailable_name = True
+ return self.generic_visit(node)
+
+ def visit_Import(self, node):
+ return None # drop imports, we add these into the text ourself
+
+ def visit_ImportFrom(self, node):
+ return None # drop imports, we add these into the text ourself
+
+ def visit_Call(self, node):
+ if (
+ isinstance(node.func, ast.Attribute)
+ and node.func.attr == "assertRaisesRegex"
+ ):
+ # we end up with a bunch of subtle name changes that are very hard to correct for
+ # therefore, replace with "assertRaises"
+ node.func.attr = "assertRaises"
+ node.args.pop()
+ return self.generic_visit(node)
+
+ def visit_Module(self, node):
+ self.generic_visit(node)
+ node.body[0:0] = self.global_classes
+ return node
+
+ def visit_AnnAssign(self, node):
+ # string annotations are forward declarations but the string will be wrong
+ # (because we're renaming the class)
+ if (isinstance(node.annotation, ast.Constant) and
+ isinstance(node.annotation.value, str)):
+ # although it'd be good to resolve these declarations, for the
+ # sake of the tests they only need to be "object"
+ node.annotation = ast.Name(id="object", ctx=ast.Load)
+
+ return node
+
+
+def main():
+ script_path = os.path.split(sys.argv[0])[0]
+ filename = "test_dataclasses.py"
+ py_module_path = os.path.join(script_path, "dataclass_test_data", filename)
+ with open(py_module_path, "r") as f:
+ tree = ast.parse(f.read(), filename)
+
+ cdef_class_finder = IdentifyCdefClasses()
+ cdef_class_finder.visit(tree)
+ transformer = ExtractDataclassesToTopLevel(cdef_class_finder.cdef_classes)
+ tree = transformer.visit(tree)
+
+ output_path = os.path.join(script_path, "..", "tests", "run", filename + "x")
+ with open(output_path, "w") as f:
+ print("# AUTO-GENERATED BY Tools/make_dataclass_tests.py", file=f)
+ print("# DO NOT EDIT", file=f)
+ print(file=f)
+ # the directive doesn't get applied outside the include if it's put
+ # in the pxi file
+ print("# cython: language_level=3", file=f)
+ # any extras Cython needs to add go in this include file
+ print('include "test_dataclasses.pxi"', file=f)
+ print(file=f)
+ print(ast.unparse(tree), file=f)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/Tools/rules.bzl b/Tools/rules.bzl
index cd3eed58f..c59af6a99 100644
--- a/Tools/rules.bzl
+++ b/Tools/rules.bzl
@@ -11,8 +11,8 @@ load("@cython//Tools:rules.bzl", "pyx_library")
pyx_library(name = 'mylib',
srcs = ['a.pyx', 'a.pxd', 'b.py', 'pkg/__init__.py', 'pkg/c.pyx'],
- py_deps = ['//py_library/dep'],
- data = ['//other/data'],
+ # python library deps passed to py_library
+ deps = ['//py_library/dep']
)
The __init__.py file must be in your srcs list so that Cython can resolve