From f3eae28f3ecab8bc2bb60c7ae11d91341505369f Mon Sep 17 00:00:00 2001 From: James Renken Date: Wed, 9 Oct 2024 22:07:55 -0700 Subject: [PATCH] Add missing ModelToPb; add initial tests for contact --- sa/sa.go | 12 ++++++++-- sa/sa_test.go | 64 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/sa/sa.go b/sa/sa.go index b86ca17e4f7..df7d825506f 100644 --- a/sa/sa.go +++ b/sa/sa.go @@ -197,13 +197,17 @@ func (ssa *SQLStorageAuthority) UpdateRegistrationContact(ctx context.Context, r return nil, berrors.InternalServerError("no registration ID '%d' updated with new contact field", req.RegistrationID) } - updatedRegistration, err := selectRegistration(ctx, tx, "id", req.RegistrationID) + updatedRegistrationModel, err := selectRegistration(ctx, tx, "id", req.RegistrationID) if err != nil { if db.IsNoRows(err) { return nil, berrors.NotFoundError("registration with ID '%d' not found", req.RegistrationID) } return nil, err } + updatedRegistration, err := registrationModelToPb(updatedRegistrationModel) + if err != nil { + return nil, err + } return updatedRegistration, nil }) @@ -258,13 +262,17 @@ func (ssa *SQLStorageAuthority) UpdateRegistrationKey(ctx context.Context, req * return nil, berrors.InternalServerError("no registration ID '%d' updated with new jwk", req.RegistrationID) } - updatedRegistration, err := selectRegistration(ctx, tx, "id", req.RegistrationID) + updatedRegistrationModel, err := selectRegistration(ctx, tx, "id", req.RegistrationID) if err != nil { if db.IsNoRows(err) { return nil, berrors.NotFoundError("registration with ID '%d' not found", req.RegistrationID) } return nil, err } + updatedRegistration, err := registrationModelToPb(updatedRegistrationModel) + if err != nil { + return nil, err + } return updatedRegistration, nil }) diff --git a/sa/sa_test.go b/sa/sa_test.go index 74682d88109..fcee20bea5d 100644 --- a/sa/sa_test.go +++ b/sa/sa_test.go @@ -3,6 +3,8 @@ package sa import ( "bytes" "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/sha256" @@ -4858,3 +4860,65 @@ func TestGetPausedIdentifiersOnlyUnpausesOneAccount(t *testing.T) { test.AssertEquals(t, len(identifiers.Identifiers), 1) test.AssertEquals(t, identifiers.Identifiers[0].Value, "example.net") } + +func newAcctKey(t *testing.T) []byte { + key, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + jwk := &jose.JSONWebKey{Key: key.Public()} + acctKey, err := jwk.MarshalJSON() + test.AssertNotError(t, err, "failed to marshal account key") + return acctKey +} + +func TestUpdateRegistrationContact(t *testing.T) { + sa, _, cleanUp := initSA(t) + defer cleanUp() + + noContact, _ := json.Marshal("") + exampleContact, _ := json.Marshal("test@example.com") + + tests := []struct { + name string + oldContactsJSON []string + newContacts []string + expectedError error + }{ + { + name: "update a valid registration from no contacts to one email address", + oldContactsJSON: []string{string(noContact)}, + newContacts: []string{"mailto:test@example.com"}, + expectedError: nil, + }, + { + name: "update a valid registration from no contacts to two email addresses", + oldContactsJSON: []string{string(noContact)}, + newContacts: []string{"mailto:test1@example.com", "mailto:test2@example.com"}, + expectedError: nil, + }, + { + name: "update a valid registration from one email address to no contacts", + oldContactsJSON: []string{string(exampleContact)}, + newContacts: []string{}, + expectedError: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + initialIP, _ := net.ParseIP("43.34.43.34").MarshalText() + + reg, err := sa.NewRegistration(ctx, &corepb.Registration{ + Contact: tt.oldContactsJSON, + Key: newAcctKey(t), + InitialIP: initialIP, + }) + test.AssertNotError(t, err, "creating new registration") + + reg, err = sa.UpdateRegistrationContact(ctx, &sapb.UpdateRegistrationContactRequest{ + RegistrationID: reg.Id, + Contacts: tt.newContacts, + }) + test.AssertNotError(t, err, "Unexpected error for UpdateRegistrationContact()") + + test.AssertDeepEquals(t, reg.Contact, tt.newContacts) + }) + } +}