diff options
-rwxr-xr-x | redis/connection.py | 9 | ||||
-rw-r--r-- | tests/test_connection.py | 34 |
2 files changed, 41 insertions, 2 deletions
diff --git a/redis/connection.py b/redis/connection.py index 6c4494b..1bb8eb5 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -717,9 +717,14 @@ class Connection: self._parser.on_disconnect() if self._sock is None: return - try: - if os.getpid() == self.pid: + + if os.getpid() == self.pid: + try: self._sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass + + try: self._sock.close() except OSError: pass diff --git a/tests/test_connection.py b/tests/test_connection.py index 22f1b71..d94a815 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -3,6 +3,7 @@ from unittest import mock import pytest +from redis.connection import Connection from redis.exceptions import InvalidResponse from redis.utils import HIREDIS_AVAILABLE @@ -40,3 +41,36 @@ def test_loading_external_modules(modclient): # mod = j(modclient) # mod.set("fookey", ".", d) # assert mod.get('fookey') == d + + +class TestConnection: + def test_disconnect(self): + conn = Connection() + mock_sock = mock.Mock() + conn._sock = mock_sock + conn.disconnect() + mock_sock.shutdown.assert_called_once() + mock_sock.close.assert_called_once() + assert conn._sock is None + + def test_disconnect__shutdown_OSError(self): + """An OSError on socket shutdown will still close the socket.""" + conn = Connection() + mock_sock = mock.Mock() + conn._sock = mock_sock + conn._sock.shutdown.side_effect = OSError + conn.disconnect() + mock_sock.shutdown.assert_called_once() + mock_sock.close.assert_called_once() + assert conn._sock is None + + def test_disconnect__close_OSError(self): + """An OSError on socket close will still clear out the socket.""" + conn = Connection() + mock_sock = mock.Mock() + conn._sock = mock_sock + conn._sock.close.side_effect = OSError + conn.disconnect() + mock_sock.shutdown.assert_called_once() + mock_sock.close.assert_called_once() + assert conn._sock is None |