diff --git a/Agoda.Frameworks.DB.Tests/DbRepositoryTest.cs b/Agoda.Frameworks.DB.Tests/DbRepositoryTest.cs index 2ba7f52..03ed623 100644 --- a/Agoda.Frameworks.DB.Tests/DbRepositoryTest.cs +++ b/Agoda.Frameworks.DB.Tests/DbRepositoryTest.cs @@ -666,28 +666,6 @@ public void QueryMultipleAsync_Retry_Failure() Assert.IsNotNull(_onQueryCompleteEvents[0].Error); Assert.IsNotNull(_onQueryCompleteEvents[1].Error); } - - [Test] - public async Task ExecuteReaderAsync_CancellationToken_Cancelled() - { - var cancellationToken = new CancellationTokenSource(TimeSpan.FromDays(1)); - cancellationToken.Cancel(); - var maxAttemptCount = 2; - Assert.ThrowsAsync(async () => await _db.ExecuteReaderAsync("mobile_ro", "db.v1.sp_foo", 1, - maxAttemptCount, - cancellationToken.Token, - new IDbDataParameter[] - { - new SqlParameter("@param1", "value1"), - new SqlParameter("@param2", "value2") - - }, reader => { return Task.FromResult("");}) - ); - - _dbResources.Verify(x => x.ChooseDb("mobile_ro").UpdateWeight(It.IsAny(), false), Times.Exactly(maxAttemptCount)); - } - - protected class FakeStoredProc : IStoredProc { public string DbName => "mobile_ro"; diff --git a/Agoda.Frameworks.DB/DbRepository.cs b/Agoda.Frameworks.DB/DbRepository.cs index 487b5fa..699dd4e 100644 --- a/Agoda.Frameworks.DB/DbRepository.cs +++ b/Agoda.Frameworks.DB/DbRepository.cs @@ -303,7 +303,7 @@ public Task ExecuteReaderAsync( string storedProc, int timeoutSecs, int maxAttemptCount, - CancellationToken token, + int taskCancellationTimeOutInMilliSecs, IDbDataParameter[] parameters, Func> callback) { @@ -311,53 +311,58 @@ public Task ExecuteReaderAsync( { var stopwatch = Stopwatch.StartNew(); Exception error = null; - try + using (var cancellationTokenSource = new CancellationTokenSource()) { - using (var connection = _generateConnection(connectionStr)) + cancellationTokenSource.CancelAfter(taskCancellationTimeOutInMilliSecs); + try { - if (connection is SqlConnection sqlConn) + using (var connection = _generateConnection(connectionStr)) { - await sqlConn.OpenAsync(); - } - else - { - connection.Open(); - } - SqlCommand sqlCommand = null; - try - { - sqlCommand = new SqlCommand(storedProc, connection as SqlConnection) + if (connection is SqlConnection sqlConn) { - CommandType = CommandType.StoredProcedure, - CommandTimeout = timeoutSecs - }; - sqlCommand.Parameters.AddRange(parameters); - using (var reader = await sqlCommand.ExecuteReaderAsync(token)) + await sqlConn.OpenAsync(cancellationTokenSource.Token); + } + else { - return await callback(reader); + connection.Open(); } - } - finally - { - if (sqlCommand != null) + SqlCommand sqlCommand = null; + try { - sqlCommand.Parameters.Clear(); - sqlCommand.Dispose(); + sqlCommand = new SqlCommand(storedProc, connection as SqlConnection) + { + CommandType = CommandType.StoredProcedure, + CommandTimeout = timeoutSecs + }; + sqlCommand.Parameters.AddRange(parameters); + using (var reader = await sqlCommand.ExecuteReaderAsync(cancellationTokenSource.Token)) + { + return await callback(reader); + } + } + finally + { + if (sqlCommand != null) + { + sqlCommand.Parameters.Clear(); + sqlCommand.Dispose(); + } } } } + catch (Exception e) + { + error = e; + throw; + } + finally + { + stopwatch.Stop(); + RaiseOnExecuteReaderComplete( + database, storedProc, stopwatch.ElapsedMilliseconds, error); + } } - catch (Exception e) - { - error = e; - throw; - } - finally - { - stopwatch.Stop(); - RaiseOnExecuteReaderComplete( - database, storedProc, stopwatch.ElapsedMilliseconds, error); - } + }, ShouldRetry(maxAttemptCount), RaiseOnError); } diff --git a/Agoda.Frameworks.DB/IDbRepository.cs b/Agoda.Frameworks.DB/IDbRepository.cs index 9729254..1d84d8a 100644 --- a/Agoda.Frameworks.DB/IDbRepository.cs +++ b/Agoda.Frameworks.DB/IDbRepository.cs @@ -141,7 +141,7 @@ Task ExecuteReaderAsync( string storedProc, int timeoutSecs, int maxAttemptCount, - CancellationToken token, + int taskCancellationTimeOutInMilliSecs, IDbDataParameter[] parameters, Func> callback);