summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJelmer Vernooij <jelmer@samba.org>2010-11-08 18:47:33 +0100
committerJelmer Vernooij <jelmer@samba.org>2010-11-08 18:47:33 +0100
commit9b092296f18b54999130089d503cbc0cbe4c50f8 (patch)
treec95c1a8137746e4c309676bd0fb2a651ffac26ac
parentc728544110b9e2bdada0968052bf9c228fbd8419 (diff)
downloadbzr-fastimport-9b092296f18b54999130089d503cbc0cbe4c50f8.tar.gz
Fix output stream to stdout for bzr fast-export.
-rw-r--r--__init__.py4
-rw-r--r--[-rwxr-xr-x]exporter.py (renamed from bzr_exporter.py)20
-rw-r--r--tests/__init__.py1
-rw-r--r--tests/test_exporter.py62
4 files changed, 78 insertions, 9 deletions
diff --git a/__init__.py b/__init__.py
index bd52155..2fc555c 100644
--- a/__init__.py
+++ b/__init__.py
@@ -712,11 +712,11 @@ class cmd_fast_export(Command):
import_marks=None, export_marks=None, revision=None,
plain=True):
load_fastimport()
- from bzrlib.plugins.fastimport import bzr_exporter
+ from bzrlib.plugins.fastimport import exporter
if marks:
import_marks = export_marks = marks
- exporter = bzr_exporter.BzrFastExporter(source,
+ exporter = exporter.BzrFastExporter(source,
destination=destination,
git_branch=git_branch, checkpoint=checkpoint,
import_marks_file=import_marks, export_marks_file=export_marks,
diff --git a/bzr_exporter.py b/exporter.py
index a1bd75b..df8e403 100755..100644
--- a/bzr_exporter.py
+++ b/exporter.py
@@ -43,6 +43,18 @@ from bzrlib import (
from bzrlib.plugins.fastimport import helpers, marks_file
from fastimport import commands
+from fastimport.helpers import binary_stream
+
+
+def _get_output_stream(destination):
+ if destination is None or destination == '-':
+ return binary_stream(sys.stdout)
+ elif destination.endswith('gz'):
+ import gzip
+ return gzip.open(destination, 'wb')
+ else:
+ return open(destination, 'wb')
+
class BzrFastExporter(object):
@@ -57,13 +69,7 @@ class BzrFastExporter(object):
authors, revision properties, etc.
"""
self.source = source
- if destination is None or destination == '-':
- self.outf = helpers.binary_stream(sys.stdout)
- elif destination.endswith('gz'):
- import gzip
- self.outf = gzip.open(destination, 'wb')
- else:
- self.outf = open(destination, 'wb')
+ self.outf = _get_output_stream(destination)
self.git_branch = git_branch
self.checkpoint = checkpoint
self.import_marks_file = import_marks_file
diff --git a/tests/__init__.py b/tests/__init__.py
index b9c3802..47441e6 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -42,6 +42,7 @@ FastimportFeature = _FastimportFeature()
def test_suite():
module_names = [__name__ + '.' + x for x in [
'test_commands',
+ 'test_exporter',
'test_branch_mapper',
'test_generic_processor',
'test_revision_store',
diff --git a/tests/test_exporter.py b/tests/test_exporter.py
new file mode 100644
index 0000000..fe50e3b
--- /dev/null
+++ b/tests/test_exporter.py
@@ -0,0 +1,62 @@
+# Copyright (C) 2010 Canonical Ltd
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+
+"""Test the exporter."""
+
+import os
+import tempfile
+import gzip
+
+from bzrlib import tests
+
+from bzrlib.plugins.fastimport.exporter import (
+ _get_output_stream,
+ )
+
+from bzrlib.plugins.fastimport.tests import (
+ FastimportFeature,
+ )
+
+
+class TestOutputStream(tests.TestCase):
+
+ _test_needs_features = [FastimportFeature]
+
+ def test_get_output_stream_stdout(self):
+ # - returns standard out
+ self.assertIsNot(None, _get_output_stream("-"))
+
+ def test_get_source_gz(self):
+ fd, filename = tempfile.mkstemp(suffix=".gz")
+ os.close(fd)
+ stream = _get_output_stream(filename)
+ stream.write("bla")
+ stream.close()
+ # files ending in .gz are automatically decompressed.
+ f = gzip.GzipFile(filename)
+ self.assertEquals("bla", f.read())
+ f.close()
+
+ def test_get_source_file(self):
+ # other files are opened as regular files.
+ fd, filename = tempfile.mkstemp()
+ os.close(fd)
+ stream = _get_output_stream(filename)
+ stream.write("foo")
+ stream.close()
+ f = open(filename, 'r')
+ self.assertEquals("foo", f.read())
+ f.close()