diff --git a/api/email/post.go b/api/email/post.go index 8693f97..43689e4 100644 --- a/api/email/post.go +++ b/api/email/post.go @@ -47,6 +47,8 @@ func HandlePostEmail( sourceEmailPassword := getEnv("SOURCE_EMAIL_PASSWORD") skipTlsVerify := getEnv("TEST_ONLY_SKIP_TLS_VERIFY") == "dummy string just in case" + smtpMessageMutex := sync.Mutex{} + buildMessage := func(email *requestBody) string { return fmt.Sprintf( "To: %s\r\nSubject: %s\r\n\r\n%s\r\n\r\nSent by %s", @@ -102,6 +104,8 @@ func HandlePostEmail( sendEmail := func(request *http.Request, client *smtp.Client, email *requestBody) (err error) { doneChannel := make(chan struct{}) + smtpMessageMutex.Lock() + defer smtpMessageMutex.Unlock() log.Println("[DEBUG] Setting SMTP email sender") go func() { diff --git a/api/email/post_test.go b/api/email/post_test.go index c6f3143..6983fc8 100644 --- a/api/email/post_test.go +++ b/api/email/post_test.go @@ -145,6 +145,33 @@ func TestCancellation(t *testing.T) { assert.Equal(t, expectedErrorRedirectUrl, response.Header.Get("Location")) } +func TestConcurrentRequests(t *testing.T) { + concurrentRequests := 10 + emailsReceived := 0 + smtpHandler := func(_ net.Addr, _ string, _ []string, _ []byte) error { + emailsReceived++ + return nil + } + + smtpServer, smtpServerPort := setupSmtpServer(t, smtpHandler, nil) + defer teardownSmtpServer(smtpServer) + testHttpServer, shutdownWaitGroup, triggerShutdown := setupHttpServer(context.Background(), smtpServerPort) + defer teardownHttpServer(testHttpServer, shutdownWaitGroup, triggerShutdown) + + httpRequestCompleted := make(chan struct{}) + for range concurrentRequests { + go func() { + requestPostEmail(t, testHttpServer.URL) + httpRequestCompleted <- struct{}{} + }() + } + + for range concurrentRequests { + <-httpRequestCompleted + } + assert.Equal(t, concurrentRequests, emailsReceived) +} + func setupSmtpServer(t *testing.T, handler smtpd.Handler, authHandler smtpd.AuthHandler) (*smtpd.Server, int) { smtpServer := newSmtpServer(t, handler, authHandler) smtpListener, smtpServerPort := newSmtpServerListener()