diff options
-rwxr-xr-x | redis/connection.py | 16 | ||||
-rw-r--r-- | tests/test_multiprocessing.py | 154 |
2 files changed, 159 insertions, 11 deletions
diff --git a/redis/connection.py b/redis/connection.py index 0d1c394..ee0b92a 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -276,9 +276,7 @@ class PythonParser(BaseParser): def on_disconnect(self): "Called when the socket disconnects" - if self._sock is not None: - self._sock.close() - self._sock = None + self._sock = None if self._buffer is not None: self._buffer.close() self._buffer = None @@ -473,12 +471,6 @@ class Connection(object): def __repr__(self): return self.description_format % self._description_args - def __del__(self): - try: - self.disconnect() - except Exception: - pass - def register_connect_callback(self, callback): self._connect_callbacks.append(callback) @@ -582,7 +574,8 @@ class Connection(object): if self._sock is None: return try: - self._sock.shutdown(socket.SHUT_RDWR) + if os.getpid() == self.pid: + self._sock.shutdown(socket.SHUT_RDWR) self._sock.close() except socket.error: pass @@ -975,7 +968,6 @@ class ConnectionPool(object): # another thread already did the work while we waited # on the lock. return - self.disconnect() self.reset() def get_connection(self, command_name, *keys, **options): @@ -1014,6 +1006,7 @@ class ConnectionPool(object): def disconnect(self): "Disconnects all connections in the pool" + self._checkpid() all_conns = chain(self._available_connections, self._in_use_connections) for connection in all_conns: @@ -1135,5 +1128,6 @@ class BlockingConnectionPool(ConnectionPool): def disconnect(self): "Disconnects all connections in the pool." + self._checkpid() for connection in self._connections: connection.disconnect() diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py new file mode 100644 index 0000000..bb31a06 --- /dev/null +++ b/tests/test_multiprocessing.py @@ -0,0 +1,154 @@ +import pytest +import multiprocessing +import contextlib + +from redis.connection import Connection, ConnectionPool +from redis.exceptions import ConnectionError + + +@contextlib.contextmanager +def exit_callback(callback, *args): + try: + yield + finally: + callback(*args) + + +class TestMultiprocessing(object): + # Test connection sharing between forks. + # See issue #1085 for details. + + def test_close_connection_in_child(self): + """ + A connection owned by a parent and closed by a child doesn't + destroy the file descriptors so a parent can still use it. + """ + conn = Connection() + conn.send_command('ping') + assert conn.read_response() == b'PONG' + + def target(conn): + conn.send_command('ping') + assert conn.read_response() == b'PONG' + conn.disconnect() + + proc = multiprocessing.Process(target=target, args=(conn,)) + proc.start() + proc.join(3) + assert proc.exitcode is 0 + + # The connection was created in the parent but disconnected in the + # child. The child called socket.close() but did not call + # socket.shutdown() because it wasn't the "owning" process. + # Therefore the connection still works in the parent. + conn.send_command('ping') + assert conn.read_response() == b'PONG' + + def test_close_connection_in_parent(self): + """ + A connection owned by a parent is unusable by a child if the parent + (the owning process) closes the connection. + """ + conn = Connection() + conn.send_command('ping') + assert conn.read_response() == b'PONG' + + def target(conn, ev): + ev.wait() + # the parent closed the connection. because it also created the + # connection, the connection is shutdown and the child + # cannot use it. + with pytest.raises(ConnectionError): + conn.send_command('ping') + + ev = multiprocessing.Event() + proc = multiprocessing.Process(target=target, args=(conn, ev)) + proc.start() + + conn.disconnect() + ev.set() + + proc.join(3) + assert proc.exitcode is 0 + + @pytest.mark.parametrize('max_connections', [1, 2, None]) + def test_pool(self, max_connections): + """ + A child will create its own connections when using a pool created + by a parent. + """ + pool = ConnectionPool.from_url('redis://localhost', + max_connections=max_connections) + + conn = pool.get_connection('ping') + main_conn_pid = conn.pid + with exit_callback(pool.release, conn): + conn.send_command('ping') + assert conn.read_response() == b'PONG' + + def target(pool): + with exit_callback(pool.disconnect): + conn = pool.get_connection('ping') + assert conn.pid != main_conn_pid + with exit_callback(pool.release, conn): + assert conn.send_command('ping') is None + assert conn.read_response() == b'PONG' + + proc = multiprocessing.Process(target=target, args=(pool,)) + proc.start() + proc.join(3) + assert proc.exitcode is 0 + + # Check that connection is still alive after fork process has exited + # and disconnected the connections in its pool + conn = pool.get_connection('ping') + with exit_callback(pool.release, conn): + assert conn.send_command('ping') is None + assert conn.read_response() == b'PONG' + + @pytest.mark.parametrize('max_connections', [1, 2, None]) + def test_close_pool_in_main(self, max_connections): + """ + A child process that uses the same pool as its parent isn't affected + when the parent disconnects all connections within the pool. + """ + pool = ConnectionPool.from_url('redis://localhost', + max_connections=max_connections) + + conn = pool.get_connection('ping') + assert conn.send_command('ping') is None + assert conn.read_response() == b'PONG' + + def target(pool, disconnect_event): + conn = pool.get_connection('ping') + with exit_callback(pool.release, conn): + assert conn.send_command('ping') is None + assert conn.read_response() == b'PONG' + disconnect_event.wait() + assert conn.send_command('ping') is None + assert conn.read_response() == b'PONG' + + ev = multiprocessing.Event() + + proc = multiprocessing.Process(target=target, args=(pool, ev)) + proc.start() + + pool.disconnect() + ev.set() + proc.join(3) + assert proc.exitcode is 0 + + def test_redis_client(self, r): + "A redis client created in a parent can also be used in a child" + assert r.ping() is True + + def target(client): + assert client.ping() is True + del client + + proc = multiprocessing.Process(target=target, args=(r,)) + proc.start() + proc.join(3) + assert proc.exitcode is 0 + + assert r.ping() is True |