diff --git a/src/e2ee/worker/FrameCryptor.test.ts b/src/e2ee/worker/FrameCryptor.test.ts index 7d270a92ac..b72ba5b846 100644 --- a/src/e2ee/worker/FrameCryptor.test.ts +++ b/src/e2ee/worker/FrameCryptor.test.ts @@ -1,13 +1,109 @@ -import { describe, expect, it } from 'vitest'; -import { isFrameServerInjected } from './FrameCryptor'; +import { afterEach, describe, expect, it, vitest } from 'vitest'; +import { IV_LENGTH, KEY_PROVIDER_DEFAULTS } from '../constants'; +import { CryptorEvent } from '../events'; +import type { KeyProviderOptions } from '../types'; +import { createKeyMaterialFromString } from '../utils'; +import { FrameCryptor, encryptionEnabledMap, isFrameServerInjected } from './FrameCryptor'; +import { ParticipantKeyHandler } from './ParticipantKeyHandler'; + +function mockEncryptedRTCEncodedVideoFrame(keyIndex: number): RTCEncodedVideoFrame { + const trailer = mockFrameTrailer(keyIndex); + const data = new Uint8Array(trailer.length + 10); + data.set(trailer, 10); + return mockRTCEncodedVideoFrame(data); +} + +function mockRTCEncodedVideoFrame(data: Uint8Array): RTCEncodedVideoFrame { + return { + data: data.buffer, + timestamp: vitest.getMockedSystemTime()?.getTime() ?? 0, + type: 'key', + getMetadata(): RTCEncodedVideoFrameMetadata { + return {}; + }, + }; +} + +function mockFrameTrailer(keyIndex: number): Uint8Array { + const frameTrailer = new Uint8Array(2); + + frameTrailer[0] = IV_LENGTH; + frameTrailer[1] = keyIndex; + + return frameTrailer; +} + +class TestUnderlyingSource implements UnderlyingSource { + controller: ReadableStreamController; + + start(controller: ReadableStreamController): void { + this.controller = controller; + } + + write(chunk: T): void { + this.controller.enqueue(chunk as any); + } + + close(): void { + this.controller.close(); + } +} + +class TestUnderlyingSink implements UnderlyingSink { + public chunks: T[] = []; + + write(chunk: T): void { + this.chunks.push(chunk); + } +} + +function prepareParticipantTestDecoder( + participantIdentity: string, + partialKeyProviderOptions: Partial, +): { + keys: ParticipantKeyHandler; + cryptor: FrameCryptor; + input: TestUnderlyingSource; + output: TestUnderlyingSink; +} { + const keyProviderOptions = { ...KEY_PROVIDER_DEFAULTS, ...partialKeyProviderOptions }; + const keys = new ParticipantKeyHandler(participantIdentity, keyProviderOptions); + + encryptionEnabledMap.set(participantIdentity, true); + + const cryptor = new FrameCryptor({ + participantIdentity, + keys, + keyProviderOptions, + sifTrailer: new Uint8Array(), + }); + + const input = new TestUnderlyingSource(); + const output = new TestUnderlyingSink(); + cryptor.setupTransform( + 'decode', + new ReadableStream(input), + new WritableStream(output), + 'testTrack', + ); + + return { keys, cryptor, input, output }; +} describe('FrameCryptor', () => { + const participantIdentity = 'testParticipant'; + + afterEach(() => { + encryptionEnabledMap.clear(); + }); + it('identifies server injected frame correctly', () => { const frameTrailer = new TextEncoder().encode('LKROCKS'); const frameData = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, ...frameTrailer]).buffer; expect(isFrameServerInjected(frameData, frameTrailer)).toBe(true); }); + it('identifies server non server injected frame correctly', () => { const frameTrailer = new TextEncoder().encode('LKROCKS'); const frameData = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, ...frameTrailer, 10]); @@ -16,4 +112,155 @@ describe('FrameCryptor', () => { frameData.fill(0); expect(isFrameServerInjected(frameData.buffer, frameTrailer)).toBe(false); }); + + it('passthrough if participant encryption disabled', async () => { + vitest.useFakeTimers(); + try { + const { input, output } = prepareParticipantTestDecoder(participantIdentity, {}); + + // disable encryption for participant + encryptionEnabledMap.set(participantIdentity, false); + + const frame = mockEncryptedRTCEncodedVideoFrame(1); + + input.write(frame); + await vitest.advanceTimersToNextTimerAsync(); + + expect(output.chunks).toEqual([frame]); + } finally { + vitest.useRealTimers(); + } + }); + + it('passthrough for empty frame', async () => { + vitest.useFakeTimers(); + try { + const { input, output } = prepareParticipantTestDecoder(participantIdentity, {}); + + // empty frame + const frame = mockRTCEncodedVideoFrame(new Uint8Array(0)); + + input.write(frame); + await vitest.advanceTimersToNextTimerAsync(); + + expect(output.chunks).toEqual([frame]); + } finally { + vitest.useRealTimers(); + } + }); + + it('drops frames when invalid key', async () => { + vitest.useFakeTimers(); + try { + const { keys, input, output } = prepareParticipantTestDecoder(participantIdentity, { + failureTolerance: 0, + }); + + expect(keys.hasValidKey).toBe(true); + + await keys.setKey(await createKeyMaterialFromString('password'), 0); + + input.write(mockEncryptedRTCEncodedVideoFrame(1)); + await vitest.advanceTimersToNextTimerAsync(); + + expect(output.chunks).toEqual([]); + expect(keys.hasValidKey).toBe(false); + + // this should still fail as keys are all marked as invalid + input.write(mockEncryptedRTCEncodedVideoFrame(0)); + await vitest.advanceTimersToNextTimerAsync(); + + expect(output.chunks).toEqual([]); + expect(keys.hasValidKey).toBe(false); + } finally { + vitest.useRealTimers(); + } + }); + + it('marks key invalid after too many failures', async () => { + const { keys, cryptor, input } = prepareParticipantTestDecoder(participantIdentity, { + failureTolerance: 1, + }); + + expect(keys.hasValidKey).toBe(true); + + await keys.setKey(await createKeyMaterialFromString('password'), 0); + + vitest.spyOn(keys, 'getKeySet'); + vitest.spyOn(keys, 'decryptionFailure'); + + const errorListener = vitest.fn().mockImplementation((e) => { + console.log('error', e); + }); + cryptor.on(CryptorEvent.Error, errorListener); + + input.write(mockEncryptedRTCEncodedVideoFrame(1)); + + await vitest.waitFor(() => expect(keys.decryptionFailure).toHaveBeenCalled()); + expect(errorListener).toHaveBeenCalled(); + expect(keys.decryptionFailure).toHaveBeenCalledTimes(1); + expect(keys.getKeySet).toHaveBeenCalled(); + expect(keys.getKeySet).toHaveBeenLastCalledWith(1); + expect(keys.hasValidKey).toBe(true); + + vitest.clearAllMocks(); + + input.write(mockEncryptedRTCEncodedVideoFrame(1)); + + await vitest.waitFor(() => expect(keys.decryptionFailure).toHaveBeenCalled()); + expect(errorListener).toHaveBeenCalled(); + expect(keys.decryptionFailure).toHaveBeenCalledTimes(1); + expect(keys.getKeySet).toHaveBeenCalled(); + expect(keys.getKeySet).toHaveBeenLastCalledWith(1); + expect(keys.hasValidKey).toBe(false); + + vitest.clearAllMocks(); + + // this should still fail as keys are all marked as invalid + input.write(mockEncryptedRTCEncodedVideoFrame(0)); + + await vitest.waitFor(() => expect(keys.getKeySet).toHaveBeenCalled()); + // decryptionFailure() isn't called in this case + expect(keys.getKeySet).toHaveBeenCalled(); + expect(keys.getKeySet).toHaveBeenLastCalledWith(0); + expect(keys.hasValidKey).toBe(false); + }); + + it('mark as valid when a new key is set on same index', async () => { + const { keys, input } = prepareParticipantTestDecoder(participantIdentity, { + failureTolerance: 0, + }); + + const material = await createKeyMaterialFromString('password'); + await keys.setKey(material, 0); + + expect(keys.hasValidKey).toBe(true); + + input.write(mockEncryptedRTCEncodedVideoFrame(1)); + + expect(keys.hasValidKey).toBe(false); + + await keys.setKey(material, 0); + + expect(keys.hasValidKey).toBe(true); + }); + + it('mark as valid when a new key is set on new index', async () => { + const { keys, input } = prepareParticipantTestDecoder(participantIdentity, { + failureTolerance: 0, + }); + + const material = await createKeyMaterialFromString('password'); + await keys.setKey(material, 0); + + expect(keys.hasValidKey).toBe(true); + + input.write(mockEncryptedRTCEncodedVideoFrame(1)); + + expect(keys.hasValidKey).toBe(false); + + await keys.setKey(material, 1); + + expect(keys.hasValidKey).toBe(true); + }); }); diff --git a/src/e2ee/worker/FrameCryptor.ts b/src/e2ee/worker/FrameCryptor.ts index fae30e5e0b..129f254256 100644 --- a/src/e2ee/worker/FrameCryptor.ts +++ b/src/e2ee/worker/FrameCryptor.ts @@ -156,8 +156,8 @@ export class FrameCryptor extends BaseFrameCryptor { setupTransform( operation: 'encode' | 'decode', - readable: ReadableStream, - writable: WritableStream, + readable: ReadableStream, + writable: WritableStream, trackId: string, codec?: VideoCodec, ) { diff --git a/src/e2ee/worker/ParticipantKeyHandler.test.ts b/src/e2ee/worker/ParticipantKeyHandler.test.ts index 85a35e0388..43197b70fb 100644 --- a/src/e2ee/worker/ParticipantKeyHandler.test.ts +++ b/src/e2ee/worker/ParticipantKeyHandler.test.ts @@ -34,4 +34,89 @@ describe('ParticipantKeyHandler', () => { await keyHandler.setKey(materialB, 0); expect(keyHandler.getKeySet(0)?.material).toEqual(materialB); }); + + it('marks invalid if more than failureTolerance failures', async () => { + const keyHandler = new ParticipantKeyHandler(participantIdentity, { + ...KEY_PROVIDER_DEFAULTS, + failureTolerance: 2, + }); + expect(keyHandler.hasValidKey).toBe(true); + + // 1 + keyHandler.decryptionFailure(); + expect(keyHandler.hasValidKey).toBe(true); + + // 2 + keyHandler.decryptionFailure(); + expect(keyHandler.hasValidKey).toBe(true); + + // 3 + keyHandler.decryptionFailure(); + expect(keyHandler.hasValidKey).toBe(false); + }); + + it('marks valid on encryption success', async () => { + const keyHandler = new ParticipantKeyHandler(participantIdentity, { + ...KEY_PROVIDER_DEFAULTS, + failureTolerance: 0, + }); + + expect(keyHandler.hasValidKey).toBe(true); + + keyHandler.decryptionFailure(); + + expect(keyHandler.hasValidKey).toBe(false); + + keyHandler.decryptionSuccess(); + + expect(keyHandler.hasValidKey).toBe(true); + }); + + it('marks valid on new key', async () => { + const keyHandler = new ParticipantKeyHandler(participantIdentity, { + ...KEY_PROVIDER_DEFAULTS, + failureTolerance: 0, + }); + + expect(keyHandler.hasValidKey).toBe(true); + + keyHandler.decryptionFailure(); + + expect(keyHandler.hasValidKey).toBe(false); + + await keyHandler.setKey(await createKeyMaterialFromString('passwordA')); + + expect(keyHandler.hasValidKey).toBe(true); + }); + + it('updates currentKeyIndex on new key', async () => { + const keyHandler = new ParticipantKeyHandler(participantIdentity, KEY_PROVIDER_DEFAULTS); + const material = await createKeyMaterialFromString('password'); + + expect(keyHandler.getCurrentKeyIndex()).toBe(0); + + // default is zero + await keyHandler.setKey(material); + expect(keyHandler.getCurrentKeyIndex()).toBe(0); + + // should go to next index + await keyHandler.setKey(material, 1); + expect(keyHandler.getCurrentKeyIndex()).toBe(1); + + // should be able to jump ahead + await keyHandler.setKey(material, 10); + expect(keyHandler.getCurrentKeyIndex()).toBe(10); + }); + + it('allows many failures if failureTolerance is -1', async () => { + const keyHandler = new ParticipantKeyHandler(participantIdentity, { + ...KEY_PROVIDER_DEFAULTS, + failureTolerance: -1, + }); + expect(keyHandler.hasValidKey).toBe(true); + for (let i = 0; i < 100; i++) { + keyHandler.decryptionFailure(); + expect(keyHandler.hasValidKey).toBe(true); + } + }); });