From 522f9f3db3268ef5722c0721e065d63ced48e77c Mon Sep 17 00:00:00 2001 From: Taylor Southwick Date: Fri, 26 Jul 2024 12:20:55 -0700 Subject: [PATCH] Unload root element if Part.GetStream updates the underlying value (#1760) Fixes #1755 --- .../Packaging/OpenXmlPart.cs | 93 ++++++++++++++++++- .../OpenXmlPartTests.cs | 72 ++++++++++++++ 2 files changed, 160 insertions(+), 5 deletions(-) create mode 100644 test/DocumentFormat.OpenXml.Tests/OpenXmlPartTests.cs diff --git a/src/DocumentFormat.OpenXml.Framework/Packaging/OpenXmlPart.cs b/src/DocumentFormat.OpenXml.Framework/Packaging/OpenXmlPart.cs index 706518853..749e483eb 100644 --- a/src/DocumentFormat.OpenXml.Framework/Packaging/OpenXmlPart.cs +++ b/src/DocumentFormat.OpenXml.Framework/Packaging/OpenXmlPart.cs @@ -2,12 +2,13 @@ // Licensed under the MIT license. See LICENSE file in the project root for full license information. using DocumentFormat.OpenXml.Features; +using DocumentFormat.OpenXml.Framework; using System; using System.Collections.Generic; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.IO; -using System.IO.Packaging; +using System.Threading; namespace DocumentFormat.OpenXml.Packaging { @@ -236,9 +237,7 @@ public IEnumerable GetParentParts() /// The content stream of the part. public Stream GetStream(FileMode mode) { - ThrowIfObjectDisposed(); - - return PackagePart.GetStream(mode, Features.GetRequired().Package.FileOpenAccess); + return GetStream(mode, Features.GetRequired().Package.FileOpenAccess); } /// @@ -251,7 +250,20 @@ public Stream GetStream(FileMode mode, FileAccess access) { ThrowIfObjectDisposed(); - return PackagePart.GetStream(mode, access); + var stream = PackagePart.GetStream(mode, access); + + if (mode is FileMode.Create || stream.Length == 0) + { + UnloadRootElement(); + return new UnloadingRootElementStream(this, stream); + } + + if (stream.CanWrite) + { + return new UnloadingRootElementStream(this, stream); + } + + return stream; } /// @@ -605,5 +617,76 @@ internal sealed override OpenXmlPart ThisOpenXmlPart internal MarkupCompatibilityProcessSettings? MCSettings { get; set; } #endregion + + /// + /// A used by to unload the root if updated. + /// + private sealed class UnloadingRootElementStream : DelegatingStream + { + private readonly OpenXmlPart _part; + + private bool _hasWritten; + + public UnloadingRootElementStream(OpenXmlPart part, Stream innerStream) + : base(innerStream) + { + _part = part; + } + + protected override void Dispose(bool disposing) + { + if (disposing && _hasWritten) + { + _part.UnloadRootElement(); + } + + base.Dispose(disposing); + } + + public override void Write(byte[] buffer, int offset, int count) + { + NotifyOfWrite(); + base.Write(buffer, offset, count); + } + +#if NET46_OR_GREATER || NET || NETSTANDARD + public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + NotifyOfWrite(); + return base.WriteAsync(buffer, offset, count, cancellationToken); + } +#endif + +#if NET6_0_OR_GREATER + public override void Write(ReadOnlySpan buffer) + { + NotifyOfWrite(); + base.Write(buffer); + } + + public override System.Threading.Tasks.ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + NotifyOfWrite(); + return base.WriteAsync(buffer, cancellationToken); + } +#endif + + public override void WriteByte(byte value) + { + NotifyOfWrite(); + base.WriteByte(value); + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state) + { + NotifyOfWrite(); + return base.BeginWrite(buffer, offset, count, callback, state); + } + + private void NotifyOfWrite() + { + _hasWritten = true; + } + } } } diff --git a/test/DocumentFormat.OpenXml.Tests/OpenXmlPartTests.cs b/test/DocumentFormat.OpenXml.Tests/OpenXmlPartTests.cs new file mode 100644 index 000000000..9095c315e --- /dev/null +++ b/test/DocumentFormat.OpenXml.Tests/OpenXmlPartTests.cs @@ -0,0 +1,72 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using DocumentFormat.OpenXml.Spreadsheet; +using System; +using System.IO; +using System.Text; +using Xunit; + +namespace DocumentFormat.OpenXml.Packaging.Tests; + +public class OpenXmlPartTests +{ + [InlineData(FileAccess.Write)] + [InlineData(FileAccess.ReadWrite)] + [Theory] + public void GetStreamWrite(FileAccess access) + { + // Arrange + const string expected = """Test"""; + var stream = new MemoryStream(); + { + using var package = SpreadsheetDocument.Create(stream, SpreadsheetDocumentType.Workbook); + var wb = package.AddWorkbookPart(); + + var part = wb.AddNewPart(); + + part.SharedStringTable = new SharedStringTable(); + + using var partStream = part.GetStream(FileMode.Create, access); + + var bytes = Encoding.UTF8.GetBytes(expected); + partStream.Write(bytes, 0, bytes.Length); + } + + // Reopen package + stream.Position = 0; + using var spreadsheet = SpreadsheetDocument.Open(stream, isEditable: false); + + // Assert + Assert.Equal(expected, spreadsheet.WorkbookPart.SharedStringTablePart.RootElement.OuterXml); + } + + [InlineData(FileAccess.Write)] + [InlineData(FileAccess.ReadWrite)] + [Theory] + public void GetStreamWriteNoUpdates(FileAccess access) + { + // Arrange + const string expected = """"""; + var stream = new MemoryStream(); + { + using var package = SpreadsheetDocument.Create(stream, SpreadsheetDocumentType.Workbook); + var wb = package.AddWorkbookPart(); + + var part = wb.AddNewPart(); + + part.SharedStringTable = new SharedStringTable(); + + package.Save(); + + using var partStream = part.GetStream(FileMode.Open, access); + } + + // Reopen package + stream.Position = 0; + using var spreadsheet = SpreadsheetDocument.Open(stream, isEditable: false); + + // Assert + Assert.Equal(expected, spreadsheet.WorkbookPart.SharedStringTablePart.RootElement.OuterXml); + } +}