diff --git a/README.md b/README.md index 45f4c14..023e3a6 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# thread9-test-bed +# threat9-test-bed ## Installation ```bash diff --git a/tests/service_mocks/test_tcp_service_mock.py b/tests/service_mocks/test_tcp_service_mock.py index 2e6a640..a7f72ba 100644 --- a/tests/service_mocks/test_tcp_service_mock.py +++ b/tests/service_mocks/test_tcp_service_mock.py @@ -3,21 +3,18 @@ from threat9_test_bed.service_mocks import TCPServiceMock -def test_tcp_service_mock_add_banner(): +def test_tcp_service_mock_get_command_mock(): with TCPServiceMock("127.0.0.1", 8023) as target: assert target.host == "127.0.0.1" assert target.port == 8023 + mocked_scoo = target.get_command_mock(b"scoo") + mocked_scoo.return_value = b"bee" mocked_doo = target.get_command_mock(b"doo") mocked_doo.return_value = b"where are you?" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.connect((target.host, target.port)) s.send(b"doo") assert s.recv(1024) == b"where are you?" - - mocked_scoo = target.get_command_mock(b"scoo") - mocked_scoo.return_value = b"bee" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.connect((target.host, target.port)) s.send(b"scoo") assert s.recv(1024) == b"bee" diff --git a/threat9_test_bed/tcp_service/tcp_server.py b/threat9_test_bed/tcp_service/tcp_server.py index 4e741b2..5a4b8a1 100644 --- a/threat9_test_bed/tcp_service/tcp_server.py +++ b/threat9_test_bed/tcp_service/tcp_server.py @@ -7,6 +7,7 @@ class TCPServer(socketserver.ThreadingTCPServer): allow_reuse_address = True + daemon_threads = True def __init__( self, @@ -27,6 +28,7 @@ def get_command_mock(self, command: bytes) -> mock.Mock: class TCPHandler(socketserver.BaseRequestHandler): def handle(self): - data = self.request.recv(1024) - handler = self.server.handlers[data] - self.request.sendall(handler()) + while True: + data = self.request.recv(1024) + handler = self.server.handlers.get(data, lambda: b"") + self.request.sendall(handler())