Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure that cancellation token is passed in InvokeWithActivityAsync #4329

Merged
merged 8 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dotnet/AutoGen.v3.ncrunchsolution
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
<SolutionConfiguration>
<Settings>
<AllowParallelTestExecution>True</AllowParallelTestExecution>
<EnableRDI>True</EnableRDI>
<RdiConfigured>True</RdiConfigured>
<SolutionConfigured>True</SolutionConfigured>
</Settings>
</SolutionConfiguration>
6 changes: 3 additions & 3 deletions dotnet/src/Microsoft.AutoGen/Abstractions/IAgentRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ public interface IAgentRuntime
ValueTask SendRequestAsync(IAgentBase agent, RpcRequest request, CancellationToken cancellationToken = default);
ValueTask SendMessageAsync(Message message, CancellationToken cancellationToken = default);
ValueTask PublishEventAsync(CloudEvent @event, CancellationToken cancellationToken = default);
void Update(Activity? activity, RpcRequest request);
void Update(Activity? activity, CloudEvent cloudEvent);
(string?, string?) GetTraceIDandState(IDictionary<string, string> metadata);
void Update(RpcRequest request, Activity? activity);
void Update(CloudEvent cloudEvent, Activity? activity);
(string?, string?) GetTraceIdAndState(IDictionary<string, string> metadata);
IDictionary<string, string> ExtractMetadata(IDictionary<string, string> metadata);
}
20 changes: 18 additions & 2 deletions dotnet/src/Microsoft.AutoGen/Abstractions/IAgentState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,24 @@

namespace Microsoft.AutoGen.Abstractions;

/// <summary>
/// Interface for managing the state of an agent.
/// </summary>
public interface IAgentState
{
ValueTask<AgentState> ReadStateAsync();
ValueTask<string> WriteStateAsync(AgentState state, string eTag);
/// <summary>
/// Reads the current state of the agent asynchronously.
/// </summary>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task that represents the asynchronous read operation. The task result contains the current state of the agent.</returns>
ValueTask<AgentState> ReadStateAsync(CancellationToken cancellationToken = default);

/// <summary>
/// Writes the specified state of the agent asynchronously.
/// </summary>
/// <param name="state">The state to write.</param>
/// <param name="eTag">The ETag for concurrency control.</param>
/// <param name="cancellationToken">A token to cancel the operation.</param>
/// <returns>A task that represents the asynchronous write operation. The task result contains the ETag of the written state.</returns>
ValueTask<string> WriteStateAsync(AgentState state, string eTag, CancellationToken cancellationToken = default);
}
20 changes: 10 additions & 10 deletions dotnet/src/Microsoft.AutoGen/Agents/AgentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ protected internal async Task HandleRpcMessage(Message msg, CancellationToken ca
{
var activity = this.ExtractActivity(msg.CloudEvent.Type, msg.CloudEvent.Metadata);
await this.InvokeWithActivityAsync(
static ((AgentBase Agent, CloudEvent Item) state) => state.Agent.CallHandler(state.Item),
static ((AgentBase Agent, CloudEvent Item) state, CancellationToken _) => state.Agent.CallHandler(state.Item),
(this, msg.CloudEvent),
activity,
msg.CloudEvent.Type, cancellationToken).ConfigureAwait(false);
Expand All @@ -103,7 +103,7 @@ await this.InvokeWithActivityAsync(
{
var activity = this.ExtractActivity(msg.Request.Method, msg.Request.Metadata);
await this.InvokeWithActivityAsync(
static ((AgentBase Agent, RpcRequest Request) state) => state.Agent.OnRequestCoreAsync(state.Request),
static ((AgentBase Agent, RpcRequest Request) state, CancellationToken ct) => state.Agent.OnRequestCoreAsync(state.Request, ct),
(this, msg.Request),
activity,
msg.Request.Method, cancellationToken).ConfigureAwait(false);
Expand Down Expand Up @@ -142,8 +142,8 @@ public async Task StoreAsync(AgentState state, CancellationToken cancellationTok
}
public async Task<T> ReadAsync<T>(AgentId agentId, CancellationToken cancellationToken = default) where T : IMessage, new()
{
var agentstate = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
return agentstate.FromAgentState<T>();
var agentState = await _context.ReadAsync(agentId, cancellationToken).ConfigureAwait(false);
return agentState.FromAgentState<T>();
}
private void OnResponseCore(RpcResponse response)
{
Expand Down Expand Up @@ -195,9 +195,9 @@ protected async Task<RpcResponse> RequestAsync(AgentId target, string method, Di
activity?.SetTag("peer.service", target.ToString());

var completion = new TaskCompletionSource<RpcResponse>(TaskCreationOptions.RunContinuationsAsynchronously);
_context.Update(activity, request);
_context.Update(request, activity);
await this.InvokeWithActivityAsync(
static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResponse>) state) =>
static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResponse>) state, CancellationToken ct) =>
{
var (self, request, completion) = state;
Expand All @@ -206,7 +206,7 @@ static async ((AgentBase Agent, RpcRequest Request, TaskCompletionSource<RpcResp
self._pendingRequests[request.RequestId] = completion;
}
await state.Agent._context.SendRequestAsync(state.Agent, state.Request).ConfigureAwait(false);
await state.Agent._context.SendRequestAsync(state.Agent, state.Request, ct).ConfigureAwait(false);
await completion.Task.ConfigureAwait(false);
},
Expand All @@ -231,11 +231,11 @@ public async ValueTask PublishEventAsync(CloudEvent item, CancellationToken canc
activity?.SetTag("peer.service", $"{item.Type}/{item.Source}");

// TODO: fix activity
_context.Update(activity, item);
_context.Update(item, activity);
await this.InvokeWithActivityAsync(
static async ((AgentBase Agent, CloudEvent Event) state) =>
static async ((AgentBase Agent, CloudEvent Event) state, CancellationToken ct) =>
{
await state.Agent._context.PublishEventAsync(state.Event).ConfigureAwait(false);
await state.Agent._context.PublishEventAsync(state.Event, ct).ConfigureAwait(false);
},
(this, item),
activity,
Expand Down
37 changes: 28 additions & 9 deletions dotnet/src/Microsoft.AutoGen/Agents/AgentBaseExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,25 @@

namespace Microsoft.AutoGen.Agents;

/// <summary>
/// Provides extension methods for the <see cref="AgentBase"/> class.
/// </summary>
public static class AgentBaseExtensions
{
/// <summary>
/// Extracts an <see cref="Activity"/> from the given agent and metadata.
/// </summary>
/// <param name="agent">The agent from which to extract the activity.</param>
/// <param name="activityName">The name of the activity.</param>
/// <param name="metadata">The metadata containing trace information.</param>
/// <returns>The extracted <see cref="Activity"/> or null if extraction fails.</returns>
public static Activity? ExtractActivity(this AgentBase agent, string activityName, IDictionary<string, string> metadata)
{
Activity? activity;
(var traceParent, var traceState) = agent.Context.GetTraceIDandState(metadata);
var (traceParent, traceState) = agent.Context.GetTraceIdAndState(metadata);
if (!string.IsNullOrEmpty(traceParent))
{
if (ActivityContext.TryParse(traceParent, traceState, isRemote: true, out ActivityContext parentContext))
if (ActivityContext.TryParse(traceParent, traceState, isRemote: true, out var parentContext))
{
// traceParent is a W3CId
activity = AgentBase.s_source.CreateActivity(activityName, ActivityKind.Server, parentContext);
Expand All @@ -33,12 +43,9 @@ public static class AgentBaseExtensions

var baggage = agent.Context.ExtractMetadata(metadata);

if (baggage is not null)
foreach (var baggageItem in baggage)
{
foreach (var baggageItem in baggage)
{
activity.AddBaggage(baggageItem.Key, baggageItem.Value);
}
activity.AddBaggage(baggageItem.Key, baggageItem.Value);
}
}
}
Expand All @@ -49,7 +56,19 @@ public static class AgentBaseExtensions

return activity;
}
public static async Task InvokeWithActivityAsync<TState>(this AgentBase agent, Func<TState, Task> func, TState state, Activity? activity, string methodName, CancellationToken cancellationToken = default)

/// <summary>
/// Invokes a function asynchronously within the context of an <see cref="Activity"/>.
/// </summary>
/// <typeparam name="TState">The type of the state parameter.</typeparam>
/// <param name="agent">The agent invoking the function.</param>
/// <param name="func">The function to invoke.</param>
/// <param name="state">The state parameter to pass to the function.</param>
/// <param name="activity">The activity within which to invoke the function.</param>
/// <param name="methodName">The name of the method being invoked.</param>
/// <param name="cancellationToken">A token to monitor for cancellation requests.</param>
/// <returns>A task representing the asynchronous operation.</returns>
public static async Task InvokeWithActivityAsync<TState>(this AgentBase agent, Func<TState, CancellationToken, Task> func, TState state, Activity? activity, string methodName, CancellationToken cancellationToken = default)
{
if (activity is not null && activity.StartTimeUtc == default)
{
Expand All @@ -63,7 +82,7 @@ public static async Task InvokeWithActivityAsync<TState>(this AgentBase agent, F

try
{
await func(state).ConfigureAwait(false);
await func(state, cancellationToken).ConfigureAwait(false);
if (activity is not null && activity.IsAllDataRequested)
{
activity.SetStatus(ActivityStatusCode.Ok);
Expand Down
6 changes: 3 additions & 3 deletions dotnet/src/Microsoft.AutoGen/Agents/AgentRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal sealed class AgentRuntime(AgentId agentId, IAgentWorker worker, ILogger
public ILogger<AgentBase> Logger { get; } = logger;
public IAgentBase? AgentInstance { get; set; }
private DistributedContextPropagator DistributedContextPropagator { get; } = distributedContextPropagator;
public (string?, string?) GetTraceIDandState(IDictionary<string, string> metadata)
public (string?, string?) GetTraceIdAndState(IDictionary<string, string> metadata)
{
DistributedContextPropagator.ExtractTraceIdAndState(metadata,
static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable<string>? fieldValues) =>
Expand All @@ -28,11 +28,11 @@ internal sealed class AgentRuntime(AgentId agentId, IAgentWorker worker, ILogger
out var traceState);
return (traceParent, traceState);
}
public void Update(Activity? activity, RpcRequest request)
public void Update(RpcRequest request, Activity? activity = null)
{
DistributedContextPropagator.Inject(activity, request.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
}
public void Update(Activity? activity, CloudEvent cloudEvent)
public void Update(CloudEvent cloudEvent, Activity? activity = null)
{
DistributedContextPropagator.Inject(activity, cloudEvent.Metadata, static (carrier, key, value) => ((IDictionary<string, string>)carrier!)[key] = value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
using Microsoft.AutoGen.Abstractions;
using Microsoft.Extensions.AI;
namespace Microsoft.AutoGen.Agents;
public abstract class InferenceAgent<T> : AgentBase where T : IMessage, new()
public abstract class InferenceAgent<T>(
IAgentRuntime context,
EventTypes typeRegistry,
IChatClient client)
: AgentBase(context, typeRegistry)
where T : IMessage, new()
{
protected IChatClient ChatClient { get; }
public InferenceAgent(
IAgentRuntime context,
EventTypes typeRegistry, IChatClient client
) : base(context, typeRegistry)
{
ChatClient = client;
}
protected IChatClient ChatClient { get; } = client;

private Task<ChatCompletion> CompleteAsync(
IList<ChatMessage> chatMessages,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ namespace Microsoft.AutoGen.Agents;

internal sealed class AgentStateGrain([PersistentState("state", "AgentStateStore")] IPersistentState<AgentState> state) : Grain, IAgentState
{
public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag)
/// <inheritdoc />
public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag, CancellationToken cancellationToken = default)
{
// etags for optimistic concurrency control
// if the Etag is null, its a new state
Expand All @@ -27,7 +28,8 @@ public async ValueTask<string> WriteStateAsync(AgentState newState, string eTag)
return state.Etag;
}

public ValueTask<AgentState> ReadStateAsync()
/// <inheritdoc />
public ValueTask<AgentState> ReadStateAsync(CancellationToken cancellationToken = default)
{
return ValueTask.FromResult(state.State);
}
Expand Down
Loading