diff --git a/build-logic/src/main/kotlin/rsocketbuild.multiplatform-base.gradle.kts b/build-logic/src/main/kotlin/rsocketbuild.multiplatform-base.gradle.kts index 898f147c..f3abd98d 100644 --- a/build-logic/src/main/kotlin/rsocketbuild.multiplatform-base.gradle.kts +++ b/build-logic/src/main/kotlin/rsocketbuild.multiplatform-base.gradle.kts @@ -46,6 +46,7 @@ kotlin { // rsocket related optIn(OptIns.TransportApi) + optIn(OptIns.RSocketTransportApi) optIn(OptIns.ExperimentalMetadataApi) optIn(OptIns.ExperimentalStreamsApi) optIn(OptIns.RSocketLoggingApi) diff --git a/build-logic/src/main/kotlin/rsocketbuild/OptIns.kt b/build-logic/src/main/kotlin/rsocketbuild/OptIns.kt index aa9e7229..5c21a6e3 100644 --- a/build-logic/src/main/kotlin/rsocketbuild/OptIns.kt +++ b/build-logic/src/main/kotlin/rsocketbuild/OptIns.kt @@ -23,6 +23,7 @@ object OptIns { const val DelicateCoroutinesApi = "kotlinx.coroutines.DelicateCoroutinesApi" const val TransportApi = "io.rsocket.kotlin.TransportApi" + const val RSocketTransportApi = "io.rsocket.kotlin.transport.RSocketTransportApi" const val ExperimentalMetadataApi = "io.rsocket.kotlin.ExperimentalMetadataApi" const val ExperimentalStreamsApi = "io.rsocket.kotlin.ExperimentalStreamsApi" const val RSocketLoggingApi = "io.rsocket.kotlin.RSocketLoggingApi" diff --git a/rsocket-core/api/rsocket-core.api b/rsocket-core/api/rsocket-core.api index d1a9ac27..c2d5a82b 100644 --- a/rsocket-core/api/rsocket-core.api +++ b/rsocket-core/api/rsocket-core.api @@ -194,6 +194,7 @@ public abstract interface class io/rsocket/kotlin/core/MimeTypeWithName : io/rso public final class io/rsocket/kotlin/core/RSocketConnector { public final fun connect (Lio/rsocket/kotlin/transport/ClientTransport;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun connect (Lio/rsocket/kotlin/transport/RSocketClientTarget;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } public final class io/rsocket/kotlin/core/RSocketConnectorBuilder { @@ -228,6 +229,8 @@ public final class io/rsocket/kotlin/core/RSocketConnectorBuilderKt { public final class io/rsocket/kotlin/core/RSocketServer { public final fun bind (Lio/rsocket/kotlin/transport/ServerTransport;Lio/rsocket/kotlin/ConnectionAcceptor;)Ljava/lang/Object; public final fun bindIn (Lkotlinx/coroutines/CoroutineScope;Lio/rsocket/kotlin/transport/ServerTransport;Lio/rsocket/kotlin/ConnectionAcceptor;)Ljava/lang/Object; + public final fun createHandler (Lio/rsocket/kotlin/ConnectionAcceptor;)Lio/rsocket/kotlin/transport/RSocketConnectionHandler; + public final fun startServer (Lio/rsocket/kotlin/transport/RSocketServerTarget;Lio/rsocket/kotlin/ConnectionAcceptor;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } public final class io/rsocket/kotlin/core/RSocketServerBuilder { @@ -760,7 +763,71 @@ public final class io/rsocket/kotlin/transport/ClientTransportKt { public static final fun ClientTransport (Lkotlin/coroutines/CoroutineContext;Lio/rsocket/kotlin/transport/ClientTransport;)Lio/rsocket/kotlin/transport/ClientTransport; } +public abstract interface class io/rsocket/kotlin/transport/RSocketClientTarget : kotlinx/coroutines/CoroutineScope { + public abstract fun connectClient (Lio/rsocket/kotlin/transport/RSocketConnectionHandler;)Lkotlinx/coroutines/Job; +} + +public abstract interface class io/rsocket/kotlin/transport/RSocketConnection { +} + +public abstract interface class io/rsocket/kotlin/transport/RSocketConnectionHandler { + public abstract fun handleConnection (Lio/rsocket/kotlin/transport/RSocketConnection;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public abstract interface class io/rsocket/kotlin/transport/RSocketMultiplexedConnection : io/rsocket/kotlin/transport/RSocketConnection { + public abstract fun acceptStream (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun createStream (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public abstract interface class io/rsocket/kotlin/transport/RSocketMultiplexedConnection$Stream : java/io/Closeable { + public abstract fun close ()V + public abstract fun isClosedForSend ()Z + public abstract fun receiveFrame (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun sendFrame (Lio/ktor/utils/io/core/ByteReadPacket;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun setSendPriority (I)V +} + +public abstract interface class io/rsocket/kotlin/transport/RSocketSequentialConnection : io/rsocket/kotlin/transport/RSocketConnection { + public abstract fun isClosedForSend ()Z + public abstract fun receiveFrame (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun sendFrame (ILio/ktor/utils/io/core/ByteReadPacket;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public abstract interface class io/rsocket/kotlin/transport/RSocketServerInstance : kotlinx/coroutines/CoroutineScope { +} + +public abstract interface class io/rsocket/kotlin/transport/RSocketServerTarget : kotlinx/coroutines/CoroutineScope { + public abstract fun startServer (Lio/rsocket/kotlin/transport/RSocketConnectionHandler;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public abstract interface class io/rsocket/kotlin/transport/RSocketTransport : kotlinx/coroutines/CoroutineScope { +} + +public abstract interface annotation class io/rsocket/kotlin/transport/RSocketTransportApi : java/lang/annotation/Annotation { +} + +public abstract interface class io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun buildTransport (Lkotlin/coroutines/CoroutineContext;)Lio/rsocket/kotlin/transport/RSocketTransport; +} + +public abstract class io/rsocket/kotlin/transport/RSocketTransportFactory { + public fun (Lkotlin/jvm/functions/Function0;)V + public final fun getCreateBuilder ()Lkotlin/jvm/functions/Function0; + public final fun invoke (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;)Lio/rsocket/kotlin/transport/RSocketTransport; + public static synthetic fun invoke$default (Lio/rsocket/kotlin/transport/RSocketTransportFactory;Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)Lio/rsocket/kotlin/transport/RSocketTransport; +} + public abstract interface class io/rsocket/kotlin/transport/ServerTransport { public abstract fun start (Lkotlinx/coroutines/CoroutineScope;Lkotlin/jvm/functions/Function3;)Ljava/lang/Object; } +public final class io/rsocket/kotlin/transport/internal/PrioritizationFrameQueue { + public fun (I)V + public final fun cancel ()V + public final fun close ()V + public final fun dequeueFrame (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun enqueueFrame (ILio/ktor/utils/io/core/ByteReadPacket;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public final fun isClosedForSend ()Z + public final fun tryDequeueFrame ()Lio/ktor/utils/io/core/ByteReadPacket; +} + diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt index 8cb03268..e38d2149 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt @@ -17,9 +17,6 @@ package io.rsocket.kotlin import io.ktor.utils.io.core.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.internal.io.* import kotlinx.coroutines.* /** @@ -30,12 +27,3 @@ public interface Connection : CoroutineScope { public suspend fun send(packet: ByteReadPacket) public suspend fun receive(): ByteReadPacket } - -@OptIn(TransportApi::class) -internal suspend inline fun Connection.receiveFrame(pool: BufferPool, block: (frame: Frame) -> T): T = - receive().readFrame(pool).closeOnError(block) - -@OptIn(TransportApi::class) -internal suspend fun Connection.sendFrame(pool: BufferPool, frame: Frame) { - frame.toPacket(pool).closeOnError { send(it) } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/Connection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/Connection.kt new file mode 100644 index 00000000..27492d98 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/Connection.kt @@ -0,0 +1,151 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.operation.* +import io.rsocket.kotlin.payload.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.* +import kotlin.coroutines.* + +// TODO: rename to just `Connection` after root `Connection` will be dropped +@RSocketTransportApi +internal abstract class Connection2( + protected val frameCodec: FrameCodec, + // requestContext + final override val coroutineContext: CoroutineContext, +) : RSocket, Closeable { + + // connection establishment part + + abstract suspend fun establishConnection(handler: ConnectionEstablishmentHandler): ConnectionConfig + + // setup completed, start handling requests + abstract suspend fun handleConnection(inbound: ConnectionInbound) + + // connection part + + protected abstract suspend fun sendConnectionFrame(frame: ByteReadPacket) + private suspend fun sendConnectionFrame(frame: Frame): Unit = sendConnectionFrame(frameCodec.encodeFrame(frame)) + + suspend fun sendError(cause: Throwable) { + sendConnectionFrame(ErrorFrame(0, cause)) + } + + private suspend fun sendMetadataPush(metadata: ByteReadPacket) { + sendConnectionFrame(MetadataPushFrame(metadata)) + } + + suspend fun sendKeepAlive(respond: Boolean, data: ByteReadPacket, lastPosition: Long) { + sendConnectionFrame(KeepAliveFrame(respond, lastPosition, data)) + } + + // operations part + + protected abstract fun launchRequest(requestPayload: Payload, operation: RequesterOperation): Job + private suspend fun ensureActiveOrClose(closeable: Closeable) { + currentCoroutineContext().ensureActive { closeable.close() } + coroutineContext.ensureActive { closeable.close() } + } + + final override suspend fun metadataPush(metadata: ByteReadPacket) { + ensureActiveOrClose(metadata) + sendMetadataPush(metadata) + } + + final override suspend fun fireAndForget(payload: Payload) { + ensureActiveOrClose(payload) + + suspendCancellableCoroutine { cont -> + val requestJob = launchRequest( + requestPayload = payload, + operation = RequesterFireAndForgetOperation(cont) + ) + cont.invokeOnCancellation { cause -> + requestJob.cancel("Request was cancelled", cause) + } + } + } + + final override suspend fun requestResponse(payload: Payload): Payload { + ensureActiveOrClose(payload) + + val responseDeferred = CompletableDeferred() + + val requestJob = launchRequest( + requestPayload = payload, + operation = RequesterRequestResponseOperation(responseDeferred) + ) + + try { + responseDeferred.join() + } catch (cause: Throwable) { + requestJob.cancel("Request was cancelled", cause) + throw cause + } + return responseDeferred.await() + } + + @OptIn(ExperimentalStreamsApi::class) + final override fun requestStream( + payload: Payload, + ): Flow = payloadFlow { strategy, initialRequest -> + ensureActiveOrClose(payload) + + val responsePayloads = PayloadChannel() + + val requestJob = launchRequest( + requestPayload = payload, + operation = RequesterRequestStreamOperation(initialRequest, responsePayloads) + ) + + throw try { + responsePayloads.consumeInto(this, strategy) + } catch (cause: Throwable) { + requestJob.cancel("Request was cancelled", cause) + throw cause + } ?: return@payloadFlow + } + + @OptIn(ExperimentalStreamsApi::class) + final override fun requestChannel( + initPayload: Payload, + payloads: Flow, + ): Flow = payloadFlow { strategy, initialRequest -> + ensureActiveOrClose(initPayload) + + val responsePayloads = PayloadChannel() + + val requestJob = launchRequest( + initPayload, + RequesterRequestChannelOperation(initialRequest, payloads, responsePayloads) + ) + + throw try { + responsePayloads.consumeInto(this, strategy) + } catch (cause: Throwable) { + requestJob.cancel("Request was cancelled", cause) + throw cause + } ?: return@payloadFlow + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentContext.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentContext.kt new file mode 100644 index 00000000..0ff8ed5c --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentContext.kt @@ -0,0 +1,49 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.frame.io.* +import io.rsocket.kotlin.keepalive.* +import io.rsocket.kotlin.payload.* +import io.rsocket.kotlin.transport.* + +// send/receive setup, resume, resume ok, lease, error +@RSocketTransportApi +internal abstract class ConnectionEstablishmentContext( + private val frameCodec: FrameCodec, +) { + protected abstract suspend fun receiveFrameRaw(): ByteReadPacket? + protected abstract suspend fun sendFrame(frame: ByteReadPacket) + private suspend fun sendFrame(frame: Frame): Unit = sendFrame(frameCodec.encodeFrame(frame)) + + // only setup|lease|resume|resume_ok|error frames + suspend fun receiveFrame(): Frame = frameCodec.decodeFrame( + expectedStreamId = 0, + frame = receiveFrameRaw() ?: error("Expected frame during connection establishment but nothing was received") + ) + + suspend fun sendSetup( + version: Version, + honorLease: Boolean, + keepAlive: KeepAlive, + resumeToken: ByteReadPacket?, + payloadMimeType: PayloadMimeType, + payload: Payload, + ): Unit = sendFrame(SetupFrame(version, honorLease, keepAlive, resumeToken, payloadMimeType, payload)) +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentHandler.kt new file mode 100644 index 00000000..532fdb19 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionEstablishmentHandler.kt @@ -0,0 +1,107 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.rsocket.kotlin.* +import io.rsocket.kotlin.core.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.keepalive.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal abstract class ConnectionEstablishmentHandler( + private val isClient: Boolean, + private val frameCodec: FrameCodec, + private val connectionAcceptor: ConnectionAcceptor, + private val interceptors: Interceptors, + private val requesterDeferred: CompletableDeferred?, +) : RSocketConnectionHandler { + abstract suspend fun establishConnection(context: ConnectionEstablishmentContext): ConnectionConfig + + private suspend fun wrapConnection( + connection: RSocketConnection, + requestContext: CoroutineContext, + ): Connection2 = when (connection) { + is RSocketMultiplexedConnection -> { + val initialStream = when { + isClient -> connection.createStream() + else -> connection.acceptStream() ?: error("Initial stream should be received") + } + initialStream.setSendPriority(0) + MultiplexedConnection(isClient, frameCodec, requestContext, connection, initialStream) + } + + is RSocketSequentialConnection -> { + SequentialConnection(isClient, frameCodec, requestContext, connection) + } + } + + @Suppress("SuspendFunctionOnCoroutineScope") + private suspend fun CoroutineScope.handleConnection(connection: Connection2) { + try { + val connectionConfig = connection.establishConnection(this@ConnectionEstablishmentHandler) + try { + val requester = interceptors.wrapRequester(connection) + val responder = interceptors.wrapResponder( + with(interceptors.wrapAcceptor(connectionAcceptor)) { + ConnectionAcceptorContext(connectionConfig, requester).accept() + } + ) + + // link completing of requester, connection and requestHandler + requester.coroutineContext.job.invokeOnCompletion { + coroutineContext.job.cancel("Requester cancelled", it) + } + responder.coroutineContext.job.invokeOnCompletion { + coroutineContext.job.cancel("Responder cancelled", it) + } + coroutineContext.job.invokeOnCompletion { cause -> + // the responder is not linked to `coroutineContext` + responder.cancel("Connection closed", cause) + } + + requesterDeferred?.complete(requester) + + val keepAliveHandler = KeepAliveHandler(connectionConfig.keepAlive, connection, this) + connection.handleConnection( + ConnectionInbound(connection.coroutineContext, responder, keepAliveHandler) + ) + } catch (cause: Throwable) { + connectionConfig.setupPayload.close() + throw cause + } + } catch (cause: Throwable) { + connection.close() + withContext(NonCancellable) { + connection.sendError( + when (cause) { + is RSocketError -> cause + else -> RSocketError.ConnectionError(cause.message ?: "Connection failed") + } + ) + } + throw cause + } + } + + final override suspend fun handleConnection(connection: RSocketConnection): Unit = coroutineScope { + handleConnection(wrapConnection(connection, coroutineContext.supervisorContext())) + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionInbound.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionInbound.kt new file mode 100644 index 00000000..9d584bc0 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/ConnectionInbound.kt @@ -0,0 +1,72 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.keepalive.* +import io.rsocket.kotlin.operation.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class ConnectionInbound( + // requestContext + override val coroutineContext: CoroutineContext, + private val responder: RSocket, + private val keepAliveHandler: KeepAliveHandler, +) : CoroutineScope { + fun handleFrame(frame: Frame): Unit = when (frame) { + is MetadataPushFrame -> receiveMetadataPush(frame.metadata) + is KeepAliveFrame -> receiveKeepAlive(frame.respond, frame.data, frame.lastPosition) + is ErrorFrame -> receiveError(frame.throwable) + is LeaseFrame -> receiveLease(frame.ttl, frame.numberOfRequests, frame.metadata) + // ignore other frames + else -> frame.close() + } + + private fun receiveMetadataPush(metadata: ByteReadPacket) { + launch { + responder.metadataPush(metadata) + }.invokeOnCompletion { metadata.close() } + } + + @Suppress("UNUSED_PARAMETER") // will be used later + private fun receiveKeepAlive(respond: Boolean, data: ByteReadPacket, lastPosition: Long) { + keepAliveHandler.receive(data, respond) + } + + @Suppress("UNUSED_PARAMETER") // will be used later + private fun receiveLease(ttl: Int, numberOfRequests: Int, metadata: ByteReadPacket?) { + metadata?.close() + error("Lease is not supported") + } + + private fun receiveError(cause: Throwable) { + throw cause // TODO? + } + + fun createOperation(type: FrameType, requestJob: Job): ResponderOperation = when (type) { + FrameType.RequestFnF -> ResponderFireAndForgetOperation(requestJob, responder) + FrameType.RequestResponse -> ResponderRequestResponseOperation(requestJob, responder) + FrameType.RequestStream -> ResponderRequestStreamOperation(requestJob, responder) + FrameType.RequestChannel -> ResponderRequestChannelOperation(requestJob, responder) + else -> error("should happen") + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt new file mode 100644 index 00000000..597f2fa9 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/LoggingConnection.kt @@ -0,0 +1,113 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.logging.* +import io.rsocket.kotlin.transport.* + +@RSocketLoggingApi +@RSocketTransportApi +internal fun RSocketConnectionHandler.logging(logger: Logger, bufferPool: BufferPool): RSocketConnectionHandler { + if (!logger.isLoggable(LoggingLevel.DEBUG)) return this + + return RSocketConnectionHandler { + handleConnection( + when (it) { + is RSocketSequentialConnection -> SequentialLoggingConnection(it, logger, bufferPool) + is RSocketMultiplexedConnection -> MultiplexedLoggingConnection(it, logger, bufferPool) + } + ) + } +} + +@RSocketLoggingApi +@RSocketTransportApi +private class SequentialLoggingConnection( + private val delegate: RSocketSequentialConnection, + private val logger: Logger, + private val bufferPool: BufferPool, +) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = delegate.isClosedForSend + + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + logger.debug { "Send: ${dumpFrameToString(frame, bufferPool)}" } + delegate.sendFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return delegate.receiveFrame()?.also { frame -> + logger.debug { "Receive: ${dumpFrameToString(frame, bufferPool)}" } + } + } + +} + +private fun dumpFrameToString(frame: ByteReadPacket, bufferPool: BufferPool): String { + val length = frame.remaining + return frame.copy().use { it.readFrame(bufferPool).use { it.dump(length) } } +} + +@RSocketLoggingApi +@RSocketTransportApi +private class MultiplexedLoggingConnection( + private val delegate: RSocketMultiplexedConnection, + private val logger: Logger, + private val bufferPool: BufferPool, +) : RSocketMultiplexedConnection { + override suspend fun createStream(): RSocketMultiplexedConnection.Stream { + return MultiplexedLoggingStream(delegate.createStream(), logger, bufferPool) + } + + override suspend fun acceptStream(): RSocketMultiplexedConnection.Stream? { + return delegate.acceptStream()?.let { + MultiplexedLoggingStream(it, logger, bufferPool) + } + } +} + +@RSocketLoggingApi +@RSocketTransportApi +private class MultiplexedLoggingStream( + private val delegate: RSocketMultiplexedConnection.Stream, + private val logger: Logger, + private val bufferPool: BufferPool, +) : RSocketMultiplexedConnection.Stream { + override val isClosedForSend: Boolean get() = delegate.isClosedForSend + + override fun setSendPriority(priority: Int) { + delegate.setSendPriority(priority) + } + + override suspend fun sendFrame(frame: ByteReadPacket) { + logger.debug { "Send: ${dumpFrameToString(frame, bufferPool)}" } + delegate.sendFrame(frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return delegate.receiveFrame()?.also { frame -> + logger.debug { "Receive: ${dumpFrameToString(frame, bufferPool)}" } + } + } + + override fun close() { + delegate.close() + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/MultiplexedConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/MultiplexedConnection.kt new file mode 100644 index 00000000..ebc37c5f --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/MultiplexedConnection.kt @@ -0,0 +1,233 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.operation.* +import io.rsocket.kotlin.payload.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class MultiplexedConnection( + isClient: Boolean, + frameCodec: FrameCodec, + requestContext: CoroutineContext, + private val connection: RSocketMultiplexedConnection, + private val initialStream: RSocketMultiplexedConnection.Stream, +) : Connection2(frameCodec, requestContext) { + private val storage = StreamDataStorage(isClient) + + override fun close() { + storage.clear() + } + + override suspend fun establishConnection(handler: ConnectionEstablishmentHandler): ConnectionConfig { + return handler.establishConnection(EstablishmentContext()) + } + + private inner class EstablishmentContext : ConnectionEstablishmentContext(frameCodec) { + override suspend fun sendFrame(frame: ByteReadPacket): Unit = initialStream.sendFrame(frame) + override suspend fun receiveFrameRaw(): ByteReadPacket? = initialStream.receiveFrame() + } + + override suspend fun handleConnection(inbound: ConnectionInbound) = coroutineScope { + launch { + while (true) { + val frame = frameCodec.decodeFrame( + expectedStreamId = 0, + frame = initialStream.receiveFrame() ?: break + ) + inbound.handleFrame(frame) + } + } + + while (true) if (!acceptRequest(inbound)) break + } + + override suspend fun sendConnectionFrame(frame: ByteReadPacket) { + initialStream.sendFrame(frame) + } + + @OptIn(ExperimentalCoroutinesApi::class) + override fun launchRequest( + requestPayload: Payload, + operation: RequesterOperation, + ): Job = launch(start = CoroutineStart.ATOMIC) { + operation.handleExecutionFailure(requestPayload) { + ensureActive() // because of atomic start + val stream = connection.createStream() + val streamId = storage.createStream(Unit) + try { + execute(streamId, stream, requestPayload, operation) + } finally { + storage.removeStream(streamId) + stream.close() + } + } + } + + @OptIn(ExperimentalCoroutinesApi::class) + private fun acceptRequest( + connectionInbound: ConnectionInbound, + stream: RSocketMultiplexedConnection.Stream, + ): Job = launch(start = CoroutineStart.ATOMIC) { + try { + ensureActive() // because of atomic start + val ( + streamId, + type, + initialRequest, + requestPayload, + complete, + ) = receiveRequest(stream) + try { + val operation = connectionInbound.createOperation(type, coroutineContext.job) + operation.handleExecutionFailure(requestPayload) { + if (operation.shouldReceiveFrame(FrameType.RequestN)) + operation.receiveRequestNFrame(initialRequest) + if (operation.shouldReceiveFrame(FrameType.Payload) && complete) + operation.receivePayloadFrame(null, true) + execute(streamId, stream, requestPayload, operation) + } + } finally { + storage.removeStream(streamId) + } + } finally { + stream.close() + } + } + + private suspend fun acceptRequest(inbound: ConnectionInbound): Boolean { + val stream = connection.acceptStream() ?: return false + acceptRequest(inbound, stream) + return true + } + + private suspend fun receiveRequest(stream: RSocketMultiplexedConnection.Stream): ResponderOperationData { + val initialFrame = frameCodec.decodeFrame( + frame = stream.receiveFrame() ?: error("Expected initial frame for stream") + ) + val streamId = initialFrame.streamId + + if (streamId == 0) { + initialFrame.close() + error("expected stream id != 0") + } + if (initialFrame !is RequestFrame || !initialFrame.type.isRequestType) { + initialFrame.close() + error("expected request frame type") + } + if (!storage.acceptStream(streamId, Unit)) { + initialFrame.close() + error("invalid stream id") + } + + val complete: Boolean + val requestPayload: Payload + val assembler = PayloadAssembler() + try { + if (initialFrame.follows) { + assembler.appendFragment(initialFrame.payload) + while (true) { + val frame = frameCodec.decodeFrame( + expectedStreamId = streamId, + frame = stream.receiveFrame() ?: error("Unexpected stream closure") + ) + when (frame) { + // request is cancelled during fragmentation + is CancelFrame -> error("Request was cancelled by remote party") + is RequestFrame -> { + // TODO: extract assembly logic? + when { + // for RC, it could contain the complete flag + // complete+follows=complete, "complete" overrides "follows" flag + frame.complete -> check(frame.next) { "next flag should be set" } + frame.next && !frame.follows -> {} // last fragment + else -> { + assembler.appendFragment(frame.payload) + continue // await more fragments + } + } + complete = frame.complete + requestPayload = assembler.assemblePayload(frame.payload) + break + } + + else -> { + frame.close() + error("unexpected frame: ${frame.type}") + } + } + } + } else { + complete = initialFrame.complete + requestPayload = initialFrame.payload + } + } catch (cause: Throwable) { + assembler.close() + throw cause + } + + return ResponderOperationData( + streamId = streamId, + requestType = initialFrame.type, + initialRequest = initialFrame.initialRequest, + requestPayload = requestPayload, + complete = complete + ) + } + + private suspend fun execute( + streamId: Int, + stream: RSocketMultiplexedConnection.Stream, + requestPayload: Payload, + operation: Operation, + ): Unit = coroutineScope { + val outbound = Outbound(streamId, stream) + val receiveJob = launch { + val handler = OperationFrameHandler(operation) + try { + while (true) { + val frame = frameCodec.decodeFrame( + expectedStreamId = streamId, + frame = stream.receiveFrame() ?: break + ) + handler.handleFrame(frame) + } + handler.handleDone() + } finally { + handler.close() + } + } + operation.execute(outbound, requestPayload) + receiveJob.cancel() // stop receiving + } + + private inner class Outbound( + streamId: Int, + private val stream: RSocketMultiplexedConnection.Stream, + ) : OperationOutbound(streamId, frameCodec) { + override val isClosed: Boolean get() = stream.isClosedForSend + override suspend fun sendFrame(frame: ByteReadPacket): Unit = stream.sendFrame(frame) + } + +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/OldConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/OldConnection.kt new file mode 100644 index 00000000..9fe11bea --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/OldConnection.kt @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* + +@TransportApi +@RSocketTransportApi +internal suspend fun RSocketConnectionHandler.handleConnection(connection: Connection): Unit = coroutineScope { + val outboundQueue = PrioritizationFrameQueue(Channel.BUFFERED) + + val senderJob = launch { + while (true) connection.send(outboundQueue.dequeueFrame() ?: break) + }.onCompletion { outboundQueue.cancel() } + + try { + handleConnection(OldConnection(outboundQueue, connection)) + } finally { + outboundQueue.close() + withContext(NonCancellable) { + senderJob.join() + } + } +} + +@TransportApi +@RSocketTransportApi +private class OldConnection( + private val outboundQueue: PrioritizationFrameQueue, + private val connection: Connection, +) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = outboundQueue.isClosedForSend + + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + return outboundQueue.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? = try { + connection.receive() + } catch (cause: Throwable) { + currentCoroutineContext().ensureActive() + null + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/SequentialConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/SequentialConnection.kt new file mode 100644 index 00000000..398d5e56 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/SequentialConnection.kt @@ -0,0 +1,189 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.operation.* +import io.rsocket.kotlin.payload.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@RSocketTransportApi +internal class SequentialConnection( + isClient: Boolean, + frameCodec: FrameCodec, + requestContext: CoroutineContext, + private val connection: RSocketSequentialConnection, +) : Connection2(frameCodec, requestContext) { + private val storage = StreamDataStorage(isClient) + + override fun close() { + storage.clear().forEach { it.close() } + } + + override suspend fun establishConnection(handler: ConnectionEstablishmentHandler): ConnectionConfig { + return handler.establishConnection(EstablishmentContext()) + } + + private inner class EstablishmentContext : ConnectionEstablishmentContext(frameCodec) { + override suspend fun sendFrame(frame: ByteReadPacket): Unit = connection.sendFrame(streamId = 0, frame) + override suspend fun receiveFrameRaw(): ByteReadPacket? = connection.receiveFrame() + } + + override suspend fun handleConnection(inbound: ConnectionInbound) { + while (true) { + val frame = frameCodec.decodeFrame( + frame = connection.receiveFrame() ?: break + ) + when (frame.streamId) { + 0 -> inbound.handleFrame(frame) + else -> receiveFrame(inbound, frame) + } + } + } + + override suspend fun sendConnectionFrame(frame: ByteReadPacket) { + connection.sendFrame(0, frame) + } + + @OptIn(ExperimentalCoroutinesApi::class) + override fun launchRequest( + requestPayload: Payload, + operation: RequesterOperation, + ): Job = launch(start = CoroutineStart.ATOMIC) { + operation.handleExecutionFailure(requestPayload) { + ensureActive() // because of atomic start + val streamId = storage.createStream(OperationFrameHandler(operation)) + try { + operation.execute(Outbound(streamId), requestPayload) + } finally { + storage.removeStream(streamId)?.close() + } + } + } + + @OptIn(ExperimentalCoroutinesApi::class) + private fun acceptRequest( + connectionInbound: ConnectionInbound, + operationData: ResponderOperationData, + ): ResponderOperation { + val requestJob = Job(coroutineContext.job) + val operation = connectionInbound.createOperation(operationData.requestType, requestJob) + launch(requestJob, start = CoroutineStart.ATOMIC) { + val ( + streamId, + _, + initialRequest, + requestPayload, + complete, + ) = operationData + operation.handleExecutionFailure(requestPayload) { + ensureActive() // because of atomic start + try { + if (operation.shouldReceiveFrame(FrameType.RequestN)) + operation.receiveRequestNFrame(initialRequest) + if (operation.shouldReceiveFrame(FrameType.Payload) && complete) + operation.receivePayloadFrame(null, true) + operation.execute(Outbound(streamId), requestPayload) + } finally { + storage.removeStream(streamId)?.close() + } + } + } + requestJob.complete() + return operation + } + + private fun receiveFrame(connectionInbound: ConnectionInbound, frame: Frame) { + val streamId = frame.streamId + if (frame is RequestFrame && frame.type.isRequestType) { + if (storage.isValidForAccept(streamId)) { + val operationData = ResponderOperationData( + streamId = streamId, + requestType = frame.type, + initialRequest = frame.initialRequest, + requestPayload = frame.payload, + complete = frame.complete + ) + val handler = OperationFrameHandler( + when { + frame.follows -> ResponderInboundWrapper(connectionInbound, operationData) + else -> acceptRequest(connectionInbound, operationData) + } + ) + if (storage.acceptStream(streamId, handler)) { + // for fragmentation + if (frame.follows) handler.handleFrame(frame) + } else { + frame.close() + handler.close() + } + } else { + frame.close() // ignore + } + } else { + storage.getStream(streamId)?.handleFrame(frame) ?: frame.close() + } + } + + private inner class Outbound(streamId: Int) : OperationOutbound(streamId, frameCodec) { + override val isClosed: Boolean get() = !isActive || connection.isClosedForSend + override suspend fun sendFrame(frame: ByteReadPacket): Unit = connection.sendFrame(streamId, frame) + } + + private inner class ResponderInboundWrapper( + private val connectionInbound: ConnectionInbound, + private val operationData: ResponderOperationData, + ) : OperationInbound { + + override fun shouldReceiveFrame(frameType: FrameType): Boolean { + return frameType.isRequestType || frameType === FrameType.Payload || frameType === FrameType.Cancel + } + + override fun receivePayloadFrame(payload: Payload?, complete: Boolean) { + if (payload != null) { + val operation = acceptRequest( + connectionInbound = connectionInbound, + operationData = ResponderOperationData( + streamId = operationData.streamId, + requestType = operationData.requestType, + initialRequest = operationData.initialRequest, + requestPayload = payload, + complete = complete + ) + ) + // close old handler + storage.replaceStream(operationData.streamId, OperationFrameHandler(operation))?.close() + } else { + // should not happen really + storage.removeStream(operationData.streamId)?.close() + } + } + + override fun receiveCancelFrame() { + storage.removeStream(operationData.streamId)?.close() + } + + override fun receiveDone() { + // if for some reason it happened... + storage.removeStream(operationData.streamId)?.close() + } + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/StreamDataStorage.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/StreamDataStorage.kt new file mode 100644 index 00000000..c159fa79 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/StreamDataStorage.kt @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.connection + +import io.rsocket.kotlin.internal.* +import kotlinx.atomicfu.locks.* + +internal class StreamDataStorage(private val isClient: Boolean) { + private val lock = SynchronizedObject() + private val streamIdGenerator: StreamIdGenerator = StreamIdGenerator(isClient) + private val storage: IntMap = IntMap() + + fun createStream(value: T): Int = synchronized(lock) { + val streamId = streamIdGenerator.next(storage) + storage[streamId] = value + streamId + } + + // false if not valid + fun isValidForAccept(id: Int): Boolean { + if (isClient.xor(id % 2 == 0)) return false + return synchronized(lock) { id !in storage } + } + + // false if not valid + // TODO: implement IntMap.putIfAbsent + fun acceptStream(id: Int, value: T): Boolean { + if (isClient.xor(id % 2 == 0)) return false + return synchronized(lock) { + val isValid = id !in storage + if (isValid) storage[id] = value + isValid + } + } + + // for responder part of sequential connection + fun replaceStream(id: Int, value: T): T? = synchronized(lock) { storage.set(id, value) } + + fun removeStream(id: Int): T? = synchronized(lock) { storage.remove(id) } + fun getStream(id: Int): T? = synchronized(lock) { storage[id] } + + fun clear(): List = synchronized(lock) { + val values = storage.values() + storage.clear() + values + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamId.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/StreamIdGenerator.kt similarity index 71% rename from rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamId.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/StreamIdGenerator.kt index 25c3a9f6..26025458 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamId.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/connection/StreamIdGenerator.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2022 the original author or authors. + * Copyright 2015-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,11 +14,12 @@ * limitations under the License. */ -package io.rsocket.kotlin.internal +package io.rsocket.kotlin.connection +import io.rsocket.kotlin.internal.* import kotlinx.atomicfu.* -internal class StreamId(streamId: Int) { +internal class StreamIdGenerator(streamId: Int) { private val streamId = atomic(streamId) fun next(streamIds: IntMap<*>): Int { @@ -33,9 +34,10 @@ internal class StreamId(streamId: Int) { companion object { private const val MASK = 0x7FFFFFFF - fun client(): StreamId = StreamId(-1) - fun server(): StreamId = StreamId(0) - operator fun invoke(isServer: Boolean): StreamId = if (isServer) server() else client() + fun client(): StreamIdGenerator = StreamIdGenerator(-1) + fun server(): StreamIdGenerator = StreamIdGenerator(0) + + operator fun invoke(isClient: Boolean): StreamIdGenerator = if (isClient) client() else server() } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt index de4858c0..d9efe6eb 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt @@ -17,17 +17,18 @@ package io.rsocket.kotlin.core import io.rsocket.kotlin.* +import io.rsocket.kotlin.connection.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.frame.io.* -import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* +import kotlin.coroutines.* -@OptIn(TransportApi::class, RSocketLoggingApi::class) +@OptIn(TransportApi::class, RSocketTransportApi::class, RSocketLoggingApi::class) public class RSocketConnector internal constructor( - private val loggerFactory: LoggerFactory, + loggerFactory: LoggerFactory, private val maxFragmentSize: Int, private val interceptors: Interceptors, private val connectionConfigProvider: () -> ConnectionConfig, @@ -35,55 +36,63 @@ public class RSocketConnector internal constructor( private val reconnectPredicate: ReconnectPredicate?, private val bufferPool: BufferPool, ) { + private val connectionLogger = loggerFactory.logger("io.rsocket.kotlin.connection") + private val frameLogger = loggerFactory.logger("io.rsocket.kotlin.frame") - public suspend fun connect(transport: ClientTransport): RSocket = when (reconnectPredicate) { - //TODO current coroutineContext job is overriden by transport coroutineContext jov - null -> withContext(transport.coroutineContext) { connectOnce(transport) } + public suspend fun connect(transport: ClientTransport): RSocket = connect(object : RSocketClientTarget { + override val coroutineContext: CoroutineContext get() = transport.coroutineContext + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + handler.handleConnection(interceptors.wrapConnection(transport.connect())) + } + }) + + public suspend fun connect(transport: RSocketClientTarget): RSocket = when (reconnectPredicate) { + null -> connectOnce(transport) else -> connectWithReconnect( transport.coroutineContext, - loggerFactory.logger("io.rsocket.kotlin.connection"), + connectionLogger, { connectOnce(transport) }, reconnectPredicate, ) } - private suspend fun connectOnce(transport: ClientTransport): RSocket { - val connection = transport.connect().wrapConnection() - val connectionConfig = try { - connectionConfigProvider() - } catch (cause: Throwable) { - connection.cancel("Connection config provider failed", cause) - throw cause - } - val setupFrame = SetupFrame( - version = Version.Current, - honorLease = false, - keepAlive = connectionConfig.keepAlive, - resumeToken = null, - payloadMimeType = connectionConfig.payloadMimeType, - payload = connectionConfig.setupPayload.copy() //copy needed, as it can be used in acceptor - ) - try { - val requester = connect( - connection = connection, - isServer = false, - maxFragmentSize = maxFragmentSize, - interceptors = interceptors, - connectionConfig = connectionConfig, - acceptor = acceptor, - bufferPool = bufferPool - ) - connection.sendFrame(bufferPool, setupFrame) - return requester + private suspend fun connectOnce(transport: RSocketClientTarget): RSocket { + val requesterDeferred = CompletableDeferred() + val connectJob = transport.connectClient( + SetupConnection(requesterDeferred).logging(frameLogger, bufferPool) + ).onCompletion { if (it != null) requesterDeferred.completeExceptionally(it) } + return try { + requesterDeferred.await() } catch (cause: Throwable) { - connectionConfig.setupPayload.close() - setupFrame.close() - connection.cancel("Connection establishment failed", cause) + connectJob.cancel("RSocketConnector.connect was cancelled", cause) throw cause } } - private fun Connection.wrapConnection(): Connection = - interceptors.wrapConnection(this) - .logging(loggerFactory.logger("io.rsocket.kotlin.frame"), bufferPool) + private inner class SetupConnection(requesterDeferred: CompletableDeferred) : ConnectionEstablishmentHandler( + isClient = true, + frameCodec = FrameCodec(bufferPool, maxFragmentSize), + connectionAcceptor = acceptor, + interceptors = interceptors, + requesterDeferred = requesterDeferred + ) { + override suspend fun establishConnection(context: ConnectionEstablishmentContext): ConnectionConfig { + val connectionConfig = connectionConfigProvider() + try { + context.sendSetup( + version = Version.Current, + honorLease = false, + keepAlive = connectionConfig.keepAlive, + resumeToken = null, + payloadMimeType = connectionConfig.payloadMimeType, + // copy needed, as it can be used in acceptor + payload = connectionConfig.setupPayload.copy() + ) + } catch (cause: Throwable) { + connectionConfig.setupPayload.close() + throw cause + } + return connectionConfig + } + } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt index 0641d839..94bc3c9b 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt @@ -17,21 +17,22 @@ package io.rsocket.kotlin.core import io.rsocket.kotlin.* +import io.rsocket.kotlin.connection.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.frame.io.* -import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* -@OptIn(TransportApi::class, RSocketLoggingApi::class) +@OptIn(TransportApi::class, RSocketTransportApi::class, RSocketLoggingApi::class) public class RSocketServer internal constructor( - private val loggerFactory: LoggerFactory, + loggerFactory: LoggerFactory, private val maxFragmentSize: Int, private val interceptors: Interceptors, private val bufferPool: BufferPool, ) { + private val frameLogger = loggerFactory.logger("io.rsocket.kotlin.frame") @DelicateCoroutinesApi public fun bind( @@ -43,48 +44,51 @@ public class RSocketServer internal constructor( scope: CoroutineScope, transport: ServerTransport, acceptor: ConnectionAcceptor, - ): T = with(transport) { - scope.start { - it.wrapConnection().bind(acceptor).join() - } - } - - private suspend fun Connection.bind(acceptor: ConnectionAcceptor): Job = receiveFrame(bufferPool) { setupFrame -> - when { - setupFrame !is SetupFrame -> failSetup(RSocketError.Setup.Invalid("Invalid setup frame: ${setupFrame.type}")) - setupFrame.version != Version.Current -> failSetup(RSocketError.Setup.Unsupported("Unsupported version: ${setupFrame.version}")) - setupFrame.honorLease -> failSetup(RSocketError.Setup.Unsupported("Lease is not supported")) - setupFrame.resumeToken != null -> failSetup(RSocketError.Setup.Unsupported("Resume is not supported")) - else -> try { - connect( - connection = this, - isServer = true, - maxFragmentSize = maxFragmentSize, - interceptors = interceptors, - connectionConfig = ConnectionConfig( - keepAlive = setupFrame.keepAlive, - payloadMimeType = setupFrame.payloadMimeType, - setupPayload = setupFrame.payload - ), - acceptor = acceptor, - bufferPool = bufferPool - ) - coroutineContext.job - } catch (e: Throwable) { - failSetup(RSocketError.Setup.Rejected(e.message ?: "Rejected by server acceptor")) + ): T { + val handler = createHandler(acceptor) + return with(transport) { + scope.start { + handler.handleConnection(interceptors.wrapConnection(it)) } } } - @Suppress("SuspendFunctionOnCoroutineScope") - private suspend fun Connection.failSetup(error: RSocketError.Setup): Nothing { - sendFrame(bufferPool, ErrorFrame(0, error)) - cancel("Connection establishment failed", error) - throw error - } + public suspend fun startServer( + transport: RSocketServerTarget, + acceptor: ConnectionAcceptor, + ): T = transport.startServer(createHandler(acceptor)) - private fun Connection.wrapConnection(): Connection = - interceptors.wrapConnection(this) - .logging(loggerFactory.logger("io.rsocket.kotlin.frame"), bufferPool) + @RSocketTransportApi + public fun createHandler(acceptor: ConnectionAcceptor): RSocketConnectionHandler = + AcceptConnection(acceptor).logging(frameLogger, bufferPool) + private inner class AcceptConnection(acceptor: ConnectionAcceptor) : ConnectionEstablishmentHandler( + isClient = false, + frameCodec = FrameCodec(bufferPool, maxFragmentSize), + connectionAcceptor = acceptor, + interceptors = interceptors, + requesterDeferred = null + ) { + override suspend fun establishConnection(context: ConnectionEstablishmentContext): ConnectionConfig { + val setupFrame = context.receiveFrame() + return try { + when { + setupFrame !is SetupFrame -> throw RSocketError.Setup.Invalid("Invalid setup frame: ${setupFrame.type}") + setupFrame.version != Version.Current -> throw RSocketError.Setup.Unsupported("Unsupported version: ${setupFrame.version}") + setupFrame.honorLease -> throw RSocketError.Setup.Unsupported("Lease is not supported") + setupFrame.resumeToken != null -> throw RSocketError.Setup.Unsupported("Resume is not supported") + else -> { + ConnectionConfig( + keepAlive = setupFrame.keepAlive, + payloadMimeType = setupFrame.payloadMimeType, + setupPayload = setupFrame.payload + ) + } + } + } catch (cause: Throwable) { + setupFrame.close() + throw cause + } + } + } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/ReconnectableRSocket.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt similarity index 94% rename from rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/ReconnectableRSocket.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt index e48c4ec5..6968654e 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/ReconnectableRSocket.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2022 the original author or authors. + * Copyright 2015-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.rsocket.kotlin.internal +package io.rsocket.kotlin.core import io.ktor.utils.io.core.* import io.rsocket.kotlin.* @@ -54,6 +54,7 @@ internal suspend fun connectWithReconnect( value.rSocket.coroutineContext.job.join() //await for connection completion logger.debug { "Connection closed. Reconnecting..." } } + is ReconnectState.Failed -> child.cancel("Reconnect failed", value.error) //reconnect failed, fail job ReconnectState.Connecting -> Unit //skip, still waiting for new connection } @@ -86,7 +87,14 @@ private class ReconnectableRSocket( suspend fun currentRSocket(): RSocket = state.value.current() ?: state.mapNotNull { it.current() }.first() - private suspend fun currentRSocket(closeable: Closeable): RSocket = closeable.closeOnError { currentRSocket() } + private suspend fun currentRSocket(closeable: Closeable): RSocket { + return try { + currentRSocket() + } catch (cause: Throwable) { + closeable.close() + throw cause + } + } private fun ReconnectState.current(): RSocket? = when (this) { is ReconnectState.Connected -> rSocket.takeIf(RSocket::isActive) //connection is ready to handle requests diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/FrameCodec.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/FrameCodec.kt new file mode 100644 index 00000000..0d6ff001 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/frame/FrameCodec.kt @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.frame + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.internal.io.* + +internal class FrameCodec( + val bufferPool: BufferPool, + // affects encoding only, + val maxFragmentSize: Int, +) { + fun decodeFrame(frame: ByteReadPacket): Frame = frame.readFrame(bufferPool) + fun decodeFrame(expectedStreamId: Int, frame: ByteReadPacket): Frame = decodeFrame(frame).also { + if (it.streamId != expectedStreamId) { + it.close() + error("Invalid stream id, expected '$expectedStreamId', actual '${it.streamId}'") + } + } + + fun encodeFrame(frame: Frame): ByteReadPacket = frame.toPacket(bufferPool) + + // TODO: move fragmentation logic here or into separate class? +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt deleted file mode 100644 index 3d2ece0a..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal - -import io.rsocket.kotlin.* -import io.rsocket.kotlin.core.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.io.* -import kotlinx.coroutines.* - -@OptIn(TransportApi::class) -internal suspend inline fun connect( - connection: Connection, - isServer: Boolean, - maxFragmentSize: Int, - interceptors: Interceptors, - connectionConfig: ConnectionConfig, - acceptor: ConnectionAcceptor, - bufferPool: BufferPool, -): RSocket { - val prioritizer = Prioritizer() - val frameSender = FrameSender(prioritizer, bufferPool, maxFragmentSize) - val streamsStorage = StreamsStorage(isServer) - val keepAliveHandler = KeepAliveHandler(connectionConfig.keepAlive, frameSender) - - val requestJob = SupervisorJob(connection.coroutineContext[Job]) - val requestContext = connection.coroutineContext + requestJob - - requestJob.invokeOnCompletion { - prioritizer.close(it) - streamsStorage.cleanup(it) - connectionConfig.setupPayload.close() - } - - val requester = interceptors.wrapRequester( - RSocketRequester( - requestContext + CoroutineName("rSocket-requester"), - frameSender, - streamsStorage, - ) - ) - val requestHandler = interceptors.wrapResponder( - with(interceptors.wrapAcceptor(acceptor)) { - ConnectionAcceptorContext(connectionConfig, requester).accept() - } - ) - val responder = RSocketResponder( - requestContext + CoroutineName("rSocket-responder"), - frameSender, - requestHandler - ) - - // link completing of requester, connection and requestHandler - requester.coroutineContext[Job]?.invokeOnCompletion { - connection.cancel("Requester cancelled", it) - } - requestHandler.coroutineContext[Job]?.invokeOnCompletion { - if (it != null) connection.cancel("Request handler failed", it) - } - connection.coroutineContext[Job]?.invokeOnCompletion { - requester.cancel("Connection closed", it) - requestHandler.cancel("Connection closed", it) - } - - // start keepalive ticks - (connection + CoroutineName("rSocket-connection-keep-alive")).launch { - while (isActive) keepAliveHandler.tick() - } - - // start sending frames to connection - (connection + CoroutineName("rSocket-connection-send")).launch { - while (isActive) connection.sendFrame(bufferPool, prioritizer.receive()) - } - - // start frame handling - (connection + CoroutineName("rSocket-connection-receive")).launch { - while (isActive) connection.receiveFrame(bufferPool) { frame -> - when (frame.streamId) { - 0 -> when (frame) { - is MetadataPushFrame -> responder.handleMetadataPush(frame.metadata) - is ErrorFrame -> connection.cancel("Error frame received on 0 stream", frame.throwable) - is KeepAliveFrame -> keepAliveHandler.mark(frame) - is LeaseFrame -> frame.close().also { error("lease isn't implemented") } - else -> frame.close() - } - else -> streamsStorage.handleFrame(frame, responder) - } - } - } - - return requester -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/IntMap.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/IntMap.kt index bfb00956..a518837e 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/IntMap.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/IntMap.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2022 the original author or authors. + * Copyright 2015-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package io.rsocket.kotlin.internal -import kotlinx.atomicfu.locks.* import kotlin.math.* private fun safeFindNextPositivePowerOfTwo(value: Int): Int = when { @@ -25,15 +24,18 @@ private fun safeFindNextPositivePowerOfTwo(value: Int): Int = when { else -> 1 shl 32 - (value - 1).countLeadingZeroBits() } -//TODO decide, is it needed, or can be replaced by simple map, or concurrent map on JVM? -// do benchmarks +// TODO: may be move to `internal-io` (and rename to just `rsocket-internal`) +// and use in prioritization queue to support more granular prioritization for streams +// +// TODO decide, is it needed, or can be replaced by simple map, or concurrent map on JVM? +// do benchmarks /** * IntMap implementation based on Netty IntObjectHashMap. */ internal class IntMap( initialCapacity: Int = 8, - private val loadFactor: Float = 0.5f -) : SynchronizedObject() { + private val loadFactor: Float = 0.5f, +) { init { require(loadFactor > 0.0f && loadFactor <= 1.0f) { "loadFactor must be > 0 and <= 1" } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/LoggingConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/LoggingConnection.kt deleted file mode 100644 index 0ac632b8..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/LoggingConnection.kt +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -@file:OptIn(TransportApi::class, RSocketLoggingApi::class) - -package io.rsocket.kotlin.internal - -import io.ktor.utils.io.core.* -import io.rsocket.kotlin.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.io.* -import io.rsocket.kotlin.logging.* - -internal fun Connection.logging(logger: Logger, bufferPool: BufferPool): Connection = - if (logger.isLoggable(LoggingLevel.DEBUG)) LoggingConnection(this, logger, bufferPool) else this - -private class LoggingConnection( - private val delegate: Connection, - private val logger: Logger, - private val bufferPool: BufferPool, -) : Connection by delegate { - - private fun ByteReadPacket.dumpFrameToString(): String { - val length = remaining - return copy().use { it.readFrame(bufferPool).use { it.dump(length) } } - } - - override suspend fun send(packet: ByteReadPacket) { - logger.debug { "Send: ${packet.dumpFrameToString()}" } - delegate.send(packet) - } - - override suspend fun receive(): ByteReadPacket { - val packet = delegate.receive() - logger.debug { "Receive: ${packet.dumpFrameToString()}" } - return packet - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadAssembler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadAssembler.kt new file mode 100644 index 00000000..3fefe658 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadAssembler.kt @@ -0,0 +1,81 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal + +import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* +import io.rsocket.kotlin.payload.* + +// TODO: make metadata should be fully transmitted before data +internal class PayloadAssembler : Closeable { + // TODO: better name + var hasPayload: Boolean = false + private set + private var hasMetadata: Boolean = false + + private val data = BytePacketBuilder(NoPool) + private val metadata = BytePacketBuilder(NoPool) + + fun appendFragment(fragment: Payload) { + hasPayload = true + data.writePacket(fragment.data) + + val meta = fragment.metadata ?: return + hasMetadata = true + metadata.writePacket(meta) + } + + fun assemblePayload(fragment: Payload): Payload { + if (!hasPayload) return fragment + + appendFragment(fragment) + + val payload = Payload( + data = data.build(), + metadata = when { + hasMetadata -> metadata.build() + else -> null + } + ) + hasMetadata = false + hasPayload = false + return payload + } + + override fun close() { + data.close() + metadata.close() + } + + @Suppress("DEPRECATION") + private object NoPool : ObjectPool { + override val capacity: Int get() = error("should not be called") + + override fun borrow(): ChunkBuffer { + error("should not be called") + } + + override fun dispose() { + error("should not be called") + } + + override fun recycle(instance: ChunkBuffer) { + error("should not be called") + } + } +} \ No newline at end of file diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadChannel.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadChannel.kt new file mode 100644 index 00000000..656e34f4 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadChannel.kt @@ -0,0 +1,68 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal + +import io.rsocket.kotlin.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.flow.* + +internal class PayloadChannel { + // TODO: capacity should be configurable + private val payloads = channelForCloseable(Channel.UNLIMITED) + private val requestNs = Channel(Channel.UNLIMITED) + + suspend fun nextRequestN(): Int? = requestNs.receiveCatching().getOrNull() + + @OptIn(DelicateCoroutinesApi::class) + val isActive: Boolean get() = !payloads.isClosedForSend + + fun trySend(payload: Payload) { + if (payloads.trySend(payload).isFailure) payload.close() + } + + @ExperimentalStreamsApi + suspend fun consumeInto(collector: FlowCollector, strategy: RequestStrategy.Element): Throwable? { + // TODO: requestNs should be cancelled on success path? + payloads.consume { + while (true) { + payloads + .receiveCatching() + .onClosed { return it } + .getOrThrow() // will never throw + .also { collector.emit(it) } // emit frame + + @OptIn(DelicateCoroutinesApi::class) + if (requestNs.isClosedForSend) continue + + val next = strategy.nextRequest() + if (next <= 0) continue + + // if this fails, it's means that requests no longer possible; + // next payloads.receiveCatching() should return a closed state + requestNs.trySend(next) + } + } + } + + fun close(cause: Throwable?) { + requestNs.cancel() + payloads.close(cause) + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadFlow.kt new file mode 100644 index 00000000..180ce1bf --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadFlow.kt @@ -0,0 +1,47 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal + +import io.rsocket.kotlin.* +import io.rsocket.kotlin.payload.* +import kotlinx.atomicfu.* +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.* + +@ExperimentalStreamsApi +internal inline fun payloadFlow( + crossinline block: suspend FlowCollector.(strategy: RequestStrategy.Element, initialRequest: Int) -> Unit, +): Flow = object : PayloadFlow() { + override suspend fun collect(collector: FlowCollector, strategy: RequestStrategy.Element, initialRequest: Int) { + return collector.block(strategy, initialRequest) + } +} + +@ExperimentalStreamsApi +internal abstract class PayloadFlow : Flow { + private val consumed = atomic(false) + + override suspend fun collect(collector: FlowCollector) { + check(!consumed.getAndSet(true)) { "RequestFlow can be collected just once" } + + val strategy = currentCoroutineContext().requestStrategy() + val initialRequest = strategy.firstRequest() + collect(collector, strategy, initialRequest) + } + + abstract suspend fun collect(collector: FlowCollector, strategy: RequestStrategy.Element, initialRequest: Int) +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadLimiter.kt similarity index 55% rename from rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadLimiter.kt index 523f19cf..011373f8 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/PayloadLimiter.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2022 the original author or authors. + * Copyright 2015-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,18 +20,21 @@ import io.rsocket.kotlin.payload.* import kotlinx.atomicfu.* import kotlinx.atomicfu.locks.* import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* -import kotlin.coroutines.* internal suspend inline fun Flow.collectLimiting( - limiter: Limiter, - crossinline action: suspend (value: Payload) -> Unit + limiter: PayloadLimiter, + crossinline action: suspend (value: Payload) -> Unit, ) { collect { payload -> - payload.closeOnError { + try { limiter.useRequest() - action(it) + } catch (cause: Throwable) { + payload.close() + throw cause } + action(payload) } } @@ -47,40 +50,20 @@ internal suspend inline fun Flow.collectLimiting( * The coroutine is resumed when [updateRequests] is called. * */ -internal class Limiter(initial: Int) : SynchronizedObject() { - private val requests: AtomicLong = atomic(initial.toLong()) - private var awaiter: CancellableContinuation? = null +internal class PayloadLimiter(initial: Int) : SynchronizedObject() { + private val requests = atomic(initial) + private val requestNs = Channel(Channel.UNLIMITED) fun updateRequests(n: Int) { if (n <= 0) return - synchronized(this) { - val updatedRequests = requests.value + n.toLong() - if (updatedRequests < 0) { - requests.value = Long.MAX_VALUE - } else { - requests.value = updatedRequests - } - - if (awaiter?.isActive == true) { - awaiter?.resume(Unit) - awaiter = null - } - } + requestNs.trySend(n) } suspend fun useRequest() { - if (requests.decrementAndGet() >= 0) { + if (requests.decrementAndGet() > 0) { currentCoroutineContext().ensureActive() } else { - suspendCancellableCoroutine { continuation -> - synchronized(this) { - if (requests.value >= 0 && continuation.isActive) { - continuation.resume(Unit) - } else { - this.awaiter = continuation - } - } - } + requests.value = requestNs.receive() } } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt deleted file mode 100644 index 59d55511..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2015-2023 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal - -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.io.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.coroutines.selects.* - -private val selectFrame: suspend (Frame) -> Frame = { it } - -internal class Prioritizer { - private val priorityChannel = channelForCloseable(Channel.UNLIMITED) - private val commonChannel = channelForCloseable(Channel.UNLIMITED) - - suspend fun send(frame: Frame) { - currentCoroutineContext().ensureActive() - val channel = if (frame.streamId == 0) priorityChannel else commonChannel - channel.send(frame) - } - - suspend fun receive(): Frame { - priorityChannel.tryReceive().onSuccess { return it } - commonChannel.tryReceive().onSuccess { return it } - return select { - priorityChannel.onReceive(selectFrame) - commonChannel.onReceive(selectFrame) - } - } - - fun close(error: Throwable?) { - priorityChannel.cancelWithCause(error) - commonChannel.cancelWithCause(error) - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt deleted file mode 100644 index e3b91e32..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal - -import io.ktor.utils.io.core.* -import io.rsocket.kotlin.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.handler.* -import io.rsocket.kotlin.internal.io.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.coroutines.flow.* -import kotlin.coroutines.* - -//TODO may be need to move all calls on transport dispatcher -@OptIn(ExperimentalStreamsApi::class) -internal class RSocketRequester( - override val coroutineContext: CoroutineContext, - private val sender: FrameSender, - private val streamsStorage: StreamsStorage, -) : RSocket { - - override suspend fun metadataPush(metadata: ByteReadPacket) { - ensureActiveOrRelease(metadata) - metadata.closeOnError { - sender.sendMetadataPush(metadata) - } - } - - override suspend fun fireAndForget(payload: Payload) { - ensureActiveOrRelease(payload) - - val id = streamsStorage.nextId() - try { - sender.sendRequestPayload(FrameType.RequestFnF, id, payload) - } catch (cause: Throwable) { - payload.close() - if (isActive) sender.sendCancel(id) //if cancelled during fragmentation - throw cause - } - } - - override suspend fun requestResponse(payload: Payload): Payload { - ensureActiveOrRelease(payload) - - val id = streamsStorage.nextId() - - val deferred = CompletableDeferred() - val handler = RequesterRequestResponseFrameHandler(id, streamsStorage, deferred) - streamsStorage.save(id, handler) - - return handler.receiveOrCancel(id, payload) { - sender.sendRequestPayload(FrameType.RequestResponse, id, payload) - deferred.await() - } - } - - override fun requestStream(payload: Payload): Flow = requestFlow { strategy, initialRequest -> - ensureActiveOrRelease(payload) - - val id = streamsStorage.nextId() - - val channel = channelForCloseable(Channel.UNLIMITED) - val handler = RequesterRequestStreamFrameHandler(id, streamsStorage, channel) - streamsStorage.save(id, handler) - - handler.receiveOrCancel(id, payload) { - sender.sendRequestPayload(FrameType.RequestStream, id, payload, initialRequest) - emitAllWithRequestN(channel, strategy) { sender.sendRequestN(id, it) } - } - } - - override fun requestChannel(initPayload: Payload, payloads: Flow): Flow = - requestFlow { strategy, initialRequest -> - ensureActiveOrRelease(initPayload) - - val id = streamsStorage.nextId() - - val channel = channelForCloseable(Channel.UNLIMITED) - val limiter = Limiter(0) - val payloadsJob = Job(this@RSocketRequester.coroutineContext.job) - val handler = RequesterRequestChannelFrameHandler(id, streamsStorage, limiter, payloadsJob, channel) - streamsStorage.save(id, handler) - - handler.receiveOrCancel(id, initPayload) { - sender.sendRequestPayload(FrameType.RequestChannel, id, initPayload, initialRequest) - //TODO lazy? - launch(payloadsJob) { - handler.sendOrFail(id) { - payloads.collectLimiting(limiter) { sender.sendNextPayload(id, it) } - sender.sendCompletePayload(id) - } - } - emitAllWithRequestN(channel, strategy) { sender.sendRequestN(id, it) } - } - } - - private suspend inline fun SendFrameHandler.sendOrFail(id: Int, block: () -> Unit) { - try { - block() - onSendComplete() - } catch (cause: Throwable) { - val isFailed = onSendFailed(cause) - if (isActive && isFailed) sender.sendError(id, cause) - throw cause - } - } - - private suspend inline fun ReceiveFrameHandler.receiveOrCancel(id: Int, payload: Payload, block: () -> T): T { - try { - val result = block() - onReceiveComplete() - return result - } catch (cause: Throwable) { - payload.close() - val isCancelled = onReceiveCancelled(cause) - if (isActive && isCancelled) sender.sendCancel(id) - throw cause - } - } - - private fun ensureActiveOrRelease(closeable: Closeable) { - if (isActive) return - closeable.close() - ensureActive() - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt deleted file mode 100644 index 7eddd007..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright 2015-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal - -import io.ktor.utils.io.core.* -import io.rsocket.kotlin.* -import io.rsocket.kotlin.internal.handler.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* -import kotlin.coroutines.* - -@OptIn(ExperimentalStreamsApi::class) -internal class RSocketResponder( - override val coroutineContext: CoroutineContext, - private val sender: FrameSender, - private val requestHandler: RSocket -) : CoroutineScope { - - fun handleMetadataPush(metadata: ByteReadPacket): Job = launch { - requestHandler.metadataPush(metadata) - }.closeOnCompletion(metadata) - - fun handleFireAndForget(payload: Payload, handler: ResponderFireAndForgetFrameHandler): Job = launch { - try { - requestHandler.fireAndForget(payload) - } finally { - handler.onSendComplete() - } - }.closeOnCompletion(payload) - - fun handleRequestResponse(payload: Payload, id: Int, handler: ResponderRequestResponseFrameHandler): Job = launch { - handler.sendOrFail(id, payload) { - val response = requestHandler.requestResponse(payload) - sender.sendNextCompletePayload(id, response) - } - }.closeOnCompletion(payload) - - fun handleRequestStream(payload: Payload, id: Int, handler: ResponderRequestStreamFrameHandler): Job = launch { - handler.sendOrFail(id, payload) { - requestHandler.requestStream(payload).collectLimiting(handler.limiter) { sender.sendNextPayload(id, it) } - sender.sendCompletePayload(id) - } - }.closeOnCompletion(payload) - - fun handleRequestChannel(payload: Payload, id: Int, handler: ResponderRequestChannelFrameHandler): Job = launch { - val payloads = requestFlow { strategy, initialRequest -> - handler.receiveOrCancel(id) { - sender.sendRequestN(id, initialRequest) - emitAllWithRequestN(handler.channel, strategy) { sender.sendRequestN(id, it) } - } - } - handler.sendOrFail(id, payload) { - requestHandler.requestChannel(payload, payloads) - .collectLimiting(handler.limiter) { sender.sendNextPayload(id, it) } - sender.sendCompletePayload(id) - } - }.closeOnCompletion(payload) - - private suspend inline fun SendFrameHandler.sendOrFail(id: Int, payload: Payload, block: () -> Unit) { - try { - block() - onSendComplete() - } catch (cause: Throwable) { - val isFailed = onSendFailed(cause) - if (currentCoroutineContext().isActive && isFailed) sender.sendError(id, cause) - throw cause - } finally { - payload.close() - } - } - - private suspend inline fun ReceiveFrameHandler.receiveOrCancel(id: Int, block: () -> Unit) { - try { - block() - onReceiveComplete() - } catch (cause: Throwable) { - val isCancelled = onReceiveCancelled(cause) - if (isActive && isCancelled) sender.sendCancel(id) - throw cause - } - } - - private fun Job.closeOnCompletion(closeable: Closeable): Job { - invokeOnCompletion { - closeable.close() - } - return this - } - -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RequestFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RequestFlow.kt deleted file mode 100644 index 31a34e96..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RequestFlow.kt +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal - -import io.rsocket.kotlin.* -import io.rsocket.kotlin.payload.* -import kotlinx.atomicfu.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.coroutines.flow.* - -@ExperimentalStreamsApi -internal inline fun requestFlow( - crossinline block: suspend FlowCollector.(strategy: RequestStrategy.Element, initialRequest: Int) -> Unit -): Flow = object : RequestFlow() { - override suspend fun collect( - collector: FlowCollector, - strategy: RequestStrategy.Element, - initialRequest: Int - ) { - collector.block(strategy, initialRequest) - } -} - -@ExperimentalStreamsApi -internal suspend inline fun FlowCollector.emitAllWithRequestN( - channel: ReceiveChannel, - strategy: RequestStrategy.Element, - crossinline onRequest: suspend (n: Int) -> Unit, -) { - val collector = object : RequestFlowCollector(this, strategy) { - override suspend fun onRequest(n: Int) { - @OptIn(DelicateCoroutinesApi::class) - if (!channel.isClosedForReceive) onRequest(n) - } - } - collector.emitAll(channel) -} - -@ExperimentalStreamsApi -internal abstract class RequestFlow : Flow { - private val consumed = atomic(false) - - override suspend fun collect(collector: FlowCollector) { - check(!consumed.getAndSet(true)) { "RequestFlow can be collected just once" } - - val strategy = currentCoroutineContext().requestStrategy() - val initial = strategy.firstRequest() - collect(collector, strategy, initial) - } - - abstract suspend fun collect( - collector: FlowCollector, - strategy: RequestStrategy.Element, - initialRequest: Int - ) -} - -@ExperimentalStreamsApi -internal abstract class RequestFlowCollector( - private val collector: FlowCollector, - private val strategy: RequestStrategy.Element, -) : FlowCollector { - override suspend fun emit(value: Payload): Unit = value.closeOnError { - collector.emit(value) - val next = strategy.nextRequest() - if (next > 0) onRequest(next) - } - - abstract suspend fun onRequest(n: Int) -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamsStorage.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamsStorage.kt deleted file mode 100644 index e6956a9f..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamsStorage.kt +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal - -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.handler.* -import kotlinx.atomicfu.locks.* - -internal class StreamsStorage( - private val isServer: Boolean, -) : SynchronizedObject() { - private val streamId: StreamId = StreamId(isServer) - private val handlers: IntMap = IntMap() - - fun nextId(): Int = synchronized(this) { streamId.next(handlers) } - fun save(id: Int, handler: FrameHandler) = synchronized(this) { handlers[id] = handler } - fun remove(id: Int): FrameHandler? = synchronized(this) { handlers.remove(id) }?.also(FrameHandler::close) - fun contains(id: Int): Boolean = synchronized(this) { id in handlers } - private fun get(id: Int): FrameHandler? = synchronized(this) { handlers[id] } - - fun cleanup(error: Throwable?) { - val values = synchronized(this) { - val values = handlers.values() - handlers.clear() - values - } - values.forEach { - it.cleanup(error) - it.close() - } - } - - fun handleFrame(frame: Frame, responder: RSocketResponder) { - val id = frame.streamId - when (frame) { - is RequestNFrame -> get(id)?.handleRequestN(frame.requestN) - is CancelFrame -> get(id)?.handleCancel() - is ErrorFrame -> get(id)?.handleError(frame.throwable) - is RequestFrame -> when { - frame.type == FrameType.Payload -> get(id)?.handleRequest(frame) - ?: frame.close() // release on unknown stream id - isServer.xor(id % 2 != 0) -> frame.close() // request frame on wrong stream id - else -> { - val initialRequest = frame.initialRequest - val handler = when (frame.type) { - FrameType.RequestFnF -> ResponderFireAndForgetFrameHandler(id, this, responder) - FrameType.RequestResponse -> ResponderRequestResponseFrameHandler(id, this, responder) - FrameType.RequestStream -> ResponderRequestStreamFrameHandler( - id, - this, - responder, - initialRequest, - ) - FrameType.RequestChannel -> ResponderRequestChannelFrameHandler( - id, - this, - responder, - initialRequest, - ) - else -> error("Wrong request frame type") // should never happen - } - save(id, handler) - handler.handleRequest(frame) - } - } - else -> frame.close() - } - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt deleted file mode 100644 index 6d78f54d..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/FrameHandler.kt +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.handler - -import io.ktor.utils.io.core.* -import io.ktor.utils.io.core.internal.* -import io.ktor.utils.io.pool.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* - -internal abstract class FrameHandler : Closeable { - private val data = BytePacketBuilder(NoPool) - private val metadata = BytePacketBuilder(NoPool) - private var hasMetadata: Boolean = false - - fun handleRequest(frame: RequestFrame) { - if (frame.next || frame.type.isRequestType) handleNextFragment(frame) - if (frame.complete) handleComplete() - } - - private fun handleNextFragment(frame: RequestFrame) { - data.writePacket(frame.payload.data) - when (val meta = frame.payload.metadata) { - null -> Unit - else -> { - hasMetadata = true - metadata.writePacket(meta) - } - } - if (frame.follows && !frame.complete) return - - val payload = Payload(data.build(), if (hasMetadata) metadata.build() else null) - hasMetadata = false - handleNext(payload) - } - - protected abstract fun handleNext(payload: Payload) - protected abstract fun handleComplete() - abstract fun handleError(cause: Throwable) - abstract fun handleCancel() - abstract fun handleRequestN(n: Int) - - abstract fun cleanup(cause: Throwable?) - - override fun close() { - data.close() - metadata.close() - } -} - -internal interface ReceiveFrameHandler { - fun onReceiveComplete() - fun onReceiveCancelled(cause: Throwable): Boolean // if true, then request is cancelled -} - -internal interface SendFrameHandler { - fun onSendComplete() - fun onSendFailed(cause: Throwable): Boolean // if true, then request is failed -} - -internal abstract class RequesterFrameHandler : FrameHandler(), ReceiveFrameHandler { - override fun handleCancel() { - //should be called only for RC - } - - override fun handleRequestN(n: Int) { - //should be called only for RC - } -} - -internal abstract class ResponderFrameHandler : FrameHandler(), SendFrameHandler { - protected var job: Job? = null - - protected abstract fun start(payload: Payload): Job - - final override fun handleNext(payload: Payload) { - if (job == null) job = start(payload) - else handleNextPayload(payload) - } - - protected open fun handleNextPayload(payload: Payload) { - //should be called only for RC - } - - override fun handleComplete() { - //should be called only for RC - } - - override fun handleError(cause: Throwable) { - //should be called only for RC - } -} - -@Suppress("DEPRECATION") -private object NoPool : ObjectPool { - override val capacity: Int - get() = error("should not be called") - - override fun borrow(): ChunkBuffer { - error("should not be called") - } - - override fun dispose() { - error("should not be called") - } - - override fun recycle(instance: ChunkBuffer) { - error("should not be called") - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestChannelFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestChannelFrameHandler.kt deleted file mode 100644 index 27b9b3e0..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestChannelFrameHandler.kt +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.handler - -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.internal.io.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* - -internal class RequesterRequestChannelFrameHandler( - private val id: Int, - private val streamsStorage: StreamsStorage, - private val limiter: Limiter, - private val sender: Job, - private val channel: Channel, -) : RequesterFrameHandler(), SendFrameHandler { - - override fun handleNext(payload: Payload) { - channel.safeTrySend(payload) - } - - override fun handleComplete() { - channel.close() - } - - override fun handleError(cause: Throwable) { - streamsStorage.remove(id) - channel.cancelWithCause(cause) - sender.cancel("Request failed", cause) - } - - override fun handleCancel() { - sender.cancel("Request cancelled") - } - - override fun handleRequestN(n: Int) { - limiter.updateRequests(n) - } - - override fun cleanup(cause: Throwable?) { - channel.cancelWithCause(cause) - sender.cancel("Connection closed", cause) - } - - override fun onReceiveComplete() { - if (!sender.isActive) streamsStorage.remove(id) - } - - override fun onReceiveCancelled(cause: Throwable): Boolean { - val isCancelled = streamsStorage.remove(id) != null - if (isCancelled) sender.cancel("Request cancelled", cause) - return isCancelled - } - - @OptIn(DelicateCoroutinesApi::class) - override fun onSendComplete() { - if (channel.isClosedForSend) streamsStorage.remove(id) - } - - override fun onSendFailed(cause: Throwable): Boolean { - if (sender.isCancelled) return false - - val isFailed = streamsStorage.remove(id) != null - if (isFailed) channel.cancelWithCause(cause) - return isFailed - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestResponseFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestResponseFrameHandler.kt deleted file mode 100644 index dfe734e7..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestResponseFrameHandler.kt +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.handler - -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* - -internal class RequesterRequestResponseFrameHandler( - private val id: Int, - private val streamsStorage: StreamsStorage, - private val deferred: CompletableDeferred, -) : RequesterFrameHandler() { - override fun handleNext(payload: Payload) { - deferred.complete(payload) - } - - override fun handleComplete() { - //ignore - } - - override fun handleError(cause: Throwable) { - streamsStorage.remove(id) - deferred.completeExceptionally(cause) - } - - override fun cleanup(cause: Throwable?) { - deferred.cancel("Connection closed", cause) - } - - override fun onReceiveComplete() { - streamsStorage.remove(id) - } - - override fun onReceiveCancelled(cause: Throwable): Boolean = streamsStorage.remove(id) != null -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestStreamFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestStreamFrameHandler.kt deleted file mode 100644 index 227fd3c3..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestStreamFrameHandler.kt +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.handler - -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.internal.io.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.channels.* - -internal class RequesterRequestStreamFrameHandler( - private val id: Int, - private val streamsStorage: StreamsStorage, - private val channel: Channel, -) : RequesterFrameHandler() { - - override fun handleNext(payload: Payload) { - channel.safeTrySend(payload) - } - - override fun handleComplete() { - channel.close() - } - - override fun handleError(cause: Throwable) { - streamsStorage.remove(id) - channel.cancelWithCause(cause) - } - - override fun cleanup(cause: Throwable?) { - channel.cancelWithCause(cause) - } - - override fun onReceiveComplete() { - streamsStorage.remove(id) - } - - override fun onReceiveCancelled(cause: Throwable): Boolean = streamsStorage.remove(id) != null -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderFireAndForgetFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderFireAndForgetFrameHandler.kt deleted file mode 100644 index abef9352..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderFireAndForgetFrameHandler.kt +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.handler - -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* - -internal class ResponderFireAndForgetFrameHandler( - private val id: Int, - private val streamsStorage: StreamsStorage, - private val responder: RSocketResponder, -) : ResponderFrameHandler() { - - override fun start(payload: Payload): Job = responder.handleFireAndForget(payload, this) - - override fun handleCancel() { - streamsStorage.remove(id) - job?.cancel("Request cancelled") - } - - override fun handleRequestN(n: Int) { - //ignore - } - - override fun cleanup(cause: Throwable?) { - //ignore - } - - override fun onSendComplete() { - streamsStorage.remove(id) - } - - override fun onSendFailed(cause: Throwable): Boolean = false -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestChannelFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestChannelFrameHandler.kt deleted file mode 100644 index ac98537b..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestChannelFrameHandler.kt +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.handler - -import io.rsocket.kotlin.* -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.internal.io.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* - -internal class ResponderRequestChannelFrameHandler( - private val id: Int, - private val streamsStorage: StreamsStorage, - private val responder: RSocketResponder, - initialRequest: Int, -) : ResponderFrameHandler(), ReceiveFrameHandler { - val limiter = Limiter(initialRequest) - val channel = channelForCloseable(Channel.UNLIMITED) - - @OptIn(ExperimentalStreamsApi::class) - override fun start(payload: Payload): Job = responder.handleRequestChannel(payload, id, this) - - override fun handleNextPayload(payload: Payload) { - channel.safeTrySend(payload) - } - - override fun handleComplete() { - channel.close() - } - - override fun handleError(cause: Throwable) { - streamsStorage.remove(id) - channel.cancelWithCause(cause) - } - - override fun handleCancel() { - streamsStorage.remove(id) - val cancelError = CancellationException("Request cancelled") - channel.cancelWithCause(cancelError) - job?.cancel(cancelError) - } - - override fun handleRequestN(n: Int) { - limiter.updateRequests(n) - } - - override fun cleanup(cause: Throwable?) { - channel.cancelWithCause(cause) - } - - override fun onSendComplete() { - @OptIn(DelicateCoroutinesApi::class) - if (channel.isClosedForSend) streamsStorage.remove(id) - } - - override fun onSendFailed(cause: Throwable): Boolean { - val isFailed = streamsStorage.remove(id) != null - if (isFailed) channel.cancelWithCause(cause) - return isFailed - } - - override fun onReceiveComplete() { - val job = this.job!! //always not null here - if (!job.isActive) streamsStorage.remove(id) - } - - override fun onReceiveCancelled(cause: Throwable): Boolean { - val job = this.job!! //always not null here - if (!streamsStorage.contains(id) && job.isActive) job.cancel("Request handling failed [Error frame]", cause) - return !job.isCancelled - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestResponseFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestResponseFrameHandler.kt deleted file mode 100644 index 61084100..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestResponseFrameHandler.kt +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.handler - -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* - -internal class ResponderRequestResponseFrameHandler( - private val id: Int, - private val streamsStorage: StreamsStorage, - private val responder: RSocketResponder, -) : ResponderFrameHandler() { - - override fun start(payload: Payload): Job = responder.handleRequestResponse(payload, id, this) - - override fun handleCancel() { - streamsStorage.remove(id) - job?.cancel("Request cancelled") - } - - override fun handleRequestN(n: Int) { - //ignore - } - - override fun cleanup(cause: Throwable?) { - //ignore - } - - override fun onSendComplete() { - streamsStorage.remove(id) - } - - override fun onSendFailed(cause: Throwable): Boolean = streamsStorage.remove(id) != null -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestStreamFrameHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestStreamFrameHandler.kt deleted file mode 100644 index f7932b46..00000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/handler/ResponderRequestStreamFrameHandler.kt +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2015-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.handler - -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* - -internal class ResponderRequestStreamFrameHandler( - private val id: Int, - private val streamsStorage: StreamsStorage, - private val responder: RSocketResponder, - initialRequest: Int, -) : ResponderFrameHandler() { - val limiter = Limiter(initialRequest) - - override fun start(payload: Payload): Job = responder.handleRequestStream(payload, id, this) - - override fun handleCancel() { - streamsStorage.remove(id) - job?.cancel("Request cancelled") - } - - override fun handleRequestN(n: Int) { - limiter.updateRequests(n) - } - - override fun cleanup(cause: Throwable?) { - //ignore - } - - override fun onSendComplete() { - streamsStorage.remove(id) - } - - override fun onSendFailed(cause: Throwable): Boolean = streamsStorage.remove(id) != null -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/keepalive/KeepAliveHandler.kt similarity index 53% rename from rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/keepalive/KeepAliveHandler.kt index 15c94b1e..c30670c4 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/keepalive/KeepAliveHandler.kt @@ -14,35 +14,45 @@ * limitations under the License. */ -package io.rsocket.kotlin.internal +package io.rsocket.kotlin.keepalive import io.ktor.utils.io.core.* import io.rsocket.kotlin.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.keepalive.* +import io.rsocket.kotlin.connection.* +import io.rsocket.kotlin.transport.* import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlin.time.* +@RSocketTransportApi internal class KeepAliveHandler( private val keepAlive: KeepAlive, - private val sender: FrameSender, + private val connection2: Connection2, + private val connectionScope: CoroutineScope, ) { private val initial = TimeSource.Monotonic.markNow() private fun currentDelayMillis() = initial.elapsedNow().inWholeMilliseconds private val lastMark = atomic(currentDelayMillis()) // mark initial timestamp for keepalive - suspend fun mark(frame: KeepAliveFrame) { - lastMark.value = currentDelayMillis() - if (frame.respond) sender.sendKeepAlive(false, 0, frame.data) - } + init { + // this could be moved to a function like `run` or `start` + connectionScope.launch { + while (true) { + delay(keepAlive.intervalMillis.toLong()) + if (currentDelayMillis() - lastMark.value >= keepAlive.maxLifetimeMillis) + throw RSocketError.ConnectionError("No keep-alive for ${keepAlive.maxLifetimeMillis} ms") - suspend fun tick() { - delay(keepAlive.intervalMillis.toLong()) - if (currentDelayMillis() - lastMark.value >= keepAlive.maxLifetimeMillis) - throw RSocketError.ConnectionError("No keep-alive for ${keepAlive.maxLifetimeMillis} ms") + connection2.sendKeepAlive(true, ByteReadPacket.Empty, 0) + } + } + } - sender.sendKeepAlive(true, 0, ByteReadPacket.Empty) + fun receive(data: ByteReadPacket, respond: Boolean) { + lastMark.value = currentDelayMillis() + // in most cases it will be possible to not suspend at all + if (respond) connectionScope.launch(start = CoroutineStart.UNDISPATCHED) { + connection2.sendKeepAlive(false, data, 0) + } } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/Operation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/Operation.kt new file mode 100644 index 00000000..1197c62d --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/Operation.kt @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.operation + +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.payload.* + +internal interface Operation : OperationInbound { + // after `execute` is completed, no other interactions with the operation are possible + suspend fun execute(outbound: OperationOutbound, requestPayload: Payload) + + // for requester only + responder RC + // should not throw + fun operationFailure(cause: Throwable) {} +} + +internal inline fun Operation.handleExecutionFailure(requestPayload: Payload, block: () -> Unit) { + try { + block() + } catch (cause: Throwable) { + operationFailure(cause) + requestPayload.close() + throw cause + } +} + +internal data class ResponderOperationData( + val streamId: Int, + val requestType: FrameType, + val initialRequest: Int, + val requestPayload: Payload, + val complete: Boolean, +) + +// just marker interface +internal interface RequesterOperation : Operation + +// just marker interface +internal interface ResponderOperation : Operation diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/OperationInbound.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/OperationInbound.kt new file mode 100644 index 00000000..e3eceff5 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/OperationInbound.kt @@ -0,0 +1,101 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.operation + +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* + +internal interface OperationInbound { + fun shouldReceiveFrame(frameType: FrameType): Boolean + + // payload is null when `next` flag was not set + fun receivePayloadFrame(payload: Payload?, complete: Boolean): Unit = error("Payload frame is not expected to be received") + fun receiveRequestNFrame(requestN: Int): Unit = error("RequestN frame is not expected to be received") + fun receiveErrorFrame(cause: Throwable): Unit = error("Error frame is not expected to be received") + fun receiveCancelFrame(): Unit = error("Cancel frame is not expected to be received") + + // for streaming case, when stream will not receive any more frames + fun receiveDone() {} +} + +// TODO: merge into OperationInbound? +internal class OperationFrameHandler(private val inbound: OperationInbound) { + private val assembler = PayloadAssembler() + + fun close() { + assembler.close() + } + + fun handleDone() { + inbound.receiveDone() + } + + fun handleFrame(frame: Frame) { + if (!inbound.shouldReceiveFrame(frame.type)) { + // TODO: replace with logging + println("unexpected frame: $frame") + return frame.close() + } + + when (frame) { + is CancelFrame -> inbound.receiveCancelFrame() + is ErrorFrame -> inbound.receiveErrorFrame(frame.throwable) + is RequestNFrame -> inbound.receiveRequestNFrame(frame.requestN) + is RequestFrame -> { + // TODO: split frames + if (frame.initialRequest != 0) inbound.receiveRequestNFrame(frame.initialRequest) + + val payload = when { + // complete+follows=complete + frame.complete -> when { + frame.next -> assembler.assemblePayload(frame.payload) + // TODO - what if we previously received fragment? + else -> { + check(!assembler.hasPayload) { "wrong combination of frames" } + null + } + } + + frame.next -> when { + // if follows - then it's not the last fragment + frame.follows -> { + assembler.appendFragment(frame.payload) + return + } + + else -> assembler.assemblePayload(frame.payload) + } + + else -> error("wrong flags") + } + + inbound.receivePayloadFrame(payload, frame.complete) +// +// +// // TODO: recheck notes +// // TODO: if there are no fragments saved and there are no following - we can ignore going through buffer +// // TODO: really, fragment could be NULL when `complete` is true, but `next` is false +// if (frame.next || frame.type.isRequestType) appendFragment(frame.payload) +// if (frame.complete) inbound.receivePayloadFrame(assemblePayload(), complete = true) +// else if (!frame.follows) inbound.receivePayloadFrame(assemblePayload(), complete = false) + } + + else -> error("should not happen") + } + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/FrameSender.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/OperationOutbound.kt similarity index 52% rename from rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/FrameSender.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/OperationOutbound.kt index 240936bb..0b67e014 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/FrameSender.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/OperationOutbound.kt @@ -14,14 +14,12 @@ * limitations under the License. */ -package io.rsocket.kotlin.internal +package io.rsocket.kotlin.operation import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.frame.io.* -import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* import kotlin.math.* private const val lengthSize = 3 @@ -29,57 +27,63 @@ private const val headerSize = 6 private const val fragmentOffset = lengthSize + headerSize private const val fragmentOffsetWithMetadata = fragmentOffset + lengthSize -internal class FrameSender( - private val prioritizer: Prioritizer, - private val pool: BufferPool, - private val maxFragmentSize: Int, +internal abstract class OperationOutbound( + protected val streamId: Int, + private val frameCodec: FrameCodec, ) { + // TODO: decide on it + // private var firstRequestFrameSent: Boolean = false - suspend fun sendKeepAlive(respond: Boolean, lastPosition: Long, data: ByteReadPacket): Unit = - prioritizer.send(KeepAliveFrame(respond, lastPosition, data)) + abstract val isClosed: Boolean - suspend fun sendMetadataPush(metadata: ByteReadPacket): Unit = prioritizer.send(MetadataPushFrame(metadata)) + protected abstract suspend fun sendFrame(frame: ByteReadPacket) + private suspend fun sendFrame(frame: Frame): Unit = sendFrame(frameCodec.encodeFrame(frame)) - suspend fun sendCancel(id: Int): Unit = withContext(NonCancellable) { prioritizer.send(CancelFrame(id)) } - suspend fun sendError(id: Int, throwable: Throwable): Unit = - withContext(NonCancellable) { prioritizer.send(ErrorFrame(id, throwable)) } + suspend fun sendError(cause: Throwable) { + return sendFrame(ErrorFrame(streamId, cause)) + } - suspend fun sendRequestN(id: Int, n: Int): Unit = prioritizer.send(RequestNFrame(id, n)) + suspend fun sendCancel() { + return sendFrame(CancelFrame(streamId)) + } - suspend fun sendRequestPayload(type: FrameType, streamId: Int, payload: Payload, initialRequest: Int = 0) { - sendFragmented(type, streamId, payload, false, false, initialRequest) + suspend fun sendRequestN(requestN: Int) { + return sendFrame(RequestNFrame(streamId, requestN)) } - suspend fun sendNextPayload(streamId: Int, payload: Payload) { - sendFragmented(FrameType.Payload, streamId, payload, false, true, 0) + suspend fun sendComplete() { + return sendFrame( + RequestFrame( + type = FrameType.Payload, + streamId = streamId, + follows = false, + complete = true, + next = false, + initialRequest = 0, + payload = Payload.Empty + ) + ) } - suspend fun sendNextCompletePayload(streamId: Int, payload: Payload) { - sendFragmented(FrameType.Payload, streamId, payload, true, true, 0) + suspend fun sendNext(payload: Payload, complete: Boolean) { + return sendRequestPayload(FrameType.Payload, payload, complete, initialRequest = 0) } - suspend fun sendCompletePayload(streamId: Int) { - sendFragmented(FrameType.Payload, streamId, Payload.Empty, true, false, 0) + suspend fun sendRequest(type: FrameType, payload: Payload, complete: Boolean, initialRequest: Int) { + return sendRequestPayload(type, payload, complete, initialRequest) } - private suspend fun sendFragmented( - type: FrameType, - streamId: Int, - payload: Payload, - complete: Boolean, - next: Boolean, - initialRequest: Int - ) { - //TODO release on fail ? + // TODO rework/simplify later + // TODO release on fail ? + private suspend fun sendRequestPayload(type: FrameType, payload: Payload, complete: Boolean, initialRequest: Int) { if (!payload.isFragmentable(type.hasInitialRequest)) { - prioritizer.send(RequestFrame(type, streamId, false, complete, next, initialRequest, payload)) - return + return sendFrame(RequestFrame(type, streamId, false, complete, true, initialRequest, payload)) } val data = payload.data val metadata = payload.metadata - val fragmentSize = maxFragmentSize - fragmentOffset - (if (type.hasInitialRequest) Int.SIZE_BYTES else 0) + val fragmentSize = frameCodec.maxFragmentSize - fragmentOffset - (if (type.hasInitialRequest) Int.SIZE_BYTES else 0) var first = true var remaining = fragmentSize @@ -90,13 +94,13 @@ internal class FrameSender( if (!first) remaining -= lengthSize val length = min(metadata.remaining.toInt(), remaining) remaining -= length - metadata.readPacket(pool, length) + metadata.readPacket(frameCodec.bufferPool, length) } else null val dataFragment = if (remaining > 0 && data.isNotEmpty) { val length = min(data.remaining.toInt(), remaining) remaining -= length - data.readPacket(pool, length) + data.readPacket(frameCodec.bufferPool, length) } else { ByteReadPacket.Empty } @@ -104,7 +108,7 @@ internal class FrameSender( val fType = if (first && type.isRequestType) type else FrameType.Payload val fragment = Payload(dataFragment, metadataFragment) val follows = metadata != null && metadata.isNotEmpty || data.isNotEmpty - prioritizer.send( + sendFrame( RequestFrame( type = fType, streamId = streamId, @@ -120,11 +124,11 @@ internal class FrameSender( } while (follows) } - private fun Payload.isFragmentable(hasInitialRequest: Boolean) = when (maxFragmentSize) { - 0 -> false + private fun Payload.isFragmentable(hasInitialRequest: Boolean) = when (frameCodec.maxFragmentSize) { + 0 -> false else -> when (val meta = metadata) { - null -> data.remaining > maxFragmentSize - fragmentOffset - (if (hasInitialRequest) Int.SIZE_BYTES else 0) - else -> data.remaining + meta.remaining > maxFragmentSize - fragmentOffsetWithMetadata - (if (hasInitialRequest) Int.SIZE_BYTES else 0) + null -> data.remaining > frameCodec.maxFragmentSize - fragmentOffset - (if (hasInitialRequest) Int.SIZE_BYTES else 0) + else -> data.remaining + meta.remaining > frameCodec.maxFragmentSize - fragmentOffsetWithMetadata - (if (hasInitialRequest) Int.SIZE_BYTES else 0) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterFireAndForgetOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterFireAndForgetOperation.kt new file mode 100644 index 00000000..ad290734 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterFireAndForgetOperation.kt @@ -0,0 +1,49 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.operation + +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +internal class RequesterFireAndForgetOperation( + private val requestSentCont: CancellableContinuation, +) : RequesterOperation { + + override suspend fun execute(outbound: OperationOutbound, requestPayload: Payload) { + try { + outbound.sendRequest( + type = FrameType.RequestFnF, + payload = requestPayload, + complete = false, + initialRequest = 0 + ) + requestSentCont.resume(Unit) + } catch (cause: Throwable) { + if (requestSentCont.isActive) requestSentCont.resumeWithException(cause) + if (!outbound.isClosed) withContext(NonCancellable) { outbound.sendCancel() } + throw cause + } + } + + override fun shouldReceiveFrame(frameType: FrameType): Boolean = false + + override fun operationFailure(cause: Throwable) { + if (requestSentCont.isActive) requestSentCont.resumeWithException(cause) + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestChannelOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestChannelOperation.kt new file mode 100644 index 00000000..6dc02b36 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestChannelOperation.kt @@ -0,0 +1,111 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.operation + +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.atomicfu.* +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.* + +internal class RequesterRequestChannelOperation( + private val initialRequestN: Int, + private val requestPayloads: Flow, + private val responsePayloads: PayloadChannel, +) : RequesterOperation { + private val limiter = PayloadLimiter(0) + private var senderJob: Job? by atomic(null) + private var failure: Throwable? = null + + override suspend fun execute(outbound: OperationOutbound, requestPayload: Payload) { + try { + coroutineScope { + outbound.sendRequest( + type = FrameType.RequestChannel, + payload = requestPayload, + complete = false, + initialRequest = initialRequestN + ) + + senderJob = launch { + try { + requestPayloads.collectLimiting(limiter) { payload -> + outbound.sendNext(payload, complete = false) + } + outbound.sendComplete() + } catch (cause: Throwable) { + // senderJob could be cancelled + if (isActive) failure = cause + throw cause // failing senderJob here will fail request + } + } + + try { + while (true) outbound.sendRequestN(responsePayloads.nextRequestN() ?: break) + } catch (cause: Throwable) { + if (!currentCoroutineContext().isActive || !outbound.isClosed) throw cause + } + } + } catch (cause: Throwable) { + if (!outbound.isClosed) withContext(NonCancellable) { + when (val error = failure) { + null -> outbound.sendCancel() + else -> outbound.sendError(error) + } + } + throw cause + } + } + + override fun shouldReceiveFrame(frameType: FrameType): Boolean = when { + responsePayloads.isActive -> frameType == FrameType.Payload || frameType == FrameType.Error + else -> false + } || when { + // TODO: handle cancel, when `senderJob` is not started + senderJob == null || senderJob?.isActive == true -> frameType == FrameType.RequestN || frameType == FrameType.Cancel + else -> false + } + + override fun receiveRequestNFrame(requestN: Int) { + limiter.updateRequests(requestN) + } + + override fun receivePayloadFrame(payload: Payload?, complete: Boolean) { + if (payload != null) responsePayloads.trySend(payload) + if (complete) responsePayloads.close(null) + } + + override fun receiveCancelFrame() { + senderJob?.cancel("Request payloads cancelled") + } + + override fun receiveErrorFrame(cause: Throwable) { + responsePayloads.close(cause) + senderJob?.cancel("Error received from remote", cause) + } + + override fun receiveDone() { + if (responsePayloads.isActive) responsePayloads.close( + IllegalStateException("Unexpected end of stream") + ) + } + + override fun operationFailure(cause: Throwable) { + if (responsePayloads.isActive) responsePayloads.close(cause) + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestResponseOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestResponseOperation.kt new file mode 100644 index 00000000..3722a9aa --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestResponseOperation.kt @@ -0,0 +1,71 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.operation + +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal class RequesterRequestResponseOperation( + private val responseDeferred: CompletableDeferred, +) : RequesterOperation { + + override suspend fun execute(outbound: OperationOutbound, requestPayload: Payload) { + try { + outbound.sendRequest( + type = FrameType.RequestResponse, + payload = requestPayload, + complete = false, + initialRequest = 0 + ) + responseDeferred.join() + } catch (cause: Throwable) { + // TODO: we don't need to send cancel if we have sent no frames + if (!outbound.isClosed) withContext(NonCancellable) { outbound.sendCancel() } + throw cause + } + } + + override fun shouldReceiveFrame(frameType: FrameType): Boolean = when { + responseDeferred.isActive -> frameType === FrameType.Payload || frameType === FrameType.Error + else -> false + } + + override fun receivePayloadFrame(payload: Payload?, complete: Boolean) { + if (payload != null) { + if (!responseDeferred.complete(payload)) payload.close() + } else { + responseDeferred.completeExceptionally( + IllegalStateException("Unexpected request completion: payload should be present for RequestResponse") + ) + } + } + + override fun receiveErrorFrame(cause: Throwable) { + responseDeferred.completeExceptionally(cause) + } + + override fun receiveDone() { + if (responseDeferred.isActive) responseDeferred.completeExceptionally( + IllegalStateException("Unexpected request completion: no response received") + ) + } + + override fun operationFailure(cause: Throwable) { + if (responseDeferred.isActive) responseDeferred.completeExceptionally(cause) + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestStreamOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestStreamOperation.kt new file mode 100644 index 00000000..4e0d559f --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/RequesterRequestStreamOperation.kt @@ -0,0 +1,71 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.operation + +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal class RequesterRequestStreamOperation( + private val initialRequestN: Int, + private val responsePayloads: PayloadChannel, +) : RequesterOperation { + + override suspend fun execute(outbound: OperationOutbound, requestPayload: Payload) { + try { + outbound.sendRequest( + type = FrameType.RequestStream, + payload = requestPayload, + complete = false, + initialRequest = initialRequestN + ) + try { + while (true) outbound.sendRequestN(responsePayloads.nextRequestN() ?: break) + } catch (cause: Throwable) { + if (!currentCoroutineContext().isActive || !outbound.isClosed) throw cause + } + } catch (cause: Throwable) { + if (!outbound.isClosed) withContext(NonCancellable) { outbound.sendCancel() } + throw cause + } + } + + override fun shouldReceiveFrame(frameType: FrameType): Boolean = when { + responsePayloads.isActive -> frameType == FrameType.Payload || frameType == FrameType.Error + else -> false + } + + override fun receivePayloadFrame(payload: Payload?, complete: Boolean) { + if (payload != null) responsePayloads.trySend(payload) + if (complete) responsePayloads.close(null) + } + + override fun receiveErrorFrame(cause: Throwable) { + responsePayloads.close(cause) + } + + override fun receiveDone() { + if (responsePayloads.isActive) responsePayloads.close( + IllegalStateException("Unexpected end of stream") + ) + } + + override fun operationFailure(cause: Throwable) { + if (responsePayloads.isActive) responsePayloads.close(cause) + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderFireAndForgetOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderFireAndForgetOperation.kt new file mode 100644 index 00000000..dd8f8c95 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderFireAndForgetOperation.kt @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.operation + +import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal class ResponderFireAndForgetOperation( + private val requestJob: Job, + private val responder: RSocket, +) : ResponderOperation { + + override suspend fun execute(outbound: OperationOutbound, requestPayload: Payload) { + responder.fireAndForget(requestPayload) + } + + override fun shouldReceiveFrame(frameType: FrameType): Boolean = frameType === FrameType.Cancel + + override fun receiveCancelFrame() { + requestJob.cancel("Request was cancelled by remote party") + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestChannelOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestChannelOperation.kt new file mode 100644 index 00000000..28828bad --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestChannelOperation.kt @@ -0,0 +1,102 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.operation + +import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal class ResponderRequestChannelOperation( + private val requestJob: Job, + private val responder: RSocket, +) : ResponderOperation { + private val limiter = PayloadLimiter(0) + private val requestPayloads = PayloadChannel() + + override suspend fun execute(outbound: OperationOutbound, requestPayload: Payload) { + try { + coroutineScope { + @OptIn(ExperimentalStreamsApi::class) + val requestFlow = payloadFlow { strategy, initialRequest -> + // if requestPayloads flow is consumed after the request is completed - we should fail + ensureActive() + val senderJob = launch { + try { + outbound.sendRequestN(initialRequest) + while (true) outbound.sendRequestN(requestPayloads.nextRequestN() ?: break) + } catch (cause: Throwable) { + // ignore error if outbound was closed - TODO: recheck + if (this@coroutineScope.isActive && outbound.isClosed) return@launch + // send cancel only if the operation is active + if (this@coroutineScope.isActive) withContext(NonCancellable) { outbound.sendCancel() } + throw cause + } + } + + throw try { + requestPayloads.consumeInto(this, strategy) + } catch (cause: Throwable) { + senderJob.cancel() + throw cause + } ?: return@payloadFlow + } + + responder.requestChannel(requestPayload, requestFlow).collectLimiting(limiter) { responsePayload -> + outbound.sendNext(responsePayload, complete = false) + } + outbound.sendComplete() + } + } catch (cause: Throwable) { + requestPayloads.close(cause) + if (currentCoroutineContext().isActive) outbound.sendError(cause) + throw cause + } + } + + override fun shouldReceiveFrame(frameType: FrameType): Boolean = + frameType === FrameType.Cancel || when { + requestPayloads.isActive -> frameType === FrameType.Payload || frameType === FrameType.Error + else -> false + } || frameType === FrameType.RequestN // TODO + + override fun receiveRequestNFrame(requestN: Int) { + limiter.updateRequests(requestN) + } + + override fun receivePayloadFrame(payload: Payload?, complete: Boolean) { + if (payload != null) requestPayloads.trySend(payload) + if (complete) requestPayloads.close(null) + } + + override fun receiveErrorFrame(cause: Throwable) { + requestPayloads.close(cause) + } + + override fun receiveCancelFrame() { + requestJob.cancel("Request was cancelled by remote party") + } + + override fun operationFailure(cause: Throwable) { + requestPayloads.close(cause) + } + + override fun receiveDone() { + if (requestPayloads.isActive) requestPayloads.close(IllegalStateException("Unexpected end of stream")) + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestResponseOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestResponseOperation.kt new file mode 100644 index 00000000..00e20b03 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestResponseOperation.kt @@ -0,0 +1,44 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.operation + +import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal class ResponderRequestResponseOperation( + private val requestJob: Job, + private val responder: RSocket, +) : ResponderOperation { + + override suspend fun execute(outbound: OperationOutbound, requestPayload: Payload) { + try { + val response = responder.requestResponse(requestPayload) + outbound.sendNext(response, complete = true) + } catch (cause: Throwable) { + if (currentCoroutineContext().isActive) outbound.sendError(cause) + throw cause + } + } + + override fun shouldReceiveFrame(frameType: FrameType): Boolean = frameType === FrameType.Cancel + + override fun receiveCancelFrame() { + requestJob.cancel("Request was cancelled by remote party") + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestStreamOperation.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestStreamOperation.kt new file mode 100644 index 00000000..0191d438 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/operation/ResponderRequestStreamOperation.kt @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.operation + +import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.* +import io.rsocket.kotlin.payload.* +import kotlinx.coroutines.* + +internal class ResponderRequestStreamOperation( + private val requestJob: Job, + private val responder: RSocket, +) : ResponderOperation { + private val limiter = PayloadLimiter(0) + + override suspend fun execute(outbound: OperationOutbound, requestPayload: Payload) { + try { + responder.requestStream(requestPayload).collectLimiting(limiter) { responsePayload -> + outbound.sendNext(responsePayload, complete = false) + } + outbound.sendComplete() + } catch (cause: Throwable) { + if (currentCoroutineContext().isActive) outbound.sendError(cause) + throw cause + } + } + + override fun shouldReceiveFrame(frameType: FrameType): Boolean = + frameType === FrameType.RequestN || frameType === FrameType.Cancel + + override fun receiveRequestNFrame(requestN: Int) { + limiter.updateRequests(requestN) + } + + override fun receiveCancelFrame() { + requestJob.cancel("Request was cancelled by remote party") + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketConnection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketConnection.kt new file mode 100644 index 00000000..7b62af0a --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketConnection.kt @@ -0,0 +1,68 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport + +import io.ktor.utils.io.core.* + +// all methods can be called from any thread/context at any time +// should be accessed only internally +// should be implemented only by transports +@RSocketTransportApi +public sealed interface RSocketConnection + +@RSocketTransportApi +public fun interface RSocketConnectionHandler { + public suspend fun handleConnection(connection: RSocketConnection) +} + +@RSocketTransportApi +public interface RSocketSequentialConnection : RSocketConnection { + // TODO: is it needed for connection? + public val isClosedForSend: Boolean + + // throws if frame not sent + // streamId=0 should be sent earlier + public suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) + + // null if no more frames could be received + public suspend fun receiveFrame(): ByteReadPacket? +} + +@RSocketTransportApi +public interface RSocketMultiplexedConnection : RSocketConnection { + public suspend fun createStream(): Stream + public suspend fun acceptStream(): Stream? + + public interface Stream : Closeable { + public val isClosedForSend: Boolean + + // 0 - highest priority + // Int.MAX - lowest priority + public fun setSendPriority(priority: Int) + + // throws if frame not sent + public suspend fun sendFrame(frame: ByteReadPacket) + + // null if no more frames could be received + public suspend fun receiveFrame(): ByteReadPacket? + + // closing stream will send buffered frames (if needed) + // sending/receiving frames will be not possible after it + // should not throw + override fun close() + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketTransport.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketTransport.kt new file mode 100644 index 00000000..0a5e73b1 --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketTransport.kt @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport + +import kotlinx.coroutines.* +import kotlin.coroutines.* + +@SubclassOptInRequired(RSocketTransportApi::class) +public abstract class RSocketTransportFactory>( + @PublishedApi internal val createBuilder: () -> Builder, +) { + @OptIn(RSocketTransportApi::class) + public inline operator fun invoke( + context: CoroutineContext, + configure: Builder.() -> Unit = {}, + ): Transport = createBuilder().apply(configure).buildTransport(context) +} + +@SubclassOptInRequired(RSocketTransportApi::class) +public interface RSocketTransportBuilder { + @RSocketTransportApi + public fun buildTransport(context: CoroutineContext): Transport +} + +@SubclassOptInRequired(RSocketTransportApi::class) +public interface RSocketTransport : CoroutineScope { + // transports should have methods like: + // `fun target(address: SocketAddress): RSocketClientTarget` +} + +@SubclassOptInRequired(RSocketTransportApi::class) +public interface RSocketClientTarget : CoroutineScope { + // cancelling Job will cancel connection + // Job will be completed when the connection is finished + @RSocketTransportApi + public fun connectClient(handler: RSocketConnectionHandler): Job +} + +@SubclassOptInRequired(RSocketTransportApi::class) +public interface RSocketServerTarget : CoroutineScope { + // handler will be called for all new connections + @RSocketTransportApi + public suspend fun startServer(handler: RSocketConnectionHandler): Instance +} + +// cancelling it will cancel server +@SubclassOptInRequired(RSocketTransportApi::class) +public interface RSocketServerInstance : CoroutineScope { + // graceful closing API should be here +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketTransportApi.kt similarity index 54% rename from rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketTransportApi.kt index 83a228e6..05fc3651 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/RSocketTransportApi.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2023 the original author or authors. + * Copyright 2015-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,20 +14,13 @@ * limitations under the License. */ -package io.rsocket.kotlin.internal +package io.rsocket.kotlin.transport -import io.ktor.utils.io.core.* -import kotlinx.coroutines.channels.* - -internal inline fun T.closeOnError(block: (T) -> R): R { - try { - return block(this) - } catch (e: Throwable) { - close() - throw e - } -} - -internal fun SendChannel.safeTrySend(element: E) { - trySend(element).onFailure { element.close() } -} +@RequiresOptIn( + level = RequiresOptIn.Level.ERROR, + message = """ + This is an API which is used to implement transport for RSocket, such as WS or TCP. + This API should not be used from general code + """ +) +public annotation class RSocketTransportApi diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueue.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueue.kt new file mode 100644 index 00000000..3b59c76c --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueue.kt @@ -0,0 +1,73 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.internal + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.selects.* + +private val selectFrame: suspend (ChannelResult) -> ChannelResult = { it } + +@RSocketTransportApi +public class PrioritizationFrameQueue(buffersCapacity: Int) { + private val priorityFrames = channelForCloseable(buffersCapacity) + private val normalFrames = channelForCloseable(buffersCapacity) + + private val priorityOnReceive = priorityFrames.onReceiveCatching + private val normalOnReceive = normalFrames.onReceiveCatching + + // priorityFrames is closed/cancelled first, no need to check `normalFrames` + @OptIn(DelicateCoroutinesApi::class) + public val isClosedForSend: Boolean get() = priorityFrames.isClosedForSend + + private fun channel(streamId: Int): SendChannel = when (streamId) { + 0 -> priorityFrames + else -> normalFrames + } + + public suspend fun enqueueFrame(streamId: Int, frame: ByteReadPacket): Unit = channel(streamId).send(frame) + + public fun tryDequeueFrame(): ByteReadPacket? { + // priority is first + priorityFrames.tryReceive().onSuccess { return it } + normalFrames.tryReceive().onSuccess { return it } + return null + } + + // TODO: recheck, that it works fine in case priority channel is closed, but normal channel has other frames to send + public suspend fun dequeueFrame(): ByteReadPacket? { + tryDequeueFrame()?.let { return it } + return select { + priorityOnReceive(selectFrame) + normalOnReceive(selectFrame) + }.getOrNull() + } + + // TODO: document + public fun close() { + priorityFrames.close() + normalFrames.close() + } + + public fun cancel() { + priorityFrames.cancel() + normalFrames.cancel() + } +} diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt index 69eee240..c5e44ced 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt @@ -24,9 +24,26 @@ import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.test.* import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* +import kotlin.coroutines.* import kotlin.test.* class ConnectionEstablishmentTest : SuspendTest, TestWithLeakCheck { + + private class TestInstance(val deferred: Deferred) : RSocketServerInstance { + override val coroutineContext: CoroutineContext get() = deferred + } + + private class TestServer( + override val coroutineContext: CoroutineContext, + private val connection: RSocketConnection, + ) : RSocketServerTarget { + override suspend fun startServer(handler: RSocketConnectionHandler): TestInstance { + return TestInstance(async { + handler.handleConnection(connection) + }) + } + } + @Ignore // it will be rewritten anyway @Test fun responderRejectSetup() = test { @@ -35,14 +52,12 @@ class ConnectionEstablishmentTest : SuspendTest, TestWithLeakCheck { val connection = TestConnection() - val serverTransport = ServerTransport { accept -> - GlobalScope.async { accept(connection) } - } + val serverTransport = TestServer(Dispatchers.Unconfined, connection) - val deferred = TestServer().bind(serverTransport) { + val deferred = TestServer().startServer(serverTransport) { sendingRSocket.complete(requester) error(errorMessage) - } + }.deferred connection.sendToReceiver( SetupFrame( diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/TestConnection.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/TestConnection.kt index 74f89c76..107e2cfe 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/TestConnection.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/TestConnection.kt @@ -30,27 +30,33 @@ import kotlin.test.* import kotlin.time.* import kotlin.time.Duration.Companion.seconds -class TestConnection : Connection, ClientTransport { - override val coroutineContext: CoroutineContext = - Job() + Dispatchers.Unconfined + TestExceptionHandler +class TestConnection : RSocketSequentialConnection, RSocketClientTarget { + private val job = Job() + override val coroutineContext: CoroutineContext = job + Dispatchers.Unconfined + TestExceptionHandler private val sendChannel = channelForCloseable(Channel.UNLIMITED) private val receiveChannel = channelForCloseable(Channel.UNLIMITED) init { coroutineContext.job.invokeOnCompletion { - sendChannel.close(it) + sendChannel.cancelWithCause(it) receiveChannel.cancelWithCause(it) } } - override suspend fun connect(): Connection = this + override fun connectClient(handler: RSocketConnectionHandler): Job = launch { + handler.handleConnection(this@TestConnection) + }.onCompletion { + if (it != null) job.completeExceptionally(it) + } + + override val isClosedForSend: Boolean get() = sendChannel.isClosedForSend - override suspend fun send(packet: ByteReadPacket) { - sendChannel.send(packet) + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + sendChannel.send(frame) } - override suspend fun receive(): ByteReadPacket { + override suspend fun receiveFrame(): ByteReadPacket { return receiveChannel.receive() } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/StreamIdTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/connection/StreamIdGeneratorTest.kt similarity index 84% rename from rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/StreamIdTest.kt rename to rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/connection/StreamIdGeneratorTest.kt index 40732282..b0c5e891 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/StreamIdTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/connection/StreamIdGeneratorTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2022 the original author or authors. + * Copyright 2015-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,16 +14,17 @@ * limitations under the License. */ -package io.rsocket.kotlin.internal +package io.rsocket.kotlin.connection +import io.rsocket.kotlin.internal.* import kotlin.test.* -class StreamIdTest { +class StreamIdGeneratorTest { private val map = IntMap() @Test fun testClientSequence() { - val streamId = StreamId.client() + val streamId = StreamIdGenerator.client() assertEquals(1, streamId.next(map)) assertEquals(3, streamId.next(map)) assertEquals(5, streamId.next(map)) @@ -31,7 +32,7 @@ class StreamIdTest { @Test fun testServerSequence() { - val streamId = StreamId.server() + val streamId = StreamIdGenerator.server() assertEquals(2, streamId.next(map)) assertEquals(4, streamId.next(map)) assertEquals(6, streamId.next(map)) @@ -39,7 +40,7 @@ class StreamIdTest { @Test fun testClientIsValid() { - val streamId = StreamId.client() + val streamId = StreamIdGenerator.client() assertFalse(streamId.isBeforeOrCurrent(1)) assertFalse(streamId.isBeforeOrCurrent(3)) streamId.next(map) @@ -57,7 +58,7 @@ class StreamIdTest { @Test fun testServerIsValid() { - val streamId = StreamId.server() + val streamId = StreamIdGenerator.server() assertFalse(streamId.isBeforeOrCurrent(2)) assertFalse(streamId.isBeforeOrCurrent(4)) streamId.next(map) @@ -75,7 +76,7 @@ class StreamIdTest { @Test fun testWrapOdd() { - val streamId = StreamId(Int.MAX_VALUE - 3) + val streamId = StreamIdGenerator(Int.MAX_VALUE - 3) assertEquals(2147483646, streamId.next(map)) assertEquals(2, streamId.next(map)) assertEquals(4, streamId.next(map)) @@ -83,7 +84,7 @@ class StreamIdTest { @Test fun testWrapEven() { - val streamId = StreamId(Int.MAX_VALUE - 2) + val streamId = StreamIdGenerator(Int.MAX_VALUE - 2) assertEquals(2147483647, streamId.next(map)) assertEquals(1, streamId.next(map)) assertEquals(3, streamId.next(map)) @@ -91,7 +92,7 @@ class StreamIdTest { @Test fun testSkipFound() { - val streamId = StreamId.client() + val streamId = StreamIdGenerator.client() map[5] = "" map[9] = "" assertEquals(1, streamId.next(map)) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt index a8fb0fac..2c70f22b 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt @@ -26,10 +26,60 @@ import io.rsocket.kotlin.transport.local.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* +import kotlin.coroutines.* import kotlin.test.* import kotlin.time.Duration.Companion.seconds -class RSocketTest : SuspendTest, TestWithLeakCheck { +class OldLocalRSocketTest : RSocketTest({ context, acceptor -> + val localServer = TestServer().bindIn( + CoroutineScope(context), + LocalServerTransport(), + acceptor + ) + + TestConnector { + connectionConfig { + keepAlive = KeepAlive(1000.seconds, 1000.seconds) + } + }.connect(localServer) +}) + +class SequentialLocalRSocketTest : RSocketTest({ context, acceptor -> + val localServer = TestServer().startServer( + LocalServerTransport(context) { sequential() }.target(), + acceptor + ) + + TestConnector { + connectionConfig { + keepAlive = KeepAlive(1000.seconds, 1000.seconds) + } + }.connect( + LocalClientTransport(context).target(localServer.serverName) + ) +}) + +class MultiplexedLocalRSocketTest : RSocketTest({ context, acceptor -> + val localServer = TestServer().startServer( + LocalServerTransport(context) { multiplexed() }.target(), + acceptor + ) + + TestConnector { + connectionConfig { + keepAlive = KeepAlive(1000.seconds, 1000.seconds) + } + }.connect( + LocalClientTransport(context).target(localServer.serverName) + ) +}) + +abstract class RSocketTest( + private val connect: suspend ( + context: CoroutineContext, + acceptor: ConnectionAcceptor, + ) -> RSocket, +) : SuspendTest, TestWithLeakCheck { private val testJob: Job = Job() @@ -39,10 +89,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } private suspend fun start(handler: RSocket? = null): RSocket { - val localServer = TestServer().bindIn( - CoroutineScope(Dispatchers.Unconfined + testJob + TestExceptionHandler), - LocalServerTransport() - ) { + return connect(Dispatchers.Unconfined + testJob + TestExceptionHandler) { handler ?: RSocketRequestHandler { requestResponse { it } requestStream { @@ -51,17 +98,15 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } requestChannel { init, payloads -> init.close() - payloads.onEach { it.close() }.launchIn(this) - flow { repeat(10) { emitOrClose(payload("server got -> [$it]")) } } + flow { + coroutineScope { + payloads.onEach { it.close() }.launchIn(this) + repeat(10) { emitOrClose(payload("server got -> [$it]")) } + } + } } } } - - return TestConnector { - connectionConfig { - keepAlive = KeepAlive(1000.seconds, 1000.seconds) - } - }.connect(localServer) } @Test @@ -132,15 +177,29 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { fun testStreamRequesterError() = test { val requester = start(RSocketRequestHandler { requestStream { - (0..100).asFlow().map { - payload(it.toString()) + it.close() + flow { + repeat(100) { + val payload = payload(it.toString()) + try { + emit(payload) + } catch (cause: Throwable) { + payload.close() + throw cause + } + } } } }) requester.requestStream(payload("HELLO")) .flowOn(PrefetchStrategy(10, 0)) .withIndex() - .onEach { if (it.index == 23) error("oops") } + .onEach { + if (it.index == 23) { + it.value.close() + error("oops") + } + } .map { it.value } .test { repeat(23) { @@ -156,8 +215,17 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { fun testStreamCancel() = test { val requester = start(RSocketRequestHandler { requestStream { - (0..100).asFlow().map { - payload(it.toString()) + it.close() + flow { + repeat(100) { + val payload = payload(it.toString()) + try { + emit(payload) + } catch (cause: Throwable) { + payload.close() + throw cause + } + } } } }) @@ -176,8 +244,16 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { fun testStreamCancelWithChannel() = test { val requester = start(RSocketRequestHandler { requestStream { - (0..100).asFlow().map { - payload(it.toString()) + flow { + repeat(100) { + val payload = payload(it.toString()) + try { + emit(payload) + } catch (cause: Throwable) { + payload.close() + throw cause + } + } } } }) @@ -196,6 +272,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { fun testStreamInitialMaxValue() = test { val requester = start(RSocketRequestHandler { requestStream { + it.close() (0..9).asFlow().map { payload(it.toString()) } @@ -215,6 +292,8 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { fun testStreamRequestN() = test { start(RSocketRequestHandler { requestStream { + // TODO: we should really call close here - + it.close() (0..9).asFlow().map { payload(it.toString()) } } }) @@ -251,10 +330,10 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } }) val request = flow { error("test") } - //TODO - kotlin.runCatching { + // TODO: should requester fail if there was a failure in `request`? + assertFailsWith(IllegalStateException::class) { requester.requestChannel(Payload.Empty, request).collect() - }.also(::println) + } val e = error.await() assertTrue(e is RSocketError.ApplicationError) assertEquals("test", e.message) @@ -376,10 +455,10 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { requesterReceiveChannel.cancel() delay(1000) - assertTrue(requesterSendChannel.isClosedForSend) - assertTrue(responderSendChannel.isClosedForSend) - assertTrue(requesterReceiveChannel.isClosedForReceive) - assertTrue(responderReceiveChannel.isClosedForReceive) + assertTrue(requesterSendChannel.isClosedForSend, "requesterSendChannel.isClosedForSend") + assertTrue(responderSendChannel.isClosedForSend, "responderSendChannel.isClosedForSend") + assertTrue(requesterReceiveChannel.isClosedForReceive, "requesterReceiveChannel.isClosedForReceive") + assertTrue(responderReceiveChannel.isClosedForReceive, "responderReceiveChannel.isClosedForReceive") } private suspend fun initRequestChannel( @@ -413,7 +492,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { private suspend inline fun cancel( requesterChannel: SendChannel, - responderChannel: ReceiveChannel + responderChannel: ReceiveChannel, ) { responderChannel.cancel() delay(100) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt index 635150c3..606d69dc 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2022 the original author or authors. + * Copyright 2015-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ package io.rsocket.kotlin.core import app.cash.turbine.* import io.ktor.utils.io.core.* import io.rsocket.kotlin.* -import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.test.* diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/PrioritizerTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/PrioritizerTest.kt deleted file mode 100644 index 18a31cbb..00000000 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/PrioritizerTest.kt +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Copyright 2015-2022 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal - -import io.ktor.utils.io.core.* -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.test.* -import kotlinx.coroutines.* -import kotlin.test.* - -class PrioritizerTest : SuspendTest, TestWithLeakCheck { - private val prioritizer = Prioritizer() - - @Test - fun testOrdering() = test { - prioritizer.send(CancelFrame(1)) - prioritizer.send(CancelFrame(2)) - prioritizer.send(CancelFrame(3)) - - assertEquals(1, prioritizer.receive().streamId) - assertEquals(2, prioritizer.receive().streamId) - assertEquals(3, prioritizer.receive().streamId) - } - - @Test - fun testOrderingPriority() = test { - prioritizer.send(MetadataPushFrame(ByteReadPacket.Empty)) - prioritizer.send(KeepAliveFrame(true, 0, ByteReadPacket.Empty)) - - assertTrue(prioritizer.receive() is MetadataPushFrame) - assertTrue(prioritizer.receive() is KeepAliveFrame) - } - - @Test - fun testPrioritization() = test { - prioritizer.send(CancelFrame(5)) - prioritizer.send(MetadataPushFrame(ByteReadPacket.Empty)) - prioritizer.send(CancelFrame(1)) - prioritizer.send(MetadataPushFrame(ByteReadPacket.Empty)) - - assertEquals(0, prioritizer.receive().streamId) - assertEquals(0, prioritizer.receive().streamId) - assertEquals(5, prioritizer.receive().streamId) - assertEquals(1, prioritizer.receive().streamId) - } - - @Test - fun testAsyncReceive() = test { - val deferred = CompletableDeferred() - launch(anotherDispatcher) { - deferred.complete(prioritizer.receive()) - } - delay(100) - prioritizer.send(CancelFrame(5)) - assertTrue(deferred.await() is CancelFrame) - } - - @Test - fun testPrioritizationAndOrdering() = test { - prioritizer.send(RequestNFrame(1, 1)) - prioritizer.send(MetadataPushFrame(ByteReadPacket.Empty)) - prioritizer.send(CancelFrame(1)) - prioritizer.send(KeepAliveFrame(true, 0, ByteReadPacket.Empty)) - - assertTrue(prioritizer.receive() is MetadataPushFrame) - assertTrue(prioritizer.receive() is KeepAliveFrame) - assertTrue(prioritizer.receive() is RequestNFrame) - assertTrue(prioritizer.receive() is CancelFrame) - } - - @Test - fun testReleaseOnClose() = test { - val packet = packet("metadata") - val payload = payload("data") - prioritizer.send(MetadataPushFrame(packet)) - prioritizer.send(NextPayloadFrame(1, payload)) - - assertTrue(packet.isNotEmpty) - assertTrue(payload.data.isNotEmpty) - - prioritizer.close(null) - - assertTrue(packet.isEmpty) - assertTrue(payload.data.isEmpty) - } -} diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt index bd2b7499..df3c69dd 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt @@ -290,7 +290,7 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { @Test fun testRequestReplyWithCancel() = test { connection.test { - withTimeoutOrNull(100) { requester.requestResponse(Payload.Empty) } + assertTrue(withTimeoutOrNull(100) { requester.requestResponse(Payload.Empty) } == null) awaitFrame { assertTrue(it is RequestFrame) } awaitFrame { assertTrue(it is CancelFrame) } @@ -413,6 +413,7 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { assertTrue(frame is RequestFrame) } requester.cancel() //cancel requester + awaitFrame { assertTrue(it is ErrorFrame) } awaitError() } delay(100) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt index 532d2151..5f98631b 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketResponderRequestNTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2023 the original author or authors. + * Copyright 2015-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,19 +27,31 @@ import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* +import kotlin.coroutines.* import kotlin.test.* class RSocketResponderRequestNTest : TestWithLeakCheck, TestWithConnection() { private val testJob: Job = Job() - private suspend fun start(handler: RSocket) { - val serverTransport = ServerTransport { accept -> - GlobalScope.async { accept(connection) } + private class TestInstance(val deferred: Deferred) : RSocketServerInstance { + override val coroutineContext: CoroutineContext get() = deferred + } + + private class TestServer( + override val coroutineContext: CoroutineContext, + private val connection: RSocketConnection, + ) : RSocketServerTarget { + override suspend fun startServer(handler: RSocketConnectionHandler): TestInstance { + return TestInstance(async { + handler.handleConnection(connection) + }) } + } - val scope = CoroutineScope(Dispatchers.Unconfined + testJob + TestExceptionHandler) - @Suppress("DeferredResultUnused") - TestServer().bindIn(scope, serverTransport) { + private suspend fun start(handler: RSocket) { + TestServer().startServer( + TestServer(Dispatchers.Unconfined + testJob + TestExceptionHandler, connection) + ) { config.setupPayload.close() handler } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt index 312541d3..3339bc21 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt @@ -28,7 +28,7 @@ import kotlin.time.Duration.Companion.seconds class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { private suspend fun requester( - keepAlive: KeepAlive = KeepAlive(100.milliseconds, 1.seconds) + keepAlive: KeepAlive = KeepAlive(100.milliseconds, 1.seconds), ): RSocket = TestConnector { connectionConfig { this.keepAlive = keepAlive @@ -100,7 +100,8 @@ class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { fun rSocketCanceledOnMissingKeepAliveTicks() = test { val rSocket = requester() connection.test { - while (rSocket.isActive) kotlin.runCatching { awaitItem() } + while (rSocket.isActive) awaitFrame { it is KeepAliveFrame } + awaitError() } @OptIn(InternalCoroutinesApi::class) assertTrue(rSocket.coroutineContext.job.getCancellationException().cause is RSocketError.ConnectionError) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/FrameSenderTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/operation/OperationOutboundTest.kt similarity index 64% rename from rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/FrameSenderTest.kt rename to rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/operation/OperationOutboundTest.kt index b472377c..a57f5b81 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/FrameSenderTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/operation/OperationOutboundTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2022 the original author or authors. + * Copyright 2015-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,28 +14,41 @@ * limitations under the License. */ -package io.rsocket.kotlin.internal +package io.rsocket.kotlin.operation +import io.ktor.utils.io.core.* import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.io.* import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.test.* +import kotlinx.coroutines.channels.* import kotlin.test.* -class FrameSenderTest : SuspendTest, TestWithLeakCheck { +class OperationOutboundTest : SuspendTest, TestWithLeakCheck { + private class Outbound( + streamId: Int, + maxFragmentSize: Int, + ) : OperationOutbound(streamId, FrameCodec(InUseTrackingPool, maxFragmentSize)) { + val frames = channelForCloseable(Channel.BUFFERED) + override val isClosed: Boolean get() = frames.isClosedForSend - private val prioritizer = Prioritizer() - private fun sender(maxFragmentSize: Int) = FrameSender(prioritizer, InUseTrackingPool, maxFragmentSize) + override suspend fun sendFrame(frame: ByteReadPacket) { + frames.send(frame) + } + } + + private fun sender(maxFragmentSize: Int) = Outbound(1, maxFragmentSize) @Test fun testFrameFragmented() = test { val sender = sender(99) - sender.sendNextPayload(1, buildPayload { + sender.sendNext(buildPayload { data("1234567890".repeat(50)) - }) + }, false) repeat(6) { - val frame = prioritizer.receive() + val frame = sender.frames.receive().readFrame(InUseTrackingPool) assertIs(frame) assertTrue(frame.next) assertNull(frame.payload.metadata) @@ -53,12 +66,12 @@ class FrameSenderTest : SuspendTest, TestWithLeakCheck { fun testFrameFragmentedFully() = test { val sender = sender(99) - sender.sendNextPayload(1, buildPayload { + sender.sendNext(buildPayload { data("1234567890".repeat(18)) - }) + }, false) repeat(2) { - val frame = prioritizer.receive() + val frame = sender.frames.receive().readFrame(InUseTrackingPool) assertIs(frame) assertTrue(frame.next) assertNull(frame.payload.metadata) diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestResponseFrameHandlerTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/operation/RequesterRequestResponseOperationTest.kt similarity index 59% rename from rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestResponseFrameHandlerTest.kt rename to rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/operation/RequesterRequestResponseOperationTest.kt index 00557af6..1196b584 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/handler/RequesterRequestResponseFrameHandlerTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/operation/RequesterRequestResponseOperationTest.kt @@ -14,51 +14,47 @@ * limitations under the License. */ -package io.rsocket.kotlin.internal.handler +package io.rsocket.kotlin.operation import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.test.* import kotlinx.coroutines.* import kotlin.test.* -class RequesterRequestResponseFrameHandlerTest : SuspendTest, TestWithLeakCheck { - private val storage = StreamsStorage(true) +// TODO: write better tests +class RequesterRequestResponseOperationTest : SuspendTest, TestWithLeakCheck { private val deferred = CompletableDeferred() - private val handler = - RequesterRequestResponseFrameHandler(1, storage, deferred).also { storage.save(1, it) } + private val operation = RequesterRequestResponseOperation(deferred) + private val handler = OperationFrameHandler(operation) @Test fun testCompleteOnPayloadReceive() = test { - handler.handleRequest(RequestFrame(FrameType.Payload, 1, false, false, true, 0, payload("hello"))) + handler.handleFrame(RequestFrame(FrameType.Payload, 1, false, false, true, 0, payload("hello"))) assertTrue(deferred.isCompleted) assertEquals("hello", deferred.await().data.readText()) - handler.onReceiveComplete() - assertFalse(storage.contains(1)) } @Test fun testFailOnPayloadReceive() = test { - handler.handleError(RSocketError.ApplicationError("failed")) + handler.handleFrame(ErrorFrame(1, RSocketError.ApplicationError("failed"))) assertTrue(deferred.isCompleted) assertFailsWith(RSocketError.ApplicationError::class, "failed") { deferred.await() } - assertFalse(storage.contains(1)) } @Test - fun testFailOnCleanup() = test { - handler.cleanup(IllegalStateException("failed")) + fun testFailOnFailure() = test { + operation.operationFailure(IllegalStateException("failed")) assertTrue(deferred.isCompleted) - assertFailsWith(CancellationException::class, "Connection closed") { deferred.await() } + assertFailsWith(IllegalStateException::class, "failed") { deferred.await() } } @Test fun testReassembly() = test { - handler.handleRequest(RequestFrame(FrameType.Payload, 1, true, false, true, 0, payload("hello"))) + handler.handleFrame(RequestFrame(FrameType.Payload, 1, true, false, true, 0, payload("hello"))) assertFalse(deferred.isCompleted) - handler.handleRequest(RequestFrame(FrameType.Payload, 1, false, false, true, 0, payload(" world"))) + handler.handleFrame(RequestFrame(FrameType.Payload, 1, false, false, true, 0, payload(" world"))) assertTrue(deferred.isCompleted) assertEquals("hello world", deferred.await().data.readText()) } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueueTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueueTest.kt new file mode 100644 index 00000000..e8c9120a --- /dev/null +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/transport/internal/PrioritizationFrameQueueTest.kt @@ -0,0 +1,93 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.internal + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.test.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlin.test.* + +class PrioritizationFrameQueueTest : SuspendTest, TestWithLeakCheck { + private val queue = PrioritizationFrameQueue(Channel.BUFFERED) + + @Test + fun testOrdering() = test { + queue.enqueueFrame(1, packet("1")) + queue.enqueueFrame(2, packet("2")) + queue.enqueueFrame(3, packet("3")) + + assertEquals("1", queue.dequeueFrame()?.readText()) + assertEquals("2", queue.dequeueFrame()?.readText()) + assertEquals("3", queue.dequeueFrame()?.readText()) + } + + @Test + fun testOrderingPriority() = test { + queue.enqueueFrame(0, packet("1")) + queue.enqueueFrame(0, packet("2")) + + assertEquals("1", queue.dequeueFrame()?.readText()) + assertEquals("2", queue.dequeueFrame()?.readText()) + } + + @Test + fun testPrioritization() = test { + queue.enqueueFrame(5, packet("1")) + queue.enqueueFrame(0, packet("2")) + queue.enqueueFrame(1, packet("3")) + queue.enqueueFrame(0, packet("4")) + + assertEquals("2", queue.dequeueFrame()?.readText()) + assertEquals("4", queue.dequeueFrame()?.readText()) + + assertEquals("1", queue.dequeueFrame()?.readText()) + assertEquals("3", queue.dequeueFrame()?.readText()) + } + + @Test + fun testAsyncReceive() = test { + val deferred = CompletableDeferred() + launch(anotherDispatcher) { + deferred.complete(queue.dequeueFrame()) + } + delay(100) + queue.enqueueFrame(5, packet("1")) + assertEquals("1", deferred.await()?.readText()) + } + + @Test + fun testReleaseOnCancel() = test { + val p1 = packet("1") + val p2 = packet("2") + queue.enqueueFrame(0, p1) + queue.enqueueFrame(1, p2) + + assertTrue(p1.isNotEmpty) + assertTrue(p2.isNotEmpty) + + queue.close() + + assertTrue(p1.isNotEmpty) + assertTrue(p2.isNotEmpty) + + queue.cancel() + + assertTrue(p1.isEmpty) + assertTrue(p2.isEmpty) + } +} diff --git a/rsocket-internal-io/api/rsocket-internal-io.api b/rsocket-internal-io/api/rsocket-internal-io.api index 2245d547..9b4bef18 100644 --- a/rsocket-internal-io/api/rsocket-internal-io.api +++ b/rsocket-internal-io/api/rsocket-internal-io.api @@ -22,6 +22,13 @@ public final class io/rsocket/kotlin/internal/io/ChannelsKt { public static final fun channelForCloseable (I)Lkotlinx/coroutines/channels/Channel; } +public final class io/rsocket/kotlin/internal/io/ContextKt { + public static final fun childContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; + public static final fun ensureActive (Lkotlin/coroutines/CoroutineContext;Lkotlin/jvm/functions/Function0;)V + public static final fun onCompletion (Lkotlinx/coroutines/Job;Lkotlin/jvm/functions/Function1;)Lkotlinx/coroutines/Job; + public static final fun supervisorContext (Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext; +} + public final class io/rsocket/kotlin/internal/io/Int24Kt { public static final fun readInt24 (Lio/ktor/utils/io/core/ByteReadPacket;)I public static final fun writeInt24 (Lio/ktor/utils/io/core/BytePacketBuilder;I)V diff --git a/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt b/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt new file mode 100644 index 00000000..dc9b9780 --- /dev/null +++ b/rsocket-internal-io/src/commonMain/kotlin/io/rsocket/kotlin/internal/io/Context.kt @@ -0,0 +1,34 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.internal.io + +import kotlinx.coroutines.* +import kotlin.coroutines.* + +public fun CoroutineContext.supervisorContext(): CoroutineContext = plus(SupervisorJob(get(Job))) +public fun CoroutineContext.childContext(): CoroutineContext = plus(Job(get(Job))) + +public fun T.onCompletion(handler: CompletionHandler): T { + invokeOnCompletion(handler) + return this +} + +public inline fun CoroutineContext.ensureActive(onInactive: () -> Unit) { + if (isActive) return + onInactive() // should not throw + ensureActive() // will throw +} diff --git a/rsocket-ktor/rsocket-ktor-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketConnectionTest.kt b/rsocket-ktor/rsocket-ktor-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketConnectionTest.kt index ecd95027..4f9a8b75 100644 --- a/rsocket-ktor/rsocket-ktor-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketConnectionTest.kt +++ b/rsocket-ktor/rsocket-ktor-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketConnectionTest.kt @@ -26,7 +26,6 @@ import io.rsocket.kotlin.ktor.client.* import io.rsocket.kotlin.ktor.server.* import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.test.* -import io.rsocket.kotlin.transport.tests.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* import kotlin.test.* @@ -38,7 +37,6 @@ import io.rsocket.kotlin.ktor.client.RSocketSupport as ClientRSocketSupport import io.rsocket.kotlin.ktor.server.RSocketSupport as ServerRSocketSupport class WebSocketConnectionTest : SuspendTest, TestWithLeakCheck { - private val port = PortProvider.next() private val client = HttpClient(ClientCIO) { install(ClientWebSockets) install(ClientRSocketSupport) { @@ -52,7 +50,7 @@ class WebSocketConnectionTest : SuspendTest, TestWithLeakCheck { private var responderJob: Job? = null - private val server = embeddedServer(ServerCIO, port) { + private val server = embeddedServer(ServerCIO, port = 0) { install(ServerWebSockets) install(ServerRSocketSupport) { server = TestServer() @@ -87,6 +85,7 @@ class WebSocketConnectionTest : SuspendTest, TestWithLeakCheck { @Test fun testWorks() = test { + val port = server.resolvedConnectors().single().port val rSocket = client.rSocket(port = port) val requesterJob = rSocket.coroutineContext.job diff --git a/rsocket-transport-tests/src/commonMain/kotlin/io/rsocket/kotlin/transport/tests/TransportTest.kt b/rsocket-transport-tests/src/commonMain/kotlin/io/rsocket/kotlin/transport/tests/TransportTest.kt index 98e7f000..b852064f 100644 --- a/rsocket-transport-tests/src/commonMain/kotlin/io/rsocket/kotlin/transport/tests/TransportTest.kt +++ b/rsocket-transport-tests/src/commonMain/kotlin/io/rsocket/kotlin/transport/tests/TransportTest.kt @@ -1,5 +1,5 @@ /* - * Copyright 2015-2023 the original author or authors. + * Copyright 2015-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,6 +46,12 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { protected fun startServer(serverTransport: ServerTransport): T = SERVER.bindIn(testScope, serverTransport, ACCEPTOR) + protected suspend fun connectClient(clientTransport: RSocketClientTarget): RSocket = + CONNECTOR.connect(clientTransport) + + protected suspend fun startServer(serverTransport: RSocketServerTarget): T = + SERVER.startServer(serverTransport, ACCEPTOR) + override suspend fun after() { client.coroutineContext.job.cancelAndJoin() testJob.cancelAndJoin() @@ -83,8 +89,11 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { @Test fun requestChannel1() = test(10.seconds) { - val list = client.requestChannel(payload(0), flowOf(payload(0))).onEach { it.close() }.toList() - assertEquals(1, list.size) + val count = + client.requestChannel(payload(0), flowOf(payload(0))) + .onEach { it.close() } + .count() + assertEquals(1, count) } @Test @@ -92,9 +101,12 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { val request = flow { repeat(3) { emit(payload(it)) } } - val list = - client.requestChannel(payload(0), request).flowOn(PrefetchStrategy(3, 0)).onEach { it.close() }.toList() - assertEquals(3, list.size) + val count = + client.requestChannel(payload(0), request) + .flowOn(PrefetchStrategy(3, 0)) + .onEach { it.close() } + .count() + assertEquals(3, count) } @Test @@ -102,12 +114,12 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { val request = flow { repeat(200) { emit(requesterLargePayload) } } - val list = + val count = client.requestChannel(requesterLargePayload, request) .flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)) .onEach { it.close() } - .toList() - assertEquals(200, list.size) + .count() + assertEquals(200, count) } @Test @@ -116,11 +128,11 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { val request = flow { repeat(20_000) { emit(payload(7)) } } - val list = client.requestChannel(payload(7), request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { + val count = client.requestChannel(payload(7), request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { assertEquals(requesterData, it.data.readText()) assertEquals(requesterMetadata, it.metadata?.readText()) - }.toList() - assertEquals(20_000, list.size) + }.count() + assertEquals(20_000, count) } @Test @@ -129,9 +141,12 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { val request = flow { repeat(200_000) { emit(payload(it)) } } - val list = - client.requestChannel(payload(0), request).flowOn(PrefetchStrategy(10000, 0)).onEach { it.close() }.toList() - assertEquals(200_000, list.size) + val count = + client.requestChannel(payload(0), request) + .flowOn(PrefetchStrategy(10000, 0)) + .onEach { it.close() } + .count() + assertEquals(200_000, count) } @Test @@ -143,9 +158,9 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { } } (0..16).map { - async(Dispatchers.Default) { - val list = client.requestChannel(payload(0), request).onEach { it.close() }.toList() - assertEquals(256, list.size) + async { + val count = client.requestChannel(payload(0), request).onEach { it.close() }.count() + assertEquals(256, count) } }.awaitAll() } @@ -159,20 +174,42 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { } } (0..256).map { - async(Dispatchers.Default) { - val list = client.requestChannel(payload(0), request).onEach { it.close() }.toList() - assertEquals(512, list.size) + async { + val count = client.requestChannel(payload(0), request).onEach { it.close() }.count() + assertEquals(512, count) + } + }.awaitAll() + } + + @Test + @IgnoreNative // slow test + fun requestStreamX16() = test { + (0..16).map { + async { + val count = client.requestStream(payload(0)).onEach { it.close() }.count() + assertEquals(8192, count) } }.awaitAll() } @Test @Ignore //flaky, ignore for now + fun requestStreamX256() = test { + (0..256).map { + async { + val count = client.requestStream(payload(0)).onEach { it.close() }.count() + assertEquals(8192, count) + } + }.awaitAll() + } + + @Test + @IgnoreNative //flaky, ignore for now fun requestChannel500NoLeak() = test { val request = flow { repeat(10_000) { emitOrClose(payload(3)) } } - val list = + val count = client .requestChannel(payload(3), request) .flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)) @@ -180,9 +217,9 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { .onEach { assertEquals(requesterData, it.data.readText()) assertEquals(requesterMetadata, it.metadata?.readText()) - }.toList() - assertEquals(500, list.size) - delay(1000) //TODO: leak check + } + .count() + assertEquals(500, count) } @Test @@ -206,7 +243,7 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { } @Test - @Ignore //flaky, ignore for now + @IgnoreNative //flaky, ignore for now fun requestResponse10000() = test { (1..10000).map { async { client.requestResponse(payload(3)).let(Companion::checkPayload) } }.awaitAll() } @@ -219,29 +256,29 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { @Test fun requestStream5() = test { - val list = - client.requestStream(payload(3)).flowOn(PrefetchStrategy(5, 0)).take(5).onEach { checkPayload(it) }.toList() - assertEquals(5, list.size) + val count = + client.requestStream(payload(3)).flowOn(PrefetchStrategy(5, 0)).take(5).onEach { checkPayload(it) }.count() + assertEquals(5, count) } @Test - fun requestStream10000() = test { - val list = client.requestStream(payload(3)).onEach { checkPayload(it) }.toList() - assertEquals(10000, list.size) + @IgnoreNative + fun requestStream8K() = test { + val count = client.requestStream(payload(3)).onEach { checkPayload(it) }.count() + assertEquals(8192, count) // TODO } @Test - @Ignore //flaky, ignore for now + @IgnoreNative fun requestStream500NoLeak() = test { - val list = + val count = client .requestStream(payload(3)) .flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)) .take(500) .onEach { checkPayload(it) } - .toList() - assertEquals(500, list.size) - delay(1000) //TODO: leak check + .count() + assertEquals(500, count) } companion object { @@ -292,7 +329,7 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { override fun requestStream(payload: Payload): Flow = flow { payload.close() - repeat(10000) { + repeat(8192) { emitOrClose(Payload(packet(responderData), packet(responderMetadata))) } } diff --git a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt index f7065fbf..95667dfb 100644 --- a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt +++ b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpServerTest.kt @@ -19,16 +19,17 @@ package io.rsocket.kotlin.transport.ktor.tcp import io.ktor.network.sockets.* import io.rsocket.kotlin.* import io.rsocket.kotlin.test.* -import io.rsocket.kotlin.transport.tests.* import kotlinx.coroutines.* import kotlin.test.* class TcpServerTest : SuspendTest, TestWithLeakCheck { private val testJob = Job() private val testContext = testJob + TestExceptionHandler - private val address = InetSocketAddress("0.0.0.0", PortProvider.next()) - private val serverTransport = TcpServerTransport(address) - private val clientTransport = TcpClientTransport(address, testContext) + private val serverTransport = TcpServerTransport() + private suspend fun clientTransport(server: TcpServer) = TcpClientTransport( + server.serverSocket.await().localAddress as InetSocketAddress, + testContext + ) override suspend fun after() { testJob.cancelAndJoin() @@ -50,7 +51,7 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { payload(text) } } - }.connect(clientTransport) + }.connect(clientTransport(server)) val client1 = newClient("ok") client1.requestResponse(payload("ok")).close() @@ -86,7 +87,7 @@ class TcpServerTest : SuspendTest, TestWithLeakCheck { }.also { handlers += it } }.also { it.serverSocket.await() } - suspend fun newClient() = TestConnector().connect(clientTransport) + suspend fun newClient() = TestConnector().connect(clientTransport(server)) val client1 = newClient() diff --git a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt index d2410604..a55440e5 100644 --- a/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt +++ b/rsocket-transports/ktor-tcp/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/tcp/TcpTransportTest.kt @@ -21,8 +21,7 @@ import io.rsocket.kotlin.transport.tests.* class TcpTransportTest : TransportTest() { override suspend fun before() { - val address = InetSocketAddress("0.0.0.0", PortProvider.next()) - startServer(TcpServerTransport(address)).serverSocket.await() - client = connectClient(TcpClientTransport(address, testContext)) + val serverSocket = startServer(TcpServerTransport()).serverSocket.await() + client = connectClient(TcpClientTransport(serverSocket.localAddress as InetSocketAddress, testContext)) } } diff --git a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt index 08f5f81d..850b3cd0 100644 --- a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt +++ b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/CIOWebSocketTransportTest.kt @@ -16,31 +16,7 @@ package io.rsocket.kotlin.transport.ktor.websocket.server -import io.rsocket.kotlin.test.* -import kotlin.test.* -import kotlin.time.* -import kotlin.time.Duration.Companion.minutes import io.ktor.client.engine.cio.CIO as ClientCIO import io.ktor.server.cio.CIO as ServerCIO -class CIOWebSocketTransportTest : WebSocketTransportTest(ClientCIO, ServerCIO) { - //on native we need more time here - override val testTimeout: Duration = 5.minutes - - //tests are ignored, because current CIO:native websockets implementation is unstable when working with large frames - @Test - @IgnoreNative - override fun largePayloadFireAndForget10() = super.largePayloadFireAndForget10() - - @Test - @IgnoreNative - override fun largePayloadMetadataPush10() = super.largePayloadMetadataPush10() - - @Test - @IgnoreNative - override fun largePayloadRequestChannel200() = super.largePayloadRequestChannel200() - - @Test - @IgnoreNative - override fun largePayloadRequestResponse100() = super.largePayloadRequestResponse100() -} +class CIOWebSocketTransportTest : WebSocketTransportTest(ClientCIO, ServerCIO) diff --git a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt index d47c9a7d..ffd26b12 100644 --- a/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt +++ b/rsocket-transports/ktor-websocket-server/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/websocket/server/WebSocketTransportTest.kt @@ -26,12 +26,12 @@ abstract class WebSocketTransportTest( private val serverEngine: ApplicationEngineFactory<*, *>, ) : TransportTest() { override suspend fun before() { - val port = PortProvider.next() - startServer( - WebSocketServerTransport(serverEngine, port = port) + val engine = startServer( + WebSocketServerTransport(serverEngine, port = 0) ) + val connector = engine.resolvedConnectors().single() client = connectClient( - WebSocketClientTransport(clientEngine, port = port, context = testContext) + WebSocketClientTransport(clientEngine, port = connector.port, context = testContext) ) } } diff --git a/rsocket-transports/local/api/rsocket-transport-local.api b/rsocket-transports/local/api/rsocket-transport-local.api index 6d416bc0..10c48ca0 100644 --- a/rsocket-transports/local/api/rsocket-transport-local.api +++ b/rsocket-transports/local/api/rsocket-transport-local.api @@ -1,9 +1,44 @@ +public abstract interface class io/rsocket/kotlin/transport/local/LocalClientTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/local/LocalClientTransport$Factory; + public abstract fun target (Ljava/lang/String;)Lio/rsocket/kotlin/transport/RSocketClientTarget; +} + +public final class io/rsocket/kotlin/transport/local/LocalClientTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/local/LocalClientTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V + public fun inheritDispatcher ()V +} + public final class io/rsocket/kotlin/transport/local/LocalServer : io/rsocket/kotlin/transport/ClientTransport { public fun connect (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun getCoroutineContext ()Lkotlin/coroutines/CoroutineContext; } +public abstract interface class io/rsocket/kotlin/transport/local/LocalServerInstance : io/rsocket/kotlin/transport/RSocketServerInstance { + public abstract fun getServerName ()Ljava/lang/String; +} + public final class io/rsocket/kotlin/transport/local/LocalServerKt { public static final fun LocalServerTransport ()Lio/rsocket/kotlin/transport/ServerTransport; } +public abstract interface class io/rsocket/kotlin/transport/local/LocalServerTransport : io/rsocket/kotlin/transport/RSocketTransport { + public static final field Factory Lio/rsocket/kotlin/transport/local/LocalServerTransport$Factory; + public abstract fun target ()Lio/rsocket/kotlin/transport/RSocketServerTarget; + public abstract fun target (Ljava/lang/String;)Lio/rsocket/kotlin/transport/RSocketServerTarget; +} + +public final class io/rsocket/kotlin/transport/local/LocalServerTransport$Factory : io/rsocket/kotlin/transport/RSocketTransportFactory { +} + +public abstract interface class io/rsocket/kotlin/transport/local/LocalServerTransportBuilder : io/rsocket/kotlin/transport/RSocketTransportBuilder { + public abstract fun dispatcher (Lkotlin/coroutines/CoroutineContext;)V + public fun inheritDispatcher ()V + public abstract fun multiplexed (II)V + public static synthetic fun multiplexed$default (Lio/rsocket/kotlin/transport/local/LocalServerTransportBuilder;IIILjava/lang/Object;)V + public abstract fun sequential (I)V + public static synthetic fun sequential$default (Lio/rsocket/kotlin/transport/local/LocalServerTransportBuilder;IILjava/lang/Object;)V +} + diff --git a/rsocket-transports/local/build.gradle.kts b/rsocket-transports/local/build.gradle.kts index edb2dc3d..06689a41 100644 --- a/rsocket-transports/local/build.gradle.kts +++ b/rsocket-transports/local/build.gradle.kts @@ -18,6 +18,7 @@ import rsocketbuild.* plugins { id("rsocketbuild.multiplatform-library") + id("kotlinx-atomicfu") } description = "rsocket-kotlin Local transport implementation" diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalClientTransport.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalClientTransport.kt new file mode 100644 index 00000000..c81e3c24 --- /dev/null +++ b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalClientTransport.kt @@ -0,0 +1,69 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.local + +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +public sealed interface LocalClientTransport : RSocketTransport { + public fun target(serverName: String): RSocketClientTarget + + public companion object Factory : + RSocketTransportFactory(::LocalClientTransportBuilderImpl) +} + +public sealed interface LocalClientTransportBuilder : RSocketTransportBuilder { + public fun dispatcher(context: CoroutineContext) + public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) +} + +private class LocalClientTransportBuilderImpl : LocalClientTransportBuilder { + private var dispatcher: CoroutineContext = Dispatchers.Default + + override fun dispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.dispatcher = context + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): LocalClientTransport = LocalClientTransportImpl( + coroutineContext = context.supervisorContext() + dispatcher, + ) +} + +private class LocalClientTransportImpl( + override val coroutineContext: CoroutineContext, +) : LocalClientTransport { + override fun target(serverName: String): RSocketClientTarget = LocalClientTargetImpl( + coroutineContext = coroutineContext.supervisorContext(), + server = LocalServerRegistry.get(serverName) + ) +} + +private class LocalClientTargetImpl( + override val coroutineContext: CoroutineContext, + private val server: LocalServerInstanceImpl, +) : RSocketClientTarget { + + @RSocketTransportApi + override fun connectClient(handler: RSocketConnectionHandler): Job { + coroutineContext.ensureActive() + return server.connect(clientScope = this, clientHandler = handler) + } +} diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt index d52cab25..f9f57a3a 100644 --- a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt +++ b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt @@ -26,7 +26,9 @@ import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlin.coroutines.* +import kotlin.js.* +@JsName("LocalServerTransport2") // for compatibility with new API public fun LocalServerTransport(): ServerTransport = ServerTransport { accept -> val connections = Channel() val handlerJob = launch { diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerConnector.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerConnector.kt new file mode 100644 index 00000000..d5ea6a0e --- /dev/null +++ b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerConnector.kt @@ -0,0 +1,184 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.local + +import io.ktor.utils.io.core.* +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import io.rsocket.kotlin.transport.internal.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* + +internal sealed class LocalServerConnector { + @RSocketTransportApi + abstract fun connect( + clientScope: CoroutineScope, + clientHandler: RSocketConnectionHandler, + serverScope: CoroutineScope, + serverHandler: RSocketConnectionHandler, + ): Job + + internal class Sequential( + private val prioritizationQueueBuffersCapacity: Int, + ) : LocalServerConnector() { + + @RSocketTransportApi + override fun connect( + clientScope: CoroutineScope, + clientHandler: RSocketConnectionHandler, + serverScope: CoroutineScope, + serverHandler: RSocketConnectionHandler, + ): Job { + val clientToServer = PrioritizationFrameQueue(prioritizationQueueBuffersCapacity) + val serverToClient = PrioritizationFrameQueue(prioritizationQueueBuffersCapacity) + + launchLocalConnection(serverScope, serverToClient, clientToServer, serverHandler) + return launchLocalConnection(clientScope, clientToServer, serverToClient, clientHandler) + } + + @RSocketTransportApi + private fun launchLocalConnection( + scope: CoroutineScope, + outbound: PrioritizationFrameQueue, + inbound: PrioritizationFrameQueue, + handler: RSocketConnectionHandler, + ): Job = scope.launch { + handler.handleConnection(Connection(outbound, inbound)) + }.onCompletion { + outbound.close() + inbound.cancel() + } + + @RSocketTransportApi + private class Connection( + private val outbound: PrioritizationFrameQueue, + private val inbound: PrioritizationFrameQueue, + ) : RSocketSequentialConnection { + override val isClosedForSend: Boolean get() = outbound.isClosedForSend + + override suspend fun sendFrame(streamId: Int, frame: ByteReadPacket) { + return outbound.enqueueFrame(streamId, frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return inbound.dequeueFrame() + } + } + } + + // TODO: better parameters naming + class Multiplexed( + private val streamsQueueCapacity: Int, + private val streamBufferCapacity: Int, + ) : LocalServerConnector() { + @RSocketTransportApi + override fun connect( + clientScope: CoroutineScope, + clientHandler: RSocketConnectionHandler, + serverScope: CoroutineScope, + serverHandler: RSocketConnectionHandler, + ): Job { + val streams = Channels>(streamsQueueCapacity) + + launchLocalConnection(serverScope, streams.serverToClient, streams.clientToServer, serverHandler) + return launchLocalConnection(clientScope, streams.clientToServer, streams.serverToClient, clientHandler) + } + + @RSocketTransportApi + private fun launchLocalConnection( + scope: CoroutineScope, + outbound: SendChannel>, + inbound: ReceiveChannel>, + handler: RSocketConnectionHandler, + ): Job = scope.launch { + handler.handleConnection(Connection(SupervisorJob(coroutineContext.job), outbound, inbound, streamBufferCapacity)) + }.onCompletion { + outbound.close() + inbound.cancel() + } + + @RSocketTransportApi + private class Connection( + private val streamsJob: Job, + private val outbound: SendChannel>, + private val inbound: ReceiveChannel>, + private val streamBufferCapacity: Int, + ) : RSocketMultiplexedConnection { + override suspend fun createStream(): RSocketMultiplexedConnection.Stream { + val frames = Channels(streamBufferCapacity) + + outbound.send(frames) + + return Stream( + parentJob = streamsJob, + outbound = frames.clientToServer, + inbound = frames.serverToClient + ) + } + + override suspend fun acceptStream(): RSocketMultiplexedConnection.Stream? { + val frames = inbound.receiveCatching().getOrNull() ?: return null + + return Stream( + parentJob = streamsJob, + outbound = frames.serverToClient, + inbound = frames.clientToServer + ) + } + } + + @RSocketTransportApi + private class Stream( + parentJob: Job, + private val outbound: SendChannel, + private val inbound: ReceiveChannel, + ) : RSocketMultiplexedConnection.Stream { + private val streamJob = Job(parentJob).onCompletion { + outbound.close() + inbound.cancel() + } + + override fun close() { + streamJob.complete() + } + + @OptIn(DelicateCoroutinesApi::class) + override val isClosedForSend: Boolean get() = outbound.isClosedForSend + + override fun setSendPriority(priority: Int) {} + + override suspend fun sendFrame(frame: ByteReadPacket) { + return outbound.send(frame) + } + + override suspend fun receiveFrame(): ByteReadPacket? { + return inbound.receiveCatching().getOrNull() + } + } + + private class Channels(bufferCapacity: Int) : Closeable { + val clientToServer = channelForCloseable(bufferCapacity) + val serverToClient = channelForCloseable(bufferCapacity) + + // only for undelivered element case + override fun close() { + clientToServer.cancel() + serverToClient.cancel() + } + } + } +} diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerInstanceImpl.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerInstanceImpl.kt new file mode 100644 index 00000000..2892d83f --- /dev/null +++ b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerInstanceImpl.kt @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.local + +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +internal class LocalServerInstanceImpl @RSocketTransportApi constructor( + override val serverName: String, + override val coroutineContext: CoroutineContext, + private val serverHandler: RSocketConnectionHandler, + private val connector: LocalServerConnector, +) : LocalServerInstance { + private val serverScope = CoroutineScope(coroutineContext.supervisorContext()) + + init { + LocalServerRegistry.register(serverName, this) + } + + @RSocketTransportApi + fun connect( + clientScope: CoroutineScope, + clientHandler: RSocketConnectionHandler, + ): Job { + coroutineContext.ensureActive() + + return connector.connect( + clientScope = clientScope, + clientHandler = clientHandler, + serverScope = serverScope, + serverHandler = serverHandler + ) + } +} diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerRegistry.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerRegistry.kt new file mode 100644 index 00000000..b1246df9 --- /dev/null +++ b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerRegistry.kt @@ -0,0 +1,39 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.local + +import kotlinx.atomicfu.locks.* +import kotlinx.coroutines.* + +internal object LocalServerRegistry { + private val lock = SynchronizedObject() + private val instances = mutableMapOf() + + fun register(name: String, target: LocalServerInstanceImpl) { + synchronized(lock) { + check(name !in instances) { "Already registered: $name" } + instances[name] = target + } + target.coroutineContext.job.invokeOnCompletion { + synchronized(lock) { instances.remove(name) } + } + } + + fun get(name: String): LocalServerInstanceImpl = synchronized(lock) { + checkNotNull(instances[name]) { "Cannot find $name" } + } +} diff --git a/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerTransport.kt b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerTransport.kt new file mode 100644 index 00000000..ce1fe323 --- /dev/null +++ b/rsocket-transports/local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServerTransport.kt @@ -0,0 +1,110 @@ +/* + * Copyright 2015-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin.transport.local + +import io.rsocket.kotlin.internal.io.* +import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlin.coroutines.* +import kotlin.random.* + +// TODO: rename to inprocess and more to another module/package later +public sealed interface LocalServerInstance : RSocketServerInstance { + public val serverName: String +} + +public sealed interface LocalServerTransport : RSocketTransport { + public fun target(): RSocketServerTarget + public fun target(serverName: String): RSocketServerTarget + + public companion object Factory : + RSocketTransportFactory(::LocalServerTransportBuilderImpl) +} + +public sealed interface LocalServerTransportBuilder : RSocketTransportBuilder { + public fun dispatcher(context: CoroutineContext) + public fun inheritDispatcher(): Unit = dispatcher(EmptyCoroutineContext) + + public fun sequential( + prioritizationQueueBuffersCapacity: Int = Channel.BUFFERED, + ) + + public fun multiplexed( + streamsQueueCapacity: Int = Channel.BUFFERED, + streamBufferCapacity: Int = Channel.BUFFERED, + ) +} + +private class LocalServerTransportBuilderImpl : LocalServerTransportBuilder { + private var dispatcher: CoroutineContext = Dispatchers.Default + private var connector: LocalServerConnector? = null + + override fun dispatcher(context: CoroutineContext) { + check(context[Job] == null) { "Dispatcher shouldn't contain job" } + this.dispatcher = context + } + + override fun sequential(prioritizationQueueBuffersCapacity: Int) { + connector = LocalServerConnector.Sequential(prioritizationQueueBuffersCapacity) + } + + override fun multiplexed(streamsQueueCapacity: Int, streamBufferCapacity: Int) { + connector = LocalServerConnector.Multiplexed(streamsQueueCapacity, streamBufferCapacity) + } + + @RSocketTransportApi + override fun buildTransport(context: CoroutineContext): LocalServerTransport = LocalServerTransportImpl( + coroutineContext = context.supervisorContext() + dispatcher, + connector = connector ?: LocalServerConnector.Sequential(Channel.BUFFERED) + ) +} + +private class LocalServerTransportImpl( + override val coroutineContext: CoroutineContext, + private val connector: LocalServerConnector, +) : LocalServerTransport { + override fun target(serverName: String): RSocketServerTarget = LocalServerTargetImpl( + serverName = serverName, + coroutineContext = coroutineContext.supervisorContext(), + connector = connector + ) + + @OptIn(ExperimentalStdlibApi::class) + override fun target(): RSocketServerTarget = target( + Random.nextBytes(16).toHexString(HexFormat.UpperCase) + ) +} + +private class LocalServerTargetImpl( + override val coroutineContext: CoroutineContext, + private val serverName: String, + private val connector: LocalServerConnector, +) : RSocketServerTarget { + @RSocketTransportApi + override suspend fun startServer(handler: RSocketConnectionHandler): LocalServerInstance { + currentCoroutineContext().ensureActive() + coroutineContext.ensureActive() + + return LocalServerInstanceImpl( + serverName = serverName, + coroutineContext = coroutineContext.childContext(), + serverHandler = handler, + connector = connector + ) + } +} diff --git a/rsocket-transports/local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt b/rsocket-transports/local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt index 384e1944..606d3a80 100644 --- a/rsocket-transports/local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt +++ b/rsocket-transports/local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt @@ -17,10 +17,42 @@ package io.rsocket.kotlin.transport.local import io.rsocket.kotlin.transport.tests.* +import kotlinx.coroutines.channels.* -class LocalTransportTest : TransportTest() { +class OldLocalTransportTest : TransportTest() { override suspend fun before() { val server = startServer(LocalServerTransport()) client = connectClient(server) } } + +abstract class LocalTransportTest( + private val configure: LocalServerTransportBuilder.() -> Unit, +) : TransportTest() { + override suspend fun before() { + val server = startServer(LocalServerTransport(testContext, configure).target()) + client = connectClient(LocalClientTransport(testContext).target(server.serverName)) + } +} + +class SequentialBufferedLocalTransportTest : LocalTransportTest({ + sequential(prioritizationQueueBuffersCapacity = Channel.BUFFERED) +}) + +class SequentialUnlimitedLocalTransportTest : LocalTransportTest({ + sequential(prioritizationQueueBuffersCapacity = Channel.UNLIMITED) +}) + +class MultiplexedBufferedLocalTransportTest : LocalTransportTest({ + multiplexed( + streamsQueueCapacity = Channel.BUFFERED, + streamBufferCapacity = Channel.BUFFERED + ) +}) + +class MultiplexedUnlimitedLocalTransportTest : LocalTransportTest({ + multiplexed( + streamsQueueCapacity = Channel.UNLIMITED, + streamBufferCapacity = Channel.UNLIMITED + ) +})