Skip to content

Commit

Permalink
add test for send extension handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
sarp committed Sep 6, 2024
1 parent c0d2197 commit 2223c3f
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 0 deletions.
112 changes: 112 additions & 0 deletions internal/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package internal

import (
"bytes"
"encoding/binary"
"io"
"net"

logger "github.com/codecrafters-io/tester-utils/logger"
"github.com/jackpal/bencode-go"
)

// Helper methods to read and send BitTorrent messages
type messageID uint8
type Message struct {
ID messageID
Payload []byte
}

const (
HandshakeExtendedID uint8 = 0
RequestMetadataExtensionMsgType uint8 = 0
DataMetadataExtensionMsgType uint8 = 1

MsgBitfield messageID = 5
MsgExtended messageID = 20
)

func sendBitfieldMessage(conn net.Conn, payload []byte, logger *logger.Logger) (err error) {
defer logOnExit(logger, &err)

logger.Debugln("Sending bitfield message")
req := Message { ID: MsgBitfield, Payload: payload }
serialized := req.Serialize()
_, err = conn.Write(serialized)
return err
}

// Serialize serializes a message into a buffer of the form
// <length prefix><message ID><payload>
// Interprets `nil` as a keep-alive message
func (m *Message) Serialize() []byte {
if m == nil {
return make([]byte, 4)
}
length := uint32(len(m.Payload) + 1) // +1 for id
buf := make([]byte, 4+length)
binary.BigEndian.PutUint32(buf[0:4], length)
buf[4] = byte(m.ID)
copy(buf[5:], m.Payload)
return buf
}

func sendExtensionHandshake(conn net.Conn, metadataID uint8, metadataSize int, logger *logger.Logger) (err error) {
defer logOnExit(logger, &err)

logger.Debugln("Sending extension handshake")
req := createExtensionHandshake(metadataID, metadataSize, logger)
serialized := req.Serialize()
_, err = conn.Write(serialized)
return err
}

func createExtensionHandshake(metadataID uint8, metadataSize int, logger *logger.Logger) *Message {
dict := make(map[string]interface{})
inner := make(map[string]int64)
inner["ut_metadata"] = int64(metadataID)
dict["m"] = inner
dict["metadata_size"] = metadataSize
var buf bytes.Buffer
err := bencode.Marshal(&buf, dict)
if err != nil {
logger.Errorf("Error encoding: %v", err)
}
payload := formatExtendedPayload(buf, HandshakeExtendedID)
return &Message{ID: MsgExtended, Payload: payload}
}

func formatExtendedPayload(buf bytes.Buffer, extensionId uint8) []byte {
payload := make([]byte, 1+buf.Len())
payload[0] = uint8(extensionId)
copy(payload[1:], buf.Bytes())
return payload
}

// Read parses a message from a stream. Returns `nil` on keep-alive message
func readMessage(r io.Reader) (*Message, error) {
lengthBuf := make([]byte, 4)
_, err := io.ReadFull(r, lengthBuf)
if err != nil {
return nil, err
}
length := binary.BigEndian.Uint32(lengthBuf)

// keep-alive message
if length == 0 {
return nil, nil
}

messageBuf := make([]byte, length)
_, err = io.ReadFull(r, messageBuf)
if err != nil {
return nil, err
}

m := Message{
ID: messageID(messageBuf[0]),
Payload: messageBuf[1:],
}

return &m, nil
}
104 changes: 104 additions & 0 deletions internal/stage_magnet_helpers.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package internal

import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"math/rand"
"net"

logger "github.com/codecrafters-io/tester-utils/logger"
"github.com/jackpal/bencode-go"
)

type MagnetTestParams struct {
Expand Down Expand Up @@ -149,3 +153,103 @@ func decodeInfoHash(infoHashStr string) ([20]byte, error) {
copy(infoHash[:], decodedBytes)
return infoHash, nil
}

func receiveAndAssertExtensionHandshake(conn net.Conn, logger *logger.Logger) (id uint8, err error) {
defer logOnExit(logger, &err)

msg, err := receiveExtensionHandshake(conn, logger)
if err != nil {
return 0, fmt.Errorf("error receiving extension handshake: %v", err)
}

metadataExtensionID, err := assertExtensionHandshake(msg, logger)
if err != nil {
return 0, err
}

return metadataExtensionID, nil
}

func receiveExtensionHandshake(conn net.Conn, logger *logger.Logger) (*Message, error) {
logger.Debugln("Waiting to receive extension handshake message")
msg, err := readMessage(conn)
if err != nil {
return nil, err
}
logger.Infof("Received extension handshake with payload: %s", string(msg.Payload))
return msg, nil
}

func assertExtensionHandshake(msg *Message, logger *logger.Logger) (uint8, error) {
if msg.ID != MsgExtended {
return 0, fmt.Errorf("expected message id: %d, actual: %d", MsgExtended, msg.ID)
}

if len(msg.Payload) < 2 {
return 0, fmt.Errorf("expecting a larger payload size than %d", len(msg.Payload))
}

if msg.Payload[0] != 0 {
return 0, fmt.Errorf("expected extension handshake message id: %d, actual: %d. First byte of payload indicates extension message id and it needs to be zero for extension handshake", 0, msg.Payload[0])
}

metadataExtensionID, err := extractMetadataExtensionID(msg, logger)
if err != nil {
return 0, err
}

return metadataExtensionID, nil
}

func extractMetadataExtensionID(msg *Message, logger *logger.Logger) (id uint8, err error) {
defer logOnExit(logger, &err)

logger.Debugln("Checking metadata extension id received")

handshake, err := bencode.Decode(bytes.NewReader(msg.Payload[1:]))
if err != nil {
return 0, fmt.Errorf("error decoding bencoded dictionary in message payload starting at payload index 1, error message: %s", err)
}
dict, ok := handshake.(map[string]interface{})
if !ok {
return 0, errors.New("bencoded dictionary missing or wrong type in payload, expected a dictionary with string keys")
}
inner, exists := dict["m"]
if !exists {
return 0, errors.New("dictionary under key m is missing or wrong type")
}

innerDict, ok := inner.(map[string]interface{})
if !ok {
return 0, errors.New("dictionary under key m is of wrong type, expected a dictionary with string keys")
}
value, exists := innerDict["ut_metadata"]
if exists {
theirMetadataExtensionID, ok := value.(int64)
if !ok {
return 0, errors.New("value for ut_metadata needs to be an integer, it's wrong type")
}
if theirMetadataExtensionID <= 0 {
return 0, errors.New("value for ut_metadata needs to be greater than zero")
}
theirMetadataExtensionIDUint8, err := safeINT64toUINT8(theirMetadataExtensionID)
if err != nil {
return 0, err
}
return theirMetadataExtensionIDUint8, nil
} else {
return 0, errors.New("ut_metadata key is missing in dictionary under key m during extension handshake")
}
}

func safeINT64toUINT8(i any) (uint8, error) {
if value, ok := i.(int64); ok {
if value < 0 || value > 255 {
return 0, fmt.Errorf("number out of range for uint8")
} else {
return uint8(value), nil
}
} else {
return 0, fmt.Errorf("expected int64, received different type")
}
}
73 changes: 73 additions & 0 deletions internal/stage_magnet_send.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package internal

import (
"errors"
"fmt"
"net"

"github.com/codecrafters-io/tester-utils/test_case_harness"
)

var handshakeChannel = make(chan bool)

func testMagnetSendExtendedHandshake(stageHarness *test_case_harness.TestCaseHarness) error {
initRandom()

logger := stageHarness.Logger
executable := stageHarness.Executable

magnetLink := randomMagnetLink()
params, err := NewMagnetTestParams(magnetLink, logger)
if err != nil {
return err
}

go listenAndServeTrackerResponse(params.toTrackerParams())
go waitAndHandlePeerConnection(params.toPeerConnectionParams(), handleSendExtensionHandshake)

logger.Infof("Running ./your_bittorrent.sh magnet_handshake %q", params.MagnetUrlEncoded)
result, err := executable.Run("magnet_handshake", params.MagnetUrlEncoded)
if err != nil {
return err
}

if err = assertExitCode(result, 0); err != nil {
return err
}

expected := fmt.Sprintf("Peer Metadata Extension ID: %d\n", params.MyMetadataExtensionID)

if err = assertStdoutContains(result, expected); err != nil {
return err
}

success := <-handshakeChannel
if success {
return nil
}

return errors.New("extension handshake was not received")
}

func handleSendExtensionHandshake(conn net.Conn, params PeerConnectionParams) {
defer conn.Close()
logger := params.logger

if err := receiveAndSendHandshake(conn, params); err != nil {
return
}

if err := sendBitfieldMessage(conn, params.bitfield, logger); err != nil {
return
}

if err := sendExtensionHandshake(conn, params.myMetadataExtensionID, params.metadataSizeBytes, logger); err != nil {
return
}

if _, err := receiveAndAssertExtensionHandshake(conn, logger); err != nil {
return
}

handshakeChannel <- true
}

0 comments on commit 2223c3f

Please sign in to comment.