-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* enable access token caching * addressing PR comments * fix * convenient test * addressing PR comments * changelog update * fix changelog version & using userProfile enum to get homedir
- Loading branch information
Showing
12 changed files
with
304 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
using System; | ||
using System.IO; | ||
using System.Threading.Tasks; | ||
using FluentAssertions; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
|
||
namespace RelationalAI.Test | ||
{ | ||
[Collection("RelationalAI.Test")] | ||
public class DefaultAccessTokenHandlerTests : UnitTest | ||
{ | ||
private string TestCachePath() | ||
{ | ||
var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); | ||
|
||
return Path.Join(home, ".rai", "test_tokens.json"); | ||
} | ||
|
||
public DefaultAccessTokenHandlerTests(ITestOutputHelper output) : base(output) | ||
{ } | ||
|
||
[Fact] | ||
public async Task DefaultAccessTokenHandlerTestAsync() | ||
{ | ||
var ctx = CreateContext(GetConfig()); | ||
var creds = ctx.Credentials as ClientCredentials; | ||
var rest = new Rest(ctx); | ||
|
||
// should generate token if cache path doesn't exist | ||
var accessTokenHandler = new DefaultAccessTokenHandler(rest, "/fake/path/tokens.json", _logger); | ||
var token = await accessTokenHandler.GetAccessTokenAsync(creds); | ||
token.Should().NotBeNull(); | ||
token.Should().BeEquivalentTo(creds.AccessToken); | ||
|
||
// should generate and cache token if the path exists | ||
accessTokenHandler = new DefaultAccessTokenHandler(rest, TestCachePath(), _logger); | ||
token = await accessTokenHandler.GetAccessTokenAsync(creds); | ||
token.Should().NotBeNull(); | ||
token.Should().BeEquivalentTo(creds.AccessToken); | ||
|
||
accessTokenHandler = new DefaultAccessTokenHandler(null, TestCachePath(), _logger); | ||
var cachedToken = await accessTokenHandler.GetAccessTokenAsync(creds); | ||
token.Should().NotBeNull(); | ||
token.Should().BeEquivalentTo(cachedToken); | ||
token.Should().BeEquivalentTo(creds.AccessToken); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
/* | ||
* Copyright 2022 RelationalAI, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
using System; | ||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using Microsoft.Extensions.Logging; | ||
using Newtonsoft.Json; | ||
|
||
namespace RelationalAI | ||
{ | ||
// This handler caches tokens in ~/.rai/tokens.json. It will attempt to load | ||
// a token from the cache file and if it is not found or has expired, it will | ||
// delegate to rest.RequestAccessTokenAsync to retrieve a new token and will save it | ||
// in the cache file. | ||
public class DefaultAccessTokenHandler : IAccessTokenHandler | ||
{ | ||
private static readonly SemaphoreSlim SemaphoreSlim = new SemaphoreSlim(1); | ||
private readonly Rest _rest; | ||
private readonly ILogger _logger; | ||
private readonly string _cachePath; | ||
|
||
public DefaultAccessTokenHandler(Rest rest, string cachePath = null, ILogger logger = null) | ||
{ | ||
_rest = rest; | ||
_cachePath = cachePath ?? DefaultCachePath(); | ||
_logger = logger ?? new LoggerFactory().CreateLogger("RAI-SDK"); | ||
} | ||
|
||
public async Task<AccessToken> GetAccessTokenAsync(ClientCredentials creds) | ||
{ | ||
var token = ReadAccessToken(creds); | ||
if (token != null && !token.IsExpired) | ||
{ | ||
creds.AccessToken = token; | ||
return creds.AccessToken; | ||
} | ||
|
||
token = await _rest.RequestAccessTokenAsync(creds); | ||
if (token != null) | ||
{ | ||
creds.AccessToken = token; | ||
await WriteAccessTokenAsync(creds, token); | ||
} | ||
|
||
return creds.AccessToken; | ||
} | ||
|
||
private string DefaultCachePath() | ||
{ | ||
var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); | ||
|
||
return Path.Join(home, ".rai", "tokens.json"); | ||
} | ||
|
||
private AccessToken ReadAccessToken(ClientCredentials creds) | ||
{ | ||
var cache = ReadTokenCache(); | ||
|
||
if (cache.ContainsKey(creds.ClientId)) | ||
{ | ||
return cache[creds.ClientId]; | ||
} | ||
|
||
return null; | ||
} | ||
|
||
private Dictionary<string, AccessToken> ReadTokenCache() | ||
{ | ||
try | ||
{ | ||
var data = File.ReadAllText(_cachePath); | ||
var cache = JsonConvert.DeserializeObject<Dictionary<string, AccessToken>>(data); | ||
return cache; | ||
} | ||
catch (IOException ex) | ||
{ | ||
_logger.LogInformation($"Unable to read from local cache, fallback to memory-based cache. {ex.Message}"); | ||
} | ||
|
||
return new Dictionary<string, AccessToken>(); | ||
} | ||
|
||
private async Task WriteAccessTokenAsync(ClientCredentials creds, AccessToken token) | ||
{ | ||
try | ||
{ | ||
await SemaphoreSlim.WaitAsync(); | ||
|
||
var dict = ReadTokenCache(); | ||
|
||
if (dict.ContainsKey(creds.ClientId)) | ||
{ | ||
dict[creds.ClientId] = token; | ||
} | ||
else | ||
{ | ||
dict.Add(creds.ClientId, token); | ||
} | ||
|
||
File.WriteAllText(_cachePath, JsonConvert.SerializeObject(dict)); | ||
} | ||
catch (IOException ex) | ||
{ | ||
_logger.LogWarning($"Unable to write to local cache. {ex.Message}"); | ||
} | ||
finally | ||
{ | ||
SemaphoreSlim.Release(); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
/* | ||
* Copyright 2022 RelationalAI, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
using System.Threading.Tasks; | ||
using Microsoft.Extensions.Logging; | ||
|
||
namespace RelationalAI | ||
{ | ||
// IAccessTokenHandler is a contract interface | ||
// for custom access token handlers implementation. | ||
// Check DefaultAccessToken for an example | ||
public interface IAccessTokenHandler | ||
{ | ||
Task<AccessToken> GetAccessTokenAsync(ClientCredentials creds); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.