Skip to content

Commit

Permalink
Enable access token caching (#66)
Browse files Browse the repository at this point in the history
* enable access token caching

* addressing PR comments

* fix

* convenient test

* addressing PR comments

* changelog update

* fix changelog version & using userProfile enum to get homedir
  • Loading branch information
NRHelmi authored Apr 11, 2023
1 parent 1787e81 commit 79dad30
Show file tree
Hide file tree
Showing 12 changed files with 304 additions and 54 deletions.
1 change: 1 addition & 0 deletions .github/workflows/dotnet-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:

- name: Dotnet linter
run: |
mkdir ~/.rai
dotnet tool install -g dotnet-format
dotnet format . --check
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

# v0.9.13-alpha
* Add `IAccessTokenHandler` interface to implement custom access Auth0 token handlers.
* `Client` class consturctor extended to accept custom access token handlers implementations.
* Add `DefaultAccessTokenHandler` to cache Auth0 tokens locally.

# v0.9.12-alpha
* Add support to v2 `CreateDatabase`.
# v0.9.11-alpha
Expand Down
49 changes: 49 additions & 0 deletions RelationalAI.Test/DefaultAccessTokenHandlerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using System;
using System.IO;
using System.Threading.Tasks;
using FluentAssertions;
using Xunit;
using Xunit.Abstractions;

namespace RelationalAI.Test
{
[Collection("RelationalAI.Test")]
public class DefaultAccessTokenHandlerTests : UnitTest
{
private string TestCachePath()
{
var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile);

return Path.Join(home, ".rai", "test_tokens.json");
}

public DefaultAccessTokenHandlerTests(ITestOutputHelper output) : base(output)
{ }

[Fact]
public async Task DefaultAccessTokenHandlerTestAsync()
{
var ctx = CreateContext(GetConfig());
var creds = ctx.Credentials as ClientCredentials;
var rest = new Rest(ctx);

// should generate token if cache path doesn't exist
var accessTokenHandler = new DefaultAccessTokenHandler(rest, "/fake/path/tokens.json", _logger);
var token = await accessTokenHandler.GetAccessTokenAsync(creds);
token.Should().NotBeNull();
token.Should().BeEquivalentTo(creds.AccessToken);

// should generate and cache token if the path exists
accessTokenHandler = new DefaultAccessTokenHandler(rest, TestCachePath(), _logger);
token = await accessTokenHandler.GetAccessTokenAsync(creds);
token.Should().NotBeNull();
token.Should().BeEquivalentTo(creds.AccessToken);

accessTokenHandler = new DefaultAccessTokenHandler(null, TestCachePath(), _logger);
var cachedToken = await accessTokenHandler.GetAccessTokenAsync(creds);
token.Should().NotBeNull();
token.Should().BeEquivalentTo(cachedToken);
token.Should().BeEquivalentTo(creds.AccessToken);
}
}
}
18 changes: 14 additions & 4 deletions RelationalAI.Test/UnitTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace RelationalAI.Test
public class UnitTest : IAsyncLifetime
{
private static readonly RAILog4NetProvider _loggerProvider = new RAILog4NetProvider();
private readonly ILogger _logger;
protected readonly ILogger _logger;

public UnitTest()
{ }
Expand All @@ -25,7 +25,7 @@ public UnitTest(ITestOutputHelper testOutputHelper)
_logger = _loggerProvider.CreateLogger("RAI-SDK");
}

public Client CreateClient()
public Dictionary<string, object> GetConfig()
{
Dictionary<string, object> config;

Expand Down Expand Up @@ -57,9 +57,19 @@ public Client CreateClient()
config = Config.Read(new MemoryStream(Encoding.UTF8.GetBytes(configStr)));
}

var customHeaders = JsonConvert.DeserializeObject<Dictionary<string, string>>(GetEnvironmentVariable("CUSTOM_HEADERS"));
return config;
}

var ctx = new Client.Context(config);
public Client.Context CreateContext(Dictionary<string, object> config)
{
return new Client.Context(config);
}

public Client CreateClient()
{
var customHeaders = JsonConvert.DeserializeObject<Dictionary<string, string>>(GetEnvironmentVariable("CUSTOM_HEADERS"));
var config = GetConfig();
var ctx = CreateContext(config);

Client testClient;
if (_logger != null)
Expand Down
30 changes: 22 additions & 8 deletions RelationalAI/AccessToken.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,43 @@
*/

using System;
using Newtonsoft.Json;

namespace RelationalAI
{
public class AccessToken
public class AccessToken : Entity
{
private readonly DateTime _createdOn;
[JsonProperty("expires_in")]
private int _expiresIn;

public AccessToken(string token, int expiresIn)
public AccessToken(string token, long createdOn, int expiresIn, string scope)
{
Token = token;
ExpiresIn = expiresIn;
_createdOn = DateTime.Now;
Scope = scope;
CreatedOn = createdOn;
_expiresIn = expiresIn;
}

public bool IsExpired => (DateTime.Now - _createdOn).TotalSeconds >= ExpiresIn - 5; // Anticipate access token expiration by 5 seconds

[JsonProperty("access_token", Required = Required.Always)]
public string Token { get; set; }

[JsonProperty("scope")]
public string Scope { get; set; }

[JsonProperty("created_on")]
public long CreatedOn { get; set; }

public int ExpiresIn
{
get => _expiresIn;
set => _expiresIn = value > 0 ? value : throw new ArgumentException("ExpiresIn should be greater than 0 ");
}

public double ExpiresOn
{
get => CreatedOn + _expiresIn;
}

public bool IsExpired => new DateTimeOffset(DateTime.Now).ToUnixTimeSeconds() >= ExpiresOn - 5; // Anticipate access token expiration by 5 seconds
}
}
}
3 changes: 2 additions & 1 deletion RelationalAI/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ public class Client
private readonly Context _context;
private readonly ILogger _logger;

public Client(Context context, ILogger logger = null)
public Client(Context context, ILogger logger = null, IAccessTokenHandler accessTokenHandler = null)
{
_context = context;
_logger = logger ?? new LoggerFactory().CreateLogger("RAI-SDK");
_rest = new Rest(context, _logger);
_rest.AccessTokenHandler = accessTokenHandler ?? new DefaultAccessTokenHandler(_rest, logger: _logger);
}

public HttpClient HttpClient
Expand Down
13 changes: 11 additions & 2 deletions RelationalAI/ClientCredentials.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public class ClientCredentials : ICredentials
private const string DefaultClientCredentialsUrl = "https://login.relationalai.com/oauth/token";
private string _clientId;
private string _clientSecret;
private string _audience;
private string _clientCredentialsUrl = DefaultClientCredentialsUrl;

public ClientCredentials(string clientId, string clientSecret)
Expand All @@ -31,10 +32,11 @@ public ClientCredentials(string clientId, string clientSecret)
ClientSecret = clientSecret;
}

public ClientCredentials(string clientId, string clientSecret, string clientCredentialsUrl)
public ClientCredentials(string clientId, string clientSecret, string clientCredentialsUrl, string audience)
: this(clientId, clientSecret)
{
ClientCredentialsUrl = clientCredentialsUrl;
Audience = audience;
}

public string ClientId
Expand All @@ -57,6 +59,13 @@ public string ClientCredentialsUrl
set => _clientCredentialsUrl = !string.IsNullOrEmpty(value) ? value : DefaultClientCredentialsUrl;
}

public string Audience
{
get => _audience;
set => _audience = !string.IsNullOrEmpty(value) ? value :
throw new ArgumentException("Audience cannot be null or empty");
}

public AccessToken AccessToken { get; set; }
}
}
}
9 changes: 8 additions & 1 deletion RelationalAI/Config.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,16 @@ private static ICredentials ReadClientCredentials(IniData data, string profile)
var clientId = GetIniValue(data, profile, "client_id", null);
var clientSecret = GetIniValue(data, profile, "client_secret", null);
var clientCredentialsUrl = GetIniValue(data, profile, "client_credentials_url", null);
var audience = GetIniValue(data, profile, "audience", null);
if (clientId != null && clientSecret != null)
{
return new ClientCredentials(clientId, clientSecret, clientCredentialsUrl);
if (string.IsNullOrEmpty(audience))
{
var host = GetIniValue(data, profile, "host", null);
audience = $"https://{host}";
}

return new ClientCredentials(clientId, clientSecret, clientCredentialsUrl, audience);
}

return null;
Expand Down
128 changes: 128 additions & 0 deletions RelationalAI/DefaultAccessTokenHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright 2022 RelationalAI, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Collections.Generic;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;

namespace RelationalAI
{
// This handler caches tokens in ~/.rai/tokens.json. It will attempt to load
// a token from the cache file and if it is not found or has expired, it will
// delegate to rest.RequestAccessTokenAsync to retrieve a new token and will save it
// in the cache file.
public class DefaultAccessTokenHandler : IAccessTokenHandler
{
private static readonly SemaphoreSlim SemaphoreSlim = new SemaphoreSlim(1);
private readonly Rest _rest;
private readonly ILogger _logger;
private readonly string _cachePath;

public DefaultAccessTokenHandler(Rest rest, string cachePath = null, ILogger logger = null)
{
_rest = rest;
_cachePath = cachePath ?? DefaultCachePath();
_logger = logger ?? new LoggerFactory().CreateLogger("RAI-SDK");
}

public async Task<AccessToken> GetAccessTokenAsync(ClientCredentials creds)
{
var token = ReadAccessToken(creds);
if (token != null && !token.IsExpired)
{
creds.AccessToken = token;
return creds.AccessToken;
}

token = await _rest.RequestAccessTokenAsync(creds);
if (token != null)
{
creds.AccessToken = token;
await WriteAccessTokenAsync(creds, token);
}

return creds.AccessToken;
}

private string DefaultCachePath()
{
var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile);

return Path.Join(home, ".rai", "tokens.json");
}

private AccessToken ReadAccessToken(ClientCredentials creds)
{
var cache = ReadTokenCache();

if (cache.ContainsKey(creds.ClientId))
{
return cache[creds.ClientId];
}

return null;
}

private Dictionary<string, AccessToken> ReadTokenCache()
{
try
{
var data = File.ReadAllText(_cachePath);
var cache = JsonConvert.DeserializeObject<Dictionary<string, AccessToken>>(data);
return cache;
}
catch (IOException ex)
{
_logger.LogInformation($"Unable to read from local cache, fallback to memory-based cache. {ex.Message}");
}

return new Dictionary<string, AccessToken>();
}

private async Task WriteAccessTokenAsync(ClientCredentials creds, AccessToken token)
{
try
{
await SemaphoreSlim.WaitAsync();

var dict = ReadTokenCache();

if (dict.ContainsKey(creds.ClientId))
{
dict[creds.ClientId] = token;
}
else
{
dict.Add(creds.ClientId, token);
}

File.WriteAllText(_cachePath, JsonConvert.SerializeObject(dict));
}
catch (IOException ex)
{
_logger.LogWarning($"Unable to write to local cache. {ex.Message}");
}
finally
{
SemaphoreSlim.Release();
}
}
}
}
29 changes: 29 additions & 0 deletions RelationalAI/IAccessTokenHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2022 RelationalAI, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Threading.Tasks;
using Microsoft.Extensions.Logging;

namespace RelationalAI
{
// IAccessTokenHandler is a contract interface
// for custom access token handlers implementation.
// Check DefaultAccessToken for an example
public interface IAccessTokenHandler
{
Task<AccessToken> GetAccessTokenAsync(ClientCredentials creds);
}
}
2 changes: 1 addition & 1 deletion RelationalAI/RelationalAI.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<Version>0.9.12-alpha</Version>
<Version>0.9.13-alpha</Version>
<TargetFramework>netcoreapp3.1</TargetFramework>
<PackageId>RAI</PackageId>
<Authors>RelationalAI</Authors>
Expand Down
Loading

0 comments on commit 79dad30

Please sign in to comment.