Skip to content

Commit

Permalink
added serial number set implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Snawoot committed Oct 26, 2024
1 parent c9ae62c commit f4b9175
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 0 deletions.
75 changes: 75 additions & 0 deletions auth/cert.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package auth

import (
"bufio"
"bytes"
"encoding/hex"
"errors"
"fmt"
"io"
"math/big"
"net/http"
"sync/atomic"
)

type CertAuth struct{}
Expand Down Expand Up @@ -32,5 +37,75 @@ func formatSerial(serial *big.Int) string {
for i := 0; i < len(x); i += 2 {
buf = append(buf, x[i], x[i+1], ':')
}
if serial.Sign() == -1 {
return "(Negative)" + string(buf[:len(buf)-1])
}
return string(buf[:len(buf)-1])
}

type serialNumberKey = [20]byte
type serialNumberSet struct {
sns atomic.Pointer[map[serialNumberKey]struct{}]
}

func normalizeSNBytes(b []byte) serialNumberKey {
var k serialNumberKey
copy(
k[max(len(k)-len(b), 0):],
b[max(len(b)-len(k), 0):],
)
return k
}

func (s *serialNumberSet) Has(serial *big.Int) bool {
key := normalizeSNBytes(serial.Bytes())
sns := s.sns.Load()
if sns == nil || *sns == nil {
return false
}
_, found := (*sns)[key]
return found
}

func (s *serialNumberSet) LoadFromReader(r io.Reader) error {
set := make(map[serialNumberKey]struct{})
scanner := bufio.NewScanner(r)
for scanner.Scan() {
line, _, _ := bytes.Cut(scanner.Bytes(), []byte{'#'})
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
serial, err := parseSerialBytes(line)
if err != nil {
continue
}
set[normalizeSNBytes(serial)] = struct{}{}
}

if err := scanner.Err(); err != nil {
return fmt.Errorf("unable to load serial number set: %w", err)
}

s.sns.Store(&set)
return nil
}

func parseSerialBytes(serial []byte) ([]byte, error) {
res := make([]byte, (len(serial)+2)/3)

var i int
for ; i < len(res) && i*3+1 < len(serial); i++ {
if _, err := hex.Decode(res[i:i+1], serial[i*3:i*3+2]); err != nil {
return nil, fmt.Errorf("parseSerialBytes() failed: %w", err)
}
if i*3+2 < len(serial) && serial[i*3+2] != ':' {
return nil, errors.New("missing colon delimiter")
}
}
if i < len(res) {
return nil, errors.New("incomplete serial number string")
}

return res, nil
}
144 changes: 144 additions & 0 deletions auth/cert_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package auth

import (
"bytes"
"fmt"
"math/big"
"strings"
"testing"
)

func mkbytes(l uint) []byte {
b := make([]byte, l)
for i := uint(0); i < l; i++ {
b[i] = byte(i)
}
return b
}

var mask *big.Int = big.NewInt(0).Add(big.NewInt(0).Lsh(big.NewInt(1), uint(8*len(serialNumberKey{}))), big.NewInt(-1))

func TestNormalizeSNBytes(t *testing.T) {
for i := uint(0); i <= 32; i++ {
t.Run(fmt.Sprintf("%d-bytes", i), func(t *testing.T) {
s := mkbytes(i)
k := normalizeSNBytes(s)
var a, b big.Int
a.SetBytes(s).And(&a, mask)
b.SetBytes(k[:])
if a.Cmp(&b) != 0 {
t.Fatalf("%d != %d", &a, &b)
}
})
}
}

type parseSerialBytesTestcase struct {
input []byte
output []byte
error bool
}

func TestParseSerialBytes(t *testing.T) {
testcases := []parseSerialBytesTestcase{
{
input: []byte(""),
output: []byte{},
},
{
input: []byte("01:02:03"),
output: []byte{1, 2, 3},
},
{
input: []byte("ff"),
output: []byte{255},
},
{
input: []byte("ff:f"),
error: true,
},
{
input: []byte("f"),
error: true,
},
{
input: []byte("fff"),
error: true,
},
{
input: []byte("---"),
error: true,
},
}
for i, testcase := range testcases {
t.Run(fmt.Sprintf("Testcase[%d]", i), func(t *testing.T) {
out, err := parseSerialBytes(testcase.input)
if (err != nil) != testcase.error {
t.Fatalf("unexpected error: %v", err)
}
if bytes.Compare(out, testcase.output) != 0 {
t.Fatalf("expected %v, got %v", testcase.output, out)
}
})
}
}

type serialNumberSetTestcase struct {
input *big.Int
output bool
}

func TestSerialNumberSetSmoke(t *testing.T) {
var s serialNumberSet
const testFile = `
01:00:00:00:00 # test
# test 2
03
03
00
01
02`
testcases := []serialNumberSetTestcase{
{
input: big.NewInt(1<<32),
output: true,
},
{
input: big.NewInt(0),
output: true,
},
{
input: big.NewInt(1),
output: true,
},
{
input: big.NewInt(2),
output: true,
},
{
input: big.NewInt(3),
output: true,
},
{
input: big.NewInt(4),
output: false,
},
{
input: big.NewInt(-2),
output: true,
},
}
err := s.LoadFromReader(strings.NewReader(testFile))
if err != nil {
t.Fatalf("unable to load test set: %v", err)
}
for i, testcase := range testcases {
t.Run(fmt.Sprintf("Testcase[%d]", i), func(t *testing.T) {
out := s.Has(testcase.input)
if out != testcase.output {
t.Fatalf("expected %v, got %v", testcase.output, out)
}
})
}
}

0 comments on commit f4b9175

Please sign in to comment.