summaryrefslogtreecommitdiff
path: root/bzrlib/tests/test_sftp_transport.py
blob: 1c4d04b3693e4b50c78485ecd81e0ef3a9ed304f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
# Copyright (C) 2005-2011 Robey Pointer <robey@lag.net>
# Copyright (C) 2005, 2006, 2007 Canonical Ltd
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA

import os
import socket
import sys
import time

from bzrlib import (
    config,
    controldir,
    errors,
    tests,
    transport as _mod_transport,
    ui,
    )
from bzrlib.osutils import (
    lexists,
    )
from bzrlib.tests import (
    features,
    TestCaseWithTransport,
    TestCase,
    TestSkipped,
    )
from bzrlib.tests.http_server import HttpServer
import bzrlib.transport.http

if features.paramiko.available():
    from bzrlib.transport import sftp as _mod_sftp
    from bzrlib.tests import stub_sftp


def set_test_transport_to_sftp(testcase):
    """A helper to set transports on test case instances."""
    if getattr(testcase, '_get_remote_is_absolute', None) is None:
        testcase._get_remote_is_absolute = True
    if testcase._get_remote_is_absolute:
        testcase.transport_server = stub_sftp.SFTPAbsoluteServer
    else:
        testcase.transport_server = stub_sftp.SFTPHomeDirServer
    testcase.transport_readonly_server = HttpServer


class TestCaseWithSFTPServer(TestCaseWithTransport):
    """A test case base class that provides a sftp server on localhost."""

    def setUp(self):
        super(TestCaseWithSFTPServer, self).setUp()
        self.requireFeature(features.paramiko)
        set_test_transport_to_sftp(self)


class SFTPLockTests(TestCaseWithSFTPServer):

    def test_sftp_locks(self):
        from bzrlib.errors import LockError
        t = self.get_transport()

        l = t.lock_write('bogus')
        self.assertPathExists('bogus.write-lock')

        # Don't wait for the lock, locking an already locked
        # file should raise an assert
        self.assertRaises(LockError, t.lock_write, 'bogus')

        l.unlock()
        self.assertFalse(lexists('bogus.write-lock'))

        with open('something.write-lock', 'wb') as f: f.write('fake lock\n')
        self.assertRaises(LockError, t.lock_write, 'something')
        os.remove('something.write-lock')

        l = t.lock_write('something')

        l2 = t.lock_write('bogus')

        l.unlock()
        l2.unlock()


class SFTPTransportTestRelative(TestCaseWithSFTPServer):
    """Test the SFTP transport with homedir based relative paths."""

    def test__remote_path(self):
        if sys.platform == 'darwin':
            # This test is about sftp absolute path handling. There is already
            # (in this test) a TODO about windows needing an absolute path
            # without drive letter. To me, using self.test_dir is a trick to
            # get an absolute path for comparison purposes.  That fails for OSX
            # because the sftp server doesn't resolve the links (and it doesn't
            # have to). --vila 20070924
            self.knownFailure('Mac OSX symlinks /tmp to /private/tmp,'
                              ' testing against self.test_dir'
                              ' is not appropriate')
        t = self.get_transport()
        # This test require unix-like absolute path
        test_dir = self.test_dir
        if sys.platform == 'win32':
            # using hack suggested by John Meinel.
            # TODO: write another mock server for this test
            #       and use absolute path without drive letter
            test_dir = '/' + test_dir
        # try what is currently used:
        # remote path = self._abspath(relpath)
        self.assertIsSameRealPath(test_dir + '/relative',
                                  t._remote_path('relative'))
        # we dont os.path.join because windows gives us the wrong path
        root_segments = test_dir.split('/')
        root_parent = '/'.join(root_segments[:-1])
        # .. should be honoured
        self.assertIsSameRealPath(root_parent + '/sibling',
                                  t._remote_path('../sibling'))
        # /  should be illegal ?
        ### FIXME decide and then test for all transports. RBC20051208


class SFTPTransportTestRelativeRoot(TestCaseWithSFTPServer):
    """Test the SFTP transport with homedir based relative paths."""

    def setUp(self):
        # Only SFTPHomeDirServer is tested here
        self._get_remote_is_absolute = False
        super(SFTPTransportTestRelativeRoot, self).setUp()

    def test__remote_path_relative_root(self):
        # relative paths are preserved
        t = self.get_transport('')
        self.assertEqual('/~/', t._parsed_url.path)
        # the remote path should be relative to home dir
        # (i.e. not begining with a '/')
        self.assertEqual('a', t._remote_path('a'))


class SFTPNonServerTest(TestCase):
    def setUp(self):
        TestCase.setUp(self)
        self.requireFeature(features.paramiko)

    def test_parse_url_with_home_dir(self):
        s = _mod_sftp.SFTPTransport(
            'sftp://ro%62ey:h%40t@example.com:2222/~/relative')
        self.assertEquals(s._parsed_url.host, 'example.com')
        self.assertEquals(s._parsed_url.port, 2222)
        self.assertEquals(s._parsed_url.user, 'robey')
        self.assertEquals(s._parsed_url.password, 'h@t')
        self.assertEquals(s._parsed_url.path, '/~/relative/')

    def test_relpath(self):
        s = _mod_sftp.SFTPTransport('sftp://user@host.com/abs/path')
        self.assertRaises(errors.PathNotChild, s.relpath,
                          'sftp://user@host.com/~/rel/path/sub')

    def test_get_paramiko_vendor(self):
        """Test that if no 'ssh' is available we get builtin paramiko"""
        from bzrlib.transport import ssh
        # set '.' as the only location in the path, forcing no 'ssh' to exist
        self.overrideAttr(ssh, '_ssh_vendor_manager')
        self.overrideEnv('PATH', '.')
        ssh._ssh_vendor_manager.clear_cache()
        vendor = ssh._get_ssh_vendor()
        self.assertIsInstance(vendor, ssh.ParamikoVendor)

    def test_abspath_root_sibling_server(self):
        server = stub_sftp.SFTPSiblingAbsoluteServer()
        server.start_server()
        self.addCleanup(server.stop_server)

        transport = _mod_transport.get_transport_from_url(server.get_url())
        self.assertFalse(transport.abspath('/').endswith('/~/'))
        self.assertTrue(transport.abspath('/').endswith('/'))
        del transport


class SFTPBranchTest(TestCaseWithSFTPServer):
    """Test some stuff when accessing a bzr Branch over sftp"""

    def test_push_support(self):
        self.build_tree(['a/', 'a/foo'])
        t = controldir.ControlDir.create_standalone_workingtree('a')
        b = t.branch
        t.add('foo')
        t.commit('foo', rev_id='a1')

        b2 = controldir.ControlDir.create_branch_and_repo(self.get_url('/b'))
        b2.pull(b)

        self.assertEquals(b2.last_revision(), 'a1')

        with open('a/foo', 'wt') as f: f.write('something new in foo\n')
        t.commit('new', rev_id='a2')
        b2.pull(b)

        self.assertEquals(b2.last_revision(), 'a2')


class SSHVendorConnection(TestCaseWithSFTPServer):
    """Test that the ssh vendors can all connect.

    Verify that a full-handshake (SSH over loopback TCP) sftp connection works.

    We have 3 sftp implementations in the test suite:
      'loopback': Doesn't use ssh, just uses a local socket. Most tests are
                  done this way to save the handshaking time, so it is not
                  tested again here
      'none':     This uses paramiko's built-in ssh client and server, and layers
                  sftp on top of it.
      None:       If 'ssh' exists on the machine, then it will be spawned as a
                  child process.
    """

    def setUp(self):
        super(SSHVendorConnection, self).setUp()

        def create_server():
            """Just a wrapper so that when created, it will set _vendor"""
            # SFTPFullAbsoluteServer can handle any vendor,
            # it just needs to be set between the time it is instantiated
            # and the time .setUp() is called
            server = stub_sftp.SFTPFullAbsoluteServer()
            server._vendor = self._test_vendor
            return server
        self._test_vendor = 'loopback'
        self.vfs_transport_server = create_server
        f = open('a_file', 'wb')
        try:
            f.write('foobar\n')
        finally:
            f.close()

    def set_vendor(self, vendor):
        self._test_vendor = vendor

    def test_connection_paramiko(self):
        from bzrlib.transport import ssh
        self.set_vendor(ssh.ParamikoVendor())
        t = self.get_transport()
        self.assertEqual('foobar\n', t.get('a_file').read())

    def test_connection_vendor(self):
        raise TestSkipped("We don't test spawning real ssh,"
                          " because it prompts for a password."
                          " Enable this test if we figure out"
                          " how to prevent this.")
        self.set_vendor(None)
        t = self.get_transport()
        self.assertEqual('foobar\n', t.get('a_file').read())


class SSHVendorBadConnection(TestCaseWithTransport):
    """Test that the ssh vendors handle bad connection properly

    We don't subclass TestCaseWithSFTPServer, because we don't actually
    need an SFTP connection.
    """

    def setUp(self):
        self.requireFeature(features.paramiko)
        super(SSHVendorBadConnection, self).setUp()

        # open a random port, so we know nobody else is using it
        # but don't actually listen on the port.
        s = socket.socket()
        s.bind(('localhost', 0))
        self.addCleanup(s.close)
        self.bogus_url = 'sftp://%s:%s/' % s.getsockname()

    def set_vendor(self, vendor, subprocess_stderr=None):
        from bzrlib.transport import ssh
        self.overrideAttr(ssh._ssh_vendor_manager, '_cached_ssh_vendor', vendor)
        if subprocess_stderr is not None:
            self.overrideAttr(ssh.SubprocessVendor, "_stderr_target",
                subprocess_stderr)

    def test_bad_connection_paramiko(self):
        """Test that a real connection attempt raises the right error"""
        from bzrlib.transport import ssh
        self.set_vendor(ssh.ParamikoVendor())
        t = _mod_transport.get_transport_from_url(self.bogus_url)
        self.assertRaises(errors.ConnectionError, t.get, 'foobar')

    def test_bad_connection_ssh(self):
        """None => auto-detect vendor"""
        f = file(os.devnull, "wb")
        self.addCleanup(f.close)
        self.set_vendor(None, f)
        t = _mod_transport.get_transport_from_url(self.bogus_url)
        try:
            self.assertRaises(errors.ConnectionError, t.get, 'foobar')
        except NameError, e:
            if "global name 'SSHException'" in str(e):
                self.knownFailure('Known NameError bug in paramiko 1.6.1')
            raise


class SFTPLatencyKnob(TestCaseWithSFTPServer):
    """Test that the testing SFTPServer's latency knob works."""

    def test_latency_knob_slows_transport(self):
        # change the latency knob to 500ms. We take about 40ms for a
        # loopback connection ordinarily.
        start_time = time.time()
        self.get_server().add_latency = 0.5
        transport = self.get_transport()
        transport.has('not me') # Force connection by issuing a request
        with_latency_knob_time = time.time() - start_time
        self.assertTrue(with_latency_knob_time > 0.4)

    def test_default(self):
        # This test is potentially brittle: under extremely high machine load
        # it could fail, but that is quite unlikely
        raise TestSkipped('Timing-sensitive test')
        start_time = time.time()
        transport = self.get_transport()
        transport.has('not me') # Force connection by issuing a request
        regular_time = time.time() - start_time
        self.assertTrue(regular_time < 0.5)


class FakeSocket(object):
    """Fake socket object used to test the SocketDelay wrapper without
    using a real socket.
    """

    def __init__(self):
        self._data = ""

    def send(self, data, flags=0):
        self._data += data
        return len(data)

    def sendall(self, data, flags=0):
        self._data += data
        return len(data)

    def recv(self, size, flags=0):
        if size < len(self._data):
            result = self._data[:size]
            self._data = self._data[size:]
            return result
        else:
            result = self._data
            self._data = ""
            return result


class TestSocketDelay(TestCase):

    def setUp(self):
        TestCase.setUp(self)
        self.requireFeature(features.paramiko)

    def test_delay(self):
        sending = FakeSocket()
        receiving = stub_sftp.SocketDelay(sending, 0.1, bandwidth=1000000,
                                          really_sleep=False)
        # check that simulated time is charged only per round-trip:
        t1 = stub_sftp.SocketDelay.simulated_time
        receiving.send("connect1")
        self.assertEqual(sending.recv(1024), "connect1")
        t2 = stub_sftp.SocketDelay.simulated_time
        self.assertAlmostEqual(t2 - t1, 0.1)
        receiving.send("connect2")
        self.assertEqual(sending.recv(1024), "connect2")
        sending.send("hello")
        self.assertEqual(receiving.recv(1024), "hello")
        t3 = stub_sftp.SocketDelay.simulated_time
        self.assertAlmostEqual(t3 - t2, 0.1)
        sending.send("hello")
        self.assertEqual(receiving.recv(1024), "hello")
        sending.send("hello")
        self.assertEqual(receiving.recv(1024), "hello")
        sending.send("hello")
        self.assertEqual(receiving.recv(1024), "hello")
        t4 = stub_sftp.SocketDelay.simulated_time
        self.assertAlmostEqual(t4, t3)

    def test_bandwidth(self):
        sending = FakeSocket()
        receiving = stub_sftp.SocketDelay(sending, 0, bandwidth=8.0/(1024*1024),
                                          really_sleep=False)
        # check that simulated time is charged only per round-trip:
        t1 = stub_sftp.SocketDelay.simulated_time
        receiving.send("connect")
        self.assertEqual(sending.recv(1024), "connect")
        sending.send("a" * 100)
        self.assertEqual(receiving.recv(1024), "a" * 100)
        t2 = stub_sftp.SocketDelay.simulated_time
        self.assertAlmostEqual(t2 - t1, 100 + 7)


class ReadvFile(object):
    """An object that acts like Paramiko's SFTPFile when readv() is used"""

    def __init__(self, data):
        self._data = data

    def readv(self, requests):
        for start, length in requests:
            yield self._data[start:start+length]

    def close(self):
        pass


def _null_report_activity(*a, **k):
    pass


class Test_SFTPReadvHelper(tests.TestCase):

    def checkGetRequests(self, expected_requests, offsets):
        self.requireFeature(features.paramiko)
        helper = _mod_sftp._SFTPReadvHelper(offsets, 'artificial_test',
            _null_report_activity)
        self.assertEqual(expected_requests, helper._get_requests())

    def test__get_requests(self):
        # Small single requests become a single readv request
        self.checkGetRequests([(0, 100)],
                              [(0, 20), (30, 50), (20, 10), (80, 20)])
        # Non-contiguous ranges are given as multiple requests
        self.checkGetRequests([(0, 20), (30, 50)],
                              [(10, 10), (30, 20), (0, 10), (50, 30)])
        # Ranges larger than _max_request_size (32kB) are broken up into
        # multiple requests, even if it actually spans multiple logical
        # requests
        self.checkGetRequests([(0, 32768), (32768, 32768), (65536, 464)],
                              [(0, 40000), (40000, 100), (40100, 1900),
                               (42000, 24000)])

    def checkRequestAndYield(self, expected, data, offsets):
        self.requireFeature(features.paramiko)
        helper = _mod_sftp._SFTPReadvHelper(offsets, 'artificial_test',
            _null_report_activity)
        data_f = ReadvFile(data)
        result = list(helper.request_and_yield_offsets(data_f))
        self.assertEqual(expected, result)

    def test_request_and_yield_offsets(self):
        data = 'abcdefghijklmnopqrstuvwxyz'
        self.checkRequestAndYield([(0, 'a'), (5, 'f'), (10, 'klm')], data,
                                  [(0, 1), (5, 1), (10, 3)])
        # Should combine requests, and split them again
        self.checkRequestAndYield([(0, 'a'), (1, 'b'), (10, 'klm')], data,
                                  [(0, 1), (1, 1), (10, 3)])
        # Out of order requests. The requests should get combined, but then be
        # yielded out-of-order. We also need one that is at the end of a
        # previous range. See bug #293746
        self.checkRequestAndYield([(0, 'a'), (10, 'k'), (4, 'efg'), (1, 'bcd')],
                                  data, [(0, 1), (10, 1), (4, 3), (1, 3)])


class TestUsesAuthConfig(TestCaseWithSFTPServer):
    """Test that AuthenticationConfig can supply default usernames."""

    def get_transport_for_connection(self, set_config):
        port = self.get_server().port
        if set_config:
            conf = config.AuthenticationConfig()
            conf._get_config().update(
                {'sftptest': {'scheme': 'ssh', 'port': port, 'user': 'bar'}})
            conf._save()
        t = _mod_transport.get_transport_from_url(
            'sftp://localhost:%d' % port)
        # force a connection to be performed.
        t.has('foo')
        return t

    def test_sftp_uses_config(self):
        t = self.get_transport_for_connection(set_config=True)
        self.assertEqual('bar', t._get_credentials()[0])

    def test_sftp_is_none_if_no_config(self):
        t = self.get_transport_for_connection(set_config=False)
        self.assertIs(None, t._get_credentials()[0])

    def test_sftp_doesnt_prompt_username(self):
        stdout = tests.StringIOWrapper()
        ui.ui_factory = tests.TestUIFactory(stdin='joe\nfoo\n', stdout=stdout)
        t = self.get_transport_for_connection(set_config=False)
        self.assertIs(None, t._get_credentials()[0])
        # No prompts should've been printed, stdin shouldn't have been read
        self.assertEquals("", stdout.getvalue())
        self.assertEquals(0, ui.ui_factory.stdin.tell())