Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for .NET CancellationTokens #9127

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/Orleans.CodeGenerator/AnalyzerReleases.Unshipped.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
; Unshipped analyzer release
; Unshipped analyzer release
; https://github.com/dotnet/roslyn-analyzers/blob/main/src/Microsoft.CodeAnalysis.Analyzers/ReleaseTrackingAnalyzers.Help.md

### New Rules

Rule ID | Category | Severity | Notes
--------|----------|----------|--------------------
ORLEANS0109 | Usage | Error | Method has multiple CancellationToken parameters
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using System.Linq;
using Microsoft.CodeAnalysis;

namespace Orleans.CodeGenerator.Diagnostics;

public static class MultipleCancellationTokenParametersDiagnostic
{
public const string DiagnosticId = "ORLEANS0109";
public const string Title = "Grain method has multiple parameters of type CancellationToken";
public const string MessageFormat = "The type {0} contains method {1} which has multiple CancellationToken parameters. Only a single CancellationToken parameter is supported.";
public const string Category = "Usage";

private static readonly DiagnosticDescriptor Rule = new DiagnosticDescriptor(DiagnosticId, Title, MessageFormat, Category, DiagnosticSeverity.Error, isEnabledByDefault: true);

internal static Diagnostic CreateDiagnostic(IMethodSymbol symbol) => Diagnostic.Create(Rule, symbol.Locations.First(), symbol.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat), symbol.Name);
}
334 changes: 295 additions & 39 deletions src/Orleans.CodeGenerator/InvokableGenerator.cs

Large diffs are not rendered by default.

21 changes: 15 additions & 6 deletions src/Orleans.CodeGenerator/LibraryTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options)
ConstructorAttributeTypes = options.ConstructorAttributes.Select(Type).ToArray();
AliasAttribute = Type("Orleans.AliasAttribute");
IInvokable = Type("Orleans.Serialization.Invocation.IInvokable");
ICancellableInvokable = Type("Orleans.Serialization.Invocation.ICancellableInvokable");
InvokeMethodNameAttribute = Type("Orleans.InvokeMethodNameAttribute");
RuntimeHelpers = Type("System.Runtime.CompilerServices.RuntimeHelpers");
InvokableCustomInitializerAttribute = Type("Orleans.InvokableCustomInitializerAttribute");
Expand All @@ -58,6 +59,8 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options)
SuppressReferenceTrackingAttribute = Type("Orleans.SuppressReferenceTrackingAttribute");
OmitDefaultMemberValuesAttribute = Type("Orleans.OmitDefaultMemberValuesAttribute");
ITargetHolder = Type("Orleans.Serialization.Invocation.ITargetHolder");
ICancellationRuntime = Type("Orleans.Serialization.Invocation.ICancellationRuntime");
ICancellableInvokableGrainExtension = TypeOrDefault("Orleans.Runtime.ICancellableInvokableGrainExtension");
TypeManifestProviderAttribute = Type("Orleans.Serialization.Configuration.TypeManifestProviderAttribute");
NonSerializedAttribute = Type("System.NonSerializedAttribute");
ObsoleteAttribute = Type("System.ObsoleteAttribute");
Expand All @@ -69,6 +72,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options)
TypeManifestOptions = Type("Orleans.Serialization.Configuration.TypeManifestOptions");
Task = Type("System.Threading.Tasks.Task");
Task_1 = Type("System.Threading.Tasks.Task`1");
IAsyncEnumerable = Type("System.Collections.Generic.IAsyncEnumerable`1");
this.Type = Type("System.Type");
_uri = Type("System.Uri");
_int128 = TypeOrDefault("System.Int128");
Expand All @@ -77,11 +81,11 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options)
_dateOnly = TypeOrDefault("System.DateOnly");
_dateTimeOffset = Type("System.DateTimeOffset");
_bitVector32 = Type("System.Collections.Specialized.BitVector32");
_guid = Type("System.Guid");
_compareInfo = Type("System.Globalization.CompareInfo");
_cultureInfo = Type("System.Globalization.CultureInfo");
_version = Type("System.Version");
_timeOnly = TypeOrDefault("System.TimeOnly");
Guid = Type("System.Guid");
ICodecProvider = Type("Orleans.Serialization.Serializers.ICodecProvider");
ValueSerializer = Type("Orleans.Serialization.Serializers.IValueSerializer`1");
ValueTask = Type("System.Threading.Tasks.ValueTask");
Expand Down Expand Up @@ -124,6 +128,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options)
new(TypeOrDefault("System.Int128"), TypeOrDefault("Orleans.Serialization.Codecs.Int128Codec")),
new(TypeOrDefault("System.Half"), TypeOrDefault("Orleans.Serialization.Codecs.HalfCodec")),
new(Type("System.Uri"), Type("Orleans.Serialization.Codecs.UriCodec")),
new(Type("System.Threading.CancellationToken"), Type("Orleans.Serialization.Codecs.CancellationTokenCodec")),
}.Where(desc => desc.UnderlyingType is { } && desc.CodecType is { }).ToArray();
WellKnownCodecs = new WellKnownCodecDescription[]
{
Expand Down Expand Up @@ -153,7 +158,7 @@ private LibraryTypes(Compilation compilation, CodeGeneratorOptions options)
TimeSpan = Type("System.TimeSpan");
_ipAddress = Type("System.Net.IPAddress");
_ipEndPoint = Type("System.Net.IPEndPoint");
_cancellationToken = Type("System.Threading.CancellationToken");
CancellationToken = Type("System.Threading.CancellationToken");
_immutableContainerTypes = new[]
{
compilation.GetSpecialType(SpecialType.System_Nullable_T),
Expand Down Expand Up @@ -218,7 +223,10 @@ INamedTypeSymbol Type(string metadataName)
public INamedTypeSymbol IActivator_1 { get; private set; }
public INamedTypeSymbol IBufferWriter { get; private set; }
public INamedTypeSymbol IInvokable { get; private set; }
public INamedTypeSymbol ICancellableInvokable { get; private set; }
public INamedTypeSymbol ITargetHolder { get; private set; }
public INamedTypeSymbol ICancellationRuntime { get; private set; }
public INamedTypeSymbol? ICancellableInvokableGrainExtension { get; private set; }
public INamedTypeSymbol TypeManifestProviderAttribute { get; private set; }
public INamedTypeSymbol NonSerializedAttribute { get; private set; }
public INamedTypeSymbol ObsoleteAttribute { get; private set; }
Expand All @@ -230,6 +238,7 @@ INamedTypeSymbol Type(string metadataName)
public INamedTypeSymbol TypeManifestOptions { get; private set; }
public INamedTypeSymbol Task { get; private set; }
public INamedTypeSymbol Task_1 { get; private set; }
public INamedTypeSymbol IAsyncEnumerable { get; private set; }
public INamedTypeSymbol Type { get; private set; }
private INamedTypeSymbol _uri;
private INamedTypeSymbol? _dateOnly;
Expand Down Expand Up @@ -259,13 +268,13 @@ INamedTypeSymbol Type(string metadataName)
public INamedTypeSymbol SuppressReferenceTrackingAttribute { get; private set; }
public INamedTypeSymbol OmitDefaultMemberValuesAttribute { get; private set; }
public INamedTypeSymbol CopyContext { get; private set; }
public INamedTypeSymbol CancellationToken { get; private set; }
public INamedTypeSymbol Guid { get; private set; }
public Compilation Compilation { get; private set; }
public INamedTypeSymbol TimeSpan { get; private set; }
private INamedTypeSymbol _ipAddress;
private INamedTypeSymbol _ipEndPoint;
private INamedTypeSymbol _cancellationToken;
private INamedTypeSymbol[] _immutableContainerTypes;
private INamedTypeSymbol _guid;
private INamedTypeSymbol _bitVector32;
private INamedTypeSymbol _compareInfo;
private INamedTypeSymbol _cultureInfo;
Expand All @@ -280,14 +289,14 @@ INamedTypeSymbol Type(string metadataName)
_dateOnly,
_timeOnly,
_dateTimeOffset,
_guid,
Guid,
_bitVector32,
_compareInfo,
_cultureInfo,
_version,
_ipAddress,
_ipEndPoint,
_cancellationToken,
CancellationToken,
Type,
_uri,
_uInt128,
Expand Down
6 changes: 6 additions & 0 deletions src/Orleans.CodeGenerator/Model/InvokableMethodDescription.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Globalization;
using System.Linq;

namespace Orleans.CodeGenerator
{
Expand Down Expand Up @@ -206,6 +207,11 @@ static bool TryGetNamedArgument(ImmutableArray<KeyValuePair<string, TypedConstan
/// </summary>
public INamedTypeSymbol ContainingInterface { get; }

/// <summary>
/// Gets a value indicating whether this method is cancellable.
/// </summary>
public bool IsCancellable => Method.Parameters.Any(parameterSymbol => SymbolEqualityComparer.Default.Equals(CodeGenerator.LibraryTypes.CancellationToken, parameterSymbol.Type));

public bool Equals(InvokableMethodDescription other) => Key.Equals(other.Key);
public override bool Equals(object obj) => obj is InvokableMethodDescription imd && Equals(imd);
public override int GetHashCode() => Key.GetHashCode();
Expand Down
45 changes: 45 additions & 0 deletions src/Orleans.CodeGenerator/ProxyGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,51 @@ MethodDeclarationSyntax CreateProxyMethod(ProxyMethodDescription methodDescripti
.Concat(_codeGenerator.LibraryTypes.StaticCopiers)
.ToList();

// Ensure to hook up the cancellation token if the method has one
var cancellationTokenParameter = methodSymbol.Parameters.SingleOrDefault(parameter => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, parameter.Type));
if (cancellationTokenParameter is not null)
{
// Throw aggressively if cancellation is already requested
statements.Add(
ExpressionStatement(
InvocationExpression(
IdentifierName($"arg{cancellationTokenParameter.Ordinal}").Member("ThrowIfCancellationRequested"),
ArgumentList()
)
)
);

// Register for cancellation
statements.Add(
ExpressionStatement(
InvocationExpression(
koenbeuk marked this conversation as resolved.
Show resolved Hide resolved
IdentifierName($"arg{cancellationTokenParameter.Ordinal}").Member("Register"))
.WithArgumentList(
ArgumentList(SeparatedList(new[]
{
Argument(
SimpleLambdaExpression(
Parameter(Identifier("arg")),
InvocationExpression(
InvocationExpression(ThisExpression().Member("AsReference", LibraryTypes.ICancellableInvokableGrainExtension.ToTypeSyntax())).Member("CancelRemoteToken"),
ArgumentList(SeparatedList(new[]
{
Argument(
CastExpression(
ParseTypeName(_codeGenerator.LibraryTypes.Guid.ToDisplayName()),
IdentifierName("arg")
)
),
}))
)
)
),
Argument(
InvocationExpression(
IdentifierName("request").Member(IdentifierName("GetCancellableTokenId"))))
})))));
}

// Set request object fields from method parameters.
var parameterIndex = 0;
var parameters = invokable.Members.OfType<MethodParameterFieldDescription>().Select(member => new SerializableMethodMember(member));
Expand Down
56 changes: 56 additions & 0 deletions src/Orleans.CodeGenerator/SerializerGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ public ClassDeclarationSyntax Generate(ISerializableTypeDescription type)
{
members.Add(new SerializableMethodMember(methodParameter));
}
else if (member is CancellableTokenFieldDescription cancellableTokenField)
{
members.Add(new SerializableCancellableTokenMember(_codeGenerator, cancellableTokenField));
}
}

var fieldDescriptions = GetFieldDescriptions(type, members);
Expand Down Expand Up @@ -1137,6 +1141,8 @@ public SerializableMethodMember(MethodParameterFieldDescription member)

public bool IsShallowCopyable => LibraryTypes.IsShallowCopyable(_member.Parameter.Type) || _member.Parameter.HasAnyAttribute(LibraryTypes.ImmutableAttributes);

public bool IsCancellationToken => SymbolEqualityComparer.Default.Equals(LibraryTypes.CancellationToken, _member.Parameter.Type);

/// <summary>
/// Gets syntax representing the type of this field.
/// </summary>
Expand Down Expand Up @@ -1168,6 +1174,56 @@ public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax va
public FieldAccessorDescription GetSetterFieldDescription() => null;
}

internal class SerializableCancellableTokenMember : ISerializableMember
{
private readonly CodeGenerator _codeGenerator;
private readonly CancellableTokenFieldDescription _member;

public SerializableCancellableTokenMember(CodeGenerator codeGenerator, CancellableTokenFieldDescription member)
{
_codeGenerator = codeGenerator;
_member = member;
}

public IMemberDescription Member => _member;

private LibraryTypes LibraryTypes => _codeGenerator.LibraryTypes;

public bool IsShallowCopyable => LibraryTypes.IsShallowCopyable(_member.Type);

public bool IsCancellationToken => false;

/// <summary>
/// Gets syntax representing the type of this field.
/// </summary>
public TypeSyntax TypeSyntax => _member.TypeSyntax;

public bool IsValueType => _member.Type.IsValueType;

public bool IsPrimaryConstructorParameter => false;

/// <summary>
/// Returns syntax for retrieving the value of this field, deep copying it if necessary.
/// </summary>
/// <param name="instance">The instance of the containing type.</param>
/// <returns>Syntax for retrieving the value of this field.</returns>
public ExpressionSyntax GetGetter(ExpressionSyntax instance) => instance.Member(_member.FieldName);

/// <summary>
/// Returns syntax for setting the value of this field.
/// </summary>
/// <param name="instance">The instance of the containing type.</param>
/// <param name="value">Syntax for the new value.</param>
/// <returns>Syntax for setting the value of this field.</returns>
public ExpressionSyntax GetSetter(ExpressionSyntax instance, ExpressionSyntax value) => AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
instance.Member(_member.FieldName),
value);

public FieldAccessorDescription GetGetterFieldDescription() => null;
public FieldAccessorDescription GetSetterFieldDescription() => null;
}

/// <summary>
/// Represents a serializable member (field/property) of a type.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using System;
using System.Threading.Tasks;
using Orleans.Concurrency;

namespace Orleans.Runtime;

public interface ICancellableInvokableGrainExtension : IGrainExtension
{
/// <summary>
/// Indicates that a cancellation token has been canceled.
/// </summary>
/// <param name="tokenId">
/// The token id</param>
/// <returns>
/// A <see cref="Task"/> representing the operation.
/// </returns>
[AlwaysInterleave]
Task CancelRemoteToken(Guid tokenId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
using System;
using System.Threading;
using System.Threading.Tasks;
using Orleans.Serialization.Invocation;

namespace Orleans.Runtime.Cancellation;

internal class CancellableInvokableGrainExtension : ICancellableInvokableGrainExtension, IDisposable
{
readonly ICancellationRuntime _runtime;
readonly Timer _cleanupTimer;

public CancellableInvokableGrainExtension(IGrainContext grainContext)
{
_runtime = grainContext.GetComponent<ICancellationRuntime>();
_cleanupTimer = new Timer(obj => ((CancellableInvokableGrainExtension)obj)._runtime.ExpireTokens(), this, TimeSpan.FromSeconds(30), TimeSpan.FromSeconds(30));
}

public Task CancelRemoteToken(Guid tokenId)
{
if (_runtime is not null)
{
_runtime.Cancel(tokenId, lastCall: false);
}

return Task.CompletedTask;
}

public void Dispose()
{
_cleanupTimer.Dispose();
}
}
Loading
Loading