summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicolas Chauvat <nicolas.chauvat@logilab.fr>2009-07-31 21:09:54 +0200
committerNicolas Chauvat <nicolas.chauvat@logilab.fr>2009-07-31 21:09:54 +0200
commit506b3cdacb4542801e5f5da1f2706e643194ef0e (patch)
treefd4f2f923a45f8d04af7d79e8debc1d32af64a87
parent31def7ddeaeedc058f50703e973f6fcf619a7f99 (diff)
downloadlogilab-common-506b3cdacb4542801e5f5da1f2706e643194ef0e.tar.gz
F [shellutils] new RawInput supersedes confirm
-rw-r--r--shellutils.py58
-rw-r--r--test/unittest_shellutils.py65
2 files changed, 110 insertions, 13 deletions
diff --git a/shellutils.py b/shellutils.py
index 67e2888..76998e6 100644
--- a/shellutils.py
+++ b/shellutils.py
@@ -278,18 +278,52 @@ class ProgressBar(object):
self._stream.write(self._fstr % ('.' * min(self._progress, self._size)) )
self._stream.flush()
+from logilab.common.deprecation import deprecated
+
+@deprecated('confirm() is deprecated, use RawInput.confirm() instead')
def confirm(question, default_is_yes=True):
"""ask for confirmation and return true on positive answer"""
- if default_is_yes:
- input_str = '%s [Y/n]: '
- else:
- input_str = '%s [y/N]: '
- answer = raw_input(input_str % (question)).strip().lower()
- if default_is_yes:
- if answer in ('n', 'no'):
- return False
- return True
- if answer in ('y', 'yes'):
- return True
- return False
+ return RawInput().confirm(question, default_is_yes)
+
+class RawInput(object):
+ def __init__(self, input=None, printer=None):
+ self._input = input or raw_input
+ self._print = printer
+
+ def ask(self, question, options, default):
+ assert default in options
+ choices = []
+ for option in options:
+ if option == default:
+ label = option[0].upper()
+ else:
+ label = option[0].lower()
+ if len(option) > 1:
+ label += '(%s)' % option[1:].lower()
+ choices.append((option, label))
+ prompt = "%s [%s]: " % (question,
+ '/'.join(opt[1] for opt in choices))
+ tries = 3
+ while tries > 0:
+ answer = self._input(prompt).strip().lower()
+ if answer:
+ possible = [option for option, label in choices
+ if option.lower().startswith(answer)]
+ if len(possible) == 1:
+ return possible[0]
+ else:
+ return default
+ msg = ('%s is an ambiguous answer, do you mean %s ?' % (
+ answer, ' or '.join(possible)))
+ if self._print:
+ self._print(msg)
+ else:
+ print msg
+ tries -= 1
+ raise Exception('unable to get a sensible answer')
+
+ def confirm(self, question, default_is_yes=True):
+ default = default_is_yes and 'y' or 'n'
+ answer = self.ask(question, ('y','n'), default)
+ return answer == 'y'
diff --git a/test/unittest_shellutils.py b/test/unittest_shellutils.py
index e0ba67f..20add35 100644
--- a/test/unittest_shellutils.py
+++ b/test/unittest_shellutils.py
@@ -6,7 +6,10 @@ import datetime, time
from logilab.common.testlib import TestCase, unittest_main
-from logilab.common.shellutils import globfind, find, ProgressBar, acquire_lock, release_lock
+from logilab.common.shellutils import (globfind, find, ProgressBar,
+ acquire_lock, release_lock,
+ RawInput, confirm)
+
from logilab.common.proc import NoSuchProcess
from StringIO import StringIO
@@ -163,6 +166,66 @@ class AcquireLockTC(TestCase):
os.system("touch -t %s %s" % (touch, self.lock))
self.assertRaises(UserWarning, acquire_lock, self.lock, max_try=2, delay=1)
+class RawInputTC(TestCase):
+
+ def auto_input(self, *args):
+ self.input_args = args
+ return self.input_answer
+
+ def setUp(self):
+ null_printer = lambda x: None
+ self.qa = RawInput(self.auto_input, null_printer)
+
+ def test_ask_default(self):
+ self.input_answer = ''
+ answer = self.qa.ask('text', ('yes','no'), 'yes')
+ self.assertEquals(answer, 'yes')
+ self.input_answer = ' '
+ answer = self.qa.ask('text', ('yes','no'), 'yes')
+ self.assertEquals(answer, 'yes')
+
+ def test_ask_case(self):
+ self.input_answer = 'no'
+ answer = self.qa.ask('text', ('yes','no'), 'yes')
+ self.assertEquals(answer, 'no')
+ self.input_answer = 'No'
+ answer = self.qa.ask('text', ('yes','no'), 'yes')
+ self.assertEquals(answer, 'no')
+ self.input_answer = 'NO'
+ answer = self.qa.ask('text', ('yes','no'), 'yes')
+ self.assertEquals(answer, 'no')
+ self.input_answer = 'nO'
+ answer = self.qa.ask('text', ('yes','no'), 'yes')
+ self.assertEquals(answer, 'no')
+ self.input_answer = 'YES'
+ answer = self.qa.ask('text', ('yes','no'), 'yes')
+ self.assertEquals(answer, 'yes')
+
+ def test_ask_prompt(self):
+ self.input_answer = ''
+ answer = self.qa.ask('text', ('yes','no'), 'yes')
+ self.assertEquals(self.input_args[0], 'text [Y(es)/n(o)]: ')
+ answer = self.qa.ask('text', ('y','n'), 'y')
+ self.assertEquals(self.input_args[0], 'text [Y/n]: ')
+ answer = self.qa.ask('text', ('n','y'), 'y')
+ self.assertEquals(self.input_args[0], 'text [n/Y]: ')
+ answer = self.qa.ask('text', ('yes','no','maybe','1'), 'yes')
+ self.assertEquals(self.input_args[0], 'text [Y(es)/n(o)/m(aybe)/1]: ')
+
+ def test_ask_ambiguous(self):
+ self.input_answer = 'y'
+ self.assertRaises(Exception, self.qa.ask, 'text', ('yes','yep'), 'yes')
+
+ def test_confirm(self):
+ self.input_answer = 'y'
+ self.assertEquals(self.qa.confirm('Say yes'), True)
+ self.assertEquals(self.qa.confirm('Say yes', default_is_yes=False), True)
+ self.input_answer = 'n'
+ self.assertEquals(self.qa.confirm('Say yes'), False)
+ self.assertEquals(self.qa.confirm('Say yes', default_is_yes=False), False)
+ self.input_answer = ''
+ self.assertEquals(self.qa.confirm('Say default'), True)
+ self.assertEquals(self.qa.confirm('Say default', default_is_yes=False), False)
if __name__ == '__main__':
unittest_main()