summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Brown <paul90brown@gmail.com>2021-12-16 07:35:22 +0000
committerGitHub <noreply@github.com>2021-12-16 09:35:22 +0200
commita8b8f142399a62e64c3003adda2d9563eea95ef4 (patch)
tree448e41f01ab295ad69f93be6b72d3b6d8a8f5bad
parent82bad1686177c4c543818a8bfac35c6fdfc9ddf1 (diff)
downloadredis-py-a8b8f142399a62e64c3003adda2d9563eea95ef4.tar.gz
close socket after server disconnect (#1797)
-rwxr-xr-xredis/connection.py9
-rw-r--r--tests/test_connection.py34
2 files changed, 41 insertions, 2 deletions
diff --git a/redis/connection.py b/redis/connection.py
index 6c4494b..1bb8eb5 100755
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -717,9 +717,14 @@ class Connection:
self._parser.on_disconnect()
if self._sock is None:
return
- try:
- if os.getpid() == self.pid:
+
+ if os.getpid() == self.pid:
+ try:
self._sock.shutdown(socket.SHUT_RDWR)
+ except OSError:
+ pass
+
+ try:
self._sock.close()
except OSError:
pass
diff --git a/tests/test_connection.py b/tests/test_connection.py
index 22f1b71..d94a815 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -3,6 +3,7 @@ from unittest import mock
import pytest
+from redis.connection import Connection
from redis.exceptions import InvalidResponse
from redis.utils import HIREDIS_AVAILABLE
@@ -40,3 +41,36 @@ def test_loading_external_modules(modclient):
# mod = j(modclient)
# mod.set("fookey", ".", d)
# assert mod.get('fookey') == d
+
+
+class TestConnection:
+ def test_disconnect(self):
+ conn = Connection()
+ mock_sock = mock.Mock()
+ conn._sock = mock_sock
+ conn.disconnect()
+ mock_sock.shutdown.assert_called_once()
+ mock_sock.close.assert_called_once()
+ assert conn._sock is None
+
+ def test_disconnect__shutdown_OSError(self):
+ """An OSError on socket shutdown will still close the socket."""
+ conn = Connection()
+ mock_sock = mock.Mock()
+ conn._sock = mock_sock
+ conn._sock.shutdown.side_effect = OSError
+ conn.disconnect()
+ mock_sock.shutdown.assert_called_once()
+ mock_sock.close.assert_called_once()
+ assert conn._sock is None
+
+ def test_disconnect__close_OSError(self):
+ """An OSError on socket close will still clear out the socket."""
+ conn = Connection()
+ mock_sock = mock.Mock()
+ conn._sock = mock_sock
+ conn._sock.close.side_effect = OSError
+ conn.disconnect()
+ mock_sock.shutdown.assert_called_once()
+ mock_sock.close.assert_called_once()
+ assert conn._sock is None