Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some unit testing for frame cryptor key handling #1271

Merged
merged 6 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 249 additions & 2 deletions src/e2ee/worker/FrameCryptor.test.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,109 @@
import { describe, expect, it } from 'vitest';
import { isFrameServerInjected } from './FrameCryptor';
import { afterEach, describe, expect, it, vitest } from 'vitest';
import { IV_LENGTH, KEY_PROVIDER_DEFAULTS } from '../constants';
import { CryptorEvent } from '../events';
import type { KeyProviderOptions } from '../types';
import { createKeyMaterialFromString } from '../utils';
import { FrameCryptor, encryptionEnabledMap, isFrameServerInjected } from './FrameCryptor';
import { ParticipantKeyHandler } from './ParticipantKeyHandler';

function mockEncryptedRTCEncodedVideoFrame(keyIndex: number): RTCEncodedVideoFrame {
const trailer = mockFrameTrailer(keyIndex);
const data = new Uint8Array(trailer.length + 10);
data.set(trailer, 10);
return mockRTCEncodedVideoFrame(data);
}

function mockRTCEncodedVideoFrame(data: Uint8Array): RTCEncodedVideoFrame {
return {
data: data.buffer,
timestamp: vitest.getMockedSystemTime()?.getTime() ?? 0,
type: 'key',
getMetadata(): RTCEncodedVideoFrameMetadata {
return {};
},
};
}

function mockFrameTrailer(keyIndex: number): Uint8Array {
const frameTrailer = new Uint8Array(2);

frameTrailer[0] = IV_LENGTH;
frameTrailer[1] = keyIndex;

return frameTrailer;
}

class TestUnderlyingSource<T> implements UnderlyingSource<T> {
controller: ReadableStreamController<T>;

start(controller: ReadableStreamController<T>): void {
this.controller = controller;
}

write(chunk: T): void {
this.controller.enqueue(chunk as any);
}

close(): void {
this.controller.close();
}
}

class TestUnderlyingSink<T> implements UnderlyingSink<T> {
public chunks: T[] = [];

write(chunk: T): void {
this.chunks.push(chunk);
}
}

function prepareParticipantTestDecoder(
participantIdentity: string,
partialKeyProviderOptions: Partial<KeyProviderOptions>,
): {
keys: ParticipantKeyHandler;
cryptor: FrameCryptor;
input: TestUnderlyingSource<RTCEncodedVideoFrame>;
output: TestUnderlyingSink<RTCEncodedVideoFrame>;
} {
const keyProviderOptions = { ...KEY_PROVIDER_DEFAULTS, ...partialKeyProviderOptions };
const keys = new ParticipantKeyHandler(participantIdentity, keyProviderOptions);

encryptionEnabledMap.set(participantIdentity, true);

const cryptor = new FrameCryptor({
participantIdentity,
keys,
keyProviderOptions,
sifTrailer: new Uint8Array(),
});

const input = new TestUnderlyingSource<RTCEncodedVideoFrame>();
const output = new TestUnderlyingSink<RTCEncodedVideoFrame>();
cryptor.setupTransform(
'decode',
new ReadableStream(input),
new WritableStream(output),
'testTrack',
);

return { keys, cryptor, input, output };
}

describe('FrameCryptor', () => {
const participantIdentity = 'testParticipant';

afterEach(() => {
encryptionEnabledMap.clear();
});

it('identifies server injected frame correctly', () => {
const frameTrailer = new TextEncoder().encode('LKROCKS');
const frameData = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, ...frameTrailer]).buffer;

expect(isFrameServerInjected(frameData, frameTrailer)).toBe(true);
});

it('identifies server non server injected frame correctly', () => {
const frameTrailer = new TextEncoder().encode('LKROCKS');
const frameData = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, ...frameTrailer, 10]);
Expand All @@ -16,4 +112,155 @@ describe('FrameCryptor', () => {
frameData.fill(0);
expect(isFrameServerInjected(frameData.buffer, frameTrailer)).toBe(false);
});

it('passthrough if participant encryption disabled', async () => {
vitest.useFakeTimers();
try {
const { input, output } = prepareParticipantTestDecoder(participantIdentity, {});

// disable encryption for participant
encryptionEnabledMap.set(participantIdentity, false);

const frame = mockEncryptedRTCEncodedVideoFrame(1);

input.write(frame);
await vitest.advanceTimersToNextTimerAsync();

expect(output.chunks).toEqual([frame]);
} finally {
vitest.useRealTimers();
}
});

it('passthrough for empty frame', async () => {
vitest.useFakeTimers();
try {
const { input, output } = prepareParticipantTestDecoder(participantIdentity, {});

// empty frame
const frame = mockRTCEncodedVideoFrame(new Uint8Array(0));

input.write(frame);
await vitest.advanceTimersToNextTimerAsync();

expect(output.chunks).toEqual([frame]);
} finally {
vitest.useRealTimers();
}
});

it('drops frames when invalid key', async () => {
vitest.useFakeTimers();
try {
const { keys, input, output } = prepareParticipantTestDecoder(participantIdentity, {
failureTolerance: 0,
});

expect(keys.hasValidKey).toBe(true);

await keys.setKey(await createKeyMaterialFromString('password'), 0);

input.write(mockEncryptedRTCEncodedVideoFrame(1));
await vitest.advanceTimersToNextTimerAsync();

expect(output.chunks).toEqual([]);
expect(keys.hasValidKey).toBe(false);

// this should still fail as keys are all marked as invalid
input.write(mockEncryptedRTCEncodedVideoFrame(0));
await vitest.advanceTimersToNextTimerAsync();

expect(output.chunks).toEqual([]);
expect(keys.hasValidKey).toBe(false);
} finally {
vitest.useRealTimers();
}
});

it('marks key invalid after too many failures', async () => {
const { keys, cryptor, input } = prepareParticipantTestDecoder(participantIdentity, {
failureTolerance: 1,
});

expect(keys.hasValidKey).toBe(true);

await keys.setKey(await createKeyMaterialFromString('password'), 0);

vitest.spyOn(keys, 'getKeySet');
vitest.spyOn(keys, 'decryptionFailure');

const errorListener = vitest.fn().mockImplementation((e) => {
console.log('error', e);
});
cryptor.on(CryptorEvent.Error, errorListener);

input.write(mockEncryptedRTCEncodedVideoFrame(1));

await vitest.waitFor(() => expect(keys.decryptionFailure).toHaveBeenCalled());
expect(errorListener).toHaveBeenCalled();
expect(keys.decryptionFailure).toHaveBeenCalledTimes(1);
expect(keys.getKeySet).toHaveBeenCalled();
expect(keys.getKeySet).toHaveBeenLastCalledWith(1);
expect(keys.hasValidKey).toBe(true);

vitest.clearAllMocks();

input.write(mockEncryptedRTCEncodedVideoFrame(1));

await vitest.waitFor(() => expect(keys.decryptionFailure).toHaveBeenCalled());
expect(errorListener).toHaveBeenCalled();
expect(keys.decryptionFailure).toHaveBeenCalledTimes(1);
expect(keys.getKeySet).toHaveBeenCalled();
expect(keys.getKeySet).toHaveBeenLastCalledWith(1);
expect(keys.hasValidKey).toBe(false);

vitest.clearAllMocks();

// this should still fail as keys are all marked as invalid
input.write(mockEncryptedRTCEncodedVideoFrame(0));

await vitest.waitFor(() => expect(keys.getKeySet).toHaveBeenCalled());
// decryptionFailure() isn't called in this case
expect(keys.getKeySet).toHaveBeenCalled();
expect(keys.getKeySet).toHaveBeenLastCalledWith(0);
expect(keys.hasValidKey).toBe(false);
});

it('mark as valid when a new key is set on same index', async () => {
const { keys, input } = prepareParticipantTestDecoder(participantIdentity, {
failureTolerance: 0,
});

const material = await createKeyMaterialFromString('password');
await keys.setKey(material, 0);

expect(keys.hasValidKey).toBe(true);

input.write(mockEncryptedRTCEncodedVideoFrame(1));

expect(keys.hasValidKey).toBe(false);

await keys.setKey(material, 0);

expect(keys.hasValidKey).toBe(true);
});

it('mark as valid when a new key is set on new index', async () => {
const { keys, input } = prepareParticipantTestDecoder(participantIdentity, {
failureTolerance: 0,
});

const material = await createKeyMaterialFromString('password');
await keys.setKey(material, 0);

expect(keys.hasValidKey).toBe(true);

input.write(mockEncryptedRTCEncodedVideoFrame(1));

expect(keys.hasValidKey).toBe(false);

await keys.setKey(material, 1);

expect(keys.hasValidKey).toBe(true);
});
});
4 changes: 2 additions & 2 deletions src/e2ee/worker/FrameCryptor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ export class FrameCryptor extends BaseFrameCryptor {

setupTransform(
operation: 'encode' | 'decode',
readable: ReadableStream,
writable: WritableStream,
readable: ReadableStream<RTCEncodedVideoFrame | RTCEncodedAudioFrame>,
writable: WritableStream<RTCEncodedVideoFrame | RTCEncodedAudioFrame>,
trackId: string,
codec?: VideoCodec,
) {
Expand Down
85 changes: 85 additions & 0 deletions src/e2ee/worker/ParticipantKeyHandler.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,89 @@ describe('ParticipantKeyHandler', () => {
await keyHandler.setKey(materialB, 0);
expect(keyHandler.getKeySet(0)?.material).toEqual(materialB);
});

it('marks invalid if more than failureTolerance failures', async () => {
const keyHandler = new ParticipantKeyHandler(participantIdentity, {
...KEY_PROVIDER_DEFAULTS,
failureTolerance: 2,
});
expect(keyHandler.hasValidKey).toBe(true);

// 1
keyHandler.decryptionFailure();
expect(keyHandler.hasValidKey).toBe(true);

// 2
keyHandler.decryptionFailure();
expect(keyHandler.hasValidKey).toBe(true);

// 3
keyHandler.decryptionFailure();
expect(keyHandler.hasValidKey).toBe(false);
});

it('marks valid on encryption success', async () => {
const keyHandler = new ParticipantKeyHandler(participantIdentity, {
...KEY_PROVIDER_DEFAULTS,
failureTolerance: 0,
});

expect(keyHandler.hasValidKey).toBe(true);

keyHandler.decryptionFailure();

expect(keyHandler.hasValidKey).toBe(false);

keyHandler.decryptionSuccess();

expect(keyHandler.hasValidKey).toBe(true);
});

it('marks valid on new key', async () => {
const keyHandler = new ParticipantKeyHandler(participantIdentity, {
...KEY_PROVIDER_DEFAULTS,
failureTolerance: 0,
});

expect(keyHandler.hasValidKey).toBe(true);

keyHandler.decryptionFailure();

expect(keyHandler.hasValidKey).toBe(false);

await keyHandler.setKey(await createKeyMaterialFromString('passwordA'));

expect(keyHandler.hasValidKey).toBe(true);
});

it('updates currentKeyIndex on new key', async () => {
const keyHandler = new ParticipantKeyHandler(participantIdentity, KEY_PROVIDER_DEFAULTS);
const material = await createKeyMaterialFromString('password');

expect(keyHandler.getCurrentKeyIndex()).toBe(0);

// default is zero
await keyHandler.setKey(material);
expect(keyHandler.getCurrentKeyIndex()).toBe(0);

// should go to next index
await keyHandler.setKey(material, 1);
expect(keyHandler.getCurrentKeyIndex()).toBe(1);

// should be able to jump ahead
await keyHandler.setKey(material, 10);
expect(keyHandler.getCurrentKeyIndex()).toBe(10);
});

it('allows many failures if failureTolerance is -1', async () => {
const keyHandler = new ParticipantKeyHandler(participantIdentity, {
...KEY_PROVIDER_DEFAULTS,
failureTolerance: -1,
});
expect(keyHandler.hasValidKey).toBe(true);
for (let i = 0; i < 100; i++) {
keyHandler.decryptionFailure();
expect(keyHandler.hasValidKey).toBe(true);
}
});
});
Loading