Skip to content

Commit

Permalink
Unload root element if Part.GetStream updates the underlying value (#…
Browse files Browse the repository at this point in the history
…1760)

Fixes #1755
  • Loading branch information
twsouthwick authored Jul 26, 2024
1 parent f1fecd3 commit 522f9f3
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 5 deletions.
93 changes: 88 additions & 5 deletions src/DocumentFormat.OpenXml.Framework/Packaging/OpenXmlPart.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using DocumentFormat.OpenXml.Features;
using DocumentFormat.OpenXml.Framework;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.IO.Packaging;
using System.Threading;

namespace DocumentFormat.OpenXml.Packaging
{
Expand Down Expand Up @@ -236,9 +237,7 @@ public IEnumerable<OpenXmlPart> GetParentParts()
/// <returns>The content stream of the part. </returns>
public Stream GetStream(FileMode mode)
{
ThrowIfObjectDisposed();

return PackagePart.GetStream(mode, Features.GetRequired<IPackageFeature>().Package.FileOpenAccess);
return GetStream(mode, Features.GetRequired<IPackageFeature>().Package.FileOpenAccess);
}

/// <summary>
Expand All @@ -251,7 +250,20 @@ public Stream GetStream(FileMode mode, FileAccess access)
{
ThrowIfObjectDisposed();

return PackagePart.GetStream(mode, access);
var stream = PackagePart.GetStream(mode, access);

if (mode is FileMode.Create || stream.Length == 0)
{
UnloadRootElement();
return new UnloadingRootElementStream(this, stream);
}

if (stream.CanWrite)
{
return new UnloadingRootElementStream(this, stream);
}

return stream;
}

/// <summary>
Expand Down Expand Up @@ -605,5 +617,76 @@ internal sealed override OpenXmlPart ThisOpenXmlPart
internal MarkupCompatibilityProcessSettings? MCSettings { get; set; }

#endregion

/// <summary>
/// A <see cref="Stream"/> used by <see cref="GetStream(FileMode, FileAccess)" /> to unload the root if updated.
/// </summary>
private sealed class UnloadingRootElementStream : DelegatingStream
{
private readonly OpenXmlPart _part;

private bool _hasWritten;

public UnloadingRootElementStream(OpenXmlPart part, Stream innerStream)
: base(innerStream)
{
_part = part;
}

protected override void Dispose(bool disposing)
{
if (disposing && _hasWritten)
{
_part.UnloadRootElement();
}

base.Dispose(disposing);
}

public override void Write(byte[] buffer, int offset, int count)
{
NotifyOfWrite();
base.Write(buffer, offset, count);
}

#if NET46_OR_GREATER || NET || NETSTANDARD
public override System.Threading.Tasks.Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
NotifyOfWrite();
return base.WriteAsync(buffer, offset, count, cancellationToken);
}
#endif

#if NET6_0_OR_GREATER
public override void Write(ReadOnlySpan<byte> buffer)
{
NotifyOfWrite();
base.Write(buffer);
}

public override System.Threading.Tasks.ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
NotifyOfWrite();
return base.WriteAsync(buffer, cancellationToken);
}
#endif

public override void WriteByte(byte value)
{
NotifyOfWrite();
base.WriteByte(value);
}

public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? callback, object? state)
{
NotifyOfWrite();
return base.BeginWrite(buffer, offset, count, callback, state);
}

private void NotifyOfWrite()
{
_hasWritten = true;
}
}
}
}
72 changes: 72 additions & 0 deletions test/DocumentFormat.OpenXml.Tests/OpenXmlPartTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using DocumentFormat.OpenXml.Spreadsheet;
using System;
using System.IO;
using System.Text;
using Xunit;

namespace DocumentFormat.OpenXml.Packaging.Tests;

public class OpenXmlPartTests
{
[InlineData(FileAccess.Write)]
[InlineData(FileAccess.ReadWrite)]
[Theory]
public void GetStreamWrite(FileAccess access)
{
// Arrange
const string expected = """<x:sst xmlns:x="http://schemas.openxmlformats.org/spreadsheetml/2006/main"><x:si><x:t>Test</x:t></x:si></x:sst>""";
var stream = new MemoryStream();
{
using var package = SpreadsheetDocument.Create(stream, SpreadsheetDocumentType.Workbook);
var wb = package.AddWorkbookPart();

var part = wb.AddNewPart<SharedStringTablePart>();

part.SharedStringTable = new SharedStringTable();

using var partStream = part.GetStream(FileMode.Create, access);

var bytes = Encoding.UTF8.GetBytes(expected);
partStream.Write(bytes, 0, bytes.Length);
}

// Reopen package
stream.Position = 0;
using var spreadsheet = SpreadsheetDocument.Open(stream, isEditable: false);

// Assert
Assert.Equal(expected, spreadsheet.WorkbookPart.SharedStringTablePart.RootElement.OuterXml);
}

[InlineData(FileAccess.Write)]
[InlineData(FileAccess.ReadWrite)]
[Theory]
public void GetStreamWriteNoUpdates(FileAccess access)
{
// Arrange
const string expected = """<x:sst xmlns:x="http://schemas.openxmlformats.org/spreadsheetml/2006/main" />""";
var stream = new MemoryStream();
{
using var package = SpreadsheetDocument.Create(stream, SpreadsheetDocumentType.Workbook);
var wb = package.AddWorkbookPart();

var part = wb.AddNewPart<SharedStringTablePart>();

part.SharedStringTable = new SharedStringTable();

package.Save();

using var partStream = part.GetStream(FileMode.Open, access);
}

// Reopen package
stream.Position = 0;
using var spreadsheet = SpreadsheetDocument.Open(stream, isEditable: false);

// Assert
Assert.Equal(expected, spreadsheet.WorkbookPart.SharedStringTablePart.RootElement.OuterXml);
}
}

0 comments on commit 522f9f3

Please sign in to comment.