Skip to content

Commit

Permalink
ENH: Allow to pass input file without named argument (#2576)
Browse files Browse the repository at this point in the history
  • Loading branch information
pubpub-zz authored Apr 14, 2024
1 parent ced67e1 commit eb6b21c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
38 changes: 35 additions & 3 deletions pypdf/_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def is_encrypted(self) -> bool:

def __init__(
self,
fileobj: StrByteType = "",
fileobj: Union[None, PdfReader, StrByteType, Path] = "",
clone_from: Union[None, PdfReader, StrByteType, Path] = None,
) -> None:
self._header = b"%PDF-1.3"
Expand Down Expand Up @@ -213,12 +213,41 @@ def __init__(
)
self._root = self._add_object(self._root_object)

def _get_clone_from(
fileobj: Union[None, PdfReader, str, Path, IO[Any], BytesIO],
clone_from: Union[None, PdfReader, str, Path, IO[Any], BytesIO],
) -> Union[None, PdfReader, str, Path, IO[Any], BytesIO]:
if not isinstance(fileobj, (str, Path, IO, BytesIO)) or (
fileobj != "" and clone_from is None
):
cloning = True
if not (
not isinstance(fileobj, (str, Path))
or (
Path(str(fileobj)).exists()
and Path(str(fileobj)).stat().st_size > 0
)
):
cloning = False
if isinstance(fileobj, (IO, BytesIO)):
t = fileobj.tell()
fileobj.seek(-1, 2)
if fileobj.tell() == 0:
cloning = False
fileobj.seek(t, 0)
if cloning:
clone_from = fileobj
return clone_from

clone_from = _get_clone_from(fileobj, clone_from)
# to prevent overwriting
self.temp_fileobj = fileobj
self.fileobj = ""
self.with_as_usage = False
if clone_from is not None:
if not isinstance(clone_from, PdfReader):
clone_from = PdfReader(clone_from)
self.clone_document_from_reader(clone_from)
self.fileobj = fileobj
self.with_as_usage = False

self._encryption: Optional[Encryption] = None
self._encrypt_entry: Optional[DictionaryObject] = None
Expand Down Expand Up @@ -268,7 +297,10 @@ def xmp_metadata(self, value: Optional[XmpInformation]) -> None:

def __enter__(self) -> "PdfWriter":
"""Store that writer is initialized by 'with'."""
t = self.temp_fileobj
self.__init__() # type: ignore
self.with_as_usage = True
self.fileobj = t # type: ignore
return self

def __exit__(
Expand Down
21 changes: 21 additions & 0 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2196,3 +2196,24 @@ def test_mime_jupyter():
writer = PdfWriter(clone_from=reader)
assert reader._repr_mimebundle_(("include",), ("exclude",)) == {}
assert writer._repr_mimebundle_(("include",), ("exclude",)) == {}


def test_init_without_named_arg():
"""Test to use file_obj argument and not clone_from"""
pdf_path = RESOURCE_ROOT / "crazyones.pdf"
reader = PdfReader(pdf_path)
writer = PdfWriter(clone_from=reader)
nb = len(writer._objects)
writer = PdfWriter(reader)
assert len(writer._objects) == nb
with open(pdf_path, "rb") as f:
writer = PdfWriter(f)
f.seek(0, 0)
by = BytesIO(f.read())
assert len(writer._objects) == nb
writer = PdfWriter(pdf_path)
assert len(writer._objects) == nb
writer = PdfWriter(str(pdf_path))
assert len(writer._objects) == nb
writer = PdfWriter(by)
assert len(writer._objects) == nb

0 comments on commit eb6b21c

Please sign in to comment.