From f4b9175f3c464766254c5eb466c02da1091da740 Mon Sep 17 00:00:00 2001 From: Vladislav Yarmak Date: Sat, 26 Oct 2024 01:35:33 +0300 Subject: [PATCH] added serial number set implementation --- auth/cert.go | 75 ++++++++++++++++++++++++ auth/cert_test.go | 144 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+) create mode 100644 auth/cert_test.go diff --git a/auth/cert.go b/auth/cert.go index 0a88282..6e3328f 100644 --- a/auth/cert.go +++ b/auth/cert.go @@ -1,10 +1,15 @@ package auth import ( + "bufio" + "bytes" "encoding/hex" + "errors" "fmt" + "io" "math/big" "net/http" + "sync/atomic" ) type CertAuth struct{} @@ -32,5 +37,75 @@ func formatSerial(serial *big.Int) string { for i := 0; i < len(x); i += 2 { buf = append(buf, x[i], x[i+1], ':') } + if serial.Sign() == -1 { + return "(Negative)" + string(buf[:len(buf)-1]) + } return string(buf[:len(buf)-1]) } + +type serialNumberKey = [20]byte +type serialNumberSet struct { + sns atomic.Pointer[map[serialNumberKey]struct{}] +} + +func normalizeSNBytes(b []byte) serialNumberKey { + var k serialNumberKey + copy( + k[max(len(k)-len(b), 0):], + b[max(len(b)-len(k), 0):], + ) + return k +} + +func (s *serialNumberSet) Has(serial *big.Int) bool { + key := normalizeSNBytes(serial.Bytes()) + sns := s.sns.Load() + if sns == nil || *sns == nil { + return false + } + _, found := (*sns)[key] + return found +} + +func (s *serialNumberSet) LoadFromReader(r io.Reader) error { + set := make(map[serialNumberKey]struct{}) + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line, _, _ := bytes.Cut(scanner.Bytes(), []byte{'#'}) + line = bytes.TrimSpace(line) + if len(line) == 0 { + continue + } + serial, err := parseSerialBytes(line) + if err != nil { + continue + } + set[normalizeSNBytes(serial)] = struct{}{} + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("unable to load serial number set: %w", err) + } + + s.sns.Store(&set) + return nil +} + +func parseSerialBytes(serial []byte) ([]byte, error) { + res := make([]byte, (len(serial)+2)/3) + + var i int + for ; i < len(res) && i*3+1 < len(serial); i++ { + if _, err := hex.Decode(res[i:i+1], serial[i*3:i*3+2]); err != nil { + return nil, fmt.Errorf("parseSerialBytes() failed: %w", err) + } + if i*3+2 < len(serial) && serial[i*3+2] != ':' { + return nil, errors.New("missing colon delimiter") + } + } + if i < len(res) { + return nil, errors.New("incomplete serial number string") + } + + return res, nil +} diff --git a/auth/cert_test.go b/auth/cert_test.go new file mode 100644 index 0000000..1091e1e --- /dev/null +++ b/auth/cert_test.go @@ -0,0 +1,144 @@ +package auth + +import ( + "bytes" + "fmt" + "math/big" + "strings" + "testing" +) + +func mkbytes(l uint) []byte { + b := make([]byte, l) + for i := uint(0); i < l; i++ { + b[i] = byte(i) + } + return b +} + +var mask *big.Int = big.NewInt(0).Add(big.NewInt(0).Lsh(big.NewInt(1), uint(8*len(serialNumberKey{}))), big.NewInt(-1)) + +func TestNormalizeSNBytes(t *testing.T) { + for i := uint(0); i <= 32; i++ { + t.Run(fmt.Sprintf("%d-bytes", i), func(t *testing.T) { + s := mkbytes(i) + k := normalizeSNBytes(s) + var a, b big.Int + a.SetBytes(s).And(&a, mask) + b.SetBytes(k[:]) + if a.Cmp(&b) != 0 { + t.Fatalf("%d != %d", &a, &b) + } + }) + } +} + +type parseSerialBytesTestcase struct { + input []byte + output []byte + error bool +} + +func TestParseSerialBytes(t *testing.T) { + testcases := []parseSerialBytesTestcase{ + { + input: []byte(""), + output: []byte{}, + }, + { + input: []byte("01:02:03"), + output: []byte{1, 2, 3}, + }, + { + input: []byte("ff"), + output: []byte{255}, + }, + { + input: []byte("ff:f"), + error: true, + }, + { + input: []byte("f"), + error: true, + }, + { + input: []byte("fff"), + error: true, + }, + { + input: []byte("---"), + error: true, + }, + } + for i, testcase := range testcases { + t.Run(fmt.Sprintf("Testcase[%d]", i), func(t *testing.T) { + out, err := parseSerialBytes(testcase.input) + if (err != nil) != testcase.error { + t.Fatalf("unexpected error: %v", err) + } + if bytes.Compare(out, testcase.output) != 0 { + t.Fatalf("expected %v, got %v", testcase.output, out) + } + }) + } +} + +type serialNumberSetTestcase struct { + input *big.Int + output bool +} + +func TestSerialNumberSetSmoke(t *testing.T) { + var s serialNumberSet + const testFile = ` +01:00:00:00:00 # test +# test 2 +03 +03 + +00 + 01 +02` + testcases := []serialNumberSetTestcase{ + { + input: big.NewInt(1<<32), + output: true, + }, + { + input: big.NewInt(0), + output: true, + }, + { + input: big.NewInt(1), + output: true, + }, + { + input: big.NewInt(2), + output: true, + }, + { + input: big.NewInt(3), + output: true, + }, + { + input: big.NewInt(4), + output: false, + }, + { + input: big.NewInt(-2), + output: true, + }, + } + err := s.LoadFromReader(strings.NewReader(testFile)) + if err != nil { + t.Fatalf("unable to load test set: %v", err) + } + for i, testcase := range testcases { + t.Run(fmt.Sprintf("Testcase[%d]", i), func(t *testing.T) { + out := s.Has(testcase.input) + if out != testcase.output { + t.Fatalf("expected %v, got %v", testcase.output, out) + } + }) + } +}