diff --git a/client/client_test.go b/client/client_test.go index 39e6079..099d6a6 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -1890,11 +1890,18 @@ func (stp *testSetup) recv() pkts.Packet { rawPacket = rawPacket[:n] var h pkts.Header - h.Unpack(rawPacket) - p := pkts1.NewPacketWithHeader(h) - p.Unpack(rawPacket[h.HeaderLength():]) + if err := h.Unpack(rawPacket); err != nil { + stp.t.Fatal(err) + } + pkt, err := pkts1.NewPacketWithHeader(h) + if err != nil { + stp.t.Fatal(err) + } + if err := pkt.Unpack(rawPacket[h.HeaderLength():]); err != nil { + stp.t.Fatal(err) + } - return p + return pkt } func testRead(conn net.Conn, timeout time.Duration) ([]byte, error) { diff --git a/gateway/gateway_test.go b/gateway/gateway_test.go index c7940c3..bf7115c 100644 --- a/gateway/gateway_test.go +++ b/gateway/gateway_test.go @@ -1341,11 +1341,18 @@ func (stp *testSetup) snRecv() snPkts.Packet { rawPacket = rawPacket[:n] var h snPkts.Header - h.Unpack(rawPacket) - p := snPkts1.NewPacketWithHeader(h) - p.Unpack(rawPacket[h.HeaderLength():]) + if err := h.Unpack(rawPacket); err != nil { + stp.t.Fatal(err) + } + pkt, err := snPkts1.NewPacketWithHeader(h) + if err != nil { + stp.t.Fatal(err) + } + if err := pkt.Unpack(rawPacket[h.HeaderLength():]); err != nil { + stp.t.Fatal(err) + } - return p + return pkt } func (stp *testSetup) mqttSend(pkt mqPkts.ControlPacket, setMsgID bool) { diff --git a/packets1/packets1.go b/packets1/packets1.go index d640d39..2b8d90d 100644 --- a/packets1/packets1.go +++ b/packets1/packets1.go @@ -2,7 +2,6 @@ package packets1 import ( - "errors" "fmt" "io" @@ -74,20 +73,20 @@ func (c ReturnCode) String() string { func ReadPacket(r io.Reader) (pkt pkts.Packet, err error) { var h pkts.Header - packet := make([]byte, MaxPacketLen) - n, err := r.Read(packet) + rawPacket := make([]byte, MaxPacketLen) + n, err := r.Read(rawPacket) if err != nil { return nil, err } - packet = packet[:n] - if err := h.Unpack(packet); err != nil { + rawPacket = rawPacket[:n] + if err := h.Unpack(rawPacket); err != nil { return nil, err } - pkt = NewPacketWithHeader(h) - if pkt == nil { - return nil, errors.New("invalid MQTT-SN packet type") + pkt, err = NewPacketWithHeader(h) + if err != nil { + return nil, err } - if err := pkt.Unpack(packet[h.HeaderLength():]); err != nil { + if err := pkt.Unpack(rawPacket[h.HeaderLength():]); err != nil { return nil, err } @@ -96,7 +95,7 @@ func ReadPacket(r io.Reader) (pkt pkts.Packet, err error) { // NewPacketWithHeader returns a particular packet struct with a given header. // The struct type is determined by h.msgType. -func NewPacketWithHeader(h pkts.Header) (pkt pkts.Packet) { +func NewPacketWithHeader(h pkts.Header) (pkt pkts.Packet, err error) { switch h.PacketType() { case pkts.ADVERTISE: pkt = &Advertise{Header: h} @@ -154,6 +153,8 @@ func NewPacketWithHeader(h pkts.Header) (pkt pkts.Packet) { pkt = &WillMsgUpd{Header: h} case pkts.WILLMSGRESP: pkt = &WillMsgResp{Header: h} + default: + err = fmt.Errorf("invalid MQTT-SN 1.2 packet type: %d", h.PacketType()) } return } diff --git a/packets1/packets1_test.go b/packets1/packets1_test.go index af03a33..909bda8 100644 --- a/packets1/packets1_test.go +++ b/packets1/packets1_test.go @@ -42,6 +42,6 @@ func TestUnmarshalInvalidPacketType(t *testing.T) { }) _, err := ReadPacket(buff) if assert.Error(t, err) { - assert.Equal(t, err.Error(), "invalid MQTT-SN packet type") + assert.Equal(t, err.Error(), "invalid MQTT-SN 1.2 packet type: 25") } }