diff --git a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs index 5f0610dc11dda6..fd8d45455c7a72 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/AsyncThunks.cs @@ -1,12 +1,19 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System; using Internal.TypeSystem; namespace Internal.IL.Stubs { public static class AsyncThunkILEmitter { + // Emits a thunk that wraps an async method to return a Task or ValueTask. + // The thunk calls the async method, and if it completes synchronously, + // it returns a completed Task/ValueTask. If the async method suspends, + // it calls FinalizeTaskReturningThunk/FinalizeValueTaskReturningThunk method to get the Task/ValueTask. + + // The emitted code matches method EmitTaskReturningThunk in CoreCLR VM. public static MethodIL EmitTaskReturningThunk(MethodDesc taskReturningMethod, MethodDesc asyncMethod) { TypeSystemContext context = taskReturningMethod.Context; @@ -14,24 +21,219 @@ public static MethodIL EmitTaskReturningThunk(MethodDesc taskReturningMethod, Me var emitter = new ILEmitter(); var codestream = emitter.NewCodeStream(); - // TODO: match EmitTaskReturningThunk in CoreCLR VM + MethodSignature sig = taskReturningMethod.Signature; + TypeDesc returnType = sig.ReturnType; + + bool isValueTask = returnType.IsValueType; + + TypeDesc logicalReturnType = null; + ILLocalVariable logicalResultLocal = 0; + if (returnType.HasInstantiation) + { + // The return type is either Task or ValueTask, exactly one generic argument + logicalReturnType = returnType.Instantiation[0]; + logicalResultLocal = emitter.NewLocal(logicalReturnType); + } - MethodSignature sig = asyncMethod.Signature; - int numParams = (sig.IsStatic || sig.IsExplicitThis) ? sig.Length : sig.Length + 1; - for (int i = 0; i < numParams; i++) - codestream.EmitLdArg(i); + ILLocalVariable returnTaskLocal = emitter.NewLocal(returnType); - codestream.Emit(ILOpcode.call, emitter.NewToken(asyncMethod)); + TypeDesc executionAndSyncBlockStoreType = context.SystemModule.GetKnownType("System.Runtime.CompilerServices"u8, "ExecutionAndSyncBlockStore"u8); + ILLocalVariable executionAndSyncBlockStoreLocal = emitter.NewLocal(executionAndSyncBlockStoreType); - if (sig.ReturnType.IsVoid) + ILCodeLabel returnTaskLabel = emitter.NewCodeLabel(); + ILCodeLabel suspendedLabel = emitter.NewCodeLabel(); + ILCodeLabel finishedLabel = emitter.NewCodeLabel(); + + codestream.EmitLdLoca(executionAndSyncBlockStoreLocal); + codestream.Emit(ILOpcode.call, emitter.NewToken(executionAndSyncBlockStoreType.GetKnownMethod("Push"u8, null))); + + ILExceptionRegionBuilder tryFinallyRegion = emitter.NewFinallyRegion(); { - codestream.Emit(ILOpcode.call, emitter.NewToken(context.SystemModule.GetKnownType("System.Threading.Tasks"u8, "Task"u8).GetKnownMethod("get_CompletedTask"u8, null))); + codestream.BeginTry(tryFinallyRegion); + codestream.Emit(ILOpcode.nop); + ILExceptionRegionBuilder tryCatchRegion = emitter.NewCatchRegion(context.GetWellKnownType(WellKnownType.Object)); + { + codestream.BeginTry(tryCatchRegion); + + int localArg = 0; + if (!sig.IsStatic) + { + codestream.EmitLdArg(localArg++); + } + + for (int iArg = 0; iArg < sig.Length; iArg++) + { + codestream.EmitLdArg(localArg++); + } + + if (asyncMethod.OwningType.HasInstantiation) + { + var instantiatedType = (InstantiatedType)TypeSystemHelpers.InstantiateAsOpen(asyncMethod.OwningType); + asyncMethod = context.GetMethodForInstantiatedType(asyncMethod, instantiatedType); + } + + if (asyncMethod.HasInstantiation) + { + var inst = new TypeDesc[asyncMethod.Instantiation.Length]; + for (int i = 0; i < inst.Length; i++) + { + inst[i] = context.GetSignatureVariable(i, true); + } + asyncMethod = asyncMethod.MakeInstantiatedMethod(new Instantiation(inst)); + } + + codestream.Emit(ILOpcode.call, emitter.NewToken(asyncMethod)); + + if (logicalReturnType != null) + { + codestream.EmitStLoc(logicalResultLocal); + } + + MethodDesc asyncCallContinuationMd = context.SystemModule + .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + .GetKnownMethod("AsyncCallContinuation"u8, null); + + codestream.Emit(ILOpcode.call, emitter.NewToken(asyncCallContinuationMd)); + + codestream.Emit(ILOpcode.brfalse, finishedLabel); + codestream.Emit(ILOpcode.leave, suspendedLabel); + codestream.EmitLabel(finishedLabel); + + if (logicalReturnType != null) + { + codestream.EmitLdLoc(logicalResultLocal); + + MethodDesc fromResultMethod; + if (isValueTask) + { + fromResultMethod = context.SystemModule + .GetKnownType("System.Threading.Tasks"u8, "ValueTask"u8) + .GetKnownMethod("FromResult"u8, null) + .MakeInstantiatedMethod(new Instantiation(logicalReturnType)); + } + else + { + fromResultMethod = context.SystemModule + .GetKnownType("System.Threading.Tasks"u8, "Task"u8) + .GetKnownMethod("FromResult"u8, null) + .MakeInstantiatedMethod(new Instantiation(logicalReturnType)); + } + + codestream.Emit(ILOpcode.call, emitter.NewToken(fromResultMethod)); + } + else + { + MethodDesc getCompletedTaskMethod; + if (isValueTask) + { + getCompletedTaskMethod = context.SystemModule + .GetKnownType("System.Threading.Tasks"u8, "ValueTask"u8) + .GetKnownMethod("get_CompletedTask"u8, null); + } + else + { + getCompletedTaskMethod = context.SystemModule + .GetKnownType("System.Threading.Tasks"u8, "Task"u8) + .GetKnownMethod("get_CompletedTask"u8, null); + } + codestream.Emit(ILOpcode.call, emitter.NewToken(getCompletedTaskMethod)); + } + + codestream.EmitStLoc(returnTaskLocal); + codestream.Emit(ILOpcode.leave, returnTaskLabel); + + codestream.EndTry(tryCatchRegion); + } + // Catch + { + codestream.BeginHandler(tryCatchRegion); + + TypeDesc exceptionType = context.GetWellKnownType(WellKnownType.Exception); + + MethodDesc fromExceptionMd; + if (logicalReturnType != null) + { + MethodSignature fromExceptionSignature = new MethodSignature( + MethodSignatureFlags.Static, + genericParameterCount: 1, + returnType: ((MetadataType)returnType.GetTypeDefinition()).MakeInstantiatedType(context.GetSignatureVariable(0, true)), + parameters: new[] { exceptionType } + ); + + fromExceptionMd = context.SystemModule + .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + .GetKnownMethod(isValueTask ? "ValueTaskFromException"u8 : "TaskFromException"u8, fromExceptionSignature) + .MakeInstantiatedMethod(new Instantiation(logicalReturnType)); + } + else + { + MethodSignature fromExceptionSignature = new MethodSignature( + MethodSignatureFlags.Static, + genericParameterCount: 0, + returnType: returnType, + parameters: new[] { exceptionType } + ); + + fromExceptionMd = context.SystemModule + .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + .GetKnownMethod(isValueTask ? "ValueTaskFromException"u8 : "TaskFromException"u8, fromExceptionSignature); + } + + codestream.Emit(ILOpcode.call, emitter.NewToken(fromExceptionMd)); + codestream.EmitStLoc(returnTaskLocal); + codestream.Emit(ILOpcode.leave, returnTaskLabel); + codestream.EndHandler(tryCatchRegion); + } + + codestream.EmitLabel(suspendedLabel); + + MethodDesc finalizeTaskReturningThunkMd; + if (logicalReturnType != null) + { + MethodSignature finalizeReturningThunkSignature = new MethodSignature( + MethodSignatureFlags.Static, + genericParameterCount: 1, + returnType: ((MetadataType)returnType.GetTypeDefinition()).MakeInstantiatedType(context.GetSignatureVariable(0, true)), + parameters: Array.Empty() + ); + + finalizeTaskReturningThunkMd = context.SystemModule + .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + .GetKnownMethod(isValueTask ? "FinalizeValueTaskReturningThunk"u8 : "FinalizeTaskReturningThunk"u8, finalizeReturningThunkSignature) + .MakeInstantiatedMethod(new Instantiation(logicalReturnType)); + } + else + { + MethodSignature finalizeReturningThunkSignature = new MethodSignature( + MethodSignatureFlags.Static, + genericParameterCount: 0, + returnType: returnType, + parameters: Array.Empty() + ); + + finalizeTaskReturningThunkMd = context.SystemModule + .GetKnownType("System.Runtime.CompilerServices"u8, "AsyncHelpers"u8) + .GetKnownMethod(isValueTask ? "FinalizeValueTaskReturningThunk"u8 : "FinalizeTaskReturningThunk"u8, finalizeReturningThunkSignature); + } + + codestream.Emit(ILOpcode.call, emitter.NewToken(finalizeTaskReturningThunkMd)); + codestream.EmitStLoc(returnTaskLocal); + codestream.Emit(ILOpcode.leave, returnTaskLabel); + + codestream.EndTry(tryFinallyRegion); } - else + { - codestream.Emit(ILOpcode.call, emitter.NewToken(context.SystemModule.GetKnownType("System.Threading.Tasks"u8, "Task"u8).GetKnownMethod("FromResult"u8, null).MakeInstantiatedMethod(sig.ReturnType))); + codestream.BeginHandler(tryFinallyRegion); + + codestream.EmitLdLoca(executionAndSyncBlockStoreLocal); + codestream.Emit(ILOpcode.call, emitter.NewToken(executionAndSyncBlockStoreType.GetKnownMethod("Pop"u8, null))); + codestream.Emit(ILOpcode.endfinally); + codestream.EndHandler(tryFinallyRegion); } + codestream.EmitLabel(returnTaskLabel); + codestream.EmitLdLoc(returnTaskLocal); codestream.Emit(ILOpcode.ret); return emitter.Link(taskReturningMethod); diff --git a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs index 3625b6397f5d85..10d15a976b614b 100644 --- a/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs +++ b/src/coreclr/tools/Common/TypeSystem/IL/Stubs/ILEmitter.cs @@ -517,8 +517,15 @@ public class ILExceptionRegionBuilder internal ILCodeStream _endHandlerStream; internal int _endHandlerOffset; - internal ILExceptionRegionBuilder() + internal ILExceptionRegionKind _exceptionRegionKind; + internal TypeDesc _catchExceptionType; + + internal ILExceptionRegionBuilder(ILExceptionRegionKind exceptionRegionKind, TypeDesc catchExceptionType = null) { + _exceptionRegionKind = exceptionRegionKind; + _catchExceptionType = catchExceptionType; + Debug.Assert((exceptionRegionKind == ILExceptionRegionKind.Catch && catchExceptionType != null) + || (exceptionRegionKind != ILExceptionRegionKind.Catch && catchExceptionType == null)); } internal int TryOffset => _beginTryStream.RelativeToAbsoluteOffset(_beginTryOffset); @@ -669,7 +676,7 @@ public class ILEmitter private ArrayBuilder _codeStreams; private ArrayBuilder _locals; private ArrayBuilder _tokens; - private ArrayBuilder _finallyRegions; + private ArrayBuilder _exceptionRegions; public ILEmitter() { @@ -727,10 +734,17 @@ public ILCodeLabel NewCodeLabel() return newLabel; } + public ILExceptionRegionBuilder NewCatchRegion(TypeDesc exceptionType) + { + var region = new ILExceptionRegionBuilder(ILExceptionRegionKind.Catch, exceptionType); + _exceptionRegions.Add(region); + return region; + } + public ILExceptionRegionBuilder NewFinallyRegion() { - var region = new ILExceptionRegionBuilder(); - _finallyRegions.Add(region); + var region = new ILExceptionRegionBuilder(ILExceptionRegionKind.Finally); + _exceptionRegions.Add(region); return region; } @@ -782,21 +796,34 @@ public MethodIL Link(MethodDesc owningMethod) ILExceptionRegion[] exceptionRegions = null; - int numberOfExceptionRegions = _finallyRegions.Count; + int numberOfExceptionRegions = _exceptionRegions.Count; if (numberOfExceptionRegions > 0) { exceptionRegions = new ILExceptionRegion[numberOfExceptionRegions]; - - for (int i = 0; i < _finallyRegions.Count; i++) + for (int i = 0; i < _exceptionRegions.Count; i++) { - ILExceptionRegionBuilder region = _finallyRegions[i]; + ILExceptionRegionBuilder region = _exceptionRegions[i]; Debug.Assert(region.IsDefined); - exceptionRegions[i] = new ILExceptionRegion(ILExceptionRegionKind.Finally, + int exceptionTypeToken = (region._catchExceptionType != null) ? (int)NewToken(region._catchExceptionType) : 0; + + exceptionRegions[i] = new ILExceptionRegion(region._exceptionRegionKind, region.TryOffset, region.TryLength, region.HandlerOffset, region.HandlerLength, - classToken: 0, filterOffset: 0); + classToken: exceptionTypeToken, filterOffset: 0); } + + // Sort exception regions so that innermost (most nested) regions come first + // as this is required by the spec. + // Innermost regions have higher TryOffset and smaller TryLength. + Array.Sort(exceptionRegions, (a, b) => + { + int offsetComparison = b.TryOffset.CompareTo(a.TryOffset); + if (offsetComparison != 0) + return offsetComparison; + + return a.TryLength.CompareTo(b.TryLength); + }); } var result = new ILStubMethodIL(owningMethod, ilInstructions, _locals.ToArray(), _tokens.ToArray(), exceptionRegions, debugInfo); diff --git a/src/coreclr/vm/asyncthunks.cpp b/src/coreclr/vm/asyncthunks.cpp index 0f92fe053f734a..bd50ffb34649f6 100644 --- a/src/coreclr/vm/asyncthunks.cpp +++ b/src/coreclr/vm/asyncthunks.cpp @@ -58,7 +58,8 @@ bool MethodDesc::TryGenerateAsyncThunk(DynamicResolver** resolver, COR_ILMETHOD_ return true; } -// provided an async method, emits a Task-returning wrapper. +// Provided an async method, emits a Task-returning wrapper. +// The emitted code matches method EmitTaskReturningThunk in the Managed Type System. void MethodDesc::EmitTaskReturningThunk(MethodDesc* pAsyncOtherVariant, MetaSig& thunkMsig, ILStubLinker* pSL) { _ASSERTE(!pAsyncOtherVariant->IsAsyncThunkMethod());