Skip to content

Commit

Permalink
Remove direct usage of logical/pki's storageContext.Backend field (#2…
Browse files Browse the repository at this point in the history
…7401)

* Add method storageContext.Logger().

* Add method storageContext.System().

* Add method storageContext.CrlBuilder().

* Add method storageContext.GetUnifiedTransferStatus().

* Add method storageContext.GetPkiManagedView().

* Add method storageContext.GetCertificateCounter().

* Add method storageContext.UseLegacyBundleCaStorage().

* Add method storageContext.GetRevokeStorageLock().

* Add acmeState to acmeContext.

Make acmeState accessible from acmeContext, so that storageContext doesn't have
to be used for this purpose.

* Decouple getAndValidateAcmeRole() from storageContext.Backend.

* Don't access Backend.ciepsState through storageContext.

* Add method storageContext.GetRole().

* Change signature of getCiepsAcmeSettings for CE compatibility.
  • Loading branch information
victorr authored Jun 7, 2024
1 parent d382103 commit 8fd63b0
Show file tree
Hide file tree
Showing 20 changed files with 160 additions and 123 deletions.
8 changes: 4 additions & 4 deletions builtin/logical/pki/acme_challenge_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,13 @@ func (ace *ACMEChallengeEngine) AcceptChallenge(sc *storageContext, account stri
func (ace *ACMEChallengeEngine) VerifyChallenge(runnerSc *storageContext, id string, validationQueueRetries int, finished chan bool, config *acmeConfigEntry) {
sc, cancel := runnerSc.WithFreshTimeout(MaxChallengeTimeout)
defer cancel()
runnerSc.Backend.Logger().Debug("Starting verification of challenge", "id", id)
runnerSc.Logger().Debug("Starting verification of challenge", "id", id)

if retry, retryAfter, err := ace._verifyChallenge(sc, id, config); err != nil {
// Because verification of this challenge failed, we need to retry
// it in the future. Log the error and re-add the item to the queue
// to try again later.
sc.Backend.Logger().Error(fmt.Sprintf("ACME validation failed for %v: %v", id, err))
sc.Logger().Error(fmt.Sprintf("ACME validation failed for %v: %v", id, err))

if retry {
validationQueueRetries++
Expand All @@ -331,10 +331,10 @@ func (ace *ACMEChallengeEngine) VerifyChallenge(runnerSc *storageContext, id str
// we have a secondary check here to see if we are consistently looping within the validation
// queue that is larger than the normal retry attempts we would allow.
if validationQueueRetries > MaxRetryAttempts*2 {
sc.Backend.Logger().Warn("reached max error attempts within challenge queue: %v, giving up", id)
sc.Logger().Warn("reached max error attempts within challenge queue: %v, giving up", id)
_, _, err = ace._verifyChallengeCleanup(sc, nil, id)
if err != nil {
sc.Backend.Logger().Warn("Failed cleaning up challenge entry: %v", err)
sc.Logger().Warn("Failed cleaning up challenge entry: %v", err)
}
finished <- true
return
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/acme_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ func (a *acmeState) reloadConfigIfRequired(sc *storageContext) error {
return nil
}

config, err := sc.getAcmeConfig()
config, err := getAcmeConfig(sc)
if err != nil {
return fmt.Errorf("failed reading ACME config: %w", err)
}
Expand Down
12 changes: 7 additions & 5 deletions builtin/logical/pki/acme_wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type acmeContext struct {
baseUrl *url.URL
clusterUrl *url.URL
sc *storageContext
acmeState *acmeState
role *issuing.RoleEntry
issuer *issuing.IssuerEntry
// acmeDirectory is a string that can distinguish the various acme directories we have configured
Expand All @@ -32,7 +33,7 @@ type acmeContext struct {
}

func (c acmeContext) getAcmeState() *acmeState {
return c.sc.Backend.GetAcmeState()
return c.acmeState
}

type (
Expand Down Expand Up @@ -110,7 +111,7 @@ func (b *backend) acmeWrapper(opts acmeWrapperOpts, op acmeOperation) framework.
return acmeErrorWrapper(func(ctx context.Context, r *logical.Request, data *framework.FieldData) (*logical.Response, error) {
sc := b.makeStorageContext(ctx, r.Storage)

config, err := sc.Backend.GetAcmeState().getConfigWithUpdate(sc)
config, err := b.GetAcmeState().getConfigWithUpdate(sc)
if err != nil {
return nil, fmt.Errorf("failed to fetch ACME configuration: %w", err)
}
Expand Down Expand Up @@ -144,7 +145,7 @@ func (b *backend) acmeWrapper(opts acmeWrapperOpts, op acmeOperation) framework.
return nil, err
}

isCiepsEnabled, ciepsPolicy, err := getCiepsAcmeSettings(sc, opts, config, data)
isCiepsEnabled, ciepsPolicy, err := getCiepsAcmeSettings(b, sc, opts, config, data)
if err != nil {
return nil, err
}
Expand All @@ -163,6 +164,7 @@ func (b *backend) acmeWrapper(opts acmeWrapperOpts, op acmeOperation) framework.
baseUrl: acmeBaseUrl,
clusterUrl: clusterBase,
sc: sc,
acmeState: b.acmeState,
role: role,
issuer: issuer,
acmeDirectory: acmeDirectory,
Expand Down Expand Up @@ -455,7 +457,7 @@ func getAcmeRoleAndIssuer(sc *storageContext, data *framework.FieldData, config

func getAndValidateAcmeRole(sc *storageContext, requestedRole string) (*issuing.RoleEntry, error) {
var err error
role, err := sc.Backend.GetRole(sc.Context, sc.Storage, requestedRole)
role, err := sc.GetRole(requestedRole)
if err != nil {
return nil, fmt.Errorf("%w: err loading role", ErrServerInternal)
}
Expand Down Expand Up @@ -496,7 +498,7 @@ func isAcmeDisabled(sc *storageContext, config *acmeConfigEntry, policy EabPolic

disableAcme, nonFatalErr := isPublicACMEDisabledByEnv()
if nonFatalErr != nil {
sc.Backend.Logger().Warn(fmt.Sprintf("could not parse env var '%s'", disableAcmeEnvVar), "error", nonFatalErr)
sc.Logger().Warn(fmt.Sprintf("could not parse env var '%s'", disableAcmeEnvVar), "error", nonFatalErr)
}

// The OS environment if true will override any configuration option.
Expand Down
18 changes: 8 additions & 10 deletions builtin/logical/pki/ca_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,13 @@ func getGenerationParams(sc *storageContext, data *framework.FieldData) (exporte

func generateCABundle(sc *storageContext, input *inputBundle, data *certutil.CreationBundle, randomSource io.Reader) (*certutil.ParsedCertBundle, error) {
ctx := sc.Context
b := sc.Backend

if kmsRequested(input) {
keyId, err := getManagedKeyId(input.apiData)
if err != nil {
return nil, err
}
return managed_key.GenerateManagedKeyCABundle(ctx, b, keyId, data, randomSource)
return managed_key.GenerateManagedKeyCABundle(ctx, sc.GetPkiManagedView(), keyId, data, randomSource)
}

if existingKeyRequested(input) {
Expand All @@ -110,7 +109,7 @@ func generateCABundle(sc *storageContext, input *inputBundle, data *certutil.Cre
if err != nil {
return nil, err
}
return managed_key.GenerateManagedKeyCABundle(ctx, b, keyId, data, randomSource)
return managed_key.GenerateManagedKeyCABundle(ctx, sc.GetPkiManagedView(), keyId, data, randomSource)
}

return certutil.CreateCertificateWithKeyGenerator(data, randomSource, existingKeyGeneratorFromBytes(keyEntry))
Expand All @@ -121,15 +120,14 @@ func generateCABundle(sc *storageContext, input *inputBundle, data *certutil.Cre

func generateCSRBundle(sc *storageContext, input *inputBundle, data *certutil.CreationBundle, addBasicConstraints bool, randomSource io.Reader) (*certutil.ParsedCSRBundle, error) {
ctx := sc.Context
b := sc.Backend

if kmsRequested(input) {
keyId, err := getManagedKeyId(input.apiData)
if err != nil {
return nil, err
}

return managed_key.GenerateManagedKeyCSRBundle(ctx, b, keyId, data, addBasicConstraints, randomSource)
return managed_key.GenerateManagedKeyCSRBundle(ctx, sc.GetPkiManagedView(), keyId, data, addBasicConstraints, randomSource)
}

if existingKeyRequested(input) {
Expand All @@ -148,7 +146,7 @@ func generateCSRBundle(sc *storageContext, input *inputBundle, data *certutil.Cr
if err != nil {
return nil, err
}
return managed_key.GenerateManagedKeyCSRBundle(ctx, b, keyId, data, addBasicConstraints, randomSource)
return managed_key.GenerateManagedKeyCSRBundle(ctx, sc.GetPkiManagedView(), keyId, data, addBasicConstraints, randomSource)
}

return certutil.CreateCSRWithKeyGenerator(data, addBasicConstraints, randomSource, existingKeyGeneratorFromBytes(key))
Expand All @@ -157,8 +155,8 @@ func generateCSRBundle(sc *storageContext, input *inputBundle, data *certutil.Cr
return certutil.CreateCSRWithRandomSource(data, addBasicConstraints, randomSource)
}

func parseCABundle(ctx context.Context, b *backend, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
return issuing.ParseCABundle(ctx, b, bundle)
func parseCABundle(ctx context.Context, mkv managed_key.PkiManagedKeyView, bundle *certutil.CertBundle) (*certutil.ParsedCertBundle, error) {
return issuing.ParseCABundle(ctx, mkv, bundle)
}

func (sc *storageContext) getKeyTypeAndBitsForRole(data *framework.FieldData) (string, int, error) {
Expand Down Expand Up @@ -190,7 +188,7 @@ func (sc *storageContext) getKeyTypeAndBitsForRole(data *framework.FieldData) (s
return "", 0, errors.New("unable to determine managed key id: " + err.Error())
}

pubKeyManagedKey, err := managed_key.GetManagedKeyPublicKey(sc.Context, sc.Backend, keyId)
pubKeyManagedKey, err := managed_key.GetManagedKeyPublicKey(sc.Context, sc.GetPkiManagedView(), keyId)
if err != nil {
return "", 0, errors.New("failed to lookup public key from managed key: " + err.Error())
}
Expand Down Expand Up @@ -222,7 +220,7 @@ func (sc *storageContext) getExistingPublicKey(data *framework.FieldData) (crypt
if err != nil {
return nil, err
}
return getPublicKey(sc.Context, sc.Backend, key)
return getPublicKey(sc.Context, sc.GetPkiManagedView(), key)
}

func getKeyTypeAndBitsFromPublicKeyForRole(pubKey crypto.PublicKey) (certutil.PrivateKeyType, int, error) {
Expand Down
31 changes: 15 additions & 16 deletions builtin/logical/pki/cert_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ func (sc *storageContext) fetchCAInfo(issuerRef string, usage issuing.IssuerUsag
func (sc *storageContext) fetchCAInfoWithIssuer(issuerRef string, usage issuing.IssuerUsage) (*certutil.CAInfoBundle, issuing.IssuerID, error) {
var issuerId issuing.IssuerID

if sc.Backend.UseLegacyBundleCaStorage() {
if sc.UseLegacyBundleCaStorage() {
// We have not completed the migration so attempt to load the bundle from the legacy location
sc.Backend.Logger().Info("Using legacy CA bundle as PKI migration has not completed.")
sc.Logger().Info("Using legacy CA bundle as PKI migration has not completed.")
issuerId = legacyBundleShimID
} else {
var err error
Expand All @@ -163,7 +163,7 @@ func (sc *storageContext) fetchCAInfoWithIssuer(issuerRef string, usage issuing.
// fetchCAInfoByIssuerId will fetch the CA info, will return an error if no ca info exists for the given issuerId.
// This does support the loading using the legacyBundleShimID
func (sc *storageContext) fetchCAInfoByIssuerId(issuerId issuing.IssuerID, usage issuing.IssuerUsage) (*certutil.CAInfoBundle, error) {
return issuing.FetchCAInfoByIssuerId(sc.Context, sc.Storage, sc.Backend, issuerId, usage)
return issuing.FetchCAInfoByIssuerId(sc.Context, sc.Storage, sc.GetPkiManagedView(), issuerId, usage)
}

func fetchCertBySerialBigInt(sc *storageContext, prefix string, serial *big.Int) (*logical.StorageEntry, error) {
Expand All @@ -190,7 +190,7 @@ func fetchCertBySerial(sc *storageContext, prefix, serial string) (*logical.Stor
legacyPath = "revoked/" + colonSerial
path = "revoked/" + hyphenSerial
case serial == legacyCRLPath || serial == deltaCRLPath || serial == unifiedCRLPath || serial == unifiedDeltaCRLPath:
warnings, err := sc.Backend.CrlBuilder().rebuildIfForced(sc)
warnings, err := sc.CrlBuilder().rebuildIfForced(sc)
if err != nil {
return nil, err
}
Expand All @@ -199,7 +199,7 @@ func fetchCertBySerial(sc *storageContext, prefix, serial string) (*logical.Stor
for index, warning := range warnings {
msg = fmt.Sprintf("%v\n %d. %v", msg, index+1, warning)
}
sc.Backend.Logger().Warn(msg)
sc.Logger().Warn(msg)
}

unified := serial == unifiedCRLPath || serial == unifiedDeltaCRLPath
Expand All @@ -209,7 +209,7 @@ func fetchCertBySerial(sc *storageContext, prefix, serial string) (*logical.Stor
}

if serial == deltaCRLPath || serial == unifiedDeltaCRLPath {
if sc.Backend.UseLegacyBundleCaStorage() {
if sc.UseLegacyBundleCaStorage() {
return nil, fmt.Errorf("refusing to serve delta CRL with legacy CA bundle")
}

Expand Down Expand Up @@ -250,7 +250,7 @@ func fetchCertBySerial(sc *storageContext, prefix, serial string) (*logical.Stor

// Update old-style paths to new-style paths
certEntry.Key = path
certCounter := sc.Backend.GetCertificateCounter()
certCounter := sc.GetCertificateCounter()
certsCounted := certCounter.IsInitialized()
if err = sc.Storage.Put(sc.Context, certEntry); err != nil {
return nil, errutil.InternalError{Err: fmt.Sprintf("error saving certificate with serial %s to new location: %s", serial, err)}
Expand Down Expand Up @@ -327,7 +327,6 @@ func generateCert(sc *storageContext,
randomSource io.Reader) (*certutil.ParsedCertBundle, []string, error,
) {
ctx := sc.Context
b := sc.Backend

if input.role == nil {
return nil, nil, errutil.InternalError{Err: "no role found in data bundle"}
Expand All @@ -337,7 +336,7 @@ func generateCert(sc *storageContext,
return nil, nil, errutil.UserError{Err: "RSA keys < 2048 bits are unsafe and not supported"}
}

data, warnings, err := generateCreationBundle(b, input, caSign, nil)
data, warnings, err := generateCreationBundle(sc.System(), input, caSign, nil)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -402,7 +401,7 @@ func generateCert(sc *storageContext,
// N.B.: This is only meant to be used for generating intermediate CAs.
// It skips some sanity checks.
func generateIntermediateCSR(sc *storageContext, input *inputBundle, randomSource io.Reader) (*certutil.ParsedCSRBundle, []string, error) {
creation, warnings, err := generateCreationBundle(sc.Backend, input, nil, nil)
creation, warnings, err := generateCreationBundle(sc.System(), input, nil, nil)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -468,15 +467,15 @@ func (i SignCertInputFromDataFields) GetPermittedDomains() []string {
return i.data.Get("permitted_dns_domains").([]string)
}

func signCert(b *backend, data *inputBundle, caSign *certutil.CAInfoBundle, isCA bool, useCSRValues bool) (*certutil.ParsedCertBundle, []string, error) {
func signCert(sysView logical.SystemView, data *inputBundle, caSign *certutil.CAInfoBundle, isCA bool, useCSRValues bool) (*certutil.ParsedCertBundle, []string, error) {
if data.role == nil {
return nil, nil, errutil.InternalError{Err: "no role found in data bundle"}
}

entityInfo := issuing.NewEntityInfoFromReq(data.req)
signCertInput := NewSignCertInputFromDataFields(data.apiData, isCA, useCSRValues)

return issuing.SignCert(b.System(), data.role, entityInfo, caSign, signCertInput)
return issuing.SignCert(sysView, data.role, entityInfo, caSign, signCertInput)
}

func getOtherSANsFromX509Extensions(exts []pkix.Extension) ([]certutil.OtherNameUtf8, error) {
Expand Down Expand Up @@ -542,18 +541,18 @@ func (cb CreationBundleInputFromFieldData) GetUserIds() []string {
// generateCreationBundle is a shared function that reads parameters supplied
// from the various endpoints and generates a CreationParameters with the
// parameters that can be used to issue or sign
func generateCreationBundle(b *backend, data *inputBundle, caSign *certutil.CAInfoBundle, csr *x509.CertificateRequest) (*certutil.CreationBundle, []string, error) {
func generateCreationBundle(sysView logical.SystemView, data *inputBundle, caSign *certutil.CAInfoBundle, csr *x509.CertificateRequest) (*certutil.CreationBundle, []string, error) {
entityInfo := issuing.NewEntityInfoFromReq(data.req)
creationBundleInput := NewCreationBundleInputFromFieldData(data.apiData)

return issuing.GenerateCreationBundle(b.System(), data.role, entityInfo, creationBundleInput, caSign, csr)
return issuing.GenerateCreationBundle(sysView, data.role, entityInfo, creationBundleInput, caSign, csr)
}

// getCertificateNotAfter compute a certificate's NotAfter date based on the mount ttl, role, signing bundle and input
// api data being sent. Returns a NotAfter time, a set of warnings or an error.
func getCertificateNotAfter(b *backend, data *inputBundle, caSign *certutil.CAInfoBundle) (time.Time, []string, error) {
func getCertificateNotAfter(sysView logical.SystemView, data *inputBundle, caSign *certutil.CAInfoBundle) (time.Time, []string, error) {
input := NewCertNotAfterInputFromFieldData(data.apiData)
return issuing.GetCertificateNotAfter(b.System(), data.role, input, caSign)
return issuing.GetCertificateNotAfter(sysView, data.role, input, caSign)
}

// applyIssuerLeafNotAfterBehavior resets a certificate's notAfter time or errors out based on the
Expand Down
4 changes: 2 additions & 2 deletions builtin/logical/pki/cert_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func TestPki_MultipleOUs(t *testing.T) {
OU: []string{"Z", "E", "V"},
},
}
cb, _, err := generateCreationBundle(b, input, nil, nil)
cb, _, err := generateCreationBundle(b.System(), input, nil, nil)
if err != nil {
t.Fatalf("Error: %v", err)
}
Expand Down Expand Up @@ -245,7 +245,7 @@ func TestPki_PermitFQDNs(t *testing.T) {
name := name
testCase := testCase
t.Run(name, func(t *testing.T) {
cb, _, err := generateCreationBundle(b, testCase.input, nil, nil)
cb, _, err := generateCreationBundle(b.System(), testCase.input, nil, nil)
if err != nil {
t.Fatalf("Error: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/pki/cieps_util_oss.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ func issueAcmeCertUsingCieps(_ *backend, _ *acmeContext, _ *logical.Request, _ *
return nil, "", fmt.Errorf("cieps is an enterprise only feature")
}

func getCiepsAcmeSettings(sc *storageContext, opts acmeWrapperOpts, config *acmeConfigEntry, data *framework.FieldData) (bool, string, error) {
func getCiepsAcmeSettings(b *backend, sc *storageContext, opts acmeWrapperOpts, config *acmeConfigEntry, data *framework.FieldData) (bool, string, error) {
return false, "", nil
}
Loading

0 comments on commit 8fd63b0

Please sign in to comment.