summaryrefslogtreecommitdiff
path: root/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'test.py')
-rw-r--r--test.py112
1 files changed, 111 insertions, 1 deletions
diff --git a/test.py b/test.py
index 98debff..2433256 100644
--- a/test.py
+++ b/test.py
@@ -1,8 +1,13 @@
# -*- coding: utf-8 -*-
-
+import io
+import os
+import sys
import unittest
+from contextlib import contextmanager
+
from slugify import slugify
from slugify import smart_truncate
+from slugify.__main__ import slugify_params, parse_args
class TestSlugification(unittest.TestCase):
@@ -242,5 +247,110 @@ class TestUtils(unittest.TestCase):
self.assertEqual(r, txt)
+PY3 = sys.version_info.major == 3
+
+
+@contextmanager
+def captured_stderr():
+ backup = sys.stderr
+ sys.stderr = io.StringIO() if PY3 else io.BytesIO()
+ try:
+ yield sys.stderr
+ finally:
+ sys.stderr = backup
+
+
+@contextmanager
+def loaded_stdin(contents):
+ backup = sys.stdin
+ sys.stdin = io.StringIO(contents) if PY3 else io.BytesIO(contents)
+ try:
+ yield sys.stdin
+ finally:
+ sys.stdin = backup
+
+
+class TestCommandParams(unittest.TestCase):
+ DEFAULTS = {
+ 'entities': True,
+ 'decimal': True,
+ 'hexadecimal': True,
+ 'max_length': 0,
+ 'word_boundary': False,
+ 'save_order': False,
+ 'separator': '-',
+ 'stopwords': None,
+ 'lowercase': True,
+ 'replacements': None
+ }
+
+ def get_params_from_cli(self, *argv):
+ args = parse_args([None] + list(argv))
+ return slugify_params(args)
+
+ def make_params(self, **values):
+ return dict(self.DEFAULTS, **values)
+
+ def assertParamsMatch(self, expected, checked):
+ reduced_checked = {}
+ for key in expected.keys():
+ reduced_checked[key] = checked[key]
+ self.assertEqual(expected, reduced_checked)
+
+ def test_defaults(self):
+ params = self.get_params_from_cli()
+ self.assertParamsMatch(self.DEFAULTS, params)
+
+ def test_negative_flags(self):
+ params = self.get_params_from_cli('--no-entities', '--no-decimal', '--no-hexadecimal', '--no-lowercase')
+ expected = self.make_params(entities=False, decimal=False, hexadecimal=False, lowercase=False)
+ self.assertFalse(expected['lowercase'])
+ self.assertFalse(expected['word_boundary'])
+ self.assertParamsMatch(expected, params)
+
+ def test_affirmative_flags(self):
+ params = self.get_params_from_cli('--word-boundary', '--save-order')
+ expected = self.make_params(word_boundary=True, save_order=True)
+ self.assertParamsMatch(expected, params)
+
+ def test_valued_arguments(self):
+ params = self.get_params_from_cli('--stopwords', 'abba', 'beatles', '--max-length', '98', '--separator', '+')
+ expected = self.make_params(stopwords=['abba', 'beatles'], max_length=98, separator='+')
+ self.assertParamsMatch(expected, params)
+
+ def test_replacements_right(self):
+ params = self.get_params_from_cli('--replacements', 'A->B', 'C->D')
+ expected = self.make_params(replacements=[['A', 'B'], ['C', 'D']])
+ self.assertParamsMatch(expected, params)
+
+ def test_replacements_wrong(self):
+ with self.assertRaises(SystemExit) as err, captured_stderr() as cse:
+ self.get_params_from_cli('--replacements', 'A--B')
+ self.assertEqual(err.exception.code, 2)
+ self.assertIn("Replacements must be of the form: ORIGINAL->REPLACED", cse.getvalue())
+
+ def test_text_in_cli(self):
+ params = self.get_params_from_cli('Cool Text')
+ expected = self.make_params(text='Cool Text')
+ self.assertParamsMatch(expected, params)
+
+ def test_text_in_cli_multi(self):
+ params = self.get_params_from_cli('Cool', 'Text')
+ expected = self.make_params(text='Cool Text')
+ self.assertParamsMatch(expected, params)
+
+ def test_text_in_stdin(self):
+ with loaded_stdin("Cool Stdin"):
+ params = self.get_params_from_cli('--stdin')
+ expected = self.make_params(text='Cool Stdin')
+ self.assertParamsMatch(expected, params)
+
+ def test_two_text_sources_fails(self):
+ with self.assertRaises(SystemExit) as err, captured_stderr() as cse:
+ self.get_params_from_cli('--stdin', 'Text')
+ self.assertEqual(err.exception.code, 2)
+ self.assertIn("Input strings and --stdin cannot work together", cse.getvalue())
+
+
if __name__ == '__main__':
unittest.main()