From 61a99bdeffe78c9c2e7e45c9802fd1db9fef3225 Mon Sep 17 00:00:00 2001 From: Matthew Peveler Date: Sat, 16 Oct 2021 07:22:03 -1000 Subject: Split off utils functions into own module (#197) --- .github/workflows/test.yml | 4 +- asciidoc/asciidoc.py | 261 +++++++-------------------------------------- asciidoc/utils.py | 185 ++++++++++++++++++++++++++++++++ tests/test_utils.py | 103 ++++++++++++++++++ 4 files changed, 331 insertions(+), 222 deletions(-) create mode 100644 asciidoc/utils.py create mode 100644 tests/test_utils.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4018e07..d8bdefc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,7 +19,7 @@ jobs: - name: Install Dependencies run: | python3 -m pip install -U pip - python3 -m pip install -U pytest pytest-runner flake8 + python3 -m pip install -U flake8 - name: Lint codebase run: python3 -m flake8 @@ -60,7 +60,7 @@ jobs: texlive-latex-base \ xsltproc - - run: pip install -U pytest coverage + - run: pip install -U pytest pytest-mock coverage - run: coverage run -m asciidoc.asciidoc --doctest - run: coverage run --append --source=asciidoc -m pytest diff --git a/asciidoc/asciidoc.py b/asciidoc/asciidoc.py index 9ab1dd2..4c323f1 100644 --- a/asciidoc/asciidoc.py +++ b/asciidoc/asciidoc.py @@ -21,8 +21,6 @@ import csv from functools import lru_cache import getopt import io -import locale -import math import os import re import shutil @@ -39,6 +37,7 @@ from collections import OrderedDict from .collections import AttrDict, InsensitiveDict from .exceptions import EAsciiDoc +from . import utils CONF_DIR = os.path.join(os.path.dirname(__file__), 'resources') METADATA = {} @@ -101,12 +100,12 @@ class Trace(object): if self.name_re is not None: msg = message.format(name, 'TRACE: ', self.linenos, offset=self.offset) if before != after and re.match(self.name_re, name): - if is_array(before): + if utils.is_array(before): before = '\n'.join(before) if after is None: msg += '\n%s\n' % before else: - if is_array(after): + if utils.is_array(after): after = '\n'.join(after) msg += '\n<<<\n%s\n>>>\n%s\n' % (before, after) message.stderr(msg) @@ -177,29 +176,6 @@ class Message: self.error('unsafe: '+msg) -def userdir(): - """ - Return user's home directory or None if it is not defined. - """ - result = os.path.expanduser('~') - if result == '~': - result = None - return result - - -def file_in(fname, directory): - """Return True if file fname resides inside directory.""" - assert os.path.isfile(fname) - # Empty directory (not to be confused with None) is the current directory. - if directory == '': - directory = os.getcwd() - else: - assert os.path.isdir(directory) - directory = os.path.realpath(directory) - fname = os.path.realpath(fname) - return os.path.commonprefix((directory, fname)) == directory - - def safe(): return document.safe @@ -215,8 +191,8 @@ def is_safe_file(fname, directory=None): directory = '.' return ( not safe() or - file_in(fname, directory) or - file_in(fname, CONF_DIR) + utils.file_in(fname, directory) or + utils.file_in(fname, CONF_DIR) ) @@ -235,119 +211,6 @@ def safe_filename(fname, parentdir): return fname -def assign(dst, src): - """Assign all attributes from 'src' object to 'dst' object.""" - for a, v in list(src.__dict__.items()): - setattr(dst, a, v) - - -def strip_quotes(s): - """Trim white space and, if necessary, quote characters from s.""" - s = s.strip() - # Strip quotation mark characters from quoted strings. - if len(s) >= 3 and s[0] == '"' and s[-1] == '"': - s = s[1:-1] - return s - - -def is_re(s): - """Return True if s is a valid regular expression else return False.""" - try: - re.compile(s) - except: - return False - else: - return True - - -def re_join(relist): - """Join list of regular expressions re1,re2,... to single regular - expression (re1)|(re2)|...""" - if len(relist) == 0: - return None - result = [] - # Delete named groups to avoid ambiguity. - for s in relist: - result.append(re.sub(r'\?P<\S+?>', '', s)) - result = ')|('.join(result) - result = '(' + result + ')' - return result - - -def lstrip_list(s): - """ - Return list with empty items from start of list removed. - """ - for i in range(len(s)): - if s[i]: - break - else: - return [] - return s[i:] - - -def rstrip_list(s): - """ - Return list with empty items from end of list removed. - """ - for i in range(len(s) - 1, -1, -1): - if s[i]: - break - else: - return [] - return s[:i + 1] - - -def strip_list(s): - """ - Return list with empty items from start and end of list removed. - """ - s = lstrip_list(s) - s = rstrip_list(s) - return s - - -def is_array(obj): - """ - Return True if object is list or tuple type. - """ - return isinstance(obj, list) or isinstance(obj, tuple) - - -def dovetail(lines1, lines2): - """ - Append list or tuple of strings 'lines2' to list 'lines1'. Join the last - non-blank item in 'lines1' with the first non-blank item in 'lines2' into a - single string. - """ - assert is_array(lines1) - assert is_array(lines2) - lines1 = strip_list(lines1) - lines2 = strip_list(lines2) - if not lines1 or not lines2: - return list(lines1) + list(lines2) - result = list(lines1[:-1]) - result.append(lines1[-1] + lines2[0]) - result += list(lines2[1:]) - return result - - -def dovetail_tags(stag, content, etag): - """Merge the end tag with the first content line and the last - content line with the end tag. This ensures verbatim elements don't - include extraneous opening and closing line breaks.""" - return dovetail(dovetail(stag, content), etag) - - -def py2round(n, d=0): - """Utility function to get python2 rounding in python3. Python3 changed it such that - given two equally close multiples, it'll round towards the even choice. For example, - round(42.5) == 42 instead of the expected round(42.5) == 43). This function gives us - back that functionality.""" - p = 10 ** d - return float(math.floor((n * p) + math.copysign(0.5, n))) / p - - def get_args(val): d = {} args = ast.parse("d(" + val + ")", mode='eval').body.args @@ -582,9 +445,9 @@ def parse_entry(entry, dict=None, unquote=False, unique_values=False, else: return None if unquote: - name = strip_quotes(name) + name = utils.strip_quotes(name) if value is not None: - value = strip_quotes(value) + value = utils.strip_quotes(value) else: name = name.strip() if value is not None: @@ -1077,7 +940,7 @@ def subs_attrs(lines, dictionary=None): if len(v) not in (2, 3): message.error('illegal attribute syntax: %s' % attr) s = '' - elif not is_re('^' + v[0] + '$'): + elif not utils.is_re('^' + v[0] + '$'): message.error('illegal attribute regexp: %s' % attr) s = '' else: @@ -1156,48 +1019,6 @@ def subs_attrs(lines, dictionary=None): return tuple(result) -east_asian_widths = { - 'W': 2, # Wide - 'F': 2, # Full-width (wide) - 'Na': 1, # Narrow - 'H': 1, # Half-width (narrow) - 'N': 1, # Neutral (not East Asian, treated as narrow) - 'A': 1, # Ambiguous (s/b wide in East Asian context, narrow otherwise, but that doesn't work) -} -"""Mapping of result codes from `unicodedata.east_asian_width()` to character -column widths.""" - - -def column_width(s): - width = 0 - for c in s: - width += east_asian_widths[unicodedata.east_asian_width(c)] - return width - - -def date_time_str(t): - """Convert seconds since the Epoch to formatted local date and time strings.""" - source_date_epoch = os.environ.get('SOURCE_DATE_EPOCH') - if source_date_epoch is not None: - t = time.gmtime(min(t, int(source_date_epoch))) - else: - t = time.localtime(t) - date_str = time.strftime('%Y-%m-%d', t) - time_str = time.strftime('%H:%M:%S', t) - if source_date_epoch is not None: - time_str += ' UTC' - elif time.daylight and t.tm_isdst == 1: - time_str += ' ' + time.tzname[1] - else: - time_str += ' ' + time.tzname[0] - # Attempt to convert the localtime to the output encoding. - try: - time_str = time_str.decode(locale.getdefaultlocale()[1]) - except Exception: - pass - return date_str, time_str - - class Lex: """Lexical analysis routines. Static methods and attributes only.""" prev_element = None @@ -1377,7 +1198,7 @@ class Document(object): Set implicit attributes and attributes in 'attrs'. """ t = time.time() - self.attributes['localdate'], self.attributes['localtime'] = date_time_str(t) + self.attributes['localdate'], self.attributes['localtime'] = utils.date_time_str(t) self.attributes['asciidoc-module'] = 'asciidoc' self.attributes['asciidoc-version'] = VERSION self.attributes['asciidoc-confdir'] = CONF_DIR @@ -1403,7 +1224,7 @@ class Document(object): else: t = None if t: - self.attributes['docdate'], self.attributes['doctime'] = date_time_str(t) + self.attributes['docdate'], self.attributes['doctime'] = utils.date_time_str(t) if self.infile != '': self.attributes['infile'] = self.infile self.attributes['indir'] = os.path.dirname(self.infile) @@ -2055,7 +1876,7 @@ class Title: if len(lines) < 2: return False title, ul = lines[:2] - title_len = column_width(title) + title_len = utils.column_width(title) ul_len = len(ul) if ul_len < 2: return False @@ -2118,13 +1939,13 @@ class Title: Title.dump_dict['subs'] = entries['subs'] if 'sectiontitle' in entries: pat = entries['sectiontitle'] - if not pat or not is_re(pat): + if not pat or not utils.is_re(pat): raise EAsciiDoc('malformed [titles] sectiontitle entry') Title.pattern = pat Title.dump_dict['sectiontitle'] = pat if 'blocktitle' in entries: pat = entries['blocktitle'] - if not pat or not is_re(pat): + if not pat or not utils.is_re(pat): raise EAsciiDoc('malformed [titles] blocktitle entry') BlockTitle.pattern = pat Title.dump_dict['blocktitle'] = pat @@ -2132,7 +1953,7 @@ class Title: for k in ('sect0', 'sect1', 'sect2', 'sect3', 'sect4'): if k in entries: pat = entries[k] - if not pat or not is_re(pat): + if not pat or not utils.is_re(pat): raise EAsciiDoc('malformed [titles] %s entry' % k) Title.dump_dict[k] = pat # TODO: Check we have either a Title.pattern or at least one @@ -2252,7 +2073,7 @@ class Section: if 'ascii-ids' in document.attributes: # Replace non-ASCII characters with ASCII equivalents. try: - from trans import trans + from trans import trans # pyright: reportMissingImports=false base_id = trans(base_id) except ImportError: base_id = unicodedata.normalize('NFKD', base_id).encode('ascii', 'ignore').decode('ascii') @@ -2435,7 +2256,7 @@ class AbstractBlock: v = parse_options(v, SUBS_OPTIONS, msg % (k, v)) copy(dst, k, v) elif k == 'delimiter': - if v and is_re(v): + if v and utils.is_re(v): copy(dst, k, v) else: raise EAsciiDoc(msg % (k, v)) @@ -2621,7 +2442,7 @@ class AbstractBlock: def check_array_parameter(param): # Check the parameter is a sequence type. - if not is_array(self.parameters[param]): + if not utils.is_array(self.parameters[param]): message.error('malformed %s parameter: %s' % (param, self.parameters[param])) # Revert to default value. self.parameters[param] = getattr(self, param) @@ -2724,7 +2545,7 @@ class AbstractBlocks: b.validate() if b.delimiter: delimiters.append(b.delimiter) - self.delimiters = re_join(delimiters) + self.delimiters = utils.re_join(delimiters) class Paragraph(AbstractBlock): @@ -2773,7 +2594,7 @@ class Paragraph(AbstractBlock): body = Lex.subs(body, postsubs) etag = config.section2tags(template, self.attributes, skipstart=True)[1] # Write start tag, content, end tag. - writer.write(dovetail_tags(stag, body, etag), trace='paragraph') + writer.write(utils.dovetail_tags(stag, body, etag), trace='paragraph') class Paragraphs(AbstractBlocks): @@ -3157,7 +2978,7 @@ class DelimitedBlock(AbstractBlock): body = Lex.subs(body, postsubs) # Write start tag, content, end tag. etag = config.section2tags(template, self.attributes, skipstart=True)[1] - writer.write(dovetail_tags(stag, body, etag), trace=name) + writer.write(utils.dovetail_tags(stag, body, etag), trace=name) trace(self.short_name() + ' block close', etag) if reader.eof(): self.error('missing closing delimiter', self.start) @@ -3335,7 +3156,7 @@ class Table(AbstractBlock): self.error('illegal csv separator=%s' % separator) separator = ',' else: - if not is_re(separator): + if not utils.is_re(separator): self.error('illegal regular expression: separator=%s' % separator) self.parameters.format = format self.parameters.tags = tags @@ -3433,14 +3254,14 @@ class Table(AbstractBlock): col.pcwidth = (float(col.width) / props) * 100 col.abswidth = self.abswidth * (col.pcwidth / 100) if config.pageunits in ('cm', 'mm', 'in', 'em'): - col.abswidth = '%.2f' % py2round(col.abswidth, 2) + col.abswidth = '%.2f' % utils.py2round(col.abswidth, 2) else: - col.abswidth = '%d' % py2round(col.abswidth) + col.abswidth = '%d' % utils.py2round(col.abswidth) percents += col.pcwidth col.pcwidth = int(col.pcwidth) - if py2round(percents) > 100: + if utils.py2round(percents) > 100: self.error('total width exceeds 100%%: %s' % cols, self.start) - elif py2round(percents) < 100: + elif utils.py2round(percents) < 100: self.error('total width less than 100%%: %s' % cols, self.start) def build_colspecs(self): @@ -3601,7 +3422,7 @@ class Table(AbstractBlock): text = '\n'.join(data).strip() data = [] for para in re.split(r'\n{2,}', text): - data += dovetail_tags([stag], para.split('\n'), [etag]) + data += utils.dovetail_tags([stag], para.split('\n'), [etag]) if rowtype == 'header': dtag = tags.headdata elif rowtype == 'footer': @@ -3609,7 +3430,7 @@ class Table(AbstractBlock): else: dtag = tags.bodydata stag, etag = subs_tag(dtag, self.attributes) - result = result + dovetail_tags([stag], data, [etag]) + result = result + utils.dovetail_tags([stag], data, [etag]) i += cell.span return result @@ -3995,7 +3816,7 @@ class Macro: self.name = None self.pattern = entry return - if not is_re(e[0]): + if not utils.is_re(e[0]): raise EAsciiDoc('illegal macro regular expression: %s' % e[0]) pattern, name = e if name and name[0] in ('+', '#'): @@ -4344,7 +4165,7 @@ class Reader1: return result # Clone self and set as parent (self assumes the role of child). parent = Reader1() - assign(parent, self) + utils.assign(parent, self) self.parent = parent # Set attributes in child. if 'tabsize' in attrs: @@ -4394,7 +4215,7 @@ class Reader1: # End of current file. if self.parent: self.closefile() - assign(self, self.parent) # Restore parent reader. + utils.assign(self, self.parent) # Restore parent reader. document.attributes['infile'] = self.infile document.attributes['indir'] = self.indir return Reader1.eof(self) @@ -4635,7 +4456,7 @@ class Writer: self.lines_out = self.lines_out + 1 else: for arg in args: - if is_array(arg): + if utils.is_array(arg): for s in arg: self.write_line(s) elif arg is not None: @@ -4738,7 +4559,7 @@ class Config: message.stderr('FAILED: Python %d.%d or better required' % MIN_PYTHON_VERSION) sys.exit(1) global USER_DIR - USER_DIR = userdir() + USER_DIR = utils.userdir() if USER_DIR is not None: USER_DIR = os.path.join(USER_DIR, '.asciidoc') if not os.path.isdir(USER_DIR): @@ -5168,8 +4989,8 @@ class Config: d = {} parse_entries(self.sections.get('specialsections', ()), d, unquote=True) for pat, sectname in list(d.items()): - pat = strip_quotes(pat) - if not is_re(pat): + pat = utils.strip_quotes(pat) + if not utils.is_re(pat): raise EAsciiDoc('[specialsections] entry is not a valid regular expression: %s' % pat) if sectname is None: if pat in self.specialsections: @@ -5189,14 +5010,14 @@ class Config: @staticmethod def set_replacement(pat, rep, replacements): """Add pattern and replacement to replacements dictionary.""" - pat = strip_quotes(pat) - if not is_re(pat): + pat = utils.strip_quotes(pat) + if not utils.is_re(pat): return False if rep is None: if pat in replacements: del replacements[pat] else: - replacements[pat] = strip_quotes(rep) + replacements[pat] = utils.strip_quotes(rep) return True def subs_replacements(self, s, sect='replacements'): @@ -5224,8 +5045,8 @@ class Config: else: words = reo.findall(wordlist) for word in words: - word = strip_quotes(word) - if not is_re(word): + word = utils.strip_quotes(word) + if not utils.is_re(word): raise EAsciiDoc('[specialwords] entry in %s ' 'is not a valid regular expression: %s' % (self.fname, word)) self.specialwords[word] = name @@ -5841,7 +5662,7 @@ class Tables_OLD(AbstractBlocks): b.headdata = b.bodydata if not b.footdata: b.footdata = b.bodydata - self.delimiters = re_join(delimiters) + self.delimiters = utils.re_join(delimiters) # Check table definitions are valid. for b in self.blocks: b.validate() @@ -5951,7 +5772,7 @@ class Plugin: Return plugins path (.asciidoc/filters or .asciidoc/themes) in user's home directory or None if user home not defined. """ - result = userdir() + result = utils.userdir() if result: result = os.path.join(result, '.asciidoc', Plugin.type + 's') return result diff --git a/asciidoc/utils.py b/asciidoc/utils.py new file mode 100644 index 0000000..fee2c00 --- /dev/null +++ b/asciidoc/utils.py @@ -0,0 +1,185 @@ +import locale +import math +import os +import re +import time +from typing import Optional +import unicodedata + + +def userdir() -> Optional[str]: + """ + Return user's home directory or None if it is not defined. + """ + result = os.path.expanduser('~') + if result == '~': + result = None + return result + + +def file_in(fname, directory) -> bool: + """Return True if file fname resides inside directory.""" + assert os.path.isfile(fname) + # Empty directory (not to be confused with None) is the current directory. + if directory == '': + directory = os.getcwd() + else: + assert os.path.isdir(directory) + directory = os.path.realpath(directory) + fname = os.path.realpath(fname) + return os.path.commonprefix((directory, fname)) == directory + + +def assign(dst, src): + """Assign all attributes from 'src' object to 'dst' object.""" + for a, v in list(src.__dict__.items()): + setattr(dst, a, v) + + +def strip_quotes(s): + """Trim white space and, if necessary, quote characters from s.""" + s = s.strip() + # Strip quotation mark characters from quoted strings. + if len(s) >= 3 and s[0] == '"' and s[-1] == '"': + s = s[1:-1] + return s + + +def is_re(s) -> bool: + """Return True if s is a valid regular expression else return False.""" + try: + re.compile(s) + return True + except BaseException: + return False + + +def re_join(relist): + """Join list of regular expressions re1,re2,... to single regular + expression (re1)|(re2)|...""" + if len(relist) == 0: + return None + result = [] + # Delete named groups to avoid ambiguity. + for s in relist: + result.append(re.sub(r'\?P<\S+?>', '', s)) + result = ')|('.join(result) + result = '(' + result + ')' + return result + + +def lstrip_list(s): + """ + Return list with empty items from start of list removed. + """ + for i in range(len(s)): + if s[i]: + break + else: + return [] + return s[i:] + + +def rstrip_list(s): + """ + Return list with empty items from end of list removed. + """ + for i in range(len(s) - 1, -1, -1): + if s[i]: + break + else: + return [] + return s[:i + 1] + + +def strip_list(s): + """ + Return list with empty items from start and end of list removed. + """ + s = lstrip_list(s) + s = rstrip_list(s) + return s + + +def is_array(obj) -> bool: + """ + Return True if object is list or tuple type. + """ + return isinstance(obj, list) or isinstance(obj, tuple) + + +def dovetail(lines1, lines2): + """ + Append list or tuple of strings 'lines2' to list 'lines1'. Join the last + non-blank item in 'lines1' with the first non-blank item in 'lines2' into a + single string. + """ + assert is_array(lines1) + assert is_array(lines2) + lines1 = strip_list(lines1) + lines2 = strip_list(lines2) + if not lines1 or not lines2: + return list(lines1) + list(lines2) + result = list(lines1[:-1]) + result.append(lines1[-1] + lines2[0]) + result += list(lines2[1:]) + return result + + +def dovetail_tags(stag, content, etag): + """Merge the end tag with the first content line and the last + content line with the end tag. This ensures verbatim elements don't + include extraneous opening and closing line breaks.""" + return dovetail(dovetail(stag, content), etag) + + +def py2round(n, d=0): + """Utility function to get python2 rounding in python3. Python3 changed it such that + given two equally close multiples, it'll round towards the even choice. For example, + round(42.5) == 42 instead of the expected round(42.5) == 43). This function gives us + back that functionality.""" + p = 10 ** d + return float(math.floor((n * p) + math.copysign(0.5, n))) / p + + +east_asian_widths = { + 'W': 2, # Wide + 'F': 2, # Full-width (wide) + 'Na': 1, # Narrow + 'H': 1, # Half-width (narrow) + 'N': 1, # Neutral (not East Asian, treated as narrow) + 'A': 1, # Ambiguous (s/b wide in East Asian context, narrow otherwise, but that + # doesn't work) +} +"""Mapping of result codes from `unicodedata.east_asian_width()` to character +column widths.""" + + +def column_width(s): + width = 0 + for c in s: + width += east_asian_widths[unicodedata.east_asian_width(c)] + return width + + +def date_time_str(t): + """Convert seconds since the Epoch to formatted local date and time strings.""" + source_date_epoch = os.environ.get('SOURCE_DATE_EPOCH') + if source_date_epoch is not None: + t = time.gmtime(min(t, int(source_date_epoch))) + else: + t = time.localtime(t) + date_str = time.strftime('%Y-%m-%d', t) + time_str = time.strftime('%H:%M:%S', t) + if source_date_epoch is not None: + time_str += ' UTC' + elif time.daylight and t.tm_isdst == 1: + time_str += ' ' + time.tzname[1] + else: + time_str += ' ' + time.tzname[0] + # Attempt to convert the localtime to the output encoding. + try: + time_str = time_str.decode(locale.getdefaultlocale()[1]) + except Exception: + pass + return date_str, time_str diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..eb32fbf --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,103 @@ +import pytest +from pytest_mock import MockerFixture +from typing import Optional, Tuple + +from asciidoc import utils + + +@pytest.mark.parametrize( + "input,expected", + ( + ('/home/user', '/home/user'), + ('~', None), + ) +) +def test_userdir(mocker: MockerFixture, input: str, expected: Optional[str]) -> None: + mocker.patch('os.path.expanduser', return_value=input) + assert utils.userdir() == expected + + +@pytest.mark.parametrize( + "input,expected", + ( + (' a ', 'a'), + ('"a"', 'a'), + (' "b ', '"b'), + (' b" ', 'b"'), + ('""', '""'), + ), +) +def test_strip_quotes(input: str, expected: str) -> None: + assert utils.strip_quotes(input) == expected + + +@pytest.mark.parametrize( + "input,expected", + ( + (('a', 'b'), ('a', 'b')), + (('', 'a', 'b'), ('a', 'b')), + (('a', 'b', ''), ('a', 'b', '')), + (('', 'a', 'b', ''), ('a', 'b', '')), + ), +) +def test_lstrip_list(input: Tuple[str, ...], expected: Tuple[str, ...]) -> None: + assert utils.lstrip_list(input) == expected + + +@pytest.mark.parametrize( + "input,expected", + ( + (('a', 'b'), ('a', 'b')), + (('', 'a', 'b'), ('', 'a', 'b')), + (('a', 'b', ''), ('a', 'b')), + (('', 'a', 'b', ''), ('', 'a', 'b')), + ), +) +def test_rstrip_list(input: Tuple[str, ...], expected: Tuple[str, ...]) -> None: + assert utils.rstrip_list(input) == expected + + +@pytest.mark.parametrize( + "input,expected", + ( + (('a', 'b'), ('a', 'b')), + (('', 'a', 'b'), ('a', 'b')), + (('a', 'b', ''), ('a', 'b')), + (('', 'a', 'b', ''), ('a', 'b')), + ), +) +def test_strip_list(input: Tuple[str, ...], expected: Tuple[str, ...]) -> None: + assert utils.strip_list(input) == expected + + +@pytest.mark.parametrize( + "input,expected", + ( + ((1,), True), + ([1], True), + ('a', False), + ), +) +def test_is_array(input, expected): + assert utils.is_array(input) == expected + + +@pytest.mark.parametrize( + "n,d,expected", + ( + (42.0, 0, 42), + (42.4, 0, 42), + (42.5, 0, 43), + (42.6, 0, 43), + (42.9, 0, 43), + (42.0, 2, 42), + (42.5, 2, 42.5), + (42.550, 2, 42.55), + (42.554, 2, 42.55), + (42.555, 2, 42.56), + (42.556, 2, 42.56), + (42.559, 2, 42.56), + ), +) +def test_py2round(n: float, d: int, expected: float) -> None: + assert utils.py2round(n, d) == expected -- cgit v1.2.1