Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/transformation/utils/scope.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ export interface Scope {
importStatements?: lua.Statement[];
loopContinued?: LoopContinued;
functionReturned?: boolean;
asyncTryHasReturn?: boolean;
asyncTryHasBreak?: boolean;
asyncTryHasContinue?: LoopContinued;
}

export interface HoistingResult {
Expand Down Expand Up @@ -84,6 +87,23 @@ export function findScope(context: TransformationContext, scopeTypes: ScopeType)
}
}

export function findAsyncTryScopeInStack(context: TransformationContext): Scope | undefined {
for (const scope of walkScopesUp(context)) {
if (scope.type === ScopeType.Function) return undefined;
if (scope.type === ScopeType.Try || scope.type === ScopeType.Catch) return scope;
}
return undefined;
}

/** Like findAsyncTryScopeInStack, but also stops at Loop boundaries. */
export function findAsyncTryScopeBeforeLoop(context: TransformationContext): Scope | undefined {
for (const scope of walkScopesUp(context)) {
if (scope.type === ScopeType.Function || scope.type === ScopeType.Loop) return undefined;
if (scope.type === ScopeType.Try || scope.type === ScopeType.Catch) return scope;
}
return undefined;
}

export function addScopeVariableDeclaration(scope: Scope, declaration: lua.VariableDeclarationStatement) {
scope.variableDeclarations ??= [];

Expand Down
29 changes: 27 additions & 2 deletions src/transformation/visitors/break-continue.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,22 @@ import * as ts from "typescript";
import { LuaTarget } from "../../CompilerOptions";
import * as lua from "../../LuaAST";
import { FunctionVisitor } from "../context";
import { findScope, LoopContinued, ScopeType } from "../utils/scope";
import { findAsyncTryScopeBeforeLoop, findScope, LoopContinued, ScopeType } from "../utils/scope";
import { isInAsyncFunction } from "../utils/typescript";

export const transformBreakStatement: FunctionVisitor<ts.BreakStatement> = (breakStatement, context) => {
void context;
const tryScope = isInAsyncFunction(breakStatement) ? findAsyncTryScopeBeforeLoop(context) : undefined;
if (tryScope) {
tryScope.asyncTryHasBreak = true;
return [
lua.createAssignmentStatement(
lua.createIdentifier("____hasBroken"),
lua.createBooleanLiteral(true),
breakStatement
),
lua.createReturnStatement([], breakStatement),
];
}
return lua.createBreakStatement(breakStatement);
};

Expand All @@ -28,6 +40,19 @@ export const transformContinueStatement: FunctionVisitor<ts.ContinueStatement> =
scope.loopContinued = continuedWith;
}

const tryScope = isInAsyncFunction(statement) ? findAsyncTryScopeBeforeLoop(context) : undefined;
if (tryScope) {
tryScope.asyncTryHasContinue = continuedWith;
return [
lua.createAssignmentStatement(
lua.createIdentifier("____hasContinued"),
lua.createBooleanLiteral(true),
statement
),
lua.createReturnStatement([], statement),
];
}

const label = `__continue${scope?.id ?? ""}`;

switch (continuedWith) {
Expand Down
90 changes: 76 additions & 14 deletions src/transformation/visitors/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { FunctionVisitor, TransformationContext } from "../context";
import { unsupportedForTarget, unsupportedForTargetButOverrideAvailable } from "../utils/diagnostics";
import { createUnpackCall } from "../utils/lua-ast";
import { transformLuaLibFunction } from "../utils/lualib";
import { Scope, ScopeType } from "../utils/scope";
import { findScope, LoopContinued, Scope, ScopeType } from "../utils/scope";
import { isInAsyncFunction, isInGeneratorFunction } from "../utils/typescript";
import { wrapInAsyncAwaiter } from "./async-await";
import { transformScopeBlock } from "./block";
Expand All @@ -14,7 +14,7 @@ import { isInMultiReturnFunction } from "./language-extensions/multi";
import { createReturnStatement } from "./return";

const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context) => {
const [tryBlock] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);
const [tryBlock, tryScope] = transformScopeBlock(context, statement.tryBlock, ScopeType.Try);

if (
(context.options.luaTarget === LuaTarget.Lua50 || context.options.luaTarget === LuaTarget.Lua51) &&
Expand All @@ -31,13 +31,14 @@ const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context)
return tryBlock.statements;
}

// __TS__AsyncAwaiter(<catch block>)
// __TS__AsyncAwaiter(<try block>)
const awaiter = wrapInAsyncAwaiter(context, tryBlock.statements, false);
const awaiterIdentifier = lua.createIdentifier("____try");
const awaiterDefinition = lua.createVariableDeclarationStatement(awaiterIdentifier, awaiter);

// local ____try = __TS__AsyncAwaiter(<catch block>)
const result: lua.Statement[] = [awaiterDefinition];
// Transform catch/finally and collect scope info before building the result
let catchScope: Scope | undefined;
const chainCalls: lua.Statement[] = [];

if (statement.finallyBlock) {
const awaiterFinally = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("finally"));
Expand All @@ -49,27 +50,88 @@ const transformAsyncTry: FunctionVisitor<ts.TryStatement> = (statement, context)
[awaiterIdentifier, finallyFunction],
statement.finallyBlock
);
// ____try.finally(<finally function>)
result.push(lua.createExpressionStatement(finallyCall));
chainCalls.push(lua.createExpressionStatement(finallyCall));
}

if (statement.catchClause) {
// ____try.catch(<catch function>)
const [catchFunction] = transformCatchClause(context, statement.catchClause);
const [catchFunction, cScope] = transformCatchClause(context, statement.catchClause);
catchScope = cScope;
if (catchFunction.params) {
catchFunction.params.unshift(lua.createAnonymousIdentifier());
}

const awaiterCatch = lua.createTableIndexExpression(awaiterIdentifier, lua.createStringLiteral("catch"));
const catchCall = lua.createCallExpression(awaiterCatch, [awaiterIdentifier, catchFunction]);

// await ____try.catch(<catch function>)
const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, catchCall);
result.push(lua.createExpressionStatement(promiseAwait, statement));
chainCalls.push(lua.createExpressionStatement(promiseAwait, statement));
} else {
// await ____try
const promiseAwait = transformLuaLibFunction(context, LuaLibFeature.Await, statement, awaiterIdentifier);
result.push(lua.createExpressionStatement(promiseAwait, statement));
chainCalls.push(lua.createExpressionStatement(promiseAwait, statement));
}

const hasReturn = tryScope.asyncTryHasReturn ?? catchScope?.asyncTryHasReturn;
const hasBreak = tryScope.asyncTryHasBreak ?? catchScope?.asyncTryHasBreak;
const hasContinue = tryScope.asyncTryHasContinue ?? catchScope?.asyncTryHasContinue;

// Build result in output order: flag declarations, awaiter, chain calls, post-checks
const result: lua.Statement[] = [];

if (hasReturn || hasBreak || hasContinue !== undefined) {
const flagDecls: lua.Identifier[] = [];
if (hasReturn) {
flagDecls.push(lua.createIdentifier("____hasReturned"));
flagDecls.push(lua.createIdentifier("____returnValue"));
}
if (hasBreak) {
flagDecls.push(lua.createIdentifier("____hasBroken"));
}
if (hasContinue !== undefined) {
flagDecls.push(lua.createIdentifier("____hasContinued"));
}
result.push(lua.createVariableDeclarationStatement(flagDecls));
}

result.push(awaiterDefinition);
result.push(...chainCalls);

if (hasReturn) {
result.push(
lua.createIfStatement(
lua.createIdentifier("____hasReturned"),
lua.createBlock([createReturnStatement(context, [lua.createIdentifier("____returnValue")], statement)])
)
);
}

if (hasBreak) {
result.push(
lua.createIfStatement(lua.createIdentifier("____hasBroken"), lua.createBlock([lua.createBreakStatement()]))
);
}

if (hasContinue !== undefined) {
const loopScope = findScope(context, ScopeType.Loop);
const label = `__continue${loopScope?.id ?? ""}`;

const continueStatements: lua.Statement[] = [];
switch (hasContinue) {
case LoopContinued.WithGoto:
continueStatements.push(lua.createGotoStatement(label));
break;
case LoopContinued.WithContinue:
continueStatements.push(lua.createContinueStatement());
break;
case LoopContinued.WithRepeatBreak:
continueStatements.push(
lua.createAssignmentStatement(lua.createIdentifier(label), lua.createBooleanLiteral(true))
);
continueStatements.push(lua.createBreakStatement());
break;
}

result.push(
lua.createIfStatement(lua.createIdentifier("____hasContinued"), lua.createBlock(continueStatements))
);
}

return result;
Expand Down
74 changes: 49 additions & 25 deletions src/transformation/visitors/return.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as lua from "../../LuaAST";
import { FunctionVisitor, TransformationContext } from "../context";
import { validateAssignment } from "../utils/assignment-validation";
import { createUnpackCall, wrapInTable } from "../utils/lua-ast";
import { ScopeType, walkScopesUp } from "../utils/scope";
import { findAsyncTryScopeInStack, ScopeType, walkScopesUp } from "../utils/scope";
import { transformArguments } from "./call";
import {
returnsMultiType,
Expand All @@ -16,11 +16,7 @@ import {
import { invalidMultiFunctionReturnType } from "../utils/diagnostics";
import { isInAsyncFunction } from "../utils/typescript";

function transformExpressionsInReturn(
context: TransformationContext,
node: ts.Expression,
insideTryCatch: boolean
): lua.Expression[] {
function transformExpressionsInReturn(context: TransformationContext, node: ts.Expression): lua.Expression[] {
const expressionType = context.checker.getTypeAtLocation(node);

// skip type assertions
Expand All @@ -36,20 +32,7 @@ function transformExpressionsInReturn(
context.diagnostics.push(invalidMultiFunctionReturnType(innerNode));
}

let returnValues = transformArguments(context, innerNode.arguments);
if (insideTryCatch) {
returnValues = [wrapInTable(...returnValues)]; // Wrap results when returning inside try/catch
}
return returnValues;
}

// Force-wrap LuaMultiReturn when returning inside try/catch
if (
insideTryCatch &&
returnsMultiType(context, innerNode) &&
!shouldMultiReturnCallBeWrapped(context, innerNode)
) {
return [wrapInTable(context.transformExpression(node))];
return transformArguments(context, innerNode.arguments);
}
} else if (isInMultiReturnFunction(context, innerNode) && isMultiReturnType(expressionType)) {
// Unpack objects typed as LuaMultiReturn
Expand All @@ -63,24 +46,65 @@ export function transformExpressionBodyToReturnStatement(
context: TransformationContext,
node: ts.Expression
): lua.Statement {
const expressions = transformExpressionsInReturn(context, node, false);
const expressions = transformExpressionsInReturn(context, node);
return createReturnStatement(context, expressions, node);
}

function transformReturnExpressionForTryCatch(context: TransformationContext, node: ts.Expression): lua.Expression {
const innerNode = ts.skipOuterExpressions(node, ts.OuterExpressionKinds.Assertions);

if (ts.isCallExpression(innerNode)) {
if (isMultiFunctionCall(context, innerNode)) {
const type = context.checker.getContextualType(node);
if (type && !canBeMultiReturnType(type)) {
context.diagnostics.push(invalidMultiFunctionReturnType(innerNode));
}
return wrapInTable(...transformArguments(context, innerNode.arguments));
}

if (returnsMultiType(context, innerNode) && !shouldMultiReturnCallBeWrapped(context, innerNode)) {
return wrapInTable(context.transformExpression(node));
}
}

return context.transformExpression(node);
}

export const transformReturnStatement: FunctionVisitor<ts.ReturnStatement> = (statement, context) => {
let results: lua.Expression[];
const asyncTryScope = isInAsyncFunction(statement) ? findAsyncTryScopeInStack(context) : undefined;

if (statement.expression) {
const expressionType = context.checker.getTypeAtLocation(statement.expression);
const returnType = context.checker.getContextualType(statement.expression);
if (returnType) {
validateAssignment(context, statement, expressionType, returnType);
}
}

results = transformExpressionsInReturn(context, statement.expression, isInTryCatch(context));
} else {
// Empty return
if (asyncTryScope) {
asyncTryScope.asyncTryHasReturn = true;
const stmts: lua.Statement[] = [
lua.createAssignmentStatement(
lua.createIdentifier("____hasReturned"),
lua.createBooleanLiteral(true),
statement
),
];
if (statement.expression) {
const returnValue = transformReturnExpressionForTryCatch(context, statement.expression);
stmts.push(lua.createAssignmentStatement(lua.createIdentifier("____returnValue"), returnValue, statement));
}
stmts.push(lua.createReturnStatement([], statement));
return stmts;
}

let results: lua.Expression[];
if (!statement.expression) {
results = [];
} else if (isInTryCatch(context)) {
results = [transformReturnExpressionForTryCatch(context, statement.expression)];
} else {
results = transformExpressionsInReturn(context, statement.expression);
}

return createReturnStatement(context, results, statement);
Expand Down
Loading
Loading