Skip to content

Commit

Permalink
Add EF Core interceptor to use AAD auth and token caching
Browse files Browse the repository at this point in the history
  • Loading branch information
mderriey committed Feb 19, 2021
1 parent 02a373d commit 9f8e910
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using System;
using System.Data.Common;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.SqlClient;
using Microsoft.EntityFrameworkCore.Diagnostics;

namespace AzureIdentityLivestream.Web.Services.Sql
{
public class AzureAdAuthenticationDbConnectionInterceptor : DbConnectionInterceptor
{
private readonly IAzureSqlTokenProvider _azureSqlTokenProvider;

public AzureAdAuthenticationDbConnectionInterceptor(IAzureSqlTokenProvider azureSqlTokenProvider)
{
_azureSqlTokenProvider = azureSqlTokenProvider;
}

public override InterceptionResult ConnectionOpening(DbConnection connection, ConnectionEventData eventData, InterceptionResult result)
{
var sqlConnection = (SqlConnection)connection;
if (ConnectionNeedsAccessToken(sqlConnection))
{
var (token, _) = _azureSqlTokenProvider.GetAccessToken();
sqlConnection.AccessToken = token;
}

return base.ConnectionOpening(connection, eventData, result);
}

public override async ValueTask<InterceptionResult> ConnectionOpeningAsync(DbConnection connection, ConnectionEventData eventData, InterceptionResult result, CancellationToken cancellationToken = default)
{
var sqlConnection = (SqlConnection)connection;
if (ConnectionNeedsAccessToken(sqlConnection))
{
var (token, _) = await _azureSqlTokenProvider.GetAccessTokenAsync(cancellationToken);
sqlConnection.AccessToken = token;
}

return await base.ConnectionOpeningAsync(connection, eventData, result, cancellationToken);
}

private static bool ConnectionNeedsAccessToken(SqlConnection connection)
{
var connectionStringBuilder = new SqlConnectionStringBuilder(connection.ConnectionString);

return connectionStringBuilder.DataSource.Contains("database.windows.net", StringComparison.OrdinalIgnoreCase) &&
string.IsNullOrEmpty(connectionStringBuilder.UserID);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ public class AzureIdentityAzureSqlTokenProvider : IAzureSqlTokenProvider

private static readonly string[] _azureSqlScopes = new string[] { "https://database.windows.net//.default" };

public (string Token, DateTimeOffset ExpiresOn) GetAccessToken()
{
var tokenRequest = new TokenRequestContext(_azureSqlScopes);
var tokenResult = _credential.GetToken(tokenRequest, default);

return (tokenResult.Token, tokenResult.ExpiresOn);
}

public async Task<(string Token, DateTimeOffset ExpiresOn)> GetAccessTokenAsync(CancellationToken cancellationToken = default)
{
var tokenRequest = new TokenRequestContext(_azureSqlScopes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ public CacheAzureSqlTokenProvider(IAzureSqlTokenProvider inner, IMemoryCache cac
_cache = cache;
}

public (string Token, DateTimeOffset ExpiresOn) GetAccessToken()
{
return _cache.GetOrCreate(_cacheKey, entry =>
{
var (token, expiresOn) = _inner.GetAccessToken();

entry.SetAbsoluteExpiration(expiresOn);

return (token, expiresOn);
});
}

public async Task<(string Token, DateTimeOffset ExpiresOn)> GetAccessTokenAsync(CancellationToken cancellationToken)
{
return await _cache.GetOrCreateAsync(_cacheKey, async entry =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ namespace AzureIdentityLivestream.Web.Services.Sql
{
public interface IAzureSqlTokenProvider
{
(string Token, DateTimeOffset ExpiresOn) GetAccessToken();
Task<(string Token, DateTimeOffset ExpiresOn)> GetAccessTokenAsync(CancellationToken cancellationToken = default);
}
}
11 changes: 9 additions & 2 deletions src/AzureIdentityLivestream.Web/Startup.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,17 @@ public void ConfigureServices(IServiceCollection services)
{
services.AddApplicationInsightsTelemetry();

services.AddDbContext<LivestreamContext>(builder =>
services.AddMemoryCache();
services.AddSingleton<IAzureSqlTokenProvider, AzureIdentityAzureSqlTokenProvider>();
services.Decorate<IAzureSqlTokenProvider, CacheAzureSqlTokenProvider>();
services.AddSingleton<AzureAdAuthenticationDbConnectionInterceptor>();

services.AddDbContext<LivestreamContext>((provider, builder) =>
{
var sqlConnectionString = Configuration.GetValue<string>("SqlConnectionString");
builder.UseSqlServer(sqlConnectionString, options => options.EnableRetryOnFailure());
builder
.UseSqlServer(sqlConnectionString, options => options.EnableRetryOnFailure())
.AddInterceptors(provider.GetRequiredService<AzureAdAuthenticationDbConnectionInterceptor>());
});

services.AddScoped<IPersonProvider, EfCorePersonProvider>();
Expand Down
2 changes: 1 addition & 1 deletion src/AzureIdentityLivestream.Web/appsettings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"StorageConnectionString": "https://azureidentitylivestream.blob.core.windows.net",
"SqlConnectionString": "Data Source=(localdb)\\azureidentitylivestream; Initial Catalog=azure-identity-livestream; Integrated Security=SSPI",
"SqlConnectionString": "Data Source=tcp:azureidentitylivestream.database.windows.net,1433; Initial Catalog=azure-identity-livestream; Encrypt=True; TrustServerCertificate=False",
"Logging": {
"LogLevel": {
"Default": "Information",
Expand Down

0 comments on commit 9f8e910

Please sign in to comment.