Skip to content

Commit

Permalink
Support preepmtion of DTLS sessions with CID
Browse files Browse the repository at this point in the history
  • Loading branch information
akolosov-n committed Jul 27, 2023
1 parent 3f0c028 commit 17d1575
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@ class DtlsServer(

return when {
dtlsState is DtlsHandshake -> dtlsState.step(buf)
dtlsState is DtlsSession -> dtlsState.decrypt(buf)
dtlsState is DtlsSession -> {
if (dtlsState.sessionContext.cid?.contentEquals(cid) != true) {
dtlsState.preemptSession()
} else {
dtlsState.decrypt(buf)
}
}

// no session, but dtls packet contains CID
cid != null -> ReceiveResult.CidSessionMissing(cid!!)
Expand Down Expand Up @@ -128,6 +134,7 @@ class DtlsServer(
object DecryptFailed : ReceiveResult
class Decrypted(val packet: Packet<ByteBuffer>) : ReceiveResult
class CidSessionMissing(val cid: ByteArray) : ReceiveResult
object CidSessionPreemption : ReceiveResult
}

private abstract inner class DtlsState(val peerAddress: InetSocketAddress) {
Expand Down Expand Up @@ -293,5 +300,14 @@ class DtlsServer(
private fun reportSessionFinished(reason: DtlsSessionLifecycleCallbacks.Reason, err: Throwable? = null) {
lifecycleCallbacks.sessionFinished(peerAddress, reason, err)
}

fun preemptSession(): ReceiveResult {
lifecycleCallbacks.sessionFinished(peerAddress, DtlsSessionLifecycleCallbacks.Reason.PREEMPTED)
logger.info("[{}] DTLS session (CID:{}) is preempted.", peerAddress, ctx.ownCid?.toHex())
sessions.remove(peerAddress, this)
storeAndClose()

return ReceiveResult.CidSessionPreemption
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ class DtlsServerTransport private constructor(
}
}
}

is DtlsServer.ReceiveResult.CidSessionPreemption -> {
val copyBuf = buf.copy()
receive0(adr, copyBuf, timeout)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.net.InetSocketAddress

interface DtlsSessionLifecycleCallbacks {
enum class Reason {
SUCCEEDED, FAILED, CLOSED, EXPIRED
SUCCEEDED, FAILED, CLOSED, EXPIRED, PREEMPTED
}
fun handshakeStarted(adr: InetSocketAddress) = Unit
fun handshakeFinished(adr: InetSocketAddress, hanshakeStartTimestamp: Long, reason: Reason, throwable: Throwable? = null) = Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.awaitility.kotlin.await
import org.junit.jupiter.api.AfterEach
import org.junit.jupiter.api.Assertions.assertEquals
import org.junit.jupiter.api.Assertions.assertFalse
import org.junit.jupiter.api.Assertions.assertNotNull
import org.junit.jupiter.api.Assertions.assertNull
import org.junit.jupiter.api.Assertions.assertTrue
import org.junit.jupiter.api.Test
Expand Down Expand Up @@ -472,6 +473,55 @@ class DtlsServerTransportTest {
client.close()
}

@Test
fun `should reuse the session interrupted by another client from the same port`() {
server = DtlsServerTransport.create(conf, sessionStore = sessionStore, expireAfter = 5.seconds).listen(echoHandler)

// Client1
val (cid1, session1) = DtlsTransmitter.connect(server, clientConfig, bindPort = 2_5684).await().use {
it.send("hello:client1")
assertEquals("hello:client1:resp", it.receiveString())

Pair(
server.getSessionCid(localAddress(2_5684)).await()!!,
it.saveSession()
)
}

// Client2
val (cid2, session2) = DtlsTransmitter.connect(server, clientConfig, bindPort = 2_5684).await().use {
it.send("hello:client2")
assertEquals("hello:client2:resp", it.receiveString())

Pair(
server.getSessionCid(localAddress(2_5684)).await()!!,
it.saveSession()
)
}

await.untilAsserted {
assertEquals(0, server.numberOfSessions())
}

assertEquals(2, sessionStore.size())

// Client1 again
DtlsTransmitter.create(server.localAddress(), clientConfig.loadSession(cid1, session1, server.localAddress()), 2_5684).use {
it.send("hello:client1")
assertEquals("hello:client1:resp", it.receiveString())
}

// Client2 again
DtlsTransmitter.create(server.localAddress(), clientConfig.loadSession(cid2, session2, server.localAddress()), 2_5684).use {
it.send("hello:client2")
assertEquals("hello:client2:resp", it.receiveString())
}

await.untilAsserted {
assertEquals(0, server.numberOfSessions())
}
}

private fun <T> Transport<T>.dropReceive(drop: (Int) -> Boolean): Transport<T> {
val underlying = this
var i = 0
Expand Down

0 comments on commit 17d1575

Please sign in to comment.