diff options
author | rfkelly0 <rfkelly0@67cdc799-7952-0410-af00-57a81ceafa0f> | 2009-06-16 03:38:34 +0000 |
---|---|---|
committer | rfkelly0 <rfkelly0@67cdc799-7952-0410-af00-57a81ceafa0f> | 2009-06-16 03:38:34 +0000 |
commit | 9ace9f0d90b569bfcfe1d1d2cb93cf0ffa74e81f (patch) | |
tree | c3d7aaa2fd445b7169b09f35679570cbc99a6a87 /fs/sftpfs.py | |
parent | 12ca3e2c6554017726909f663c3f910278a1f407 (diff) | |
download | pyfilesystem-git-9ace9f0d90b569bfcfe1d1d2cb93cf0ffa74e81f.tar.gz |
SFTPFS: better thread-safety using a per-thread SFTPClient instance
Diffstat (limited to 'fs/sftpfs.py')
-rw-r--r-- | fs/sftpfs.py | 54 |
1 files changed, 41 insertions, 13 deletions
diff --git a/fs/sftpfs.py b/fs/sftpfs.py index 1d825b6..9dd1e70 100644 --- a/fs/sftpfs.py +++ b/fs/sftpfs.py @@ -11,6 +11,22 @@ import paramiko from fs.base import * +# SFTPClient appears to not be thread-safe, so we use an instance per thread +if hasattr(threading,"local"): + thread_local = threading.local +else: + class thread_local(object): + def __init__(self): + self._map = {} + def __getattr__(self,attr): + try: + return self._map[(threading.currentThread().ident,attr)] + except KeyError: + raise AttributeError, attr + def __setattr__(self,attr,value): + self._map[(threading.currentThread().ident,attr)] = value + + if not hasattr(paramiko.SFTPFile,"__enter__"): paramiko.SFTPFile.__enter__ = lambda self: self @@ -41,17 +57,20 @@ class SFTPFS(FS): other keyword arguments are assumed to be credentials to be used when connecting the transport. """ + self.closed = False self._owns_transport = False self._credentials = credentials + self._tlocal = thread_local() if isinstance(connection,paramiko.Channel): - self.client = paramiko.SFTPClient(connection) + self._transport = None + self._client = paramiko.SFTPClient(connection) else: if not isinstance(connection,paramiko.Transport): connection = paramiko.Transport(connection) self._owns_transport = True if not connection.is_authenticated(): connection.connect(**credentials) - self.client = paramiko.SFTPClient.from_transport(connection) + self._transport = connection self.root = abspath(normpath(root)) def __del__(self): @@ -59,28 +78,37 @@ class SFTPFS(FS): def __getstate__(self): state = super(SFTPFS,self).__getstate__() + del state["_tlocal"] if self._owns_transport: - state['client'] = self.client.get_channel().get_transport().getpeername() + state['_transport'] = self._transport.getpeername() return state def __setstate__(self,state): for (k,v) in state.iteritems(): self.__dict__[k] = v + self._tlocal = thread_local() if self._owns_transport: - t = paramiko.Transport(self.client) - t.connect(**self._credentials) - self.client = paramiko.SFTPClient.from_transport(t) + self._transport = paramiko.Transport(self._transport) + self._transport.connect(**self._credentials) + + @property + def client(self): + try: + return self._tlocal.client + except AttributeError: + if self._transport is None: + return self._client + client = paramiko.SFTPClient.from_transport(self._transport) + self._tlocal.client = client + return client def close(self): """Close the connection to the remote server.""" - if getattr(self,"client",None): - if self._owns_transport: - t = self.client.get_channel().get_transport() - self.client.close() - t.close() - else: + if not self.closed: + if self.client: self.client.close() - self.client = None + if self._owns_transport and self._transport: + self._transport.close() def _normpath(self,path): npath = pathjoin(self.root,relpath(normpath(path))) |