diff --git a/core/auxiliary.go b/core/auxiliary.go index 7d3321e..d2a4737 100644 --- a/core/auxiliary.go +++ b/core/auxiliary.go @@ -2,7 +2,6 @@ package core import ( "os" - "reflect" "unsafe" ) @@ -17,7 +16,10 @@ func roundToPageSize(length int) int { } // Convert a pointer and length to a byte slice that describes that memory. -func getBytes(ptr *byte, len int) []byte { - var sl = reflect.SliceHeader{Data: uintptr(unsafe.Pointer(ptr)), Len: len, Cap: len} - return *(*[]byte)(unsafe.Pointer(&sl)) +func getBufferPart(buf []byte, offset, length int) []byte { + start := offset + if offset < 0 { + start = len(buf) + offset + } + return unsafe.Slice(&buf[start], length) } diff --git a/core/auxiliary_test.go b/core/auxiliary_test.go index 67b5432..392d2fa 100644 --- a/core/auxiliary_test.go +++ b/core/auxiliary_test.go @@ -29,7 +29,7 @@ func TestGetBytes(t *testing.T) { buffer := make([]byte, 32) // Get am alternate reference to it using our slice builder. - derived := getBytes(&buffer[0], len(buffer)) + derived := getBufferPart(buffer, 0, len(buffer)) // Check for naive equality. if !bytes.Equal(buffer, derived) { diff --git a/core/buffer.go b/core/buffer.go index 4e15ce8..630c062 100644 --- a/core/buffer.go +++ b/core/buffer.go @@ -3,12 +3,11 @@ package core import ( "errors" "sync" - - "github.com/awnumar/memcall" ) var ( - buffers = new(bufferList) + allocator = NewPageAllocator() + buffers = new(bufferList) ) // ErrNullBuffer is returned when attempting to construct a buffer of size less than one. @@ -28,22 +27,13 @@ type Buffer struct { alive bool // Signals that destruction has not come mutable bool // Mutability state of underlying memory - data []byte // Portion of memory holding the data - memory []byte // Entire allocated memory region - - preguard []byte // Guard page addressed before the data - inner []byte // Inner region between the guard pages - postguard []byte // Guard page addressed after the data - - canary []byte // Value written behind data to detect spillage + data []byte // Portion of memory holding the data } /* NewBuffer is a raw constructor for the Buffer object. */ func NewBuffer(size int) (*Buffer, error) { - var err error - if size < 1 { return nil, ErrNullBuffer } @@ -51,43 +41,12 @@ func NewBuffer(size int) (*Buffer, error) { b := new(Buffer) // Allocate the total needed memory - innerLen := roundToPageSize(size) - b.memory, err = memcall.Alloc((2 * pageSize) + innerLen) + var err error + b.data, err = allocator.Alloc(size) if err != nil { Panic(err) } - // Construct slice reference for data buffer. - b.data = getBytes(&b.memory[pageSize+innerLen-size], size) - - // Construct slice references for page sectors. - b.preguard = getBytes(&b.memory[0], pageSize) - b.inner = getBytes(&b.memory[pageSize], innerLen) - b.postguard = getBytes(&b.memory[pageSize+innerLen], pageSize) - - // Construct slice reference for canary portion of inner page. - b.canary = getBytes(&b.memory[pageSize], len(b.inner)-len(b.data)) - - // Lock the pages that will hold sensitive data. - if err := memcall.Lock(b.inner); err != nil { - Panic(err) - } - - // Initialise the canary value and reference regions. - if err := Scramble(b.canary); err != nil { - Panic(err) - } - Copy(b.preguard, b.canary) - Copy(b.postguard, b.canary) - - // Make the guard pages inaccessible. - if err := memcall.Protect(b.preguard, memcall.NoAccess()); err != nil { - Panic(err) - } - if err := memcall.Protect(b.postguard, memcall.NoAccess()); err != nil { - Panic(err) - } - // Set remaining properties b.alive = true b.mutable = true @@ -106,7 +65,7 @@ func (b *Buffer) Data() []byte { // Inner returns a byte slice representing the entire inner memory pages. This should NOT be used unless you have a specific need. func (b *Buffer) Inner() []byte { - return b.inner + return allocator.Inner(b.data) } // Freeze makes the underlying memory of a given buffer immutable. This will do nothing if the Buffer has been destroyed. @@ -125,7 +84,7 @@ func (b *Buffer) freeze() error { } if b.mutable { - if err := memcall.Protect(b.inner, memcall.ReadOnly()); err != nil { + if err := allocator.Protect(b.data, true); err != nil { return err } b.mutable = false @@ -150,7 +109,7 @@ func (b *Buffer) melt() error { } if !b.mutable { - if err := memcall.Protect(b.inner, memcall.ReadWrite()); err != nil { + if err := allocator.Protect(b.data, false); err != nil { return err } b.mutable = true @@ -198,42 +157,17 @@ func (b *Buffer) destroy() error { return nil } - // Make all of the memory readable and writable. - if err := memcall.Protect(b.memory, memcall.ReadWrite()); err != nil { - return err - } - b.mutable = true - - // Wipe data field. - Wipe(b.data) - - // Verify the canary - if !Equal(b.preguard, b.postguard) || !Equal(b.preguard[:len(b.canary)], b.canary) { - return errors.New(" canary verification failed; buffer overflow detected") - } - - // Wipe the memory. - Wipe(b.memory) - - // Unlock pages locked into memory. - if err := memcall.Unlock(b.inner); err != nil { - return err - } - - // Free all related memory. - if err := memcall.Free(b.memory); err != nil { - return err + // Destroy the memory content and free the space + if b.data != nil { + if err := allocator.Free(b.data); err != nil { + return err + } } // Reset the fields. b.alive = false b.mutable = false b.data = nil - b.memory = nil - b.preguard = nil - b.inner = nil - b.postguard = nil - b.canary = nil return nil } diff --git a/core/buffer_test.go b/core/buffer_test.go index 175d7f1..6010515 100644 --- a/core/buffer_test.go +++ b/core/buffer_test.go @@ -4,52 +4,32 @@ import ( "bytes" "testing" "unsafe" + + "github.com/stretchr/testify/require" ) func TestNewBuffer(t *testing.T) { // Check the error case with zero length. - b, err := NewBuffer(0) - if err != ErrNullBuffer { - t.Error("expected ErrNullBuffer; got", err) - } - if b != nil { - t.Error("expected nil buffer; got", b) - } + a, err := NewBuffer(0) + require.ErrorIs(t, err, ErrNullBuffer) + require.Nil(t, a) // Check the error case with negative length. - b, err = NewBuffer(-1) - if err != ErrNullBuffer { - t.Error("expected ErrNullBuffer; got", err) - } - if b != nil { - t.Error("expected nil buffer; got", b) - } + b, err := NewBuffer(-1) + require.ErrorIs(t, err, ErrNullBuffer) + require.Nil(t, b) // Test normal execution. b, err = NewBuffer(32) - if err != nil { - t.Error("expected nil err; got", err) - } - if !b.alive { - t.Error("did not expect destroyed buffer") - } - if len(b.Data()) != 32 || cap(b.Data()) != 32 { - t.Errorf("buffer has invalid length (%d) or capacity (%d)", len(b.Data()), cap(b.Data())) - } - if !b.mutable { - t.Error("buffer is not marked mutable") - } - if len(b.memory) != roundToPageSize(32)+(2*pageSize) { - t.Error("allocated incorrect length of memory") - } - if !bytes.Equal(b.Data(), make([]byte, 32)) { - t.Error("container is not zero-filled") - } + require.NoError(t, err) + require.True(t, b.alive, "did not expect destroyed buffer") + require.Lenf(t, b.Data(), 32, "buffer has invalid length (%d)", len(b.Data())) + require.Equalf(t, cap(b.Data()), 32, "buffer has invalid capacity (%d)", cap(b.Data())) + require.True(t, b.mutable, "buffer is not marked mutable") + require.EqualValues(t, make([]byte, 32), b.Data(), "container is not zero-filled") // Check if the buffer was added to the buffers list. - if !buffers.exists(b) { - t.Error("buffer not in buffers list") - } + require.True(t, buffers.exists(b), "buffer not in buffers list") // Destroy the buffer. b.Destroy() @@ -58,35 +38,17 @@ func TestNewBuffer(t *testing.T) { func TestLotsOfAllocs(t *testing.T) { for i := 1; i <= 16385; i++ { b, err := NewBuffer(i) - if err != nil { - t.Error(err) - } - if !b.alive || !b.mutable { - t.Error("invalid metadata") - } - if len(b.data) != i { - t.Error("invalid data length") - } - if len(b.memory) != roundToPageSize(i)+2*pageSize { - t.Error("memory length invalid") - } - if len(b.preguard) != pageSize || len(b.postguard) != pageSize { - t.Error("guard pages length invalid") - } - if len(b.canary) != len(b.inner)-i { - t.Error("canary length invalid") - } - if len(b.inner)%pageSize != 0 { - t.Error("inner length is not multiple of page size") - } + require.NoErrorf(t, err, "creating buffer in iteration %d", i) + require.Truef(t, b.alive, "not alive in iteration %d", i) + require.Truef(t, b.mutable, "not mutable in iteration %d", i) + require.Lenf(t, b.data, i, "invalid data length %d in iteration %d", len(b.data), i) + require.Zerof(t, len(b.Inner())%pageSize, "inner length %d is not multiple of page size in iteration %d", len(b.Inner()), i) + + // Fill data for j := range b.data { b.data[j] = 1 } - for j := range b.data { - if b.data[j] != 1 { - t.Error("region rw test failed") - } - } + require.Equalf(t, bytes.Repeat([]byte{1}, i), b.data, "region rw test failed in iteration %d", i) b.Destroy() } } @@ -184,21 +146,12 @@ func TestDestroy(t *testing.T) { if b.Data() != nil { t.Error("expected bytes buffer to be nil; got", b.Data()) } - if b.memory != nil { - t.Error("expected memory to be nil; got", b.memory) - } if b.mutable || b.alive { t.Error("buffer should be dead and immutable") } - if b.preguard != nil || b.postguard != nil { - t.Error("guard page slice references are not nil") - } - if b.inner != nil { + if b.Inner() != nil { t.Error("inner pages slice reference not nil") } - if b.canary != nil { - t.Error("canary slice reference not nil") - } // Check if the buffer was removed from the buffers list. if buffers.exists(b) { @@ -212,21 +165,12 @@ func TestDestroy(t *testing.T) { if b.Data() != nil { t.Error("expected bytes buffer to be nil; got", b.Data()) } - if b.memory != nil { - t.Error("expected memory to be nil; got", b.memory) - } if b.mutable || b.alive { t.Error("buffer should be dead and immutable") } - if b.preguard != nil || b.postguard != nil { - t.Error("guard page slice references are not nil") - } - if b.inner != nil { + if b.Inner() != nil { t.Error("inner pages slice reference not nil") } - if b.canary != nil { - t.Error("canary slice reference not nil") - } } func TestBufferList(t *testing.T) { diff --git a/core/exit.go b/core/exit.go index eb29ef1..c99badf 100644 --- a/core/exit.go +++ b/core/exit.go @@ -3,8 +3,6 @@ package core import ( "fmt" "os" - - "github.com/awnumar/memcall" ) /* @@ -37,18 +35,6 @@ func Purge() { } else { opErr = fmt.Errorf("%s; %s", opErr.Error(), err.Error()) } - // buffer destroy failed; wipe instead - b.Lock() - defer b.Unlock() - if !b.mutable { - if err := memcall.Protect(b.inner, memcall.ReadWrite()); err != nil { - // couldn't change it to mutable; we can't wipe it! (could this happen?) - // not sure what we can do at this point, just warn and move on - fmt.Fprintf(os.Stderr, "!WARNING: failed to wipe immutable data at address %p", &b.data) - continue // wipe in subprocess? - } - } - Wipe(b.data) } } }() diff --git a/core/exit_test.go b/core/exit_test.go index 9082872..742d98c 100644 --- a/core/exit_test.go +++ b/core/exit_test.go @@ -60,7 +60,7 @@ func TestPurge(t *testing.T) { if err != nil { t.Error(err) } - Scramble(b.inner) + Scramble(allocator.Inner(b.data)) b.Freeze() if !panics(func() { Purge() diff --git a/core/memallocator.go b/core/memallocator.go new file mode 100644 index 0000000..4fb08d3 --- /dev/null +++ b/core/memallocator.go @@ -0,0 +1,24 @@ +package core + +import ( + "errors" +) + +// Define a memory allocator +type MemAllocator interface { + Alloc(size int) ([]byte, error) + Inner(buf []byte) []byte + Protect(buf []byte, readonly bool) error + Free(buf []byte) error +} + +var ( + // ErrBufferNotOwnedByAllocator indicating that the memory region is not owned by this allocator + ErrBufferNotOwnedByAllocator = errors.New(" buffer not owned by allocator; potential double free") + // ErrBufferOverflow indicating that the memory region was tampered with + ErrBufferOverflow = errors.New(" canary verification failed; buffer overflow detected") + // ErrNullAlloc indicating that a zero length memory region was requested + ErrNullAlloc = errors.New(" zero-length allocation") + // ErrNullPointer indicating an attempted operation on a nil buffer + ErrNullPointer = errors.New(" nil buffer") +) diff --git a/core/memallocator_page.go b/core/memallocator_page.go new file mode 100644 index 0000000..c3e2a87 --- /dev/null +++ b/core/memallocator_page.go @@ -0,0 +1,229 @@ +package core + +import ( + "fmt" + "os" + "sync" + "unsafe" + + "github.com/awnumar/memcall" +) + +type pageAllocator struct { + objects map[int]*pageObject + sync.Mutex +} + +func NewPageAllocator() MemAllocator { + a := &pageAllocator{ + objects: make(map[int]*pageObject), + } + return a +} + +func (a *pageAllocator) Alloc(size int) ([]byte, error) { + if size < 1 { + return nil, ErrNullAlloc + } + o, err := a.newPageObject(size) + if err != nil { + return nil, err + } + + // Store the allocated object with the lookup key of the inner + // buffers address. This allows to efficiently free the buffer + // later + addr := int(uintptr(unsafe.Pointer(&o.data[0]))) + a.Lock() + a.objects[addr] = o + a.Unlock() + + return o.data, nil +} + +func (a *pageAllocator) Protect(buf []byte, readonly bool) error { + if len(buf) == 0 { + return ErrNullPointer + } + + // Determine the object belonging to the buffer + o, found := a.lookup(buf) + if !found { + Panic(ErrBufferNotOwnedByAllocator) + } + + var flag memcall.MemoryProtectionFlag + if readonly { + flag = memcall.ReadOnly() + } else { + flag = memcall.ReadWrite() + } + + return memcall.Protect(o.inner, flag) +} + +func (a *pageAllocator) Inner(buf []byte) []byte { + if len(buf) == 0 { + return nil + } + + // Determine the object belonging to the buffer + o, found := a.lookup(buf) + if !found { + Panic(ErrBufferNotOwnedByAllocator) + } + + return o.inner +} + +func (a *pageAllocator) Free(buf []byte) error { + // Determine the address of the buffer we should free + o, found := a.pop(buf) + if !found { + return ErrBufferNotOwnedByAllocator + } + + // Destroy the object's content + if err := o.wipe(); err != nil { + return err + } + + // Free the related memory + if err := memcall.Free(o.memory); err != nil { + return err + } + + return nil +} + +// *** INTERNAL FUNCTIONS *** // +func (a *pageAllocator) lookup(buf []byte) (*pageObject, bool) { + if len(buf) == 0 { + return nil, false + } + + // Determine the address of the buffer we should free + addr := int(uintptr(unsafe.Pointer(&buf[0]))) + + a.Lock() + defer a.Unlock() + o, found := a.objects[addr] + return o, found +} + +func (a *pageAllocator) pop(buf []byte) (*pageObject, bool) { + if len(buf) == 0 { + return nil, false + } + + addr := int(uintptr(unsafe.Pointer(&buf[0]))) + + a.Lock() + defer a.Unlock() + o, found := a.objects[addr] + if !found { + return nil, false + } + delete(a.objects, addr) + + return o, true +} + +// object holding each allocation +type pageObject struct { + data []byte // Portion of memory holding the data + memory []byte // Entire allocated memory region + + preguard []byte // Guard page addressed before the data + inner []byte // Inner region between the guard pages + postguard []byte // Guard page addressed after the data + + canary []byte // Value written behind data to detect spillage +} + +func (a *pageAllocator) newPageObject(size int) (*pageObject, error) { + // Round a length to a multiple of the system page size for page locking + // and protection + innerLen := roundToPageSize(size) + + // Allocate the total needed memory + memory, err := memcall.Alloc(2*pageSize + innerLen) + if err != nil { + return nil, err + } + + o := &pageObject{ + memory: memory, + // Construct slice reference for data buffer. + data: getBufferPart(memory, pageSize+innerLen-size, size), + // Construct slice references for page sectors. + preguard: getBufferPart(memory, 0, pageSize), + inner: getBufferPart(memory, pageSize, innerLen), + postguard: getBufferPart(memory, pageSize+innerLen, pageSize), + } + // Construct slice reference for canary portion of inner page. + o.canary = getBufferPart(memory, pageSize, len(o.inner)-len(o.data)) + + // Lock the pages that will hold sensitive data. + if err := memcall.Lock(o.inner); err != nil { + return nil, err + } + + // Create a random signature for the protection pages and reuse the + // fitting part for the canary + if err := Scramble(o.preguard); err != nil { + return nil, err + } + Copy(o.postguard, o.preguard) + Copy(o.canary, o.preguard) + + // Make the guard pages inaccessible. + if err := memcall.Protect(o.preguard, memcall.NoAccess()); err != nil { + return nil, err + } + if err := memcall.Protect(o.postguard, memcall.NoAccess()); err != nil { + return nil, err + } + + return o, nil +} + +func (o *pageObject) wipe() error { + // Make all of the memory readable and writable. + var partialUnprotect bool + if err := memcall.Protect(o.memory, memcall.ReadWrite()); err != nil { + partialUnprotect = true + if partialErr := memcall.Protect(o.inner, memcall.ReadWrite()); partialErr != nil { + fmt.Fprintf(os.Stderr, "!WARNING: failed to wipe immutable data at address %p: %v", &o.data, partialErr) + return err + } + } + + // Wipe data field. + Wipe(o.data) + o.data = nil + + // Verify the guards and canary + if !Equal(o.preguard, o.postguard) || !Equal(o.preguard[:len(o.canary)], o.canary) { + return ErrBufferOverflow + } + + // Wipe the whole memory region if we were able to switch it to mutable. + if !partialUnprotect { + Wipe(o.memory) + } + + // Unlock pages locked into memory. + if err := memcall.Unlock(o.inner); err != nil { + return err + } + + // Reset the fields. + o.data = nil + o.inner = nil + o.preguard = nil + o.postguard = nil + o.canary = nil + + return nil +} diff --git a/core/memallocator_page_test.go b/core/memallocator_page_test.go new file mode 100644 index 0000000..eaa4172 --- /dev/null +++ b/core/memallocator_page_test.go @@ -0,0 +1,108 @@ +package core + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPageAllocAllocInvalidSize(t *testing.T) { + alloc := NewPageAllocator() + + a, err := alloc.Alloc(0) + require.Nil(t, a) + require.ErrorIs(t, err, ErrNullAlloc) + + b, err := alloc.Alloc(-1) + require.Nil(t, b) + require.ErrorIs(t, err, ErrNullAlloc) +} + +func TestPageAllocAlloc(t *testing.T) { + alloc := NewPageAllocator() + + b, err := alloc.Alloc(32) + require.NoError(t, err) + require.Lenf(t, b, 32, "invalid buffer len %d", len(b)) + + o, found := alloc.(*pageAllocator).lookup(b) + require.True(t, found) + require.Lenf(t, o.data, 32, "invalid data len %d", len(o.data)) + require.Equalf(t, cap(o.data), 32, "invalid data capacity %d", cap(o.data)) + require.Len(t, o.memory, 3*pageSize) + require.EqualValues(t, make([]byte, 32), o.data, "container is not zero-filled") + + // Destroy the buffer. + require.NoError(t, alloc.Free(b)) +} + +func TestPageAllocLotsOfAllocs(t *testing.T) { + // Create a local allocator instance + alloc := NewPageAllocator() + palloc := alloc.(*pageAllocator) + + for i := 1; i <= 16385; i++ { + b, err := alloc.Alloc(i) + require.NoErrorf(t, err, "size: %d", i) + + o, found := palloc.lookup(b) + require.True(t, found) + + require.Lenf(t, o.data, i, "size: %d", i) + require.Lenf(t, o.memory, roundToPageSize(i)+2*pageSize, "memory length invalid size: %d", i) + require.Lenf(t, o.preguard, pageSize, "pre-guard length invalid size: %d", i) + require.Lenf(t, o.postguard, pageSize, "pre-guard length invalid size: %d", i) + require.Lenf(t, o.canary, len(o.inner)-i, "canary length invalid size: %d", i) + require.Zerof(t, len(o.inner)%pageSize, "inner length is not multiple of page size size: %d", i) + + // Fill the data + for j := range o.data { + o.data[j] = 1 + } + require.EqualValuesf(t, bytes.Repeat([]byte{1}, i), o.data, "region rw test failed", "size: %d", i) + require.NoErrorf(t, alloc.Free(b), "size: %d", i) + } +} + +func TestPageAllocDestroy(t *testing.T) { + alloc := NewPageAllocator() + + // Allocate a new buffer. + b, err := alloc.Alloc(32) + require.NoError(t, err) + + o, found := alloc.(*pageAllocator).lookup(b) + require.True(t, found) + + // Destroy it and check it is gone... + require.NoError(t, o.wipe()) + + // Pick apart the destruction. + require.Nil(t, o.data, "data not nil") + require.Nil(t, o.inner, "inner not nil") + require.Nil(t, o.preguard, "preguard not nil") + require.Nil(t, o.postguard, "postguard not nil") + require.Nil(t, o.canary, "canary not nil") + require.EqualValues(t, make([]byte, len(o.memory)), o.memory, "memory not zero'ed") + + // Call destroy again to check idempotency. + require.NoError(t, alloc.Free(b)) +} + +func TestPageAllocOverflow(t *testing.T) { + alloc := NewPageAllocator() + + // Allocate a new buffer. + b, err := alloc.Alloc(32) + require.NoError(t, err) + + o, found := alloc.(*pageAllocator).lookup(b) + require.True(t, found) + + // Modify the canary as if we overflow + o.canary[0] = ^o.canary[0] + + // Destroy it and check it is gone... + require.ErrorIs(t, alloc.Free(b), ErrBufferOverflow) +} diff --git a/core/memallocator_slab.go b/core/memallocator_slab.go new file mode 100644 index 0000000..d54e08c --- /dev/null +++ b/core/memallocator_slab.go @@ -0,0 +1,342 @@ +package core + +import ( + "errors" + "sort" + "sync" + "unsafe" +) + +var ( + ErrSlabNoCacheFound = errors.New("no slab cache matching request") + ErrSlabTooLarge = errors.New("requested size too large") +) + +type SlabAllocatorConfig struct { + MinCanarySize int + Sizes []int +} + +// Configuration options +type SlabOption func(*SlabAllocatorConfig) + +// WithSizes allows to overwrite the SLAB Page sizes, defaulting to +// 64, 128, 256, 512, 1024 and 2048 byte +func WithSizes(sizes []int) SlabOption { + return func(cfg *SlabAllocatorConfig) { + cfg.Sizes = sizes + } +} + +// WithMinCanarySize allows to specify the minimum canary size (default: 16 byte) +func WithMinCanarySize(size int) SlabOption { + return func(cfg *SlabAllocatorConfig) { + cfg.MinCanarySize = size + } +} + +// Memory allocator implementation +type slabAllocator struct { + maxSlabSize int + cfg *SlabAllocatorConfig + allocator *pageAllocator + slabs []*slab +} + +func NewSlabAllocator(options ...SlabOption) MemAllocator { + cfg := &SlabAllocatorConfig{ + MinCanarySize: 16, + Sizes: []int{64, 128, 256, 512, 1024, 2048}, + } + for _, o := range options { + o(cfg) + } + sort.Ints(cfg.Sizes) + + if len(cfg.Sizes) == 0 { + return nil + } + + // Setup the allocator and initialize the slabs + a := &slabAllocator{ + maxSlabSize: cfg.Sizes[len(cfg.Sizes)-1], + cfg: cfg, + slabs: make([]*slab, 0, len(cfg.Sizes)), + allocator: &pageAllocator{ + objects: make(map[int]*pageObject), + }, + } + for _, size := range cfg.Sizes { + s := &slab{ + objSize: size, + allocator: a.allocator, + } + a.slabs = append(a.slabs, s) + } + + return a +} + +func (a *slabAllocator) Alloc(size int) ([]byte, error) { + if size < 1 { + return nil, ErrNullAlloc + } + + // If the requested size is bigger than the largest slab, just malloc + // the memory. + requiredSlabSize := size + a.cfg.MinCanarySize + if requiredSlabSize > a.maxSlabSize { + return a.allocator.Alloc(size) + } + + // Determine which slab to use depending on the size + var s *slab + for _, current := range a.slabs { + if requiredSlabSize <= current.objSize { + s = current + break + } + } + if s == nil { + return nil, ErrSlabNoCacheFound + } + buf, err := s.alloc(size) + if err != nil { + return nil, err + } + + // Trunc the buffer to the required size if requested + return buf, nil +} + +func (a *slabAllocator) Protect(buf []byte, readonly bool) error { + // For the slab allocator, the data-slice is not identical to a memory page. + // However, protection rules can only be applied to whole memory pages, + // therefore protection of the data-slice is not supported by the slab + // allocator. + return nil +} + +func (a *slabAllocator) Inner(buf []byte) []byte { + if len(buf) == 0 { + return nil + } + + // If the buffer size is bigger than the largest slab, just free + // the memory. + size := len(buf) + a.cfg.MinCanarySize + if size > a.maxSlabSize { + return a.allocator.Inner(buf) + } + + // Determine which slab to use depending on the size + var s *slab + for _, current := range a.slabs { + if size <= current.objSize { + s = current + break + } + } + if s == nil { + Panic(ErrSlabNoCacheFound) + } + + for _, c := range s.pages { + if offset, contained := contains(c.buffer, buf); contained { + return c.buffer[offset : offset+s.objSize] + } + } + return nil +} + +func (a *slabAllocator) Free(buf []byte) error { + size := len(buf) + a.cfg.MinCanarySize + + // If the buffer size is bigger than the largest slab, just free + // the memory. + if size > a.maxSlabSize { + return a.allocator.Free(buf) + } + + // Determine which slab to use depending on the size + var s *slab + for _, current := range a.slabs { + if size <= current.objSize { + s = current + break + } + } + if s == nil { + return ErrSlabNoCacheFound + } + + return s.free(buf) +} + +// *** INTERNAL FUNCTIONS *** // + +// Page implementation +type slabObject struct { + offset int + next *slabObject +} + +type slabPage struct { + used int + head *slabObject + canary []byte + buffer []byte +} + +func newPage(page []byte, size int) *slabPage { + if size > len(page) || size < 1 { + Panic(ErrSlabTooLarge) + } + + // Determine the number of objects fitting into the page + count := len(page) / size + + // Init the Page meta-data + c := &slabPage{ + head: &slabObject{}, + canary: page[len(page)-size:], + buffer: page, + } + + // Use the last object to create a canary prototype + if err := Scramble(c.canary); err != nil { + Panic(err) + } + + // Initialize the objects + last := c.head + offset := size + for i := 1; i < count-1; i++ { + obj := &slabObject{offset: offset} + last.next = obj + offset += size + last = obj + } + + return c +} + +// Slab is a container for all Pages serving the same size +type slab struct { + objSize int + allocator *pageAllocator + pages []*slabPage + sync.Mutex +} + +func (s *slab) alloc(size int) ([]byte, error) { + s.Lock() + defer s.Unlock() + + // Find the fullest Page that isn't completely filled + var c *slabPage + for _, current := range s.pages { + if current.head != nil && (c == nil || current.used > c.used) { + c = current + } + } + + // No Page available, create a new one + if c == nil { + // Use the page allocator to get a new guarded memory page + page, err := s.allocator.Alloc(pageSize - s.objSize) + if err != nil { + return nil, err + } + c = newPage(page, s.objSize) + s.pages = append(s.pages, c) + } + + // Remove the object from the free-list and increase the usage count + obj := c.head + c.head = c.head.next + c.used++ + + data := getBufferPart(c.buffer, obj.offset, size) + canary := getBufferPart(c.buffer, obj.offset+size, s.objSize-size) + + // Fill in the remaining bytes with canary + Copy(canary, c.canary) + + return data, nil +} + +func contains(buf, obj []byte) (int, bool) { + bb := uintptr(unsafe.Pointer(&buf[0])) + be := uintptr(unsafe.Pointer(&buf[len(buf)-1])) + o := uintptr(unsafe.Pointer(&obj[0])) + + if bb <= be { + return int(o - bb), bb <= o && o < be + } + return int(o - be), be <= o && o < bb +} + +func (s *slab) free(buf []byte) error { + s.Lock() + defer s.Unlock() + + // Find the Page containing the object + var c *slabPage + var cidx, offset int + for i, current := range s.pages { + diff, contained := contains(current.buffer, buf) + if contained { + c = current + cidx = i + offset = diff + break + } + } + if c == nil { + return ErrBufferNotOwnedByAllocator + } + + // Wipe the buffer including the canary check + if err := s.wipe(c, offset, len(buf)); err != nil { + return err + } + obj := &slabObject{ + offset: offset, + next: c.head, + } + c.head = obj + c.used-- + + // In case the Page is completely empty, we should remove it and + // free the underlying memory + if c.used == 0 { + err := s.allocator.Free(c.buffer) + if err != nil { + return err + } + + s.pages = append(s.pages[:cidx], s.pages[cidx+1:]...) + } + + return nil +} + +func (s *slab) wipe(page *slabPage, offset, size int) error { + canary := getBufferPart(page.buffer, -s.objSize, s.objSize) + inner := getBufferPart(page.buffer, offset, s.objSize) + data := getBufferPart(page.buffer, offset, size) + + // Wipe data field + Wipe(data) + + // Verify the canary + if !Equal(inner[len(data):], canary[:size]) { + return ErrBufferOverflow + } + + // Wipe the memory + Wipe(inner) + + return nil +} diff --git a/core/memallocator_slab_test.go b/core/memallocator_slab_test.go new file mode 100644 index 0000000..cfbcecb --- /dev/null +++ b/core/memallocator_slab_test.go @@ -0,0 +1,106 @@ +package core + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSlabAllocAllocInvalidSize(t *testing.T) { + alloc := NewSlabAllocator() + + a, err := alloc.Alloc(0) + require.Nil(t, a) + require.ErrorIs(t, err, ErrNullAlloc) + + b, err := alloc.Alloc(-1) + require.Nil(t, b) + require.ErrorIs(t, err, ErrNullAlloc) +} + +func TestSlabAllocAlloc(t *testing.T) { + alloc := NewSlabAllocator() + + b, err := alloc.Alloc(32) + require.NoError(t, err) + require.Lenf(t, b, 32, "invalid buffer len %d", len(b)) + + require.Lenf(t, b, 32, "invalid data len %d", len(b)) + require.Equalf(t, cap(b), 32, "invalid data capacity %d", cap(b)) + // require.Len(t, o.memory, 3*pageSize) + // require.EqualValues(t, make([]byte, 32), o.data, "container is not zero-filled") + + // Destroy the buffer. + require.NoError(t, alloc.Free(b)) +} + +func TestSlabAllocLotsOfAllocs(t *testing.T) { + // Create a local allocator instance + alloc := NewPageAllocator() + palloc := alloc.(*pageAllocator) + + for i := 1; i <= 16385; i++ { + b, err := alloc.Alloc(i) + require.NoErrorf(t, err, "size: %d", i) + + o, found := palloc.lookup(b) + require.True(t, found) + + require.Lenf(t, o.data, i, "size: %d", i) + require.Lenf(t, o.memory, roundToPageSize(i)+2*pageSize, "memory length invalid size: %d", i) + require.Lenf(t, o.preguard, pageSize, "pre-guard length invalid size: %d", i) + require.Lenf(t, o.postguard, pageSize, "pre-guard length invalid size: %d", i) + require.Lenf(t, o.canary, len(o.inner)-i, "canary length invalid size: %d", i) + require.Zerof(t, len(o.inner)%pageSize, "inner length is not multiple of page size size: %d", i) + + // Fill the data + for j := range o.data { + o.data[j] = 1 + } + require.EqualValuesf(t, bytes.Repeat([]byte{1}, i), o.data, "region rw test failed", "size: %d", i) + require.NoErrorf(t, alloc.Free(b), "size: %d", i) + } +} + +func TestSlabAllocDestroy(t *testing.T) { + alloc := NewPageAllocator() + + // Allocate a new buffer. + b, err := alloc.Alloc(32) + require.NoError(t, err) + + o, found := alloc.(*pageAllocator).lookup(b) + require.True(t, found) + + // Destroy it and check it is gone... + require.NoError(t, o.wipe()) + + // Pick apart the destruction. + require.Nil(t, o.data, "data not nil") + require.Nil(t, o.inner, "inner not nil") + require.Nil(t, o.preguard, "preguard not nil") + require.Nil(t, o.postguard, "postguard not nil") + require.Nil(t, o.canary, "canary not nil") + require.EqualValues(t, make([]byte, len(o.memory)), o.memory, "memory not zero'ed") + + // Call destroy again to check idempotency. + require.NoError(t, alloc.Free(b)) +} + +func TestSlabAllocOverflow(t *testing.T) { + alloc := NewPageAllocator() + + // Allocate a new buffer. + b, err := alloc.Alloc(32) + require.NoError(t, err) + + o, found := alloc.(*pageAllocator).lookup(b) + require.True(t, found) + + // Modify the canary as if we overflow + o.canary[0] = ^o.canary[0] + + // Destroy it and check it is gone... + require.ErrorIs(t, alloc.Free(b), ErrBufferOverflow) +} diff --git a/go.mod b/go.mod index 35b0c8d..aba08bb 100644 --- a/go.mod +++ b/go.mod @@ -4,11 +4,15 @@ go 1.18 require ( github.com/awnumar/memcall v0.2.0 - golang.org/x/crypto v0.16.0 + github.com/stretchr/testify v1.8.4 + golang.org/x/crypto v0.18.0 + golang.org/x/sys v0.16.0 lukechampine.com/frand v1.4.2 ) require ( github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da // indirect - golang.org/x/sys v0.15.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 1ad020c..71775e8 100644 --- a/go.sum +++ b/go.sum @@ -2,11 +2,21 @@ github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da h1:KjTM2ks9d14ZYCvmH github.com/aead/chacha20 v0.0.0-20180709150244-8b13a72661da/go.mod h1:eHEWzANqSiWQsof+nXEI9bUVUyV6F53Fp89EuCh2EAA= github.com/awnumar/memcall v0.2.0 h1:sRaogqExTOOkkNwO9pzJsL8jrOV29UuUW7teRMfbqtI= github.com/awnumar/memcall v0.2.0/go.mod h1:S911igBPR9CThzd/hYQQmTc9SWNu3ZHIlCGaWsWsoJo= -golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= -golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +golang.org/x/crypto v0.18.0 h1:PGVlW0xEltQnzFZ55hkuX5+KLyrMYhHld1YHO4AKcdc= +golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= golang.org/x/sys v0.0.0-20190626221950-04f50cda93cb/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191001151750-bb3f8db39f24/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= lukechampine.com/frand v1.4.2 h1:RzFIpOvkMXuPMBb9maa4ND4wjBn71E1Jpf8BzJHMaVw= lukechampine.com/frand v1.4.2/go.mod h1:4S/TM2ZgrKejMcKMbeLjISpJMO+/eZ1zu3vYX9dtj3s=