diff --git a/Source/MQTTnet.AspNetCore.Tests/MQTTnet.AspNetCore.Tests.csproj b/Source/MQTTnet.AspNetCore.Tests/MQTTnet.AspNetCore.Tests.csproj index 63f972c25..2d39e1b57 100644 --- a/Source/MQTTnet.AspNetCore.Tests/MQTTnet.AspNetCore.Tests.csproj +++ b/Source/MQTTnet.AspNetCore.Tests/MQTTnet.AspNetCore.Tests.csproj @@ -9,9 +9,9 @@ - - - + + + diff --git a/Source/MQTTnet.Tests/MQTTnet.Tests.csproj b/Source/MQTTnet.Tests/MQTTnet.Tests.csproj index adf62ce8e..0779befb7 100644 --- a/Source/MQTTnet.Tests/MQTTnet.Tests.csproj +++ b/Source/MQTTnet.Tests/MQTTnet.Tests.csproj @@ -12,9 +12,9 @@ - - - + + + diff --git a/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs b/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs index 533cbb2bb..194969ce1 100644 --- a/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs +++ b/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs @@ -2,13 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.VisualStudio.TestTools.UnitTesting; -using MQTTnet.Implementations; using System; using System.Net; using System.Net.Sockets; using System.Threading; using System.Threading.Tasks; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using MQTTnet.Implementations; namespace MQTTnet.Tests { @@ -18,62 +18,68 @@ public class MqttTcpChannel_Tests [TestMethod] public async Task Dispose_Channel_While_Used() { - var ct = new CancellationTokenSource(); - var serverSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp); - - try + using (var ct = new CancellationTokenSource()) { - serverSocket.Bind(new IPEndPoint(IPAddress.Any, 50001)); - serverSocket.Listen(0); + using (var serverSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp)) + { + try + { + serverSocket.Bind(new IPEndPoint(IPAddress.Any, 50001)); + serverSocket.Listen(0); #pragma warning disable 4014 - Task.Run(async () => + Task.Run( + async () => #pragma warning restore 4014 - { - while (!ct.IsCancellationRequested) - { - var client = await serverSocket.AcceptAsync(); - var data = new byte[] { 128 }; - await client.SendAsync(new ArraySegment(data), SocketFlags.None); - } - }, ct.Token); + { + while (!ct.IsCancellationRequested) + { + var client = await serverSocket.AcceptAsync(ct.Token); + var data = new byte[] { 128 }; + await client.SendAsync(new ArraySegment(data), SocketFlags.None); + } + }, + ct.Token); - var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp); - await clientSocket.ConnectAsync("localhost", 50001, CancellationToken.None); + var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp); + await clientSocket.ConnectAsync("localhost", 50001, CancellationToken.None); - var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null); + var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null); - await Task.Delay(100, ct.Token); + await Task.Delay(100, ct.Token); - var buffer = new byte[1]; - await tcpChannel.ReadAsync(buffer, 0, 1, ct.Token); + var buffer = new byte[1]; + await tcpChannel.ReadAsync(buffer, 0, 1, ct.Token); - Assert.AreEqual(128, buffer[0]); + Assert.AreEqual(128, buffer[0]); - // This block should fail after dispose. + // This block should fail after dispose. #pragma warning disable 4014 - Task.Run(() => + Task.Run( + () => #pragma warning restore 4014 - { - Task.Delay(200, ct.Token); - tcpChannel.Dispose(); - }, ct.Token); + { + Task.Delay(200, ct.Token); + tcpChannel.Dispose(); + }, + ct.Token); - try - { - await tcpChannel.ReadAsync(buffer, 0, 1, CancellationToken.None); - } - catch (Exception exception) - { - Assert.IsInstanceOfType(exception, typeof(SocketException)); - Assert.AreEqual(SocketError.OperationAborted, ((SocketException)exception).SocketErrorCode); + try + { + await tcpChannel.ReadAsync(buffer, 0, 1, CancellationToken.None); + } + catch (Exception exception) + { + Assert.IsInstanceOfType(exception, typeof(SocketException)); + Assert.AreEqual(SocketError.OperationAborted, ((SocketException)exception).SocketErrorCode); + } + } + finally + { + ct.Cancel(false); + } } } - finally - { - ct.Cancel(false); - serverSocket.Dispose(); - } } } -} +} \ No newline at end of file diff --git a/Source/MQTTnet/Implementations/CrossPlatformSocket.cs b/Source/MQTTnet/Implementations/CrossPlatformSocket.cs index 619760503..a298b624d 100644 --- a/Source/MQTTnet/Implementations/CrossPlatformSocket.cs +++ b/Source/MQTTnet/Implementations/CrossPlatformSocket.cs @@ -142,12 +142,14 @@ public int SendTimeout set => _socket.SendTimeout = value; } - public async Task AcceptAsync() + public async Task AcceptAsync(CancellationToken cancellationToken) { try { #if NET452 || NET461 var clientSocket = await Task.Factory.FromAsync(_socket.BeginAccept, _socket.EndAccept, null).ConfigureAwait(false); +#elif NET7_0_OR_GREATER + var clientSocket = await _socket.AcceptAsync(cancellationToken).ConfigureAwait(false); #else var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); #endif diff --git a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs index bb008afb2..a5784ed1b 100644 --- a/Source/MQTTnet/Implementations/MqttTcpServerListener.cs +++ b/Source/MQTTnet/Implementations/MqttTcpServerListener.cs @@ -141,7 +141,7 @@ async Task AcceptClientConnectionsAsync(CancellationToken cancellationToken) { try { - var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); + var clientSocket = await _socket.AcceptAsync(cancellationToken).ConfigureAwait(false); if (clientSocket == null) { continue;