# copyright 2003-2012 LOGILAB S.A. (Paris, FRANCE), all rights reserved. # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr # # This file is part of logilab-common. # # logilab-common is free software: you can redistribute it and/or modify it under # the terms of the GNU Lesser General Public License as published by the Free # Software Foundation, either version 2.1 of the License, or (at your option) any # later version. # # logilab-common is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more # details. # # You should have received a copy of the GNU Lesser General Public License along # with logilab-common. If not, see . """unit tests for logilab.common.shellutils""" from os.path import join, dirname, abspath from unittest.mock import patch from logilab.common.testlib import TestCase, unittest_main from logilab.common.shellutils import globfind, find, ProgressBar, RawInput from logilab.common.compat import StringIO DATA_DIR = join(dirname(abspath(__file__)), "data", "find_test") class FindTC(TestCase): def test_include(self): files = set(find(DATA_DIR, ".py")) self.assertSetEqual( files, set( [ join(DATA_DIR, f) for f in [ "__init__.py", "module.py", "module2.py", "noendingnewline.py", "nonregr.py", join("sub", "momo.py"), ] ] ), ) files = set(find(DATA_DIR, (".py",), blacklist=("sub",))) self.assertSetEqual( files, set( [ join(DATA_DIR, f) for f in [ "__init__.py", "module.py", "module2.py", "noendingnewline.py", "nonregr.py", ] ] ), ) def test_exclude(self): files = set(find(DATA_DIR, (".py", ".pyc"), exclude=True)) self.assertSetEqual( files, set( [ join(DATA_DIR, f) for f in [ "foo.txt", "newlines.txt", "normal_file.txt", "test.ini", "test1.msg", "test2.msg", "spam.txt", join("sub", "doc.txt"), "write_protected_file.txt", ] ] ), ) def test_globfind(self): files = set(globfind(DATA_DIR, "*.py")) self.assertSetEqual( files, set( [ join(DATA_DIR, f) for f in [ "__init__.py", "module.py", "module2.py", "noendingnewline.py", "nonregr.py", join("sub", "momo.py"), ] ] ), ) files = set(globfind(DATA_DIR, "mo*.py")) self.assertSetEqual( files, set([join(DATA_DIR, f) for f in ["module.py", "module2.py", join("sub", "momo.py")]]), ) files = set(globfind(DATA_DIR, "mo*.py", blacklist=("sub",))) self.assertSetEqual(files, set([join(DATA_DIR, f) for f in ["module.py", "module2.py"]])) class ProgressBarTC(TestCase): def test_refresh(self): pgb_stream = StringIO() expected_stream = StringIO() pgb = ProgressBar(20, stream=pgb_stream) self.assertEqual( pgb_stream.getvalue(), expected_stream.getvalue() ) # nothing print before refresh pgb.refresh() expected_stream.write("\r[" + " " * 20 + "]") self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) def test_refresh_g_size(self): pgb_stream = StringIO() expected_stream = StringIO() pgb = ProgressBar(20, 35, stream=pgb_stream) pgb.refresh() expected_stream.write("\r[" + " " * 35 + "]") self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) def test_refresh_l_size(self): pgb_stream = StringIO() expected_stream = StringIO() pgb = ProgressBar(20, 3, stream=pgb_stream) pgb.refresh() expected_stream.write("\r[" + " " * 3 + "]") self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) def _update_test(self, nbops, expected, size=None): pgb_stream = StringIO() expected_stream = StringIO() if size is None: pgb = ProgressBar(nbops, stream=pgb_stream) size = 20 else: pgb = ProgressBar(nbops, size, stream=pgb_stream) last = 0 for round in expected: if not hasattr(round, "__int__"): dots, update = round else: dots, update = round, None pgb.update() if update or (update is None and dots != last): last = dots expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]") self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) def test_default(self): self._update_test(20, range(1, 21)) def test_nbops_gt_size(self): """Test the progress bar for nbops > size""" def half(total): for counter in range(1, total + 1): yield counter // 2 self._update_test(40, half(40)) def test_nbops_lt_size(self): """Test the progress bar for nbops < size""" def double(total): for counter in range(1, total + 1): yield counter * 2 self._update_test(10, double(10)) def test_nbops_nomul_size(self): """Test the progress bar for size % nbops !=0 (non int number of dots per update)""" self._update_test(3, (6, 13, 20)) def test_overflow(self): self._update_test(5, (8, 16, 25, 33, 42, (42, True)), size=42) def test_update_exact(self): pgb_stream = StringIO() expected_stream = StringIO() size = 20 pgb = ProgressBar(100, size, stream=pgb_stream) for dots in range(10, 105, 15): pgb.update(dots, exact=True) dots //= 5 expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]") self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) def test_update_relative(self): pgb_stream = StringIO() expected_stream = StringIO() size = 20 pgb = ProgressBar(100, size, stream=pgb_stream) for dots in range(5, 105, 5): pgb.update(5, exact=False) dots //= 5 expected_stream.write("\r[" + ("=" * dots) + (" " * (size - dots)) + "]") self.assertEqual(pgb_stream.getvalue(), expected_stream.getvalue()) class RawInputTC(TestCase): def auto_input(self, *args): self.input_args = args return self.input_answer def setUp(self): null_printer = lambda x: None self.qa = RawInput(self.auto_input, null_printer) def test_ask_using_builtin_input(self): with patch("builtins.input", return_value="no"): qa = RawInput() answer = qa.ask("text", ("yes", "no"), "yes") self.assertEqual(answer, "no") def test_ask_default(self): self.input_answer = "" answer = self.qa.ask("text", ("yes", "no"), "yes") self.assertEqual(answer, "yes") self.input_answer = " " answer = self.qa.ask("text", ("yes", "no"), "yes") self.assertEqual(answer, "yes") def test_ask_case(self): self.input_answer = "no" answer = self.qa.ask("text", ("yes", "no"), "yes") self.assertEqual(answer, "no") self.input_answer = "No" answer = self.qa.ask("text", ("yes", "no"), "yes") self.assertEqual(answer, "no") self.input_answer = "NO" answer = self.qa.ask("text", ("yes", "no"), "yes") self.assertEqual(answer, "no") self.input_answer = "nO" answer = self.qa.ask("text", ("yes", "no"), "yes") self.assertEqual(answer, "no") self.input_answer = "YES" answer = self.qa.ask("text", ("yes", "no"), "yes") self.assertEqual(answer, "yes") def test_ask_prompt(self): self.input_answer = "" self.qa.ask("text", ("yes", "no"), "yes") self.assertEqual(self.input_args[0], "text [Y(es)/n(o)]: ") self.qa.ask("text", ("y", "n"), "y") self.assertEqual(self.input_args[0], "text [Y/n]: ") self.qa.ask("text", ("n", "y"), "y") self.assertEqual(self.input_args[0], "text [n/Y]: ") self.qa.ask("text", ("yes", "no", "maybe", "1"), "yes") self.assertEqual(self.input_args[0], "text [Y(es)/n(o)/m(aybe)/1]: ") def test_ask_ambiguous(self): self.input_answer = "y" self.assertRaises(Exception, self.qa.ask, "text", ("yes", "yep"), "yes") def test_confirm(self): self.input_answer = "y" self.assertEqual(self.qa.confirm("Say yes"), True) self.assertEqual(self.qa.confirm("Say yes", default_is_yes=False), True) self.input_answer = "n" self.assertEqual(self.qa.confirm("Say yes"), False) self.assertEqual(self.qa.confirm("Say yes", default_is_yes=False), False) self.input_answer = "" self.assertEqual(self.qa.confirm("Say default"), True) self.assertEqual(self.qa.confirm("Say default", default_is_yes=False), False) if __name__ == "__main__": unittest_main()