From 98aa9742e757ec428e1953ba7f47c6d7c44b331a Mon Sep 17 00:00:00 2001 From: Alex Meyer <144723289+alexaryn@users.noreply.github.com> Date: Wed, 30 Oct 2024 11:59:47 -0700 Subject: [PATCH] BUG: Don't close stream passed to PdfWriter.write() (#2909) Closes #2905. --- pypdf/_writer.py | 33 ++++++++++++++++++++++++--------- tests/test_writer.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 9 deletions(-) diff --git a/pypdf/_writer.py b/pypdf/_writer.py index 3d1fac029..fbc4b29d1 100644 --- a/pypdf/_writer.py +++ b/pypdf/_writer.py @@ -63,6 +63,7 @@ StreamType, _get_max_pdf_version_header, deprecate, + deprecate_no_replacement, deprecation_with_replacement, logger_warning, ) @@ -236,10 +237,9 @@ def _get_clone_from( or Path(str(fileobj)).stat().st_size == 0 ): cloning = False - if isinstance(fileobj, (IO, BytesIO)): + if isinstance(fileobj, (IOBase, BytesIO)): t = fileobj.tell() - fileobj.seek(-1, 2) - if fileobj.tell() == 0: + if fileobj.seek(0, 2) == 0: cloning = False fileobj.seek(t, 0) if cloning: @@ -250,7 +250,8 @@ def _get_clone_from( # to prevent overwriting self.temp_fileobj = fileobj self.fileobj = "" - self.with_as_usage = False + self._with_as_usage = False + self._cloned = False # The root of our page tree node. pages = DictionaryObject() pages.update( @@ -268,6 +269,7 @@ def _get_clone_from( if not isinstance(clone_from, PdfReader): clone_from = PdfReader(clone_from) self.clone_document_from_reader(clone_from) + self._cloned = True else: self._pages = self._add_object(pages) # root object @@ -355,11 +357,23 @@ def xmp_metadata(self, value: Optional[XmpInformation]) -> None: return self.root_object.xmp_metadata # type: ignore + @property + def with_as_usage(self) -> bool: + deprecate_no_replacement("with_as_usage", "6.0") + return self._with_as_usage + + @with_as_usage.setter + def with_as_usage(self, value: bool) -> None: + deprecate_no_replacement("with_as_usage", "6.0") + self._with_as_usage = value + def __enter__(self) -> "PdfWriter": - """Store that writer is initialized by 'with'.""" + """Store how writer is initialized by 'with'.""" + c: bool = self._cloned t = self.temp_fileobj self.__init__() # type: ignore - self.with_as_usage = True + self._cloned = c + self._with_as_usage = True self.fileobj = t # type: ignore return self @@ -370,7 +384,7 @@ def __exit__( traceback: Optional[TracebackType], ) -> None: """Write data to the fileobj.""" - if self.fileobj: + if self.fileobj and not self._cloned: self.write(self.fileobj) def _repr_mimebundle_( @@ -1406,13 +1420,14 @@ def write(self, stream: Union[Path, StrByteType]) -> Tuple[bool, IO[Any]]: if isinstance(stream, (str, Path)): stream = FileIO(stream, "wb") - self.with_as_usage = True my_file = True self.write_stream(stream) - if self.with_as_usage: + if my_file: stream.close() + else: + stream.flush() return my_file, stream diff --git a/tests/test_writer.py b/tests/test_writer.py index 672f2378a..09f24244b 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -2481,3 +2481,43 @@ def test_append_pdf_with_dest_without_page(caplog): writer.append(reader) assert "/__WKANCHOR_8" not in writer.named_destinations assert len(writer.named_destinations) == 3 + + +def test_stream_not_closed(): + """Tests for #2905""" + src = RESOURCE_ROOT / "pdflatex-outline.pdf" + with NamedTemporaryFile(suffix=".pdf") as tmp: + with PdfReader(src) as reader, PdfWriter() as writer: + writer.add_page(reader.pages[0]) + writer.write(tmp) + assert not tmp.file.closed + + with NamedTemporaryFile(suffix=".pdf") as target: + with PdfWriter(target.file) as writer: + writer.add_blank_page(100, 100) + assert not target.file.closed + + with open(src, "rb") as fileobj: + with PdfWriter(fileobj) as writer: + pass + assert not fileobj.closed + + +def test_auto_write(tmp_path): + """Another test for #2905""" + target = tmp_path / "out.pdf" + with PdfWriter(target) as writer: + writer.add_blank_page(100, 100) + assert target.stat().st_size > 0 + + +def test_deprecate_with_as(): + """Yet another test for #2905""" + with PdfWriter() as writer: + with pytest.warns(DeprecationWarning) as w: + val = writer.with_as_usage + assert "with_as_usage is deprecated" in w[0].message.args[0] + assert val + with pytest.warns(DeprecationWarning) as w: + writer.with_as_usage = val # old code allowed setting this, so... + assert "with_as_usage is deprecated" in w[0].message.args[0]