summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xbin/ansible-connection58
-rw-r--r--lib/ansible/plugins/connection/network_cli.py60
-rw-r--r--lib/ansible/plugins/connection/persistent.py16
-rw-r--r--lib/ansible/plugins/terminal/__init__.py41
-rw-r--r--lib/ansible/plugins/terminal/ios.py51
-rw-r--r--test/units/plugins/connection/test_network_cli.py30
6 files changed, 145 insertions, 111 deletions
diff --git a/bin/ansible-connection b/bin/ansible-connection
index 10449115f6..676cd70223 100755
--- a/bin/ansible-connection
+++ b/bin/ansible-connection
@@ -45,7 +45,8 @@ from io import BytesIO
from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native
-from ansible.module_utils.six.moves import cPickle, StringIO
+from ansible.module_utils.six import PY3
+from ansible.module_utils.six.moves import cPickle
from ansible.playbook.play_context import PlayContext
from ansible.plugins import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe
@@ -73,11 +74,11 @@ def do_fork():
sys.exit(0)
if C.DEFAULT_LOG_PATH != '':
- out_file = file(C.DEFAULT_LOG_PATH, 'a+')
- err_file = file(C.DEFAULT_LOG_PATH, 'a+', 0)
+ out_file = open(C.DEFAULT_LOG_PATH, 'ab+')
+ err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0)
else:
- out_file = file('/dev/null', 'a+')
- err_file = file('/dev/null', 'a+', 0)
+ out_file = open('/dev/null', 'ab+')
+ err_file = open('/dev/null', 'ab+', 0)
os.dup2(out_file.fileno(), sys.stdout.fileno())
os.dup2(err_file.fileno(), sys.stderr.fileno())
@@ -90,7 +91,7 @@ def do_fork():
sys.exit(1)
def send_data(s, data):
- packed_len = struct.pack('!Q',len(data))
+ packed_len = struct.pack('!Q', len(data))
return s.sendall(packed_len + data)
def recv_data(s):
@@ -101,7 +102,7 @@ def recv_data(s):
if not d:
return None
data += d
- data_len = struct.unpack('!Q',data[:header_len])[0]
+ data_len = struct.unpack('!Q', data[:header_len])[0]
data = data[header_len:]
while len(data) < data_len:
d = s.recv(data_len - len(data))
@@ -211,11 +212,9 @@ class Server():
pass
elif data.startswith(b'CONTEXT: '):
display.display("socket operation is CONTEXT", log_only=True)
- pc_data = data.split(b'CONTEXT: ')[1]
+ pc_data = data.split(b'CONTEXT: ', 1)[1]
- src = StringIO(pc_data)
- pc_data = cPickle.load(src)
- src.close()
+ pc_data = cPickle.loads(pc_data)
pc = PlayContext()
pc.deserialize(pc_data)
@@ -234,12 +233,12 @@ class Server():
display.display("socket operation completed with rc %s" % rc, log_only=True)
- send_data(s, to_bytes(str(rc)))
+ send_data(s, to_bytes(rc))
send_data(s, to_bytes(stdout))
send_data(s, to_bytes(stderr))
s.close()
except Exception as e:
- display.display(traceback.format_exec(), log_only=True)
+ display.display(traceback.format_exc(), log_only=True)
finally:
# when done, close the connection properly and cleanup
# the socket file so it can be recreated
@@ -254,21 +253,25 @@ class Server():
os.remove(self.path)
def main():
+ # Need stdin as a byte stream
+ if PY3:
+ stdin = sys.stdin.buffer
+ else:
+ stdin = sys.stdin
try:
# read the play context data via stdin, which means depickling it
# FIXME: as noted above, we will probably need to deserialize the
# connection loader here as well at some point, otherwise this
# won't find role- or playbook-based connection plugins
- cur_line = sys.stdin.readline()
- init_data = ''
- while cur_line.strip() != '#END_INIT#':
- if cur_line == '':
- raise Exception("EOL found before init data was complete")
+ cur_line = stdin.readline()
+ init_data = b''
+ while cur_line.strip() != b'#END_INIT#':
+ if cur_line == b'':
+ raise Exception("EOF found before init data was complete")
init_data += cur_line
- cur_line = sys.stdin.readline()
- src = BytesIO(to_bytes(init_data))
- pc_data = cPickle.load(src)
+ cur_line = stdin.readline()
+ pc_data = cPickle.loads(init_data)
pc = PlayContext()
pc.deserialize(pc_data)
@@ -319,10 +322,10 @@ def main():
# the connection will timeout here. Need to make this more resilient.
rc = 0
while rc == 0:
- data = sys.stdin.readline()
- if data == '':
+ data = stdin.readline()
+ if data == b'':
break
- if data.strip() == '':
+ if data.strip() == b'':
continue
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
attempts = 1
@@ -342,11 +345,10 @@ def main():
# send the play_context back into the connection so the connection
# can handle any privilege escalation activities
- pc_data = 'CONTEXT: %s' % src.getvalue()
- send_data(sf, to_bytes(pc_data))
- src.close()
+ pc_data = b'CONTEXT: %s' % init_data
+ send_data(sf, pc_data)
- send_data(sf, to_bytes(data.strip()))
+ send_data(sf, data.strip())
rc = int(recv_data(sf), 10)
stdout = recv_data(sf)
diff --git a/lib/ansible/plugins/connection/network_cli.py b/lib/ansible/plugins/connection/network_cli.py
index 137884104f..6f004c00fc 100644
--- a/lib/ansible/plugins/connection/network_cli.py
+++ b/lib/ansible/plugins/connection/network_cli.py
@@ -18,17 +18,18 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
-import re
-import socket
import json
+import logging
+import re
import signal
-import datetime
+import socket
import traceback
-import logging
+from collections import Sequence
from ansible import constants as C
from ansible.errors import AnsibleConnectionFailure
-from ansible.module_utils.six.moves import StringIO
+from ansible.module_utils.six import BytesIO, binary_type, text_type
+from ansible.module_utils._text import to_bytes, to_text
from ansible.plugins import terminal_loader
from ansible.plugins.connection import ensure_connect
from ansible.plugins.connection.paramiko_ssh import Connection as _Connection
@@ -113,7 +114,7 @@ class Connection(_Connection):
self._terminal.on_authorize(passwd=auth_pass)
display.display('shell successfully opened', log_only=True)
- return (0, 'ok', '')
+ return (0, b'ok', b'')
def close(self):
display.display('closing connection', log_only=True)
@@ -131,11 +132,11 @@ class Connection(_Connection):
self._shell.close()
self._shell = None
- return (0, 'ok', '')
+ return (0, b'ok', b'')
def receive(self, obj=None):
"""Handles receiving of output from command"""
- recv = StringIO()
+ recv = BytesIO()
handled = False
self._matched_prompt = None
@@ -162,30 +163,30 @@ class Connection(_Connection):
try:
command = obj['command']
self._history.append(command)
- self._shell.sendall('%s\r' % command)
+ self._shell.sendall(b'%s\r' % command)
if obj.get('sendonly'):
return
return self.receive(obj)
- except (socket.timeout, AttributeError) as exc:
+ except (socket.timeout, AttributeError):
display.display(traceback.format_exc(), log_only=True)
raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip())
def _strip(self, data):
"""Removes ANSI codes from device response"""
for regex in self._terminal.ansi_re:
- data = regex.sub('', data)
+ data = regex.sub(b'', data)
return data
def _handle_prompt(self, resp, obj):
"""Matches the command prompt and responds"""
- if not isinstance(obj['prompt'], list):
+ if isinstance(obj, (binary_type, text_type)) or not isinstance(obj['prompt'], Sequence):
obj['prompt'] = [obj['prompt']]
prompts = [re.compile(r, re.I) for r in obj['prompt']]
answer = obj['answer']
for regex in prompts:
match = regex.search(resp)
if match:
- self._shell.sendall('%s\r' % answer)
+ self._shell.sendall(b'%s\r' % answer)
return True
def _sanitize(self, resp, obj=None):
@@ -196,7 +197,7 @@ class Connection(_Connection):
if (command and line.startswith(command.strip())) or self._matched_prompt.strip() in line:
continue
cleaned.append(line)
- return str("\n".join(cleaned)).strip()
+ return b"\n".join(cleaned).strip()
def _find_prompt(self, response):
"""Searches the buffered response for a matching command prompt"""
@@ -225,9 +226,9 @@ class Connection(_Connection):
def exec_command(self, cmd):
"""Executes the cmd on in the shell and returns the output
- The method accepts two forms of cmd. The first form is as a
+ The method accepts two forms of cmd. The first form is as a byte
string that represents the command to be executed in the shell. The
- second form is as a JSON string with additional keyword.
+ second form is as a utf8 JSON byte string with additional keywords.
Keywords supported for cmd:
* command - the command string to execute
@@ -235,28 +236,30 @@ class Connection(_Connection):
* answer - the string to respond to the prompt with
* sendonly - bool to disable waiting for response
- :arg cmd: the string that represents the command to be executed
- which can be a single command or a json encoded string
+ :arg cmd: the byte string that represents the command to be executed
+ which can be a single command or a json encoded string.
:returns: a tuple of (return code, stdout, stderr). The return
- code is an integer and stdout and stderr are strings
+ code is an integer and stdout and stderr are byte strings
"""
try:
- obj = json.loads(cmd)
+ obj = json.loads(to_text(cmd, errors='surrogate_or_strict'))
+ obj = dict((k, to_bytes(v, errors='surrogate_or_strict', nonstring='passthru')) for k, v in obj.items())
except (ValueError, TypeError):
- obj = {'command': str(cmd).strip()}
+ obj = {'command': to_bytes(cmd.strip(), errors='surrogate_or_strict')}
- if obj['command'] == 'close_shell()':
+ if obj['command'] == b'close_shell()':
return self.close_shell()
- elif obj['command'] == 'open_shell()':
+ elif obj['command'] == b'open_shell()':
return self.open_shell()
- elif obj['command'] == 'prompt()':
- return (0, self._matched_prompt, '')
+ elif obj['command'] == b'prompt()':
+ return (0, self._matched_prompt, b'')
try:
if self._shell is None:
self.open_shell()
except AnsibleConnectionFailure as exc:
- return (1, '', str(exc))
+ # FIXME: Feels like we should raise this rather than return it
+ return (1, b'', to_bytes(exc))
try:
if not signal.getsignal(signal.SIGALRM):
@@ -264,6 +267,7 @@ class Connection(_Connection):
signal.alarm(self._play_context.timeout)
out = self.send(obj)
signal.alarm(0)
- return (0, out, '')
+ return (0, out, b'')
except (AnsibleConnectionFailure, ValueError) as exc:
- return (1, '', str(exc))
+ # FIXME: Feels like we should raise this rather than return it
+ return (1, b'', to_bytes(exc))
diff --git a/lib/ansible/plugins/connection/persistent.py b/lib/ansible/plugins/connection/persistent.py
index 8465d66268..fc210a9766 100644
--- a/lib/ansible/plugins/connection/persistent.py
+++ b/lib/ansible/plugins/connection/persistent.py
@@ -24,7 +24,7 @@ import subprocess
import sys
from ansible.module_utils._text import to_bytes
-from ansible.module_utils.six.moves import cPickle, StringIO
+from ansible.module_utils.six.moves import cPickle
from ansible.plugins.connection import ConnectionBase
try:
@@ -52,16 +52,20 @@ class Connection(ConnectionBase):
stdin = os.fdopen(master, 'wb', 0)
os.close(slave)
- src = StringIO()
- cPickle.dump(self._play_context.serialize(), src)
- stdin.write(src.getvalue())
- src.close()
+ # Need to force a protocol that is compatible with both py2 and py3.
+ # That would be protocol=2 or less.
+ # Also need to force a protocol that excludes certain control chars as
+ # stdin in this case is a pty and control chars will cause problems.
+ # that means only protocol=0 will work.
+ src = cPickle.dumps(self._play_context.serialize(), protocol=0)
+ stdin.write(src)
stdin.write(b'\n#END_INIT#\n')
stdin.write(to_bytes(action))
stdin.write(b'\n\n')
- stdin.close()
+
(stdout, stderr) = p.communicate()
+ stdin.close()
return (p.returncode, stdout, stderr)
diff --git a/lib/ansible/plugins/terminal/__init__.py b/lib/ansible/plugins/terminal/__init__.py
index e8e04884eb..e52ae62273 100644
--- a/lib/ansible/plugins/terminal/__init__.py
+++ b/lib/ansible/plugins/terminal/__init__.py
@@ -30,33 +30,54 @@ from ansible.module_utils.six import with_metaclass
class TerminalBase(with_metaclass(ABCMeta, object)):
'''
A base class for implementing cli connections
+
+ .. note:: Unlike most of Ansible, nearly all strings in
+ :class:`TerminalBase` plugins are byte strings. This is because of
+ how close to the underlying platform these plugins operate. Remember
+ to mark literal strings as byte string (``b"string"``) and to use
+ :func:`~ansible.module_utils._text.to_bytes` and
+ :func:`~ansible.module_utils._text.to_text` to avoid unexpected
+ problems.
'''
- # compiled regular expression as stdout
+ #: compiled bytes regular expressions as stdout
terminal_stdout_re = []
- # compiled regular expression as stderr
+ #: compiled bytes regular expressions as stderr
terminal_stderr_re = []
- # copiled regular expression to remove ANSI codes
+ #: compiled bytes regular expressions to remove ANSI codes
ansi_re = [
- re.compile(r'(\x1b\[\?1h\x1b=)'),
- re.compile(r'\x08.')
+ re.compile(br'(\x1b\[\?1h\x1b=)'),
+ re.compile(br'\x08.')
]
def __init__(self, connection):
self._connection = connection
def _exec_cli_command(self, cmd, check_rc=True):
- """Executes a CLI command on the device"""
+ """
+ Executes a CLI command on the device
+
+ :arg cmd: Byte string consisting of the command to execute
+ :kwarg check_rc: If True, the default, raise an
+ :exc:`AnsibleConnectionFailure` if the return code from the
+ command is nonzero
+ :returns: A tuple of return code, stdout, and stderr from running the
+ command. stdout and stderr are both byte strings.
+ """
rc, out, err = self._connection.exec_command(cmd)
if check_rc and rc != 0:
raise AnsibleConnectionFailure(err)
return rc, out, err
def _get_prompt(self):
- """ Returns the current prompt from the device"""
- for cmd in ['\n', 'prompt()']:
+ """
+ Returns the current prompt from the device
+
+ :returns: A byte string of the prompt
+ """
+ for cmd in (b'\n', b'prompt()'):
rc, out, err = self._exec_cli_command(cmd)
return out
@@ -82,6 +103,8 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
def on_authorize(self, passwd=None):
"""Called when privilege escalation is requested
+ :kwarg passwd: String containing the password
+
This method is called when the privilege is requested to be elevated
in the play context by setting become to True. It is the responsibility
of the terminal plugin to actually do the privilege escalation such
@@ -94,6 +117,6 @@ class TerminalBase(with_metaclass(ABCMeta, object)):
This method is called when the privilege changed from escalated
(become=True) to non escalated (become=False). It is the responsibility
- of the this method to actually perform the deauthorization procedure
+ of this method to actually perform the deauthorization procedure
"""
pass
diff --git a/lib/ansible/plugins/terminal/ios.py b/lib/ansible/plugins/terminal/ios.py
index 4ce5dc9406..fb01e79fa3 100644
--- a/lib/ansible/plugins/terminal/ios.py
+++ b/lib/ansible/plugins/terminal/ios.py
@@ -19,49 +19,52 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type
-import re
import json
+import re
-from ansible.plugins.terminal import TerminalBase
from ansible.errors import AnsibleConnectionFailure
+from ansible.module_utils._text import to_bytes
+from ansible.plugins.terminal import TerminalBase
class TerminalModule(TerminalBase):
terminal_stdout_re = [
- re.compile(r"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
- re.compile(r"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$")
+ re.compile(br"[\r\n]?[\w+\-\.:\/\[\]]+(?:\([^\)]+\)){,3}(?:>|#) ?$"),
+ re.compile(br"\[\w+\@[\w\-\.]+(?: [^\]])\] ?[>#\$] ?$")
]
terminal_stderr_re = [
- re.compile(r"% ?Error"),
- #re.compile(r"^% \w+", re.M),
- re.compile(r"% ?Bad secret"),
- re.compile(r"invalid input", re.I),
- re.compile(r"(?:incomplete|ambiguous) command", re.I),
- re.compile(r"connection timed out", re.I),
- re.compile(r"[^\r\n]+ not found", re.I),
- re.compile(r"'[^']' +returned error code: ?\d+"),
+ re.compile(br"% ?Error"),
+ #re.compile(br"^% \w+", re.M),
+ re.compile(br"% ?Bad secret"),
+ re.compile(br"invalid input", re.I),
+ re.compile(br"(?:incomplete|ambiguous) command", re.I),
+ re.compile(br"connection timed out", re.I),
+ re.compile(br"[^\r\n]+ not found", re.I),
+ re.compile(br"'[^']' +returned error code: ?\d+"),
]
def on_open_shell(self):
try:
- for cmd in ['terminal length 0', 'terminal width 512']:
+ for cmd in (b'terminal length 0', b'terminal width 512'):
self._exec_cli_command(cmd)
except AnsibleConnectionFailure:
raise AnsibleConnectionFailure('unable to set terminal parameters')
def on_authorize(self, passwd=None):
- if self._get_prompt().endswith('#'):
+ if self._get_prompt().endswith(b'#'):
return
- cmd = {'command': 'enable'}
+ cmd = {u'command': u'enable'}
if passwd:
- cmd['prompt'] = r"[\r\n]?password: $"
- cmd['answer'] = passwd
+ # Note: python-3.5 cannot combine u"" and r"" together. Thus make
+ # an r string and use to_text to ensure it's text on both py2 and py3.
+ cmd[u'prompt'] = to_text(r"[\r\n]?password: $", errors='surrogate_or_strict')
+ cmd[u'answer'] = passwd
try:
- self._exec_cli_command(json.dumps(cmd))
+ self._exec_cli_command(to_bytes(json.dumps(cmd), errors='surrogate_or_strict'))
except AnsibleConnectionFailure:
raise AnsibleConnectionFailure('unable to elevate privilege to enable mode')
@@ -71,11 +74,9 @@ class TerminalModule(TerminalBase):
# if prompt is None most likely the terminal is hung up at a prompt
return
- if '(config' in prompt:
- self._exec_cli_command('end')
- self._exec_cli_command('disable')
-
- elif prompt.endswith('#'):
- self._exec_cli_command('disable')
-
+ if b'(config' in prompt:
+ self._exec_cli_command(b'end')
+ self._exec_cli_command(b'disable')
+ elif prompt.endswith(b'#'):
+ self._exec_cli_command(b'disable')
diff --git a/test/units/plugins/connection/test_network_cli.py b/test/units/plugins/connection/test_network_cli.py
index 818376a2eb..b9e4c8e40e 100644
--- a/test/units/plugins/connection/test_network_cli.py
+++ b/test/units/plugins/connection/test_network_cli.py
@@ -117,21 +117,21 @@ class TestConnectionClass(unittest.TestCase):
mock_open_shell = MagicMock()
conn.open_shell = mock_open_shell
- mock_send = MagicMock(return_value='command response')
+ mock_send = MagicMock(return_value=b'command response')
conn.send = mock_send
# test sending a single command and converting to dict
rc, out, err = conn.exec_command('command')
- self.assertEqual(out, 'command response')
+ self.assertEqual(out, b'command response')
self.assertTrue(mock_open_shell.called)
- mock_send.assert_called_with({'command': 'command'})
+ mock_send.assert_called_with({'command': b'command'})
mock_open_shell.reset_mock()
# test sending a json string
rc, out, err = conn.exec_command(json.dumps({'command': 'command'}))
- self.assertEqual(out, 'command response')
- mock_send.assert_called_with({'command': 'command'})
+ self.assertEqual(out, b'command response')
+ mock_send.assert_called_with({'command': b'command'})
self.assertTrue(mock_open_shell.called)
mock_open_shell.reset_mock()
@@ -139,9 +139,9 @@ class TestConnectionClass(unittest.TestCase):
# test _shell already open
rc, out, err = conn.exec_command('command')
- self.assertEqual(out, 'command response')
+ self.assertEqual(out, b'command response')
self.assertFalse(mock_open_shell.called)
- mock_send.assert_called_with({'command': 'command'})
+ mock_send.assert_called_with({'command': b'command'})
def test_network_cli_send(self):
@@ -150,14 +150,14 @@ class TestConnectionClass(unittest.TestCase):
conn = network_cli.Connection(pc, new_stdin)
mock__terminal = MagicMock()
- mock__terminal.terminal_stdout_re = [re.compile('device#')]
- mock__terminal.terminal_stderr_re = [re.compile('^ERROR')]
+ mock__terminal.terminal_stdout_re = [re.compile(b'device#')]
+ mock__terminal.terminal_stderr_re = [re.compile(b'^ERROR')]
conn._terminal = mock__terminal
mock__shell = MagicMock()
conn._shell = mock__shell
- response = """device#command
+ response = b"""device#command
command response
device#
@@ -165,15 +165,15 @@ class TestConnectionClass(unittest.TestCase):
mock__shell.recv.return_value = response
- output = conn.send({'command': 'command'})
+ output = conn.send({'command': b'command'})
- mock__shell.sendall.assert_called_with('command\r')
- self.assertEqual(output, 'command response')
+ mock__shell.sendall.assert_called_with(b'command\r')
+ self.assertEqual(output, b'command response')
mock__shell.reset_mock()
- mock__shell.recv.return_value = "ERROR: error message"
+ mock__shell.recv.return_value = b"ERROR: error message"
with self.assertRaises(AnsibleConnectionFailure) as exc:
- conn.send({'command': 'command'})
+ conn.send({'command': b'command'})
self.assertEqual(str(exc.exception), 'ERROR: error message')