Skip to content

Commit

Permalink
Introduce token caching when using AAD auth to connect to SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
mderriey committed Feb 19, 2021
1 parent 10697f1 commit c3df01f
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<Project Sdk="Microsoft.NET.Sdk.Web">

<PropertyGroup>
<TargetFramework>net5.0</TargetFramework>
Expand All @@ -11,6 +11,7 @@
<PackageReference Include="Dapper" Version="2.0.78" />
<PackageReference Include="Microsoft.ApplicationInsights.AspNetCore" Version="2.16.0" />
<PackageReference Include="Microsoft.Data.SqlClient" Version="2.1.1" />
<PackageReference Include="Scrutor" Version="3.3.0" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Identity;

namespace AzureIdentityLivestream.Web.Services.Sql
{
public class AzureIdentityAzureSqlTokenProvider : IAzureSqlTokenProvider
{
private static readonly TokenCredential _credential = new ChainedTokenCredential(
new ManagedIdentityCredential(),
new VisualStudioCodeCredential());

private static readonly string[] _azureSqlScopes = new string[] { "https://database.windows.net//.default" };

public async Task<(string Token, DateTimeOffset ExpiresOn)> GetAccessTokenAsync(CancellationToken cancellationToken = default)
{
var tokenRequest = new TokenRequestContext(_azureSqlScopes);
var tokenResult = await _credential.GetTokenAsync(tokenRequest, cancellationToken);

return (tokenResult.Token, tokenResult.ExpiresOn);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Caching.Memory;

namespace AzureIdentityLivestream.Web.Services.Sql
{
public class CacheAzureSqlTokenProvider : IAzureSqlTokenProvider
{
private static readonly string _cacheKey = $"{nameof(CacheAzureSqlTokenProvider)}.{nameof(GetAccessTokenAsync)}";

private readonly IAzureSqlTokenProvider _inner;
private readonly IMemoryCache _cache;

public CacheAzureSqlTokenProvider(IAzureSqlTokenProvider inner, IMemoryCache cache)
{
_inner = inner;
_cache = cache;
}

public async Task<(string Token, DateTimeOffset ExpiresOn)> GetAccessTokenAsync(CancellationToken cancellationToken)
{
return await _cache.GetOrCreateAsync(_cacheKey, async entry =>
{
var (token, expiresOn) = await _inner.GetAccessTokenAsync(cancellationToken);

entry.SetAbsoluteExpiration(expiresOn);

return (token, expiresOn);
});
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
using System;
using System.Threading;
using System.Threading.Tasks;

namespace AzureIdentityLivestream.Web.Services.Sql
{
public interface IAzureSqlTokenProvider
{
Task<(string Token, DateTimeOffset ExpiresOn)> GetAccessTokenAsync(CancellationToken cancellationToken = default);
}
}
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
using System;
using System.Threading.Tasks;
using Azure.Core;
using Azure.Identity;
using Microsoft.Data.SqlClient;

namespace AzureIdentityLivestream.Web.Services.Sql
{
public class SqlConnectionFactory
{
private readonly string _connectionString;
private readonly IAzureSqlTokenProvider _azureSqlTokenProvider;

public SqlConnectionFactory(string connectionString)
public SqlConnectionFactory(string connectionString, IAzureSqlTokenProvider azureSqlTokenProvider)
{
_connectionString = connectionString;
_azureSqlTokenProvider = azureSqlTokenProvider;
}

public async Task<SqlConnection> CreateConnection()
Expand All @@ -23,14 +23,8 @@ public async Task<SqlConnection> CreateConnection()
if (sqlConnection.DataSource.Contains("database.windows.net", StringComparison.OrdinalIgnoreCase) &&
string.IsNullOrEmpty(sqlConnectionStringBuilder.UserID))
{
var credential = new ChainedTokenCredential(
new ManagedIdentityCredential(),
new VisualStudioCodeCredential());

var tokenRequest = new TokenRequestContext(new[] { "https://database.windows.net//.default" });
var tokenResponse = await credential.GetTokenAsync(tokenRequest);

sqlConnection.AccessToken = tokenResponse.Token;
var (token, _) = await _azureSqlTokenProvider.GetAccessTokenAsync();
sqlConnection.AccessToken = token;
}

return sqlConnection;
Expand Down
12 changes: 6 additions & 6 deletions src/AzureIdentityLivestream.Web/Startup.cs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
using System;
using Azure.Identity;
using Azure.Storage.Blobs;
using AzureIdentityLivestream.Web.Services;
using AzureIdentityLivestream.Web.Services.AzureBlobStorage;
using AzureIdentityLivestream.Web.Services.Sql;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
Expand All @@ -25,10 +21,14 @@ public void ConfigureServices(IServiceCollection services)
{
services.AddApplicationInsightsTelemetry();

services.AddSingleton(_ =>
services.AddMemoryCache();
services.AddSingleton<IAzureSqlTokenProvider, AzureIdentityAzureSqlTokenProvider>();
services.Decorate<IAzureSqlTokenProvider, CacheAzureSqlTokenProvider>();

services.AddSingleton(provider =>
{
var sqlConnectionString = Configuration.GetValue<string>("SqlConnectionString");
return new SqlConnectionFactory(sqlConnectionString);
return new SqlConnectionFactory(sqlConnectionString, provider.GetRequiredService<IAzureSqlTokenProvider>());
});

services.AddSingleton<IPersonProvider, DapperPersonProvider>();
Expand Down

0 comments on commit c3df01f

Please sign in to comment.