Skip to content

Commit

Permalink
cleaning. Added get cid to DtslServerTransport as well
Browse files Browse the repository at this point in the history
  • Loading branch information
topiasjokiniemi-nordic committed Jul 17, 2023
1 parent 57364cc commit 2a8bd16
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class DtlsServer(
private val sessions = mutableMapOf<InetSocketAddress, DtlsState>()
private val cidSize = sslConfig.cidSupplier.next().size
val numberOfSessions get() = sessions.size
fun getCidStateFromAddress(inet: InetSocketAddress): ByteArray? {
fun getSessionCid(inet: InetSocketAddress): ByteArray? {
val dtlsState = sessions[inet] as? DtlsSession
return dtlsState?.sessionContext?.cid
}
Expand Down Expand Up @@ -220,7 +220,7 @@ class DtlsServer(
get() = DtlsSessionContext(
peerCertificateSubject = ctx.peerCertificateSubject,
authenticationContext = authenticationContext,
cid = ctx.ownCid ?: ctx.peerCid
cid = if (ctx.ownCid?.isEmpty() != true) ctx.ownCid else ctx.peerCid
)

init {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,6 @@ class DtlsServerTransport private constructor(
executor.supply {
dtlsServer.putSessionAuthenticationContext(adr, key, value)
}

fun getSessionCid(adr: InetSocketAddress) = dtlsServer.getSessionCid(adr)
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,17 @@ class DtlsServerTest {
fun `should find session cid`() {
// given
val clientSession = clientHandshake()
val cid2 = dtlsServer.getCidStateFromAddress(localAddress(2_5684))
assert(cid2!!.isNotEmpty())
val cid = dtlsServer.getSessionCid(localAddress(2_5684))
assert(cid!!.isNotEmpty())
clientSession.close()
}

@Test
fun `shouldn't find session cid`() {
// given
val clientSession = clientHandshake()
val cid2 = dtlsServer.getCidStateFromAddress(localAddress(1234))
assertEquals(null, cid2)
val cid = dtlsServer.getSessionCid(localAddress(1234))
assertEquals(null, cid)
clientSession.close()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,18 @@ class DtlsServerTransportTest {
assertTrue(server.executor() is ScheduledThreadPoolExecutor)
}

@Test
fun `should return cid`() {
server = DtlsServerTransport.create(conf, sessionStore = sessionStore)
val client = DtlsTransmitter.connect(server, clientConfig).await()

// when
client.send("hello")

val cid = server.getSessionCid(server.localAddress())
assert(cid != null)
}

@Test
fun `should set and use session context`() {
// given
Expand Down

0 comments on commit 2a8bd16

Please sign in to comment.