From 79dad30f168330cd5e9658b434e3b06e12b5ec4d Mon Sep 17 00:00:00 2001 From: Helmi Nour Date: Wed, 12 Apr 2023 00:47:24 +0100 Subject: [PATCH] Enable access token caching (#66) * enable access token caching * addressing PR comments * fix * convenient test * addressing PR comments * changelog update * fix changelog version & using userProfile enum to get homedir --- .github/workflows/dotnet-build.yaml | 1 + CHANGELOG.md | 5 + .../DefaultAccessTokenHandlerTests.cs | 49 +++++++ RelationalAI.Test/UnitTest.cs | 18 ++- RelationalAI/AccessToken.cs | 30 ++-- RelationalAI/Client.cs | 3 +- RelationalAI/ClientCredentials.cs | 13 +- RelationalAI/Config.cs | 9 +- RelationalAI/DefaultAccessTokenHandler.cs | 128 ++++++++++++++++++ RelationalAI/IAccessTokenHandler.cs | 29 ++++ RelationalAI/RelationalAI.csproj | 2 +- RelationalAI/Rest.cs | 71 +++++----- 12 files changed, 304 insertions(+), 54 deletions(-) create mode 100644 RelationalAI.Test/DefaultAccessTokenHandlerTests.cs create mode 100644 RelationalAI/DefaultAccessTokenHandler.cs create mode 100644 RelationalAI/IAccessTokenHandler.cs diff --git a/.github/workflows/dotnet-build.yaml b/.github/workflows/dotnet-build.yaml index a53184a..04e163d 100644 --- a/.github/workflows/dotnet-build.yaml +++ b/.github/workflows/dotnet-build.yaml @@ -19,6 +19,7 @@ jobs: - name: Dotnet linter run: | + mkdir ~/.rai dotnet tool install -g dotnet-format dotnet format . --check diff --git a/CHANGELOG.md b/CHANGELOG.md index 62000f2..9979d8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +# v0.9.13-alpha +* Add `IAccessTokenHandler` interface to implement custom access Auth0 token handlers. +* `Client` class consturctor extended to accept custom access token handlers implementations. +* Add `DefaultAccessTokenHandler` to cache Auth0 tokens locally. + # v0.9.12-alpha * Add support to v2 `CreateDatabase`. # v0.9.11-alpha diff --git a/RelationalAI.Test/DefaultAccessTokenHandlerTests.cs b/RelationalAI.Test/DefaultAccessTokenHandlerTests.cs new file mode 100644 index 0000000..4ba0517 --- /dev/null +++ b/RelationalAI.Test/DefaultAccessTokenHandlerTests.cs @@ -0,0 +1,49 @@ +using System; +using System.IO; +using System.Threading.Tasks; +using FluentAssertions; +using Xunit; +using Xunit.Abstractions; + +namespace RelationalAI.Test +{ + [Collection("RelationalAI.Test")] + public class DefaultAccessTokenHandlerTests : UnitTest + { + private string TestCachePath() + { + var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); + + return Path.Join(home, ".rai", "test_tokens.json"); + } + + public DefaultAccessTokenHandlerTests(ITestOutputHelper output) : base(output) + { } + + [Fact] + public async Task DefaultAccessTokenHandlerTestAsync() + { + var ctx = CreateContext(GetConfig()); + var creds = ctx.Credentials as ClientCredentials; + var rest = new Rest(ctx); + + // should generate token if cache path doesn't exist + var accessTokenHandler = new DefaultAccessTokenHandler(rest, "/fake/path/tokens.json", _logger); + var token = await accessTokenHandler.GetAccessTokenAsync(creds); + token.Should().NotBeNull(); + token.Should().BeEquivalentTo(creds.AccessToken); + + // should generate and cache token if the path exists + accessTokenHandler = new DefaultAccessTokenHandler(rest, TestCachePath(), _logger); + token = await accessTokenHandler.GetAccessTokenAsync(creds); + token.Should().NotBeNull(); + token.Should().BeEquivalentTo(creds.AccessToken); + + accessTokenHandler = new DefaultAccessTokenHandler(null, TestCachePath(), _logger); + var cachedToken = await accessTokenHandler.GetAccessTokenAsync(creds); + token.Should().NotBeNull(); + token.Should().BeEquivalentTo(cachedToken); + token.Should().BeEquivalentTo(creds.AccessToken); + } + } +} \ No newline at end of file diff --git a/RelationalAI.Test/UnitTest.cs b/RelationalAI.Test/UnitTest.cs index 7401c95..a2c21d7 100644 --- a/RelationalAI.Test/UnitTest.cs +++ b/RelationalAI.Test/UnitTest.cs @@ -14,7 +14,7 @@ namespace RelationalAI.Test public class UnitTest : IAsyncLifetime { private static readonly RAILog4NetProvider _loggerProvider = new RAILog4NetProvider(); - private readonly ILogger _logger; + protected readonly ILogger _logger; public UnitTest() { } @@ -25,7 +25,7 @@ public UnitTest(ITestOutputHelper testOutputHelper) _logger = _loggerProvider.CreateLogger("RAI-SDK"); } - public Client CreateClient() + public Dictionary GetConfig() { Dictionary config; @@ -57,9 +57,19 @@ public Client CreateClient() config = Config.Read(new MemoryStream(Encoding.UTF8.GetBytes(configStr))); } - var customHeaders = JsonConvert.DeserializeObject>(GetEnvironmentVariable("CUSTOM_HEADERS")); + return config; + } - var ctx = new Client.Context(config); + public Client.Context CreateContext(Dictionary config) + { + return new Client.Context(config); + } + + public Client CreateClient() + { + var customHeaders = JsonConvert.DeserializeObject>(GetEnvironmentVariable("CUSTOM_HEADERS")); + var config = GetConfig(); + var ctx = CreateContext(config); Client testClient; if (_logger != null) diff --git a/RelationalAI/AccessToken.cs b/RelationalAI/AccessToken.cs index f40c524..0264a31 100644 --- a/RelationalAI/AccessToken.cs +++ b/RelationalAI/AccessToken.cs @@ -15,29 +15,43 @@ */ using System; +using Newtonsoft.Json; namespace RelationalAI { - public class AccessToken + public class AccessToken : Entity { - private readonly DateTime _createdOn; + [JsonProperty("expires_in")] private int _expiresIn; - public AccessToken(string token, int expiresIn) + public AccessToken(string token, long createdOn, int expiresIn, string scope) { Token = token; - ExpiresIn = expiresIn; - _createdOn = DateTime.Now; + Scope = scope; + CreatedOn = createdOn; + _expiresIn = expiresIn; } - public bool IsExpired => (DateTime.Now - _createdOn).TotalSeconds >= ExpiresIn - 5; // Anticipate access token expiration by 5 seconds - + [JsonProperty("access_token", Required = Required.Always)] public string Token { get; set; } + [JsonProperty("scope")] + public string Scope { get; set; } + + [JsonProperty("created_on")] + public long CreatedOn { get; set; } + public int ExpiresIn { get => _expiresIn; set => _expiresIn = value > 0 ? value : throw new ArgumentException("ExpiresIn should be greater than 0 "); } + + public double ExpiresOn + { + get => CreatedOn + _expiresIn; + } + + public bool IsExpired => new DateTimeOffset(DateTime.Now).ToUnixTimeSeconds() >= ExpiresOn - 5; // Anticipate access token expiration by 5 seconds } -} \ No newline at end of file +} diff --git a/RelationalAI/Client.cs b/RelationalAI/Client.cs index 7fd1320..ee0d042 100644 --- a/RelationalAI/Client.cs +++ b/RelationalAI/Client.cs @@ -40,11 +40,12 @@ public class Client private readonly Context _context; private readonly ILogger _logger; - public Client(Context context, ILogger logger = null) + public Client(Context context, ILogger logger = null, IAccessTokenHandler accessTokenHandler = null) { _context = context; _logger = logger ?? new LoggerFactory().CreateLogger("RAI-SDK"); _rest = new Rest(context, _logger); + _rest.AccessTokenHandler = accessTokenHandler ?? new DefaultAccessTokenHandler(_rest, logger: _logger); } public HttpClient HttpClient diff --git a/RelationalAI/ClientCredentials.cs b/RelationalAI/ClientCredentials.cs index 3c35804..5903e9d 100644 --- a/RelationalAI/ClientCredentials.cs +++ b/RelationalAI/ClientCredentials.cs @@ -23,6 +23,7 @@ public class ClientCredentials : ICredentials private const string DefaultClientCredentialsUrl = "https://login.relationalai.com/oauth/token"; private string _clientId; private string _clientSecret; + private string _audience; private string _clientCredentialsUrl = DefaultClientCredentialsUrl; public ClientCredentials(string clientId, string clientSecret) @@ -31,10 +32,11 @@ public ClientCredentials(string clientId, string clientSecret) ClientSecret = clientSecret; } - public ClientCredentials(string clientId, string clientSecret, string clientCredentialsUrl) + public ClientCredentials(string clientId, string clientSecret, string clientCredentialsUrl, string audience) : this(clientId, clientSecret) { ClientCredentialsUrl = clientCredentialsUrl; + Audience = audience; } public string ClientId @@ -57,6 +59,13 @@ public string ClientCredentialsUrl set => _clientCredentialsUrl = !string.IsNullOrEmpty(value) ? value : DefaultClientCredentialsUrl; } + public string Audience + { + get => _audience; + set => _audience = !string.IsNullOrEmpty(value) ? value : + throw new ArgumentException("Audience cannot be null or empty"); + } + public AccessToken AccessToken { get; set; } } -} \ No newline at end of file +} diff --git a/RelationalAI/Config.cs b/RelationalAI/Config.cs index 45eb52c..98b348f 100644 --- a/RelationalAI/Config.cs +++ b/RelationalAI/Config.cs @@ -66,9 +66,16 @@ private static ICredentials ReadClientCredentials(IniData data, string profile) var clientId = GetIniValue(data, profile, "client_id", null); var clientSecret = GetIniValue(data, profile, "client_secret", null); var clientCredentialsUrl = GetIniValue(data, profile, "client_credentials_url", null); + var audience = GetIniValue(data, profile, "audience", null); if (clientId != null && clientSecret != null) { - return new ClientCredentials(clientId, clientSecret, clientCredentialsUrl); + if (string.IsNullOrEmpty(audience)) + { + var host = GetIniValue(data, profile, "host", null); + audience = $"https://{host}"; + } + + return new ClientCredentials(clientId, clientSecret, clientCredentialsUrl, audience); } return null; diff --git a/RelationalAI/DefaultAccessTokenHandler.cs b/RelationalAI/DefaultAccessTokenHandler.cs new file mode 100644 index 0000000..e89900c --- /dev/null +++ b/RelationalAI/DefaultAccessTokenHandler.cs @@ -0,0 +1,128 @@ +/* + * Copyright 2022 RelationalAI, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; + +namespace RelationalAI +{ + // This handler caches tokens in ~/.rai/tokens.json. It will attempt to load + // a token from the cache file and if it is not found or has expired, it will + // delegate to rest.RequestAccessTokenAsync to retrieve a new token and will save it + // in the cache file. + public class DefaultAccessTokenHandler : IAccessTokenHandler + { + private static readonly SemaphoreSlim SemaphoreSlim = new SemaphoreSlim(1); + private readonly Rest _rest; + private readonly ILogger _logger; + private readonly string _cachePath; + + public DefaultAccessTokenHandler(Rest rest, string cachePath = null, ILogger logger = null) + { + _rest = rest; + _cachePath = cachePath ?? DefaultCachePath(); + _logger = logger ?? new LoggerFactory().CreateLogger("RAI-SDK"); + } + + public async Task GetAccessTokenAsync(ClientCredentials creds) + { + var token = ReadAccessToken(creds); + if (token != null && !token.IsExpired) + { + creds.AccessToken = token; + return creds.AccessToken; + } + + token = await _rest.RequestAccessTokenAsync(creds); + if (token != null) + { + creds.AccessToken = token; + await WriteAccessTokenAsync(creds, token); + } + + return creds.AccessToken; + } + + private string DefaultCachePath() + { + var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); + + return Path.Join(home, ".rai", "tokens.json"); + } + + private AccessToken ReadAccessToken(ClientCredentials creds) + { + var cache = ReadTokenCache(); + + if (cache.ContainsKey(creds.ClientId)) + { + return cache[creds.ClientId]; + } + + return null; + } + + private Dictionary ReadTokenCache() + { + try + { + var data = File.ReadAllText(_cachePath); + var cache = JsonConvert.DeserializeObject>(data); + return cache; + } + catch (IOException ex) + { + _logger.LogInformation($"Unable to read from local cache, fallback to memory-based cache. {ex.Message}"); + } + + return new Dictionary(); + } + + private async Task WriteAccessTokenAsync(ClientCredentials creds, AccessToken token) + { + try + { + await SemaphoreSlim.WaitAsync(); + + var dict = ReadTokenCache(); + + if (dict.ContainsKey(creds.ClientId)) + { + dict[creds.ClientId] = token; + } + else + { + dict.Add(creds.ClientId, token); + } + + File.WriteAllText(_cachePath, JsonConvert.SerializeObject(dict)); + } + catch (IOException ex) + { + _logger.LogWarning($"Unable to write to local cache. {ex.Message}"); + } + finally + { + SemaphoreSlim.Release(); + } + } + } +} diff --git a/RelationalAI/IAccessTokenHandler.cs b/RelationalAI/IAccessTokenHandler.cs new file mode 100644 index 0000000..63bbe99 --- /dev/null +++ b/RelationalAI/IAccessTokenHandler.cs @@ -0,0 +1,29 @@ +/* + * Copyright 2022 RelationalAI, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace RelationalAI +{ + // IAccessTokenHandler is a contract interface + // for custom access token handlers implementation. + // Check DefaultAccessToken for an example + public interface IAccessTokenHandler + { + Task GetAccessTokenAsync(ClientCredentials creds); + } +} diff --git a/RelationalAI/RelationalAI.csproj b/RelationalAI/RelationalAI.csproj index c9f2597..66e2f4b 100644 --- a/RelationalAI/RelationalAI.csproj +++ b/RelationalAI/RelationalAI.csproj @@ -1,7 +1,7 @@  - 0.9.12-alpha + 0.9.13-alpha netcoreapp3.1 RAI RelationalAI diff --git a/RelationalAI/Rest.cs b/RelationalAI/Rest.cs index 8b75d59..a8d8c8b 100644 --- a/RelationalAI/Rest.cs +++ b/RelationalAI/Rest.cs @@ -41,7 +41,7 @@ public class Rest private readonly ILogger _logger; - public Rest(Context context, ILogger logger) + public Rest(Context context, ILogger logger = null) { _context = context; HttpClient = new HttpClient(); @@ -50,6 +50,8 @@ public Rest(Context context, ILogger logger) public HttpClient HttpClient { get; set; } + internal IAccessTokenHandler AccessTokenHandler { get; set; } + public static string EncodeQueryString(Dictionary parameters) { if (parameters == null) @@ -122,7 +124,7 @@ public async Task RequestAsync( } } - var accessToken = await GetAccessTokenAsync(GetHost(url)); + var accessToken = await GetAccessTokenAsync(); caseInsensitiveHeaders.Add("Authorization", $"Bearer {accessToken}"); return await RequestHelperAsync(method, url, data, caseInsensitiveHeaders, parameters); } @@ -163,6 +165,34 @@ public MetadataInfo ReadMetadataProtobuf(byte[] data) return MetadataInfo.Parser.ParseFrom(data); } + public async Task RequestAccessTokenAsync(ClientCredentials creds) + { + // Form the API request body. + var data = new Dictionary + { + { "client_id", creds.ClientId }, + { "client_secret", creds.ClientSecret }, + { "audience", creds.Audience }, + { "grant_type", "client_credentials" } + }; + var resp = await RequestHelperAsync("POST", creds.ClientCredentialsUrl, data); + if (!(resp is string stringResponse)) + { + throw new InvalidResponseException( + $"Unexpected response type, expected a string but received {resp.GetType().Name}", + resp); + } + + var result = JsonConvert.DeserializeObject>(stringResponse); + + if (result == null) + { + throw new InvalidResponseException("Unexpected access token response format", resp); + } + + return new AccessToken(result["access_token"], new DateTimeOffset(DateTime.Now).ToUnixTimeSeconds(), int.Parse(result["expires_in"]), result["scope"]); + } + private static HttpContent EncodeContent(object body) { if (body == null) @@ -179,11 +209,6 @@ private static HttpContent EncodeContent(object body) return new ByteArrayContent(Encoding.UTF8.GetBytes(s)); } - private static string GetHost(string url) - { - return new Uri(url).Host; - } - private static string GetUserAgent() { return $"rai-sdk-csharp/{SdkProperties.Version}"; @@ -272,7 +297,7 @@ private static List ParseMultipartResponse(byte[] content) return output; } - private async Task GetAccessTokenAsync(string host) + private async Task GetAccessTokenAsync() { if (!(_context.Credentials is ClientCredentials creds)) { @@ -281,40 +306,12 @@ private async Task GetAccessTokenAsync(string host) if (creds.AccessToken == null || creds.AccessToken.IsExpired) { - creds.AccessToken = await RequestAccessTokenAsync(host, creds); + creds.AccessToken = await AccessTokenHandler.GetAccessTokenAsync(creds); } return creds.AccessToken.Token; } - private async Task RequestAccessTokenAsync(string host, ClientCredentials creds) - { - // Form the API request body. - var data = new Dictionary - { - { "client_id", creds.ClientId }, - { "client_secret", creds.ClientSecret }, - { "audience", $"https://{host}" }, - { "grant_type", "client_credentials" } - }; - var resp = await RequestHelperAsync("POST", creds.ClientCredentialsUrl, data); - if (!(resp is string stringResponse)) - { - throw new InvalidResponseException( - $"Unexpected response type, expected a string but received {resp.GetType().Name}", - resp); - } - - var result = JsonConvert.DeserializeObject>(stringResponse); - - if (result == null) - { - throw new InvalidResponseException("Unexpected access token response format", resp); - } - - return new AccessToken(result["access_token"], int.Parse(result["expires_in"])); - } - private async Task RequestHelperAsync( string method, string url,