diff --git a/src/NodeApi/Interop/JSCollectionExtensions.cs b/src/NodeApi/Interop/JSCollectionExtensions.cs index 696b35e3..ab97ab5a 100644 --- a/src/NodeApi/Interop/JSCollectionExtensions.cs +++ b/src/NodeApi/Interop/JSCollectionExtensions.cs @@ -199,12 +199,17 @@ internal sealed class JSAsyncIterableEnumerator : IAsyncEnumerator, IDispo private readonly JSValue.To _fromJS; private readonly JSReference _iteratorReference; private JSReference? _currentReference; + private CancellationToken _cancellation; - internal JSAsyncIterableEnumerator(JSValue iterable, JSValue.To fromJS) + internal JSAsyncIterableEnumerator( + JSValue iterable, + JSValue.To fromJS, + CancellationToken cancellation) { _fromJS = fromJS; _iteratorReference = new JSReference(iterable.CallMethod(JSSymbol.AsyncIterator)); _currentReference = null; + _cancellation = cancellation; } public async ValueTask MoveNextAsync() @@ -214,7 +219,7 @@ public async ValueTask MoveNextAsync() _currentReference?.Dispose(); JSPromise nextPromise = (JSPromise)iterator.CallMethod("next"); - JSValue nextResult = await nextPromise.AsTask(); + JSValue nextResult = await nextPromise.AsTask(_cancellation); JSValue done = nextResult["done"]; if (done.IsBoolean() && (bool)done) @@ -334,10 +339,10 @@ internal JSAsyncIterableEnumerable(JSValue iterable, JSValue.To fromJS) bool IEquatable.Equals(JSValue other) => _iterableReference.Run((iterable) => iterable.Equals(other)); - public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken) + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellation) { return _iterableReference.Run( - (iterable) => new JSAsyncIterableEnumerator(iterable, _fromJS)); + (iterable) => new JSAsyncIterableEnumerator(iterable, _fromJS, cancellation)); } public void Dispose() diff --git a/src/NodeApi/JSPromiseExtensions.cs b/src/NodeApi/JSPromiseExtensions.cs index 3eb52901..43d4cf00 100644 --- a/src/NodeApi/JSPromiseExtensions.cs +++ b/src/NodeApi/JSPromiseExtensions.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Threading; using System.Threading.Tasks; namespace Microsoft.JavaScript.NodeApi; @@ -27,12 +28,41 @@ public static Task AsTask(this JSPromise promise) return completion.Task; } - public static async Task AsTask(this JSPromise promise, JSValue.To fromJS) + public static Task AsTask(this JSPromise promise, CancellationToken cancellation) + { + TaskCompletionSource completion = new(); + cancellation.Register(() => completion.TrySetCanceled(cancellation)); + promise.Then( + (JSValue value) => + { + completion.TrySetResult(value); + return default; + }, + (JSError error) => + { + completion.TrySetException(new JSException(error)); + return default; + }); + return completion.Task; + } + + public static async Task AsTask( + this JSPromise promise, + JSValue.To fromJS) { Task jsTask = promise.AsTask(); return fromJS(await jsTask); } + public static async Task AsTask( + this JSPromise promise, + JSValue.To fromJS, + CancellationToken cancellation) + { + Task jsTask = promise.AsTask(cancellation); + return fromJS(await jsTask); + } + public static JSPromise AsPromise(this Task task) { if (task.Status == TaskStatus.RanToCompletion)