diff --git a/src/transformation/utils/scope.ts b/src/transformation/utils/scope.ts index fa66ee3e2..314f1d795 100644 --- a/src/transformation/utils/scope.ts +++ b/src/transformation/utils/scope.ts @@ -38,6 +38,9 @@ export interface Scope { importStatements?: lua.Statement[]; loopContinued?: LoopContinued; functionReturned?: boolean; + asyncTryHasReturn?: boolean; + asyncTryHasBreak?: boolean; + asyncTryHasContinue?: LoopContinued; } export interface HoistingResult { @@ -84,6 +87,23 @@ export function findScope(context: TransformationContext, scopeTypes: ScopeType) } } +export function findAsyncTryScopeInStack(context: TransformationContext): Scope | undefined { + for (const scope of walkScopesUp(context)) { + if (scope.type === ScopeType.Function) return undefined; + if (scope.type === ScopeType.Try || scope.type === ScopeType.Catch) return scope; + } + return undefined; +} + +/** Like findAsyncTryScopeInStack, but also stops at Loop boundaries. */ +export function findAsyncTryScopeBeforeLoop(context: TransformationContext): Scope | undefined { + for (const scope of walkScopesUp(context)) { + if (scope.type === ScopeType.Function || scope.type === ScopeType.Loop) return undefined; + if (scope.type === ScopeType.Try || scope.type === ScopeType.Catch) return scope; + } + return undefined; +} + export function addScopeVariableDeclaration(scope: Scope, declaration: lua.VariableDeclarationStatement) { scope.variableDeclarations ??= []; diff --git a/src/transformation/visitors/break-continue.ts b/src/transformation/visitors/break-continue.ts index 2e64a38eb..e41a7934b 100644 --- a/src/transformation/visitors/break-continue.ts +++ b/src/transformation/visitors/break-continue.ts @@ -2,10 +2,22 @@ import * as ts from "typescript"; import { LuaTarget } from "../../CompilerOptions"; import * as lua from "../../LuaAST"; import { FunctionVisitor } from "../context"; -import { findScope, LoopContinued, ScopeType } from "../utils/scope"; +import { findAsyncTryScopeBeforeLoop, findScope, LoopContinued, ScopeType } from "../utils/scope"; +import { isInAsyncFunction } from "../utils/typescript"; export const transformBreakStatement: FunctionVisitor = (breakStatement, context) => { - void context; + const tryScope = isInAsyncFunction(breakStatement) ? findAsyncTryScopeBeforeLoop(context) : undefined; + if (tryScope) { + tryScope.asyncTryHasBreak = true; + return [ + lua.createAssignmentStatement( + lua.createIdentifier("____hasBroken"), + lua.createBooleanLiteral(true), + breakStatement + ), + lua.createReturnStatement([], breakStatement), + ]; + } return lua.createBreakStatement(breakStatement); }; @@ -28,6 +40,19 @@ export const transformContinueStatement: FunctionVisitor = scope.loopContinued = continuedWith; } + const tryScope = isInAsyncFunction(statement) ? findAsyncTryScopeBeforeLoop(context) : undefined; + if (tryScope) { + tryScope.asyncTryHasContinue = continuedWith; + return [ + lua.createAssignmentStatement( + lua.createIdentifier("____hasContinued"), + lua.createBooleanLiteral(true), + statement + ), + lua.createReturnStatement([], statement), + ]; + } + const label = `__continue${scope?.id ?? ""}`; switch (continuedWith) { diff --git a/src/transformation/visitors/errors.ts b/src/transformation/visitors/errors.ts index b2c9d36d1..cc73b2ac0 100644 --- a/src/transformation/visitors/errors.ts +++ b/src/transformation/visitors/errors.ts @@ -5,7 +5,7 @@ import { FunctionVisitor, TransformationContext } from "../context"; import { unsupportedForTarget, unsupportedForTargetButOverrideAvailable } from "../utils/diagnostics"; import { createUnpackCall } from "../utils/lua-ast"; import { transformLuaLibFunction } from "../utils/lualib"; -import { Scope, ScopeType } from "../utils/scope"; +import { findScope, LoopContinued, Scope, ScopeType } from "../utils/scope"; import { isInAsyncFunction, isInGeneratorFunction } from "../utils/typescript"; import { wrapInAsyncAwaiter } from "./async-await"; import { transformScopeBlock } from "./block"; @@ -14,7 +14,7 @@ import { isInMultiReturnFunction } from "./language-extensions/multi"; import { createReturnStatement } from "./return"; const transformAsyncTry: FunctionVisitor = (statement, context) => { - const [tryBlock] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try); + const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try); if ( (context.options.luaTarget === LuaTarget.Lua50 || context.options.luaTarget === LuaTarget.Lua51) && @@ -31,13 +31,14 @@ const transformAsyncTry: FunctionVisitor = (statement, context) return tryBlock.statements; } - // __TS__AsyncAwaiter() + // __TS__AsyncAwaiter() const awaiter = wrapInAsyncAwaiter(context, tryBlock.statements, false); const awaiterIdentifier = lua.createIdentifier("____try"); const awaiterDefinition = lua.createVariableDeclarationStatement(awaiterIdentifier, awaiter); - // local ____try = __TS__AsyncAwaiter() - const result: lua.Statement[] = [awaiterDefinition]; + // Transform catch/finally and collect scope info before building the result + let catchScope: Scope | undefined; + const chainCalls: lua.Statement[] = []; if (statement.finallyBlock) { const awaiterFinally = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("finally")); @@ -49,27 +50,88 @@ const transformAsyncTry: FunctionVisitor = (statement, context) [awaiterIdentifier, finallyFunction], statement.finallyBlock ); - // ____try.finally() - result.push(lua.createExpressionStatement(finallyCall)); + chainCalls.push(lua.createExpressionStatement(finallyCall)); } if (statement.catchClause) { - // ____try.catch() - const [catchFunction] = transformCatchClause(context, statement.catchClause); + const [catchFunction, cScope] = transformCatchClause(context, statement.catchClause); + catchScope = cScope; if (catchFunction.params) { catchFunction.params.unshift(lua.createAnonymousIdentifier()); } const awaiterCatch = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("catch")); const catchCall = lua.createCallExpression(awaiterCatch, [awaiterIdentifier, catchFunction]); - - // await ____try.catch() const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, catchCall); - result.push(lua.createExpressionStatement(promiseAwait, statement)); + chainCalls.push(lua.createExpressionStatement(promiseAwait, statement)); } else { - // await ____try const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, awaiterIdentifier); - result.push(lua.createExpressionStatement(promiseAwait, statement)); + chainCalls.push(lua.createExpressionStatement(promiseAwait, statement)); + } + + const hasReturn = tryScope.asyncTryHasReturn ?? catchScope?.asyncTryHasReturn; + const hasBreak = tryScope.asyncTryHasBreak ?? catchScope?.asyncTryHasBreak; + const hasContinue = tryScope.asyncTryHasContinue ?? catchScope?.asyncTryHasContinue; + + // Build result in output order: flag declarations, awaiter, chain calls, post-checks + const result: lua.Statement[] = []; + + if (hasReturn || hasBreak || hasContinue !== undefined) { + const flagDecls: lua.Identifier[] = []; + if (hasReturn) { + flagDecls.push(lua.createIdentifier("____hasReturned")); + flagDecls.push(lua.createIdentifier("____returnValue")); + } + if (hasBreak) { + flagDecls.push(lua.createIdentifier("____hasBroken")); + } + if (hasContinue !== undefined) { + flagDecls.push(lua.createIdentifier("____hasContinued")); + } + result.push(lua.createVariableDeclarationStatement(flagDecls)); + } + + result.push(awaiterDefinition); + result.push(...chainCalls); + + if (hasReturn) { + result.push( + lua.createIfStatement( + lua.createIdentifier("____hasReturned"), + lua.createBlock([createReturnStatement(context, [lua.createIdentifier("____returnValue")], statement)]) + ) + ); + } + + if (hasBreak) { + result.push( + lua.createIfStatement(lua.createIdentifier("____hasBroken"), lua.createBlock([lua.createBreakStatement()])) + ); + } + + if (hasContinue !== undefined) { + const loopScope = findScope(context, ScopeType.Loop); + const label = `__continue${loopScope?.id ?? ""}`; + + const continueStatements: lua.Statement[] = []; + switch (hasContinue) { + case LoopContinued.WithGoto: + continueStatements.push(lua.createGotoStatement(label)); + break; + case LoopContinued.WithContinue: + continueStatements.push(lua.createContinueStatement()); + break; + case LoopContinued.WithRepeatBreak: + continueStatements.push( + lua.createAssignmentStatement(lua.createIdentifier(label), lua.createBooleanLiteral(true)) + ); + continueStatements.push(lua.createBreakStatement()); + break; + } + + result.push( + lua.createIfStatement(lua.createIdentifier("____hasContinued"), lua.createBlock(continueStatements)) + ); } return result; diff --git a/src/transformation/visitors/return.ts b/src/transformation/visitors/return.ts index 14d785166..3a61314ff 100644 --- a/src/transformation/visitors/return.ts +++ b/src/transformation/visitors/return.ts @@ -3,7 +3,7 @@ import * as lua from "../../LuaAST"; import { FunctionVisitor, TransformationContext } from "../context"; import { validateAssignment } from "../utils/assignment-validation"; import { createUnpackCall, wrapInTable } from "../utils/lua-ast"; -import { ScopeType, walkScopesUp } from "../utils/scope"; +import { findAsyncTryScopeInStack, ScopeType, walkScopesUp } from "../utils/scope"; import { transformArguments } from "./call"; import { returnsMultiType, @@ -16,11 +16,7 @@ import { import { invalidMultiFunctionReturnType } from "../utils/diagnostics"; import { isInAsyncFunction } from "../utils/typescript"; -function transformExpressionsInReturn( - context: TransformationContext, - node: ts.Expression, - insideTryCatch: boolean -): lua.Expression[] { +function transformExpressionsInReturn(context: TransformationContext, node: ts.Expression): lua.Expression[] { const expressionType = context.checker.getTypeAtLocation(node); // skip type assertions @@ -36,20 +32,7 @@ function transformExpressionsInReturn( context.diagnostics.push(invalidMultiFunctionReturnType(innerNode)); } - let returnValues = transformArguments(context, innerNode.arguments); - if (insideTryCatch) { - returnValues = [wrapInTable(...returnValues)]; // Wrap results when returning inside try/catch - } - return returnValues; - } - - // Force-wrap LuaMultiReturn when returning inside try/catch - if ( - insideTryCatch && - returnsMultiType(context, innerNode) && - !shouldMultiReturnCallBeWrapped(context, innerNode) - ) { - return [wrapInTable(context.transformExpression(node))]; + return transformArguments(context, innerNode.arguments); } } else if (isInMultiReturnFunction(context, innerNode) && isMultiReturnType(expressionType)) { // Unpack objects typed as LuaMultiReturn @@ -63,12 +46,32 @@ export function transformExpressionBodyToReturnStatement( context: TransformationContext, node: ts.Expression ): lua.Statement { - const expressions = transformExpressionsInReturn(context, node, false); + const expressions = transformExpressionsInReturn(context, node); return createReturnStatement(context, expressions, node); } +function transformReturnExpressionForTryCatch(context: TransformationContext, node: ts.Expression): lua.Expression { + const innerNode = ts.skipOuterExpressions(node, ts.OuterExpressionKinds.Assertions); + + if (ts.isCallExpression(innerNode)) { + if (isMultiFunctionCall(context, innerNode)) { + const type = context.checker.getContextualType(node); + if (type && !canBeMultiReturnType(type)) { + context.diagnostics.push(invalidMultiFunctionReturnType(innerNode)); + } + return wrapInTable(...transformArguments(context, innerNode.arguments)); + } + + if (returnsMultiType(context, innerNode) && !shouldMultiReturnCallBeWrapped(context, innerNode)) { + return wrapInTable(context.transformExpression(node)); + } + } + + return context.transformExpression(node); +} + export const transformReturnStatement: FunctionVisitor = (statement, context) => { - let results: lua.Expression[]; + const asyncTryScope = isInAsyncFunction(statement) ? findAsyncTryScopeInStack(context) : undefined; if (statement.expression) { const expressionType = context.checker.getTypeAtLocation(statement.expression); @@ -76,11 +79,32 @@ export const transformReturnStatement: FunctionVisitor = (st if (returnType) { validateAssignment(context, statement, expressionType, returnType); } + } - results = transformExpressionsInReturn(context, statement.expression, isInTryCatch(context)); - } else { - // Empty return + if (asyncTryScope) { + asyncTryScope.asyncTryHasReturn = true; + const stmts: lua.Statement[] = [ + lua.createAssignmentStatement( + lua.createIdentifier("____hasReturned"), + lua.createBooleanLiteral(true), + statement + ), + ]; + if (statement.expression) { + const returnValue = transformReturnExpressionForTryCatch(context, statement.expression); + stmts.push(lua.createAssignmentStatement(lua.createIdentifier("____returnValue"), returnValue, statement)); + } + stmts.push(lua.createReturnStatement([], statement)); + return stmts; + } + + let results: lua.Expression[]; + if (!statement.expression) { results = []; + } else if (isInTryCatch(context)) { + results = [transformReturnExpressionForTryCatch(context, statement.expression)]; + } else { + results = transformExpressionsInReturn(context, statement.expression); } return createReturnStatement(context, results, statement); diff --git a/test/unit/builtins/async-await.spec.ts b/test/unit/builtins/async-await.spec.ts index 45adb4518..7a511d7b1 100644 --- a/test/unit/builtins/async-await.spec.ts +++ b/test/unit/builtins/async-await.spec.ts @@ -815,4 +815,196 @@ describe("try/catch in async function", () => { }, }); }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + test("return inside try with deferred promise (#1706)", () => { + util.testFunction` + let resolveLater!: (value: string) => void; + + function deferredPromise(): Promise { + return new Promise(resolve => { + resolveLater = (v) => resolve(v); + }); + } + + async function fn(): Promise { + try { + return await deferredPromise(); + } catch { + return 'caught'; + } + log('unreachable!'); + } + + const promise = fn(); + resolveLater('ok'); + promise.then(v => log(v)); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["ok"]); + }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + test("return inside try in loop with deferred promise (#1706)", () => { + util.testFunction` + let resolveLater!: (value: string) => void; + + function deferredPromise(): Promise { + return new Promise(resolve => { + resolveLater = (v) => resolve(v); + }); + } + + async function fn(): Promise { + while (true) { + try { + return await deferredPromise(); + } catch { + return 'caught'; + } + log('unreachable!'); + } + } + + const promise = fn(); + resolveLater('ok'); + promise.then(v => log(v)); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["ok"]); + }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + test("return from catch with deferred promise (#1706)", () => { + util.testFunction` + let rejectLater!: (reason: string) => void; + + function deferredPromise(): Promise { + return new Promise((_, reject) => { + rejectLater = (r) => reject(r); + }); + } + + async function fn(): Promise { + try { + return await deferredPromise(); + } catch (e) { + return 'caught: ' + e; + } + log('unreachable!'); + } + + const promise = fn(); + rejectLater('oops'); + promise.then(v => log(v)); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["caught: oops"]); + }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + util.testEachVersion( + "break inside try in async loop (#1706)", + () => util.testModule` + export let result = "not set"; + async function fn(): Promise { + while (true) { + try { + await Promise.resolve(); + break; + } catch {} + } + result = "done"; + } + fn(); + `, + { + ...util.expectEachVersionExceptJit(builder => builder.expectToEqual({ result: "done" })), + [LuaTarget.Lua50]: builder => + builder.expectToHaveDiagnostics([unsupportedForTargetButOverrideAvailable.code]), + [LuaTarget.Lua51]: builder => + builder.expectToHaveDiagnostics([unsupportedForTargetButOverrideAvailable.code]), + } + ); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + util.testEachVersion( + "continue inside try in async loop (#1706)", + () => util.testModule` + export const results: number[] = []; + async function fn(): Promise { + for (let i = 0; i < 3; i++) { + try { + await Promise.resolve(); + if (i === 1) continue; + } catch {} + results.push(i); + } + } + fn(); + `, + { + ...util.expectEachVersionExceptJit(builder => builder.expectToEqual({ results: [0, 2] })), + [LuaTarget.Lua50]: builder => + builder.expectToHaveDiagnostics([unsupportedForTargetButOverrideAvailable.code]), + [LuaTarget.Lua51]: builder => + builder.expectToHaveDiagnostics([unsupportedForTargetButOverrideAvailable.code]), + } + ); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + test("multi return from try in async function (#1706)", () => { + util.testFunction` + async function fn(): Promise> { + try { + await Promise.resolve(); + return $multi("foo", "bar"); + } catch { + return $multi("err", "err"); + } + } + + let result: string[] = []; + fn().then(v => { const [a, b] = v; result = [a, b]; }); + + return result; + ` + .withLanguageExtensions() + .expectToEqual(["foo", "bar"]); + }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + test("return inside try with finally (#1706)", () => { + util.testFunction` + let resolveLater!: (value: string) => void; + + function deferredPromise(): Promise { + return new Promise(resolve => { + resolveLater = (v) => resolve(v); + }); + } + + async function fn(): Promise { + try { + return await deferredPromise(); + } finally { + log('finally'); + } + } + + const promise = fn(); + resolveLater('ok'); + promise.then(v => log(v)); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["finally", "ok"]); + }); });