diff --git a/src/transformation/utils/import.ts b/src/transformation/utils/import.ts new file mode 100644 index 000000000..43e641b8d --- /dev/null +++ b/src/transformation/utils/import.ts @@ -0,0 +1,15 @@ +import * as ts from "typescript"; +import * as lua from "../../LuaAST"; + +export function createImportsIdentifier(): lua.Identifier { + return lua.createIdentifier("____imports"); +} + +export function isSymbolImported(symbol: ts.Symbol): boolean { + return symbol.declarations?.some(d => ts.isImportSpecifier(d) || ts.isNamespaceImport(d)) ?? false; +} + +export function createImportedIdentifier(luaIdentifier: lua.Identifier, node: ts.Node): lua.AssignmentLeftHandSideExpression { + const importsTable = lua.createIdentifier("____imports"); + return lua.createTableIndexExpression(importsTable, lua.createStringLiteral(luaIdentifier.text), node); +} diff --git a/src/transformation/visitors/identifier.ts b/src/transformation/visitors/identifier.ts index efb5349bc..073045521 100644 --- a/src/transformation/visitors/identifier.ts +++ b/src/transformation/visitors/identifier.ts @@ -12,15 +12,16 @@ import { isStandardLibraryType } from "../utils/typescript"; import { getExtensionKindForNode, getExtensionKindForSymbol } from "../utils/language-extensions"; import { callExtensions } from "./language-extensions/call-extension"; import { isIdentifierExtensionValue, reportInvalidExtensionValue } from "./language-extensions/identifier"; +import { createImportedIdentifier, isSymbolImported } from "../utils/import"; export function transformIdentifier(context: TransformationContext, identifier: ts.Identifier): lua.Identifier { - return transformNonValueIdentifier(context, identifier, context.checker.getSymbolAtLocation(identifier)); + return transformNonValueIdentifier(context, identifier, context.checker.getSymbolAtLocation(identifier)) as lua.Identifier; } function transformNonValueIdentifier( context: TransformationContext, identifier: ts.Identifier, symbol: ts.Symbol | undefined -) { +): lua.Expression { if (isOptionalContinuation(identifier)) { return lua.createIdentifier(identifier.text, undefined, tempSymbolId); } @@ -52,7 +53,9 @@ function transformNonValueIdentifier( : identifier.text; const symbolId = getIdentifierSymbolId(context, identifier, symbol); - return lua.createIdentifier(text, identifier, symbolId, identifier.text); + const luaIdentifier = lua.createIdentifier(text, identifier, symbolId, identifier.text); + + return symbol && isSymbolImported(symbol) ? createImportedIdentifier(luaIdentifier, identifier) : luaIdentifier; } export function transformIdentifierWithSymbol( diff --git a/src/transformation/visitors/modules/export.ts b/src/transformation/visitors/modules/export.ts index 219e99eee..4915958bf 100644 --- a/src/transformation/visitors/modules/export.ts +++ b/src/transformation/visitors/modules/export.ts @@ -1,6 +1,6 @@ import * as ts from "typescript"; import * as lua from "../../../LuaAST"; -import { assert } from "../../../utils"; +import { assert, cast } from "../../../utils"; import { FunctionVisitor, TransformationContext } from "../../context"; import { createDefaultExportExpression, @@ -145,20 +145,21 @@ function transformExportSpecifiersFrom( // Wrap in block to prevent imports from hoisting out of `do` statement const [block] = transformScopeBlock(context, ts.factory.createBlock([importDeclaration]), ScopeType.Block); - const result = block.statements; + const [requireStatement, ...importAssignments] = block.statements; // Now the module is imported, add the imports to the export table + const assignments = []; for (const specifier of exportSpecifiers) { - result.push( + assignments.push( lua.createAssignmentStatement( createExportedIdentifier(context, transformIdentifier(context, specifier.name)), - transformIdentifier(context, specifier.name) + cast(importAssignments[assignments.length], lua.isAssignmentStatement).right ) ); } // Wrap this in a DoStatement to prevent polluting the scope. - return lua.createDoStatement(result, statement); + return lua.createDoStatement([requireStatement, ...assignments], statement); } export const getExported = (context: TransformationContext, exportSpecifiers: ts.NamedExports) => diff --git a/src/transformation/visitors/modules/import.ts b/src/transformation/visitors/modules/import.ts index 7a320d558..fe61b1715 100644 --- a/src/transformation/visitors/modules/import.ts +++ b/src/transformation/visitors/modules/import.ts @@ -45,14 +45,14 @@ function transformImportSpecifier( context: TransformationContext, importSpecifier: ts.ImportSpecifier, moduleTableName: lua.Identifier -): lua.VariableDeclarationStatement { +): lua.AssignmentStatement { const leftIdentifier = transformIdentifier(context, importSpecifier.name); const propertyName = transformPropertyName( context, importSpecifier.propertyName ? importSpecifier.propertyName : importSpecifier.name ); - return lua.createVariableDeclarationStatement( + return lua.createAssignmentStatement( leftIdentifier, lua.createTableIndexExpression(moduleTableName, propertyName), importSpecifier @@ -93,7 +93,7 @@ export const transformImportDeclaration: FunctionVisitor = if (statement.importClause.name) { if (shouldBeImported(context, statement.importClause)) { const propertyName = createDefaultExportStringLiteral(statement.importClause.name); - const defaultImportAssignmentStatement = lua.createVariableDeclarationStatement( + const defaultImportAssignmentStatement = lua.createAssignmentStatement( transformIdentifier(context, statement.importClause.name), lua.createTableIndexExpression(importUniqueName, propertyName), statement.importClause.name @@ -108,7 +108,7 @@ export const transformImportDeclaration: FunctionVisitor = // local module = require("module") if (statement.importClause.namedBindings && ts.isNamespaceImport(statement.importClause.namedBindings)) { if (context.resolver.isReferencedAliasDeclaration(statement.importClause.namedBindings)) { - const requireStatement = lua.createVariableDeclarationStatement( + const requireStatement = lua.createAssignmentStatement( transformIdentifier(context, statement.importClause.namedBindings.name), requireCall, statement diff --git a/src/transformation/visitors/sourceFile.ts b/src/transformation/visitors/sourceFile.ts index 6870d292b..acd5b0de3 100644 --- a/src/transformation/visitors/sourceFile.ts +++ b/src/transformation/visitors/sourceFile.ts @@ -29,6 +29,8 @@ export const transformSourceFileNode: FunctionVisitor = (node, co context.pushScope(ScopeType.File); statements = performHoisting(context, context.transformStatements(node.statements)); + const hasImports = context.scopeStack[0].importStatements?.some(lua.isAssignmentStatement); + context.popScope(); if (context.isModule) { @@ -40,6 +42,15 @@ export const transformSourceFileNode: FunctionVisitor = (node, co ); } + if (hasImports) { + statements.unshift( + lua.createVariableDeclarationStatement( + lua.createIdentifier("____imports"), + lua.createTableExpression() + ) + ); + } + // return ____exports statements.push(lua.createReturnStatement([createExportsIdentifier()])); } diff --git a/test/translation/__snapshots__/transformation.spec.ts.snap b/test/translation/__snapshots__/transformation.spec.ts.snap index 05bb647b8..2675350a3 100644 --- a/test/translation/__snapshots__/transformation.spec.ts.snap +++ b/test/translation/__snapshots__/transformation.spec.ts.snap @@ -39,15 +39,12 @@ do end do local ____xyz = require(\\"xyz\\") - local abc = ____xyz.abc - local def = ____xyz.def - ____exports.abc = abc - ____exports.def = def + ____exports.abc = ____xyz.abc + ____exports.def = ____xyz.def end do local ____xyz = require(\\"xyz\\") - local def = ____xyz.abc - ____exports.def = def + ____exports.def = ____xyz.abc end return ____exports" `; @@ -113,65 +110,70 @@ end" `; exports[`Transformation (modulesImportAll) 1`] = ` -"local ____exports = {} -local Test = require(\\"test\\") -local ____ = Test +"local ____imports = {} +local ____exports = {} +____imports.Test = require(\\"test\\") +local ____ = ____imports.Test return ____exports" `; exports[`Transformation (modulesImportNamed) 1`] = ` -"local ____exports = {} +"local ____imports = {} +local ____exports = {} local ____test = require(\\"test\\") -local TestClass = ____test.TestClass -local ____ = TestClass +____imports.TestClass = ____test.TestClass +local ____ = ____imports.TestClass return ____exports" `; exports[`Transformation (modulesImportNamedSpecialChars) 1`] = ` -"local ____exports = {} +"local ____imports = {} +local ____exports = {} local ____kebab_2Dmodule = require(\\"kebab-module\\") -local TestClass1 = ____kebab_2Dmodule.TestClass1 +____imports.TestClass1 = ____kebab_2Dmodule.TestClass1 local ____dollar_24module = require(\\"dollar$module\\") -local TestClass2 = ____dollar_24module.TestClass2 +____imports.TestClass2 = ____dollar_24module.TestClass2 local ____singlequote_27module = require(\\"singlequote'module\\") -local TestClass3 = ____singlequote_27module.TestClass3 +____imports.TestClass3 = ____singlequote_27module.TestClass3 local ____hash_23module = require(\\"hash#module\\") -local TestClass4 = ____hash_23module.TestClass4 +____imports.TestClass4 = ____hash_23module.TestClass4 local ____space_20module = require(\\"space module\\") -local TestClass5 = ____space_20module.TestClass5 -local ____ = TestClass1 -local ____ = TestClass2 -local ____ = TestClass3 -local ____ = TestClass4 -local ____ = TestClass5 +____imports.TestClass5 = ____space_20module.TestClass5 +local ____ = ____imports.TestClass1 +local ____ = ____imports.TestClass2 +local ____ = ____imports.TestClass3 +local ____ = ____imports.TestClass4 +local ____ = ____imports.TestClass5 return ____exports" `; exports[`Transformation (modulesImportRenamed) 1`] = ` -"local ____exports = {} +"local ____imports = {} +local ____exports = {} local ____test = require(\\"test\\") -local RenamedClass = ____test.TestClass -local ____ = RenamedClass +____imports.RenamedClass = ____test.TestClass +local ____ = ____imports.RenamedClass return ____exports" `; exports[`Transformation (modulesImportRenamedSpecialChars) 1`] = ` -"local ____exports = {} +"local ____imports = {} +local ____exports = {} local ____kebab_2Dmodule = require(\\"kebab-module\\") -local RenamedClass1 = ____kebab_2Dmodule.TestClass +____imports.RenamedClass1 = ____kebab_2Dmodule.TestClass local ____dollar_24module = require(\\"dollar$module\\") -local RenamedClass2 = ____dollar_24module.TestClass +____imports.RenamedClass2 = ____dollar_24module.TestClass local ____singlequote_27module = require(\\"singlequote'module\\") -local RenamedClass3 = ____singlequote_27module.TestClass +____imports.RenamedClass3 = ____singlequote_27module.TestClass local ____hash_23module = require(\\"hash#module\\") -local RenamedClass4 = ____hash_23module.TestClass +____imports.RenamedClass4 = ____hash_23module.TestClass local ____space_20module = require(\\"space module\\") -local RenamedClass5 = ____space_20module.TestClass -local ____ = RenamedClass1 -local ____ = RenamedClass2 -local ____ = RenamedClass3 -local ____ = RenamedClass4 -local ____ = RenamedClass5 +____imports.RenamedClass5 = ____space_20module.TestClass +local ____ = ____imports.RenamedClass1 +local ____ = ____imports.RenamedClass2 +local ____ = ____imports.RenamedClass3 +local ____ = ____imports.RenamedClass4 +local ____ = ____imports.RenamedClass5 return ____exports" `; @@ -278,8 +280,9 @@ end" `; exports[`Transformation (unusedDefaultWithNamespaceImport) 1`] = ` -"local ____exports = {} -local x = require(\\"module\\") -local ____ = x +"local ____imports = {} +local ____exports = {} +____imports.x = require(\\"module\\") +local ____ = ____imports.x return ____exports" `; diff --git a/test/transpile/__snapshots__/project.spec.ts.snap b/test/transpile/__snapshots__/project.spec.ts.snap index c2f66a1e7..d82e170ac 100644 --- a/test/transpile/__snapshots__/project.spec.ts.snap +++ b/test/transpile/__snapshots__/project.spec.ts.snap @@ -33,10 +33,11 @@ return ____exports }, Object { "filePath": "index.lua", - "lua": "local ____exports = {} + "lua": "local ____imports = {} +local ____exports = {} local ____otherFile = require(\\"otherFile\\") -local getNumber = ____otherFile.getNumber -local myNumber = getNumber(nil) +____imports.getNumber = ____otherFile.getNumber +local myNumber = ____imports.getNumber(nil) setAPIValue(myNumber * 5) return ____exports ", diff --git a/test/transpile/bundle.spec.ts b/test/transpile/bundle.spec.ts index 3834bbb2f..dedf212ba 100644 --- a/test/transpile/bundle.spec.ts +++ b/test/transpile/bundle.spec.ts @@ -90,7 +90,7 @@ describe("bundle with source maps", () => { }> = [ { file: "index", - luaPattern: "____exports.myNumber = getNumber(", + luaPattern: "____exports.myNumber = ____imports.getNumber(", typeScriptPattern: "const myNumber = getNumber(", }, { diff --git a/test/unit/modules/__snapshots__/resolution.spec.ts.snap b/test/unit/modules/__snapshots__/resolution.spec.ts.snap index a250dbd46..930ed3332 100644 --- a/test/unit/modules/__snapshots__/resolution.spec.ts.snap +++ b/test/unit/modules/__snapshots__/resolution.spec.ts.snap @@ -1,9 +1,10 @@ // Jest Snapshot v1, https://goo.gl/fbAQLP exports[`doesn't resolve paths out of root dir: code 1`] = ` -"local ____exports = {} -local module = require(\\"module\\") -local ____ = module +"local ____imports = {} +local ____exports = {} +____imports.module = require(\\"module\\") +local ____ = ____imports.module return ____exports" `; diff --git a/test/unit/modules/modules.spec.ts b/test/unit/modules/modules.spec.ts index 73f7c5be1..26e8de071 100644 --- a/test/unit/modules/modules.spec.ts +++ b/test/unit/modules/modules.spec.ts @@ -291,3 +291,22 @@ test("import expression", () => { .setOptions({ module: ts.ModuleKind.ESNext }) .expectToMatchJsResult(); }); + +test("imports table", () => { + util.testModule` + import { var1, var2 } from "./otherFile"; + const var3 = var1 + var2; + export { var1, var2, var3 } + ` + .addExtraFile( + "otherFile.ts", + ` + export const var1 = 3; + export const var2 = 5;` + ) + .expectToEqual({ + var1: 3, + var2: 5, + var3: 8 + }); +});