Skip to content

Commit

Permalink
Allow resolving cache based on DbConnection
Browse files Browse the repository at this point in the history
  • Loading branch information
moozzyk committed Jul 29, 2018
1 parent d5506af commit 24898cb
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 74 deletions.
31 changes: 16 additions & 15 deletions EFCache/CacheTransactionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,24 @@ namespace EFCache

public class CacheTransactionHandler : IDbTransactionInterceptor
{
private readonly ConcurrentDictionary<DbTransaction, List<string>> _affectedSetsInTransaction
private readonly ConcurrentDictionary<DbTransaction, List<string>> _affectedSetsInTransaction
= new ConcurrentDictionary<DbTransaction, List<string>>();
private readonly ICache _cache;

public CacheTransactionHandler(ICache cache)
{
if (cache == null)
{
throw new ArgumentNullException("cache");
}
_cache = cache ?? throw new ArgumentNullException(nameof(cache));
}

_cache = cache;
protected CacheTransactionHandler()
{
}

public virtual bool GetItem(DbTransaction transaction, string key, out object value)
public virtual bool GetItem(DbTransaction transaction, string key, DbConnection connection, out object value)
{
if (transaction == null)
{
return _cache.GetItem(key, out value);
return ResolveCache(connection).GetItem(key, out value);
}

value = null;
Expand All @@ -39,19 +38,19 @@ public virtual bool GetItem(DbTransaction transaction, string key, out object va
}

public virtual void PutItem(DbTransaction transaction, string key, object value, IEnumerable<string> dependentEntitySets, TimeSpan slidingExpiration,
DateTimeOffset absoluteExpiration)
DateTimeOffset absoluteExpiration, DbConnection connection)
{
if (transaction == null)
{
_cache.PutItem(key, value, dependentEntitySets, slidingExpiration, absoluteExpiration);
ResolveCache(connection).PutItem(key, value, dependentEntitySets, slidingExpiration, absoluteExpiration);
}
}

public virtual void InvalidateSets(DbTransaction transaction, IEnumerable<string> entitySets)
public virtual void InvalidateSets(DbTransaction transaction, IEnumerable<string> entitySets, DbConnection connection)
{
if (transaction == null)
{
_cache.InvalidateSets(entitySets);
ResolveCache(connection).InvalidateSets(entitySets);
}
else
{
Expand All @@ -77,9 +76,8 @@ protected void AddAffectedEntitySets(DbTransaction transaction, IEnumerable<stri

private IEnumerable<string> RemoveAffectedEntitySets(DbTransaction transaction)
{
List<string> affectedEntitySets;

_affectedSetsInTransaction.TryRemove(transaction, out affectedEntitySets);
_affectedSetsInTransaction.TryRemove(transaction, out List<string> affectedEntitySets);

return affectedEntitySets;
}
Expand All @@ -90,7 +88,7 @@ public void Committed(DbTransaction transaction, DbTransactionInterceptionContex

if (entitySets != null)
{
_cache.InvalidateSets(entitySets.Distinct());
ResolveCache(transaction.Connection).InvalidateSets(entitySets.Distinct());
}
}

Expand Down Expand Up @@ -130,5 +128,8 @@ public void RolledBack(DbTransaction transaction, DbTransactionInterceptionConte
public void RollingBack(DbTransaction transaction, DbTransactionInterceptionContext interceptionContext)
{
}

protected virtual ICache ResolveCache(DbConnection connection)
=> _cache ?? throw new InvalidOperationException("Cannot resolve cache because it has not been initialized.");
}
}
25 changes: 15 additions & 10 deletions EFCache/CachingCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior)

if (!_commandTreeFacts.IsQuery)
{
_cacheTransactionHandler.InvalidateSets(Transaction, _commandTreeFacts.AffectedEntitySets.Select(s => s.Name));
_cacheTransactionHandler.InvalidateSets(Transaction, _commandTreeFacts.AffectedEntitySets.Select(s => s.Name),
DbConnection);
}

return result;
Expand All @@ -188,7 +189,7 @@ protected override DbDataReader ExecuteDbDataReader(CommandBehavior behavior)
var key = CreateKey();

object value;
if (_cacheTransactionHandler.GetItem(Transaction, key, out value))
if (_cacheTransactionHandler.GetItem(Transaction, key, DbConnection, out value))
{
return new CachingReader((CachedResults)value);
}
Expand Down Expand Up @@ -217,7 +218,7 @@ protected async override Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBeha

if (!_commandTreeFacts.IsQuery)
{
_cacheTransactionHandler.InvalidateSets(Transaction, _commandTreeFacts.AffectedEntitySets.Select(s => s.Name));
_cacheTransactionHandler.InvalidateSets(Transaction, _commandTreeFacts.AffectedEntitySets.Select(s => s.Name), DbConnection);
}

return result;
Expand All @@ -226,7 +227,7 @@ protected async override Task<DbDataReader> ExecuteDbDataReaderAsync(CommandBeha
var key = CreateKey();

object value;
if (_cacheTransactionHandler.GetItem(Transaction, key, out value))
if (_cacheTransactionHandler.GetItem(Transaction, key, DbConnection, out value))
{
return new CachingReader((CachedResults)value);
}
Expand Down Expand Up @@ -270,7 +271,8 @@ private DbDataReader HandleCaching(DbDataReader reader, string key, List<object[
cachedResults,
_commandTreeFacts.AffectedEntitySets.Select(s => s.Name),
slidingExpiration,
absoluteExpiration);
absoluteExpiration,
DbConnection);
}

return new CachingReader(cachedResults);
Expand Down Expand Up @@ -321,7 +323,8 @@ private void InvalidateSetsForNonQuery(int recordsAffected)
{
if (recordsAffected > 0 && _commandTreeFacts.AffectedEntitySets.Any())
{
_cacheTransactionHandler.InvalidateSets(Transaction, _commandTreeFacts.AffectedEntitySets.Select(s => s.Name));
_cacheTransactionHandler.InvalidateSets(Transaction, _commandTreeFacts.AffectedEntitySets.Select(s => s.Name),
DbConnection);
}
}

Expand All @@ -336,7 +339,7 @@ public override object ExecuteScalar()

object value;

if (_cacheTransactionHandler.GetItem(Transaction, key, out value))
if (_cacheTransactionHandler.GetItem(Transaction, key, DbConnection, out value))
{
return value;
}
Expand All @@ -353,7 +356,8 @@ public override object ExecuteScalar()
value,
_commandTreeFacts.AffectedEntitySets.Select(s => s.Name),
slidingExpiration,
absoluteExpiration);
absoluteExpiration,
DbConnection);

return value;
}
Expand All @@ -370,7 +374,7 @@ public async override Task<object> ExecuteScalarAsync(CancellationToken cancella

object value;

if (_cacheTransactionHandler.GetItem(Transaction, key, out value))
if (_cacheTransactionHandler.GetItem(Transaction, key, DbConnection, out value))
{
return value;
}
Expand All @@ -387,7 +391,8 @@ public async override Task<object> ExecuteScalarAsync(CancellationToken cancella
value,
_commandTreeFacts.AffectedEntitySets.Select(s => s.Name),
slidingExpiration,
absoluteExpiration);
absoluteExpiration,
DbConnection);

return value;
}
Expand Down
93 changes: 84 additions & 9 deletions EFCacheTests/CacheTransactionHandlerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
namespace EFCache
{
using Moq;
using Moq.Protected;
using System;
using System.Collections.Generic;
using System.Data.Common;
Expand All @@ -28,7 +29,7 @@ public void GetItem_calls_GetItem_on_cache_if_transaction_is_null()
mockCache.Setup(c => c.GetItem(It.IsAny<string>(), out value)).Returns(true);

Assert.True(
new CacheTransactionHandler(mockCache.Object).GetItem(null, "key", out value));
new CacheTransactionHandler(mockCache.Object).GetItem(null, "key", Mock.Of<DbConnection>(), out value));

mockCache.Verify(c => c.GetItem("key", out value), Times.Once());
}
Expand All @@ -42,7 +43,8 @@ public void GetItem_does_not_calls_GetItem_on_cache_if_transaction_is_not_null()
mockCache.Setup(c => c.GetItem(It.IsAny<string>(), out value)).Returns(true);

Assert.False(
new CacheTransactionHandler(mockCache.Object).GetItem(Mock.Of<DbTransaction>(), "key", out value));
new CacheTransactionHandler(mockCache.Object)
.GetItem(Mock.Of<DbTransaction>(), "key", Mock.Of<DbConnection>(), out value));

mockCache.Verify(c => c.GetItem(It.IsAny<string>(), out value), Times.Never());
}
Expand All @@ -59,7 +61,7 @@ public void PutItem_calls_PutItem_on_cache_if_transaction_is_null()
var dateTime = new DateTime(1499);

new CacheTransactionHandler(mockCache.Object)
.PutItem(null, key, value, sets, timeSpan, dateTime);
.PutItem(null, key, value, sets, timeSpan, dateTime, Mock.Of<DbConnection>());

mockCache.Verify(c => c.PutItem(key, value, sets, timeSpan, dateTime), Times.Once());
}
Expand All @@ -70,7 +72,7 @@ public void PutItem_doe_not_call_PutItem_on_cache_if_transaction_is_not_null()
var mockCache = new Mock<ICache>();

new CacheTransactionHandler(mockCache.Object)
.PutItem(Mock.Of<DbTransaction>(), "key", new object(), new string[0], new TimeSpan(), new DateTime());
.PutItem(Mock.Of<DbTransaction>(), "key", new object(), new string[0], new TimeSpan(), new DateTime(), Mock.Of<DbConnection>());

mockCache.Verify(
c => c.PutItem(It.IsAny<string>(), It.IsAny<object>(), It.IsAny<IEnumerable<string>>(),
Expand All @@ -85,7 +87,7 @@ public void InvalidateSets_calls_InvalidateSets_on_cache_if_transaction_null()
var sets = new string[0];

new CacheTransactionHandler(mockCache.Object)
.InvalidateSets(null, sets);
.InvalidateSets(null, sets, Mock.Of<DbConnection>());

mockCache.Verify(c => c.InvalidateSets(sets), Times.Once());
}
Expand All @@ -97,8 +99,8 @@ public void Committed_invalidate_sets_collected_during_transaction()
var transactionHandler = new CacheTransactionHandler(mockCache.Object);

var transaction = Mock.Of<DbTransaction>();
transactionHandler.InvalidateSets(transaction, new[] {"ES1", "ES2"});
transactionHandler.InvalidateSets(transaction, new[] {"ES3", "ES2"});
transactionHandler.InvalidateSets(transaction, new[] {"ES1", "ES2"}, Mock.Of<DbConnection>());
transactionHandler.InvalidateSets(transaction, new[] {"ES3", "ES2"}, Mock.Of<DbConnection>());

transactionHandler.Committed(transaction, Mock.Of<DbTransactionInterceptionContext>());

Expand All @@ -112,14 +114,87 @@ public void RolledBack_clears_affected_sets_collected_during_transaction()
var transactionHandler = new CacheTransactionHandler(mockCache.Object);

var transaction = Mock.Of<DbTransaction>();
transactionHandler.InvalidateSets(transaction, new[] { "ES1", "ES2" });
transactionHandler.InvalidateSets(transaction, new[] { "ES3", "ES2" });
transactionHandler.InvalidateSets(transaction, new[] { "ES1", "ES2" }, Mock.Of<DbConnection>());
transactionHandler.InvalidateSets(transaction, new[] { "ES3", "ES2" }, Mock.Of<DbConnection>());

transactionHandler.RolledBack(transaction, Mock.Of<DbTransactionInterceptionContext>());
transactionHandler.Committed(transaction, Mock.Of<DbTransactionInterceptionContext>());

mockCache.Verify(c => c.InvalidateSets(It.IsAny<IEnumerable<string>>()), Times.Never());
}

[Fact]
public void ResolveCache_throws_for_uninitialized_cache()
{
var transactionHandler = new Mock<CacheTransactionHandler>{ CallBase = true }.Object;

Assert.Throws<InvalidOperationException>(() =>
transactionHandler.GetItem(null, "key", Mock.Of<DbConnection>(), out _));
}

[Fact]
public void GetItem_resolves_cache()
{
var mockTransactionHandler = new Mock<CacheTransactionHandler> { CallBase = true };
mockTransactionHandler.Protected()
.Setup<ICache>("ResolveCache", ItExpr.IsAny<DbConnection>())
.Returns(Mock.Of<ICache>());
var dbConnection = Mock.Of<DbConnection>();

mockTransactionHandler.Object.GetItem(null, "key", dbConnection, out _);

mockTransactionHandler.Protected()
.Verify("ResolveCache", Times.Once(), dbConnection);
}

[Fact]
public void PutItem_resolves_cache()
{
var mockTransactionHandler = new Mock<CacheTransactionHandler> { CallBase = true };
mockTransactionHandler.Protected()
.Setup<ICache>("ResolveCache", ItExpr.IsAny<DbConnection>())
.Returns(Mock.Of<ICache>());
var dbConnection = Mock.Of<DbConnection>();

mockTransactionHandler.Object.PutItem(null, "key", new object(), new string[0], TimeSpan.MaxValue,
DateTimeOffset.MaxValue, dbConnection);

mockTransactionHandler.Protected()
.Verify("ResolveCache", Times.Once(), dbConnection);
}

[Fact]
public void InvalidateSets_resolves_cache()
{
var mockTransactionHandler = new Mock<CacheTransactionHandler> { CallBase = true };
mockTransactionHandler.Protected()
.Setup<ICache>("ResolveCache", ItExpr.IsAny<DbConnection>())
.Returns(Mock.Of<ICache>());
var dbConnection = Mock.Of<DbConnection>();

mockTransactionHandler.Object.InvalidateSets(null, new string[0], dbConnection);

mockTransactionHandler.Protected()
.Verify("ResolveCache", Times.Once(), dbConnection);
}

[Fact]
public void Committed_resolves_cache()
{
var mockTransactionHandler = new Mock<CacheTransactionHandler> { CallBase = true };
mockTransactionHandler.Protected()
.Setup<ICache>("ResolveCache", ItExpr.IsAny<DbConnection>())
.Returns(Mock.Of<ICache>());
var dbConnection = Mock.Of<DbConnection>();
var mockTransaction = new Mock<DbTransaction>();
mockTransaction.Protected().SetupGet<DbConnection>("DbConnection").Returns(dbConnection);
var entitySets = new[] { "ES1" };

mockTransactionHandler.Object.InvalidateSets(mockTransaction.Object, entitySets, dbConnection);
mockTransactionHandler.Object.Committed(mockTransaction.Object, Mock.Of<DbTransactionInterceptionContext>());

mockTransactionHandler.Protected()
.Verify("ResolveCache", Times.Once(), dbConnection);
}
}
}
Loading

0 comments on commit 24898cb

Please sign in to comment.