diff options
Diffstat (limited to 'Tools')
-rw-r--r-- | Tools/BUILD.bazel | 1 | ||||
-rw-r--r-- | Tools/ci-run.sh | 76 | ||||
-rw-r--r-- | Tools/cython-mode.el | 303 | ||||
-rw-r--r-- | Tools/dataclass_test_data/test_dataclasses.py | 4266 | ||||
-rw-r--r-- | Tools/dump_github_issues.py | 142 | ||||
-rw-r--r-- | Tools/gen_tests_for_posix_pxds.py | 41 | ||||
-rw-r--r-- | Tools/make_dataclass_tests.py | 443 | ||||
-rw-r--r-- | Tools/rules.bzl | 4 |
8 files changed, 4956 insertions, 320 deletions
diff --git a/Tools/BUILD.bazel b/Tools/BUILD.bazel index e69de29bb..8b1378917 100644 --- a/Tools/BUILD.bazel +++ b/Tools/BUILD.bazel @@ -0,0 +1 @@ + diff --git a/Tools/ci-run.sh b/Tools/ci-run.sh index 09e9ae318..d42365fd6 100644 --- a/Tools/ci-run.sh +++ b/Tools/ci-run.sh @@ -1,24 +1,24 @@ #!/usr/bin/bash +set -x + GCC_VERSION=${GCC_VERSION:=8} # Set up compilers if [[ $TEST_CODE_STYLE == "1" ]]; then - echo "Skipping compiler setup" + echo "Skipping compiler setup: Code style run" elif [[ $OSTYPE == "linux-gnu"* ]]; then echo "Setting up linux compiler" echo "Installing requirements [apt]" sudo apt-add-repository -y "ppa:ubuntu-toolchain-r/test" sudo apt update -y -q - sudo apt install -y -q ccache gdb python-dbg python3-dbg gcc-$GCC_VERSION || exit 1 + sudo apt install -y -q gdb python3-dbg gcc-$GCC_VERSION || exit 1 ALTERNATIVE_ARGS="" if [[ $BACKEND == *"cpp"* ]]; then sudo apt install -y -q g++-$GCC_VERSION || exit 1 ALTERNATIVE_ARGS="--slave /usr/bin/g++ g++ /usr/bin/g++-$GCC_VERSION" fi - sudo /usr/sbin/update-ccache-symlinks - echo "/usr/lib/ccache" >> $GITHUB_PATH # export ccache to path sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-$GCC_VERSION 60 $ALTERNATIVE_ARGS @@ -32,7 +32,27 @@ elif [[ $OSTYPE == "darwin"* ]]; then export CC="clang -Wno-deprecated-declarations" export CXX="clang++ -stdlib=libc++ -Wno-deprecated-declarations" else - echo "No setup specified for $OSTYPE" + echo "Skipping compiler setup: No setup specified for $OSTYPE" +fi + +if [[ $COVERAGE == "1" ]]; then + echo "Skip setting up compilation caches" +elif [[ $OSTYPE == "msys" ]]; then + echo "Set up sccache" + echo "TODO: Make a soft symlink to sccache" +else + echo "Set up ccache" + + echo "/usr/lib/ccache" >> $GITHUB_PATH # export ccache to path + + echo "Make a soft symlinks to ccache" + cp ccache /usr/local/bin/ + ln -s ccache /usr/local/bin/gcc + ln -s ccache /usr/local/bin/g++ + ln -s ccache /usr/local/bin/cc + ln -s ccache /usr/local/bin/c++ + ln -s ccache /usr/local/bin/clang + ln -s ccache /usr/local/bin/clang++ fi # Set up miniconda @@ -50,14 +70,17 @@ echo "====================" echo "|VERSIONS INSTALLED|" echo "====================" echo "Python $PYTHON_SYS_VERSION" + if [[ $CC ]]; then which ${CC%% *} ${CC%% *} --version fi + if [[ $CXX ]]; then which ${CXX%% *} ${CXX%% *} --version fi + echo "====================" # Install python requirements @@ -68,12 +91,17 @@ if [[ $PYTHON_VERSION == "2.7"* ]]; then elif [[ $PYTHON_VERSION == "3."[45]* ]]; then python -m pip install wheel || exit 1 python -m pip install -r test-requirements-34.txt || exit 1 +elif [[ $PYTHON_VERSION == "pypy-2.7" ]]; then + pip install wheel || exit 1 + pip install -r test-requirements-pypy27.txt || exit 1 +elif [[ $PYTHON_VERSION == "3.1"[2-9]* ]]; then + python -m pip install wheel || exit 1 + python -m pip install -r test-requirements-312.txt || exit 1 else - python -m pip install -U pip setuptools wheel || exit 1 + python -m pip install -U pip "setuptools<60" wheel || exit 1 if [[ $PYTHON_VERSION != *"-dev" || $COVERAGE == "1" ]]; then python -m pip install -r test-requirements.txt || exit 1 - if [[ $PYTHON_VERSION != "pypy"* && $PYTHON_VERSION != "3."[1]* ]]; then python -m pip install -r test-requirements-cpython.txt || exit 1 fi @@ -108,7 +136,16 @@ export PATH="/usr/lib/ccache:$PATH" # Most modern compilers allow the last conflicting option # to override the previous ones, so '-O0 -O3' == '-O3' # This is true for the latest msvc, gcc and clang -CFLAGS="-O0 -ggdb -Wall -Wextra" +if [[ $OSTYPE == "msys" ]]; then # for MSVC cl + # /wd disables warnings + # 4711 warns that function `x` was selected for automatic inline expansion + # 4127 warns that a conditional expression is constant, should be fixed here https://github.com/cython/cython/pull/4317 + # (off by default) 5045 warns that the compiler will insert Spectre mitigations for memory load if the /Qspectre switch is specified + # (off by default) 4820 warns about the code in Python\3.9.6\x64\include ... + CFLAGS="-Od /Z7 /MP /W4 /wd4711 /wd4127 /wd5045 /wd4820" +else + CFLAGS="-O0 -ggdb -Wall -Wextra" +fi # Trying to cover debug assertions in the CI without adding # extra jobs. Therefore, odd-numbered minor versions of Python # running C++ jobs get NDEBUG undefined, and even-numbered @@ -123,6 +160,9 @@ fi if [[ $NO_CYTHON_COMPILE != "1" && $PYTHON_VERSION != "pypy"* ]]; then BUILD_CFLAGS="$CFLAGS -O2" + if [[ $CYTHON_COMPILE_ALL == "1" && $OSTYPE != "msys" ]]; then + BUILD_CFLAGS="$CFLAGS -O3 -g0 -mtune=generic" # make wheel sizes comparable to standard wheel build + fi if [[ $PYTHON_SYS_VERSION == "2"* ]]; then BUILD_CFLAGS="$BUILD_CFLAGS -fno-strict-aliasing" fi @@ -134,24 +174,30 @@ if [[ $NO_CYTHON_COMPILE != "1" && $PYTHON_VERSION != "pypy"* ]]; then if [[ $CYTHON_COMPILE_ALL == "1" ]]; then SETUP_ARGS="$SETUP_ARGS --cython-compile-all" fi - #SETUP_ARGS="$SETUP_ARGS - # $(python -c 'import sys; print("-j5" if sys.version_info >= (3,5) else "")')" + # It looks like parallel build may be causing occasional link failures on Windows + # "with exit code 1158". DW isn't completely sure of this, but has disabled it in + # the hope it helps + SETUP_ARGS="$SETUP_ARGS + $(python -c 'import sys; print("-j5" if sys.version_info >= (3,5) and not sys.platform.startswith("win") else "")')" CFLAGS=$BUILD_CFLAGS \ python setup.py build_ext -i $SETUP_ARGS || exit 1 # COVERAGE can be either "" (empty or not set) or "1" (when we set it) # STACKLESS can be either "" (empty or not set) or "true" (when we set it) - # CYTHON_COMPILE_ALL can be either "" (empty or not set) or "1" (when we set it) if [[ $COVERAGE != "1" && $STACKLESS != "true" && $BACKEND != *"cpp"* && - $CYTHON_COMPILE_ALL != "1" && $LIMITED_API == "" && $EXTRA_CFLAGS == "" ]]; then + $LIMITED_API == "" && $EXTRA_CFLAGS == "" ]]; then python setup.py bdist_wheel || exit 1 + ls -l dist/ || true fi + + echo "Extension modules created during the build:" + find Cython -name "*.so" -ls | sort -k11 fi if [[ $TEST_CODE_STYLE == "1" ]]; then - make -C docs html || echo "FIXME: docs build failed!" -elif [[ $PYTHON_VERSION != "pypy"* ]]; then + make -C docs html || exit 1 +elif [[ $PYTHON_VERSION != "pypy"* && $OSTYPE != "msys" ]]; then # Run the debugger tests in python-dbg if available # (but don't fail, because they currently do fail) PYTHON_DBG=$(python -c 'import sys; print("%d.%d" % sys.version_info[:2])') @@ -181,6 +227,6 @@ python runtests.py \ EXIT_CODE=$? -ccache -s 2>/dev/null || true +ccache -s -v -v 2>/dev/null || true exit $EXIT_CODE diff --git a/Tools/cython-mode.el b/Tools/cython-mode.el deleted file mode 100644 index e4be99f5b..000000000 --- a/Tools/cython-mode.el +++ /dev/null @@ -1,303 +0,0 @@ -;;; cython-mode.el --- Major mode for editing Cython files - -;; License: Apache-2.0 - -;;; Commentary: - -;; This should work with python-mode.el as well as either the new -;; python.el or the old. - -;;; Code: - -;; Load python-mode if available, otherwise use builtin emacs python package -(when (not (require 'python-mode nil t)) - (require 'python)) -(eval-when-compile (require 'rx)) - -;;;###autoload -(add-to-list 'auto-mode-alist '("\\.pyx\\'" . cython-mode)) -;;;###autoload -(add-to-list 'auto-mode-alist '("\\.pxd\\'" . cython-mode)) -;;;###autoload -(add-to-list 'auto-mode-alist '("\\.pxi\\'" . cython-mode)) - - -(defvar cython-buffer nil - "Variable pointing to the cython buffer which was compiled.") - -(defun cython-compile () - "Compile the file via Cython." - (interactive) - (let ((cy-buffer (current-buffer))) - (with-current-buffer - (compile compile-command) - (set (make-local-variable 'cython-buffer) cy-buffer) - (add-to-list (make-local-variable 'compilation-finish-functions) - 'cython-compilation-finish)))) - -(defun cython-compilation-finish (buffer how) - "Called when Cython compilation finishes." - ;; XXX could annotate source here - ) - -(defvar cython-mode-map - (let ((map (make-sparse-keymap))) - ;; Will inherit from `python-mode-map' thanks to define-derived-mode. - (define-key map "\C-c\C-c" 'cython-compile) - map) - "Keymap used in `cython-mode'.") - -(defvar cython-font-lock-keywords - `(;; ctypedef statement: "ctypedef (...type... alias)?" - (,(rx - ;; keyword itself - symbol-start (group "ctypedef") - ;; type specifier: at least 1 non-identifier symbol + 1 identifier - ;; symbol and anything but a comment-starter after that. - (opt (regexp "[^a-zA-Z0-9_\n]+[a-zA-Z0-9_][^#\n]*") - ;; type alias: an identifier - symbol-start (group (regexp "[a-zA-Z_]+[a-zA-Z0-9_]*")) - ;; space-or-comments till the end of the line - (* space) (opt "#" (* nonl)) line-end)) - (1 font-lock-keyword-face) - (2 font-lock-type-face nil 'noerror)) - ;; new keywords in Cython language - (,(rx symbol-start - (or "by" "cdef" "cimport" "cpdef" - "extern" "gil" "include" "nogil" "property" "public" - "readonly" "DEF" "IF" "ELIF" "ELSE" - "new" "del" "cppclass" "namespace" "const" - "__stdcall" "__cdecl" "__fastcall" "inline" "api") - symbol-end) - . font-lock-keyword-face) - ;; Question mark won't match at a symbol-end, so 'except?' must be - ;; special-cased. It's simpler to handle it separately than weaving it - ;; into the lengthy list of other keywords. - (,(rx symbol-start "except?") . font-lock-keyword-face) - ;; C and Python types (highlight as builtins) - (,(rx symbol-start - (or - "object" "dict" "list" - ;; basic c type names - "void" "char" "int" "float" "double" "bint" - ;; longness/signed/constness - "signed" "unsigned" "long" "short" - ;; special basic c types - "size_t" "Py_ssize_t" "Py_UNICODE" "Py_UCS4" "ssize_t" "ptrdiff_t") - symbol-end) - . font-lock-builtin-face) - (,(rx symbol-start "NULL" symbol-end) - . font-lock-constant-face) - ;; cdef is used for more than functions, so simply highlighting the next - ;; word is problematic. struct, enum and property work though. - (,(rx symbol-start - (group (or "struct" "enum" "union" - (seq "ctypedef" (+ space "fused")))) - (+ space) (group (regexp "[a-zA-Z_]+[a-zA-Z0-9_]*"))) - (1 font-lock-keyword-face prepend) (2 font-lock-type-face)) - ("\\_<property[ \t]+\\([a-zA-Z_]+[a-zA-Z0-9_]*\\)" - 1 font-lock-function-name-face)) - "Additional font lock keywords for Cython mode.") - -;;;###autoload -(defgroup cython nil "Major mode for editing and compiling Cython files" - :group 'languages - :prefix "cython-" - :link '(url-link :tag "Homepage" "http://cython.org")) - -;;;###autoload -(defcustom cython-default-compile-format "cython -a %s" - "Format for the default command to compile a Cython file. -It will be passed to `format' with `buffer-file-name' as the only other argument." - :group 'cython - :type 'string) - -;; Some functions defined differently in the different python modes -(defun cython-comment-line-p () - "Return non-nil if current line is a comment." - (save-excursion - (back-to-indentation) - (eq ?# (char-after (point))))) - -(defun cython-in-string/comment () - "Return non-nil if point is in a comment or string." - (nth 8 (syntax-ppss))) - -(defalias 'cython-beginning-of-statement - (cond - ;; python-mode.el - ((fboundp 'py-beginning-of-statement) - 'py-beginning-of-statement) - ;; old python.el - ((fboundp 'python-beginning-of-statement) - 'python-beginning-of-statement) - ;; new python.el - ((fboundp 'python-nav-beginning-of-statement) - 'python-nav-beginning-of-statement) - (t (error "Couldn't find implementation for `cython-beginning-of-statement'")))) - -(defalias 'cython-beginning-of-block - (cond - ;; python-mode.el - ((fboundp 'py-beginning-of-block) - 'py-beginning-of-block) - ;; old python.el - ((fboundp 'python-beginning-of-block) - 'python-beginning-of-block) - ;; new python.el - ((fboundp 'python-nav-beginning-of-block) - 'python-nav-beginning-of-block) - (t (error "Couldn't find implementation for `cython-beginning-of-block'")))) - -(defalias 'cython-end-of-statement - (cond - ;; python-mode.el - ((fboundp 'py-end-of-statement) - 'py-end-of-statement) - ;; old python.el - ((fboundp 'python-end-of-statement) - 'python-end-of-statement) - ;; new python.el - ((fboundp 'python-nav-end-of-statement) - 'python-nav-end-of-statement) - (t (error "Couldn't find implementation for `cython-end-of-statement'")))) - -(defun cython-open-block-statement-p (&optional bos) - "Return non-nil if statement at point opens a Cython block. -BOS non-nil means point is known to be at beginning of statement." - (save-excursion - (unless bos (cython-beginning-of-statement)) - (looking-at (rx (and (or "if" "else" "elif" "while" "for" "def" "cdef" "cpdef" - "class" "try" "except" "finally" "with" - "EXAMPLES:" "TESTS:" "INPUT:" "OUTPUT:") - symbol-end))))) - -(defun cython-beginning-of-defun () - "`beginning-of-defun-function' for Cython. -Finds beginning of innermost nested class or method definition. -Returns the name of the definition found at the end, or nil if -reached start of buffer." - (let ((ci (current-indentation)) - (def-re (rx line-start (0+ space) (or "def" "cdef" "cpdef" "class") (1+ space) - (group (1+ (or word (syntax symbol)))))) - found lep) ;; def-line - (if (cython-comment-line-p) - (setq ci most-positive-fixnum)) - (while (and (not (bobp)) (not found)) - ;; Treat bol at beginning of function as outside function so - ;; that successive C-M-a makes progress backwards. - ;;(setq def-line (looking-at def-re)) - (unless (bolp) (end-of-line)) - (setq lep (line-end-position)) - (if (and (re-search-backward def-re nil 'move) - ;; Must be less indented or matching top level, or - ;; equally indented if we started on a definition line. - (let ((in (current-indentation))) - (or (and (zerop ci) (zerop in)) - (= lep (line-end-position)) ; on initial line - ;; Not sure why it was like this -- fails in case of - ;; last internal function followed by first - ;; non-def statement of the main body. - ;;(and def-line (= in ci)) - (= in ci) - (< in ci))) - (not (cython-in-string/comment))) - (setq found t))))) - -(defun cython-end-of-defun () - "`end-of-defun-function' for Cython. -Finds end of innermost nested class or method definition." - (let ((orig (point)) - (pattern (rx line-start (0+ space) (or "def" "cdef" "cpdef" "class") space))) - ;; Go to start of current block and check whether it's at top - ;; level. If it is, and not a block start, look forward for - ;; definition statement. - (when (cython-comment-line-p) - (end-of-line) - (forward-comment most-positive-fixnum)) - (when (not (cython-open-block-statement-p)) - (cython-beginning-of-block)) - (if (zerop (current-indentation)) - (unless (cython-open-block-statement-p) - (while (and (re-search-forward pattern nil 'move) - (cython-in-string/comment))) ; just loop - (unless (eobp) - (beginning-of-line))) - ;; Don't move before top-level statement that would end defun. - (end-of-line) - (beginning-of-defun)) - ;; If we got to the start of buffer, look forward for - ;; definition statement. - (when (and (bobp) (not (looking-at (rx (or "def" "cdef" "cpdef" "class"))))) - (while (and (not (eobp)) - (re-search-forward pattern nil 'move) - (cython-in-string/comment)))) ; just loop - ;; We're at a definition statement (or end-of-buffer). - ;; This is where we should have started when called from end-of-defun - (unless (eobp) - (let ((block-indentation (current-indentation))) - (python-nav-end-of-statement) - (while (and (forward-line 1) - (not (eobp)) - (or (and (> (current-indentation) block-indentation) - (or (cython-end-of-statement) t)) - ;; comment or empty line - (looking-at (rx (0+ space) (or eol "#")))))) - (forward-comment -1)) - ;; Count trailing space in defun (but not trailing comments). - (skip-syntax-forward " >") - (unless (eobp) ; e.g. missing final newline - (beginning-of-line))) - ;; Catch pathological cases like this, where the beginning-of-defun - ;; skips to a definition we're not in: - ;; if ...: - ;; ... - ;; else: - ;; ... # point here - ;; ... - ;; def ... - (if (< (point) orig) - (goto-char (point-max))))) - -(defun cython-current-defun () - "`add-log-current-defun-function' for Cython." - (save-excursion - ;; Move up the tree of nested `class' and `def' blocks until we - ;; get to zero indentation, accumulating the defined names. - (let ((start t) - accum) - (while (or start (> (current-indentation) 0)) - (setq start nil) - (cython-beginning-of-block) - (end-of-line) - (beginning-of-defun) - (if (looking-at (rx (0+ space) (or "def" "cdef" "cpdef" "class") (1+ space) - (group (1+ (or word (syntax symbol)))))) - (push (match-string 1) accum))) - (if accum (mapconcat 'identity accum "."))))) - -;;;###autoload -(define-derived-mode cython-mode python-mode "Cython" - "Major mode for Cython development, derived from Python mode. - -\\{cython-mode-map}" - (font-lock-add-keywords nil cython-font-lock-keywords) - (set (make-local-variable 'outline-regexp) - (rx (* space) (or "class" "def" "cdef" "cpdef" "elif" "else" "except" "finally" - "for" "if" "try" "while" "with") - symbol-end)) - (set (make-local-variable 'beginning-of-defun-function) - #'cython-beginning-of-defun) - (set (make-local-variable 'end-of-defun-function) - #'cython-end-of-defun) - (set (make-local-variable 'compile-command) - (format cython-default-compile-format (shell-quote-argument (or buffer-file-name "")))) - (set (make-local-variable 'add-log-current-defun-function) - #'cython-current-defun) - (add-hook 'which-func-functions #'cython-current-defun nil t) - (add-to-list (make-local-variable 'compilation-finish-functions) - 'cython-compilation-finish)) - -(provide 'cython-mode) - -;;; cython-mode.el ends here diff --git a/Tools/dataclass_test_data/test_dataclasses.py b/Tools/dataclass_test_data/test_dataclasses.py new file mode 100644 index 000000000..e2eab6957 --- /dev/null +++ b/Tools/dataclass_test_data/test_dataclasses.py @@ -0,0 +1,4266 @@ +# Deliberately use "from dataclasses import *". Every name in __all__ +# is tested, so they all must be present. This is a way to catch +# missing ones. + +from dataclasses import * + +import abc +import pickle +import inspect +import builtins +import types +import weakref +import unittest +from unittest.mock import Mock +from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol +from typing import get_type_hints +from collections import deque, OrderedDict, namedtuple +from functools import total_ordering + +import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation. +import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation. + +# Just any custom exception we can catch. +class CustomError(Exception): pass + +class TestCase(unittest.TestCase): + def test_no_fields(self): + @dataclass + class C: + pass + + o = C() + self.assertEqual(len(fields(C)), 0) + + def test_no_fields_but_member_variable(self): + @dataclass + class C: + i = 0 + + o = C() + self.assertEqual(len(fields(C)), 0) + + def test_one_field_no_default(self): + @dataclass + class C: + x: int + + o = C(42) + self.assertEqual(o.x, 42) + + def test_field_default_default_factory_error(self): + msg = "cannot specify both default and default_factory" + with self.assertRaisesRegex(ValueError, msg): + @dataclass + class C: + x: int = field(default=1, default_factory=int) + + def test_field_repr(self): + int_field = field(default=1, init=True, repr=False) + int_field.name = "id" + repr_output = repr(int_field) + expected_output = "Field(name='id',type=None," \ + f"default=1,default_factory={MISSING!r}," \ + "init=True,repr=False,hash=None," \ + "compare=True,metadata=mappingproxy({})," \ + f"kw_only={MISSING!r}," \ + "_field_type=None)" + + self.assertEqual(repr_output, expected_output) + + def test_named_init_params(self): + @dataclass + class C: + x: int + + o = C(x=32) + self.assertEqual(o.x, 32) + + def test_two_fields_one_default(self): + @dataclass + class C: + x: int + y: int = 0 + + o = C(3) + self.assertEqual((o.x, o.y), (3, 0)) + + # Non-defaults following defaults. + with self.assertRaisesRegex(TypeError, + "non-default argument 'y' follows " + "default argument"): + @dataclass + class C: + x: int = 0 + y: int + + # A derived class adds a non-default field after a default one. + with self.assertRaisesRegex(TypeError, + "non-default argument 'y' follows " + "default argument"): + @dataclass + class B: + x: int = 0 + + @dataclass + class C(B): + y: int + + # Override a base class field and add a default to + # a field which didn't use to have a default. + with self.assertRaisesRegex(TypeError, + "non-default argument 'y' follows " + "default argument"): + @dataclass + class B: + x: int + y: int + + @dataclass + class C(B): + x: int = 0 + + def test_overwrite_hash(self): + # Test that declaring this class isn't an error. It should + # use the user-provided __hash__. + @dataclass(frozen=True) + class C: + x: int + def __hash__(self): + return 301 + self.assertEqual(hash(C(100)), 301) + + # Test that declaring this class isn't an error. It should + # use the generated __hash__. + @dataclass(frozen=True) + class C: + x: int + def __eq__(self, other): + return False + self.assertEqual(hash(C(100)), hash((100,))) + + # But this one should generate an exception, because with + # unsafe_hash=True, it's an error to have a __hash__ defined. + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __hash__'): + @dataclass(unsafe_hash=True) + class C: + def __hash__(self): + pass + + # Creating this class should not generate an exception, + # because even though __hash__ exists before @dataclass is + # called, (due to __eq__ being defined), since it's None + # that's okay. + @dataclass(unsafe_hash=True) + class C: + x: int + def __eq__(self): + pass + # The generated hash function works as we'd expect. + self.assertEqual(hash(C(10)), hash((10,))) + + # Creating this class should generate an exception, because + # __hash__ exists and is not None, which it would be if it + # had been auto-generated due to __eq__ being defined. + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __hash__'): + @dataclass(unsafe_hash=True) + class C: + x: int + def __eq__(self): + pass + def __hash__(self): + pass + + def test_overwrite_fields_in_derived_class(self): + # Note that x from C1 replaces x in Base, but the order remains + # the same as defined in Base. + @dataclass + class Base: + x: Any = 15.0 + y: int = 0 + + @dataclass + class C1(Base): + z: int = 10 + x: int = 15 + + o = Base() + self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)') + + o = C1() + self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)') + + o = C1(x=5) + self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)') + + def test_field_named_self(self): + @dataclass + class C: + self: str + c=C('foo') + self.assertEqual(c.self, 'foo') + + # Make sure the first parameter is not named 'self'. + sig = inspect.signature(C.__init__) + first = next(iter(sig.parameters)) + self.assertNotEqual('self', first) + + # But we do use 'self' if no field named self. + @dataclass + class C: + selfx: str + + # Make sure the first parameter is named 'self'. + sig = inspect.signature(C.__init__) + first = next(iter(sig.parameters)) + self.assertEqual('self', first) + + def test_field_named_object(self): + @dataclass + class C: + object: str + c = C('foo') + self.assertEqual(c.object, 'foo') + + def test_field_named_object_frozen(self): + @dataclass(frozen=True) + class C: + object: str + c = C('foo') + self.assertEqual(c.object, 'foo') + + def test_field_named_like_builtin(self): + # Attribute names can shadow built-in names + # since code generation is used. + # Ensure that this is not happening. + exclusions = {'None', 'True', 'False'} + builtins_names = sorted( + b for b in builtins.__dict__.keys() + if not b.startswith('__') and b not in exclusions + ) + attributes = [(name, str) for name in builtins_names] + C = make_dataclass('C', attributes) + + c = C(*[name for name in builtins_names]) + + for name in builtins_names: + self.assertEqual(getattr(c, name), name) + + def test_field_named_like_builtin_frozen(self): + # Attribute names can shadow built-in names + # since code generation is used. + # Ensure that this is not happening + # for frozen data classes. + exclusions = {'None', 'True', 'False'} + builtins_names = sorted( + b for b in builtins.__dict__.keys() + if not b.startswith('__') and b not in exclusions + ) + attributes = [(name, str) for name in builtins_names] + C = make_dataclass('C', attributes, frozen=True) + + c = C(*[name for name in builtins_names]) + + for name in builtins_names: + self.assertEqual(getattr(c, name), name) + + def test_0_field_compare(self): + # Ensure that order=False is the default. + @dataclass + class C0: + pass + + @dataclass(order=False) + class C1: + pass + + for cls in [C0, C1]: + with self.subTest(cls=cls): + self.assertEqual(cls(), cls()) + for idx, fn in enumerate([lambda a, b: a < b, + lambda a, b: a <= b, + lambda a, b: a > b, + lambda a, b: a >= b]): + with self.subTest(idx=idx): + with self.assertRaisesRegex(TypeError, + f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): + fn(cls(), cls()) + + @dataclass(order=True) + class C: + pass + self.assertLessEqual(C(), C()) + self.assertGreaterEqual(C(), C()) + + def test_1_field_compare(self): + # Ensure that order=False is the default. + @dataclass + class C0: + x: int + + @dataclass(order=False) + class C1: + x: int + + for cls in [C0, C1]: + with self.subTest(cls=cls): + self.assertEqual(cls(1), cls(1)) + self.assertNotEqual(cls(0), cls(1)) + for idx, fn in enumerate([lambda a, b: a < b, + lambda a, b: a <= b, + lambda a, b: a > b, + lambda a, b: a >= b]): + with self.subTest(idx=idx): + with self.assertRaisesRegex(TypeError, + f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): + fn(cls(0), cls(0)) + + @dataclass(order=True) + class C: + x: int + self.assertLess(C(0), C(1)) + self.assertLessEqual(C(0), C(1)) + self.assertLessEqual(C(1), C(1)) + self.assertGreater(C(1), C(0)) + self.assertGreaterEqual(C(1), C(0)) + self.assertGreaterEqual(C(1), C(1)) + + def test_simple_compare(self): + # Ensure that order=False is the default. + @dataclass + class C0: + x: int + y: int + + @dataclass(order=False) + class C1: + x: int + y: int + + for cls in [C0, C1]: + with self.subTest(cls=cls): + self.assertEqual(cls(0, 0), cls(0, 0)) + self.assertEqual(cls(1, 2), cls(1, 2)) + self.assertNotEqual(cls(1, 0), cls(0, 0)) + self.assertNotEqual(cls(1, 0), cls(1, 1)) + for idx, fn in enumerate([lambda a, b: a < b, + lambda a, b: a <= b, + lambda a, b: a > b, + lambda a, b: a >= b]): + with self.subTest(idx=idx): + with self.assertRaisesRegex(TypeError, + f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): + fn(cls(0, 0), cls(0, 0)) + + @dataclass(order=True) + class C: + x: int + y: int + + for idx, fn in enumerate([lambda a, b: a == b, + lambda a, b: a <= b, + lambda a, b: a >= b]): + with self.subTest(idx=idx): + self.assertTrue(fn(C(0, 0), C(0, 0))) + + for idx, fn in enumerate([lambda a, b: a < b, + lambda a, b: a <= b, + lambda a, b: a != b]): + with self.subTest(idx=idx): + self.assertTrue(fn(C(0, 0), C(0, 1))) + self.assertTrue(fn(C(0, 1), C(1, 0))) + self.assertTrue(fn(C(1, 0), C(1, 1))) + + for idx, fn in enumerate([lambda a, b: a > b, + lambda a, b: a >= b, + lambda a, b: a != b]): + with self.subTest(idx=idx): + self.assertTrue(fn(C(0, 1), C(0, 0))) + self.assertTrue(fn(C(1, 0), C(0, 1))) + self.assertTrue(fn(C(1, 1), C(1, 0))) + + def test_compare_subclasses(self): + # Comparisons fail for subclasses, even if no fields + # are added. + @dataclass + class B: + i: int + + @dataclass + class C(B): + pass + + for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False), + (lambda a, b: a != b, True)]): + with self.subTest(idx=idx): + self.assertEqual(fn(B(0), C(0)), expected) + + for idx, fn in enumerate([lambda a, b: a < b, + lambda a, b: a <= b, + lambda a, b: a > b, + lambda a, b: a >= b]): + with self.subTest(idx=idx): + with self.assertRaisesRegex(TypeError, + "not supported between instances of 'B' and 'C'"): + fn(B(0), C(0)) + + def test_eq_order(self): + # Test combining eq and order. + for (eq, order, result ) in [ + (False, False, 'neither'), + (False, True, 'exception'), + (True, False, 'eq_only'), + (True, True, 'both'), + ]: + with self.subTest(eq=eq, order=order): + if result == 'exception': + with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'): + @dataclass(eq=eq, order=order) + class C: + pass + else: + @dataclass(eq=eq, order=order) + class C: + pass + + if result == 'neither': + self.assertNotIn('__eq__', C.__dict__) + self.assertNotIn('__lt__', C.__dict__) + self.assertNotIn('__le__', C.__dict__) + self.assertNotIn('__gt__', C.__dict__) + self.assertNotIn('__ge__', C.__dict__) + elif result == 'both': + self.assertIn('__eq__', C.__dict__) + self.assertIn('__lt__', C.__dict__) + self.assertIn('__le__', C.__dict__) + self.assertIn('__gt__', C.__dict__) + self.assertIn('__ge__', C.__dict__) + elif result == 'eq_only': + self.assertIn('__eq__', C.__dict__) + self.assertNotIn('__lt__', C.__dict__) + self.assertNotIn('__le__', C.__dict__) + self.assertNotIn('__gt__', C.__dict__) + self.assertNotIn('__ge__', C.__dict__) + else: + assert False, f'unknown result {result!r}' + + def test_field_no_default(self): + @dataclass + class C: + x: int = field() + + self.assertEqual(C(5).x, 5) + + with self.assertRaisesRegex(TypeError, + r"__init__\(\) missing 1 required " + "positional argument: 'x'"): + C() + + def test_field_default(self): + default = object() + @dataclass + class C: + x: object = field(default=default) + + self.assertIs(C.x, default) + c = C(10) + self.assertEqual(c.x, 10) + + # If we delete the instance attribute, we should then see the + # class attribute. + del c.x + self.assertIs(c.x, default) + + self.assertIs(C().x, default) + + def test_not_in_repr(self): + @dataclass + class C: + x: int = field(repr=False) + with self.assertRaises(TypeError): + C() + c = C(10) + self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()') + + @dataclass + class C: + x: int = field(repr=False) + y: int + c = C(10, 20) + self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)') + + def test_not_in_compare(self): + @dataclass + class C: + x: int = 0 + y: int = field(compare=False, default=4) + + self.assertEqual(C(), C(0, 20)) + self.assertEqual(C(1, 10), C(1, 20)) + self.assertNotEqual(C(3), C(4, 10)) + self.assertNotEqual(C(3, 10), C(4, 10)) + + def test_no_unhashable_default(self): + # See bpo-44674. + class Unhashable: + __hash__ = None + + unhashable_re = 'mutable default .* for field a is not allowed' + with self.assertRaisesRegex(ValueError, unhashable_re): + @dataclass + class A: + a: dict = {} + + with self.assertRaisesRegex(ValueError, unhashable_re): + @dataclass + class A: + a: Any = Unhashable() + + # Make sure that the machinery looking for hashability is using the + # class's __hash__, not the instance's __hash__. + with self.assertRaisesRegex(ValueError, unhashable_re): + unhashable = Unhashable() + # This shouldn't make the variable hashable. + unhashable.__hash__ = lambda: 0 + @dataclass + class A: + a: Any = unhashable + + def test_hash_field_rules(self): + # Test all 6 cases of: + # hash=True/False/None + # compare=True/False + for (hash_, compare, result ) in [ + (True, False, 'field' ), + (True, True, 'field' ), + (False, False, 'absent'), + (False, True, 'absent'), + (None, False, 'absent'), + (None, True, 'field' ), + ]: + with self.subTest(hash=hash_, compare=compare): + @dataclass(unsafe_hash=True) + class C: + x: int = field(compare=compare, hash=hash_, default=5) + + if result == 'field': + # __hash__ contains the field. + self.assertEqual(hash(C(5)), hash((5,))) + elif result == 'absent': + # The field is not present in the hash. + self.assertEqual(hash(C(5)), hash(())) + else: + assert False, f'unknown result {result!r}' + + def test_init_false_no_default(self): + # If init=False and no default value, then the field won't be + # present in the instance. + @dataclass + class C: + x: int = field(init=False) + + self.assertNotIn('x', C().__dict__) + + @dataclass + class C: + x: int + y: int = 0 + z: int = field(init=False) + t: int = 10 + + self.assertNotIn('z', C(0).__dict__) + self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0}) + + def test_class_marker(self): + @dataclass + class C: + x: int + y: str = field(init=False, default=None) + z: str = field(repr=False) + + the_fields = fields(C) + # the_fields is a tuple of 3 items, each value + # is in __annotations__. + self.assertIsInstance(the_fields, tuple) + for f in the_fields: + self.assertIs(type(f), Field) + self.assertIn(f.name, C.__annotations__) + + self.assertEqual(len(the_fields), 3) + + self.assertEqual(the_fields[0].name, 'x') + self.assertEqual(the_fields[0].type, int) + self.assertFalse(hasattr(C, 'x')) + self.assertTrue (the_fields[0].init) + self.assertTrue (the_fields[0].repr) + self.assertEqual(the_fields[1].name, 'y') + self.assertEqual(the_fields[1].type, str) + self.assertIsNone(getattr(C, 'y')) + self.assertFalse(the_fields[1].init) + self.assertTrue (the_fields[1].repr) + self.assertEqual(the_fields[2].name, 'z') + self.assertEqual(the_fields[2].type, str) + self.assertFalse(hasattr(C, 'z')) + self.assertTrue (the_fields[2].init) + self.assertFalse(the_fields[2].repr) + + def test_field_order(self): + @dataclass + class B: + a: str = 'B:a' + b: str = 'B:b' + c: str = 'B:c' + + @dataclass + class C(B): + b: str = 'C:b' + + self.assertEqual([(f.name, f.default) for f in fields(C)], + [('a', 'B:a'), + ('b', 'C:b'), + ('c', 'B:c')]) + + @dataclass + class D(B): + c: str = 'D:c' + + self.assertEqual([(f.name, f.default) for f in fields(D)], + [('a', 'B:a'), + ('b', 'B:b'), + ('c', 'D:c')]) + + @dataclass + class E(D): + a: str = 'E:a' + d: str = 'E:d' + + self.assertEqual([(f.name, f.default) for f in fields(E)], + [('a', 'E:a'), + ('b', 'B:b'), + ('c', 'D:c'), + ('d', 'E:d')]) + + def test_class_attrs(self): + # We only have a class attribute if a default value is + # specified, either directly or via a field with a default. + default = object() + @dataclass + class C: + x: int + y: int = field(repr=False) + z: object = default + t: int = field(default=100) + + self.assertFalse(hasattr(C, 'x')) + self.assertFalse(hasattr(C, 'y')) + self.assertIs (C.z, default) + self.assertEqual(C.t, 100) + + def test_disallowed_mutable_defaults(self): + # For the known types, don't allow mutable default values. + for typ, empty, non_empty in [(list, [], [1]), + (dict, {}, {0:1}), + (set, set(), set([1])), + ]: + with self.subTest(typ=typ): + # Can't use a zero-length value. + with self.assertRaisesRegex(ValueError, + f'mutable default {typ} for field ' + 'x is not allowed'): + @dataclass + class Point: + x: typ = empty + + + # Nor a non-zero-length value + with self.assertRaisesRegex(ValueError, + f'mutable default {typ} for field ' + 'y is not allowed'): + @dataclass + class Point: + y: typ = non_empty + + # Check subtypes also fail. + class Subclass(typ): pass + + with self.assertRaisesRegex(ValueError, + f"mutable default .*Subclass'>" + ' for field z is not allowed' + ): + @dataclass + class Point: + z: typ = Subclass() + + # Because this is a ClassVar, it can be mutable. + @dataclass + class C: + z: ClassVar[typ] = typ() + + # Because this is a ClassVar, it can be mutable. + @dataclass + class C: + x: ClassVar[typ] = Subclass() + + def test_deliberately_mutable_defaults(self): + # If a mutable default isn't in the known list of + # (list, dict, set), then it's okay. + class Mutable: + def __init__(self): + self.l = [] + + @dataclass + class C: + x: Mutable + + # These 2 instances will share this value of x. + lst = Mutable() + o1 = C(lst) + o2 = C(lst) + self.assertEqual(o1, o2) + o1.x.l.extend([1, 2]) + self.assertEqual(o1, o2) + self.assertEqual(o1.x.l, [1, 2]) + self.assertIs(o1.x, o2.x) + + def test_no_options(self): + # Call with dataclass(). + @dataclass() + class C: + x: int + + self.assertEqual(C(42).x, 42) + + def test_not_tuple(self): + # Make sure we can't be compared to a tuple. + @dataclass + class Point: + x: int + y: int + self.assertNotEqual(Point(1, 2), (1, 2)) + + # And that we can't compare to another unrelated dataclass. + @dataclass + class C: + x: int + y: int + self.assertNotEqual(Point(1, 3), C(1, 3)) + + def test_not_other_dataclass(self): + # Test that some of the problems with namedtuple don't happen + # here. + @dataclass + class Point3D: + x: int + y: int + z: int + + @dataclass + class Date: + year: int + month: int + day: int + + self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3)) + self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3)) + + # Make sure we can't unpack. + with self.assertRaisesRegex(TypeError, 'unpack'): + x, y, z = Point3D(4, 5, 6) + + # Make sure another class with the same field names isn't + # equal. + @dataclass + class Point3Dv1: + x: int = 0 + y: int = 0 + z: int = 0 + self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1()) + + def test_function_annotations(self): + # Some dummy class and instance to use as a default. + class F: + pass + f = F() + + def validate_class(cls): + # First, check __annotations__, even though they're not + # function annotations. + self.assertEqual(cls.__annotations__['i'], int) + self.assertEqual(cls.__annotations__['j'], str) + self.assertEqual(cls.__annotations__['k'], F) + self.assertEqual(cls.__annotations__['l'], float) + self.assertEqual(cls.__annotations__['z'], complex) + + # Verify __init__. + + signature = inspect.signature(cls.__init__) + # Check the return type, should be None. + self.assertIs(signature.return_annotation, None) + + # Check each parameter. + params = iter(signature.parameters.values()) + param = next(params) + # This is testing an internal name, and probably shouldn't be tested. + self.assertEqual(param.name, 'self') + param = next(params) + self.assertEqual(param.name, 'i') + self.assertIs (param.annotation, int) + self.assertEqual(param.default, inspect.Parameter.empty) + self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) + param = next(params) + self.assertEqual(param.name, 'j') + self.assertIs (param.annotation, str) + self.assertEqual(param.default, inspect.Parameter.empty) + self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) + param = next(params) + self.assertEqual(param.name, 'k') + self.assertIs (param.annotation, F) + # Don't test for the default, since it's set to MISSING. + self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) + param = next(params) + self.assertEqual(param.name, 'l') + self.assertIs (param.annotation, float) + # Don't test for the default, since it's set to MISSING. + self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) + self.assertRaises(StopIteration, next, params) + + + @dataclass + class C: + i: int + j: str + k: F = f + l: float=field(default=None) + z: complex=field(default=3+4j, init=False) + + validate_class(C) + + # Now repeat with __hash__. + @dataclass(frozen=True, unsafe_hash=True) + class C: + i: int + j: str + k: F = f + l: float=field(default=None) + z: complex=field(default=3+4j, init=False) + + validate_class(C) + + def test_missing_default(self): + # Test that MISSING works the same as a default not being + # specified. + @dataclass + class C: + x: int=field(default=MISSING) + with self.assertRaisesRegex(TypeError, + r'__init__\(\) missing 1 required ' + 'positional argument'): + C() + self.assertNotIn('x', C.__dict__) + + @dataclass + class D: + x: int + with self.assertRaisesRegex(TypeError, + r'__init__\(\) missing 1 required ' + 'positional argument'): + D() + self.assertNotIn('x', D.__dict__) + + def test_missing_default_factory(self): + # Test that MISSING works the same as a default factory not + # being specified (which is really the same as a default not + # being specified, too). + @dataclass + class C: + x: int=field(default_factory=MISSING) + with self.assertRaisesRegex(TypeError, + r'__init__\(\) missing 1 required ' + 'positional argument'): + C() + self.assertNotIn('x', C.__dict__) + + @dataclass + class D: + x: int=field(default=MISSING, default_factory=MISSING) + with self.assertRaisesRegex(TypeError, + r'__init__\(\) missing 1 required ' + 'positional argument'): + D() + self.assertNotIn('x', D.__dict__) + + def test_missing_repr(self): + self.assertIn('MISSING_TYPE object', repr(MISSING)) + + def test_dont_include_other_annotations(self): + @dataclass + class C: + i: int + def foo(self) -> int: + return 4 + @property + def bar(self) -> int: + return 5 + self.assertEqual(list(C.__annotations__), ['i']) + self.assertEqual(C(10).foo(), 4) + self.assertEqual(C(10).bar, 5) + self.assertEqual(C(10).i, 10) + + def test_post_init(self): + # Just make sure it gets called + @dataclass + class C: + def __post_init__(self): + raise CustomError() + with self.assertRaises(CustomError): + C() + + @dataclass + class C: + i: int = 10 + def __post_init__(self): + if self.i == 10: + raise CustomError() + with self.assertRaises(CustomError): + C() + # post-init gets called, but doesn't raise. This is just + # checking that self is used correctly. + C(5) + + # If there's not an __init__, then post-init won't get called. + @dataclass(init=False) + class C: + def __post_init__(self): + raise CustomError() + # Creating the class won't raise + C() + + @dataclass + class C: + x: int = 0 + def __post_init__(self): + self.x *= 2 + self.assertEqual(C().x, 0) + self.assertEqual(C(2).x, 4) + + # Make sure that if we're frozen, post-init can't set + # attributes. + @dataclass(frozen=True) + class C: + x: int = 0 + def __post_init__(self): + self.x *= 2 + with self.assertRaises(FrozenInstanceError): + C() + + def test_post_init_super(self): + # Make sure super() post-init isn't called by default. + class B: + def __post_init__(self): + raise CustomError() + + @dataclass + class C(B): + def __post_init__(self): + self.x = 5 + + self.assertEqual(C().x, 5) + + # Now call super(), and it will raise. + @dataclass + class C(B): + def __post_init__(self): + super().__post_init__() + + with self.assertRaises(CustomError): + C() + + # Make sure post-init is called, even if not defined in our + # class. + @dataclass + class C(B): + pass + + with self.assertRaises(CustomError): + C() + + def test_post_init_staticmethod(self): + flag = False + @dataclass + class C: + x: int + y: int + @staticmethod + def __post_init__(): + nonlocal flag + flag = True + + self.assertFalse(flag) + c = C(3, 4) + self.assertEqual((c.x, c.y), (3, 4)) + self.assertTrue(flag) + + def test_post_init_classmethod(self): + @dataclass + class C: + flag = False + x: int + y: int + @classmethod + def __post_init__(cls): + cls.flag = True + + self.assertFalse(C.flag) + c = C(3, 4) + self.assertEqual((c.x, c.y), (3, 4)) + self.assertTrue(C.flag) + + def test_post_init_not_auto_added(self): + # See bpo-46757, which had proposed always adding __post_init__. As + # Raymond Hettinger pointed out, that would be a breaking change. So, + # add a test to make sure that the current behavior doesn't change. + + @dataclass + class A0: + pass + + @dataclass + class B0: + b_called: bool = False + def __post_init__(self): + self.b_called = True + + @dataclass + class C0(A0, B0): + c_called: bool = False + def __post_init__(self): + super().__post_init__() + self.c_called = True + + # Since A0 has no __post_init__, and one wasn't automatically added + # (because that's the rule: it's never added by @dataclass, it's only + # the class author that can add it), then B0.__post_init__ is called. + # Verify that. + c = C0() + self.assertTrue(c.b_called) + self.assertTrue(c.c_called) + + ###################################### + # Now, the same thing, except A1 defines __post_init__. + @dataclass + class A1: + def __post_init__(self): + pass + + @dataclass + class B1: + b_called: bool = False + def __post_init__(self): + self.b_called = True + + @dataclass + class C1(A1, B1): + c_called: bool = False + def __post_init__(self): + super().__post_init__() + self.c_called = True + + # This time, B1.__post_init__ isn't being called. This mimics what + # would happen if A1.__post_init__ had been automatically added, + # instead of manually added as we see here. This test isn't really + # needed, but I'm including it just to demonstrate the changed + # behavior when A1 does define __post_init__. + c = C1() + self.assertFalse(c.b_called) + self.assertTrue(c.c_called) + + def test_class_var(self): + # Make sure ClassVars are ignored in __init__, __repr__, etc. + @dataclass + class C: + x: int + y: int = 10 + z: ClassVar[int] = 1000 + w: ClassVar[int] = 2000 + t: ClassVar[int] = 3000 + s: ClassVar = 4000 + + c = C(5) + self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)') + self.assertEqual(len(fields(C)), 2) # We have 2 fields. + self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars. + self.assertEqual(c.z, 1000) + self.assertEqual(c.w, 2000) + self.assertEqual(c.t, 3000) + self.assertEqual(c.s, 4000) + C.z += 1 + self.assertEqual(c.z, 1001) + c = C(20) + self.assertEqual((c.x, c.y), (20, 10)) + self.assertEqual(c.z, 1001) + self.assertEqual(c.w, 2000) + self.assertEqual(c.t, 3000) + self.assertEqual(c.s, 4000) + + def test_class_var_no_default(self): + # If a ClassVar has no default value, it should not be set on the class. + @dataclass + class C: + x: ClassVar[int] + + self.assertNotIn('x', C.__dict__) + + def test_class_var_default_factory(self): + # It makes no sense for a ClassVar to have a default factory. When + # would it be called? Call it yourself, since it's class-wide. + with self.assertRaisesRegex(TypeError, + 'cannot have a default factory'): + @dataclass + class C: + x: ClassVar[int] = field(default_factory=int) + + self.assertNotIn('x', C.__dict__) + + def test_class_var_with_default(self): + # If a ClassVar has a default value, it should be set on the class. + @dataclass + class C: + x: ClassVar[int] = 10 + self.assertEqual(C.x, 10) + + @dataclass + class C: + x: ClassVar[int] = field(default=10) + self.assertEqual(C.x, 10) + + def test_class_var_frozen(self): + # Make sure ClassVars work even if we're frozen. + @dataclass(frozen=True) + class C: + x: int + y: int = 10 + z: ClassVar[int] = 1000 + w: ClassVar[int] = 2000 + t: ClassVar[int] = 3000 + + c = C(5) + self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)') + self.assertEqual(len(fields(C)), 2) # We have 2 fields + self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars + self.assertEqual(c.z, 1000) + self.assertEqual(c.w, 2000) + self.assertEqual(c.t, 3000) + # We can still modify the ClassVar, it's only instances that are + # frozen. + C.z += 1 + self.assertEqual(c.z, 1001) + c = C(20) + self.assertEqual((c.x, c.y), (20, 10)) + self.assertEqual(c.z, 1001) + self.assertEqual(c.w, 2000) + self.assertEqual(c.t, 3000) + + def test_init_var_no_default(self): + # If an InitVar has no default value, it should not be set on the class. + @dataclass + class C: + x: InitVar[int] + + self.assertNotIn('x', C.__dict__) + + def test_init_var_default_factory(self): + # It makes no sense for an InitVar to have a default factory. When + # would it be called? Call it yourself, since it's class-wide. + with self.assertRaisesRegex(TypeError, + 'cannot have a default factory'): + @dataclass + class C: + x: InitVar[int] = field(default_factory=int) + + self.assertNotIn('x', C.__dict__) + + def test_init_var_with_default(self): + # If an InitVar has a default value, it should be set on the class. + @dataclass + class C: + x: InitVar[int] = 10 + self.assertEqual(C.x, 10) + + @dataclass + class C: + x: InitVar[int] = field(default=10) + self.assertEqual(C.x, 10) + + def test_init_var(self): + @dataclass + class C: + x: int = None + init_param: InitVar[int] = None + + def __post_init__(self, init_param): + if self.x is None: + self.x = init_param*2 + + c = C(init_param=10) + self.assertEqual(c.x, 20) + + def test_init_var_preserve_type(self): + self.assertEqual(InitVar[int].type, int) + + # Make sure the repr is correct. + self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]') + self.assertEqual(repr(InitVar[List[int]]), + 'dataclasses.InitVar[typing.List[int]]') + self.assertEqual(repr(InitVar[list[int]]), + 'dataclasses.InitVar[list[int]]') + self.assertEqual(repr(InitVar[int|str]), + 'dataclasses.InitVar[int | str]') + + def test_init_var_inheritance(self): + # Note that this deliberately tests that a dataclass need not + # have a __post_init__ function if it has an InitVar field. + # It could just be used in a derived class, as shown here. + @dataclass + class Base: + x: int + init_base: InitVar[int] + + # We can instantiate by passing the InitVar, even though + # it's not used. + b = Base(0, 10) + self.assertEqual(vars(b), {'x': 0}) + + @dataclass + class C(Base): + y: int + init_derived: InitVar[int] + + def __post_init__(self, init_base, init_derived): + self.x = self.x + init_base + self.y = self.y + init_derived + + c = C(10, 11, 50, 51) + self.assertEqual(vars(c), {'x': 21, 'y': 101}) + + def test_default_factory(self): + # Test a factory that returns a new list. + @dataclass + class C: + x: int + y: list = field(default_factory=list) + + c0 = C(3) + c1 = C(3) + self.assertEqual(c0.x, 3) + self.assertEqual(c0.y, []) + self.assertEqual(c0, c1) + self.assertIsNot(c0.y, c1.y) + self.assertEqual(astuple(C(5, [1])), (5, [1])) + + # Test a factory that returns a shared list. + l = [] + @dataclass + class C: + x: int + y: list = field(default_factory=lambda: l) + + c0 = C(3) + c1 = C(3) + self.assertEqual(c0.x, 3) + self.assertEqual(c0.y, []) + self.assertEqual(c0, c1) + self.assertIs(c0.y, c1.y) + self.assertEqual(astuple(C(5, [1])), (5, [1])) + + # Test various other field flags. + # repr + @dataclass + class C: + x: list = field(default_factory=list, repr=False) + self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()') + self.assertEqual(C().x, []) + + # hash + @dataclass(unsafe_hash=True) + class C: + x: list = field(default_factory=list, hash=False) + self.assertEqual(astuple(C()), ([],)) + self.assertEqual(hash(C()), hash(())) + + # init (see also test_default_factory_with_no_init) + @dataclass + class C: + x: list = field(default_factory=list, init=False) + self.assertEqual(astuple(C()), ([],)) + + # compare + @dataclass + class C: + x: list = field(default_factory=list, compare=False) + self.assertEqual(C(), C([1])) + + def test_default_factory_with_no_init(self): + # We need a factory with a side effect. + factory = Mock() + + @dataclass + class C: + x: list = field(default_factory=factory, init=False) + + # Make sure the default factory is called for each new instance. + C().x + self.assertEqual(factory.call_count, 1) + C().x + self.assertEqual(factory.call_count, 2) + + def test_default_factory_not_called_if_value_given(self): + # We need a factory that we can test if it's been called. + factory = Mock() + + @dataclass + class C: + x: int = field(default_factory=factory) + + # Make sure that if a field has a default factory function, + # it's not called if a value is specified. + C().x + self.assertEqual(factory.call_count, 1) + self.assertEqual(C(10).x, 10) + self.assertEqual(factory.call_count, 1) + C().x + self.assertEqual(factory.call_count, 2) + + def test_default_factory_derived(self): + # See bpo-32896. + @dataclass + class Foo: + x: dict = field(default_factory=dict) + + @dataclass + class Bar(Foo): + y: int = 1 + + self.assertEqual(Foo().x, {}) + self.assertEqual(Bar().x, {}) + self.assertEqual(Bar().y, 1) + + @dataclass + class Baz(Foo): + pass + self.assertEqual(Baz().x, {}) + + def test_intermediate_non_dataclass(self): + # Test that an intermediate class that defines + # annotations does not define fields. + + @dataclass + class A: + x: int + + class B(A): + y: int + + @dataclass + class C(B): + z: int + + c = C(1, 3) + self.assertEqual((c.x, c.z), (1, 3)) + + # .y was not initialized. + with self.assertRaisesRegex(AttributeError, + 'object has no attribute'): + c.y + + # And if we again derive a non-dataclass, no fields are added. + class D(C): + t: int + d = D(4, 5) + self.assertEqual((d.x, d.z), (4, 5)) + + def test_classvar_default_factory(self): + # It's an error for a ClassVar to have a factory function. + with self.assertRaisesRegex(TypeError, + 'cannot have a default factory'): + @dataclass + class C: + x: ClassVar[int] = field(default_factory=int) + + def test_is_dataclass(self): + class NotDataClass: + pass + + self.assertFalse(is_dataclass(0)) + self.assertFalse(is_dataclass(int)) + self.assertFalse(is_dataclass(NotDataClass)) + self.assertFalse(is_dataclass(NotDataClass())) + + @dataclass + class C: + x: int + + @dataclass + class D: + d: C + e: int + + c = C(10) + d = D(c, 4) + + self.assertTrue(is_dataclass(C)) + self.assertTrue(is_dataclass(c)) + self.assertFalse(is_dataclass(c.x)) + self.assertTrue(is_dataclass(d.d)) + self.assertFalse(is_dataclass(d.e)) + + def test_is_dataclass_when_getattr_always_returns(self): + # See bpo-37868. + class A: + def __getattr__(self, key): + return 0 + self.assertFalse(is_dataclass(A)) + a = A() + + # Also test for an instance attribute. + class B: + pass + b = B() + b.__dataclass_fields__ = [] + + for obj in a, b: + with self.subTest(obj=obj): + self.assertFalse(is_dataclass(obj)) + + # Indirect tests for _is_dataclass_instance(). + with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): + asdict(obj) + with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): + astuple(obj) + with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'): + replace(obj, x=0) + + def test_is_dataclass_genericalias(self): + @dataclass + class A(types.GenericAlias): + origin: type + args: type + self.assertTrue(is_dataclass(A)) + a = A(list, int) + self.assertTrue(is_dataclass(type(a))) + self.assertTrue(is_dataclass(a)) + + + def test_helper_fields_with_class_instance(self): + # Check that we can call fields() on either a class or instance, + # and get back the same thing. + @dataclass + class C: + x: int + y: float + + self.assertEqual(fields(C), fields(C(0, 0.0))) + + def test_helper_fields_exception(self): + # Check that TypeError is raised if not passed a dataclass or + # instance. + with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): + fields(0) + + class C: pass + with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): + fields(C) + with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): + fields(C()) + + def test_helper_asdict(self): + # Basic tests for asdict(), it should return a new dictionary. + @dataclass + class C: + x: int + y: int + c = C(1, 2) + + self.assertEqual(asdict(c), {'x': 1, 'y': 2}) + self.assertEqual(asdict(c), asdict(c)) + self.assertIsNot(asdict(c), asdict(c)) + c.x = 42 + self.assertEqual(asdict(c), {'x': 42, 'y': 2}) + self.assertIs(type(asdict(c)), dict) + + def test_helper_asdict_raises_on_classes(self): + # asdict() should raise on a class object. + @dataclass + class C: + x: int + y: int + with self.assertRaisesRegex(TypeError, 'dataclass instance'): + asdict(C) + with self.assertRaisesRegex(TypeError, 'dataclass instance'): + asdict(int) + + def test_helper_asdict_copy_values(self): + @dataclass + class C: + x: int + y: List[int] = field(default_factory=list) + initial = [] + c = C(1, initial) + d = asdict(c) + self.assertEqual(d['y'], initial) + self.assertIsNot(d['y'], initial) + c = C(1) + d = asdict(c) + d['y'].append(1) + self.assertEqual(c.y, []) + + def test_helper_asdict_nested(self): + @dataclass + class UserId: + token: int + group: int + @dataclass + class User: + name: str + id: UserId + u = User('Joe', UserId(123, 1)) + d = asdict(u) + self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}}) + self.assertIsNot(asdict(u), asdict(u)) + u.id.group = 2 + self.assertEqual(asdict(u), {'name': 'Joe', + 'id': {'token': 123, 'group': 2}}) + + def test_helper_asdict_builtin_containers(self): + @dataclass + class User: + name: str + id: int + @dataclass + class GroupList: + id: int + users: List[User] + @dataclass + class GroupTuple: + id: int + users: Tuple[User, ...] + @dataclass + class GroupDict: + id: int + users: Dict[str, User] + a = User('Alice', 1) + b = User('Bob', 2) + gl = GroupList(0, [a, b]) + gt = GroupTuple(0, (a, b)) + gd = GroupDict(0, {'first': a, 'second': b}) + self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1}, + {'name': 'Bob', 'id': 2}]}) + self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1}, + {'name': 'Bob', 'id': 2})}) + self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1}, + 'second': {'name': 'Bob', 'id': 2}}}) + + def test_helper_asdict_builtin_object_containers(self): + @dataclass + class Child: + d: object + + @dataclass + class Parent: + child: Child + + self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}}) + self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}}) + + def test_helper_asdict_factory(self): + @dataclass + class C: + x: int + y: int + c = C(1, 2) + d = asdict(c, dict_factory=OrderedDict) + self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)])) + self.assertIsNot(d, asdict(c, dict_factory=OrderedDict)) + c.x = 42 + d = asdict(c, dict_factory=OrderedDict) + self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)])) + self.assertIs(type(d), OrderedDict) + + def test_helper_asdict_namedtuple(self): + T = namedtuple('T', 'a b c') + @dataclass + class C: + x: str + y: T + c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) + + d = asdict(c) + self.assertEqual(d, {'x': 'outer', + 'y': T(1, + {'x': 'inner', + 'y': T(11, 12, 13)}, + 2), + } + ) + + # Now with a dict_factory. OrderedDict is convenient, but + # since it compares to dicts, we also need to have separate + # assertIs tests. + d = asdict(c, dict_factory=OrderedDict) + self.assertEqual(d, {'x': 'outer', + 'y': T(1, + {'x': 'inner', + 'y': T(11, 12, 13)}, + 2), + } + ) + + # Make sure that the returned dicts are actually OrderedDicts. + self.assertIs(type(d), OrderedDict) + self.assertIs(type(d['y'][1]), OrderedDict) + + def test_helper_asdict_namedtuple_key(self): + # Ensure that a field that contains a dict which has a + # namedtuple as a key works with asdict(). + + @dataclass + class C: + f: dict + T = namedtuple('T', 'a') + + c = C({T('an a'): 0}) + + self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}}) + + def test_helper_asdict_namedtuple_derived(self): + class T(namedtuple('Tbase', 'a')): + def my_a(self): + return self.a + + @dataclass + class C: + f: T + + t = T(6) + c = C(t) + + d = asdict(c) + self.assertEqual(d, {'f': T(a=6)}) + # Make sure that t has been copied, not used directly. + self.assertIsNot(d['f'], t) + self.assertEqual(d['f'].my_a(), 6) + + def test_helper_astuple(self): + # Basic tests for astuple(), it should return a new tuple. + @dataclass + class C: + x: int + y: int = 0 + c = C(1) + + self.assertEqual(astuple(c), (1, 0)) + self.assertEqual(astuple(c), astuple(c)) + self.assertIsNot(astuple(c), astuple(c)) + c.y = 42 + self.assertEqual(astuple(c), (1, 42)) + self.assertIs(type(astuple(c)), tuple) + + def test_helper_astuple_raises_on_classes(self): + # astuple() should raise on a class object. + @dataclass + class C: + x: int + y: int + with self.assertRaisesRegex(TypeError, 'dataclass instance'): + astuple(C) + with self.assertRaisesRegex(TypeError, 'dataclass instance'): + astuple(int) + + def test_helper_astuple_copy_values(self): + @dataclass + class C: + x: int + y: List[int] = field(default_factory=list) + initial = [] + c = C(1, initial) + t = astuple(c) + self.assertEqual(t[1], initial) + self.assertIsNot(t[1], initial) + c = C(1) + t = astuple(c) + t[1].append(1) + self.assertEqual(c.y, []) + + def test_helper_astuple_nested(self): + @dataclass + class UserId: + token: int + group: int + @dataclass + class User: + name: str + id: UserId + u = User('Joe', UserId(123, 1)) + t = astuple(u) + self.assertEqual(t, ('Joe', (123, 1))) + self.assertIsNot(astuple(u), astuple(u)) + u.id.group = 2 + self.assertEqual(astuple(u), ('Joe', (123, 2))) + + def test_helper_astuple_builtin_containers(self): + @dataclass + class User: + name: str + id: int + @dataclass + class GroupList: + id: int + users: List[User] + @dataclass + class GroupTuple: + id: int + users: Tuple[User, ...] + @dataclass + class GroupDict: + id: int + users: Dict[str, User] + a = User('Alice', 1) + b = User('Bob', 2) + gl = GroupList(0, [a, b]) + gt = GroupTuple(0, (a, b)) + gd = GroupDict(0, {'first': a, 'second': b}) + self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)])) + self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2)))) + self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)})) + + def test_helper_astuple_builtin_object_containers(self): + @dataclass + class Child: + d: object + + @dataclass + class Parent: + child: Child + + self.assertEqual(astuple(Parent(Child([1]))), (([1],),)) + self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),)) + + def test_helper_astuple_factory(self): + @dataclass + class C: + x: int + y: int + NT = namedtuple('NT', 'x y') + def nt(lst): + return NT(*lst) + c = C(1, 2) + t = astuple(c, tuple_factory=nt) + self.assertEqual(t, NT(1, 2)) + self.assertIsNot(t, astuple(c, tuple_factory=nt)) + c.x = 42 + t = astuple(c, tuple_factory=nt) + self.assertEqual(t, NT(42, 2)) + self.assertIs(type(t), NT) + + def test_helper_astuple_namedtuple(self): + T = namedtuple('T', 'a b c') + @dataclass + class C: + x: str + y: T + c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) + + t = astuple(c) + self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2))) + + # Now, using a tuple_factory. list is convenient here. + t = astuple(c, tuple_factory=list) + self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) + + def test_dynamic_class_creation(self): + cls_dict = {'__annotations__': {'x': int, 'y': int}, + } + + # Create the class. + cls = type('C', (), cls_dict) + + # Make it a dataclass. + cls1 = dataclass(cls) + + self.assertEqual(cls1, cls) + self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2}) + + def test_dynamic_class_creation_using_field(self): + cls_dict = {'__annotations__': {'x': int, 'y': int}, + 'y': field(default=5), + } + + # Create the class. + cls = type('C', (), cls_dict) + + # Make it a dataclass. + cls1 = dataclass(cls) + + self.assertEqual(cls1, cls) + self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5}) + + def test_init_in_order(self): + @dataclass + class C: + a: int + b: int = field() + c: list = field(default_factory=list, init=False) + d: list = field(default_factory=list) + e: int = field(default=4, init=False) + f: int = 4 + + calls = [] + def setattr(self, name, value): + calls.append((name, value)) + + C.__setattr__ = setattr + c = C(0, 1) + self.assertEqual(('a', 0), calls[0]) + self.assertEqual(('b', 1), calls[1]) + self.assertEqual(('c', []), calls[2]) + self.assertEqual(('d', []), calls[3]) + self.assertNotIn(('e', 4), calls) + self.assertEqual(('f', 4), calls[4]) + + def test_items_in_dicts(self): + @dataclass + class C: + a: int + b: list = field(default_factory=list, init=False) + c: list = field(default_factory=list) + d: int = field(default=4, init=False) + e: int = 0 + + c = C(0) + # Class dict + self.assertNotIn('a', C.__dict__) + self.assertNotIn('b', C.__dict__) + self.assertNotIn('c', C.__dict__) + self.assertIn('d', C.__dict__) + self.assertEqual(C.d, 4) + self.assertIn('e', C.__dict__) + self.assertEqual(C.e, 0) + # Instance dict + self.assertIn('a', c.__dict__) + self.assertEqual(c.a, 0) + self.assertIn('b', c.__dict__) + self.assertEqual(c.b, []) + self.assertIn('c', c.__dict__) + self.assertEqual(c.c, []) + self.assertNotIn('d', c.__dict__) + self.assertIn('e', c.__dict__) + self.assertEqual(c.e, 0) + + def test_alternate_classmethod_constructor(self): + # Since __post_init__ can't take params, use a classmethod + # alternate constructor. This is mostly an example to show + # how to use this technique. + @dataclass + class C: + x: int + @classmethod + def from_file(cls, filename): + # In a real example, create a new instance + # and populate 'x' from contents of a file. + value_in_file = 20 + return cls(value_in_file) + + self.assertEqual(C.from_file('filename').x, 20) + + def test_field_metadata_default(self): + # Make sure the default metadata is read-only and of + # zero length. + @dataclass + class C: + i: int + + self.assertFalse(fields(C)[0].metadata) + self.assertEqual(len(fields(C)[0].metadata), 0) + with self.assertRaisesRegex(TypeError, + 'does not support item assignment'): + fields(C)[0].metadata['test'] = 3 + + def test_field_metadata_mapping(self): + # Make sure only a mapping can be passed as metadata + # zero length. + with self.assertRaises(TypeError): + @dataclass + class C: + i: int = field(metadata=0) + + # Make sure an empty dict works. + d = {} + @dataclass + class C: + i: int = field(metadata=d) + self.assertFalse(fields(C)[0].metadata) + self.assertEqual(len(fields(C)[0].metadata), 0) + # Update should work (see bpo-35960). + d['foo'] = 1 + self.assertEqual(len(fields(C)[0].metadata), 1) + self.assertEqual(fields(C)[0].metadata['foo'], 1) + with self.assertRaisesRegex(TypeError, + 'does not support item assignment'): + fields(C)[0].metadata['test'] = 3 + + # Make sure a non-empty dict works. + d = {'test': 10, 'bar': '42', 3: 'three'} + @dataclass + class C: + i: int = field(metadata=d) + self.assertEqual(len(fields(C)[0].metadata), 3) + self.assertEqual(fields(C)[0].metadata['test'], 10) + self.assertEqual(fields(C)[0].metadata['bar'], '42') + self.assertEqual(fields(C)[0].metadata[3], 'three') + # Update should work. + d['foo'] = 1 + self.assertEqual(len(fields(C)[0].metadata), 4) + self.assertEqual(fields(C)[0].metadata['foo'], 1) + with self.assertRaises(KeyError): + # Non-existent key. + fields(C)[0].metadata['baz'] + with self.assertRaisesRegex(TypeError, + 'does not support item assignment'): + fields(C)[0].metadata['test'] = 3 + + def test_field_metadata_custom_mapping(self): + # Try a custom mapping. + class SimpleNameSpace: + def __init__(self, **kw): + self.__dict__.update(kw) + + def __getitem__(self, item): + if item == 'xyzzy': + return 'plugh' + return getattr(self, item) + + def __len__(self): + return self.__dict__.__len__() + + @dataclass + class C: + i: int = field(metadata=SimpleNameSpace(a=10)) + + self.assertEqual(len(fields(C)[0].metadata), 1) + self.assertEqual(fields(C)[0].metadata['a'], 10) + with self.assertRaises(AttributeError): + fields(C)[0].metadata['b'] + # Make sure we're still talking to our custom mapping. + self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh') + + def test_generic_dataclasses(self): + T = TypeVar('T') + + @dataclass + class LabeledBox(Generic[T]): + content: T + label: str = '<unknown>' + + box = LabeledBox(42) + self.assertEqual(box.content, 42) + self.assertEqual(box.label, '<unknown>') + + # Subscripting the resulting class should work, etc. + Alias = List[LabeledBox[int]] + + def test_generic_extending(self): + S = TypeVar('S') + T = TypeVar('T') + + @dataclass + class Base(Generic[T, S]): + x: T + y: S + + @dataclass + class DataDerived(Base[int, T]): + new_field: str + Alias = DataDerived[str] + c = Alias(0, 'test1', 'test2') + self.assertEqual(astuple(c), (0, 'test1', 'test2')) + + class NonDataDerived(Base[int, T]): + def new_method(self): + return self.y + Alias = NonDataDerived[float] + c = Alias(10, 1.0) + self.assertEqual(c.new_method(), 1.0) + + def test_generic_dynamic(self): + T = TypeVar('T') + + @dataclass + class Parent(Generic[T]): + x: T + Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)], + bases=(Parent[int], Generic[T]), namespace={'other': 42}) + self.assertIs(Child[int](1, 2).z, None) + self.assertEqual(Child[int](1, 2, 3).z, 3) + self.assertEqual(Child[int](1, 2, 3).other, 42) + # Check that type aliases work correctly. + Alias = Child[T] + self.assertEqual(Alias[int](1, 2).x, 1) + # Check MRO resolution. + self.assertEqual(Child.__mro__, (Child, Parent, Generic, object)) + + def test_dataclasses_pickleable(self): + global P, Q, R + @dataclass + class P: + x: int + y: int = 0 + @dataclass + class Q: + x: int + y: int = field(default=0, init=False) + @dataclass + class R: + x: int + y: List[int] = field(default_factory=list) + q = Q(1) + q.y = 2 + samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])] + for sample in samples: + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(sample=sample, proto=proto): + new_sample = pickle.loads(pickle.dumps(sample, proto)) + self.assertEqual(sample.x, new_sample.x) + self.assertEqual(sample.y, new_sample.y) + self.assertIsNot(sample, new_sample) + new_sample.x = 42 + another_new_sample = pickle.loads(pickle.dumps(new_sample, proto)) + self.assertEqual(new_sample.x, another_new_sample.x) + self.assertEqual(sample.y, another_new_sample.y) + + def test_dataclasses_qualnames(self): + @dataclass(order=True, unsafe_hash=True, frozen=True) + class A: + x: int + y: int + + self.assertEqual(A.__init__.__name__, "__init__") + for function in ( + '__eq__', + '__lt__', + '__le__', + '__gt__', + '__ge__', + '__hash__', + '__init__', + '__repr__', + '__setattr__', + '__delattr__', + ): + self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}") + + with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"): + A() + + +class TestFieldNoAnnotation(unittest.TestCase): + def test_field_without_annotation(self): + with self.assertRaisesRegex(TypeError, + "'f' is a field but has no type annotation"): + @dataclass + class C: + f = field() + + def test_field_without_annotation_but_annotation_in_base(self): + @dataclass + class B: + f: int + + with self.assertRaisesRegex(TypeError, + "'f' is a field but has no type annotation"): + # This is still an error: make sure we don't pick up the + # type annotation in the base class. + @dataclass + class C(B): + f = field() + + def test_field_without_annotation_but_annotation_in_base_not_dataclass(self): + # Same test, but with the base class not a dataclass. + class B: + f: int + + with self.assertRaisesRegex(TypeError, + "'f' is a field but has no type annotation"): + # This is still an error: make sure we don't pick up the + # type annotation in the base class. + @dataclass + class C(B): + f = field() + + +class TestDocString(unittest.TestCase): + def assertDocStrEqual(self, a, b): + # Because 3.6 and 3.7 differ in how inspect.signature work + # (see bpo #32108), for the time being just compare them with + # whitespace stripped. + self.assertEqual(a.replace(' ', ''), b.replace(' ', '')) + + def test_existing_docstring_not_overridden(self): + @dataclass + class C: + """Lorem ipsum""" + x: int + + self.assertEqual(C.__doc__, "Lorem ipsum") + + def test_docstring_no_fields(self): + @dataclass + class C: + pass + + self.assertDocStrEqual(C.__doc__, "C()") + + def test_docstring_one_field(self): + @dataclass + class C: + x: int + + self.assertDocStrEqual(C.__doc__, "C(x:int)") + + def test_docstring_two_fields(self): + @dataclass + class C: + x: int + y: int + + self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)") + + def test_docstring_three_fields(self): + @dataclass + class C: + x: int + y: int + z: str + + self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)") + + def test_docstring_one_field_with_default(self): + @dataclass + class C: + x: int = 3 + + self.assertDocStrEqual(C.__doc__, "C(x:int=3)") + + def test_docstring_one_field_with_default_none(self): + @dataclass + class C: + x: Union[int, type(None)] = None + + self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)") + + def test_docstring_list_field(self): + @dataclass + class C: + x: List[int] + + self.assertDocStrEqual(C.__doc__, "C(x:List[int])") + + def test_docstring_list_field_with_default_factory(self): + @dataclass + class C: + x: List[int] = field(default_factory=list) + + self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)") + + def test_docstring_deque_field(self): + @dataclass + class C: + x: deque + + self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)") + + def test_docstring_deque_field_with_default_factory(self): + @dataclass + class C: + x: deque = field(default_factory=deque) + + self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)") + + +class TestInit(unittest.TestCase): + def test_base_has_init(self): + class B: + def __init__(self): + self.z = 100 + pass + + # Make sure that declaring this class doesn't raise an error. + # The issue is that we can't override __init__ in our class, + # but it should be okay to add __init__ to us if our base has + # an __init__. + @dataclass + class C(B): + x: int = 0 + c = C(10) + self.assertEqual(c.x, 10) + self.assertNotIn('z', vars(c)) + + # Make sure that if we don't add an init, the base __init__ + # gets called. + @dataclass(init=False) + class C(B): + x: int = 10 + c = C() + self.assertEqual(c.x, 10) + self.assertEqual(c.z, 100) + + def test_no_init(self): + @dataclass(init=False) + class C: + i: int = 0 + self.assertEqual(C().i, 0) + + @dataclass(init=False) + class C: + i: int = 2 + def __init__(self): + self.i = 3 + self.assertEqual(C().i, 3) + + def test_overwriting_init(self): + # If the class has __init__, use it no matter the value of + # init=. + + @dataclass + class C: + x: int + def __init__(self, x): + self.x = 2 * x + self.assertEqual(C(3).x, 6) + + @dataclass(init=True) + class C: + x: int + def __init__(self, x): + self.x = 2 * x + self.assertEqual(C(4).x, 8) + + @dataclass(init=False) + class C: + x: int + def __init__(self, x): + self.x = 2 * x + self.assertEqual(C(5).x, 10) + + def test_inherit_from_protocol(self): + # Dataclasses inheriting from protocol should preserve their own `__init__`. + # See bpo-45081. + + class P(Protocol): + a: int + + @dataclass + class C(P): + a: int + + self.assertEqual(C(5).a, 5) + + @dataclass + class D(P): + def __init__(self, a): + self.a = a * 2 + + self.assertEqual(D(5).a, 10) + + +class TestRepr(unittest.TestCase): + def test_repr(self): + @dataclass + class B: + x: int + + @dataclass + class C(B): + y: int = 10 + + o = C(4) + self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)') + + @dataclass + class D(C): + x: int = 20 + self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)') + + @dataclass + class C: + @dataclass + class D: + i: int + @dataclass + class E: + pass + self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)') + self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()') + + def test_no_repr(self): + # Test a class with no __repr__ and repr=False. + @dataclass(repr=False) + class C: + x: int + self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at', + repr(C(3))) + + # Test a class with a __repr__ and repr=False. + @dataclass(repr=False) + class C: + x: int + def __repr__(self): + return 'C-class' + self.assertEqual(repr(C(3)), 'C-class') + + def test_overwriting_repr(self): + # If the class has __repr__, use it no matter the value of + # repr=. + + @dataclass + class C: + x: int + def __repr__(self): + return 'x' + self.assertEqual(repr(C(0)), 'x') + + @dataclass(repr=True) + class C: + x: int + def __repr__(self): + return 'x' + self.assertEqual(repr(C(0)), 'x') + + @dataclass(repr=False) + class C: + x: int + def __repr__(self): + return 'x' + self.assertEqual(repr(C(0)), 'x') + + +class TestEq(unittest.TestCase): + def test_no_eq(self): + # Test a class with no __eq__ and eq=False. + @dataclass(eq=False) + class C: + x: int + self.assertNotEqual(C(0), C(0)) + c = C(3) + self.assertEqual(c, c) + + # Test a class with an __eq__ and eq=False. + @dataclass(eq=False) + class C: + x: int + def __eq__(self, other): + return other == 10 + self.assertEqual(C(3), 10) + + def test_overwriting_eq(self): + # If the class has __eq__, use it no matter the value of + # eq=. + + @dataclass + class C: + x: int + def __eq__(self, other): + return other == 3 + self.assertEqual(C(1), 3) + self.assertNotEqual(C(1), 1) + + @dataclass(eq=True) + class C: + x: int + def __eq__(self, other): + return other == 4 + self.assertEqual(C(1), 4) + self.assertNotEqual(C(1), 1) + + @dataclass(eq=False) + class C: + x: int + def __eq__(self, other): + return other == 5 + self.assertEqual(C(1), 5) + self.assertNotEqual(C(1), 1) + + +class TestOrdering(unittest.TestCase): + def test_functools_total_ordering(self): + # Test that functools.total_ordering works with this class. + @total_ordering + @dataclass + class C: + x: int + def __lt__(self, other): + # Perform the test "backward", just to make + # sure this is being called. + return self.x >= other + + self.assertLess(C(0), -1) + self.assertLessEqual(C(0), -1) + self.assertGreater(C(0), 1) + self.assertGreaterEqual(C(0), 1) + + def test_no_order(self): + # Test that no ordering functions are added by default. + @dataclass(order=False) + class C: + x: int + # Make sure no order methods are added. + self.assertNotIn('__le__', C.__dict__) + self.assertNotIn('__lt__', C.__dict__) + self.assertNotIn('__ge__', C.__dict__) + self.assertNotIn('__gt__', C.__dict__) + + # Test that __lt__ is still called + @dataclass(order=False) + class C: + x: int + def __lt__(self, other): + return False + # Make sure other methods aren't added. + self.assertNotIn('__le__', C.__dict__) + self.assertNotIn('__ge__', C.__dict__) + self.assertNotIn('__gt__', C.__dict__) + + def test_overwriting_order(self): + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __lt__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __lt__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __le__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __le__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __gt__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __gt__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __ge__' + '.*using functools.total_ordering'): + @dataclass(order=True) + class C: + x: int + def __ge__(self): + pass + +class TestHash(unittest.TestCase): + def test_unsafe_hash(self): + @dataclass(unsafe_hash=True) + class C: + x: int + y: str + self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) + + def test_hash_rules(self): + def non_bool(value): + # Map to something else that's True, but not a bool. + if value is None: + return None + if value: + return (3,) + return 0 + + def test(case, unsafe_hash, eq, frozen, with_hash, result): + with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq, + frozen=frozen): + if result != 'exception': + if with_hash: + @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) + class C: + def __hash__(self): + return 0 + else: + @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) + class C: + pass + + # See if the result matches what's expected. + if result == 'fn': + # __hash__ contains the function we generated. + self.assertIn('__hash__', C.__dict__) + self.assertIsNotNone(C.__dict__['__hash__']) + + elif result == '': + # __hash__ is not present in our class. + if not with_hash: + self.assertNotIn('__hash__', C.__dict__) + + elif result == 'none': + # __hash__ is set to None. + self.assertIn('__hash__', C.__dict__) + self.assertIsNone(C.__dict__['__hash__']) + + elif result == 'exception': + # Creating the class should cause an exception. + # This only happens with with_hash==True. + assert(with_hash) + with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'): + @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) + class C: + def __hash__(self): + return 0 + + else: + assert False, f'unknown result {result!r}' + + # There are 8 cases of: + # unsafe_hash=True/False + # eq=True/False + # frozen=True/False + # And for each of these, a different result if + # __hash__ is defined or not. + for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([ + (False, False, False, '', ''), + (False, False, True, '', ''), + (False, True, False, 'none', ''), + (False, True, True, 'fn', ''), + (True, False, False, 'fn', 'exception'), + (True, False, True, 'fn', 'exception'), + (True, True, False, 'fn', 'exception'), + (True, True, True, 'fn', 'exception'), + ], 1): + test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash) + test(case, unsafe_hash, eq, frozen, True, res_defined_hash) + + # Test non-bool truth values, too. This is just to + # make sure the data-driven table in the decorator + # handles non-bool values. + test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash) + test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash) + + + def test_eq_only(self): + # If a class defines __eq__, __hash__ is automatically added + # and set to None. This is normal Python behavior, not + # related to dataclasses. Make sure we don't interfere with + # that (see bpo=32546). + + @dataclass + class C: + i: int + def __eq__(self, other): + return self.i == other.i + self.assertEqual(C(1), C(1)) + self.assertNotEqual(C(1), C(4)) + + # And make sure things work in this case if we specify + # unsafe_hash=True. + @dataclass(unsafe_hash=True) + class C: + i: int + def __eq__(self, other): + return self.i == other.i + self.assertEqual(C(1), C(1.0)) + self.assertEqual(hash(C(1)), hash(C(1.0))) + + # And check that the classes __eq__ is being used, despite + # specifying eq=True. + @dataclass(unsafe_hash=True, eq=True) + class C: + i: int + def __eq__(self, other): + return self.i == 3 and self.i == other.i + self.assertEqual(C(3), C(3)) + self.assertNotEqual(C(1), C(1)) + self.assertEqual(hash(C(1)), hash(C(1.0))) + + def test_0_field_hash(self): + @dataclass(frozen=True) + class C: + pass + self.assertEqual(hash(C()), hash(())) + + @dataclass(unsafe_hash=True) + class C: + pass + self.assertEqual(hash(C()), hash(())) + + def test_1_field_hash(self): + @dataclass(frozen=True) + class C: + x: int + self.assertEqual(hash(C(4)), hash((4,))) + self.assertEqual(hash(C(42)), hash((42,))) + + @dataclass(unsafe_hash=True) + class C: + x: int + self.assertEqual(hash(C(4)), hash((4,))) + self.assertEqual(hash(C(42)), hash((42,))) + + def test_hash_no_args(self): + # Test dataclasses with no hash= argument. This exists to + # make sure that if the @dataclass parameter name is changed + # or the non-default hashing behavior changes, the default + # hashability keeps working the same way. + + class Base: + def __hash__(self): + return 301 + + # If frozen or eq is None, then use the default value (do not + # specify any value in the decorator). + for frozen, eq, base, expected in [ + (None, None, object, 'unhashable'), + (None, None, Base, 'unhashable'), + (None, False, object, 'object'), + (None, False, Base, 'base'), + (None, True, object, 'unhashable'), + (None, True, Base, 'unhashable'), + (False, None, object, 'unhashable'), + (False, None, Base, 'unhashable'), + (False, False, object, 'object'), + (False, False, Base, 'base'), + (False, True, object, 'unhashable'), + (False, True, Base, 'unhashable'), + (True, None, object, 'tuple'), + (True, None, Base, 'tuple'), + (True, False, object, 'object'), + (True, False, Base, 'base'), + (True, True, object, 'tuple'), + (True, True, Base, 'tuple'), + ]: + + with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected): + # First, create the class. + if frozen is None and eq is None: + @dataclass + class C(base): + i: int + elif frozen is None: + @dataclass(eq=eq) + class C(base): + i: int + elif eq is None: + @dataclass(frozen=frozen) + class C(base): + i: int + else: + @dataclass(frozen=frozen, eq=eq) + class C(base): + i: int + + # Now, make sure it hashes as expected. + if expected == 'unhashable': + c = C(10) + with self.assertRaisesRegex(TypeError, 'unhashable type'): + hash(c) + + elif expected == 'base': + self.assertEqual(hash(C(10)), 301) + + elif expected == 'object': + # I'm not sure what test to use here. object's + # hash isn't based on id(), so calling hash() + # won't tell us much. So, just check the + # function used is object's. + self.assertIs(C.__hash__, object.__hash__) + + elif expected == 'tuple': + self.assertEqual(hash(C(42)), hash((42,))) + + else: + assert False, f'unknown value for expected={expected!r}' + + +class TestFrozen(unittest.TestCase): + def test_frozen(self): + @dataclass(frozen=True) + class C: + i: int + + c = C(10) + self.assertEqual(c.i, 10) + with self.assertRaises(FrozenInstanceError): + c.i = 5 + self.assertEqual(c.i, 10) + + def test_inherit(self): + @dataclass(frozen=True) + class C: + i: int + + @dataclass(frozen=True) + class D(C): + j: int + + d = D(0, 10) + with self.assertRaises(FrozenInstanceError): + d.i = 5 + with self.assertRaises(FrozenInstanceError): + d.j = 6 + self.assertEqual(d.i, 0) + self.assertEqual(d.j, 10) + + def test_inherit_nonfrozen_from_empty_frozen(self): + @dataclass(frozen=True) + class C: + pass + + with self.assertRaisesRegex(TypeError, + 'cannot inherit non-frozen dataclass from a frozen one'): + @dataclass + class D(C): + j: int + + def test_inherit_nonfrozen_from_empty(self): + @dataclass + class C: + pass + + @dataclass + class D(C): + j: int + + d = D(3) + self.assertEqual(d.j, 3) + self.assertIsInstance(d, C) + + # Test both ways: with an intermediate normal (non-dataclass) + # class and without an intermediate class. + def test_inherit_nonfrozen_from_frozen(self): + for intermediate_class in [True, False]: + with self.subTest(intermediate_class=intermediate_class): + @dataclass(frozen=True) + class C: + i: int + + if intermediate_class: + class I(C): pass + else: + I = C + + with self.assertRaisesRegex(TypeError, + 'cannot inherit non-frozen dataclass from a frozen one'): + @dataclass + class D(I): + pass + + def test_inherit_frozen_from_nonfrozen(self): + for intermediate_class in [True, False]: + with self.subTest(intermediate_class=intermediate_class): + @dataclass + class C: + i: int + + if intermediate_class: + class I(C): pass + else: + I = C + + with self.assertRaisesRegex(TypeError, + 'cannot inherit frozen dataclass from a non-frozen one'): + @dataclass(frozen=True) + class D(I): + pass + + def test_inherit_from_normal_class(self): + for intermediate_class in [True, False]: + with self.subTest(intermediate_class=intermediate_class): + class C: + pass + + if intermediate_class: + class I(C): pass + else: + I = C + + @dataclass(frozen=True) + class D(I): + i: int + + d = D(10) + with self.assertRaises(FrozenInstanceError): + d.i = 5 + + def test_non_frozen_normal_derived(self): + # See bpo-32953. + + @dataclass(frozen=True) + class D: + x: int + y: int = 10 + + class S(D): + pass + + s = S(3) + self.assertEqual(s.x, 3) + self.assertEqual(s.y, 10) + s.cached = True + + # But can't change the frozen attributes. + with self.assertRaises(FrozenInstanceError): + s.x = 5 + with self.assertRaises(FrozenInstanceError): + s.y = 5 + self.assertEqual(s.x, 3) + self.assertEqual(s.y, 10) + self.assertEqual(s.cached, True) + + def test_overwriting_frozen(self): + # frozen uses __setattr__ and __delattr__. + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __setattr__'): + @dataclass(frozen=True) + class C: + x: int + def __setattr__(self): + pass + + with self.assertRaisesRegex(TypeError, + 'Cannot overwrite attribute __delattr__'): + @dataclass(frozen=True) + class C: + x: int + def __delattr__(self): + pass + + @dataclass(frozen=False) + class C: + x: int + def __setattr__(self, name, value): + self.__dict__['x'] = value * 2 + self.assertEqual(C(10).x, 20) + + def test_frozen_hash(self): + @dataclass(frozen=True) + class C: + x: Any + + # If x is immutable, we can compute the hash. No exception is + # raised. + hash(C(3)) + + # If x is mutable, computing the hash is an error. + with self.assertRaisesRegex(TypeError, 'unhashable type'): + hash(C({})) + + +class TestSlots(unittest.TestCase): + def test_simple(self): + @dataclass + class C: + __slots__ = ('x',) + x: Any + + # There was a bug where a variable in a slot was assumed to + # also have a default value (of type + # types.MemberDescriptorType). + with self.assertRaisesRegex(TypeError, + r"__init__\(\) missing 1 required positional argument: 'x'"): + C() + + # We can create an instance, and assign to x. + c = C(10) + self.assertEqual(c.x, 10) + c.x = 5 + self.assertEqual(c.x, 5) + + # We can't assign to anything else. + with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"): + c.y = 5 + + def test_derived_added_field(self): + # See bpo-33100. + @dataclass + class Base: + __slots__ = ('x',) + x: Any + + @dataclass + class Derived(Base): + x: int + y: int + + d = Derived(1, 2) + self.assertEqual((d.x, d.y), (1, 2)) + + # We can add a new field to the derived instance. + d.z = 10 + + def test_generated_slots(self): + @dataclass(slots=True) + class C: + x: int + y: int + + c = C(1, 2) + self.assertEqual((c.x, c.y), (1, 2)) + + c.x = 3 + c.y = 4 + self.assertEqual((c.x, c.y), (3, 4)) + + with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"): + c.z = 5 + + def test_add_slots_when_slots_exists(self): + with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'): + @dataclass(slots=True) + class C: + __slots__ = ('x',) + x: int + + def test_generated_slots_value(self): + + class Root: + __slots__ = {'x'} + + class Root2(Root): + __slots__ = {'k': '...', 'j': ''} + + class Root3(Root2): + __slots__ = ['h'] + + class Root4(Root3): + __slots__ = 'aa' + + @dataclass(slots=True) + class Base(Root4): + y: int + j: str + h: str + + self.assertEqual(Base.__slots__, ('y', )) + + @dataclass(slots=True) + class Derived(Base): + aa: float + x: str + z: int + k: str + h: str + + self.assertEqual(Derived.__slots__, ('z', )) + + @dataclass + class AnotherDerived(Base): + z: int + + self.assertNotIn('__slots__', AnotherDerived.__dict__) + + def test_cant_inherit_from_iterator_slots(self): + + class Root: + __slots__ = iter(['a']) + + class Root2(Root): + __slots__ = ('b', ) + + with self.assertRaisesRegex( + TypeError, + "^Slots of 'Root' cannot be determined" + ): + @dataclass(slots=True) + class C(Root2): + x: int + + def test_returns_new_class(self): + class A: + x: int + + B = dataclass(A, slots=True) + self.assertIsNot(A, B) + + self.assertFalse(hasattr(A, "__slots__")) + self.assertTrue(hasattr(B, "__slots__")) + + # Can't be local to test_frozen_pickle. + @dataclass(frozen=True, slots=True) + class FrozenSlotsClass: + foo: str + bar: int + + @dataclass(frozen=True) + class FrozenWithoutSlotsClass: + foo: str + bar: int + + def test_frozen_pickle(self): + # bpo-43999 + + self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar")) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + obj = self.FrozenSlotsClass("a", 1) + p = pickle.loads(pickle.dumps(obj, protocol=proto)) + self.assertIsNot(obj, p) + self.assertEqual(obj, p) + + obj = self.FrozenWithoutSlotsClass("a", 1) + p = pickle.loads(pickle.dumps(obj, protocol=proto)) + self.assertIsNot(obj, p) + self.assertEqual(obj, p) + + def test_slots_with_default_no_init(self): + # Originally reported in bpo-44649. + @dataclass(slots=True) + class A: + a: str + b: str = field(default='b', init=False) + + obj = A("a") + self.assertEqual(obj.a, 'a') + self.assertEqual(obj.b, 'b') + + def test_slots_with_default_factory_no_init(self): + # Originally reported in bpo-44649. + @dataclass(slots=True) + class A: + a: str + b: str = field(default_factory=lambda:'b', init=False) + + obj = A("a") + self.assertEqual(obj.a, 'a') + self.assertEqual(obj.b, 'b') + + def test_slots_no_weakref(self): + @dataclass(slots=True) + class A: + # No weakref. + pass + + self.assertNotIn("__weakref__", A.__slots__) + a = A() + with self.assertRaisesRegex(TypeError, + "cannot create weak reference"): + weakref.ref(a) + + def test_slots_weakref(self): + @dataclass(slots=True, weakref_slot=True) + class A: + a: int + + self.assertIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_slots_weakref_base_str(self): + class Base: + __slots__ = '__weakref__' + + @dataclass(slots=True) + class A(Base): + a: int + + # __weakref__ is in the base class, not A. But an A is still weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_slots_weakref_base_tuple(self): + # Same as test_slots_weakref_base, but use a tuple instead of a string + # in the base class. + class Base: + __slots__ = ('__weakref__',) + + @dataclass(slots=True) + class A(Base): + a: int + + # __weakref__ is in the base class, not A. But an A is still + # weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_weakref_slot_without_slot(self): + with self.assertRaisesRegex(TypeError, + "weakref_slot is True but slots is False"): + @dataclass(weakref_slot=True) + class A: + a: int + + def test_weakref_slot_make_dataclass(self): + A = make_dataclass('A', [('a', int),], slots=True, weakref_slot=True) + self.assertIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + # And make sure if raises if slots=True is not given. + with self.assertRaisesRegex(TypeError, + "weakref_slot is True but slots is False"): + B = make_dataclass('B', [('a', int),], weakref_slot=True) + + def test_weakref_slot_subclass_weakref_slot(self): + @dataclass(slots=True, weakref_slot=True) + class Base: + field: int + + # A *can* also specify weakref_slot=True if it wants to (gh-93521) + @dataclass(slots=True, weakref_slot=True) + class A(Base): + ... + + # __weakref__ is in the base class, not A. But an instance of A + # is still weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_weakref_slot_subclass_no_weakref_slot(self): + @dataclass(slots=True, weakref_slot=True) + class Base: + field: int + + @dataclass(slots=True) + class A(Base): + ... + + # __weakref__ is in the base class, not A. Even though A doesn't + # specify weakref_slot, it should still be weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + def test_weakref_slot_normal_base_weakref_slot(self): + class Base: + __slots__ = ('__weakref__',) + + @dataclass(slots=True, weakref_slot=True) + class A(Base): + field: int + + # __weakref__ is in the base class, not A. But an instance of + # A is still weakref-able. + self.assertIn("__weakref__", Base.__slots__) + self.assertNotIn("__weakref__", A.__slots__) + a = A(1) + weakref.ref(a) + + +class TestDescriptors(unittest.TestCase): + def test_set_name(self): + # See bpo-33141. + + # Create a descriptor. + class D: + def __set_name__(self, owner, name): + self.name = name + 'x' + def __get__(self, instance, owner): + if instance is not None: + return 1 + return self + + # This is the case of just normal descriptor behavior, no + # dataclass code is involved in initializing the descriptor. + @dataclass + class C: + c: int=D() + self.assertEqual(C.c.name, 'cx') + + # Now test with a default value and init=False, which is the + # only time this is really meaningful. If not using + # init=False, then the descriptor will be overwritten, anyway. + @dataclass + class C: + c: int=field(default=D(), init=False) + self.assertEqual(C.c.name, 'cx') + self.assertEqual(C().c, 1) + + def test_non_descriptor(self): + # PEP 487 says __set_name__ should work on non-descriptors. + # Create a descriptor. + + class D: + def __set_name__(self, owner, name): + self.name = name + 'x' + + @dataclass + class C: + c: int=field(default=D(), init=False) + self.assertEqual(C.c.name, 'cx') + + def test_lookup_on_instance(self): + # See bpo-33175. + class D: + pass + + d = D() + # Create an attribute on the instance, not type. + d.__set_name__ = Mock() + + # Make sure d.__set_name__ is not called. + @dataclass + class C: + i: int=field(default=d, init=False) + + self.assertEqual(d.__set_name__.call_count, 0) + + def test_lookup_on_class(self): + # See bpo-33175. + class D: + pass + D.__set_name__ = Mock() + + # Make sure D.__set_name__ is called. + @dataclass + class C: + i: int=field(default=D(), init=False) + + self.assertEqual(D.__set_name__.call_count, 1) + + def test_init_calls_set(self): + class D: + pass + + D.__set__ = Mock() + + @dataclass + class C: + i: D = D() + + # Make sure D.__set__ is called. + D.__set__.reset_mock() + c = C(5) + self.assertEqual(D.__set__.call_count, 1) + + def test_getting_field_calls_get(self): + class D: + pass + + D.__set__ = Mock() + D.__get__ = Mock() + + @dataclass + class C: + i: D = D() + + c = C(5) + + # Make sure D.__get__ is called. + D.__get__.reset_mock() + value = c.i + self.assertEqual(D.__get__.call_count, 1) + + def test_setting_field_calls_set(self): + class D: + pass + + D.__set__ = Mock() + + @dataclass + class C: + i: D = D() + + c = C(5) + + # Make sure D.__set__ is called. + D.__set__.reset_mock() + c.i = 10 + self.assertEqual(D.__set__.call_count, 1) + + def test_setting_uninitialized_descriptor_field(self): + class D: + pass + + D.__set__ = Mock() + + @dataclass + class C: + i: D + + # D.__set__ is not called because there's no D instance to call it on + D.__set__.reset_mock() + c = C(5) + self.assertEqual(D.__set__.call_count, 0) + + # D.__set__ still isn't called after setting i to an instance of D + # because descriptors don't behave like that when stored as instance vars + c.i = D() + c.i = 5 + self.assertEqual(D.__set__.call_count, 0) + + def test_default_value(self): + class D: + def __get__(self, instance: Any, owner: object) -> int: + if instance is None: + return 100 + + return instance._x + + def __set__(self, instance: Any, value: int) -> None: + instance._x = value + + @dataclass + class C: + i: D = D() + + c = C() + self.assertEqual(c.i, 100) + + c = C(5) + self.assertEqual(c.i, 5) + + def test_no_default_value(self): + class D: + def __get__(self, instance: Any, owner: object) -> int: + if instance is None: + raise AttributeError() + + return instance._x + + def __set__(self, instance: Any, value: int) -> None: + instance._x = value + + @dataclass + class C: + i: D = D() + + with self.assertRaisesRegex(TypeError, 'missing 1 required positional argument'): + c = C() + +class TestStringAnnotations(unittest.TestCase): + def test_classvar(self): + # Some expressions recognized as ClassVar really aren't. But + # if you're using string annotations, it's not an exact + # science. + # These tests assume that both "import typing" and "from + # typing import *" have been run in this file. + for typestr in ('ClassVar[int]', + 'ClassVar [int]', + ' ClassVar [int]', + 'ClassVar', + ' ClassVar ', + 'typing.ClassVar[int]', + 'typing.ClassVar[str]', + ' typing.ClassVar[str]', + 'typing .ClassVar[str]', + 'typing. ClassVar[str]', + 'typing.ClassVar [str]', + 'typing.ClassVar [ str]', + + # Not syntactically valid, but these will + # be treated as ClassVars. + 'typing.ClassVar.[int]', + 'typing.ClassVar+', + ): + with self.subTest(typestr=typestr): + @dataclass + class C: + x: typestr + + # x is a ClassVar, so C() takes no args. + C() + + # And it won't appear in the class's dict because it doesn't + # have a default. + self.assertNotIn('x', C.__dict__) + + def test_isnt_classvar(self): + for typestr in ('CV', + 't.ClassVar', + 't.ClassVar[int]', + 'typing..ClassVar[int]', + 'Classvar', + 'Classvar[int]', + 'typing.ClassVarx[int]', + 'typong.ClassVar[int]', + 'dataclasses.ClassVar[int]', + 'typingxClassVar[str]', + ): + with self.subTest(typestr=typestr): + @dataclass + class C: + x: typestr + + # x is not a ClassVar, so C() takes one arg. + self.assertEqual(C(10).x, 10) + + def test_initvar(self): + # These tests assume that both "import dataclasses" and "from + # dataclasses import *" have been run in this file. + for typestr in ('InitVar[int]', + 'InitVar [int]' + ' InitVar [int]', + 'InitVar', + ' InitVar ', + 'dataclasses.InitVar[int]', + 'dataclasses.InitVar[str]', + ' dataclasses.InitVar[str]', + 'dataclasses .InitVar[str]', + 'dataclasses. InitVar[str]', + 'dataclasses.InitVar [str]', + 'dataclasses.InitVar [ str]', + + # Not syntactically valid, but these will + # be treated as InitVars. + 'dataclasses.InitVar.[int]', + 'dataclasses.InitVar+', + ): + with self.subTest(typestr=typestr): + @dataclass + class C: + x: typestr + + # x is an InitVar, so doesn't create a member. + with self.assertRaisesRegex(AttributeError, + "object has no attribute 'x'"): + C(1).x + + def test_isnt_initvar(self): + for typestr in ('IV', + 'dc.InitVar', + 'xdataclasses.xInitVar', + 'typing.xInitVar[int]', + ): + with self.subTest(typestr=typestr): + @dataclass + class C: + x: typestr + + # x is not an InitVar, so there will be a member x. + self.assertEqual(C(10).x, 10) + + def test_classvar_module_level_import(self): + from test import dataclass_module_1 + from test import dataclass_module_1_str + from test import dataclass_module_2 + from test import dataclass_module_2_str + + for m in (dataclass_module_1, dataclass_module_1_str, + dataclass_module_2, dataclass_module_2_str, + ): + with self.subTest(m=m): + # There's a difference in how the ClassVars are + # interpreted when using string annotations or + # not. See the imported modules for details. + if m.USING_STRINGS: + c = m.CV(10) + else: + c = m.CV() + self.assertEqual(c.cv0, 20) + + + # There's a difference in how the InitVars are + # interpreted when using string annotations or + # not. See the imported modules for details. + c = m.IV(0, 1, 2, 3, 4) + + for field_name in ('iv0', 'iv1', 'iv2', 'iv3'): + with self.subTest(field_name=field_name): + with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"): + # Since field_name is an InitVar, it's + # not an instance field. + getattr(c, field_name) + + if m.USING_STRINGS: + # iv4 is interpreted as a normal field. + self.assertIn('not_iv4', c.__dict__) + self.assertEqual(c.not_iv4, 4) + else: + # iv4 is interpreted as an InitVar, so it + # won't exist on the instance. + self.assertNotIn('not_iv4', c.__dict__) + + def test_text_annotations(self): + from test import dataclass_textanno + + self.assertEqual( + get_type_hints(dataclass_textanno.Bar), + {'foo': dataclass_textanno.Foo}) + self.assertEqual( + get_type_hints(dataclass_textanno.Bar.__init__), + {'foo': dataclass_textanno.Foo, + 'return': type(None)}) + + +class TestMakeDataclass(unittest.TestCase): + def test_simple(self): + C = make_dataclass('C', + [('x', int), + ('y', int, field(default=5))], + namespace={'add_one': lambda self: self.x + 1}) + c = C(10) + self.assertEqual((c.x, c.y), (10, 5)) + self.assertEqual(c.add_one(), 11) + + + def test_no_mutate_namespace(self): + # Make sure a provided namespace isn't mutated. + ns = {} + C = make_dataclass('C', + [('x', int), + ('y', int, field(default=5))], + namespace=ns) + self.assertEqual(ns, {}) + + def test_base(self): + class Base1: + pass + class Base2: + pass + C = make_dataclass('C', + [('x', int)], + bases=(Base1, Base2)) + c = C(2) + self.assertIsInstance(c, C) + self.assertIsInstance(c, Base1) + self.assertIsInstance(c, Base2) + + def test_base_dataclass(self): + @dataclass + class Base1: + x: int + class Base2: + pass + C = make_dataclass('C', + [('y', int)], + bases=(Base1, Base2)) + with self.assertRaisesRegex(TypeError, 'required positional'): + c = C(2) + c = C(1, 2) + self.assertIsInstance(c, C) + self.assertIsInstance(c, Base1) + self.assertIsInstance(c, Base2) + + self.assertEqual((c.x, c.y), (1, 2)) + + def test_init_var(self): + def post_init(self, y): + self.x *= y + + C = make_dataclass('C', + [('x', int), + ('y', InitVar[int]), + ], + namespace={'__post_init__': post_init}, + ) + c = C(2, 3) + self.assertEqual(vars(c), {'x': 6}) + self.assertEqual(len(fields(c)), 1) + + def test_class_var(self): + C = make_dataclass('C', + [('x', int), + ('y', ClassVar[int], 10), + ('z', ClassVar[int], field(default=20)), + ]) + c = C(1) + self.assertEqual(vars(c), {'x': 1}) + self.assertEqual(len(fields(c)), 1) + self.assertEqual(C.y, 10) + self.assertEqual(C.z, 20) + + def test_other_params(self): + C = make_dataclass('C', + [('x', int), + ('y', ClassVar[int], 10), + ('z', ClassVar[int], field(default=20)), + ], + init=False) + # Make sure we have a repr, but no init. + self.assertNotIn('__init__', vars(C)) + self.assertIn('__repr__', vars(C)) + + # Make sure random other params don't work. + with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): + C = make_dataclass('C', + [], + xxinit=False) + + def test_no_types(self): + C = make_dataclass('Point', ['x', 'y', 'z']) + c = C(1, 2, 3) + self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) + self.assertEqual(C.__annotations__, {'x': 'typing.Any', + 'y': 'typing.Any', + 'z': 'typing.Any'}) + + C = make_dataclass('Point', ['x', ('y', int), 'z']) + c = C(1, 2, 3) + self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) + self.assertEqual(C.__annotations__, {'x': 'typing.Any', + 'y': int, + 'z': 'typing.Any'}) + + def test_invalid_type_specification(self): + for bad_field in [(), + (1, 2, 3, 4), + ]: + with self.subTest(bad_field=bad_field): + with self.assertRaisesRegex(TypeError, r'Invalid field: '): + make_dataclass('C', ['a', bad_field]) + + # And test for things with no len(). + for bad_field in [float, + lambda x:x, + ]: + with self.subTest(bad_field=bad_field): + with self.assertRaisesRegex(TypeError, r'has no len\(\)'): + make_dataclass('C', ['a', bad_field]) + + def test_duplicate_field_names(self): + for field in ['a', 'ab']: + with self.subTest(field=field): + with self.assertRaisesRegex(TypeError, 'Field name duplicated'): + make_dataclass('C', [field, 'a', field]) + + def test_keyword_field_names(self): + for field in ['for', 'async', 'await', 'as']: + with self.subTest(field=field): + with self.assertRaisesRegex(TypeError, 'must not be keywords'): + make_dataclass('C', ['a', field]) + with self.assertRaisesRegex(TypeError, 'must not be keywords'): + make_dataclass('C', [field]) + with self.assertRaisesRegex(TypeError, 'must not be keywords'): + make_dataclass('C', [field, 'a']) + + def test_non_identifier_field_names(self): + for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']: + with self.subTest(field=field): + with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): + make_dataclass('C', ['a', field]) + with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): + make_dataclass('C', [field]) + with self.assertRaisesRegex(TypeError, 'must be valid identifiers'): + make_dataclass('C', [field, 'a']) + + def test_underscore_field_names(self): + # Unlike namedtuple, it's okay if dataclass field names have + # an underscore. + make_dataclass('C', ['_', '_a', 'a_a', 'a_']) + + def test_funny_class_names_names(self): + # No reason to prevent weird class names, since + # types.new_class allows them. + for classname in ['()', 'x,y', '*', '2@3', '']: + with self.subTest(classname=classname): + C = make_dataclass(classname, ['a', 'b']) + self.assertEqual(C.__name__, classname) + +class TestReplace(unittest.TestCase): + def test(self): + @dataclass(frozen=True) + class C: + x: int + y: int + + c = C(1, 2) + c1 = replace(c, x=3) + self.assertEqual(c1.x, 3) + self.assertEqual(c1.y, 2) + + def test_frozen(self): + @dataclass(frozen=True) + class C: + x: int + y: int + z: int = field(init=False, default=10) + t: int = field(init=False, default=100) + + c = C(1, 2) + c1 = replace(c, x=3) + self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100)) + self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100)) + + + with self.assertRaisesRegex(ValueError, 'init=False'): + replace(c, x=3, z=20, t=50) + with self.assertRaisesRegex(ValueError, 'init=False'): + replace(c, z=20) + replace(c, x=3, z=20, t=50) + + # Make sure the result is still frozen. + with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"): + c1.x = 3 + + # Make sure we can't replace an attribute that doesn't exist, + # if we're also replacing one that does exist. Test this + # here, because setting attributes on frozen instances is + # handled slightly differently from non-frozen ones. + with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " + "keyword argument 'a'"): + c1 = replace(c, x=20, a=5) + + def test_invalid_field_name(self): + @dataclass(frozen=True) + class C: + x: int + y: int + + c = C(1, 2) + with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " + "keyword argument 'z'"): + c1 = replace(c, z=3) + + def test_invalid_object(self): + @dataclass(frozen=True) + class C: + x: int + y: int + + with self.assertRaisesRegex(TypeError, 'dataclass instance'): + replace(C, x=3) + + with self.assertRaisesRegex(TypeError, 'dataclass instance'): + replace(0, x=3) + + def test_no_init(self): + @dataclass + class C: + x: int + y: int = field(init=False, default=10) + + c = C(1) + c.y = 20 + + # Make sure y gets the default value. + c1 = replace(c, x=5) + self.assertEqual((c1.x, c1.y), (5, 10)) + + # Trying to replace y is an error. + with self.assertRaisesRegex(ValueError, 'init=False'): + replace(c, x=2, y=30) + + with self.assertRaisesRegex(ValueError, 'init=False'): + replace(c, y=30) + + def test_classvar(self): + @dataclass + class C: + x: int + y: ClassVar[int] = 1000 + + c = C(1) + d = C(2) + + self.assertIs(c.y, d.y) + self.assertEqual(c.y, 1000) + + # Trying to replace y is an error: can't replace ClassVars. + with self.assertRaisesRegex(TypeError, r"__init__\(\) got an " + "unexpected keyword argument 'y'"): + replace(c, y=30) + + replace(c, x=5) + + def test_initvar_is_specified(self): + @dataclass + class C: + x: int + y: InitVar[int] + + def __post_init__(self, y): + self.x *= y + + c = C(1, 10) + self.assertEqual(c.x, 10) + with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be " + "specified with replace()"): + replace(c, x=3) + c = replace(c, x=3, y=5) + self.assertEqual(c.x, 15) + + def test_initvar_with_default_value(self): + @dataclass + class C: + x: int + y: InitVar[int] = None + z: InitVar[int] = 42 + + def __post_init__(self, y, z): + if y is not None: + self.x += y + if z is not None: + self.x += z + + c = C(x=1, y=10, z=1) + self.assertEqual(replace(c), C(x=12)) + self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42)) + self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1)) + + def test_recursive_repr(self): + @dataclass + class C: + f: "C" + + c = C(None) + c.f = c + self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)") + + def test_recursive_repr_two_attrs(self): + @dataclass + class C: + f: "C" + g: "C" + + c = C(None, None) + c.f = c + c.g = c + self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs" + ".<locals>.C(f=..., g=...)") + + def test_recursive_repr_indirection(self): + @dataclass + class C: + f: "D" + + @dataclass + class D: + f: "C" + + c = C(None) + d = D(None) + c.f = d + d.f = c + self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection" + ".<locals>.C(f=TestReplace.test_recursive_repr_indirection" + ".<locals>.D(f=...))") + + def test_recursive_repr_indirection_two(self): + @dataclass + class C: + f: "D" + + @dataclass + class D: + f: "E" + + @dataclass + class E: + f: "C" + + c = C(None) + d = D(None) + e = E(None) + c.f = d + d.f = e + e.f = c + self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two" + ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two" + ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two" + ".<locals>.E(f=...)))") + + def test_recursive_repr_misc_attrs(self): + @dataclass + class C: + f: "C" + g: int + + c = C(None, 1) + c.f = c + self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs" + ".<locals>.C(f=..., g=1)") + + ## def test_initvar(self): + ## @dataclass + ## class C: + ## x: int + ## y: InitVar[int] + + ## c = C(1, 10) + ## d = C(2, 20) + + ## # In our case, replacing an InitVar is a no-op + ## self.assertEqual(c, replace(c, y=5)) + + ## replace(c, x=5) + +class TestAbstract(unittest.TestCase): + def test_abc_implementation(self): + class Ordered(abc.ABC): + @abc.abstractmethod + def __lt__(self, other): + pass + + @abc.abstractmethod + def __le__(self, other): + pass + + @dataclass(order=True) + class Date(Ordered): + year: int + month: 'Month' + day: 'int' + + self.assertFalse(inspect.isabstract(Date)) + self.assertGreater(Date(2020,12,25), Date(2020,8,31)) + + def test_maintain_abc(self): + class A(abc.ABC): + @abc.abstractmethod + def foo(self): + pass + + @dataclass + class Date(A): + year: int + month: 'Month' + day: 'int' + + self.assertTrue(inspect.isabstract(Date)) + msg = 'class Date without an implementation for abstract method foo' + self.assertRaisesRegex(TypeError, msg, Date) + + +class TestMatchArgs(unittest.TestCase): + def test_match_args(self): + @dataclass + class C: + a: int + self.assertEqual(C(42).__match_args__, ('a',)) + + def test_explicit_match_args(self): + ma = () + @dataclass + class C: + a: int + __match_args__ = ma + self.assertIs(C(42).__match_args__, ma) + + def test_bpo_43764(self): + @dataclass(repr=False, eq=False, init=False) + class X: + a: int + b: int + c: int + self.assertEqual(X.__match_args__, ("a", "b", "c")) + + def test_match_args_argument(self): + @dataclass(match_args=False) + class X: + a: int + self.assertNotIn('__match_args__', X.__dict__) + + @dataclass(match_args=False) + class Y: + a: int + __match_args__ = ('b',) + self.assertEqual(Y.__match_args__, ('b',)) + + @dataclass(match_args=False) + class Z(Y): + z: int + self.assertEqual(Z.__match_args__, ('b',)) + + # Ensure parent dataclass __match_args__ is seen, if child class + # specifies match_args=False. + @dataclass + class A: + a: int + z: int + @dataclass(match_args=False) + class B(A): + b: int + self.assertEqual(B.__match_args__, ('a', 'z')) + + def test_make_dataclasses(self): + C = make_dataclass('C', [('x', int), ('y', int)]) + self.assertEqual(C.__match_args__, ('x', 'y')) + + C = make_dataclass('C', [('x', int), ('y', int)], match_args=True) + self.assertEqual(C.__match_args__, ('x', 'y')) + + C = make_dataclass('C', [('x', int), ('y', int)], match_args=False) + self.assertNotIn('__match__args__', C.__dict__) + + C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)}) + self.assertEqual(C.__match_args__, ('z',)) + + +class TestKeywordArgs(unittest.TestCase): + def test_no_classvar_kwarg(self): + msg = 'field a is a ClassVar but specifies kw_only' + with self.assertRaisesRegex(TypeError, msg): + @dataclass + class A: + a: ClassVar[int] = field(kw_only=True) + + with self.assertRaisesRegex(TypeError, msg): + @dataclass + class A: + a: ClassVar[int] = field(kw_only=False) + + with self.assertRaisesRegex(TypeError, msg): + @dataclass(kw_only=True) + class A: + a: ClassVar[int] = field(kw_only=False) + + def test_field_marked_as_kwonly(self): + ####################### + # Using dataclass(kw_only=True) + @dataclass(kw_only=True) + class A: + a: int + self.assertTrue(fields(A)[0].kw_only) + + @dataclass(kw_only=True) + class A: + a: int = field(kw_only=True) + self.assertTrue(fields(A)[0].kw_only) + + @dataclass(kw_only=True) + class A: + a: int = field(kw_only=False) + self.assertFalse(fields(A)[0].kw_only) + + ####################### + # Using dataclass(kw_only=False) + @dataclass(kw_only=False) + class A: + a: int + self.assertFalse(fields(A)[0].kw_only) + + @dataclass(kw_only=False) + class A: + a: int = field(kw_only=True) + self.assertTrue(fields(A)[0].kw_only) + + @dataclass(kw_only=False) + class A: + a: int = field(kw_only=False) + self.assertFalse(fields(A)[0].kw_only) + + ####################### + # Not specifying dataclass(kw_only) + @dataclass + class A: + a: int + self.assertFalse(fields(A)[0].kw_only) + + @dataclass + class A: + a: int = field(kw_only=True) + self.assertTrue(fields(A)[0].kw_only) + + @dataclass + class A: + a: int = field(kw_only=False) + self.assertFalse(fields(A)[0].kw_only) + + def test_match_args(self): + # kw fields don't show up in __match_args__. + @dataclass(kw_only=True) + class C: + a: int + self.assertEqual(C(a=42).__match_args__, ()) + + @dataclass + class C: + a: int + b: int = field(kw_only=True) + self.assertEqual(C(42, b=10).__match_args__, ('a',)) + + def test_KW_ONLY(self): + @dataclass + class A: + a: int + _: KW_ONLY + b: int + c: int + A(3, c=5, b=4) + msg = "takes 2 positional arguments but 4 were given" + with self.assertRaisesRegex(TypeError, msg): + A(3, 4, 5) + + + @dataclass(kw_only=True) + class B: + a: int + _: KW_ONLY + b: int + c: int + B(a=3, b=4, c=5) + msg = "takes 1 positional argument but 4 were given" + with self.assertRaisesRegex(TypeError, msg): + B(3, 4, 5) + + # Explicitly make a field that follows KW_ONLY be non-keyword-only. + @dataclass + class C: + a: int + _: KW_ONLY + b: int + c: int = field(kw_only=False) + c = C(1, 2, b=3) + self.assertEqual(c.a, 1) + self.assertEqual(c.b, 3) + self.assertEqual(c.c, 2) + c = C(1, b=3, c=2) + self.assertEqual(c.a, 1) + self.assertEqual(c.b, 3) + self.assertEqual(c.c, 2) + c = C(1, b=3, c=2) + self.assertEqual(c.a, 1) + self.assertEqual(c.b, 3) + self.assertEqual(c.c, 2) + c = C(c=2, b=3, a=1) + self.assertEqual(c.a, 1) + self.assertEqual(c.b, 3) + self.assertEqual(c.c, 2) + + def test_KW_ONLY_as_string(self): + @dataclass + class A: + a: int + _: 'dataclasses.KW_ONLY' + b: int + c: int + A(3, c=5, b=4) + msg = "takes 2 positional arguments but 4 were given" + with self.assertRaisesRegex(TypeError, msg): + A(3, 4, 5) + + def test_KW_ONLY_twice(self): + msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified" + + with self.assertRaisesRegex(TypeError, msg): + @dataclass + class A: + a: int + X: KW_ONLY + Y: KW_ONLY + b: int + c: int + + with self.assertRaisesRegex(TypeError, msg): + @dataclass + class A: + a: int + X: KW_ONLY + b: int + Y: KW_ONLY + c: int + + with self.assertRaisesRegex(TypeError, msg): + @dataclass + class A: + a: int + X: KW_ONLY + b: int + c: int + Y: KW_ONLY + + # But this usage is okay, since it's not using KW_ONLY. + @dataclass + class A: + a: int + _: KW_ONLY + b: int + c: int = field(kw_only=True) + + # And if inheriting, it's okay. + @dataclass + class A: + a: int + _: KW_ONLY + b: int + c: int + @dataclass + class B(A): + _: KW_ONLY + d: int + + # Make sure the error is raised in a derived class. + with self.assertRaisesRegex(TypeError, msg): + @dataclass + class A: + a: int + _: KW_ONLY + b: int + c: int + @dataclass + class B(A): + X: KW_ONLY + d: int + Y: KW_ONLY + + + def test_post_init(self): + @dataclass + class A: + a: int + _: KW_ONLY + b: InitVar[int] + c: int + d: InitVar[int] + def __post_init__(self, b, d): + raise CustomError(f'{b=} {d=}') + with self.assertRaisesRegex(CustomError, 'b=3 d=4'): + A(1, c=2, b=3, d=4) + + @dataclass + class B: + a: int + _: KW_ONLY + b: InitVar[int] + c: int + d: InitVar[int] + def __post_init__(self, b, d): + self.a = b + self.c = d + b = B(1, c=2, b=3, d=4) + self.assertEqual(asdict(b), {'a': 3, 'c': 4}) + + def test_defaults(self): + # For kwargs, make sure we can have defaults after non-defaults. + @dataclass + class A: + a: int = 0 + _: KW_ONLY + b: int + c: int = 1 + d: int + + a = A(d=4, b=3) + self.assertEqual(a.a, 0) + self.assertEqual(a.b, 3) + self.assertEqual(a.c, 1) + self.assertEqual(a.d, 4) + + # Make sure we still check for non-kwarg non-defaults not following + # defaults. + err_regex = "non-default argument 'z' follows default argument" + with self.assertRaisesRegex(TypeError, err_regex): + @dataclass + class A: + a: int = 0 + z: int + _: KW_ONLY + b: int + c: int = 1 + d: int + + def test_make_dataclass(self): + A = make_dataclass("A", ['a'], kw_only=True) + self.assertTrue(fields(A)[0].kw_only) + + B = make_dataclass("B", + ['a', ('b', int, field(kw_only=False))], + kw_only=True) + self.assertTrue(fields(B)[0].kw_only) + self.assertFalse(fields(B)[1].kw_only) + + +if __name__ == '__main__': + unittest.main() diff --git a/Tools/dump_github_issues.py b/Tools/dump_github_issues.py new file mode 100644 index 000000000..daec51c50 --- /dev/null +++ b/Tools/dump_github_issues.py @@ -0,0 +1,142 @@ +""" +Dump the GitHub issues of the current project to a file (.json.gz). + +Usage: python3 Tools/dump_github_issues.py +""" + +import configparser +import gzip +import json +import os.path + +from datetime import datetime +from urllib.request import urlopen + +GIT_CONFIG_FILE = ".git/config" + + +class RateLimitReached(Exception): + pass + + +def gen_urls(repo): + i = 0 + while True: + yield f"https://api.github.com/repos/{repo}/issues?state=all&per_page=100&page={i}" + i += 1 + + +def read_rate_limit(): + with urlopen("https://api.github.com/rate_limit") as p: + return json.load(p) + + +def parse_rate_limit(limits): + limits = limits['resources']['core'] + return limits['limit'], limits['remaining'], datetime.fromtimestamp(limits['reset']) + + +def load_url(url): + with urlopen(url) as p: + data = json.load(p) + if isinstance(data, dict) and 'rate limit' in data.get('message', ''): + raise RateLimitReached() + + assert isinstance(data, list), type(data) + return data or None # None indicates empty last page + + +def join_list_data(lists): + result = [] + for data in lists: + if not data: + break + result.extend(data) + return result + + +def output_filename(repo): + timestamp = datetime.now() + return f"github_issues_{repo.replace('/', '_')}_{timestamp.strftime('%Y%m%d_%H%M%S')}.json.gz" + + +def write_gzjson(file_name, data, indent=2): + with gzip.open(file_name, "wt", encoding='utf-8') as gz: + json.dump(data, gz, indent=indent) + + +def find_origin_url(git_config=GIT_CONFIG_FILE): + assert os.path.exists(git_config) + parser = configparser.ConfigParser() + parser.read(git_config) + return parser.get('remote "origin"', 'url') + + +def parse_repo_name(git_url): + if git_url.endswith('.git'): + git_url = git_url[:-4] + return '/'.join(git_url.split('/')[-2:]) + + +def dump_issues(repo): + """Main entry point.""" + print(f"Reading issues from repo '{repo}'") + urls = gen_urls(repo) + try: + paged_data = map(load_url, urls) + issues = join_list_data(paged_data) + except RateLimitReached: + limit, remaining, reset_time = parse_rate_limit(read_rate_limit()) + print(f"FAILURE: Rate limits ({limit}) reached, remaining: {remaining}, reset at {reset_time}") + return + + filename = output_filename(repo) + print(f"Writing {len(issues)} to {filename}") + write_gzjson(filename, issues) + + +### TESTS + +def test_join_list_data(): + assert join_list_data([]) == [] + assert join_list_data([[1,2]]) == [1,2] + assert join_list_data([[1,2], [3]]) == [1,2,3] + assert join_list_data([[0], [1,2], [3]]) == [0,1,2,3] + assert join_list_data([[0], [1,2], [[[]],[]]]) == [0,1,2,[[]],[]] + + +def test_output_filename(): + filename = output_filename("re/po") + import re + assert re.match(r"github_issues_re_po_[0-9]{8}_[0-9]{6}\.json", filename) + + +def test_find_origin_url(): + assert find_origin_url() + + +def test_parse_repo_name(): + assert parse_repo_name("https://github.com/cython/cython") == "cython/cython" + assert parse_repo_name("git+ssh://git@github.com/cython/cython.git") == "cython/cython" + assert parse_repo_name("git+ssh://git@github.com/fork/cython.git") == "fork/cython" + + +def test_write_gzjson(): + import tempfile + with tempfile.NamedTemporaryFile() as tmp: + write_gzjson(tmp.name, [{}]) + + # test JSON format + with gzip.open(tmp.name) as f: + assert json.load(f) == [{}] + + # test indentation + with gzip.open(tmp.name) as f: + assert f.read() == b'[\n {}\n]' + + +### MAIN + +if __name__ == '__main__': + repo_name = parse_repo_name(find_origin_url()) + dump_issues(repo_name) diff --git a/Tools/gen_tests_for_posix_pxds.py b/Tools/gen_tests_for_posix_pxds.py new file mode 100644 index 000000000..c92c49a6a --- /dev/null +++ b/Tools/gen_tests_for_posix_pxds.py @@ -0,0 +1,41 @@ +#!/usr/bin/python3 + +from pathlib import Path + +PROJECT_ROOT = Path(__file__) / "../.." +POSIX_PXDS_DIR = PROJECT_ROOT / "Cython/Includes/posix" +TEST_PATH = PROJECT_ROOT / "tests/compile/posix_pxds.pyx" + +def main(): + datas = [ + "# tag: posix\n" + "# mode: compile\n" + "\n" + "# This file is generated by `Tools/gen_tests_for_posix_pxds.py`.\n" + "\n" + "cimport posix\n" + ] + + filenames = sorted(map(lambda path: path.name, POSIX_PXDS_DIR.iterdir())) + + for name in filenames: + if name == "__init__.pxd": + continue + if name.endswith(".pxd"): + name = name[:-4] + else: + continue + + s = ( + "cimport posix.{name}\n" + "from posix cimport {name}\n" + "from posix.{name} cimport *\n" + ).format(name=name) + + datas.append(s) + + with open(TEST_PATH, "w", encoding="utf-8", newline="\n") as f: + f.write("\n".join(datas)) + +if __name__ == "__main__": + main() diff --git a/Tools/make_dataclass_tests.py b/Tools/make_dataclass_tests.py new file mode 100644 index 000000000..dc38eee70 --- /dev/null +++ b/Tools/make_dataclass_tests.py @@ -0,0 +1,443 @@ +# Used to generate tests/run/test_dataclasses.pyx but translating the CPython test suite +# dataclass file. Initially run using Python 3.10 - this file is not designed to be +# backwards compatible since it will be run manually and infrequently. + +import ast +import os.path +import sys + +unavailable_functions = frozenset( + { + "dataclass_textanno", # part of CPython test module + "dataclass_module_1", # part of CPython test module + "make_dataclass", # not implemented in Cython dataclasses (probably won't be implemented) + } +) + +skip_tests = frozenset( + { + # needs Cython compile + # ==================== + ("TestCase", "test_field_default_default_factory_error"), + ("TestCase", "test_two_fields_one_default"), + ("TestCase", "test_overwrite_hash"), + ("TestCase", "test_eq_order"), + ("TestCase", "test_no_unhashable_default"), + ("TestCase", "test_disallowed_mutable_defaults"), + ("TestCase", "test_classvar_default_factory"), + ("TestCase", "test_field_metadata_mapping"), + ("TestFieldNoAnnotation", "test_field_without_annotation"), + ( + "TestFieldNoAnnotation", + "test_field_without_annotation_but_annotation_in_base", + ), + ( + "TestFieldNoAnnotation", + "test_field_without_annotation_but_annotation_in_base_not_dataclass", + ), + ("TestOrdering", "test_overwriting_order"), + ("TestHash", "test_hash_rules"), + ("TestHash", "test_hash_no_args"), + ("TestFrozen", "test_inherit_nonfrozen_from_empty_frozen"), + ("TestFrozen", "test_inherit_nonfrozen_from_frozen"), + ("TestFrozen", "test_inherit_frozen_from_nonfrozen"), + ("TestFrozen", "test_overwriting_frozen"), + ("TestSlots", "test_add_slots_when_slots_exists"), + ("TestSlots", "test_cant_inherit_from_iterator_slots"), + ("TestSlots", "test_weakref_slot_without_slot"), + ("TestKeywordArgs", "test_no_classvar_kwarg"), + ("TestKeywordArgs", "test_KW_ONLY_twice"), + ("TestKeywordArgs", "test_defaults"), + # uses local variable in class definition + ("TestCase", "test_default_factory"), + ("TestCase", "test_default_factory_with_no_init"), + ("TestCase", "test_field_default"), + ("TestCase", "test_function_annotations"), + ("TestDescriptors", "test_lookup_on_instance"), + ("TestCase", "test_default_factory_not_called_if_value_given"), + ("TestCase", "test_class_attrs"), + ("TestCase", "test_hash_field_rules"), + ("TestStringAnnotations",), # almost all the texts here use local variables + # Currently unsupported + # ===================== + ( + "TestOrdering", + "test_functools_total_ordering", + ), # combination of cython dataclass and total_ordering + ("TestCase", "test_missing_default_factory"), # we're MISSING MISSING + ("TestCase", "test_missing_default"), # MISSING + ("TestCase", "test_missing_repr"), # MISSING + ("TestSlots",), # __slots__ isn't understood + ("TestMatchArgs",), + ("TestKeywordArgs", "test_field_marked_as_kwonly"), + ("TestKeywordArgs", "test_match_args"), + ("TestKeywordArgs", "test_KW_ONLY"), + ("TestKeywordArgs", "test_KW_ONLY_as_string"), + ("TestKeywordArgs", "test_post_init"), + ( + "TestCase", + "test_class_var_frozen", + ), # __annotations__ not present on cdef classes https://github.com/cython/cython/issues/4519 + ("TestCase", "test_dont_include_other_annotations"), # __annotations__ + ("TestDocString",), # don't think cython dataclasses currently set __doc__ + # either cython.dataclasses.field or cython.dataclasses.dataclass called directly as functions + # (will probably never be supported) + ("TestCase", "test_field_repr"), + ("TestCase", "test_dynamic_class_creation"), + ("TestCase", "test_dynamic_class_creation_using_field"), + # Requires inheritance from non-cdef class + ("TestCase", "test_is_dataclass_genericalias"), + ("TestCase", "test_generic_extending"), + ("TestCase", "test_generic_dataclasses"), + ("TestCase", "test_generic_dynamic"), + ("TestInit", "test_inherit_from_protocol"), + ("TestAbstract", "test_abc_implementation"), + ("TestAbstract", "test_maintain_abc"), + # Requires multiple inheritance from extension types + ("TestCase", "test_post_init_not_auto_added"), + # Refers to nonlocal from enclosing function + ( + "TestCase", + "test_post_init_staticmethod", + ), # TODO replicate the gist of the test elsewhere + # PEP487 isn't support in Cython + ("TestDescriptors", "test_non_descriptor"), + ("TestDescriptors", "test_set_name"), + ("TestDescriptors", "test_setting_field_calls_set"), + ("TestDescriptors", "test_setting_uninitialized_descriptor_field"), + # Looks up __dict__, which cdef classes don't typically have + ("TestCase", "test_init_false_no_default"), + ("TestCase", "test_init_var_inheritance"), # __dict__ again + ("TestCase", "test_base_has_init"), + ("TestInit", "test_base_has_init"), # needs __dict__ for vars + # Requires arbitrary attributes to be writeable + ("TestCase", "test_post_init_super"), + ('TestCase', 'test_init_in_order'), + # Cython being strict about argument types - expected difference + ("TestDescriptors", "test_getting_field_calls_get"), + ("TestDescriptors", "test_init_calls_set"), + ("TestHash", "test_eq_only"), + # I think an expected difference with cdef classes - the property will be in the dict + ("TestCase", "test_items_in_dicts"), + # These tests are probably fine, but the string substitution in this file doesn't get it right + ("TestRepr", "test_repr"), + ("TestCase", "test_not_in_repr"), + ('TestRepr', 'test_no_repr'), + # class variable doesn't exist in Cython so uninitialized variable appears differently - for now this is deliberate + ('TestInit', 'test_no_init'), + # I believe the test works but the ordering functions do appear in the class dict (and default slot wrappers which + # just raise NotImplementedError + ('TestOrdering', 'test_no_order'), + # not possible to add attributes on extension types + ("TestCase", "test_post_init_classmethod"), + # Cannot redefine the same field in a base dataclass (tested in dataclass_e6) + ("TestCase", "test_field_order"), + ( + "TestCase", + "test_overwrite_fields_in_derived_class", + ), + # Bugs + #====== + # not specifically a dataclass issue - a C int crashes classvar + ("TestCase", "test_class_var"), + ( + "TestFrozen", + ), # raises AttributeError, not FrozenInstanceError (may be hard to fix) + ('TestCase', 'test_post_init'), # Works except for AttributeError instead of FrozenInstanceError + ("TestReplace", "test_frozen"), # AttributeError not FrozenInstanceError + ( + "TestCase", + "test_dataclasses_qualnames", + ), # doesn't define __setattr__ and just relies on Cython to enforce readonly properties + ("TestCase", "test_compare_subclasses"), # wrong comparison + ("TestCase", "test_simple_compare"), # wrong comparison + ( + "TestCase", + "test_field_named_self", + ), # I think just an error in inspecting the signature + ( + "TestCase", + "test_init_var_default_factory", + ), # should be raising a compile error + ("TestCase", "test_init_var_no_default"), # should be raising a compile error + ("TestCase", "test_init_var_with_default"), # not sure... + ("TestReplace", "test_initvar_with_default_value"), # needs investigating + # Maybe bugs? + # ========== + # non-default argument 'z' follows default argument in dataclass __init__ - this message looks right to me! + ("TestCase", "test_class_marker"), + # cython.dataclasses.field parameter 'metadata' must be a literal value - possibly not something we can support? + ("TestCase", "test_field_metadata_custom_mapping"), + ( + "TestCase", + "test_class_var_default_factory", + ), # possibly to do with ClassVar being assigned a field + ( + "TestCase", + "test_class_var_with_default", + ), # possibly to do with ClassVar being assigned a field + ( + "TestDescriptors", + ), # mostly don't work - I think this may be a limitation of cdef classes but needs investigating + } +) + +version_specific_skips = { + # The version numbers are the first version that the test should be run on + ("TestCase", "test_init_var_preserve_type"): ( + 3, + 10, + ), # needs language support for | operator on types +} + +class DataclassInDecorators(ast.NodeVisitor): + found = False + + def visit_Name(self, node): + if node.id == "dataclass": + self.found = True + return self.generic_visit(node) + + def generic_visit(self, node): + if self.found: + return # skip + return super().generic_visit(node) + + +def dataclass_in_decorators(decorator_list): + finder = DataclassInDecorators() + for dec in decorator_list: + finder.visit(dec) + if finder.found: + return True + return False + + +class SubstituteNameString(ast.NodeTransformer): + def __init__(self, substitutions): + super().__init__() + self.substitutions = substitutions + + def visit_Constant(self, node): + # attempt to handle some difference in class names + # (note: requires Python>=3.8) + if isinstance(node.value, str): + if node.value.find("<locals>") != -1: + import re + + new_value = new_value2 = re.sub("[\w.]*<locals>", "", node.value) + for key, value in self.substitutions.items(): + new_value2 = re.sub(f"(?<![\w])[.]{key}(?![\w])", value, new_value2) + if new_value != new_value2: + node.value = new_value2 + return node + + +class SubstituteName(SubstituteNameString): + def visit_Name(self, node): + if isinstance(node.ctx, ast.Store): # don't reassign lhs + return node + replacement = self.substitutions.get(node.id, None) + if replacement is not None: + return ast.Name(id=replacement, ctx=node.ctx) + else: + return node + + +class IdentifyCdefClasses(ast.NodeVisitor): + def __init__(self): + super().__init__() + self.top_level_class = True + self.classes = {} + self.cdef_classes = set() + + def visit_ClassDef(self, node): + top_level_class, self.top_level_class = self.top_level_class, False + try: + if not top_level_class: + self.classes[node.name] = node + if dataclass_in_decorators(node.decorator_list): + self.handle_cdef_class(node) + self.generic_visit(node) # any nested classes in it? + else: + self.generic_visit(node) + finally: + self.top_level_class = top_level_class + + def visit_FunctionDef(self, node): + classes, self.classes = self.classes, {} + self.generic_visit(node) + self.classes = classes + + def handle_cdef_class(self, cls_node): + if cls_node not in self.cdef_classes: + self.cdef_classes.add(cls_node) + # go back through previous classes we've seen and pick out any first bases + if cls_node.bases and isinstance(cls_node.bases[0], ast.Name): + base0_node = self.classes.get(cls_node.bases[0].id) + if base0_node: + self.handle_cdef_class(base0_node) + + +class ExtractDataclassesToTopLevel(ast.NodeTransformer): + def __init__(self, cdef_classes_set): + super().__init__() + self.nested_name = [] + self.current_function_global_classes = [] + self.global_classes = [] + self.cdef_classes_set = cdef_classes_set + self.used_names = set() + self.collected_substitutions = {} + self.uses_unavailable_name = False + self.top_level_class = True + + def visit_ClassDef(self, node): + if not self.top_level_class: + # Include any non-toplevel class in this to be able + # to test inheritance. + + self.generic_visit(node) # any nested classes in it? + if not node.body: + node.body.append(ast.Pass) + + # First, make it a C class. + if node in self.cdef_classes_set: + node.decorator_list.append(ast.Name(id="cclass", ctx=ast.Load())) + # otherwise move it to the global scope, but don't make it cdef + # change the name + old_name = node.name + new_name = "_".join([node.name] + self.nested_name) + while new_name in self.used_names: + new_name = new_name + "_" + node.name = new_name + self.current_function_global_classes.append(node) + self.used_names.add(new_name) + # hmmmm... possibly there's a few cases where there's more than one name? + self.collected_substitutions[old_name] = node.name + + return ast.Assign( + targets=[ast.Name(id=old_name, ctx=ast.Store())], + value=ast.Name(id=new_name, ctx=ast.Load()), + lineno=-1, + ) + else: + top_level_class, self.top_level_class = self.top_level_class, False + self.nested_name.append(node.name) + if tuple(self.nested_name) in skip_tests: + self.top_level_class = top_level_class + self.nested_name.pop() + return None + self.generic_visit(node) + self.nested_name.pop() + if not node.body: + node.body.append(ast.Pass()) + self.top_level_class = top_level_class + return node + + def visit_FunctionDef(self, node): + self.nested_name.append(node.name) + if tuple(self.nested_name) in skip_tests: + self.nested_name.pop() + return None + if tuple(self.nested_name) in version_specific_skips: + version = version_specific_skips[tuple(self.nested_name)] + decorator = ast.parse( + f"skip_on_versions_below({version})", mode="eval" + ).body + node.decorator_list.append(decorator) + collected_subs, self.collected_substitutions = self.collected_substitutions, {} + uses_unavailable_name, self.uses_unavailable_name = ( + self.uses_unavailable_name, + False, + ) + current_func_globs, self.current_function_global_classes = ( + self.current_function_global_classes, + [], + ) + + # visit once to work out what the substitutions should be + self.generic_visit(node) + if self.collected_substitutions: + # replace strings in this function + node = SubstituteNameString(self.collected_substitutions).visit(node) + replacer = SubstituteName(self.collected_substitutions) + # replace any base classes + for global_class in self.current_function_global_classes: + global_class = replacer.visit(global_class) + self.global_classes.append(self.current_function_global_classes) + + self.nested_name.pop() + self.collected_substitutions = collected_subs + if self.uses_unavailable_name: + node = None + self.uses_unavailable_name = uses_unavailable_name + self.current_function_global_classes = current_func_globs + return node + + def visit_Name(self, node): + if node.id in unavailable_functions: + self.uses_unavailable_name = True + return self.generic_visit(node) + + def visit_Import(self, node): + return None # drop imports, we add these into the text ourself + + def visit_ImportFrom(self, node): + return None # drop imports, we add these into the text ourself + + def visit_Call(self, node): + if ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "assertRaisesRegex" + ): + # we end up with a bunch of subtle name changes that are very hard to correct for + # therefore, replace with "assertRaises" + node.func.attr = "assertRaises" + node.args.pop() + return self.generic_visit(node) + + def visit_Module(self, node): + self.generic_visit(node) + node.body[0:0] = self.global_classes + return node + + def visit_AnnAssign(self, node): + # string annotations are forward declarations but the string will be wrong + # (because we're renaming the class) + if (isinstance(node.annotation, ast.Constant) and + isinstance(node.annotation.value, str)): + # although it'd be good to resolve these declarations, for the + # sake of the tests they only need to be "object" + node.annotation = ast.Name(id="object", ctx=ast.Load) + + return node + + +def main(): + script_path = os.path.split(sys.argv[0])[0] + filename = "test_dataclasses.py" + py_module_path = os.path.join(script_path, "dataclass_test_data", filename) + with open(py_module_path, "r") as f: + tree = ast.parse(f.read(), filename) + + cdef_class_finder = IdentifyCdefClasses() + cdef_class_finder.visit(tree) + transformer = ExtractDataclassesToTopLevel(cdef_class_finder.cdef_classes) + tree = transformer.visit(tree) + + output_path = os.path.join(script_path, "..", "tests", "run", filename + "x") + with open(output_path, "w") as f: + print("# AUTO-GENERATED BY Tools/make_dataclass_tests.py", file=f) + print("# DO NOT EDIT", file=f) + print(file=f) + # the directive doesn't get applied outside the include if it's put + # in the pxi file + print("# cython: language_level=3", file=f) + # any extras Cython needs to add go in this include file + print('include "test_dataclasses.pxi"', file=f) + print(file=f) + print(ast.unparse(tree), file=f) + + +if __name__ == "__main__": + main() diff --git a/Tools/rules.bzl b/Tools/rules.bzl index cd3eed58f..c59af6a99 100644 --- a/Tools/rules.bzl +++ b/Tools/rules.bzl @@ -11,8 +11,8 @@ load("@cython//Tools:rules.bzl", "pyx_library") pyx_library(name = 'mylib', srcs = ['a.pyx', 'a.pxd', 'b.py', 'pkg/__init__.py', 'pkg/c.pyx'], - py_deps = ['//py_library/dep'], - data = ['//other/data'], + # python library deps passed to py_library + deps = ['//py_library/dep'] ) The __init__.py file must be in your srcs list so that Cython can resolve |