From ed489cd0fb68c92ac6bea384654a9a3fc02e0464 Mon Sep 17 00:00:00 2001
From: Dongle <29563098+dongle-the-gadget@users.noreply.github.com>
Date: Mon, 18 Nov 2024 19:30:04 +0700
Subject: [PATCH 1/3] Preserve type metadata for casts to WinRT runtime
classes.
---
.../WinRT.SourceGenerator/AotOptimizer.cs | 109 ++
src/Authoring/WinRT.SourceGenerator/Helper.cs | 23 +
.../RcwReflectionFallbackGenerator.cs | 601 ++++++-----
src/WinRT.Runtime/CastSupport.cs | 18 +
src/WinRT.Runtime/TypeNameSupport.cs | 955 +++++++++---------
5 files changed, 920 insertions(+), 786 deletions(-)
create mode 100644 src/WinRT.Runtime/CastSupport.cs
diff --git a/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs b/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs
index 7a4a0b4bd..8a18c3d1b 100644
--- a/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs
+++ b/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs
@@ -167,6 +167,115 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
.Collect()
.Combine(properties);
context.RegisterImplementationSourceOutput(bindableCustomPropertyAttributes, GenerateBindableCustomProperties);
+
+ // Generate type metadata for casts to WinRT runtime classes.
+ var castsToWinRTClasses = context.SyntaxProvider.CreateSyntaxProvider(
+ static (node, _) => node.IsKind(SyntaxKind.CastExpression) || node.IsKind(SyntaxKind.AsExpression) || node.IsKind(SyntaxKind.IsExpression) || node.IsKind(SyntaxKind.IsPatternExpression),
+ static (context, _) =>
+ {
+ TypeSyntax type = null;
+ ExpressionSyntax expression = null;
+
+ // Try to retrieve the type being cast to, and the expression to be cast.
+ if (context.Node is CastExpressionSyntax castExpression)
+ {
+ type = castExpression.Type;
+ expression = castExpression.Expression;
+ }
+ else if (context.Node is IsPatternExpressionSyntax patternExpression)
+ {
+ expression = patternExpression.Expression;
+ if (patternExpression.Pattern is DeclarationPatternSyntax declarationPattern)
+ {
+ type = declarationPattern.Type;
+ }
+ else if (patternExpression.Pattern is RecursivePatternSyntax recursivePattern)
+ {
+ type = recursivePattern.Type;
+ }
+ else
+ {
+ return null;
+ }
+ }
+ else if (context.Node is BinaryExpressionSyntax binaryExpression)
+ {
+ type = binaryExpression.Right as TypeSyntax;
+ expression = binaryExpression.Left;
+ }
+
+ if (type == null)
+ {
+ return null;
+ }
+
+ INamedTypeSymbol namedTypeSymbol = context.SemanticModel.GetSymbolInfo(type).Symbol as INamedTypeSymbol;
+ if (namedTypeSymbol == null)
+ {
+ return null;
+ }
+
+ // Only generate if the type being cast to is a WinRT runtime class.
+ var winrtAttributeType = context.SemanticModel.Compilation.GetTypeByMetadataName("WinRT.WindowsRuntimeTypeAttribute");
+ if (namedTypeSymbol.TypeKind != TypeKind.Class || !GeneratorHelper.HasAttributeWithType(namedTypeSymbol, winrtAttributeType))
+ {
+ return null;
+ }
+
+ // Avoid cases where the type to be cast from is unknown, or can be done purely through static metadata.
+ // That is, the type to be cast to inherits from the type of the expression to be cast,
+ // as we know the cast will always work.
+ var sourceType = context.SemanticModel.GetTypeInfo(expression).Type;
+ if (sourceType == null || GeneratorHelper.IsDerivedFromType(sourceType, namedTypeSymbol))
+ {
+ return null;
+ }
+
+ // Return the fully qualified name of the type to be cast to.
+ return namedTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+ })
+ .Where(static x => x is not null)
+ .Collect()
+ .Select(static (x, _) => x.Distinct())
+ .Combine(assemblyName);
+
+ context.RegisterImplementationSourceOutput(castsToWinRTClasses, (spc, typesAndAssemblyName) =>
+ {
+ if (typesAndAssemblyName.Left.Count() == 0)
+ {
+ // Don't generate anything if there are no casts.
+ return;
+ }
+
+ StringBuilder builder = new();
+ builder.AppendLine($$"""
+ namespace WinRT.{{typesAndAssemblyName.Right}}CastSupport
+ {
+ internal static class CastSupport
+ {
+ [global::System.Runtime.CompilerServices.ModuleInitializer]
+ internal static void InitializeCastSupport()
+ {
+ """);
+ foreach (string fullyQualifiedType in typesAndAssemblyName.Left)
+ {
+ string typeofString = fullyQualifiedType.StartsWith("global::") ? fullyQualifiedType : "global::" + fullyQualifiedType;
+ string runtimeClassName = fullyQualifiedType.StartsWith("global::") ? fullyQualifiedType[8..] : fullyQualifiedType;
+ builder.Append(" global::WinRT.CastSupport.RegisterTypeName(\"");
+ builder.Append(runtimeClassName);
+ builder.Append("\", typeof(");
+ builder.Append(typeofString);
+ builder.AppendLine("));");
+ }
+
+ builder.AppendLine("""
+ }
+ }
+ }
+ """);
+
+ spc.AddSource("WinRTCastExtensions.g.cs", builder.ToString());
+ });
}
// Restrict to non-projected classes which can be instantiated
diff --git a/src/Authoring/WinRT.SourceGenerator/Helper.cs b/src/Authoring/WinRT.SourceGenerator/Helper.cs
index 83f5e67cf..db2243582 100644
--- a/src/Authoring/WinRT.SourceGenerator/Helper.cs
+++ b/src/Authoring/WinRT.SourceGenerator/Helper.cs
@@ -1159,6 +1159,29 @@ public static string GetAbiMarshalerType(string type, string abiType, TypeKind k
throw new ArgumentException();
}
+#nullable enable
+ ///
+ /// Checks whether a given type is derived from a specified type.
+ ///
+ /// The input instance to check.
+ /// The base type to look for.
+ /// Whether derives from .
+ public static bool IsDerivedFromType(ITypeSymbol typeSymbol, ITypeSymbol baseTypeSymbol)
+ {
+ for (ITypeSymbol? currentSymbol = typeSymbol.BaseType;
+ currentSymbol is { SpecialType: not SpecialType.System_Object };
+ currentSymbol = currentSymbol.BaseType)
+ {
+ if (SymbolEqualityComparer.Default.Equals(currentSymbol, baseTypeSymbol))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+#nullable disable
+
public static string EscapeAssemblyNameForIdentifier(string typeName)
{
return Regex.Replace(typeName, """[^a-zA-Z0-9_]""", "_");
diff --git a/src/Authoring/WinRT.SourceGenerator/RcwReflectionFallbackGenerator.cs b/src/Authoring/WinRT.SourceGenerator/RcwReflectionFallbackGenerator.cs
index e0360602e..432449b93 100644
--- a/src/Authoring/WinRT.SourceGenerator/RcwReflectionFallbackGenerator.cs
+++ b/src/Authoring/WinRT.SourceGenerator/RcwReflectionFallbackGenerator.cs
@@ -1,339 +1,318 @@
-// Copyright (c) Microsoft Corporation.
-// Licensed under the MIT License.
-
-using System;
-using System.Collections.Generic;
-using System.Collections.Immutable;
-using System.Linq;
-using System.Text;
-using Microsoft.CodeAnalysis;
-using WinRT.SourceGenerator;
-
-#nullable enable
-
-namespace Generator;
-
-[Generator]
-public sealed class RcwReflectionFallbackGenerator : IIncrementalGenerator
-{
- ///
- public void Initialize(IncrementalGeneratorInitializationContext context)
- {
- // Gather all PE references from the current compilation
- IncrementalValuesProvider executableReferences =
- context.CompilationProvider
- .SelectMany(static (compilation, token) =>
- {
- var executableReferences = ImmutableArray.CreateBuilder();
-
- foreach (MetadataReference metadataReference in compilation.References)
- {
- // We are only interested in PE references (not project references)
- if (metadataReference is not PortableExecutableReference executableReference)
- {
- continue;
- }
-
- executableReferences.Add(new EquatablePortableExecutableReference(executableReference, compilation));
- }
-
- return executableReferences.ToImmutable();
- });
-
- // Get whether the current project is an .exe
- IncrementalValueProvider isOutputTypeExe = context.CompilationProvider.Select(static (compilation, token) =>
- {
- return compilation.Options.OutputKind is OutputKind.ConsoleApplication or OutputKind.WindowsApplication or OutputKind.WindowsRuntimeApplication;
- });
-
- // Get whether the generator is explicitly set as opt-in
- IncrementalValueProvider isGeneratorForceOptIn = context.AnalyzerConfigOptionsProvider.Select(static (options, token) =>
- {
- return options.GetCsWinRTRcwFactoryFallbackGeneratorForceOptIn();
- });
-
- // Get whether the generator is explicitly set as opt-out
- IncrementalValueProvider isGeneratorForceOptOut = context.AnalyzerConfigOptionsProvider.Select(static (options, token) =>
- {
- return options.GetCsWinRTRcwFactoryFallbackGeneratorForceOptOut();
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Linq;
+using System.Text;
+using Microsoft.CodeAnalysis;
+using WinRT.SourceGenerator;
+
+#nullable enable
+
+namespace Generator;
+
+[Generator]
+public sealed class RcwReflectionFallbackGenerator : IIncrementalGenerator
+{
+ ///
+ public void Initialize(IncrementalGeneratorInitializationContext context)
+ {
+ // Gather all PE references from the current compilation
+ IncrementalValuesProvider executableReferences =
+ context.CompilationProvider
+ .SelectMany(static (compilation, token) =>
+ {
+ var executableReferences = ImmutableArray.CreateBuilder();
+
+ foreach (MetadataReference metadataReference in compilation.References)
+ {
+ // We are only interested in PE references (not project references)
+ if (metadataReference is not PortableExecutableReference executableReference)
+ {
+ continue;
+ }
+
+ executableReferences.Add(new EquatablePortableExecutableReference(executableReference, compilation));
+ }
+
+ return executableReferences.ToImmutable();
});
- IncrementalValueProvider csWinRTAotWarningEnabled = context.AnalyzerConfigOptionsProvider.Select(static (options, token) =>
- {
- return options.GetCsWinRTAotWarningLevel() >= 1;
- });
-
- // Get whether the generator should actually run or not
- IncrementalValueProvider isGeneratorEnabled =
- isOutputTypeExe
- .Combine(isGeneratorForceOptIn)
- .Combine(isGeneratorForceOptOut)
- .Select(static (flags, token) => (flags.Left.Left || flags.Left.Right) && !flags.Right);
-
- // Bypass all items if the flag is not set
- IncrementalValuesProvider<(EquatablePortableExecutableReference Value, bool)> enabledExecutableReferences =
- executableReferences
- .Combine(isGeneratorEnabled)
- .Where(static item => item.Right);
-
- // Get all the names of the projected types to root
- IncrementalValuesProvider> executableTypeNames = enabledExecutableReferences.Select(static (executableReference, token) =>
- {
- Compilation compilation = executableReference.Value.GetCompilationUnsafe();
-
- // We only care about resolved assembly symbols (this should always be the case anyway)
- if (compilation.GetAssemblyOrModuleSymbol(executableReference.Value.Reference) is not IAssemblySymbol assemblySymbol)
- {
- return EquatableArray.FromImmutableArray(ImmutableArray.Empty);
- }
-
- // If the assembly is not an old projections assembly, we have nothing to do
- if (!GeneratorHelper.IsOldProjectionAssembly(assemblySymbol))
- {
- return EquatableArray.FromImmutableArray(ImmutableArray.Empty);
- }
-
- token.ThrowIfCancellationRequested();
-
- ITypeSymbol attributeSymbol = compilation.GetTypeByMetadataName("System.Attribute")!;
- ITypeSymbol windowsRuntimeTypeAttributeSymbol = compilation.GetTypeByMetadataName("WinRT.WindowsRuntimeTypeAttribute")!;
-
+ // Get whether the current project is an .exe
+ IncrementalValueProvider isOutputTypeExe = context.CompilationProvider.Select(static (compilation, token) =>
+ {
+ return compilation.Options.OutputKind is OutputKind.ConsoleApplication or OutputKind.WindowsApplication or OutputKind.WindowsRuntimeApplication;
+ });
+
+ // Get whether the generator is explicitly set as opt-in
+ IncrementalValueProvider isGeneratorForceOptIn = context.AnalyzerConfigOptionsProvider.Select(static (options, token) =>
+ {
+ return options.GetCsWinRTRcwFactoryFallbackGeneratorForceOptIn();
+ });
+
+ // Get whether the generator is explicitly set as opt-out
+ IncrementalValueProvider isGeneratorForceOptOut = context.AnalyzerConfigOptionsProvider.Select(static (options, token) =>
+ {
+ return options.GetCsWinRTRcwFactoryFallbackGeneratorForceOptOut();
+ });
+
+ IncrementalValueProvider csWinRTAotWarningEnabled = context.AnalyzerConfigOptionsProvider.Select(static (options, token) =>
+ {
+ return options.GetCsWinRTAotWarningLevel() >= 1;
+ });
+
+ // Get whether the generator should actually run or not
+ IncrementalValueProvider isGeneratorEnabled =
+ isOutputTypeExe
+ .Combine(isGeneratorForceOptIn)
+ .Combine(isGeneratorForceOptOut)
+ .Select(static (flags, token) => (flags.Left.Left || flags.Left.Right) && !flags.Right);
+
+ // Bypass all items if the flag is not set
+ IncrementalValuesProvider<(EquatablePortableExecutableReference Value, bool)> enabledExecutableReferences =
+ executableReferences
+ .Combine(isGeneratorEnabled)
+ .Where(static item => item.Right);
+
+ // Get all the names of the projected types to root
+ IncrementalValuesProvider> executableTypeNames = enabledExecutableReferences.Select(static (executableReference, token) =>
+ {
+ Compilation compilation = executableReference.Value.GetCompilationUnsafe();
+
+ // We only care about resolved assembly symbols (this should always be the case anyway)
+ if (compilation.GetAssemblyOrModuleSymbol(executableReference.Value.Reference) is not IAssemblySymbol assemblySymbol)
+ {
+ return EquatableArray.FromImmutableArray(ImmutableArray.Empty);
+ }
+
+ // If the assembly is not an old projections assembly, we have nothing to do
+ if (!GeneratorHelper.IsOldProjectionAssembly(assemblySymbol))
+ {
+ return EquatableArray.FromImmutableArray(ImmutableArray.Empty);
+ }
+
+ token.ThrowIfCancellationRequested();
+
+ ITypeSymbol attributeSymbol = compilation.GetTypeByMetadataName("System.Attribute")!;
+ ITypeSymbol windowsRuntimeTypeAttributeSymbol = compilation.GetTypeByMetadataName("WinRT.WindowsRuntimeTypeAttribute")!;
+
ImmutableArray.Builder executableTypeNames = ImmutableArray.CreateBuilder();
-
- // Process all type symbols in the current assembly
- foreach (INamedTypeSymbol typeSymbol in VisitNamedTypeSymbolsExceptABI(assemblySymbol))
- {
- token.ThrowIfCancellationRequested();
-
- // We only care about public or internal classes
- if (typeSymbol is not { TypeKind: TypeKind.Class, DeclaredAccessibility: Accessibility.Public or Accessibility.Internal })
- {
- continue;
- }
-
- // Ignore static types (we only care about actual RCW types we can instantiate)
- if (typeSymbol.IsStatic)
- {
- continue;
- }
-
- // Ignore attribute types (they're never instantiated like normal RCWs)
- if (IsDerivedFromType(typeSymbol, attributeSymbol))
- {
- continue;
- }
-
- // If the type is not a generated projected type, do nothing
- if (!GeneratorHelper.HasAttributeWithType(typeSymbol, windowsRuntimeTypeAttributeSymbol))
- {
- continue;
- }
-
- // Double check we can in fact access this type (or we can't reference it)
- if (!compilation.IsSymbolAccessibleWithin(typeSymbol, compilation.Assembly))
- {
- continue;
- }
-
- var typeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
-
- // These types are in the existing WinUI projection, but have been moved to the Windows SDK projection.
- // So if we see those, we want to ignore them.
+
+ // Process all type symbols in the current assembly
+ foreach (INamedTypeSymbol typeSymbol in VisitNamedTypeSymbolsExceptABI(assemblySymbol))
+ {
+ token.ThrowIfCancellationRequested();
+
+ // We only care about public or internal classes
+ if (typeSymbol is not { TypeKind: TypeKind.Class, DeclaredAccessibility: Accessibility.Public or Accessibility.Internal })
+ {
+ continue;
+ }
+
+ // Ignore static types (we only care about actual RCW types we can instantiate)
+ if (typeSymbol.IsStatic)
+ {
+ continue;
+ }
+
+ // Ignore attribute types (they're never instantiated like normal RCWs)
+ if (GeneratorHelper.IsDerivedFromType(typeSymbol, attributeSymbol))
+ {
+ continue;
+ }
+
+ // If the type is not a generated projected type, do nothing
+ if (!GeneratorHelper.HasAttributeWithType(typeSymbol, windowsRuntimeTypeAttributeSymbol))
+ {
+ continue;
+ }
+
+ // Double check we can in fact access this type (or we can't reference it)
+ if (!compilation.IsSymbolAccessibleWithin(typeSymbol, compilation.Assembly))
+ {
+ continue;
+ }
+
+ var typeName = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
+
+ // These types are in the existing WinUI projection, but have been moved to the Windows SDK projection.
+ // So if we see those, we want to ignore them.
if (typeName == "global::Windows.UI.Text.ContentLinkInfo" ||
typeName == "global::Windows.UI.Text.RichEditTextDocument" ||
typeName == "global::Windows.UI.Text.RichEditTextRange")
{
continue;
- }
-
- // Check if we are able to resolve the type using GetTypeByMetadataName. If not,
- // it indicates there are multiple definitions of this type in the references
- // and us emitting a dependency on this type would cause compiler error. So emit
- // a warning instead.
+ }
+
+ // Check if we are able to resolve the type using GetTypeByMetadataName. If not,
+ // it indicates there are multiple definitions of this type in the references
+ // and us emitting a dependency on this type would cause compiler error. So emit
+ // a warning instead.
bool hasMultipleDefinitions = compilation.GetTypeByMetadataName(GeneratorHelper.TrimGlobalFromTypeName(typeName)) is null;
executableTypeNames.Add(new RcwReflectionFallbackType(typeName, hasMultipleDefinitions));
- }
-
- token.ThrowIfCancellationRequested();
-
- return EquatableArray.FromImmutableArray(executableTypeNames.ToImmutable());
- });
-
- // Combine all names into a single sequence
- IncrementalValueProvider<(ImmutableArray, bool)> projectedTypeNamesAndAotWarningEnabled =
- executableTypeNames
- .Where(static names => !names.IsEmpty)
- .SelectMany(static (executableTypeNames, token) => executableTypeNames.AsImmutableArray())
- .Collect()
- .Combine(csWinRTAotWarningEnabled);
-
- // Generate the [DynamicDependency] attributes
- context.RegisterImplementationSourceOutput(projectedTypeNamesAndAotWarningEnabled, static (SourceProductionContext context, (ImmutableArray projectedTypeNames, bool csWinRTAotWarningEnabled) value) =>
- {
- if (value.projectedTypeNames.IsEmpty)
- {
- return;
- }
-
- StringBuilder builder = new();
-
- builder.AppendLine("""
- //
- #pragma warning disable
-
- namespace WinRT
- {
- using global::System.Runtime.CompilerServices;
- using global::System.Diagnostics.CodeAnalysis;
-
- ///
- /// Roots RCW types for assemblies referencing old projections.
- /// It is recommended to update those, to get binary size savings.
- ///
- internal static class RcwFallbackInitializer
- {
- ///
- /// Roots all dependent RCW types.
- ///
- [ModuleInitializer]
- """);
-
- bool emittedDynamicDependency = false;
- foreach (RcwReflectionFallbackType projectedTypeName in value.projectedTypeNames)
- {
- // If there are multiple definitions of the type, emitting a dependency would result in a compiler error.
- // So instead, emit a diagnostic for it.
+ }
+
+ token.ThrowIfCancellationRequested();
+
+ return EquatableArray.FromImmutableArray(executableTypeNames.ToImmutable());
+ });
+
+ // Combine all names into a single sequence
+ IncrementalValueProvider<(ImmutableArray, bool)> projectedTypeNamesAndAotWarningEnabled =
+ executableTypeNames
+ .Where(static names => !names.IsEmpty)
+ .SelectMany(static (executableTypeNames, token) => executableTypeNames.AsImmutableArray())
+ .Collect()
+ .Combine(csWinRTAotWarningEnabled);
+
+ // Generate the [DynamicDependency] attributes
+ context.RegisterImplementationSourceOutput(projectedTypeNamesAndAotWarningEnabled, static (SourceProductionContext context, (ImmutableArray projectedTypeNames, bool csWinRTAotWarningEnabled) value) =>
+ {
+ if (value.projectedTypeNames.IsEmpty)
+ {
+ return;
+ }
+
+ StringBuilder builder = new();
+
+ builder.AppendLine("""
+ //
+ #pragma warning disable
+
+ namespace WinRT
+ {
+ using global::System.Runtime.CompilerServices;
+ using global::System.Diagnostics.CodeAnalysis;
+
+ ///
+ /// Roots RCW types for assemblies referencing old projections.
+ /// It is recommended to update those, to get binary size savings.
+ ///
+ internal static class RcwFallbackInitializer
+ {
+ ///
+ /// Roots all dependent RCW types.
+ ///
+ [ModuleInitializer]
+ """);
+
+ bool emittedDynamicDependency = false;
+ foreach (RcwReflectionFallbackType projectedTypeName in value.projectedTypeNames)
+ {
+ // If there are multiple definitions of the type, emitting a dependency would result in a compiler error.
+ // So instead, emit a diagnostic for it.
if (projectedTypeName.HasMultipleDefinitions)
{
var diagnosticDescriptor = value.csWinRTAotWarningEnabled ?
WinRTRules.ClassNotAotCompatibleOldProjectionMultipleInstancesWarning : WinRTRules.ClassNotAotCompatibleOldProjectionMultipleInstancesInfo;
// We have no location to emit the diagnostic as this is just a reference we detect.
context.ReportDiagnostic(Diagnostic.Create(diagnosticDescriptor, null, GeneratorHelper.TrimGlobalFromTypeName(projectedTypeName.TypeName)));
- }
+ }
else
{
emittedDynamicDependency = true;
builder.Append(" [DynamicDependency(DynamicallyAccessedMemberTypes.NonPublicConstructors, typeof(");
builder.Append(projectedTypeName.TypeName);
builder.AppendLine("))]");
- }
- }
-
- builder.Append("""
- public static void InitializeRcwFallback()
- {
- }
- }
- }
- """);
-
+ }
+ }
+
+ builder.Append("""
+ public static void InitializeRcwFallback()
+ {
+ }
+ }
+ }
+ """);
+
if (emittedDynamicDependency)
{
context.AddSource("RcwFallbackInitializer.g.cs", builder.ToString());
- }
- });
- }
-
- ///
- /// Visits all named type symbols in a given assembly, except for ABI types.
- ///
- /// The assembly to inspect.
- /// All named type symbols in , except for ABI types.
- private static IEnumerable VisitNamedTypeSymbolsExceptABI(IAssemblySymbol assemblySymbol)
- {
- static IEnumerable Visit(INamespaceOrTypeSymbol symbol)
- {
- foreach (ISymbol memberSymbol in symbol.GetMembers())
- {
- // Visit the current symbol if it's a type symbol
- if (memberSymbol is INamedTypeSymbol typeSymbol)
- {
- yield return typeSymbol;
- }
- else if (memberSymbol is INamespaceSymbol { Name: not ("ABI" or "WinRT") } namespaceSymbol)
- {
- // If the symbol is a namespace, also recurse (ignore the ABI namespaces)
- foreach (INamedTypeSymbol nestedTypeSymbol in Visit(namespaceSymbol))
- {
- yield return nestedTypeSymbol;
- }
- }
- }
- }
-
- return Visit(assemblySymbol.GlobalNamespace);
- }
-
- ///
- /// Checks whether a given type is derived from a specified type.
- ///
- /// The input instance to check.
- /// The base type to look for.
- /// Whether derives from .
- private static bool IsDerivedFromType(ITypeSymbol typeSymbol, ITypeSymbol baseTypeSymbol)
- {
- for (ITypeSymbol? currentSymbol = typeSymbol.BaseType;
- currentSymbol is { SpecialType: not SpecialType.System_Object };
- currentSymbol = currentSymbol.BaseType)
- {
- if (SymbolEqualityComparer.Default.Equals(currentSymbol, baseTypeSymbol))
- {
- return true;
- }
- }
-
- return false;
- }
-
- ///
- /// An equatable type that weakly references a object.
- ///
- /// The object to wrap.
- /// The instance where comes from.
- public sealed class EquatablePortableExecutableReference(
- PortableExecutableReference executableReference,
- Compilation compilation) : IEquatable
- {
- ///
- /// A weak reference to the object owning .
- ///
- private readonly WeakReference Compilation = new(compilation);
-
- ///
- /// Gets the object for this instance.
- ///
- public PortableExecutableReference Reference { get; } = executableReference;
-
- ///
- /// Gets the object for .
- ///
- /// The object for .
- /// Thrown if the object has been collected.
- ///
- /// This method should only be used from incremental steps immediately following a change in the metadata reference
- /// being used, as that would guarantee that that object would be alive.
- ///
- public Compilation GetCompilationUnsafe()
- {
- if (Compilation.TryGetTarget(out Compilation? compilation))
- {
- return compilation;
- }
-
- throw new InvalidOperationException("No compilation object is available.");
- }
-
- ///
- public bool Equals(EquatablePortableExecutableReference other)
- {
- if (other is null)
- {
- return false;
- }
-
- return other.Reference.GetMetadataId() == Reference.GetMetadataId();
- }
+ }
+ });
+ }
+
+ ///
+ /// Visits all named type symbols in a given assembly, except for ABI types.
+ ///
+ /// The assembly to inspect.
+ /// All named type symbols in , except for ABI types.
+ private static IEnumerable VisitNamedTypeSymbolsExceptABI(IAssemblySymbol assemblySymbol)
+ {
+ static IEnumerable Visit(INamespaceOrTypeSymbol symbol)
+ {
+ foreach (ISymbol memberSymbol in symbol.GetMembers())
+ {
+ // Visit the current symbol if it's a type symbol
+ if (memberSymbol is INamedTypeSymbol typeSymbol)
+ {
+ yield return typeSymbol;
+ }
+ else if (memberSymbol is INamespaceSymbol { Name: not ("ABI" or "WinRT") } namespaceSymbol)
+ {
+ // If the symbol is a namespace, also recurse (ignore the ABI namespaces)
+ foreach (INamedTypeSymbol nestedTypeSymbol in Visit(namespaceSymbol))
+ {
+ yield return nestedTypeSymbol;
+ }
+ }
+ }
+ }
+
+ return Visit(assemblySymbol.GlobalNamespace);
+ }
+
+ ///
+ /// An equatable type that weakly references a object.
+ ///
+ /// The object to wrap.
+ /// The instance where comes from.
+ public sealed class EquatablePortableExecutableReference(
+ PortableExecutableReference executableReference,
+ Compilation compilation) : IEquatable
+ {
+ ///
+ /// A weak reference to the object owning .
+ ///
+ private readonly WeakReference Compilation = new(compilation);
+
+ ///
+ /// Gets the object for this instance.
+ ///
+ public PortableExecutableReference Reference { get; } = executableReference;
+
+ ///
+ /// Gets the object for .
+ ///
+ /// The object for .
+ /// Thrown if the object has been collected.
+ ///
+ /// This method should only be used from incremental steps immediately following a change in the metadata reference
+ /// being used, as that would guarantee that that object would be alive.
+ ///
+ public Compilation GetCompilationUnsafe()
+ {
+ if (Compilation.TryGetTarget(out Compilation? compilation))
+ {
+ return compilation;
+ }
+
+ throw new InvalidOperationException("No compilation object is available.");
+ }
+
+ ///
+ public bool Equals(EquatablePortableExecutableReference other)
+ {
+ if (other is null)
+ {
+ return false;
+ }
+
+ return other.Reference.GetMetadataId() == Reference.GetMetadataId();
+ }
}
internal readonly record struct RcwReflectionFallbackType(string TypeName, bool HasMultipleDefinitions);
-}
+}
diff --git a/src/WinRT.Runtime/CastSupport.cs b/src/WinRT.Runtime/CastSupport.cs
new file mode 100644
index 000000000..229525a9f
--- /dev/null
+++ b/src/WinRT.Runtime/CastSupport.cs
@@ -0,0 +1,18 @@
+using System;
+using System.ComponentModel;
+
+namespace WinRT
+{
+ [EditorBrowsable(EditorBrowsableState.Never)]
+ public static class CastSupport
+ {
+ ///
+ /// Register a runtime class in the cache.
+ ///
+ /// The runtime class name to register.
+ /// The runtime class type to be registered.
+ /// This method is only meant to be used in AotOptimizer, in order to allow casting from to a WinRT class.
+ public static void RegisterTypeName(string runtimeClassName, Type runtimeClass)
+ => TypeNameSupport.RegisterTypeName(runtimeClassName, runtimeClass);
+ }
+}
diff --git a/src/WinRT.Runtime/TypeNameSupport.cs b/src/WinRT.Runtime/TypeNameSupport.cs
index 599e086a6..c0aacd922 100644
--- a/src/WinRT.Runtime/TypeNameSupport.cs
+++ b/src/WinRT.Runtime/TypeNameSupport.cs
@@ -1,98 +1,103 @@
-// Copyright (c) Microsoft Corporation.
-// Licensed under the MIT License.
-
-using System;
-using System.Collections.Concurrent;
-using System.Collections.Generic;
-using System.Diagnostics;
-using System.Diagnostics.CodeAnalysis;
-using System.Reflection;
-using System.Runtime.CompilerServices;
-using System.Text;
-
-namespace WinRT
-{
- [Flags]
- internal enum TypeNameGenerationFlags
- {
- None = 0,
-
- ///
- /// Generate the name of the type as if it was boxed in an object.
- ///
- GenerateBoxedName = 0x1,
-
- ///
- /// Don't output a type name of a custom .NET type. Generate a compatible WinRT type name if needed.
- ///
- ForGetRuntimeClassName = 0x2,
- }
-
- internal static class TypeNameSupport
- {
- private static readonly List projectionAssemblies = new List();
- private static readonly List> projectionTypeNameToBaseTypeNameMappings = new List>();
- private static readonly ConcurrentDictionary typeNameCache = new ConcurrentDictionary(StringComparer.Ordinal) { ["TrackerCollection"] = null };
- private static readonly ConcurrentDictionary baseRcwTypeCache = new ConcurrentDictionary(StringComparer.Ordinal) { ["TrackerCollection"] = null };
-
- public static void RegisterProjectionAssembly(Assembly assembly)
- {
- projectionAssemblies.Add(assembly);
- }
-
- public static void RegisterProjectionTypeBaseTypeMapping(IDictionary typeNameToBaseTypeNameMapping)
- {
- projectionTypeNameToBaseTypeNameMappings.Add(typeNameToBaseTypeNameMapping);
- }
-
- public static Type FindRcwTypeByNameCached(string runtimeClassName)
- {
- // Try to get the given type name. If it is not found, the type might have been trimmed.
- // Due to that, check if one of the base types exists and if so use that instead for the RCW type.
- var rcwType = FindTypeByNameCached(runtimeClassName);
- if (rcwType is null)
- {
- rcwType = baseRcwTypeCache.GetOrAdd(runtimeClassName,
- static (runtimeClassName) =>
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Concurrent;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
+using System.Reflection;
+using System.Runtime.CompilerServices;
+using System.Text;
+
+namespace WinRT
+{
+ [Flags]
+ internal enum TypeNameGenerationFlags
+ {
+ None = 0,
+
+ ///
+ /// Generate the name of the type as if it was boxed in an object.
+ ///
+ GenerateBoxedName = 0x1,
+
+ ///
+ /// Don't output a type name of a custom .NET type. Generate a compatible WinRT type name if needed.
+ ///
+ ForGetRuntimeClassName = 0x2,
+ }
+
+ internal static class TypeNameSupport
+ {
+ private static readonly List projectionAssemblies = new List();
+ private static readonly List> projectionTypeNameToBaseTypeNameMappings = new List>();
+ private static readonly ConcurrentDictionary typeNameCache = new ConcurrentDictionary(StringComparer.Ordinal) { ["TrackerCollection"] = null };
+ private static readonly ConcurrentDictionary baseRcwTypeCache = new ConcurrentDictionary(StringComparer.Ordinal) { ["TrackerCollection"] = null };
+
+ public static void RegisterProjectionAssembly(Assembly assembly)
+ {
+ projectionAssemblies.Add(assembly);
+ }
+
+ public static void RegisterProjectionTypeBaseTypeMapping(IDictionary typeNameToBaseTypeNameMapping)
+ {
+ projectionTypeNameToBaseTypeNameMappings.Add(typeNameToBaseTypeNameMapping);
+ }
+
+ public static Type FindRcwTypeByNameCached(string runtimeClassName)
+ {
+ // Try to get the given type name. If it is not found, the type might have been trimmed.
+ // Due to that, check if one of the base types exists and if so use that instead for the RCW type.
+ var rcwType = FindTypeByNameCached(runtimeClassName);
+ if (rcwType is null)
+ {
+ rcwType = baseRcwTypeCache.GetOrAdd(runtimeClassName,
+ static (runtimeClassName) =>
{
// Using for loop to avoid exception from list changing when using for each.
// List is only added to and if any are added while looping, we can ignore those.
- int count = projectionTypeNameToBaseTypeNameMappings.Count;
- for (int i = 0; i < count; i++)
+ int count = projectionTypeNameToBaseTypeNameMappings.Count;
+ for (int i = 0; i < count; i++)
{
if (projectionTypeNameToBaseTypeNameMappings[i].ContainsKey(runtimeClassName))
{
return FindRcwTypeByNameCached(projectionTypeNameToBaseTypeNameMappings[i][runtimeClassName]);
}
- }
-
- return null;
- });
- }
-
- return rcwType;
- }
-
- ///
- /// Parses and loads the given type name, if not found in the cache.
- ///
- /// The runtime class name to attempt to parse.
- /// The type, if found. Null otherwise
- public static Type FindTypeByNameCached(string runtimeClassName)
- {
- return typeNameCache.GetOrAdd(runtimeClassName,
- static (runtimeClassName) =>
- {
- Type implementationType = null;
- try
- {
- implementationType = FindTypeByName(runtimeClassName.AsSpan()).type;
- }
- catch (Exception)
- {
- }
- return implementationType;
- });
+ }
+
+ return null;
+ });
+ }
+
+ return rcwType;
+ }
+
+ ///
+ /// Parses and loads the given type name, if not found in the cache.
+ ///
+ /// The runtime class name to attempt to parse.
+ /// The type, if found. Null otherwise
+ public static Type FindTypeByNameCached(string runtimeClassName)
+ {
+ return typeNameCache.GetOrAdd(runtimeClassName,
+ static (runtimeClassName) =>
+ {
+ Type implementationType = null;
+ try
+ {
+ implementationType = FindTypeByName(runtimeClassName.AsSpan()).type;
+ }
+ catch (Exception)
+ {
+ }
+ return implementationType;
+ });
+ }
+
+ public static void RegisterTypeName(string runtimeClassName, Type runtimeClass)
+ {
+ typeNameCache.TryAdd(runtimeClassName, runtimeClass);
}
// Helper to get an exception if the input type is 'IReference' when support for it is disabled
@@ -103,134 +108,134 @@ private static Exception GetExceptionForUnsupportedIReferenceType(ReadOnlySpan' projected type. " +
"This can only be used when support for 'IReference' types is enabled in the CsWinRT configuration. To enable it, " +
"make sure that the 'CsWinRTEnableIReferenceSupport' MSBuild property is not being set to 'false' anywhere.");
- }
-
- ///
- /// Parse the first full type name within the provided span.
- ///
- /// The runtime class name to attempt to parse.
- /// A tuple containing the resolved type and the index of the end of the resolved type name.
- public static (Type type, int remaining) FindTypeByName(ReadOnlySpan runtimeClassName)
- {
- // Assume that anonymous types are expando objects, whether declared 'dynamic' or not.
- // It may be necessary to detect otherwise and return System.Object.
- if (runtimeClassName.StartsWith("<>f__AnonymousType".AsSpan(), StringComparison.Ordinal))
- {
+ }
+
+ ///
+ /// Parse the first full type name within the provided span.
+ ///
+ /// The runtime class name to attempt to parse.
+ /// A tuple containing the resolved type and the index of the end of the resolved type name.
+ public static (Type type, int remaining) FindTypeByName(ReadOnlySpan runtimeClassName)
+ {
+ // Assume that anonymous types are expando objects, whether declared 'dynamic' or not.
+ // It may be necessary to detect otherwise and return System.Object.
+ if (runtimeClassName.StartsWith("<>f__AnonymousType".AsSpan(), StringComparison.Ordinal))
+ {
if (FeatureSwitches.EnableDynamicObjectsSupport)
{
return (typeof(System.Dynamic.ExpandoObject), 0);
- }
-
- throw new NotSupportedException(
- $"The requested runtime class name is '{runtimeClassName.ToString()}', which maps to a dynamic projected type. " +
- "This can only be used when support for dynamic objects is enabled in the CsWinRT configuration. To enable it, " +
- "make sure that the 'CsWinRTEnableDynamicObjectsSupport' MSBuild property is not being set to 'false' anywhere.");
- }
-
- // PropertySet and ValueSet can return IReference but Nullable is illegal
- if (runtimeClassName.CompareTo("Windows.Foundation.IReference`1".AsSpan(), StringComparison.Ordinal) == 0)
- {
+ }
+
+ throw new NotSupportedException(
+ $"The requested runtime class name is '{runtimeClassName.ToString()}', which maps to a dynamic projected type. " +
+ "This can only be used when support for dynamic objects is enabled in the CsWinRT configuration. To enable it, " +
+ "make sure that the 'CsWinRTEnableDynamicObjectsSupport' MSBuild property is not being set to 'false' anywhere.");
+ }
+
+ // PropertySet and ValueSet can return IReference but Nullable is illegal
+ if (runtimeClassName.CompareTo("Windows.Foundation.IReference`1".AsSpan(), StringComparison.Ordinal) == 0)
+ {
if (FeatureSwitches.EnableIReferenceSupport)
{
return (typeof(ABI.System.Nullable_string), 0);
- }
-
- throw GetExceptionForUnsupportedIReferenceType(runtimeClassName);
- }
-
- if (runtimeClassName.CompareTo("Windows.Foundation.IReference`1".AsSpan(), StringComparison.Ordinal) == 0)
- {
+ }
+
+ throw GetExceptionForUnsupportedIReferenceType(runtimeClassName);
+ }
+
+ if (runtimeClassName.CompareTo("Windows.Foundation.IReference`1".AsSpan(), StringComparison.Ordinal) == 0)
+ {
if (FeatureSwitches.EnableIReferenceSupport)
{
return (typeof(ABI.System.Nullable_Type), 0);
- }
-
- throw GetExceptionForUnsupportedIReferenceType(runtimeClassName);
+ }
+
+ throw GetExceptionForUnsupportedIReferenceType(runtimeClassName);
}
-
- if (runtimeClassName.CompareTo("Windows.Foundation.IReference`1".AsSpan(), StringComparison.Ordinal) == 0)
- {
+
+ if (runtimeClassName.CompareTo("Windows.Foundation.IReference`1".AsSpan(), StringComparison.Ordinal) == 0)
+ {
if (FeatureSwitches.EnableIReferenceSupport)
{
return (typeof(ABI.System.Nullable_Exception), 0);
- }
-
- throw GetExceptionForUnsupportedIReferenceType(runtimeClassName);
- }
-
+ }
+
+ throw GetExceptionForUnsupportedIReferenceType(runtimeClassName);
+ }
+
var (genericTypeName, genericTypes, remaining) = ParseGenericTypeName(runtimeClassName);
if (genericTypeName == null)
{
return (null, -1);
}
- return (FindTypeByNameCore(genericTypeName, genericTypes), remaining);
- }
-
- ///
- /// Resolve a type from the given simple type name and the provided generic parameters.
- ///
- /// The simple type name.
- /// The generic parameters.
- /// The resolved (and instantiated if generic) type.
- ///
- /// We look up the type dynamically because at this point in the stack we can't know
- /// the full type closure of the application.
- ///
-#if NET
- [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Any types which are trimmed are not used by user code and there is fallback logic to handle that.")]
-#endif
- private static Type FindTypeByNameCore(string runtimeClassName, Type[] genericTypes)
- {
- Type resolvedType = Projections.FindCustomTypeForAbiTypeName(runtimeClassName);
-
- if (resolvedType is null)
- {
- if (genericTypes is null)
- {
- Type primitiveType = ResolvePrimitiveType(runtimeClassName);
- if (primitiveType is not null)
- {
- return primitiveType;
- }
- }
-
- // Using for loop to avoid exception from list changing when using for each.
- // List is only added to and if any are added while looping, we can ignore those.
+ return (FindTypeByNameCore(genericTypeName, genericTypes), remaining);
+ }
+
+ ///
+ /// Resolve a type from the given simple type name and the provided generic parameters.
+ ///
+ /// The simple type name.
+ /// The generic parameters.
+ /// The resolved (and instantiated if generic) type.
+ ///
+ /// We look up the type dynamically because at this point in the stack we can't know
+ /// the full type closure of the application.
+ ///
+#if NET
+ [UnconditionalSuppressMessage("Trimming", "IL2026", Justification = "Any types which are trimmed are not used by user code and there is fallback logic to handle that.")]
+#endif
+ private static Type FindTypeByNameCore(string runtimeClassName, Type[] genericTypes)
+ {
+ Type resolvedType = Projections.FindCustomTypeForAbiTypeName(runtimeClassName);
+
+ if (resolvedType is null)
+ {
+ if (genericTypes is null)
+ {
+ Type primitiveType = ResolvePrimitiveType(runtimeClassName);
+ if (primitiveType is not null)
+ {
+ return primitiveType;
+ }
+ }
+
+ // Using for loop to avoid exception from list changing when using for each.
+ // List is only added to and if any are added while looping, we can ignore those.
int count = projectionAssemblies.Count;
- for (int i = 0; i < count; i++)
- {
- Type type = projectionAssemblies[i].GetType(runtimeClassName);
- if (type is not null)
- {
- resolvedType = type;
- break;
- }
- }
- }
-
- if (resolvedType is null)
- {
- foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
- {
- Type type = assembly.GetType(runtimeClassName);
- if (type is not null)
+ for (int i = 0; i < count; i++)
+ {
+ Type type = projectionAssemblies[i].GetType(runtimeClassName);
+ if (type is not null)
+ {
+ resolvedType = type;
+ break;
+ }
+ }
+ }
+
+ if (resolvedType is null)
+ {
+ foreach (var assembly in AppDomain.CurrentDomain.GetAssemblies())
+ {
+ Type type = assembly.GetType(runtimeClassName);
+ if (type is not null)
{
- resolvedType = type;
- break;
- }
- }
- }
-
- if (resolvedType is not null)
- {
- if (genericTypes != null)
+ resolvedType = type;
+ break;
+ }
+ }
+ }
+
+ if (resolvedType is not null)
+ {
+ if (genericTypes != null)
{
return ResolveGenericType(resolvedType, genericTypes, runtimeClassName);
- }
- return resolvedType;
- }
-
- Debug.WriteLine($"FindTypeByNameCore: Unable to find a type named '{runtimeClassName}'");
+ }
+ return resolvedType;
+ }
+
+ Debug.WriteLine($"FindTypeByNameCore: Unable to find a type named '{runtimeClassName}'");
return null;
#if NET
@@ -271,194 +276,194 @@ static Type ResolveGenericType(Type resolvedType, Type[] genericTypes, string ru
}
}
}
-#endif
+#endif
return resolvedType.MakeGenericType(genericTypes);
- }
- }
-
- public static Type ResolvePrimitiveType(string primitiveTypeName)
- {
- return primitiveTypeName switch
- {
- "UInt8" => typeof(byte),
- "Int8" => typeof(sbyte),
- "UInt16" => typeof(ushort),
- "Int16" => typeof(short),
- "UInt32" => typeof(uint),
- "Int32" => typeof(int),
- "UInt64" => typeof(ulong),
- "Int64" => typeof(long),
- "Boolean" => typeof(bool),
- "String" => typeof(string),
- "Char" => typeof(char),
- "Char16" => typeof(char),
- "Single" => typeof(float),
- "Double" => typeof(double),
- "Guid" => typeof(Guid),
- "Object" => typeof(object),
- "TimeSpan" => typeof(TimeSpan),
- _ => null
- };
- }
-
- ///
- /// Parses a type name from the start of a span including its generic parameters.
- ///
- /// A span starting with a type name to parse.
- /// Returns a tuple containing the simple type name of the type, and generic type parameters if they exist, and the index of the end of the type name in the span.
- private static (string genericTypeName, Type[] genericTypes, int remaining) ParseGenericTypeName(ReadOnlySpan partialTypeName)
- {
- int possibleEndOfSimpleTypeName = partialTypeName.IndexOfAny(',', '>');
- int endOfSimpleTypeName = partialTypeName.Length;
- if (possibleEndOfSimpleTypeName != -1)
- {
- endOfSimpleTypeName = possibleEndOfSimpleTypeName;
- }
- var typeName = partialTypeName.Slice(0, endOfSimpleTypeName);
-
- // If the type name doesn't contain a '`', then it isn't a generic type
- // so we can return before starting to parse the generic type list.
- if (!typeName.Contains("`".AsSpan(), StringComparison.Ordinal))
- {
- return (typeName.ToString(), null, endOfSimpleTypeName);
- }
-
- int genericTypeListStart = partialTypeName.IndexOf('<');
- var genericTypeName = partialTypeName.Slice(0, genericTypeListStart);
- var remainingTypeName = partialTypeName.Slice(genericTypeListStart + 1);
- int remainingIndex = genericTypeListStart + 1;
- List genericTypes = new List();
- while (true)
- {
- // Resolve the generic type argument at this point in the parameter list.
- var (genericType, endOfGenericArgument) = FindTypeByName(remainingTypeName);
- if (genericType == null)
- {
- return (null, null, -1);
- }
-
- remainingIndex += endOfGenericArgument;
- genericTypes.Add(genericType);
- remainingTypeName = remainingTypeName.Slice(endOfGenericArgument);
- if (remainingTypeName[0] == ',')
- {
- // Skip the comma and the space in the type name.
- remainingIndex += 2;
- remainingTypeName = remainingTypeName.Slice(2);
- continue;
- }
- else if (remainingTypeName[0] == '>')
- {
- // Skip the space after nested '>'
- var skip = (remainingTypeName.Length > 1 && remainingTypeName[1] == ' ') ? 2 : 1;
- remainingIndex += skip;
- remainingTypeName = remainingTypeName.Slice(skip);
- break;
- }
- else
- {
- throw new InvalidOperationException("The provided type name is invalid.");
- }
- }
- return (genericTypeName.ToString(), genericTypes.ToArray(), partialTypeName.Length - remainingTypeName.Length);
- }
-
- struct VisitedType
- {
- public Type Type { get; set; }
- public bool Covariant { get; set; }
+ }
+ }
+
+ public static Type ResolvePrimitiveType(string primitiveTypeName)
+ {
+ return primitiveTypeName switch
+ {
+ "UInt8" => typeof(byte),
+ "Int8" => typeof(sbyte),
+ "UInt16" => typeof(ushort),
+ "Int16" => typeof(short),
+ "UInt32" => typeof(uint),
+ "Int32" => typeof(int),
+ "UInt64" => typeof(ulong),
+ "Int64" => typeof(long),
+ "Boolean" => typeof(bool),
+ "String" => typeof(string),
+ "Char" => typeof(char),
+ "Char16" => typeof(char),
+ "Single" => typeof(float),
+ "Double" => typeof(double),
+ "Guid" => typeof(Guid),
+ "Object" => typeof(object),
+ "TimeSpan" => typeof(TimeSpan),
+ _ => null
+ };
+ }
+
+ ///
+ /// Parses a type name from the start of a span including its generic parameters.
+ ///
+ /// A span starting with a type name to parse.
+ /// Returns a tuple containing the simple type name of the type, and generic type parameters if they exist, and the index of the end of the type name in the span.
+ private static (string genericTypeName, Type[] genericTypes, int remaining) ParseGenericTypeName(ReadOnlySpan partialTypeName)
+ {
+ int possibleEndOfSimpleTypeName = partialTypeName.IndexOfAny(',', '>');
+ int endOfSimpleTypeName = partialTypeName.Length;
+ if (possibleEndOfSimpleTypeName != -1)
+ {
+ endOfSimpleTypeName = possibleEndOfSimpleTypeName;
+ }
+ var typeName = partialTypeName.Slice(0, endOfSimpleTypeName);
+
+ // If the type name doesn't contain a '`', then it isn't a generic type
+ // so we can return before starting to parse the generic type list.
+ if (!typeName.Contains("`".AsSpan(), StringComparison.Ordinal))
+ {
+ return (typeName.ToString(), null, endOfSimpleTypeName);
+ }
+
+ int genericTypeListStart = partialTypeName.IndexOf('<');
+ var genericTypeName = partialTypeName.Slice(0, genericTypeListStart);
+ var remainingTypeName = partialTypeName.Slice(genericTypeListStart + 1);
+ int remainingIndex = genericTypeListStart + 1;
+ List genericTypes = new List();
+ while (true)
+ {
+ // Resolve the generic type argument at this point in the parameter list.
+ var (genericType, endOfGenericArgument) = FindTypeByName(remainingTypeName);
+ if (genericType == null)
+ {
+ return (null, null, -1);
+ }
+
+ remainingIndex += endOfGenericArgument;
+ genericTypes.Add(genericType);
+ remainingTypeName = remainingTypeName.Slice(endOfGenericArgument);
+ if (remainingTypeName[0] == ',')
+ {
+ // Skip the comma and the space in the type name.
+ remainingIndex += 2;
+ remainingTypeName = remainingTypeName.Slice(2);
+ continue;
+ }
+ else if (remainingTypeName[0] == '>')
+ {
+ // Skip the space after nested '>'
+ var skip = (remainingTypeName.Length > 1 && remainingTypeName[1] == ' ') ? 2 : 1;
+ remainingIndex += skip;
+ remainingTypeName = remainingTypeName.Slice(skip);
+ break;
+ }
+ else
+ {
+ throw new InvalidOperationException("The provided type name is invalid.");
+ }
+ }
+ return (genericTypeName.ToString(), genericTypes.ToArray(), partialTypeName.Length - remainingTypeName.Length);
}
-#nullable enable
- ///
- /// Tracker for visited types when determining a WinRT interface to use as the type name.
- ///
- ///
- /// Only used when is called with .
- ///
- [ThreadStatic]
+ struct VisitedType
+ {
+ public Type Type { get; set; }
+ public bool Covariant { get; set; }
+ }
+
+#nullable enable
+ ///
+ /// Tracker for visited types when determining a WinRT interface to use as the type name.
+ ///
+ ///
+ /// Only used when is called with .
+ ///
+ [ThreadStatic]
private static Stack? visitedTypesInstance;
[ThreadStatic]
- private static StringBuilder? nameForTypeBuilderInstance;
-
- public static string GetNameForType(Type? type, TypeNameGenerationFlags flags)
- {
- if (type is null)
- {
- return string.Empty;
- }
-
- // Get instance for this thread
- StringBuilder? nameBuilder = nameForTypeBuilderInstance ??= new StringBuilder();
- nameBuilder.Clear();
- if (TryAppendTypeName(type, nameBuilder, flags))
- {
- return nameBuilder.ToString();
+ private static StringBuilder? nameForTypeBuilderInstance;
+
+ public static string GetNameForType(Type? type, TypeNameGenerationFlags flags)
+ {
+ if (type is null)
+ {
+ return string.Empty;
}
-
- return string.Empty;
+
+ // Get instance for this thread
+ StringBuilder? nameBuilder = nameForTypeBuilderInstance ??= new StringBuilder();
+ nameBuilder.Clear();
+ if (TryAppendTypeName(type, nameBuilder, flags))
+ {
+ return nameBuilder.ToString();
+ }
+
+ return string.Empty;
}
-#nullable restore
-
- private static bool TryAppendSimpleTypeName(Type type, StringBuilder builder, TypeNameGenerationFlags flags)
- {
- if (type.IsPrimitive || type == typeof(string) || type == typeof(Guid) || type == typeof(TimeSpan))
- {
- if (type == typeof(byte))
- {
- builder.Append("UInt8");
- }
- else if (type == typeof(sbyte))
- {
- builder.Append("Int8");
- }
- else
- {
- builder.Append(type.Name);
- }
- }
- else if (type == typeof(object))
- {
- builder.Append("Object");
+#nullable restore
+
+ private static bool TryAppendSimpleTypeName(Type type, StringBuilder builder, TypeNameGenerationFlags flags)
+ {
+ if (type.IsPrimitive || type == typeof(string) || type == typeof(Guid) || type == typeof(TimeSpan))
+ {
+ if (type == typeof(byte))
+ {
+ builder.Append("UInt8");
+ }
+ else if (type == typeof(sbyte))
+ {
+ builder.Append("Int8");
+ }
+ else
+ {
+ builder.Append(type.Name);
+ }
+ }
+ else if (type == typeof(object))
+ {
+ builder.Append("Object");
}
else if ((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0 && type.IsTypeOfType())
{
- builder.Append("Windows.UI.Xaml.Interop.TypeName");
- }
- else
- {
- var projectedAbiTypeName = Projections.FindCustomAbiTypeNameForType(type);
- if (projectedAbiTypeName is not null)
- {
- builder.Append(projectedAbiTypeName);
- }
- else if (Projections.IsTypeWindowsRuntimeType(type))
- {
- builder.Append(type.FullName);
- }
+ builder.Append("Windows.UI.Xaml.Interop.TypeName");
+ }
+ else
+ {
+ var projectedAbiTypeName = Projections.FindCustomAbiTypeNameForType(type);
+ if (projectedAbiTypeName is not null)
+ {
+ builder.Append(projectedAbiTypeName);
+ }
+ else if (Projections.IsTypeWindowsRuntimeType(type))
+ {
+ builder.Append(type.FullName);
+ }
else
- {
- if ((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0)
- {
+ {
+ if ((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0)
+ {
return TryAppendWinRTInterfaceNameForType(type, builder, flags);
}
else
{
builder.Append(type.FullName);
}
- }
- }
- return true;
- }
-
- private static bool TryAppendWinRTInterfaceNameForType(Type type, StringBuilder builder, TypeNameGenerationFlags flags)
- {
- Debug.Assert((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0);
+ }
+ }
+ return true;
+ }
+
+ private static bool TryAppendWinRTInterfaceNameForType(Type type, StringBuilder builder, TypeNameGenerationFlags flags)
+ {
+ Debug.Assert((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0);
Debug.Assert(!type.IsGenericTypeDefinition);
-#if NET
+#if NET
var runtimeClassNameAttribute = type.GetCustomAttribute();
if (runtimeClassNameAttribute is not null)
{
@@ -481,12 +486,12 @@ private static bool TryAppendWinRTInterfaceNameForType(Type type, StringBuilder
{
return false;
}
-#endif
-
-
- var visitedTypes = visitedTypesInstance ??= new Stack();
-
- // Manual helper to save binary size (no LINQ, no lambdas) and get better performance
+#endif
+
+
+ var visitedTypes = visitedTypesInstance ??= new Stack();
+
+ // Manual helper to save binary size (no LINQ, no lambdas) and get better performance
static bool HasAnyVisitedTypes(Stack visitedTypes, Type type)
{
foreach (VisitedType visitedType in visitedTypes)
@@ -498,25 +503,25 @@ static bool HasAnyVisitedTypes(Stack visitedTypes, Type type)
}
return false;
- }
-
- if (HasAnyVisitedTypes(visitedTypes, type))
- {
- // In this case, we've already visited the type when recursing through generic parameters.
- // Try to fall back to object if the parameter is covariant and the argument is compatable with object.
- // Otherwise there's no valid type name.
- if (visitedTypes.Peek().Covariant && !type.IsValueType)
- {
- builder.Append("Object");
- return true;
- }
- return false;
- }
- else
- {
-#if NET
+ }
+
+ if (HasAnyVisitedTypes(visitedTypes, type))
+ {
+ // In this case, we've already visited the type when recursing through generic parameters.
+ // Try to fall back to object if the parameter is covariant and the argument is compatable with object.
+ // Otherwise there's no valid type name.
+ if (visitedTypes.Peek().Covariant && !type.IsValueType)
+ {
+ builder.Append("Object");
+ return true;
+ }
+ return false;
+ }
+ else
+ {
+#if NET
[UnconditionalSuppressMessage("Trimming", "IL2070", Justification = "Updated binaries will have WinRTRuntimeClassNameAttribute which will be used instead.")]
-#endif
+#endif
static bool TryAppendWinRTInterfaceNameForTypeJit(Type type, StringBuilder builder, TypeNameGenerationFlags flags)
{
var visitedTypes = visitedTypesInstance;
@@ -545,24 +550,24 @@ static bool TryAppendWinRTInterfaceNameForTypeJit(Type type, StringBuilder build
visitedTypes.Pop();
return success;
- }
-
- return TryAppendWinRTInterfaceNameForTypeJit(type, builder, flags);
- }
- }
-
- private static bool TryAppendTypeName(Type type, StringBuilder builder, TypeNameGenerationFlags flags)
- {
-#if !NET
- // We can't easily determine from just the type
- // if the array is an "single dimension index from zero"-array in .NET Standard 2.0,
- // so just approximate it.
- // (Other array types will be blocked in other code-paths anyway where we have an object.)
- if (type.IsArray && type.GetArrayRank() == 1)
-#else
- if (type.IsSZArray)
-#endif
- {
+ }
+
+ return TryAppendWinRTInterfaceNameForTypeJit(type, builder, flags);
+ }
+ }
+
+ private static bool TryAppendTypeName(Type type, StringBuilder builder, TypeNameGenerationFlags flags)
+ {
+#if !NET
+ // We can't easily determine from just the type
+ // if the array is an "single dimension index from zero"-array in .NET Standard 2.0,
+ // so just approximate it.
+ // (Other array types will be blocked in other code-paths anyway where we have an object.)
+ if (type.IsArray && type.GetArrayRank() == 1)
+#else
+ if (type.IsSZArray)
+#endif
+ {
var elementType = type.GetElementType();
if (elementType.ShouldProvideIReference())
{
@@ -573,13 +578,13 @@ private static bool TryAppendTypeName(Type type, StringBuilder builder, TypeName
return true;
}
return false;
- }
+ }
else
{
return false;
- }
- }
-
+ }
+ }
+
if ((flags & TypeNameGenerationFlags.GenerateBoxedName) != 0 && type.ShouldProvideIReference())
{
builder.Append("Windows.Foundation.IReference`1<");
@@ -589,72 +594,72 @@ private static bool TryAppendTypeName(Type type, StringBuilder builder, TypeName
}
builder.Append('>');
return true;
- }
-
- if (!type.IsGenericType || type.IsGenericTypeDefinition)
- {
- return TryAppendSimpleTypeName(type, builder, flags);
- }
-
- if ((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0 && !Projections.IsTypeWindowsRuntimeType(type))
- {
- return TryAppendWinRTInterfaceNameForType(type, builder, flags);
- }
-
- Type definition = type.GetGenericTypeDefinition();
- if (!TryAppendSimpleTypeName(definition, builder, flags))
- {
- return false;
- }
-
- builder.Append('<');
-
- bool first = true;
-
- Type[] genericTypeArguments = type.GetGenericArguments();
- Type[] genericTypeParameters = definition.GetGenericArguments();
-
- var visitedTypes = visitedTypesInstance ??= new Stack();
-
- for (int i = 0; i < genericTypeArguments.Length; i++)
- {
- Type argument = genericTypeArguments[i];
-
- if (argument.ContainsGenericParameters)
- {
- throw new ArgumentException(nameof(type));
- }
-
- if (!first)
- {
- builder.Append(", ");
- }
- first = false;
-
- if ((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0)
- {
- visitedTypes.Push(new VisitedType
- {
- Type = type,
- Covariant = (genericTypeParameters[i].GenericParameterAttributes & GenericParameterAttributes.VarianceMask) == GenericParameterAttributes.Covariant
- });
- }
-
- bool success = TryAppendTypeName(argument, builder, flags & ~TypeNameGenerationFlags.GenerateBoxedName);
-
- if ((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0)
- {
- visitedTypes.Pop();
- }
-
- if (!success)
- {
- return false;
- }
- }
-
- builder.Append('>');
- return true;
- }
- }
-}
+ }
+
+ if (!type.IsGenericType || type.IsGenericTypeDefinition)
+ {
+ return TryAppendSimpleTypeName(type, builder, flags);
+ }
+
+ if ((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0 && !Projections.IsTypeWindowsRuntimeType(type))
+ {
+ return TryAppendWinRTInterfaceNameForType(type, builder, flags);
+ }
+
+ Type definition = type.GetGenericTypeDefinition();
+ if (!TryAppendSimpleTypeName(definition, builder, flags))
+ {
+ return false;
+ }
+
+ builder.Append('<');
+
+ bool first = true;
+
+ Type[] genericTypeArguments = type.GetGenericArguments();
+ Type[] genericTypeParameters = definition.GetGenericArguments();
+
+ var visitedTypes = visitedTypesInstance ??= new Stack();
+
+ for (int i = 0; i < genericTypeArguments.Length; i++)
+ {
+ Type argument = genericTypeArguments[i];
+
+ if (argument.ContainsGenericParameters)
+ {
+ throw new ArgumentException(nameof(type));
+ }
+
+ if (!first)
+ {
+ builder.Append(", ");
+ }
+ first = false;
+
+ if ((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0)
+ {
+ visitedTypes.Push(new VisitedType
+ {
+ Type = type,
+ Covariant = (genericTypeParameters[i].GenericParameterAttributes & GenericParameterAttributes.VarianceMask) == GenericParameterAttributes.Covariant
+ });
+ }
+
+ bool success = TryAppendTypeName(argument, builder, flags & ~TypeNameGenerationFlags.GenerateBoxedName);
+
+ if ((flags & TypeNameGenerationFlags.ForGetRuntimeClassName) != 0)
+ {
+ visitedTypes.Pop();
+ }
+
+ if (!success)
+ {
+ return false;
+ }
+ }
+
+ builder.Append('>');
+ return true;
+ }
+ }
+}
From 8e028da0e007b191a75af30ba2783d65eeed7a5f Mon Sep 17 00:00:00 2001
From: Dongle <29563098+dongle-the-gadget@users.noreply.github.com>
Date: Mon, 18 Nov 2024 22:22:41 +0700
Subject: [PATCH 2/3] Comment correctness.
---
src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs b/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs
index 8a18c3d1b..a32540e3a 100644
--- a/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs
+++ b/src/Authoring/WinRT.SourceGenerator/AotOptimizer.cs
@@ -223,7 +223,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
}
// Avoid cases where the type to be cast from is unknown, or can be done purely through static metadata.
- // That is, the type to be cast to inherits from the type of the expression to be cast,
+ // That is, the type of the expression inherits from the type to be cast to,
// as we know the cast will always work.
var sourceType = context.SemanticModel.GetTypeInfo(expression).Type;
if (sourceType == null || GeneratorHelper.IsDerivedFromType(sourceType, namedTypeSymbol))
From 7821f862f625a9559318ff3fb4ebfeac9168a2fe Mon Sep 17 00:00:00 2001
From: Dongle <29563098+dongle-the-gadget@users.noreply.github.com>
Date: Thu, 21 Nov 2024 22:03:21 +0700
Subject: [PATCH 3/3] Add tests for cast metadata scenario.
---
.../CastMetadata/CastMetadata.csproj | 17 ++++++++++++
.../FunctionalTests/CastMetadata/Program.cs | 3 +++
.../TestComponentCSharp/CastMetadata.cpp | 12 +++++++++
src/Tests/TestComponentCSharp/CastMetadata.h | 27 +++++++++++++++++++
.../TestComponentCSharp.idl | 15 +++++++++++
.../TestComponentCSharp.vcxproj | 2 ++
.../TestComponentCSharp.vcxproj.filters | 2 ++
src/build.cmd | 4 +--
8 files changed, 80 insertions(+), 2 deletions(-)
create mode 100644 src/Tests/FunctionalTests/CastMetadata/CastMetadata.csproj
create mode 100644 src/Tests/FunctionalTests/CastMetadata/Program.cs
create mode 100644 src/Tests/TestComponentCSharp/CastMetadata.cpp
create mode 100644 src/Tests/TestComponentCSharp/CastMetadata.h
diff --git a/src/Tests/FunctionalTests/CastMetadata/CastMetadata.csproj b/src/Tests/FunctionalTests/CastMetadata/CastMetadata.csproj
new file mode 100644
index 000000000..e7cb454f5
--- /dev/null
+++ b/src/Tests/FunctionalTests/CastMetadata/CastMetadata.csproj
@@ -0,0 +1,17 @@
+
+
+
+ Exe
+ $(FunctionalTestsBuildTFMs)
+ x86;x64
+ win-x86;win-x64
+ $(MSBuildProjectDirectory)\..\PublishProfiles\win10-$(Platform).pubxml
+
+
+
+
+
+
+
+
+
diff --git a/src/Tests/FunctionalTests/CastMetadata/Program.cs b/src/Tests/FunctionalTests/CastMetadata/Program.cs
new file mode 100644
index 000000000..959a44972
--- /dev/null
+++ b/src/Tests/FunctionalTests/CastMetadata/Program.cs
@@ -0,0 +1,3 @@
+using TestComponentCSharp.CastMetadata;
+
+Class castObject = (Class)ClassFactory.Create();
\ No newline at end of file
diff --git a/src/Tests/TestComponentCSharp/CastMetadata.cpp b/src/Tests/TestComponentCSharp/CastMetadata.cpp
new file mode 100644
index 000000000..199076515
--- /dev/null
+++ b/src/Tests/TestComponentCSharp/CastMetadata.cpp
@@ -0,0 +1,12 @@
+#include "pch.h"
+#include "CastMetadata.h"
+#include "CastMetadata.Class.g.cpp"
+#include "CastMetadata.ClassFactory.g.cpp"
+
+namespace winrt::TestComponentCSharp::CastMetadata::implementation
+{
+ winrt::Windows::Foundation::IInspectable ClassFactory::Create()
+ {
+ return winrt::make();
+ }
+}
\ No newline at end of file
diff --git a/src/Tests/TestComponentCSharp/CastMetadata.h b/src/Tests/TestComponentCSharp/CastMetadata.h
new file mode 100644
index 000000000..f118031a7
--- /dev/null
+++ b/src/Tests/TestComponentCSharp/CastMetadata.h
@@ -0,0 +1,27 @@
+#pragma once
+#include "CastMetadata.Class.g.h"
+#include "CastMetadata.ClassFactory.g.h"
+
+namespace winrt::TestComponentCSharp::CastMetadata::implementation
+{
+ struct Class : ClassT
+ {
+ Class() = default;
+ };
+
+ struct ClassFactory
+ {
+ static winrt::Windows::Foundation::IInspectable Create();
+ };
+}
+
+namespace winrt::TestComponentCSharp::CastMetadata::factory_implementation
+{
+ struct Class : ClassT
+ {
+ };
+
+ struct ClassFactory : ClassFactoryT
+ {
+ };
+}
\ No newline at end of file
diff --git a/src/Tests/TestComponentCSharp/TestComponentCSharp.idl b/src/Tests/TestComponentCSharp/TestComponentCSharp.idl
index 0d3259fd5..83cd3c704 100644
--- a/src/Tests/TestComponentCSharp/TestComponentCSharp.idl
+++ b/src/Tests/TestComponentCSharp/TestComponentCSharp.idl
@@ -757,4 +757,19 @@ And this is another one"
static Int32 StaticProperty { get; };
}
}
+
+ namespace CastMetadata
+ {
+ [default_interface]
+ runtimeclass Class
+ {
+ Class();
+ }
+
+ // Test for casting Object to a runtimeclass
+ static runtimeclass ClassFactory
+ {
+ static Object Create();
+ }
+ }
}
\ No newline at end of file
diff --git a/src/Tests/TestComponentCSharp/TestComponentCSharp.vcxproj b/src/Tests/TestComponentCSharp/TestComponentCSharp.vcxproj
index 0cf3eaf6b..7c74e6e57 100644
--- a/src/Tests/TestComponentCSharp/TestComponentCSharp.vcxproj
+++ b/src/Tests/TestComponentCSharp/TestComponentCSharp.vcxproj
@@ -81,6 +81,7 @@
+
@@ -103,6 +104,7 @@
+
diff --git a/src/Tests/TestComponentCSharp/TestComponentCSharp.vcxproj.filters b/src/Tests/TestComponentCSharp/TestComponentCSharp.vcxproj.filters
index a2df1ce81..a181b6e59 100644
--- a/src/Tests/TestComponentCSharp/TestComponentCSharp.vcxproj.filters
+++ b/src/Tests/TestComponentCSharp/TestComponentCSharp.vcxproj.filters
@@ -25,6 +25,7 @@
+
@@ -43,6 +44,7 @@
+
diff --git a/src/build.cmd b/src/build.cmd
index e3f9605fa..d8bc37bbd 100644
--- a/src/build.cmd
+++ b/src/build.cmd
@@ -86,8 +86,8 @@ if "%cswinrt_assembly_version%"=="" set cswinrt_assembly_version=0.0.0.0
if "%cswinrt_baseline_breaking_compat_errors%"=="" set cswinrt_baseline_breaking_compat_errors=false
if "%cswinrt_baseline_assembly_version_compat_errors%"=="" set cswinrt_baseline_assembly_version_compat_errors=false
-set cswinrt_functional_tests=JsonValueFunctionCalls, ClassActivation, Structs, Events, DynamicInterfaceCasting, Collections, Async, DerivedClassActivation, DerivedClassAsBaseClass, CCW
-set cswinrt_aot_functional_tests=JsonValueFunctionCalls, ClassActivation, Structs, Events, DynamicInterfaceCasting, Collections, Async, DerivedClassActivation, DerivedClassAsBaseClass, CCW
+set cswinrt_functional_tests=CastMetadata, JsonValueFunctionCalls, ClassActivation, Structs, Events, DynamicInterfaceCasting, Collections, Async, DerivedClassActivation, DerivedClassAsBaseClass, CCW
+set cswinrt_aot_functional_tests=CastMetadata, JsonValueFunctionCalls, ClassActivation, Structs, Events, DynamicInterfaceCasting, Collections, Async, DerivedClassActivation, DerivedClassAsBaseClass, CCW
if "%cswinrt_platform%" EQU "x86" set run_functional_tests=true
if "%cswinrt_platform%" EQU "x64" (