Skip to content

Commit

Permalink
Fix StreamRefSerializer NRE bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Arkatufus committed Sep 4, 2024
1 parent b47b922 commit 1ba6014
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 35 deletions.
42 changes: 42 additions & 0 deletions src/core/Akka.Streams.Tests/Serialization/StreamRefSerializer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// -----------------------------------------------------------------------
// <copyright file="StreamRefSerializer.cs" company="Akka.NET Project">
// Copyright (C) 2009-2024 Lightbend Inc. <http://www.lightbend.com>
// Copyright (C) 2013-2024 .NET Foundation <https://github.com/akkadotnet/akka.net>
// </copyright>
// -----------------------------------------------------------------------

using System;
using Akka.Serialization;
using Akka.Streams.Implementation.StreamRef;
using FluentAssertions;
using Xunit;
using Xunit.Abstractions;
using static FluentAssertions.FluentActions;

namespace Akka.Streams.Tests.Serialization;

public class StreamRefSerializer: Akka.TestKit.Xunit2.TestKit
{
public StreamRefSerializer(ITestOutputHelper output)
: base(ActorMaterializer.DefaultConfig(), nameof(StreamRefSerializer), output)
{
}

[Fact(DisplayName = "StreamRefSerializer should not throw NRE when configuration were set before ActorSystem started")]
public void StreamsConfigBugTest()
{
var message = new SequencedOnNext(10, "test");
var serializer = (SerializerWithStringManifest)Sys.Serialization.FindSerializerFor(message);
var manifest = serializer.Manifest(message);

byte[] bytes = null;
Invoking(() =>
{
bytes = serializer.ToBinary(message); // This throws an NRE in the bug
}).Should().NotThrow<NullReferenceException>();

var deserialized = (SequencedOnNext) serializer.FromBinary(bytes, manifest);
deserialized.SeqNr.Should().Be(message.SeqNr);
deserialized.Payload.Should().Be(message.Payload);
}
}
68 changes: 33 additions & 35 deletions src/core/Akka.Streams/Serialization/StreamRefSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
using Akka.Actor;
using Akka.Serialization;
using Akka.Streams.Serialization.Proto.Msg;
using Akka.Util;
using Google.Protobuf;
using Akka.Streams.Implementation.StreamRef;
using CumulativeDemand = Akka.Streams.Implementation.StreamRef.CumulativeDemand;
Expand All @@ -19,12 +18,12 @@
using RemoteStreamFailure = Akka.Streams.Implementation.StreamRef.RemoteStreamFailure;
using SequencedOnNext = Akka.Streams.Implementation.StreamRef.SequencedOnNext;

#nullable enable
namespace Akka.Streams.Serialization
{
public sealed class StreamRefSerializer : SerializerWithStringManifest
{
private readonly ExtendedActorSystem _system;
private readonly Akka.Serialization.Serialization _serialization;

private const string SequencedOnNextManifest = "A";
private const string CumulativeDemandManifest = "B";
Expand All @@ -37,52 +36,51 @@ public sealed class StreamRefSerializer : SerializerWithStringManifest
public StreamRefSerializer(ExtendedActorSystem system) : base(system)
{
_system = system;
_serialization = system.Serialization;
}

public override string Manifest(object o)
{
switch (o)
return o switch
{
case SequencedOnNext _: return SequencedOnNextManifest;
case CumulativeDemand _: return CumulativeDemandManifest;
case OnSubscribeHandshake _: return OnSubscribeHandshakeManifest;
case RemoteStreamFailure _: return RemoteSinkFailureManifest;
case RemoteStreamCompleted _: return RemoteSinkCompletedManifest;
case SourceRefImpl _: return SourceRefManifest;
case SinkRefImpl _: return SinkRefManifest;
default: throw new ArgumentException($"Unsupported object of type {o.GetType()}", nameof(o));
}
SequencedOnNext => SequencedOnNextManifest,
CumulativeDemand => CumulativeDemandManifest,
OnSubscribeHandshake => OnSubscribeHandshakeManifest,
RemoteStreamFailure => RemoteSinkFailureManifest,
RemoteStreamCompleted => RemoteSinkCompletedManifest,
SourceRefImpl => SourceRefManifest,
SinkRefImpl => SinkRefManifest,
_ => throw new ArgumentException($"Unsupported object of type {o.GetType()}", nameof(o))
};
}

public override byte[] ToBinary(object o)
{
switch (o)
return o switch
{
case SequencedOnNext onNext: return SerializeSequencedOnNext(onNext).ToByteArray();
case CumulativeDemand demand: return SerializeCumulativeDemand(demand).ToByteArray();
case OnSubscribeHandshake handshake: return SerializeOnSubscribeHandshake(handshake).ToByteArray();
case RemoteStreamFailure failure: return SerializeRemoteStreamFailure(failure).ToByteArray();
case RemoteStreamCompleted completed: return SerializeRemoteStreamCompleted(completed).ToByteArray();
case SourceRefImpl sourceRef: return SerializeSourceRef(sourceRef).ToByteArray();
case SinkRefImpl sinkRef: return SerializeSinkRef(sinkRef).ToByteArray();
default: throw new ArgumentException($"Unsupported object of type {o.GetType()}", nameof(o));
}
SequencedOnNext onNext => SerializeSequencedOnNext(onNext).ToByteArray(),
CumulativeDemand demand => SerializeCumulativeDemand(demand).ToByteArray(),
OnSubscribeHandshake handshake => SerializeOnSubscribeHandshake(handshake).ToByteArray(),
RemoteStreamFailure failure => SerializeRemoteStreamFailure(failure).ToByteArray(),
RemoteStreamCompleted completed => SerializeRemoteStreamCompleted(completed).ToByteArray(),
SourceRefImpl sourceRef => SerializeSourceRef(sourceRef).ToByteArray(),
SinkRefImpl sinkRef => SerializeSinkRef(sinkRef).ToByteArray(),
_ => throw new ArgumentException($"Unsupported object of type {o.GetType()}", nameof(o))
};
}

public override object FromBinary(byte[] bytes, string manifest)
{
switch (manifest)
return manifest switch
{
case SequencedOnNextManifest: return DeserializeSequenceOnNext(bytes);
case CumulativeDemandManifest: return DeserializeCumulativeDemand(bytes);
case OnSubscribeHandshakeManifest: return DeserializeOnSubscribeHandshake(bytes);
case RemoteSinkFailureManifest: return DeserializeRemoteSinkFailure(bytes);
case RemoteSinkCompletedManifest: return DeserializeRemoteSinkCompleted(bytes);
case SourceRefManifest: return DeserializeSourceRef(bytes);
case SinkRefManifest: return DeserializeSinkRef(bytes);
default: throw new ArgumentException($"Unsupported manifest '{manifest}'", nameof(manifest));
}
SequencedOnNextManifest => DeserializeSequenceOnNext(bytes),
CumulativeDemandManifest => DeserializeCumulativeDemand(bytes),
OnSubscribeHandshakeManifest => DeserializeOnSubscribeHandshake(bytes),
RemoteSinkFailureManifest => DeserializeRemoteSinkFailure(bytes),
RemoteSinkCompletedManifest => DeserializeRemoteSinkCompleted(bytes),
SourceRefManifest => DeserializeSourceRef(bytes),
SinkRefManifest => DeserializeSinkRef(bytes),
_ => throw new ArgumentException($"Unsupported manifest '{manifest}'", nameof(manifest))
};
}

private SinkRefImpl DeserializeSinkRef(byte[] bytes)
Expand Down Expand Up @@ -129,7 +127,7 @@ private SequencedOnNext DeserializeSequenceOnNext(byte[] bytes)
{
var onNext = Proto.Msg.SequencedOnNext.Parser.ParseFrom(bytes);
var p = onNext.Payload;
var payload = _serialization.Deserialize(
var payload = system.Serialization.Deserialize(
p.EnclosedMessage.ToByteArray(),
p.SerializerId,
p.MessageManifest?.ToStringUtf8());
Expand Down Expand Up @@ -169,7 +167,7 @@ private ByteString SerializeCumulativeDemand(CumulativeDemand demand) =>
private ByteString SerializeSequencedOnNext(SequencedOnNext onNext)
{
var payload = onNext.Payload;
var serializer = _serialization.FindSerializerFor(payload);
var serializer = system.Serialization.FindSerializerFor(payload);
var manifest = Akka.Serialization.Serialization.ManifestFor(serializer, payload);

var p = new Payload
Expand Down

0 comments on commit 1ba6014

Please sign in to comment.