diff --git a/src/core/version.go b/src/core/version.go index bb3b9538..873669d1 100644 --- a/src/core/version.go +++ b/src/core/version.go @@ -118,26 +118,41 @@ func (m *version_metadata) decode(r io.Reader, password []byte) error { for len(bs) >= 4 { op := binary.BigEndian.Uint16(bs[:2]) - oplen := binary.BigEndian.Uint16(bs[2:4]) - if bs = bs[4:]; len(bs) < int(oplen) { - break + oplen := int(binary.BigEndian.Uint16(bs[2:4])) + if bs = bs[4:]; len(bs) < oplen { + return ErrHandshakeInvalidLength } + field := bs[:oplen] switch op { case metaVersionMajor: - m.majorVer = binary.BigEndian.Uint16(bs[:2]) + if len(field) != 2 { + return ErrHandshakeInvalidLength + } + m.majorVer = binary.BigEndian.Uint16(field) case metaVersionMinor: - m.minorVer = binary.BigEndian.Uint16(bs[:2]) + if len(field) != 2 { + return ErrHandshakeInvalidLength + } + m.minorVer = binary.BigEndian.Uint16(field) case metaPublicKey: - m.publicKey = make(ed25519.PublicKey, ed25519.PublicKeySize) - copy(m.publicKey, bs[:ed25519.PublicKeySize]) + if len(field) != ed25519.PublicKeySize { + return ErrHandshakeInvalidLength + } + m.publicKey = append(m.publicKey[:0], field...) case metaPriority: - m.priority = bs[0] + if len(field) != 1 { + return ErrHandshakeInvalidLength + } + m.priority = field[0] } bs = bs[oplen:] } + if len(bs) != 0 { + return ErrHandshakeInvalidLength + } hasher, err := blake2b.New512(password) if err != nil { diff --git a/src/core/version_test.go b/src/core/version_test.go index 512c6e59..3ab2ceba 100644 --- a/src/core/version_test.go +++ b/src/core/version_test.go @@ -3,8 +3,11 @@ package core import ( "bytes" "crypto/ed25519" + "encoding/binary" "reflect" "testing" + + "golang.org/x/crypto/blake2b" ) func TestVersionPasswordAuth(t *testing.T) { @@ -76,3 +79,78 @@ func TestVersionRoundtrip(t *testing.T) { } } } + +func TestVersionDecodeRejectsMalformedFieldLengths(t *testing.T) { + password := []byte("pw") + for _, tt := range []struct { + name string + op uint16 + field []byte + }{ + {name: "major short", op: metaVersionMajor, field: []byte{1}}, + {name: "minor short", op: metaVersionMinor, field: []byte{1}}, + {name: "public key short", op: metaPublicKey, field: []byte{1}}, + {name: "priority empty", op: metaPriority, field: nil}, + } { + t.Run(tt.name, func(t *testing.T) { + msg := malformedVersionHandshake(t, tt.op, tt.field, password) + var decoded version_metadata + if err := decoded.decode(bytes.NewReader(msg), password); err != ErrHandshakeInvalidLength { + t.Fatalf("expected %q, got %v", ErrHandshakeInvalidLength, err) + } + }) + } +} + +func TestVersionDecodeRejectsTrailingBytes(t *testing.T) { + password := []byte("pw") + pk, sk, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + + hasher, err := blake2b.New512(password) + if err != nil { + t.Fatal(err) + } + if _, err = hasher.Write(pk); err != nil { + t.Fatal(err) + } + sig := ed25519.Sign(sk, hasher.Sum(nil)) + + body := append([]byte{1, 2, 3}, sig...) + msg := append([]byte{'m', 'e', 't', 'a', 0, 0}, body...) + binary.BigEndian.PutUint16(msg[4:6], uint16(len(body))) + var decoded version_metadata + if err := decoded.decode(bytes.NewReader(msg), password); err != ErrHandshakeInvalidLength { + t.Fatalf("expected %q, got %v", ErrHandshakeInvalidLength, err) + } +} + +func malformedVersionHandshake(t *testing.T, op uint16, field []byte, password []byte) []byte { + t.Helper() + + pk, sk, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + + hasher, err := blake2b.New512(password) + if err != nil { + t.Fatal(err) + } + if _, err = hasher.Write(pk); err != nil { + t.Fatal(err) + } + sig := ed25519.Sign(sk, hasher.Sum(nil)) + + body := make([]byte, 0, 4+len(field)+len(sig)) + body = binary.BigEndian.AppendUint16(body, op) + body = binary.BigEndian.AppendUint16(body, uint16(len(field))) + body = append(body, field...) + body = append(body, sig...) + + msg := append([]byte{'m', 'e', 't', 'a', 0, 0}, body...) + binary.BigEndian.PutUint16(msg[4:6], uint16(len(body))) + return msg +} diff --git a/src/multicast/advertisement.go b/src/multicast/advertisement.go index d0db8b5a..3b5b384f 100644 --- a/src/multicast/advertisement.go +++ b/src/multicast/advertisement.go @@ -26,14 +26,18 @@ func (m *multicastAdvertisement) MarshalBinary() ([]byte, error) { } func (m *multicastAdvertisement) UnmarshalBinary(b []byte) error { - if len(b) < ed25519.PublicKeySize+8 { + const headerLen = ed25519.PublicKeySize + 8 + if len(b) < headerLen { return fmt.Errorf("invalid multicast beacon") } m.MajorVersion = binary.BigEndian.Uint16(b[0:2]) m.MinorVersion = binary.BigEndian.Uint16(b[2:4]) m.PublicKey = append(m.PublicKey[:0], b[4:4+ed25519.PublicKeySize]...) m.Port = binary.BigEndian.Uint16(b[4+ed25519.PublicKeySize : 6+ed25519.PublicKeySize]) - dl := binary.BigEndian.Uint16(b[6+ed25519.PublicKeySize : 8+ed25519.PublicKeySize]) - m.Hash = append(m.Hash[:0], b[8+ed25519.PublicKeySize:8+ed25519.PublicKeySize+dl]...) + dl := int(binary.BigEndian.Uint16(b[6+ed25519.PublicKeySize : 8+ed25519.PublicKeySize])) + if len(b) < headerLen+dl { + return fmt.Errorf("invalid multicast beacon") + } + m.Hash = append(m.Hash[:0], b[headerLen:headerLen+dl]...) return nil } diff --git a/src/multicast/advertisement_test.go b/src/multicast/advertisement_test.go index 9541da60..89a569f4 100644 --- a/src/multicast/advertisement_test.go +++ b/src/multicast/advertisement_test.go @@ -2,6 +2,7 @@ package multicast import ( "crypto/ed25519" + "encoding/binary" "reflect" "testing" ) @@ -36,3 +37,20 @@ func TestMulticastAdvertisementRoundTrip(t *testing.T) { t.Fatalf("differences found after round-trip") } } + +func TestMulticastAdvertisementRejectsTruncatedHash(t *testing.T) { + pk, _, err := ed25519.GenerateKey(nil) + if err != nil { + t.Fatal(err) + } + + b := make([]byte, ed25519.PublicKeySize+8) + copy(b[4:], pk) + binary.BigEndian.PutUint16(b[4+ed25519.PublicKeySize:6+ed25519.PublicKeySize], 9001) + binary.BigEndian.PutUint16(b[6+ed25519.PublicKeySize:8+ed25519.PublicKeySize], 32) + + var adv multicastAdvertisement + if err := adv.UnmarshalBinary(b); err == nil { + t.Fatal("expected truncated beacon to be rejected") + } +}