Skip to content

Commit

Permalink
Fix for #4701 - OnBeforeTokenRequest can change the URI
Browse files Browse the repository at this point in the history
  • Loading branch information
bgavrilMS committed Apr 9, 2024
1 parent 7faa358 commit 894758f
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ internal async Task<T> ExecuteRequestAsync<T>(
{
var requestData = new OnBeforeTokenRequestData(_bodyParameters, _headers, endpointUri, requestContext.UserCancellationToken);
await onBeforePostRequestData(requestData).ConfigureAwait(false);
endpointUri = requestData.RequestUri;
}

response = await _httpManager.SendPostAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
using System.Threading.Tasks;
using Microsoft.Identity.Client;
using Microsoft.Identity.Client.Cache;
using Microsoft.Identity.Client.Extensibility;
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.Internal.ClientCredential;
using Microsoft.Identity.Client.OAuth2;
Expand All @@ -23,6 +22,7 @@
using Microsoft.Identity.Test.Common.Core.Mocks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using NSubstitute;
using Microsoft.Identity.Client.Extensibility;

namespace Microsoft.Identity.Test.Unit.PublicApiTests
{
Expand Down Expand Up @@ -491,77 +491,6 @@ private enum CredentialType
return (app, handler);
}

private static Task ModifyRequestAsync(OnBeforeTokenRequestData requestData)
{
Assert.AreEqual("https://login.microsoftonline.com/tid/oauth2/v2.0/token", requestData.RequestUri.AbsoluteUri);
requestData.BodyParameters.Add("param1", "val1");
requestData.BodyParameters.Add("param2", "val2");

requestData.Headers.Add("header1", "hval1");
requestData.Headers.Add("header2", "hval2");

return Task.CompletedTask;
}

[TestMethod]
public async Task CertificateOverrideAsync()
{
using (var httpManager = new MockHttpManager())
{
httpManager.AddInstanceDiscoveryMockHandler();

var app = ConfidentialClientApplicationBuilder.Create(TestConstants.ClientId)
.WithAuthority("https://login.microsoftonline.com/tid/")
.WithExperimentalFeatures(true)
.WithHttpManager(httpManager)
.Build();

MockHttpMessageHandler handler = httpManager.AddMockHandlerSuccessfulClientCredentialTokenResponseMessage();

var result = await app.AcquireTokenForClient(TestConstants.s_scope.ToArray())
.WithProofOfPosessionKeyId("key1")
.OnBeforeTokenRequest(ModifyRequestAsync)
.ExecuteAsync()
.ConfigureAwait(false);

Assert.AreEqual("Bearer", result.TokenType);

Assert.AreEqual("val1", handler.ActualRequestPostData["param1"]);
Assert.AreEqual("val2", handler.ActualRequestPostData["param2"]);
Assert.AreEqual("hval1", handler.ActualRequestHeaders.GetValues("header1").Single());
Assert.AreEqual("hval2", handler.ActualRequestHeaders.GetValues("header2").Single());
Assert.IsFalse(handler.ActualRequestPostData.ContainsKey(OAuth2Parameter.ClientAssertion));
Assert.IsFalse(handler.ActualRequestPostData.ContainsKey(OAuth2Parameter.ClientAssertionType));
Assert.AreEqual("key1", (app.AppTokenCache as ITokenCacheInternal).Accessor.GetAllAccessTokens().Single().KeyId);

result = await app.AcquireTokenForClient(TestConstants.s_scope.ToArray())
.WithProofOfPosessionKeyId("key1")
.OnBeforeTokenRequest(ModifyRequestAsync)
.ExecuteAsync()
.ConfigureAwait(false);

Assert.AreEqual("Bearer", result.TokenType);
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);

Assert.AreEqual(
"key1",
(app.AppTokenCache as ITokenCacheInternal).Accessor.GetAllAccessTokens().Single().KeyId);

httpManager.AddMockHandlerSuccessfulClientCredentialTokenResponseMessage();
result = await app.AcquireTokenForClient(TestConstants.s_scope.ToArray())
.OnBeforeTokenRequest(ModifyRequestAsync)
.ExecuteAsync()
.ConfigureAwait(false);

Assert.AreEqual("Bearer", result.TokenType);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
IReadOnlyList<Client.Cache.Items.MsalAccessTokenCacheItem> ats = (app.AppTokenCache as ITokenCacheInternal).Accessor.GetAllAccessTokens();
Assert.AreEqual(2, ats.Count);
Assert.IsTrue(ats.Single(at => at.KeyId == "key1") != null);
Assert.IsTrue(ats.Single(at => at.KeyId == null) != null);
}
}

[TestMethod]
public async Task ConfidentialClientUsingCertificateTestAsync()
{
Expand Down Expand Up @@ -1860,108 +1789,6 @@ public async Task ValidateGetAccountAsyncWithNullEmptyAccountIdAsync(string acco

Assert.IsNull(acc);
}
}

[TestMethod]
public async Task ValidateAppTokenProviderAsync()
{
using (var harness = base.CreateTestHarness())
{
harness.HttpManager.AddInstanceDiscoveryMockHandler();

bool usingClaims = false;
string differentScopesForAt = string.Empty;
int callbackInvoked = 0;
var app = ConfidentialClientApplicationBuilder.Create(TestConstants.ClientId)
.WithAppTokenProvider((AppTokenProviderParameters parameters) =>
{
Assert.IsNotNull(parameters.Scopes);
Assert.IsNotNull(parameters.CorrelationId);
Assert.IsNotNull(parameters.TenantId);
Assert.IsNotNull(parameters.CancellationToken);
if (usingClaims)
{
Assert.IsNotNull(parameters.Claims);
}
Interlocked.Increment(ref callbackInvoked);
return Task.FromResult(GetAppTokenProviderResult(differentScopesForAt));
})
.WithHttpManager(harness.HttpManager)
.BuildConcrete();

// AcquireToken from app provider
AuthenticationResult result = await app.AcquireTokenForClient(TestConstants.s_scope)
.ExecuteAsync(new CancellationToken()).ConfigureAwait(false);

Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TestConstants.DefaultAccessToken, result.AccessToken);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
Assert.AreEqual(1, callbackInvoked);

var tokens = app.AppTokenCacheInternal.Accessor.GetAllAccessTokens();

Assert.AreEqual(1, tokens.Count);

var token = tokens.FirstOrDefault();
Assert.IsNotNull(token);
Assert.AreEqual(TestConstants.DefaultAccessToken, token.Secret);

// AcquireToken from cache
result = await app.AcquireTokenForClient(TestConstants.s_scope)
.ExecuteAsync(new CancellationToken()).ConfigureAwait(false);

Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TestConstants.DefaultAccessToken, result.AccessToken);
Assert.AreEqual(TokenSource.Cache, result.AuthenticationResultMetadata.TokenSource);
Assert.AreEqual(1, callbackInvoked);

// Expire token
TokenCacheHelper.ExpireAllAccessTokens(app.AppTokenCacheInternal);

// Acquire token from app provider with expired token
result = await app.AcquireTokenForClient(TestConstants.s_scope)
.ExecuteAsync(new CancellationToken()).ConfigureAwait(false);

Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TestConstants.DefaultAccessToken, result.AccessToken);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
Assert.AreEqual(2, callbackInvoked);

differentScopesForAt = "new scope";

// Acquire token from app provider with new scopes
result = await app.AcquireTokenForClient(new[] { differentScopesForAt })
.ExecuteAsync(new CancellationToken()).ConfigureAwait(false);

Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TestConstants.DefaultAccessToken + differentScopesForAt, result.AccessToken);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
Assert.AreEqual(app.AppTokenCacheInternal.Accessor.GetAllAccessTokens().Count, 2);
Assert.AreEqual(3, callbackInvoked);

// Acquire token from app provider with claims. Should not use cache
result = await app.AcquireTokenForClient(TestConstants.s_scope)
.WithClaims(TestConstants.Claims)
.ExecuteAsync(new CancellationToken()).ConfigureAwait(false);

Assert.IsNotNull(result.AccessToken);
Assert.AreEqual(TestConstants.DefaultAccessToken + differentScopesForAt, result.AccessToken);
Assert.AreEqual(TokenSource.IdentityProvider, result.AuthenticationResultMetadata.TokenSource);
Assert.AreEqual(4, callbackInvoked);
}
}

private AppTokenProviderResult GetAppTokenProviderResult(string differentScopesForAt = "", long? refreshIn = 1000)
{
var token = new AppTokenProviderResult();
token.AccessToken = TestConstants.DefaultAccessToken + differentScopesForAt; //Used to indicate that there is a new access token for a different set of scopes
token.ExpiresInSeconds = 3600;
token.RefreshInSeconds = refreshIn;

return token;
}
}
}
}
Loading

0 comments on commit 894758f

Please sign in to comment.