Skip to content

Commit

Permalink
Move ReceiveAttributes onto SqsOptions
Browse files Browse the repository at this point in the history
SqsOptions are now passed down to SQSStorage to be accessible
  • Loading branch information
jamescarter-le committed Nov 14, 2023
1 parent 637e0ae commit f4306cc
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 23 deletions.
24 changes: 15 additions & 9 deletions src/AWS/Orleans.Streaming.SQS/Storage/SQSStorage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
using Orleans.Streaming.SQS;
using SQSMessage = Amazon.SQS.Model.Message;
using Orleans;
using Orleans.Configuration;

namespace OrleansAWSUtils.Storage
{
Expand All @@ -26,6 +27,7 @@ internal class SQSStorage
private const string SecretKeyPropertyName = "SecretKey";
private const string SessionTokenPropertyName = "SessionToken";
private const string ServicePropertyName = "Service";
private readonly SqsOptions sqsOptions;
private readonly ILogger Logger;
private string accessKey;
private string secretKey;
Expand All @@ -44,18 +46,22 @@ internal class SQSStorage
/// </summary>
/// <param name="loggerFactory">logger factory to use</param>
/// <param name="queueName">The name of the queue</param>
/// <param name="connectionString">The connection string</param>
/// <param name="sqsOptions">The options for the SQS connection</param>
/// <param name="serviceId">The service ID</param>
public SQSStorage(ILoggerFactory loggerFactory, string queueName, string connectionString, string serviceId = "")
public SQSStorage(ILoggerFactory loggerFactory, string queueName, SqsOptions sqsOptions, string serviceId = "")
{
if (sqsOptions is null) throw new ArgumentNullException(nameof(sqsOptions));
this.sqsOptions = sqsOptions;
QueueName = string.IsNullOrWhiteSpace(serviceId) ? queueName : $"{serviceId}-{queueName}";
ParseDataConnectionString(connectionString);
ParseDataConnectionString(sqsOptions.ConnectionString);
Logger = loggerFactory.CreateLogger<SQSStorage>();
CreateClient();
}

private void ParseDataConnectionString(string dataConnectionString)
{
if(string.IsNullOrEmpty(dataConnectionString)) throw new ArgumentNullException(nameof(dataConnectionString));

var parameters = dataConnectionString.Split(new[] { ';' }, StringSplitOptions.RemoveEmptyEntries);

var serviceConfig = parameters.Where(p => p.Contains(ServicePropertyName)).FirstOrDefault();
Expand Down Expand Up @@ -217,13 +223,13 @@ public async Task<IEnumerable<SQSMessage>> GetMessages(int count = 1)
{
QueueUrl = queueUrl,
MaxNumberOfMessages = count <= MAX_NUMBER_OF_MESSAGE_TO_PEAK ? count : MAX_NUMBER_OF_MESSAGE_TO_PEAK,
// TODO: Move this list to Configuration
AttributeNames = new List<string> { "All" },
// TODO: Move this list to Configuration
MessageAttributeNames = new List<string> { "All" },
// TODO: Move this wait time to Configuration
WaitTimeSeconds = 20
AttributeNames = sqsOptions.ReceiveAttributes,
MessageAttributeNames = sqsOptions.ReceiveMessageAttributes,
};

if (sqsOptions.ReceiveWaitTimeSeconds.HasValue)
request.WaitTimeSeconds = sqsOptions.ReceiveWaitTimeSeconds.Value;

var response = await sqsClient.ReceiveMessageAsync(request);
return response.Messages;
}
Expand Down
13 changes: 7 additions & 6 deletions src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Threading.Tasks;
using Amazon.SQS.Model;
using Microsoft.Extensions.Logging;
using Orleans.Configuration;
using Orleans.Runtime;
using Orleans.Streaming.SQS.Streams;

Expand All @@ -15,7 +16,7 @@ internal class SQSAdapter : IQueueAdapter
{
protected readonly string ServiceId;
private readonly ISQSDataAdapter dataAdapter;
protected readonly string DataConnectionString;
protected SqsOptions sqsOptions;
private readonly IConsistentRingStreamQueueMapper streamQueueMapper;
protected readonly ConcurrentDictionary<QueueId, SQSStorage> Queues = new ConcurrentDictionary<QueueId, SQSStorage>();
private readonly ILoggerFactory loggerFactory;
Expand All @@ -24,21 +25,21 @@ internal class SQSAdapter : IQueueAdapter

public StreamProviderDirection Direction { get { return StreamProviderDirection.ReadWrite; } }

public SQSAdapter(ISQSDataAdapter dataAdapter, IConsistentRingStreamQueueMapper streamQueueMapper, ILoggerFactory loggerFactory, string dataConnectionString, string serviceId, string providerName)
public SQSAdapter(ISQSDataAdapter dataAdapter, IConsistentRingStreamQueueMapper streamQueueMapper, ILoggerFactory loggerFactory, SqsOptions sqsOptions, string serviceId, string providerName)
{
if (string.IsNullOrEmpty(dataConnectionString)) throw new ArgumentNullException(nameof(dataConnectionString));
if (sqsOptions is null) throw new ArgumentNullException(nameof(sqsOptions));
if (string.IsNullOrEmpty(serviceId)) throw new ArgumentNullException(nameof(serviceId));
this.loggerFactory = loggerFactory;
this.sqsOptions = sqsOptions;
this.dataAdapter = dataAdapter;
DataConnectionString = dataConnectionString;
this.ServiceId = serviceId;
Name = providerName;
this.streamQueueMapper = streamQueueMapper;
}

public IQueueAdapterReceiver CreateReceiver(QueueId queueId)
{
return SQSAdapterReceiver.Create(this.dataAdapter, this.loggerFactory, queueId, DataConnectionString, this.ServiceId);
return SQSAdapterReceiver.Create(this.dataAdapter, this.loggerFactory, queueId, sqsOptions, this.ServiceId);
}

public async Task QueueMessageBatchAsync<T>(StreamId streamId, IEnumerable<T> events, StreamSequenceToken token, Dictionary<string, object> requestContext)
Expand All @@ -51,7 +52,7 @@ public async Task QueueMessageBatchAsync<T>(StreamId streamId, IEnumerable<T> ev
SQSStorage queue;
if (!Queues.TryGetValue(queueId, out queue))
{
var tmpQueue = new SQSStorage(this.loggerFactory, queueId.ToString(), DataConnectionString, this.ServiceId);
var tmpQueue = new SQSStorage(this.loggerFactory, queueId.ToString(), sqsOptions, this.ServiceId);
await tmpQueue.InitQueueAsync();
queue = Queues.GetOrAdd(queueId, tmpQueue);
}
Expand Down
2 changes: 1 addition & 1 deletion src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public virtual void Init()
/// <summary>Creates the Azure Queue based adapter.</summary>
public virtual Task<IQueueAdapter> CreateAdapter()
{
var adapter = new SQSAdapter(this.dataAdapter, this.streamQueueMapper, this.loggerFactory, this.sqsOptions.ConnectionString, this.clusterOptions.ServiceId, this.providerName);
var adapter = new SQSAdapter(this.dataAdapter, this.streamQueueMapper, this.loggerFactory, this.sqsOptions, this.clusterOptions.ServiceId, this.providerName);
return Task.FromResult<IQueueAdapter>(adapter);
}

Expand Down
7 changes: 4 additions & 3 deletions src/AWS/Orleans.Streaming.SQS/Streams/SQSAdapterReceiver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Orleans.Configuration;
using Orleans.Streaming.SQS.Streams;
using SQSMessage = Amazon.SQS.Model.Message;

Expand All @@ -24,13 +25,13 @@ internal class SQSAdapterReceiver : IQueueAdapterReceiver

public QueueId Id { get; private set; }

public static IQueueAdapterReceiver Create(ISQSDataAdapter dataAdapter, ILoggerFactory loggerFactory, QueueId queueId, string dataConnectionString, string serviceId)
public static IQueueAdapterReceiver Create(ISQSDataAdapter dataAdapter, ILoggerFactory loggerFactory, QueueId queueId, SqsOptions sqsOptions, string serviceId)
{
if (queueId.IsDefault) throw new ArgumentNullException(nameof(queueId));
if (string.IsNullOrEmpty(dataConnectionString)) throw new ArgumentNullException(nameof(dataConnectionString));
if (sqsOptions is null) throw new ArgumentNullException(nameof(sqsOptions));
if (string.IsNullOrEmpty(serviceId)) throw new ArgumentNullException(nameof(serviceId));

var queue = new SQSStorage(loggerFactory, queueId.ToString(), dataConnectionString, serviceId);
var queue = new SQSStorage(loggerFactory, queueId.ToString(), sqsOptions, serviceId);
return new SQSAdapterReceiver(dataAdapter, loggerFactory, queueId, queue);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Collections.Generic;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
Expand All @@ -24,10 +24,12 @@ public static async Task DeleteAllUsedQueues(string providerName, string cluster
var queueMapper = new HashRingBasedStreamQueueMapper(new HashRingStreamQueueMapperOptions(), providerName);
List<QueueId> allQueues = queueMapper.GetAllQueues().ToList();

var sqsOptions = new SqsOptions { ConnectionString = storageConnectionString };

var deleteTasks = new List<Task>();
foreach (var queueId in allQueues)
{
var manager = new SQSStorage(loggerFactory, queueId.ToString(), storageConnectionString, clusterId);
var manager = new SQSStorage(loggerFactory, queueId.ToString(), sqsOptions, clusterId);
manager.InitQueueAsync().Wait();
deleteTasks.Add(manager.DeleteQueue());
}
Expand Down
28 changes: 27 additions & 1 deletion src/AWS/Orleans.Streaming.SQS/Streams/SqsStreamOptions.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,35 @@


using System.Collections.Generic;

namespace Orleans.Configuration
{
public class SqsOptions
{
/// <summary>
/// Specifies the connection string to use for connecting to SQS.
/// </summary>
/// <example>
/// Example for AWS: Service=eu-west-1;AccessKey=XXXXXX;SecretKey=XXXXXX;SessionToken=XXXXXX;
/// </example>
/// <example>
/// Example for LocalStack: Service=http://localhost:4566
/// </example>
[Redact]
public string ConnectionString { get; set; }

/// <summary>
/// Specifies which SQS Attributes should be retrieved about the SQS message from the Queue.
/// </summary>
public List<string> ReceiveAttributes { get; set; } = new();

/// <summary>
/// Specifies which Message Attributes should be retrieved with the SQS messages.
/// </summary>
public List<string> ReceiveMessageAttributes { get; set; } = new();

/// <summary>
/// The optional duration to long-poll for new SQS messages.
/// </summary>
public int? ReceiveWaitTimeSeconds { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public async Task SendAndReceiveFromSQS()
var options = new SqsOptions
{
ConnectionString = AWSTestConstants.SqsConnectionString,
ReceiveMessageAttributes = new[] { "StreamId" }.ToList()
};
var clusterOptions = new ClusterOptions { ServiceId = this.clusterId };
var dataAdapter = new StringOrIntSqlDataAdapter(fixture.Serializer);
Expand Down Expand Up @@ -217,9 +218,9 @@ public StringOrIntSqlDataAdapter(Serializer serializer) : base(serializer)

public override IBatchContainer GetBatchContainer(Message sqsMessage, ref long sequenceNumber)
{
// Example extracts the StreamId as an attribute instead of it being serialized in the body.
if (!sqsMessage.MessageAttributes.TryGetValue("StreamId", out var streamIdStr))
throw new DataException("SQS Message did not contain a StreamId attribute.");

var streamId = StreamId.Parse(Encoding.UTF8.GetBytes(streamIdStr.StringValue));

// Contrived example sends strings as quoted, and longs as unquoted.
Expand All @@ -241,6 +242,7 @@ public override Message ToQueueMessage<T>(StreamId streamId, IEnumerable<T> even
var serializedData = string.Join(Environment.NewLine,
events.Select(x => x is string ? $"\"{x}\"" : x.ToString()));

// Example includes the StreamId as an attribute.
return new Message
{
Attributes = new()
Expand Down

0 comments on commit f4306cc

Please sign in to comment.