From 5c27cbcfa64bd83d3563cecdc29d2615dfc66296 Mon Sep 17 00:00:00 2001 From: Sarp Centel Date: Wed, 21 Aug 2024 23:00:09 +0300 Subject: [PATCH] add test for send extension handshake --- internal/message.go | 112 +++++++++++++++++++++++++++++++ internal/stage_magnet_helpers.go | 104 ++++++++++++++++++++++++++++ internal/stage_magnet_send.go | 73 ++++++++++++++++++++ 3 files changed, 289 insertions(+) create mode 100644 internal/message.go create mode 100644 internal/stage_magnet_send.go diff --git a/internal/message.go b/internal/message.go new file mode 100644 index 0000000..1d19945 --- /dev/null +++ b/internal/message.go @@ -0,0 +1,112 @@ +// Helper methods to read and send BitTorrent messages +package internal + +import ( + "bytes" + "encoding/binary" + "io" + "net" + + logger "github.com/codecrafters-io/tester-utils/logger" + "github.com/jackpal/bencode-go" +) + +type messageID uint8 +type Message struct { + ID messageID + Payload []byte +} + +const ( + HandshakeExtendedID uint8 = 0 + RequestMetadataExtensionMsgType uint8 = 0 + DataMetadataExtensionMsgType uint8 = 1 + + MsgBitfield messageID = 5 + MsgExtended messageID = 20 +) + +func sendBitfieldMessage(conn net.Conn, payload []byte, logger *logger.Logger) (err error) { + defer logOnExit(logger, &err) + + logger.Debugln("Sending bitfield message") + req := Message { ID: MsgBitfield, Payload: payload } + serialized := req.Serialize() + _, err = conn.Write(serialized) + return err +} + +// Serialize serializes a message into a buffer of the form +// +// Interprets `nil` as a keep-alive message +func (m *Message) Serialize() []byte { + if m == nil { + return make([]byte, 4) + } + length := uint32(len(m.Payload) + 1) // +1 for id + buf := make([]byte, 4+length) + binary.BigEndian.PutUint32(buf[0:4], length) + buf[4] = byte(m.ID) + copy(buf[5:], m.Payload) + return buf +} + +func sendExtensionHandshake(conn net.Conn, metadataID uint8, metadataSize int, logger *logger.Logger) (err error) { + defer logOnExit(logger, &err) + + logger.Debugln("Sending extension handshake") + req := createExtensionHandshake(metadataID, metadataSize, logger) + serialized := req.Serialize() + _, err = conn.Write(serialized) + return err +} + +func createExtensionHandshake(metadataID uint8, metadataSize int, logger *logger.Logger) *Message { + dict := make(map[string]interface{}) + inner := make(map[string]int64) + inner["ut_metadata"] = int64(metadataID) + dict["m"] = inner + dict["metadata_size"] = metadataSize + var buf bytes.Buffer + err := bencode.Marshal(&buf, dict) + if err != nil { + logger.Errorf("Error encoding: %v", err) + } + payload := formatExtendedPayload(buf, HandshakeExtendedID) + return &Message{ID: MsgExtended, Payload: payload} +} + +func formatExtendedPayload(buf bytes.Buffer, extensionId uint8) []byte { + payload := make([]byte, 1+buf.Len()) + payload[0] = uint8(extensionId) + copy(payload[1:], buf.Bytes()) + return payload +} + +// Read parses a message from a stream. Returns `nil` on keep-alive message +func readMessage(r io.Reader) (*Message, error) { + lengthBuf := make([]byte, 4) + _, err := io.ReadFull(r, lengthBuf) + if err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(lengthBuf) + + // keep-alive message + if length == 0 { + return nil, nil + } + + messageBuf := make([]byte, length) + _, err = io.ReadFull(r, messageBuf) + if err != nil { + return nil, err + } + + m := Message{ + ID: messageID(messageBuf[0]), + Payload: messageBuf[1:], + } + + return &m, nil +} \ No newline at end of file diff --git a/internal/stage_magnet_helpers.go b/internal/stage_magnet_helpers.go index ff97c73..95c27ed 100644 --- a/internal/stage_magnet_helpers.go +++ b/internal/stage_magnet_helpers.go @@ -1,11 +1,15 @@ package internal import ( + "bytes" "encoding/hex" + "errors" "fmt" "math/rand" + "net" logger "github.com/codecrafters-io/tester-utils/logger" + "github.com/jackpal/bencode-go" ) type MagnetTestParams struct { @@ -149,3 +153,103 @@ func decodeInfoHash(infoHashStr string) ([20]byte, error) { copy(infoHash[:], decodedBytes) return infoHash, nil } + +func receiveAndAssertExtensionHandshake(conn net.Conn, logger *logger.Logger) (id uint8, err error) { + defer logOnExit(logger, &err) + + msg, err := receiveExtensionHandshake(conn, logger) + if err != nil { + return 0, fmt.Errorf("error receiving extension handshake: %v", err) + } + + metadataExtensionID, err := assertExtensionHandshake(msg, logger) + if err != nil { + return 0, err + } + + return metadataExtensionID, nil +} + +func receiveExtensionHandshake(conn net.Conn, logger *logger.Logger) (*Message, error) { + logger.Debugln("Waiting to receive extension handshake message") + msg, err := readMessage(conn) + if err != nil { + return nil, err + } + logger.Infof("Received extension handshake with payload: %s", string(msg.Payload)) + return msg, nil +} + +func assertExtensionHandshake(msg *Message, logger *logger.Logger) (uint8, error) { + if msg.ID != MsgExtended { + return 0, fmt.Errorf("expected message id: %d, actual: %d", MsgExtended, msg.ID) + } + + if len(msg.Payload) < 2 { + return 0, fmt.Errorf("expecting a larger payload size than %d", len(msg.Payload)) + } + + if msg.Payload[0] != 0 { + return 0, fmt.Errorf("expected extension handshake message id: %d, actual: %d. First byte of payload indicates extension message id and it needs to be zero for extension handshake", 0, msg.Payload[0]) + } + + metadataExtensionID, err := extractMetadataExtensionID(msg, logger) + if err != nil { + return 0, err + } + + return metadataExtensionID, nil +} + +func extractMetadataExtensionID(msg *Message, logger *logger.Logger) (id uint8, err error) { + defer logOnExit(logger, &err) + + logger.Debugln("Checking metadata extension id received") + + handshake, err := bencode.Decode(bytes.NewReader(msg.Payload[1:])) + if err != nil { + return 0, fmt.Errorf("error decoding bencoded dictionary in message payload starting at payload index 1, error message: %s", err) + } + dict, ok := handshake.(map[string]interface{}) + if !ok { + return 0, errors.New("bencoded dictionary missing or wrong type in payload, expected a dictionary with string keys") + } + inner, exists := dict["m"] + if !exists { + return 0, errors.New("dictionary under key m is missing or wrong type") + } + + innerDict, ok := inner.(map[string]interface{}) + if !ok { + return 0, errors.New("dictionary under key m is of wrong type, expected a dictionary with string keys") + } + value, exists := innerDict["ut_metadata"] + if exists { + theirMetadataExtensionID, ok := value.(int64) + if !ok { + return 0, errors.New("value for ut_metadata needs to be an integer, it's wrong type") + } + if theirMetadataExtensionID <= 0 { + return 0, errors.New("value for ut_metadata needs to be greater than zero") + } + theirMetadataExtensionIDUint8, err := safeINT64toUINT8(theirMetadataExtensionID) + if err != nil { + return 0, err + } + return theirMetadataExtensionIDUint8, nil + } else { + return 0, errors.New("ut_metadata key is missing in dictionary under key m during extension handshake") + } +} + +func safeINT64toUINT8(i any) (uint8, error) { + if value, ok := i.(int64); ok { + if value < 0 || value > 255 { + return 0, fmt.Errorf("number out of range for uint8") + } else { + return uint8(value), nil + } + } else { + return 0, fmt.Errorf("expected int64, received different type") + } +} diff --git a/internal/stage_magnet_send.go b/internal/stage_magnet_send.go new file mode 100644 index 0000000..e950644 --- /dev/null +++ b/internal/stage_magnet_send.go @@ -0,0 +1,73 @@ +package internal + +import ( + "errors" + "fmt" + "net" + + "github.com/codecrafters-io/tester-utils/test_case_harness" +) + +var handshakeChannel = make(chan bool) + +func testMagnetSendExtendedHandshake(stageHarness *test_case_harness.TestCaseHarness) error { + initRandom() + + logger := stageHarness.Logger + executable := stageHarness.Executable + + magnetLink := randomMagnetLink() + params, err := NewMagnetTestParams(magnetLink, logger) + if err != nil { + return err + } + + go listenAndServeTrackerResponse(params.toTrackerParams()) + go waitAndHandlePeerConnection(params.toPeerConnectionParams(), handleSendExtensionHandshake) + + logger.Infof("Running ./your_bittorrent.sh magnet_handshake %q", params.MagnetUrlEncoded) + result, err := executable.Run("magnet_handshake", params.MagnetUrlEncoded) + if err != nil { + return err + } + + if err = assertExitCode(result, 0); err != nil { + return err + } + + expected := fmt.Sprintf("Peer Metadata Extension ID: %d\n", params.MyMetadataExtensionID) + + if err = assertStdoutContains(result, expected); err != nil { + return err + } + + success := <-handshakeChannel + if success { + return nil + } + + return errors.New("extension handshake was not received") +} + +func handleSendExtensionHandshake(conn net.Conn, params PeerConnectionParams) { + defer conn.Close() + logger := params.logger + + if err := receiveAndSendHandshake(conn, params); err != nil { + return + } + + if err := sendBitfieldMessage(conn, params.bitfield, logger); err != nil { + return + } + + if err := sendExtensionHandshake(conn, params.myMetadataExtensionID, params.metadataSizeBytes, logger); err != nil { + return + } + + if _, err := receiveAndAssertExtensionHandshake(conn, logger); err != nil { + return + } + + handshakeChannel <- true +}