diff --git a/src/transformation/utils/scope.ts b/src/transformation/utils/scope.ts index fa66ee3e2..314f1d795 100644 --- a/src/transformation/utils/scope.ts +++ b/src/transformation/utils/scope.ts @@ -38,6 +38,9 @@ export interface Scope { importStatements?: lua.Statement[]; loopContinued?: LoopContinued; functionReturned?: boolean; + asyncTryHasReturn?: boolean; + asyncTryHasBreak?: boolean; + asyncTryHasContinue?: LoopContinued; } export interface HoistingResult { @@ -84,6 +87,23 @@ export function findScope(context: TransformationContext, scopeTypes: ScopeType) } } +export function findAsyncTryScopeInStack(context: TransformationContext): Scope | undefined { + for (const scope of walkScopesUp(context)) { + if (scope.type === ScopeType.Function) return undefined; + if (scope.type === ScopeType.Try || scope.type === ScopeType.Catch) return scope; + } + return undefined; +} + +/** Like findAsyncTryScopeInStack, but also stops at Loop boundaries. */ +export function findAsyncTryScopeBeforeLoop(context: TransformationContext): Scope | undefined { + for (const scope of walkScopesUp(context)) { + if (scope.type === ScopeType.Function || scope.type === ScopeType.Loop) return undefined; + if (scope.type === ScopeType.Try || scope.type === ScopeType.Catch) return scope; + } + return undefined; +} + export function addScopeVariableDeclaration(scope: Scope, declaration: lua.VariableDeclarationStatement) { scope.variableDeclarations ??= []; diff --git a/src/transformation/visitors/break-continue.ts b/src/transformation/visitors/break-continue.ts index 2e64a38eb..e41a7934b 100644 --- a/src/transformation/visitors/break-continue.ts +++ b/src/transformation/visitors/break-continue.ts @@ -2,10 +2,22 @@ import * as ts from "typescript"; import { LuaTarget } from "../../CompilerOptions"; import * as lua from "../../LuaAST"; import { FunctionVisitor } from "../context"; -import { findScope, LoopContinued, ScopeType } from "../utils/scope"; +import { findAsyncTryScopeBeforeLoop, findScope, LoopContinued, ScopeType } from "../utils/scope"; +import { isInAsyncFunction } from "../utils/typescript"; export const transformBreakStatement: FunctionVisitor = (breakStatement, context) => { - void context; + const tryScope = isInAsyncFunction(breakStatement) ? findAsyncTryScopeBeforeLoop(context) : undefined; + if (tryScope) { + tryScope.asyncTryHasBreak = true; + return [ + lua.createAssignmentStatement( + lua.createIdentifier("____hasBroken"), + lua.createBooleanLiteral(true), + breakStatement + ), + lua.createReturnStatement([], breakStatement), + ]; + } return lua.createBreakStatement(breakStatement); }; @@ -28,6 +40,19 @@ export const transformContinueStatement: FunctionVisitor = scope.loopContinued = continuedWith; } + const tryScope = isInAsyncFunction(statement) ? findAsyncTryScopeBeforeLoop(context) : undefined; + if (tryScope) { + tryScope.asyncTryHasContinue = continuedWith; + return [ + lua.createAssignmentStatement( + lua.createIdentifier("____hasContinued"), + lua.createBooleanLiteral(true), + statement + ), + lua.createReturnStatement([], statement), + ]; + } + const label = `__continue${scope?.id ?? ""}`; switch (continuedWith) { diff --git a/src/transformation/visitors/errors.ts b/src/transformation/visitors/errors.ts index b2c9d36d1..cc73b2ac0 100644 --- a/src/transformation/visitors/errors.ts +++ b/src/transformation/visitors/errors.ts @@ -5,7 +5,7 @@ import { FunctionVisitor, TransformationContext } from "../context"; import { unsupportedForTarget, unsupportedForTargetButOverrideAvailable } from "../utils/diagnostics"; import { createUnpackCall } from "../utils/lua-ast"; import { transformLuaLibFunction } from "../utils/lualib"; -import { Scope, ScopeType } from "../utils/scope"; +import { findScope, LoopContinued, Scope, ScopeType } from "../utils/scope"; import { isInAsyncFunction, isInGeneratorFunction } from "../utils/typescript"; import { wrapInAsyncAwaiter } from "./async-await"; import { transformScopeBlock } from "./block"; @@ -14,7 +14,7 @@ import { isInMultiReturnFunction } from "./language-extensions/multi"; import { createReturnStatement } from "./return"; const transformAsyncTry: FunctionVisitor = (statement, context) => { - const [tryBlock] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try); + const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try); if ( (context.options.luaTarget === LuaTarget.Lua50 || context.options.luaTarget === LuaTarget.Lua51) && @@ -31,13 +31,14 @@ const transformAsyncTry: FunctionVisitor = (statement, context) return tryBlock.statements; } - // __TS__AsyncAwaiter() + // __TS__AsyncAwaiter() const awaiter = wrapInAsyncAwaiter(context, tryBlock.statements, false); const awaiterIdentifier = lua.createIdentifier("____try"); const awaiterDefinition = lua.createVariableDeclarationStatement(awaiterIdentifier, awaiter); - // local ____try = __TS__AsyncAwaiter() - const result: lua.Statement[] = [awaiterDefinition]; + // Transform catch/finally and collect scope info before building the result + let catchScope: Scope | undefined; + const chainCalls: lua.Statement[] = []; if (statement.finallyBlock) { const awaiterFinally = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("finally")); @@ -49,27 +50,88 @@ const transformAsyncTry: FunctionVisitor = (statement, context) [awaiterIdentifier, finallyFunction], statement.finallyBlock ); - // ____try.finally() - result.push(lua.createExpressionStatement(finallyCall)); + chainCalls.push(lua.createExpressionStatement(finallyCall)); } if (statement.catchClause) { - // ____try.catch() - const [catchFunction] = transformCatchClause(context, statement.catchClause); + const [catchFunction, cScope] = transformCatchClause(context, statement.catchClause); + catchScope = cScope; if (catchFunction.params) { catchFunction.params.unshift(lua.createAnonymousIdentifier()); } const awaiterCatch = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("catch")); const catchCall = lua.createCallExpression(awaiterCatch, [awaiterIdentifier, catchFunction]); - - // await ____try.catch() const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, catchCall); - result.push(lua.createExpressionStatement(promiseAwait, statement)); + chainCalls.push(lua.createExpressionStatement(promiseAwait, statement)); } else { - // await ____try const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, awaiterIdentifier); - result.push(lua.createExpressionStatement(promiseAwait, statement)); + chainCalls.push(lua.createExpressionStatement(promiseAwait, statement)); + } + + const hasReturn = tryScope.asyncTryHasReturn ?? catchScope?.asyncTryHasReturn; + const hasBreak = tryScope.asyncTryHasBreak ?? catchScope?.asyncTryHasBreak; + const hasContinue = tryScope.asyncTryHasContinue ?? catchScope?.asyncTryHasContinue; + + // Build result in output order: flag declarations, awaiter, chain calls, post-checks + const result: lua.Statement[] = []; + + if (hasReturn || hasBreak || hasContinue !== undefined) { + const flagDecls: lua.Identifier[] = []; + if (hasReturn) { + flagDecls.push(lua.createIdentifier("____hasReturned")); + flagDecls.push(lua.createIdentifier("____returnValue")); + } + if (hasBreak) { + flagDecls.push(lua.createIdentifier("____hasBroken")); + } + if (hasContinue !== undefined) { + flagDecls.push(lua.createIdentifier("____hasContinued")); + } + result.push(lua.createVariableDeclarationStatement(flagDecls)); + } + + result.push(awaiterDefinition); + result.push(...chainCalls); + + if (hasReturn) { + result.push( + lua.createIfStatement( + lua.createIdentifier("____hasReturned"), + lua.createBlock([createReturnStatement(context, [lua.createIdentifier("____returnValue")], statement)]) + ) + ); + } + + if (hasBreak) { + result.push( + lua.createIfStatement(lua.createIdentifier("____hasBroken"), lua.createBlock([lua.createBreakStatement()])) + ); + } + + if (hasContinue !== undefined) { + const loopScope = findScope(context, ScopeType.Loop); + const label = `__continue${loopScope?.id ?? ""}`; + + const continueStatements: lua.Statement[] = []; + switch (hasContinue) { + case LoopContinued.WithGoto: + continueStatements.push(lua.createGotoStatement(label)); + break; + case LoopContinued.WithContinue: + continueStatements.push(lua.createContinueStatement()); + break; + case LoopContinued.WithRepeatBreak: + continueStatements.push( + lua.createAssignmentStatement(lua.createIdentifier(label), lua.createBooleanLiteral(true)) + ); + continueStatements.push(lua.createBreakStatement()); + break; + } + + result.push( + lua.createIfStatement(lua.createIdentifier("____hasContinued"), lua.createBlock(continueStatements)) + ); } return result; diff --git a/src/transformation/visitors/return.ts b/src/transformation/visitors/return.ts index 14d785166..5291bb343 100644 --- a/src/transformation/visitors/return.ts +++ b/src/transformation/visitors/return.ts @@ -3,7 +3,7 @@ import * as lua from "../../LuaAST"; import { FunctionVisitor, TransformationContext } from "../context"; import { validateAssignment } from "../utils/assignment-validation"; import { createUnpackCall, wrapInTable } from "../utils/lua-ast"; -import { ScopeType, walkScopesUp } from "../utils/scope"; +import { findAsyncTryScopeInStack, ScopeType, walkScopesUp } from "../utils/scope"; import { transformArguments } from "./call"; import { returnsMultiType, @@ -68,6 +68,8 @@ export function transformExpressionBodyToReturnStatement( } export const transformReturnStatement: FunctionVisitor = (statement, context) => { + const asyncTryScope = isInAsyncFunction(statement) ? findAsyncTryScopeInStack(context) : undefined; + let results: lua.Expression[]; if (statement.expression) { @@ -77,12 +79,35 @@ export const transformReturnStatement: FunctionVisitor = (st validateAssignment(context, statement, expressionType, returnType); } - results = transformExpressionsInReturn(context, statement.expression, isInTryCatch(context)); + // In async try, we handle return propagation via flag variables (asyncTryHasReturn) + // rather than pcall return values (functionReturned set by isInTryCatch), so we skip + // isInTryCatch but still need insideTryCatch=true for multi-return wrapping. + results = transformExpressionsInReturn( + context, + statement.expression, + asyncTryScope ? true : isInTryCatch(context) + ); } else { // Empty return results = []; } + if (asyncTryScope) { + asyncTryScope.asyncTryHasReturn = true; + const stmts: lua.Statement[] = [ + lua.createAssignmentStatement( + lua.createIdentifier("____hasReturned"), + lua.createBooleanLiteral(true), + statement + ), + ]; + if (results.length > 0) { + stmts.push(lua.createAssignmentStatement(lua.createIdentifier("____returnValue"), results[0], statement)); + } + stmts.push(lua.createReturnStatement([], statement)); + return stmts; + } + return createReturnStatement(context, results, statement); }; diff --git a/test/unit/builtins/async-await.spec.ts b/test/unit/builtins/async-await.spec.ts index 45adb4518..dcbadfe92 100644 --- a/test/unit/builtins/async-await.spec.ts +++ b/test/unit/builtins/async-await.spec.ts @@ -815,4 +815,146 @@ describe("try/catch in async function", () => { }, }); }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + test("return inside try with deferred promise (#1706)", () => { + util.testFunction` + let resolveLater!: (value: string) => void; + + function deferredPromise(): Promise { + return new Promise(resolve => { + resolveLater = (v) => resolve(v); + }); + } + + async function fn(): Promise { + try { + return await deferredPromise(); + } catch { + return 'caught'; + } + log('unreachable!'); + } + + const promise = fn(); + resolveLater('ok'); + promise.then(v => log(v)); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["ok"]); + }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + test("return inside try in loop with deferred promise (#1706)", () => { + util.testFunction` + let resolveLater!: (value: string) => void; + + function deferredPromise(): Promise { + return new Promise(resolve => { + resolveLater = (v) => resolve(v); + }); + } + + async function fn(): Promise { + while (true) { + try { + return await deferredPromise(); + } catch { + return 'caught'; + } + log('unreachable!'); + } + } + + const promise = fn(); + resolveLater('ok'); + promise.then(v => log(v)); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["ok"]); + }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + test("return from catch with deferred promise (#1706)", () => { + util.testFunction` + let rejectLater!: (reason: string) => void; + + function deferredPromise(): Promise { + return new Promise((_, reject) => { + rejectLater = (r) => reject(r); + }); + } + + async function fn(): Promise { + try { + return await deferredPromise(); + } catch (e) { + return 'caught: ' + e; + } + log('unreachable!'); + } + + const promise = fn(); + rejectLater('oops'); + promise.then(v => log(v)); + + return allLogs; + ` + .setTsHeader(promiseTestLib) + .expectToEqual(["caught: oops"]); + }); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + util.testEachVersion( + "break inside try in async loop (#1706)", + () => util.testModule` + export let result = "not set"; + async function fn(): Promise { + while (true) { + try { + await Promise.resolve(); + break; + } catch {} + } + result = "done"; + } + fn(); + `, + { + ...util.expectEachVersionExceptJit(builder => builder.expectToEqual({ result: "done" })), + [LuaTarget.Lua50]: builder => + builder.expectToHaveDiagnostics([unsupportedForTargetButOverrideAvailable.code]), + [LuaTarget.Lua51]: builder => + builder.expectToHaveDiagnostics([unsupportedForTargetButOverrideAvailable.code]), + } + ); + + // https://github.com/TypeScriptToLua/TypeScriptToLua/issues/1706 + util.testEachVersion( + "continue inside try in async loop (#1706)", + () => util.testModule` + export const results: number[] = []; + async function fn(): Promise { + for (let i = 0; i < 3; i++) { + try { + await Promise.resolve(); + if (i === 1) continue; + } catch {} + results.push(i); + } + } + fn(); + `, + { + ...util.expectEachVersionExceptJit(builder => builder.expectToEqual({ results: [0, 2] })), + [LuaTarget.Lua50]: builder => + builder.expectToHaveDiagnostics([unsupportedForTargetButOverrideAvailable.code]), + [LuaTarget.Lua51]: builder => + builder.expectToHaveDiagnostics([unsupportedForTargetButOverrideAvailable.code]), + } + ); });