diff --git a/lib/devicetrust/enroll/auto_enroll.go b/lib/devicetrust/enroll/auto_enroll.go index 4cc6663a8a9c..8bf8d4ee5ca0 100644 --- a/lib/devicetrust/enroll/auto_enroll.go +++ b/lib/devicetrust/enroll/auto_enroll.go @@ -20,6 +20,9 @@ package enroll import ( "context" + "errors" + "os" + "strconv" "github.com/gravitational/trace" @@ -27,6 +30,12 @@ import ( "github.com/gravitational/teleport/lib/devicetrust/native" ) +// ErrAutoEnrollDisabled signifies that auto-enroll is disabled in the current +// device. +// Setting the TELEPORT_DEVICE_AUTO_ENROLL_DISABLED=1 environment disables +// auto-enroll. +var ErrAutoEnrollDisabled = errors.New("auto-enroll disabled") + // AutoEnrollCeremony is the auto-enrollment version of [Ceremony]. type AutoEnrollCeremony struct { *Ceremony @@ -53,6 +62,11 @@ func AutoEnroll(ctx context.Context, devicesClient devicepb.DeviceTrustServiceCl // [devicepb.DeviceTrustServiceClient.CreateDeviceEnrollToken] and enrolls the // device using a regular [Ceremony]. func (c *AutoEnrollCeremony) Run(ctx context.Context, devicesClient devicepb.DeviceTrustServiceClient) (*devicepb.Device, error) { + const autoEnrollDisabledKey = "TELEPORT_DEVICE_AUTO_ENROLL_DISABLED" + if disabled, _ := strconv.ParseBool(os.Getenv(autoEnrollDisabledKey)); disabled { + return nil, trace.Wrap(ErrAutoEnrollDisabled) + } + cd, err := c.CollectDeviceData(native.CollectedDataAlwaysEscalate) if err != nil { return nil, trace.Wrap(err, "collecting device data") diff --git a/lib/devicetrust/enroll/auto_enroll_test.go b/lib/devicetrust/enroll/auto_enroll_test.go index 4b29db639247..d5b9e3aa1bea 100644 --- a/lib/devicetrust/enroll/auto_enroll_test.go +++ b/lib/devicetrust/enroll/auto_enroll_test.go @@ -20,6 +20,7 @@ package enroll_test import ( "context" + "os" "testing" "github.com/stretchr/testify/assert" @@ -68,3 +69,10 @@ func TestAutoEnrollCeremony_Run(t *testing.T) { }) } } + +func TestAutoEnroll_disabledByEnv(t *testing.T) { + os.Setenv("TELEPORT_DEVICE_AUTO_ENROLL_DISABLED", "1") + + _, err := enroll.AutoEnroll(context.Background(), nil /* devicesClient */) + assert.ErrorIs(t, err, enroll.ErrAutoEnrollDisabled, "AutoEnroll() error mismatch") +}