Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/transformation/utils/scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ export interface Scope {
importStatements?: lua.Statement[];
loopContinued?: LoopContinued;
functionReturned?: boolean;
asyncTryHasReturn?: boolean;
asyncTryHasBreak?: boolean;
asyncTryHasContinue?: LoopContinued;
}

export interface HoistingResult {
Expand Down Expand Up @@ -84,6 +87,23 @@ export function findScope(context: TransformationContext, scopeTypes: ScopeType)
}
}

export function findAsyncTryScopeInStack(context: TransformationContext): Scope | undefined {
for (const scope of walkScopesUp(context)) {
if (scope.type === ScopeType.Function) return undefined;
if (scope.type === ScopeType.Try || scope.type === ScopeType.Catch) return scope;
}
return undefined;
}

/** Like findAsyncTryScopeInStack, but also stops at Loop boundaries. */
export function findAsyncTryScopeBeforeLoop(context: TransformationContext): Scope | undefined {
for (const scope of walkScopesUp(context)) {
if (scope.type === ScopeType.Function || scope.type === ScopeType.Loop) return undefined;
if (scope.type === ScopeType.Try || scope.type === ScopeType.Catch) return scope;
}
return undefined;
}

export function addScopeVariableDeclaration(scope: Scope, declaration: lua.VariableDeclarationStatement) {
scope.variableDeclarations ??= [];

Expand Down
29 changes: 27 additions & 2 deletions src/transformation/visitors/break-continue.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,22 @@ import * as ts from "typescript";
import { LuaTarget } from "../../CompilerOptions";
import * as lua from "../../LuaAST";
import { FunctionVisitor } from "../context";
import { findScope, LoopContinued, ScopeType } from "../utils/scope";
import { findAsyncTryScopeBeforeLoop, findScope, LoopContinued, ScopeType } from "../utils/scope";
import { isInAsyncFunction } from "../utils/typescript";

export const transformBreakStatement: FunctionVisitor<ts.BreakStatement> = (breakStatement, context) => {
void context;
const tryScope = isInAsyncFunction(breakStatement) ? findAsyncTryScopeBeforeLoop(context) : undefined;
if (tryScope) {
tryScope.asyncTryHasBreak = true;
return [
lua.createAssignmentStatement(
lua.createIdentifier("____hasBroken"),
lua.createBooleanLiteral(true),
breakStatement
),
lua.createReturnStatement([], breakStatement),
];
}
return lua.createBreakStatement(breakStatement);
};

Expand All @@ -28,6 +40,19 @@ export const transformContinueStatement: FunctionVisitor<ts.ContinueStatement> =
scope.loopContinued = continuedWith;
}

const tryScope = isInAsyncFunction(statement) ? findAsyncTryScopeBeforeLoop(context) : undefined;
if (tryScope) {
tryScope.asyncTryHasContinue = continuedWith;
return [
lua.createAssignmentStatement(
lua.createIdentifier("____hasContinued"),
lua.createBooleanLiteral(true),
statement
),
lua.createReturnStatement([], statement),
];
}

const label = `__continue${scope?.id ?? ""}`;

switch (continuedWith) {
Expand Down
90 changes: 76 additions & 14 deletions src/transformation/visitors/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { FunctionVisitor, TransformationContext } from "../context";
import { unsupportedForTarget, unsupportedForTargetButOverrideAvailable } from "../utils/diagnostics";
import { createUnpackCall } from "../utils/lua-ast";
import { transformLuaLibFunction } from "../utils/lualib";
import { Scope, ScopeType } from "../utils/scope";
import { findScope, LoopContinued, Scope, ScopeType } from "../utils/scope";
import { isInAsyncFunction, isInGeneratorFunction } from "../utils/typescript";
import { wrapInAsyncAwaiter } from "./async-await";
import { transformScopeBlock } from "./block";
Expand All @@ -14,7 +14,7 @@ import { isInMultiReturnFunction } from "./language-extensions/multi";
import { createReturnStatement } from "./return";

const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context) => {
const [tryBlock] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);
const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);

if (
(context.options.luaTarget === LuaTarget.Lua50 || context.options.luaTarget === LuaTarget.Lua51) &&
Expand All @@ -31,13 +31,14 @@ const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context)
return tryBlock.statements;
}

// __TS__AsyncAwaiter(<catch block>)
// __TS__AsyncAwaiter(<try block>)
const awaiter = wrapInAsyncAwaiter(context, tryBlock.statements, false);
const awaiterIdentifier = lua.createIdentifier("____try");
const awaiterDefinition = lua.createVariableDeclarationStatement(awaiterIdentifier, awaiter);

// local ____try = __TS__AsyncAwaiter(<catch block>)
const result: lua.Statement[] = [awaiterDefinition];
// Transform catch/finally and collect scope info before building the result
let catchScope: Scope | undefined;
const chainCalls: lua.Statement[] = [];

if (statement.finallyBlock) {
const awaiterFinally = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("finally"));
Expand All @@ -49,27 +50,88 @@ const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context)
[awaiterIdentifier, finallyFunction],
statement.finallyBlock
);
// ____try.finally(<finally function>)
result.push(lua.createExpressionStatement(finallyCall));
chainCalls.push(lua.createExpressionStatement(finallyCall));
}

if (statement.catchClause) {
// ____try.catch(<catch function>)
const [catchFunction] = transformCatchClause(context, statement.catchClause);
const [catchFunction, cScope] = transformCatchClause(context, statement.catchClause);
catchScope = cScope;
if (catchFunction.params) {
catchFunction.params.unshift(lua.createAnonymousIdentifier());
}

const awaiterCatch = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("catch"));
const catchCall = lua.createCallExpression(awaiterCatch, [awaiterIdentifier, catchFunction]);

// await ____try.catch(<catch function>)
const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, catchCall);
result.push(lua.createExpressionStatement(promiseAwait, statement));
chainCalls.push(lua.createExpressionStatement(promiseAwait, statement));
} else {
// await ____try
const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, awaiterIdentifier);
result.push(lua.createExpressionStatement(promiseAwait, statement));
chainCalls.push(lua.createExpressionStatement(promiseAwait, statement));
}

const hasReturn = tryScope.asyncTryHasReturn ?? catchScope?.asyncTryHasReturn;
const hasBreak = tryScope.asyncTryHasBreak ?? catchScope?.asyncTryHasBreak;
const hasContinue = tryScope.asyncTryHasContinue ?? catchScope?.asyncTryHasContinue;

// Build result in output order: flag declarations, awaiter, chain calls, post-checks
const result: lua.Statement[] = [];

if (hasReturn || hasBreak || hasContinue !== undefined) {
const flagDecls: lua.Identifier[] = [];
if (hasReturn) {
flagDecls.push(lua.createIdentifier("____hasReturned"));
flagDecls.push(lua.createIdentifier("____returnValue"));
}
if (hasBreak) {
flagDecls.push(lua.createIdentifier("____hasBroken"));
}
if (hasContinue !== undefined) {
flagDecls.push(lua.createIdentifier("____hasContinued"));
}
result.push(lua.createVariableDeclarationStatement(flagDecls));
}

result.push(awaiterDefinition);
result.push(...chainCalls);

if (hasReturn) {
result.push(
lua.createIfStatement(
lua.createIdentifier("____hasReturned"),
lua.createBlock([createReturnStatement(context, [lua.createIdentifier("____returnValue")], statement)])
)
);
}

if (hasBreak) {
result.push(
lua.createIfStatement(lua.createIdentifier("____hasBroken"), lua.createBlock([lua.createBreakStatement()]))
);
}

if (hasContinue !== undefined) {
const loopScope = findScope(context, ScopeType.Loop);
const label = `__continue${loopScope?.id ?? ""}`;

const continueStatements: lua.Statement[] = [];
switch (hasContinue) {
case LoopContinued.WithGoto:
continueStatements.push(lua.createGotoStatement(label));
break;
case LoopContinued.WithContinue:
continueStatements.push(lua.createContinueStatement());
break;
case LoopContinued.WithRepeatBreak:
continueStatements.push(
lua.createAssignmentStatement(lua.createIdentifier(label), lua.createBooleanLiteral(true))
);
continueStatements.push(lua.createBreakStatement());
break;
}

result.push(
lua.createIfStatement(lua.createIdentifier("____hasContinued"), lua.createBlock(continueStatements))
);
}

return result;
Expand Down
29 changes: 27 additions & 2 deletions src/transformation/visitors/return.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as lua from "../../LuaAST";
import { FunctionVisitor, TransformationContext } from "../context";
import { validateAssignment } from "../utils/assignment-validation";
import { createUnpackCall, wrapInTable } from "../utils/lua-ast";
import { ScopeType, walkScopesUp } from "../utils/scope";
import { findAsyncTryScopeInStack, ScopeType, walkScopesUp } from "../utils/scope";
import { transformArguments } from "./call";
import {
returnsMultiType,
Expand Down Expand Up @@ -68,6 +68,8 @@ export function transformExpressionBodyToReturnStatement(
}

export const transformReturnStatement: FunctionVisitor<ts.ReturnStatement> = (statement, context) => {
const asyncTryScope = isInAsyncFunction(statement) ? findAsyncTryScopeInStack(context) : undefined;

let results: lua.Expression[];

if (statement.expression) {
Expand All @@ -77,12 +79,35 @@ export const transformReturnStatement: FunctionVisitor<ts.ReturnStatement> = (st
validateAssignment(context, statement, expressionType, returnType);
}

results = transformExpressionsInReturn(context, statement.expression, isInTryCatch(context));
// In async try, we handle return propagation via flag variables (asyncTryHasReturn)
// rather than pcall return values (functionReturned set by isInTryCatch), so we skip
// isInTryCatch but still need insideTryCatch=true for multi-return wrapping.
results = transformExpressionsInReturn(
context,
statement.expression,
asyncTryScope ? true : isInTryCatch(context)
);
} else {
// Empty return
results = [];
}

if (asyncTryScope) {
asyncTryScope.asyncTryHasReturn = true;
const stmts: lua.Statement[] = [
lua.createAssignmentStatement(
lua.createIdentifier("____hasReturned"),
lua.createBooleanLiteral(true),
statement
),
];
if (results.length > 0) {
stmts.push(lua.createAssignmentStatement(lua.createIdentifier("____returnValue"), results[0], statement));
}
stmts.push(lua.createReturnStatement([], statement));
return stmts;
}

return createReturnStatement(context, results, statement);
};

Expand Down
Loading
Loading