Skip to content

Commit

Permalink
Merge pull request #197 from Hexastack/196-issue-enhance-web-socket-c…
Browse files Browse the repository at this point in the history
…onnection-security

fix: enhance web-socket connection access
  • Loading branch information
marrouchi authored Oct 12, 2024
2 parents 268663c + ff17a9d commit baf561e
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 43 deletions.
1 change: 1 addition & 0 deletions api/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"test:cov": "jest --coverage --runInBand --detectOpenHandles --forceExit",
"test:debug": "node --inspect-brk -r tsconfig-paths/register -r ts-node/register node_modules/.bin/jest --runInBand",
"test:e2e": "jest --config ./test/jest-e2e.json",
"test:clear": "jest --clearCache",
"typecheck": "tsc --noEmit",
"reset": "npm install && npm run containers:restart",
"reset:hard": "npm clean-install && npm run containers:rebuild",
Expand Down
3 changes: 3 additions & 0 deletions api/src/channel/channel.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ export class ChannelService {
);

if (!req.session?.passport?.user?.id) {
setTimeout(() => {
req.socket.client.conn.close();
}, 300);
throw new UnauthorizedException(
'Only authenticated users are allowed to use this channel',
);
Expand Down
2 changes: 2 additions & 0 deletions api/src/websocket/websocket.gateway.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ describe('WebsocketGateway', () => {
ioClient = io('http://localhost:3000', {
autoConnect: false,
transports: ['websocket', 'polling'],
// path: '/socket.io/?EIO=4&transport=websocket&channel=offline',
query: { EIO: '4', transport: 'websocket', channel: 'offline' },
});

app.listen(3000);
Expand Down
92 changes: 51 additions & 41 deletions api/src/websocket/websocket.gateway.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,60 +207,70 @@ export class WebsocketGateway
// Handle session
this.io.use((client, next) => {
this.logger.verbose('Client connected, attempting to load session.');
if (client.request.headers.cookie) {
const cookies = cookie.parse(client.request.headers.cookie);
if (cookies && config.session.name in cookies) {
const sessionID = cookieParser.signedCookie(
cookies[config.session.name],
config.session.secret,
);
if (sessionID) {
return this.loadSession(sessionID, (err, session) => {
if (err) {
this.logger.warn(
'Unable to load session, creating a new one ...',
err,
);
return this.createAndStoreSession(client, next);
}
client.data.session = session;
client.data.sessionID = sessionID;
next();
});
try {
const { searchParams } = new URL(`ws://localhost${client.request.url}`);
if (client.request.headers.cookie) {
const cookies = cookie.parse(client.request.headers.cookie);
if (cookies && config.session.name in cookies) {
const sessionID = cookieParser.signedCookie(
cookies[config.session.name],
config.session.secret,
);
if (sessionID) {
return this.loadSession(sessionID, (err, session) => {
if (err || !session) {
this.logger.warn(
'Unable to load session, creating a new one ...',
err,
);
if (searchParams.get('channel') === 'offline') {
return this.createAndStoreSession(client, next);
} else {
return next(new Error('Unauthorized: Unknown session ID'));
}
}
client.data.session = session;
client.data.sessionID = sessionID;
next();
});
} else {
return next(new Error('Unable to parse session ID from cookie'));
}
}
} else if (searchParams.get('channel') === 'offline') {
return this.createAndStoreSession(client, next);
} else {
return next(new Error('Unauthorized to connect to WS'));
}
} catch (e) {
this.logger.warn('Something unexpected happening');
return next(e);
}

return this.createAndStoreSession(client, next);
});
}

handleConnection(client: Socket, ..._args: any[]): void {
const { sockets } = this.io.sockets;
const handshake = client.handshake;
const { channel } = handshake.query;
this.logger.log(`Client id: ${client.id} connected`);
this.logger.debug(`Number of connected clients: ${sockets?.size}`);

this.eventEmitter.emit(`hook:websocket:connection`, client);
// @TODO : Revisit once we don't use anymore in frontend
if (!channel) {
const response = new SocketResponse();
client.send(
response
.setHeaders({
'access-control-allow-origin':
config.security.cors.allowOrigins.join(','),
vary: 'Origin',
'access-control-allow-credentials':
config.security.cors.allowCredentials.toString(),
})
.status(200)
.json({
success: true,
}),
);
}
const response = new SocketResponse();
client.send(
response
.setHeaders({
'access-control-allow-origin':
config.security.cors.allowOrigins.join(','),
vary: 'Origin',
'access-control-allow-credentials':
config.security.cors.allowCredentials.toString(),
})
.status(200)
.json({
success: true,
}),
);
}

async handleDisconnect(client: Socket): Promise<void> {
Expand Down
6 changes: 4 additions & 2 deletions frontend/src/app-components/widget/ChatWidget.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import UiChatWidget from "hexabot-widget/src/UiChatWidget";
import { usePathname } from "next/navigation";

import { getAvatarSrc } from "@/components/inbox/helpers/mapMessages";
import { useAuth } from "@/hooks/useAuth";
import { useConfig } from "@/hooks/useConfig";
import i18n from "@/i18n/config";
import { EntityType, RouterType } from "@/services/types";
Expand All @@ -20,9 +21,10 @@ import { ChatWidgetHeader } from "./ChatWidgetHeader";
export const ChatWidget = () => {
const pathname = usePathname();
const { apiUrl } = useConfig();
const { isAuthenticated } = useAuth();
const isVisualEditor = pathname === `/${RouterType.VISUAL_EDITOR}`;

return (
return isAuthenticated ? (
<Box
sx={{
display: isVisualEditor ? "block" : "none",
Expand All @@ -44,5 +46,5 @@ export const ChatWidget = () => {
)}
/>
</Box>
);
) : null;
};
4 changes: 4 additions & 0 deletions frontend/src/hooks/entities/auth-hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { useMutation, useQuery, useQueryClient } from "react-query";
import { EntityType, TMutationOptions } from "@/services/types";
import { ILoginAttributes } from "@/types/auth/login.types";
import { IUser, IUserAttributes, IUserStub } from "@/types/user.types";
import { useSocket } from "@/websocket/socket-hooks";

import { useFind } from "../crud/useFind";
import { useApiClient } from "../useApiClient";
Expand Down Expand Up @@ -45,10 +46,13 @@ export const useLogout = (
>,
) => {
const { apiClient } = useApiClient();
const { socket } = useSocket();

return useMutation({
...options,
async mutationFn() {
socket?.disconnect();

return await apiClient.logout();
},
onSuccess: () => {},
Expand Down

0 comments on commit baf561e

Please sign in to comment.