Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve type metadata for casts to WinRT runtime classes. #1873

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,115 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Collect()
.Combine(properties);
context.RegisterImplementationSourceOutput(bindableCustomPropertyAttributes, GenerateBindableCustomProperties);

// Generate type metadata for casts to WinRT runtime classes.
var castsToWinRTClasses = context.SyntaxProvider.CreateSyntaxProvider(
static (node, _) => node.IsKind(SyntaxKind.CastExpression) || node.IsKind(SyntaxKind.AsExpression) || node.IsKind(SyntaxKind.IsExpression) || node.IsKind(SyntaxKind.IsPatternExpression),
static (context, _) =>
{
TypeSyntax type = null;
ExpressionSyntax expression = null;

// Try to retrieve the type being cast to, and the expression to be cast.
if (context.Node is CastExpressionSyntax castExpression)
{
type = castExpression.Type;
expression = castExpression.Expression;
}
else if (context.Node is IsPatternExpressionSyntax patternExpression)
{
expression = patternExpression.Expression;
if (patternExpression.Pattern is DeclarationPatternSyntax declarationPattern)
{
type = declarationPattern.Type;
}
else if (patternExpression.Pattern is RecursivePatternSyntax recursivePattern)
{
type = recursivePattern.Type;
}
else
{
return null;
}
}
else if (context.Node is BinaryExpressionSyntax binaryExpression)
{
type = binaryExpression.Right as TypeSyntax;
expression = binaryExpression.Left;
}

if (type == null)
{
return null;
}

INamedTypeSymbol namedTypeSymbol = context.SemanticModel.GetSymbolInfo(type).Symbol as INamedTypeSymbol;
if (namedTypeSymbol == null)
{
return null;
}

// Only generate if the type being cast to is a WinRT runtime class.
var winrtAttributeType = context.SemanticModel.Compilation.GetTypeByMetadataName("WinRT.WindowsRuntimeTypeAttribute");
if (namedTypeSymbol.TypeKind != TypeKind.Class || !GeneratorHelper.HasAttributeWithType(namedTypeSymbol, winrtAttributeType))
{
return null;
}

// Avoid cases where the type to be cast from is unknown, or can be done purely through static metadata.
// That is, the type of the expression inherits from the type to be cast to,
// as we know the cast will always work.
var sourceType = context.SemanticModel.GetTypeInfo(expression).Type;
if (sourceType == null || GeneratorHelper.IsDerivedFromType(sourceType, namedTypeSymbol))
{
return null;
}

// Return the fully qualified name of the type to be cast to.
return namedTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
})
.Where(static x => x is not null)
.Collect()
.Select(static (x, _) => x.Distinct())
.Combine(assemblyName);

context.RegisterImplementationSourceOutput(castsToWinRTClasses, (spc, typesAndAssemblyName) =>
{
if (typesAndAssemblyName.Left.Count() == 0)
{
// Don't generate anything if there are no casts.
return;
}

StringBuilder builder = new();
builder.AppendLine($$"""
namespace WinRT.{{typesAndAssemblyName.Right}}CastSupport
{
internal static class CastSupport
{
[global::System.Runtime.CompilerServices.ModuleInitializer]
internal static void InitializeCastSupport()
{
""");
foreach (string fullyQualifiedType in typesAndAssemblyName.Left)
{
string typeofString = fullyQualifiedType.StartsWith("global::") ? fullyQualifiedType : "global::" + fullyQualifiedType;
string runtimeClassName = fullyQualifiedType.StartsWith("global::") ? fullyQualifiedType[8..] : fullyQualifiedType;
builder.Append(" global::WinRT.CastSupport.RegisterTypeName(\"");
builder.Append(runtimeClassName);
builder.Append("\", typeof(");
builder.Append(typeofString);
builder.AppendLine("));");
}

builder.AppendLine("""
}
}
}
""");

spc.AddSource("WinRTCastExtensions.g.cs", builder.ToString());
});
}

// Restrict to non-projected classes which can be instantiated
Expand Down
23 changes: 23 additions & 0 deletions src/Authoring/WinRT.SourceGenerator/Helper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,29 @@ public static string GetAbiMarshalerType(string type, string abiType, TypeKind k
throw new ArgumentException();
}

#nullable enable
/// <summary>
/// Checks whether a given type is derived from a specified type.
/// </summary>
/// <param name="typeSymbol">The input <see cref="ITypeSymbol"/> instance to check.</param>
/// <param name="baseTypeSymbol">The base type to look for.</param>
/// <returns>Whether <paramref name="typeSymbol"/> derives from <paramref name="baseTypeSymbol"/>.</returns>
public static bool IsDerivedFromType(ITypeSymbol typeSymbol, ITypeSymbol baseTypeSymbol)
{
for (ITypeSymbol? currentSymbol = typeSymbol.BaseType;
currentSymbol is { SpecialType: not SpecialType.System_Object };
currentSymbol = currentSymbol.BaseType)
{
if (SymbolEqualityComparer.Default.Equals(currentSymbol, baseTypeSymbol))
{
return true;
}
}

return false;
}
#nullable disable

public static string EscapeAssemblyNameForIdentifier(string typeName)
{
return Regex.Replace(typeName, """[^a-zA-Z0-9_]""", "_");
Expand Down
Loading