Skip to content

Commit

Permalink
csi: fix concurrent use of cryptmapper package (#2408)
Browse files Browse the repository at this point in the history
* Dont error on opening already active devices

* Fix concurrency issues when working with more than one device

---------

Signed-off-by: Daniel Weiße <[email protected]>
  • Loading branch information
daniel-weisse authored Oct 5, 2023
1 parent 6ba43b0 commit f69ae26
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 36 deletions.
61 changes: 39 additions & 22 deletions csi/cryptmapper/cryptmapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
Expand All @@ -33,7 +34,7 @@ const (

// CryptMapper manages dm-crypt volumes.
type CryptMapper struct {
mapper deviceMapper
mapper func() deviceMapper
kms keyCreator
getDiskFormat func(disk string) (string, error)
}
Expand All @@ -42,7 +43,7 @@ type CryptMapper struct {
// kms is used to fetch data encryption keys for the dm-crypt volumes.
func New(kms keyCreator) *CryptMapper {
return &CryptMapper{
mapper: cryptsetup.New(),
mapper: func() deviceMapper { return cryptsetup.New() },
kms: kms,
getDiskFormat: getDiskFormat,
}
Expand Down Expand Up @@ -87,22 +88,35 @@ func (c *CryptMapper) CloseCryptDevice(volumeID string) error {
// The key used to encrypt the volume is fetched using CryptMapper's kms client.
func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID string, integrity bool) (string, error) {
// Initialize the block device
free, err := c.mapper.Init(source)
mapper := c.mapper()
free, err := mapper.Init(source)
if err != nil {
return "", fmt.Errorf("initializing dm-crypt to map device %q: %w", source, err)
}
defer free()

deviceName := filepath.Join(cryptPrefix, volumeID)
var passphrase []byte
// Try to load LUKS headers
// If this fails, the device is either not formatted at all, or already formatted with a different FS
if err := c.mapper.LoadLUKS2(); err != nil {
passphrase, err = c.formatNewDevice(ctx, volumeID, source, integrity)
if err := mapper.LoadLUKS2(); err != nil {
passphrase, err = c.formatNewDevice(ctx, mapper, volumeID, source, integrity)
if err != nil {
return "", fmt.Errorf("formatting device: %w", err)
}
} else {
uuid, err := c.mapper.GetUUID()
// Check if device is already active
// If yes, this is a no-op
// Simply return the device name
if _, err := os.Stat(deviceName); err == nil {
_, err := os.Stat(deviceName + integritySuffix)
if integrity && err != nil {
return "", fmt.Errorf("device %s already exists, but integrity device %s is missing", deviceName, deviceName+integritySuffix)
}
return deviceName, nil
}

uuid, err := mapper.GetUUID()
if err != nil {
return "", err
}
Expand All @@ -115,26 +129,27 @@ func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID stri
}
}

if err := c.mapper.ActivateByPassphrase(volumeID, 0, string(passphrase), cryptsetup.ReadWriteQueueBypass); err != nil {
if err := mapper.ActivateByPassphrase(volumeID, 0, string(passphrase), cryptsetup.ReadWriteQueueBypass); err != nil {
return "", fmt.Errorf("trying to activate dm-crypt volume: %w", err)
}

return cryptPrefix + volumeID, nil
return deviceName, nil
}

// ResizeCryptDevice resizes the underlying crypt device and returns the mapped device path.
func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (string, error) {
free, err := c.mapper.InitByName(volumeID)
mapper := c.mapper()
free, err := mapper.InitByName(volumeID)
if err != nil {
return "", fmt.Errorf("initializing device: %w", err)
}
defer free()

if err := c.mapper.LoadLUKS2(); err != nil {
if err := mapper.LoadLUKS2(); err != nil {
return "", fmt.Errorf("loading device: %w", err)
}

uuid, err := c.mapper.GetUUID()
uuid, err := mapper.GetUUID()
if err != nil {
return "", err
}
Expand All @@ -143,11 +158,11 @@ func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (s
return "", fmt.Errorf("getting key: %w", err)
}

if err := c.mapper.ActivateByPassphrase("", 0, string(passphrase), resizeFlags); err != nil {
if err := mapper.ActivateByPassphrase("", 0, string(passphrase), resizeFlags); err != nil {
return "", fmt.Errorf("activating keyring for crypt device %q with passphrase: %w", volumeID, err)
}

if err := c.mapper.Resize(volumeID, 0); err != nil {
if err := mapper.Resize(volumeID, 0); err != nil {
return "", fmt.Errorf("resizing device: %w", err)
}

Expand All @@ -156,14 +171,15 @@ func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (s

// GetDevicePath returns the device path of a mapped crypt device.
func (c *CryptMapper) GetDevicePath(volumeID string) (string, error) {
mapper := c.mapper()
name := strings.TrimPrefix(volumeID, cryptPrefix)
free, err := c.mapper.InitByName(name)
free, err := mapper.InitByName(name)
if err != nil {
return "", fmt.Errorf("initializing device: %w", err)
}
defer free()

deviceName := c.mapper.GetDeviceName()
deviceName := mapper.GetDeviceName()
if deviceName == "" {
return "", errors.New("unable to determine device name")
}
Expand All @@ -172,20 +188,21 @@ func (c *CryptMapper) GetDevicePath(volumeID string) (string, error) {

// closeCryptDevice closes the crypt device mapped for volumeID.
func (c *CryptMapper) closeCryptDevice(source, volumeID, deviceType string) error {
free, err := c.mapper.InitByName(volumeID)
mapper := c.mapper()
free, err := mapper.InitByName(volumeID)
if err != nil {
return fmt.Errorf("initializing dm-%s to unmap device %q: %w", deviceType, source, err)
}
defer free()

if err := c.mapper.Deactivate(volumeID); err != nil {
if err := mapper.Deactivate(volumeID); err != nil {
return fmt.Errorf("deactivating dm-%s volume %q for device %q: %w", deviceType, cryptPrefix+volumeID, source, err)
}

return nil
}

func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source string, integrity bool) ([]byte, error) {
func (c *CryptMapper) formatNewDevice(ctx context.Context, mapper deviceMapper, volumeID, source string, integrity bool) ([]byte, error) {
format, err := c.getDiskFormat(source)
if err != nil {
return nil, fmt.Errorf("determining if disk is formatted: %w", err)
Expand All @@ -195,11 +212,11 @@ func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source stri
}

// Device is not formatted, so we can safely create a new LUKS2 partition
if err := c.mapper.Format(integrity); err != nil {
if err := mapper.Format(integrity); err != nil {
return nil, fmt.Errorf("formatting device %q: %w", source, err)
}

uuid, err := c.mapper.GetUUID()
uuid, err := mapper.GetUUID()
if err != nil {
return nil, err
}
Expand All @@ -212,7 +229,7 @@ func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source stri
}

// Add a new keyslot using the internal volume key
if err := c.mapper.KeyslotAddByVolumeKey(0, "", string(passphrase)); err != nil {
if err := mapper.KeyslotAddByVolumeKey(0, "", string(passphrase)); err != nil {
return nil, fmt.Errorf("adding keyslot: %w", err)
}

Expand All @@ -222,7 +239,7 @@ func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source stri
fmt.Printf("Wipe in progress: %.2f%%\n", prog)
}

if err := c.mapper.Wipe(volumeID, 1024*1024, 0, logProgress, 30*time.Second); err != nil {
if err := mapper.Wipe(volumeID, 1024*1024, 0, logProgress, 30*time.Second); err != nil {
return nil, fmt.Errorf("wiping device: %w", err)
}
}
Expand Down
18 changes: 12 additions & 6 deletions csi/cryptmapper/cryptmapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func TestCloseCryptDevice(t *testing.T) {

mapper := &CryptMapper{
kms: &fakeKMS{},
mapper: tc.mapper,
mapper: testMapper(tc.mapper),
}
err := mapper.closeCryptDevice("/dev/mapper/volume01", "volume01-unit-test", "crypt")
if tc.wantErr {
Expand All @@ -58,7 +58,7 @@ func TestCloseCryptDevice(t *testing.T) {
}

mapper := &CryptMapper{
mapper: &stubCryptDevice{},
mapper: testMapper(&stubCryptDevice{}),
kms: &fakeKMS{},
getDiskFormat: getDiskFormat,
}
Expand Down Expand Up @@ -197,7 +197,7 @@ func TestOpenCryptDevice(t *testing.T) {
assert := assert.New(t)

mapper := &CryptMapper{
mapper: tc.mapper,
mapper: testMapper(tc.mapper),
kms: tc.kms,
getDiskFormat: tc.diskInfo,
}
Expand All @@ -219,7 +219,7 @@ func TestOpenCryptDevice(t *testing.T) {
}

mapper := &CryptMapper{
mapper: &stubCryptDevice{},
mapper: testMapper(&stubCryptDevice{}),
kms: &fakeKMS{},
getDiskFormat: getDiskFormat,
}
Expand Down Expand Up @@ -267,7 +267,7 @@ func TestResizeCryptDevice(t *testing.T) {

mapper := &CryptMapper{
kms: &fakeKMS{},
mapper: tc.device,
mapper: testMapper(tc.device),
}

res, err := mapper.ResizeCryptDevice(context.Background(), tc.volumeID)
Expand Down Expand Up @@ -310,7 +310,7 @@ func TestGetDevicePath(t *testing.T) {
assert := assert.New(t)

mapper := &CryptMapper{
mapper: tc.device,
mapper: testMapper(tc.device),
}

res, err := mapper.GetDevicePath(tc.volumeID)
Expand Down Expand Up @@ -451,3 +451,9 @@ func (c *stubCryptDevice) Wipe(_ string, _ int, _ int, _ func(size, offset uint6
func (c *stubCryptDevice) Resize(_ string, _ uint64) error {
return c.resizeErr
}

func testMapper(stub *stubCryptDevice) func() deviceMapper {
return func() deviceMapper {
return stub
}
}
62 changes: 54 additions & 8 deletions csi/test/mount_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"fmt"
"os"
"os/exec"
"sync"
"testing"

"github.com/edgelesssys/constellation/v2/csi/cryptmapper"
Expand All @@ -23,10 +24,10 @@ import (

const (
devicePath string = "testDevice"
deviceName string = "testdeviceName"
deviceName string = "testDeviceName"
)

func setup() {
func setup(devicePath string) {
if err := exec.Command("/bin/dd", "if=/dev/zero", fmt.Sprintf("of=%s", devicePath), "bs=64M", "count=1").Run(); err != nil {
panic(err)
}
Expand All @@ -42,7 +43,7 @@ func cp(source, target string) error {
return exec.Command("cp", source, target).Run()
}

func resize() {
func resize(devicePath string) {
if err := exec.Command("/bin/dd", "if=/dev/zero", fmt.Sprintf("of=%s", devicePath), "bs=32M", "count=1", "oflag=append", "conv=notrunc").Run(); err != nil {
panic(err)
}
Expand All @@ -63,7 +64,7 @@ func TestMain(m *testing.M) {
func TestOpenAndClose(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
setup()
setup(devicePath)
defer teardown(devicePath)

mapper := cryptmapper.New(&fakeKMS{})
Expand All @@ -81,8 +82,13 @@ func TestOpenAndClose(t *testing.T) {
_, err = os.Stat(newPath + "_dif")
assert.True(os.IsNotExist(err))

// Opening the same device should return the same path and not error
newPath2, err := mapper.OpenCryptDevice(context.Background(), devicePath, deviceName, false)
require.NoError(err)
assert.Equal(newPath, newPath2)

// Resize the device
resize()
resize(devicePath)

resizedPath, err := mapper.ResizeCryptDevice(context.Background(), deviceName)
require.NoError(err)
Expand All @@ -103,7 +109,7 @@ func TestOpenAndClose(t *testing.T) {
func TestOpenAndCloseIntegrity(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
setup()
setup(devicePath)
defer teardown(devicePath)

mapper := cryptmapper.New(&fakeKMS{})
Expand All @@ -119,8 +125,13 @@ func TestOpenAndCloseIntegrity(t *testing.T) {
_, err = os.Stat(newPath + "_dif")
assert.NoError(err)

// Opening the same device should return the same path and not error
newPath2, err := mapper.OpenCryptDevice(context.Background(), devicePath, deviceName, true)
require.NoError(err)
assert.Equal(newPath, newPath2)

// integrity devices do not support resizing
resize()
resize(devicePath)
_, err = mapper.ResizeCryptDevice(context.Background(), deviceName)
assert.Error(err)

Expand All @@ -142,7 +153,7 @@ func TestOpenAndCloseIntegrity(t *testing.T) {
func TestDeviceCloning(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
setup()
setup(devicePath)
defer teardown(devicePath)

mapper := cryptmapper.New(&dynamicKMS{})
Expand All @@ -160,6 +171,41 @@ func TestDeviceCloning(t *testing.T) {
assert.NoError(mapper.CloseCryptDevice(deviceName + "-copy"))
}

func TestConcurrency(t *testing.T) {
assert := assert.New(t)
setup(devicePath)
defer teardown(devicePath)

device2 := devicePath + "-2"
setup(device2)
defer teardown(device2)

mapper := cryptmapper.New(&fakeKMS{})

wg := sync.WaitGroup{}
runTest := func(path, name string) {
newPath, err := mapper.OpenCryptDevice(context.Background(), path, name, false)
assert.NoError(err)
defer func() {
_ = mapper.CloseCryptDevice(name)
}()

// assert crypt device got created
_, err = os.Stat(newPath)
assert.NoError(err)
// assert no integrity device got created
_, err = os.Stat(newPath + "_dif")
assert.True(os.IsNotExist(err))
assert.NoError(mapper.CloseCryptDevice(name))
wg.Done()
}

wg.Add(2)
go runTest(devicePath, deviceName)
go runTest(device2, deviceName+"-2")
wg.Wait()
}

type fakeKMS struct{}

func (k *fakeKMS) GetDEK(_ context.Context, _ string, dekSize int) ([]byte, error) {
Expand Down

0 comments on commit f69ae26

Please sign in to comment.