From f69ae2612236386b9a2f872116de4f5da683d9b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= <66256922+daniel-weisse@users.noreply.github.com> Date: Thu, 5 Oct 2023 11:20:22 +0200 Subject: [PATCH] csi: fix concurrent use of `cryptmapper` package (#2408) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Dont error on opening already active devices * Fix concurrency issues when working with more than one device --------- Signed-off-by: Daniel Weiße --- csi/cryptmapper/cryptmapper.go | 61 ++++++++++++++++++---------- csi/cryptmapper/cryptmapper_test.go | 18 ++++++--- csi/test/mount_integration_test.go | 62 +++++++++++++++++++++++++---- 3 files changed, 105 insertions(+), 36 deletions(-) diff --git a/csi/cryptmapper/cryptmapper.go b/csi/cryptmapper/cryptmapper.go index 1572331629..90ece1df2c 100644 --- a/csi/cryptmapper/cryptmapper.go +++ b/csi/cryptmapper/cryptmapper.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "io/fs" + "os" "path/filepath" "strings" "time" @@ -33,7 +34,7 @@ const ( // CryptMapper manages dm-crypt volumes. type CryptMapper struct { - mapper deviceMapper + mapper func() deviceMapper kms keyCreator getDiskFormat func(disk string) (string, error) } @@ -42,7 +43,7 @@ type CryptMapper struct { // kms is used to fetch data encryption keys for the dm-crypt volumes. func New(kms keyCreator) *CryptMapper { return &CryptMapper{ - mapper: cryptsetup.New(), + mapper: func() deviceMapper { return cryptsetup.New() }, kms: kms, getDiskFormat: getDiskFormat, } @@ -87,22 +88,35 @@ func (c *CryptMapper) CloseCryptDevice(volumeID string) error { // The key used to encrypt the volume is fetched using CryptMapper's kms client. func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID string, integrity bool) (string, error) { // Initialize the block device - free, err := c.mapper.Init(source) + mapper := c.mapper() + free, err := mapper.Init(source) if err != nil { return "", fmt.Errorf("initializing dm-crypt to map device %q: %w", source, err) } defer free() + deviceName := filepath.Join(cryptPrefix, volumeID) var passphrase []byte // Try to load LUKS headers // If this fails, the device is either not formatted at all, or already formatted with a different FS - if err := c.mapper.LoadLUKS2(); err != nil { - passphrase, err = c.formatNewDevice(ctx, volumeID, source, integrity) + if err := mapper.LoadLUKS2(); err != nil { + passphrase, err = c.formatNewDevice(ctx, mapper, volumeID, source, integrity) if err != nil { return "", fmt.Errorf("formatting device: %w", err) } } else { - uuid, err := c.mapper.GetUUID() + // Check if device is already active + // If yes, this is a no-op + // Simply return the device name + if _, err := os.Stat(deviceName); err == nil { + _, err := os.Stat(deviceName + integritySuffix) + if integrity && err != nil { + return "", fmt.Errorf("device %s already exists, but integrity device %s is missing", deviceName, deviceName+integritySuffix) + } + return deviceName, nil + } + + uuid, err := mapper.GetUUID() if err != nil { return "", err } @@ -115,26 +129,27 @@ func (c *CryptMapper) OpenCryptDevice(ctx context.Context, source, volumeID stri } } - if err := c.mapper.ActivateByPassphrase(volumeID, 0, string(passphrase), cryptsetup.ReadWriteQueueBypass); err != nil { + if err := mapper.ActivateByPassphrase(volumeID, 0, string(passphrase), cryptsetup.ReadWriteQueueBypass); err != nil { return "", fmt.Errorf("trying to activate dm-crypt volume: %w", err) } - return cryptPrefix + volumeID, nil + return deviceName, nil } // ResizeCryptDevice resizes the underlying crypt device and returns the mapped device path. func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (string, error) { - free, err := c.mapper.InitByName(volumeID) + mapper := c.mapper() + free, err := mapper.InitByName(volumeID) if err != nil { return "", fmt.Errorf("initializing device: %w", err) } defer free() - if err := c.mapper.LoadLUKS2(); err != nil { + if err := mapper.LoadLUKS2(); err != nil { return "", fmt.Errorf("loading device: %w", err) } - uuid, err := c.mapper.GetUUID() + uuid, err := mapper.GetUUID() if err != nil { return "", err } @@ -143,11 +158,11 @@ func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (s return "", fmt.Errorf("getting key: %w", err) } - if err := c.mapper.ActivateByPassphrase("", 0, string(passphrase), resizeFlags); err != nil { + if err := mapper.ActivateByPassphrase("", 0, string(passphrase), resizeFlags); err != nil { return "", fmt.Errorf("activating keyring for crypt device %q with passphrase: %w", volumeID, err) } - if err := c.mapper.Resize(volumeID, 0); err != nil { + if err := mapper.Resize(volumeID, 0); err != nil { return "", fmt.Errorf("resizing device: %w", err) } @@ -156,14 +171,15 @@ func (c *CryptMapper) ResizeCryptDevice(ctx context.Context, volumeID string) (s // GetDevicePath returns the device path of a mapped crypt device. func (c *CryptMapper) GetDevicePath(volumeID string) (string, error) { + mapper := c.mapper() name := strings.TrimPrefix(volumeID, cryptPrefix) - free, err := c.mapper.InitByName(name) + free, err := mapper.InitByName(name) if err != nil { return "", fmt.Errorf("initializing device: %w", err) } defer free() - deviceName := c.mapper.GetDeviceName() + deviceName := mapper.GetDeviceName() if deviceName == "" { return "", errors.New("unable to determine device name") } @@ -172,20 +188,21 @@ func (c *CryptMapper) GetDevicePath(volumeID string) (string, error) { // closeCryptDevice closes the crypt device mapped for volumeID. func (c *CryptMapper) closeCryptDevice(source, volumeID, deviceType string) error { - free, err := c.mapper.InitByName(volumeID) + mapper := c.mapper() + free, err := mapper.InitByName(volumeID) if err != nil { return fmt.Errorf("initializing dm-%s to unmap device %q: %w", deviceType, source, err) } defer free() - if err := c.mapper.Deactivate(volumeID); err != nil { + if err := mapper.Deactivate(volumeID); err != nil { return fmt.Errorf("deactivating dm-%s volume %q for device %q: %w", deviceType, cryptPrefix+volumeID, source, err) } return nil } -func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source string, integrity bool) ([]byte, error) { +func (c *CryptMapper) formatNewDevice(ctx context.Context, mapper deviceMapper, volumeID, source string, integrity bool) ([]byte, error) { format, err := c.getDiskFormat(source) if err != nil { return nil, fmt.Errorf("determining if disk is formatted: %w", err) @@ -195,11 +212,11 @@ func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source stri } // Device is not formatted, so we can safely create a new LUKS2 partition - if err := c.mapper.Format(integrity); err != nil { + if err := mapper.Format(integrity); err != nil { return nil, fmt.Errorf("formatting device %q: %w", source, err) } - uuid, err := c.mapper.GetUUID() + uuid, err := mapper.GetUUID() if err != nil { return nil, err } @@ -212,7 +229,7 @@ func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source stri } // Add a new keyslot using the internal volume key - if err := c.mapper.KeyslotAddByVolumeKey(0, "", string(passphrase)); err != nil { + if err := mapper.KeyslotAddByVolumeKey(0, "", string(passphrase)); err != nil { return nil, fmt.Errorf("adding keyslot: %w", err) } @@ -222,7 +239,7 @@ func (c *CryptMapper) formatNewDevice(ctx context.Context, volumeID, source stri fmt.Printf("Wipe in progress: %.2f%%\n", prog) } - if err := c.mapper.Wipe(volumeID, 1024*1024, 0, logProgress, 30*time.Second); err != nil { + if err := mapper.Wipe(volumeID, 1024*1024, 0, logProgress, 30*time.Second); err != nil { return nil, fmt.Errorf("wiping device: %w", err) } } diff --git a/csi/cryptmapper/cryptmapper_test.go b/csi/cryptmapper/cryptmapper_test.go index cf521e202e..fbc9dcd637 100644 --- a/csi/cryptmapper/cryptmapper_test.go +++ b/csi/cryptmapper/cryptmapper_test.go @@ -46,7 +46,7 @@ func TestCloseCryptDevice(t *testing.T) { mapper := &CryptMapper{ kms: &fakeKMS{}, - mapper: tc.mapper, + mapper: testMapper(tc.mapper), } err := mapper.closeCryptDevice("/dev/mapper/volume01", "volume01-unit-test", "crypt") if tc.wantErr { @@ -58,7 +58,7 @@ func TestCloseCryptDevice(t *testing.T) { } mapper := &CryptMapper{ - mapper: &stubCryptDevice{}, + mapper: testMapper(&stubCryptDevice{}), kms: &fakeKMS{}, getDiskFormat: getDiskFormat, } @@ -197,7 +197,7 @@ func TestOpenCryptDevice(t *testing.T) { assert := assert.New(t) mapper := &CryptMapper{ - mapper: tc.mapper, + mapper: testMapper(tc.mapper), kms: tc.kms, getDiskFormat: tc.diskInfo, } @@ -219,7 +219,7 @@ func TestOpenCryptDevice(t *testing.T) { } mapper := &CryptMapper{ - mapper: &stubCryptDevice{}, + mapper: testMapper(&stubCryptDevice{}), kms: &fakeKMS{}, getDiskFormat: getDiskFormat, } @@ -267,7 +267,7 @@ func TestResizeCryptDevice(t *testing.T) { mapper := &CryptMapper{ kms: &fakeKMS{}, - mapper: tc.device, + mapper: testMapper(tc.device), } res, err := mapper.ResizeCryptDevice(context.Background(), tc.volumeID) @@ -310,7 +310,7 @@ func TestGetDevicePath(t *testing.T) { assert := assert.New(t) mapper := &CryptMapper{ - mapper: tc.device, + mapper: testMapper(tc.device), } res, err := mapper.GetDevicePath(tc.volumeID) @@ -451,3 +451,9 @@ func (c *stubCryptDevice) Wipe(_ string, _ int, _ int, _ func(size, offset uint6 func (c *stubCryptDevice) Resize(_ string, _ uint64) error { return c.resizeErr } + +func testMapper(stub *stubCryptDevice) func() deviceMapper { + return func() deviceMapper { + return stub + } +} diff --git a/csi/test/mount_integration_test.go b/csi/test/mount_integration_test.go index 1075758c04..bfa1337218 100644 --- a/csi/test/mount_integration_test.go +++ b/csi/test/mount_integration_test.go @@ -13,6 +13,7 @@ import ( "fmt" "os" "os/exec" + "sync" "testing" "github.com/edgelesssys/constellation/v2/csi/cryptmapper" @@ -23,10 +24,10 @@ import ( const ( devicePath string = "testDevice" - deviceName string = "testdeviceName" + deviceName string = "testDeviceName" ) -func setup() { +func setup(devicePath string) { if err := exec.Command("/bin/dd", "if=/dev/zero", fmt.Sprintf("of=%s", devicePath), "bs=64M", "count=1").Run(); err != nil { panic(err) } @@ -42,7 +43,7 @@ func cp(source, target string) error { return exec.Command("cp", source, target).Run() } -func resize() { +func resize(devicePath string) { if err := exec.Command("/bin/dd", "if=/dev/zero", fmt.Sprintf("of=%s", devicePath), "bs=32M", "count=1", "oflag=append", "conv=notrunc").Run(); err != nil { panic(err) } @@ -63,7 +64,7 @@ func TestMain(m *testing.M) { func TestOpenAndClose(t *testing.T) { assert := assert.New(t) require := require.New(t) - setup() + setup(devicePath) defer teardown(devicePath) mapper := cryptmapper.New(&fakeKMS{}) @@ -81,8 +82,13 @@ func TestOpenAndClose(t *testing.T) { _, err = os.Stat(newPath + "_dif") assert.True(os.IsNotExist(err)) + // Opening the same device should return the same path and not error + newPath2, err := mapper.OpenCryptDevice(context.Background(), devicePath, deviceName, false) + require.NoError(err) + assert.Equal(newPath, newPath2) + // Resize the device - resize() + resize(devicePath) resizedPath, err := mapper.ResizeCryptDevice(context.Background(), deviceName) require.NoError(err) @@ -103,7 +109,7 @@ func TestOpenAndClose(t *testing.T) { func TestOpenAndCloseIntegrity(t *testing.T) { assert := assert.New(t) require := require.New(t) - setup() + setup(devicePath) defer teardown(devicePath) mapper := cryptmapper.New(&fakeKMS{}) @@ -119,8 +125,13 @@ func TestOpenAndCloseIntegrity(t *testing.T) { _, err = os.Stat(newPath + "_dif") assert.NoError(err) + // Opening the same device should return the same path and not error + newPath2, err := mapper.OpenCryptDevice(context.Background(), devicePath, deviceName, true) + require.NoError(err) + assert.Equal(newPath, newPath2) + // integrity devices do not support resizing - resize() + resize(devicePath) _, err = mapper.ResizeCryptDevice(context.Background(), deviceName) assert.Error(err) @@ -142,7 +153,7 @@ func TestOpenAndCloseIntegrity(t *testing.T) { func TestDeviceCloning(t *testing.T) { assert := assert.New(t) require := require.New(t) - setup() + setup(devicePath) defer teardown(devicePath) mapper := cryptmapper.New(&dynamicKMS{}) @@ -160,6 +171,41 @@ func TestDeviceCloning(t *testing.T) { assert.NoError(mapper.CloseCryptDevice(deviceName + "-copy")) } +func TestConcurrency(t *testing.T) { + assert := assert.New(t) + setup(devicePath) + defer teardown(devicePath) + + device2 := devicePath + "-2" + setup(device2) + defer teardown(device2) + + mapper := cryptmapper.New(&fakeKMS{}) + + wg := sync.WaitGroup{} + runTest := func(path, name string) { + newPath, err := mapper.OpenCryptDevice(context.Background(), path, name, false) + assert.NoError(err) + defer func() { + _ = mapper.CloseCryptDevice(name) + }() + + // assert crypt device got created + _, err = os.Stat(newPath) + assert.NoError(err) + // assert no integrity device got created + _, err = os.Stat(newPath + "_dif") + assert.True(os.IsNotExist(err)) + assert.NoError(mapper.CloseCryptDevice(name)) + wg.Done() + } + + wg.Add(2) + go runTest(devicePath, deviceName) + go runTest(device2, deviceName+"-2") + wg.Wait() +} + type fakeKMS struct{} func (k *fakeKMS) GetDEK(_ context.Context, _ string, dekSize int) ([]byte, error) {