diff --git a/package.json b/package.json index 87093d66..cf2f2faf 100644 --- a/package.json +++ b/package.json @@ -3,7 +3,9 @@ "version": "1.0.13", "main": "dist/index.js", "types": "dist/index.d.ts", - "files": ["dist"], + "files": [ + "dist" + ], "repository": { "type": "git", "url": "git+https://github.com/source-academy/java-slang.git" @@ -40,5 +42,6 @@ "java-parser": "^2.0.5", "lodash": "^4.17.21", "peggy": "^4.0.2" - } + }, + "packageManager": "yarn@1.22.22+sha1.ac34549e6aa8e7ead463a7407e1c7390f61a6610" } diff --git a/src/ClassFile/types/index.ts b/src/ClassFile/types/index.ts index 4f963d44..71ccd814 100644 --- a/src/ClassFile/types/index.ts +++ b/src/ClassFile/types/index.ts @@ -3,6 +3,11 @@ import { ConstantInfo } from './constants' import { FieldInfo } from './fields' import { MethodInfo } from './methods' +export interface Class { + classFile: ClassFile + className: string +} + export interface ClassFile { magic: number minorVersion: number diff --git a/src/ast/__tests__/expression-extractor.test.ts b/src/ast/__tests__/expression-extractor.test.ts index 3e3238b0..9a7d03b7 100644 --- a/src/ast/__tests__/expression-extractor.test.ts +++ b/src/ast/__tests__/expression-extractor.test.ts @@ -1035,3 +1035,142 @@ describe("extract ClassInstanceCreationExpression correctly", () => { expect(ast).toEqual(expectedAst); }); }); + +describe("extract CastExpression correctly", () => { + it("extract CastExpression int to char correctly", () => { + const programStr = ` + class Test { + void test() { + char c = (char) 65; + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "LocalVariableDeclarationStatement", + localVariableType: "char", + variableDeclaratorList: [ + { + kind: "VariableDeclarator", + variableDeclaratorId: "c", + variableInitializer: { + kind: "CastExpression", + type: "char", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "65", + }, + location: expect.anything(), + }, + location: expect.anything(), + }, + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); + + it("extract CastExpression double to int correctly", () => { + const programStr = ` + class Test { + void test() { + int x = (int) 3.14; + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "LocalVariableDeclarationStatement", + localVariableType: "int", + variableDeclaratorList: [ + { + kind: "VariableDeclarator", + variableDeclaratorId: "x", + variableInitializer: { + kind: "CastExpression", + type: "int", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalFloatingPointLiteral", + value: "3.14", + } + }, + location: expect.anything(), + }, + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); +}); diff --git a/src/ast/__tests__/switch-statement-extractor.test.ts b/src/ast/__tests__/switch-statement-extractor.test.ts new file mode 100644 index 00000000..00ae21ea --- /dev/null +++ b/src/ast/__tests__/switch-statement-extractor.test.ts @@ -0,0 +1,435 @@ +import { parse } from "../parser"; +import { AST } from "../types/packages-and-modules"; + +describe("extract SwitchStatement correctly", () => { + it("extract SwitchStatement with case labels and statements correctly", () => { + const programStr = ` + class Test { + void test(int x) { + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + default: + System.out.println("Default"); + } + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [ + { + kind: "FormalParameter", + unannType: "int", + identifier: "x", + }, + ], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "SwitchStatement", + expression: { + kind: "ExpressionName", + name: "x", + location: expect.anything(), + }, + cases: [ + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "1", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"One"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "2", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"Two"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + { + kind: "SwitchCase", + labels: [ + { + kind: "DefaultLabel", + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"Default"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); + + it("extract SwitchStatement with fallthrough correctly", () => { + const programStr = ` + class Test { + void test(int x) { + switch (x) { + case 1: + case 2: + System.out.println("One or Two"); + break; + default: + System.out.println("Default"); + } + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [ + { + kind: "FormalParameter", + unannType: "int", + identifier: "x", + }, + ], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "SwitchStatement", + expression: { + kind: "ExpressionName", + name: "x", + location: expect.anything(), + }, + cases: [ + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "1", + }, + }, + }, + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "2", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"One or Two"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + { + kind: "SwitchCase", + labels: [ + { + kind: "DefaultLabel", + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"Default"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + expect(ast).toEqual(expectedAst); + }); + + it("extract SwitchStatement without default case correctly", () => { + const programStr = ` + class Test { + void test(int x) { + switch (x) { + case 1: + System.out.println("One"); + break; + } + } + } + `; + + const expectedAst: AST = { + kind: "CompilationUnit", + importDeclarations: [], + topLevelClassOrInterfaceDeclarations: [ + { + kind: "NormalClassDeclaration", + classModifier: [], + typeIdentifier: "Test", + classBody: [ + { + kind: "MethodDeclaration", + methodModifier: [], + methodHeader: { + result: "void", + identifier: "test", + formalParameterList: [ + { + kind: "FormalParameter", + unannType: "int", + identifier: "x", + }, + ], + }, + methodBody: { + kind: "Block", + blockStatements: [ + { + kind: "SwitchStatement", + expression: { + kind: "ExpressionName", + name: "x", + location: expect.anything(), + }, + cases: [ + { + kind: "SwitchCase", + labels: [ + { + kind: "CaseLabel", + expression: { + kind: "Literal", + literalType: { + kind: "DecimalIntegerLiteral", + value: "1", + }, + }, + }, + ], + statements: [ + { + kind: "ExpressionStatement", + stmtExp: { + kind: "MethodInvocation", + identifier: "System.out.println", + argumentList: [ + { + kind: "Literal", + literalType: { + kind: "StringLiteral", + value: '"One"', + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + { + kind: "BreakStatement", + }, + ], + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + location: expect.anything(), + }, + ], + location: expect.anything(), + }, + ], + location: expect.anything(), + }; + + const ast = parse(programStr); + console.log(JSON.stringify(ast, null, 2)); + expect(ast).toEqual(expectedAst); + }); +}); diff --git a/src/ast/astExtractor/expression-extractor.ts b/src/ast/astExtractor/expression-extractor.ts index 3301db6a..dcce959c 100644 --- a/src/ast/astExtractor/expression-extractor.ts +++ b/src/ast/astExtractor/expression-extractor.ts @@ -2,6 +2,7 @@ import { ArgumentListCtx, BaseJavaCstVisitorWithDefaults, BinaryExpressionCtx, + CastExpressionCtx, ClassOrInterfaceTypeToInstantiateCtx, BooleanLiteralCtx, ExpressionCstNode, @@ -86,6 +87,62 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults { } } + castExpression(ctx: CastExpressionCtx) { + // Handle primitive cast expressions + if (ctx.primitiveCastExpression && ctx.primitiveCastExpression?.length > 0) { + const primitiveCast = ctx.primitiveCastExpression[0]; + const type = this.extractType(primitiveCast.children.primitiveType[0]); + const expression = this.visit(primitiveCast.children.unaryExpression[0]); + return { + kind: "CastExpression", + type: type, + expression: expression, + location: this.location, + }; + } + + throw new Error("Invalid CastExpression format."); + } + + private extractType(typeCtx: any): string { + // Check for the 'primitiveType' node + if (typeCtx.name === "primitiveType" && typeCtx.children) { + const { children } = typeCtx; + + // Handle 'numericType' (e.g., int, char, float, double) + if (children.numericType) { + const numericTypeCtx = children.numericType[0]; + + if (numericTypeCtx.children.integralType) { + // Handle integral types (e.g., char, int) + const integralTypeCtx = numericTypeCtx.children.integralType[0]; + + // Extract the specific type (e.g., 'char', 'int') + for (const key in integralTypeCtx.children) { + if (integralTypeCtx.children[key][0].image) { + return integralTypeCtx.children[key][0].image; + } + } + } + + if (numericTypeCtx.children.floatingPointType) { + // Handle floating-point types (e.g., float, double) + const floatingPointTypeCtx = numericTypeCtx.children.floatingPointType[0]; + + // Extract the specific type (e.g., 'float', 'double') + for (const key in floatingPointTypeCtx.children) { + if (floatingPointTypeCtx.children[key][0].image) { + return floatingPointTypeCtx.children[key][0].image; + } + } + } + } + } + + throw new Error("Invalid type context in cast expression."); + } + + private makeBinaryExpression( operators: IToken[], operands: UnaryExpressionCstNode[] @@ -174,6 +231,10 @@ export class ExpressionExtractor extends BaseJavaCstVisitorWithDefaults { } unaryExpression(ctx: UnaryExpressionCtx) { + if (ctx.primary[0].children.primaryPrefix[0].children.castExpression) { + return this.visit(ctx.primary[0].children.primaryPrefix[0].children.castExpression); + } + const node = this.visit(ctx.primary); if (ctx.UnaryPrefixOperator) { return { diff --git a/src/ast/astExtractor/statement-extractor.ts b/src/ast/astExtractor/statement-extractor.ts index ac2918cf..9e9f0e4c 100644 --- a/src/ast/astExtractor/statement-extractor.ts +++ b/src/ast/astExtractor/statement-extractor.ts @@ -21,6 +21,10 @@ import { PrimaryPrefixCtx, PrimarySuffixCtx, ReturnStatementCtx, + SwitchStatementCtx, + SwitchBlockCtx, + SwitchLabelCtx, + SwitchBlockStatementGroupCtx, StatementCstNode, StatementExpressionCtx, StatementWithoutTrailingSubstatementCtx, @@ -32,13 +36,17 @@ import { ExpressionStatementCtx, LocalVariableTypeCtx, VariableDeclaratorListCtx, - VariableDeclaratorCtx, -} from "java-parser"; + VariableDeclaratorCtx +} from 'java-parser' import { BasicForStatement, ExpressionStatement, IfStatement, MethodInvocation, + SwitchStatement, + SwitchCase, + CaseLabel, + DefaultLabel, Statement, StatementExpression, VariableDeclarator, @@ -80,6 +88,8 @@ export class StatementExtractor extends BaseJavaCstVisitorWithDefaults { return { kind: "BreakStatement" }; } else if (ctx.continueStatement) { return { kind: "ContinueStatement" }; + } else if (ctx.switchStatement) { + return this.visit(ctx.switchStatement); } else if (ctx.returnStatement) { const returnStatementExp = this.visit(ctx.returnStatement); return { @@ -90,6 +100,122 @@ export class StatementExtractor extends BaseJavaCstVisitorWithDefaults { } } + switchStatement(ctx: SwitchStatementCtx): SwitchStatement { + const expressionExtractor = new ExpressionExtractor(); + + return { + kind: "SwitchStatement", + expression: expressionExtractor.extract(ctx.expression[0]), + cases: ctx.switchBlock + ? this.visit(ctx.switchBlock) + : [], + location: ctx.Switch[0] + }; + } + + switchBlock(ctx: SwitchBlockCtx): Array { + const cases: Array = []; + let currentCase: SwitchCase; + + ctx.switchBlockStatementGroup?.forEach((group) => { + const extractedCase = this.visit(group); + + if (!currentCase) { + // First case in the switch block + currentCase = extractedCase; + cases.push(currentCase); + } else if (currentCase.statements && currentCase.statements.length === 0) { + // Fallthrough case, merge labels + currentCase.labels.push(...extractedCase.labels); + } else { + // New case with statements starts, push previous case and start new one + currentCase = extractedCase; + cases.push(currentCase); + } + }); + + return cases; + } + + switchBlockStatementGroup(ctx: SwitchBlockStatementGroupCtx): SwitchCase { + const blockStatementExtractor = new BlockStatementExtractor(); + + console.log(ctx.switchLabel) + + return { + kind: "SwitchCase", + labels: ctx.switchLabel.flatMap((label) => this.visit(label)), + statements: ctx.blockStatements + ? ctx.blockStatements.flatMap((blockStatements) => + blockStatements.children.blockStatement.map((stmt) => + blockStatementExtractor.extract(stmt) + ) + ) + : [], + }; + } + + // switchLabel(ctx: SwitchLabelCtx): CaseLabel | DefaultLabel { + // // Check if the context contains a "case" label + // if (ctx.caseOrDefaultLabel?.[0]?.children?.Case) { + // const expressionExtractor = new ExpressionExtractor(); + // // @ts-ignore + // const expressionCtx = ctx.caseOrDefaultLabel[0].children.caseLabelElement[0] + // .children.caseConstant[0].children.ternaryExpression[0].children; + // + // // Ensure the expression context is valid before proceeding + // if (!expressionCtx) { + // throw new Error("Invalid Case expression in switch label"); + // } + // + // const expression = expressionExtractor.ternaryExpression(expressionCtx); + // + // return { + // kind: "CaseLabel", + // expression: expression, + // }; + // } + // + // // Check if the context contains a "default" label + // if (ctx.caseOrDefaultLabel?.[0]?.children?.Default) { + // return { kind: "DefaultLabel" }; + // } + // + // // Throw an error if the context does not match expected patterns + // throw new Error("Invalid switch label: Neither 'case' nor 'default' found"); + // } + + switchLabel(ctx: SwitchLabelCtx): Array { + const expressionExtractor = new ExpressionExtractor(); + const labels: Array = []; + + // Process all case or default labels + for (const labelCtx of ctx.caseOrDefaultLabel) { + if (labelCtx.children.Case) { + // Extract the expression for the case label + const expressionCtx = labelCtx.children.caseLabelElement?.[0] + ?.children.caseConstant?.[0]?.children.ternaryExpression?.[0]?.children; + + if (!expressionCtx) { + throw new Error("Invalid Case expression in switch label"); + } + + labels.push({ + kind: "CaseLabel", + expression: expressionExtractor.ternaryExpression(expressionCtx), + }); + } else if (labelCtx.children.Default) { + labels.push({ kind: "DefaultLabel" }); + } + } + + if (labels.length === 0) { + throw new Error("Invalid switch label: Neither 'case' nor 'default' found"); + } + + return labels; + } + expressionStatement(ctx: ExpressionStatementCtx): ExpressionStatement { const stmtExp = this.visit(ctx.statementExpression); return { diff --git a/src/ast/types/blocks-and-statements.ts b/src/ast/types/blocks-and-statements.ts index fe5dc7ad..54a1ce9d 100644 --- a/src/ast/types/blocks-and-statements.ts +++ b/src/ast/types/blocks-and-statements.ts @@ -28,7 +28,8 @@ export type Statement = | IfStatement | WhileStatement | ForStatement - | EmptyStatement; + | EmptyStatement + | SwitchStatement; export interface EmptyStatement extends BaseNode { kind: "EmptyStatement"; @@ -66,6 +67,34 @@ export interface EnhancedForStatement extends BaseNode { kind: "EnhancedForStatement"; } +export interface SwitchStatement extends BaseNode { + kind: "SwitchStatement"; + expression: Expression; // The expression to evaluate for the switch + cases: Array; +} + +export interface SwitchCase extends BaseNode { + kind: "SwitchCase"; + labels: Array; // Labels for case blocks + statements?: Array; // Statements to execute for the case +} + +export type CaseLabel = CaseLiteralLabel | CaseExpressionLabel; + +export interface CaseLiteralLabel extends BaseNode { + kind: "CaseLabel"; + expression: Literal; // Literal values: byte, short, int, char, or String +} + +export interface CaseExpressionLabel extends BaseNode { + kind: "CaseLabel"; + expression: Expression; // For future extension if needed +} + +export interface DefaultLabel extends BaseNode { + kind: "DefaultLabel"; // Represents the default case +} + export type StatementWithoutTrailingSubstatement = | Block | ExpressionStatement @@ -259,7 +288,7 @@ export interface Assignment extends BaseNode { } export type LeftHandSide = ExpressionName | ArrayAccess; -export type UnaryExpression = PrefixExpression | PostfixExpression; +export type UnaryExpression = PrefixExpression | PostfixExpression | CastExpression; export interface PrefixExpression extends BaseNode { kind: "PrefixExpression"; @@ -289,3 +318,9 @@ export interface TernaryExpression extends BaseNode { consequent: Expression; alternate: Expression; } + +export interface CastExpression extends BaseNode { + kind: "CastExpression"; + type: UnannType; + expression: Expression; +} diff --git a/src/compiler/__tests__/__utils__/test-utils.ts b/src/compiler/__tests__/__utils__/test-utils.ts index 36d11f57..382c3907 100644 --- a/src/compiler/__tests__/__utils__/test-utils.ts +++ b/src/compiler/__tests__/__utils__/test-utils.ts @@ -1,45 +1,47 @@ -import { inspect } from "util"; -import { compile } from "../../index"; -import { BinaryWriter } from "../../binary-writer"; -import { AST } from "../../../ast/types/packages-and-modules"; -import { javaPegGrammar } from "../../grammar" +import { inspect } from 'util' +import { compile } from '../../index' +import { BinaryWriter } from '../../binary-writer' +import { AST } from '../../../ast/types/packages-and-modules' +import { javaPegGrammar } from '../../grammar' import { peggyFunctions } from '../../peggy-functions' -import { execSync } from "child_process"; +import { execSync } from 'child_process' -import * as peggy from "peggy"; -import * as fs from "fs"; +import * as peggy from 'peggy' +import * as fs from 'fs' export type testCase = { - comment: string, - program: string, - expectedLines: string[], + comment: string + program: string + expectedLines: string[] } -const debug = false; -const pathToTestDir = "./src/compiler/__tests__/"; +const debug = false +const pathToTestDir = './src/compiler/__tests__/' const parser = peggy.generate(peggyFunctions + javaPegGrammar, { - allowedStartRules: ["CompilationUnit"], -}); -const binaryWriter = new BinaryWriter(); + allowedStartRules: ['CompilationUnit'] +}) +const binaryWriter = new BinaryWriter() export function runTest(program: string, expectedLines: string[]) { - const ast = parser.parse(program); - expect(ast).not.toBeNull(); + const ast = parser.parse(program) + expect(ast).not.toBeNull() if (debug) { - console.log(inspect(ast, false, null, true)); + console.log(inspect(ast, false, null, true)) } - const classFile = compile(ast as AST); - binaryWriter.writeBinary(classFile, pathToTestDir); + const classes = compile(ast as AST) + for (let c of classes) { + binaryWriter.writeBinary(c.classFile, pathToTestDir) + } - const prevDir = process.cwd(); - process.chdir(pathToTestDir); - execSync("java -noverify Main > output.log 2> err.log"); + const prevDir = process.cwd() + process.chdir(pathToTestDir) + execSync('java -noverify Main > output.log 2> err.log') // ignore difference between \r\n and \n - const actualLines = fs.readFileSync("./output.log", 'utf-8').split(/\r?\n/).slice(0, -1); - process.chdir(prevDir); + const actualLines = fs.readFileSync('./output.log', 'utf-8').split(/\r?\n/).slice(0, -1) + process.chdir(prevDir) - expect(actualLines).toStrictEqual(expectedLines); + expect(actualLines).toStrictEqual(expectedLines) } diff --git a/src/compiler/__tests__/index.ts b/src/compiler/__tests__/index.ts index 8f31ef5c..5ea3b2b8 100644 --- a/src/compiler/__tests__/index.ts +++ b/src/compiler/__tests__/index.ts @@ -1,25 +1,35 @@ -import { printlnTest } from "./tests/println.test"; -import { variableDeclarationTest } from "./tests/variableDeclaration.test"; -import { arithmeticExpressionTest } from "./tests/arithmeticExpression.test"; -import { ifElseTest } from "./tests/ifElse.test"; -import { whileTest } from "./tests/while.test"; -import { forTest } from "./tests/for.test"; -import { unaryExpressionTest } from "./tests/unaryExpression.test"; -import { methodInvocationTest } from "./tests/methodInvocation.test"; -import { importTest } from "./tests/import.test"; -import { arrayTest } from "./tests/array.test"; -import { classTest } from "./tests/class.test"; +import { printlnTest } from './tests/println.test' +import { variableDeclarationTest } from './tests/variableDeclaration.test' +import { arithmeticExpressionTest } from './tests/arithmeticExpression.test' +import { ifElseTest } from './tests/ifElse.test' +import { whileTest } from './tests/while.test' +import { forTest } from './tests/for.test' +import { unaryExpressionTest } from './tests/unaryExpression.test' +import { methodInvocationTest } from './tests/methodInvocation.test' +import { importTest } from './tests/import.test' +import { arrayTest } from './tests/array.test' +import { classTest } from './tests/class.test' +import { assignmentExpressionTest } from './tests/assignmentExpression.test' +import { castExpressionTest } from './tests/castExpression.test' +import { switchTest } from './tests/switch.test' +import { methodOverloadingTest } from './tests/methodOverloading.test' +import { methodOverridingTest } from './tests/methodOverriding.test' -describe("compiler tests", () => { - printlnTest(); - variableDeclarationTest(); - arithmeticExpressionTest(); - unaryExpressionTest(); - ifElseTest(); - whileTest(); - forTest(); - methodInvocationTest(); - importTest(); - arrayTest(); - classTest(); -}) \ No newline at end of file +describe('compiler tests', () => { + methodOverridingTest() + methodOverloadingTest() + switchTest() + castExpressionTest() + printlnTest() + variableDeclarationTest() + arithmeticExpressionTest() + unaryExpressionTest() + ifElseTest() + whileTest() + forTest() + methodInvocationTest() + importTest() + arrayTest() + classTest() + assignmentExpressionTest() +}) diff --git a/src/compiler/__tests__/tests/arithmeticExpression.test.ts b/src/compiler/__tests__/tests/arithmeticExpression.test.ts index abe02048..72a38c1d 100644 --- a/src/compiler/__tests__/tests/arithmeticExpression.test.ts +++ b/src/compiler/__tests__/tests/arithmeticExpression.test.ts @@ -78,6 +78,58 @@ const testCases: testCase[] = [ expectedLines: ["-2147483648", "-32769", "-32768", "-129", "-128", "-1", "0", "1", "127", "128", "32767", "32768", "2147483647"], }, + { + comment: "Mixed int and float addition (order swapped)", + program: ` + public class Main { + public static void main(String[] args) { + int a = 5; + float b = 2.5f; + System.out.println(a + b); + } + } + `, + expectedLines: ["7.5"], + }, + { + comment: "Mixed long and double multiplication", + program: ` + public class Main { + public static void main(String[] args) { + double a = 3.5; + long b = 10L; + System.out.println(a * b); + } + } + `, + expectedLines: ["35.0"], + }, + { + comment: "Mixed long and double multiplication (order swapped)", + program: ` + public class Main { + public static void main(String[] args) { + long a = 10L; + double b = 3.5; + System.out.println(a * b); + } + } + `, + expectedLines: ["35.0"], + }, + { + comment: "Mixed int and double division", + program: ` + public class Main { + public static void main(String[] args) { + double a = 2.0; + int b = 5; + System.out.println(a / b); + } + } + `, + expectedLines: ["0.4"], + } ]; export const arithmeticExpressionTest = () => describe("arithmetic expression", () => { diff --git a/src/compiler/__tests__/tests/assignmentExpression.test.ts b/src/compiler/__tests__/tests/assignmentExpression.test.ts new file mode 100644 index 00000000..5950de37 --- /dev/null +++ b/src/compiler/__tests__/tests/assignmentExpression.test.ts @@ -0,0 +1,124 @@ +import { + runTest, + testCase, +} from "../__utils__/test-utils"; + +const testCases: testCase[] = [ + { + comment: "int to double assignment", + program: ` + public class Main { + public static void main(String[] args) { + int x = 5; + double y = x; + System.out.println(y); + } + } + `, + expectedLines: ["5.0"], + }, + { + comment: "int to double conversion", + program: ` + public class Main { + public static void main(String[] args) { + int x = 5; + double y; + y = x; + System.out.println(y); + } + } + `, + expectedLines: ["5.0"], + }, + { + comment: "int to double conversion, array", + program: ` + public class Main { + public static void main(String[] args) { + int x = 6; + double[] y = {1.0, 2.0, 3.0, 4.0, 5.0}; + y[1] = x; + System.out.println(y[1]); + } + } + `, + expectedLines: ["6.0"], + }, + { + comment: "int to long", + program: ` + public class Main { + public static void main(String[] args) { + int a = 123; + long b = a; + System.out.println(b); + } + } + `, + expectedLines: ["123"], + }, + { + comment: "int to float", + program: ` + public class Main { + public static void main(String[] args) { + int a = 123; + float b = a; + System.out.println(b); + } + } + `, + expectedLines: ["123.0"], + }, + + // long -> other types + { + comment: "long to float", + program: ` + public class Main { + public static void main(String[] args) { + long a = 9223372036854775807L; + float b = a; + System.out.println(b); + } + } + `, + expectedLines: ["9.223372E18"], + }, + { + comment: "long to double", + program: ` + public class Main { + public static void main(String[] args) { + long a = 9223372036854775807L; + double b = a; + System.out.println(b); + } + } + `, + expectedLines: ["9.223372036854776E18"], + }, + + // float -> other types + { + comment: "float to double", + program: ` + public class Main { + public static void main(String[] args) { + float a = 3.0f; + double b = a; + System.out.println(b); + } + } + `, + expectedLines: ["3.0"], + }, +]; + +export const assignmentExpressionTest = () => describe("assignment expression", () => { + for (let testCase of testCases) { + const { comment: comment, program: program, expectedLines: expectedLines } = testCase; + it(comment, () => runTest(program, expectedLines)); + } +}); diff --git a/src/compiler/__tests__/tests/castExpression.test.ts b/src/compiler/__tests__/tests/castExpression.test.ts new file mode 100644 index 00000000..e811ec66 --- /dev/null +++ b/src/compiler/__tests__/tests/castExpression.test.ts @@ -0,0 +1,145 @@ +import { + runTest, + testCase, +} from "../__utils__/test-utils"; + +const testCases: testCase[] = [ + { + comment: "Simple primitive casting: int to float", + program: ` + public class Main { + public static void main(String[] args) { + int a = 5; + float b = (float) a; + System.out.println(b); + } + } + `, + expectedLines: ["5.0"], + }, + { + comment: "Simple primitive casting: float to int", + program: ` + public class Main { + public static void main(String[] args) { + float a = 2.9f; + int b = (int) a; + System.out.println(b); + } + } + `, + expectedLines: ["2"], + }, + { + comment: "Primitive casting: double to long", + program: ` + public class Main { + public static void main(String[] args) { + double a = 123456789.987; + long b = (long) a; + System.out.println(b); + } + } + `, + expectedLines: ["123456789"], + }, + { + comment: "Primitive casting: long to byte", + program: ` + public class Main { + public static void main(String[] args) { + long a = 257; + byte b = (byte) a; + System.out.println(b); + } + } + `, + expectedLines: ["1"], // byte wraps around at 256 + }, + { + comment: "Primitive casting: char to int", + program: ` + public class Main { + public static void main(String[] args) { + char a = 'A'; + int b = (int) a; + System.out.println(b); + } + } + `, + expectedLines: ["65"], + }, + { + comment: "Primitive casting: int to char", + program: ` + public class Main { + public static void main(String[] args) { + int a = 65; + char b = (char) a; + System.out.println(b); + } + } + `, + expectedLines: ["A"], + }, + { + comment: "Primitive casting: int to char", + program: ` + public class Main { + public static void main(String[] args) { + int a = 66; + char b = (char) a; + System.out.println(b); + } + } + `, + expectedLines: ["B"], + }, + { + comment: "Primitive casting with loss of precision", + program: ` + public class Main { + public static void main(String[] args) { + double a = 123.456; + int b = (int) a; + System.out.println(b); + } + } + `, + expectedLines: ["123"], + }, + { + comment: "Primitive casting: float to short", + program: ` + public class Main { + public static void main(String[] args) { + float a = 32768.0f; + short b = (short) a; + System.out.println(b); + } + } + `, + expectedLines: ["-32768"], // short wraps around + }, + { + comment: "Chained casting: double to int to byte", + program: ` + public class Main { + public static void main(String[] args) { + double a = 258.99; + int b = (int) a; + byte c = (byte) b; + System.out.println(c); + } + } + `, + expectedLines: ["2"], // 258 -> byte wraps around + }, +]; + +export const castExpressionTest = () => describe("cast expression", () => { + for (let testCase of testCases) { + const { comment: comment, program: program, expectedLines: expectedLines } = testCase; + it(comment, () => runTest(program, expectedLines)); + } +}); \ No newline at end of file diff --git a/src/compiler/__tests__/tests/methodOverloading.test.ts b/src/compiler/__tests__/tests/methodOverloading.test.ts new file mode 100644 index 00000000..5d7322be --- /dev/null +++ b/src/compiler/__tests__/tests/methodOverloading.test.ts @@ -0,0 +1,192 @@ +import { + runTest, + testCase, +} from "../__utils__/test-utils"; + +const testCases: testCase[] = [ + { + comment: "Basic method overloading", + program: ` + public class Main { + public static void f(int x) { + System.out.println("int: " + x); + } + public static void f(double x) { + System.out.println("double: " + x); + } + public static void main(String[] args) { + f(5); + f(5.5); + } + } + `, + expectedLines: ["int: 5", "double: 5.5"], + }, + { + comment: "Overloaded methods with different parameter counts", + program: ` + public class Main { + public static void f(int x) { + System.out.println("single param: " + x); + } + public static void f(int x, int y) { + System.out.println("two params: " + (x + y)); + } + public static void main(String[] args) { + f(3); + f(3, 4); + } + } + `, + expectedLines: ["single param: 3", "two params: 7"], + }, + { + comment: "Method overloading with different return types", + program: ` + public class Main { + public static int f(int x) { + return x * 2; + } + public static String f(String s) { + return s + "!"; + } + public static void main(String[] args) { + System.out.println(f(4)); + System.out.println(f("Hello")); + } + } + `, + expectedLines: ["8", "Hello!"], + }, + { + comment: "Overloading with implicit type conversion", + program: ` + public class Main { + public static void f(int x) { + System.out.println("int version: " + x); + } + public static void f(long x) { + System.out.println("long version: " + x); + } + public static void main(String[] args) { + f(10); // should call int version + f(10L); // should call long version + } + } + `, + expectedLines: ["int version: 10", "long version: 10"], + }, + { + comment: "Ambiguous method overloading", + program: ` + public class Main { + public static void f(int x, double y) { + System.out.println("int, double"); + } + public static void f(double x, int y) { + System.out.println("double, int"); + } + public static void main(String[] args) { + f(5, 5.0); + f(5.0, 5); + } + } + `, + expectedLines: ["int, double", "double, int"], + }, + { + comment: "Overloading with reference types", + program: ` + public class Main { + public static void f(String s) { + System.out.println("String"); + } + public static void f(Main m) { + System.out.println("Main"); + } + public static void main(String[] args) { + f("Hello"); // should call String version + f(new Main()); // should call Main version + } + } + `, + expectedLines: ["String", "Main"], + }, + { + comment: "Overloaded instance and static methods", + program: ` + public class Main { + public void f() { + System.out.println("Instance method"); + } + public static void f(int x) { + System.out.println("Static method with int: " + x); + } + public static void main(String[] args) { + Main obj = new Main(); + obj.f(); + f(5); + } + } + `, + expectedLines: ["Instance method", "Static method with int: 5"], + }, + { + comment: "Overloaded instance methods", + program: ` + public class Main { + public void f(int x) { + System.out.println("Instance int: " + x); + } + public void f(double x) { + System.out.println("Instance double: " + x); + } + public static void main(String[] args) { + Main obj = new Main(); + obj.f(5); + obj.f(5.5); + } + } + `, + expectedLines: ["Instance int: 5", "Instance double: 5.5"], + }, + { + comment: "Implicit conversion during method invocation", + program: ` + public class Main { + public static void f(double x) { + System.out.println("Converted double: " + x); + } + public static void main(String[] args) { + f(10); // Implicitly converts int to double + } + } + `, + expectedLines: ["Converted double: 10.0"], + }, + { + comment: "Overloading with widening conversion", + program: ` + public class Main { + public static void f(long x) { + System.out.println("long version: " + x); + } + public static void f(double x) { + System.out.println("double version: " + x); + } + public static void main(String[] args) { + f(5); // Should call long version + f(5.0f); // Should call double version + } + } + `, + expectedLines: ["long version: 5", "double version: 5.0"], + } +]; + +export const methodOverloadingTest = () => describe("method overloading", () => { + for (let testCase of testCases) { + const { comment, program, expectedLines } = testCase; + it(comment, () => runTest(program, expectedLines)); + } +}); \ No newline at end of file diff --git a/src/compiler/__tests__/tests/methodOverriding.test.ts b/src/compiler/__tests__/tests/methodOverriding.test.ts new file mode 100644 index 00000000..15a89097 --- /dev/null +++ b/src/compiler/__tests__/tests/methodOverriding.test.ts @@ -0,0 +1,330 @@ +import { runTest, testCase } from '../__utils__/test-utils' + +const testCases: testCase[] = [ + { + comment: 'Basic method overriding', + program: ` + class Parent { + public void show() { + System.out.println("Parent show"); + } + } + class Child extends Parent { + public void show() { + System.out.println("Child show"); + } + } + public class Main { + public static void main(String[] args) { + Parent p = new Parent(); + p.show(); // Parent show + Child c = new Child(); + c.show(); // Child show + Parent ref = new Child(); + ref.show(); // Child show (dynamic dispatch) + } + } + `, + expectedLines: ['Parent show', 'Child show', 'Child show'] + }, + { + comment: 'Overriding with different access modifiers', + program: ` + class Parent { + protected void display() { + System.out.println("Parent display"); + } + } + class Child extends Parent { + public void display() { // Increased visibility + System.out.println("Child display"); + } + } + public class Main { + public static void main(String[] args) { + Parent ref = new Child(); + ref.display(); // Child display + } + } + `, + expectedLines: ['Child display'] + }, + { + comment: 'Overriding with multiple levels of inheritance', + program: ` + class GrandParent { + public void greet() { + System.out.println("Hello from GrandParent"); + } + } + class Parent extends GrandParent { + public void greet() { + System.out.println("Hello from Parent"); + } + } + class Child extends Parent { + public void greet() { + System.out.println("Hello from Child"); + } + } + public class Main { + public static void main(String[] args) { + GrandParent ref1 = new GrandParent(); + ref1.greet(); // GrandParent + GrandParent ref2 = new Parent(); + ref2.greet(); // Parent + GrandParent ref3 = new Child(); + ref3.greet(); // Child + } + } + `, + expectedLines: ['Hello from GrandParent', 'Hello from Parent', 'Hello from Child'] + }, + { + comment: 'Overriding and method hiding with static methods', + program: ` + class Parent { + public static void staticMethod() { + System.out.println("Parent static method"); + } + public void instanceMethod() { + System.out.println("Parent instance method"); + } + } + class Child extends Parent { + public static void staticMethod() { + System.out.println("Child static method"); + } + public void instanceMethod() { + System.out.println("Child instance method"); + } + } + public class Main { + public static void main(String[] args) { + Parent.staticMethod(); // Parent static method + Child.staticMethod(); // Child static method + Parent ref = new Child(); + ref.instanceMethod(); // Child instance method + } + } + `, + expectedLines: ['Parent static method', 'Child static method', 'Child instance method'] + }, + { + comment: 'Overriding final methods (should cause compilation error)', + program: ` + class Parent { + public final void show() { + System.out.println("Final method in Parent"); + } + } + class Child extends Parent { + // public void show() {} // Uncommenting should cause compilation error + } + public class Main { + public static void main(String[] args) { + Parent p = new Parent(); + p.show(); // Final method in Parent + } + } + `, + expectedLines: ['Final method in Parent'] + }, + { + comment: 'Overriding in a deep class hierarchy', + program: ` + class A { + public void test() { + System.out.println("A test"); + } + } + class B extends A { + public void test() { + System.out.println("B test"); + } + } + class C extends B { + public void test() { + System.out.println("C test"); + } + } + class D extends C { + public void test() { + System.out.println("D test"); + } + } + public class Main { + public static void main(String[] args) { + A ref1 = new D(); + B ref2 = new C(); + ref1.test(); // D test + ref2.test(); // C test + } + } + `, + expectedLines: ['D test', 'C test'] + }, + { + comment: 'Overriding private methods (should not override, treated as new method)', + program: ` + class Parent { + private void secret() { + System.out.println("Parent secret"); + } + } + class Child extends Parent { + public void secret() { + System.out.println("Child secret"); + } + } + public class Main { + public static void main(String[] args) { + Child c = new Child(); + c.secret(); // Child secret + } + } + `, + expectedLines: ['Child secret'] + }, + { + comment: 'Using this to call an instance method', + program: ` + class Self { + public void print() { + System.out.println("Self print"); + } + public void callSelf() { + this.print(); + } + } + public class Main { + public static void main(String[] args) { + Self s = new Self(); + s.callSelf(); // Self print + } + } + `, + expectedLines: ['Self print'] + }, + { + comment: 'Using super to invoke parent method', + program: ` + class Base { + public void greet() { + System.out.println("Hello from Base"); + } + } + class Derived extends Base { + public void greet() { + super.greet(); + System.out.println("Hello from Derived"); + } + } + public class Main { + public static void main(String[] args) { + Derived d = new Derived(); + d.greet(); + // Expected: + // Hello from Base + // Hello from Derived + } + } + `, + expectedLines: ['Hello from Base', 'Hello from Derived'] + }, + { + comment: 'Polymorphic call with dynamic dispatch', + program: ` + class Animal { + public void speak() { + System.out.println("Animal sound"); + } + } + class Dog extends Animal { + public void speak() { + System.out.println("Bark"); + } + public void callSuper() { + super.speak(); + } + } + public class Main { + public static void main(String[] args) { + Dog d = new Dog(); + d.speak(); // Bark + d.callSuper(); // Animal sound + } + } + `, + expectedLines: ['Bark', 'Animal sound'] + }, + { + comment: 'Method overloading resolution', + program: ` + class Overload { + public void test(int a) { + System.out.println("int"); + } + public void test(double a) { + System.out.println("double"); + } + } + public class Main { + public static void main(String[] args) { + Overload o = new Overload(); + o.test(5); // int + o.test(5.0); // double + } + } + `, + expectedLines: ['int', 'double'] + }, + { + comment: 'Overriding on a superclass reference', + program: ` + class X { + public void foo() { + System.out.println("X foo"); + } + } + class Y extends X { + public void foo() { + System.out.println("Y foo"); + } + } + public class Main { + public static void main(String[] args) { + X x = new Y(); + x.foo(); // Y foo + } + } + `, + expectedLines: ['Y foo'] + }, + { + comment: 'Implicit conversion (byte to int)', + program: ` + class Implicit { + public void process(int a) { + System.out.println("Processed int"); + } + } + public class Main { + public static void main(String[] args) { + Implicit imp = new Implicit(); + byte b = (byte) 10; + imp.process(b); // Processed int + } + } + `, + expectedLines: ['Processed int'] + } +] + +export const methodOverridingTest = () => + describe('method overriding', () => { + for (let testCase of testCases) { + const { comment, program, expectedLines } = testCase + it(comment, () => runTest(program, expectedLines)) + } + }) diff --git a/src/compiler/__tests__/tests/println.test.ts b/src/compiler/__tests__/tests/println.test.ts index 4e2e4635..9b5ebbea 100644 --- a/src/compiler/__tests__/tests/println.test.ts +++ b/src/compiler/__tests__/tests/println.test.ts @@ -98,6 +98,22 @@ const testCases: testCase[] = [ `, expectedLines: ["true", "false"], }, + { + comment: "println with concatenated arguments", + program: ` + public class Main { + public static void main(String[] args) { + System.out.println("Hello" + " " + "world!"); + System.out.println("This is an int: " + 123); + System.out.println("This is a float: " + 4.5f); + System.out.println("This is a long: " + 10000000000L); + System.out.println("This is a double: " + 10.3); + } + } + `, + expectedLines: ["Hello world!", "This is an int: 123", "This is a float: 4.5", + "This is a long: 10000000000", "This is a double: 10.3"], + }, { comment: "multiple println statements", program: ` diff --git a/src/compiler/__tests__/tests/switch.test.ts b/src/compiler/__tests__/tests/switch.test.ts new file mode 100644 index 00000000..003e9ff4 --- /dev/null +++ b/src/compiler/__tests__/tests/switch.test.ts @@ -0,0 +1,203 @@ +import { runTest, testCase } from '../__utils__/test-utils' + +const testCases: testCase[] = [ + { + comment: 'More basic switch case', + program: ` + public class Main { + public static void main(String[] args) { + int x = 1; + switch (x) { + case 1: + System.out.println("One"); + break; + } + } + } + `, + expectedLines: ['One'] + }, + { + comment: 'Basic switch case', + program: ` + public class Main { + public static void main(String[] args) { + int x = 2; + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + case 3: + System.out.println("Three"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Two'] + }, + { + comment: 'Switch with default case', + program: ` + public class Main { + public static void main(String[] args) { + int x = 5; + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Default'] + }, + { + comment: 'Switch fallthrough behavior', + program: ` + public class Main { + public static void main(String[] args) { + int x = 2; + switch (x) { + case 1: + System.out.println("One"); + case 2: + System.out.println("Two"); + case 3: + System.out.println("Three"); + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Two', 'Three', 'Default'] + }, + { + comment: 'Switch with break statements', + program: ` + public class Main { + public static void main(String[] args) { + int x = 3; + switch (x) { + case 1: + System.out.println("One"); + break; + case 2: + System.out.println("Two"); + break; + case 3: + System.out.println("Three"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Three'] + }, + { + comment: 'Switch with strings', + program: ` + public class Main { + public static void main(String[] args) { + String day = "Tuesday"; + switch (day) { + case "Monday": + System.out.println("Start of the week"); + break; + case "Tuesday": + System.out.println("Second day"); + break; + case "Friday": + System.out.println("Almost weekend"); + break; + default: + System.out.println("Midweek or weekend"); + } + } + } + `, + expectedLines: ['Second day'] + }, + { + comment: 'Nested switch statements', + program: ` + public class Main { + public static void main(String[] args) { + int outer = 2; + int inner = 1; + switch (outer) { + case 1: + switch (inner) { + case 1: + System.out.println("Inner One"); + break; + case 2: + System.out.println("Inner Two"); + break; + } + break; + case 2: + switch (inner) { + case 1: + System.out.println("Outer Two, Inner One"); + break; + case 2: + System.out.println("Outer Two, Inner Two"); + break; + } + break; + default: + System.out.println("Default case"); + } + } + } + `, + expectedLines: ['Outer Two, Inner One'] + }, + + { + comment: 'Switch with far apart cases', + program: ` + public class Main { + public static void main(String[] args) { + int x = 1331; + switch (x) { + case 1: + System.out.println("No"); + break; + case 1331: + System.out.println("Yes"); + break; + case 999999999: + System.out.println("No"); + break; + default: + System.out.println("Default"); + } + } + } + `, + expectedLines: ['Yes'] + } +] + +export const switchTest = () => + describe('Switch statements', () => { + for (let testCase of testCases) { + const { comment, program, expectedLines } = testCase + it(comment, () => runTest(program, expectedLines)) + } + }) diff --git a/src/compiler/__tests__/tests/unaryExpression.test.ts b/src/compiler/__tests__/tests/unaryExpression.test.ts index 175c6a2d..0e9aa469 100644 --- a/src/compiler/__tests__/tests/unaryExpression.test.ts +++ b/src/compiler/__tests__/tests/unaryExpression.test.ts @@ -159,6 +159,45 @@ const testCases: testCase[] = [ }`, expectedLines: ["10", "10", "-10", "-10", "-10", "-10", "10", "9", "-10"], }, + { + comment: "unary plus/minus for long", + program: ` + public class Main { + public static void main(String[] args) { + long a = 9223372036854775807L; + System.out.println(+a); + System.out.println(-a); + } + } + `, + expectedLines: ["9223372036854775807", "-9223372036854775807"], + }, + { + comment: "unary plus/minus for float", + program: ` + public class Main { + public static void main(String[] args) { + float a = 4.5f; + System.out.println(+a); + System.out.println(-a); + } + } + `, + expectedLines: ["4.5", "-4.5"], + }, + { + comment: "unary plus/minus for double", + program: ` + public class Main { + public static void main(String[] args) { + double a = 10.75; + System.out.println(+a); + System.out.println(-a); + } + } + `, + expectedLines: ["10.75", "-10.75"], + }, { comment: "bitwise complement", program: ` diff --git a/src/compiler/code-generator.ts b/src/compiler/code-generator.ts index 1b0d8c86..3bea9dbc 100644 --- a/src/compiler/code-generator.ts +++ b/src/compiler/code-generator.ts @@ -24,11 +24,19 @@ import { ClassInstanceCreationExpression, ExpressionStatement, TernaryExpression, - LeftHandSide + LeftHandSide, + CastExpression, + SwitchStatement, + SwitchCase, + CaseLabel } from '../ast/types/blocks-and-statements' import { MethodDeclaration, UnannType } from '../ast/types/classes' import { ConstantPoolManager } from './constant-pool-manager' -import { ConstructNotSupportedError, InvalidMethodCallError } from './error' +import { + AmbiguousMethodCallError, + ConstructNotSupportedError, + NoMethodMatchingSignatureError +} from './error' import { FieldInfo, MethodInfos, SymbolInfo, SymbolTable, VariableInfo } from './symbol-table' type Label = { @@ -164,12 +172,203 @@ const normalStoreOp: { [type: string]: OPCODE } = { Z: OPCODE.ISTORE } +const typeConversions: { [key: string]: OPCODE } = { + 'I->F': OPCODE.I2F, + 'I->D': OPCODE.I2D, + 'I->J': OPCODE.I2L, + 'I->B': OPCODE.I2B, + 'I->C': OPCODE.I2C, + 'I->S': OPCODE.I2S, + 'F->D': OPCODE.F2D, + 'F->I': OPCODE.F2I, + 'F->J': OPCODE.F2L, + 'D->F': OPCODE.D2F, + 'D->I': OPCODE.D2I, + 'D->J': OPCODE.D2L, + 'J->I': OPCODE.L2I, + 'J->F': OPCODE.L2F, + 'J->D': OPCODE.L2D +} + +const typeConversionsImplicit: { [key: string]: OPCODE } = { + 'I->F': OPCODE.I2F, + 'I->D': OPCODE.I2D, + 'I->J': OPCODE.I2L, + 'F->D': OPCODE.F2D, + 'J->F': OPCODE.L2F, + 'J->D': OPCODE.L2D +} + type CompileResult = { stackSize: number resultType: string } const EMPTY_TYPE: string = '' +function areClassTypesCompatible(fromType: string, toType: string, cg: CodeGenerator): boolean { + const cleanFrom = fromType.replace(/^L|;$/g, '') + const cleanTo = toType.replace(/^L|;$/g, '') + if (cleanFrom === cleanTo) return true; + + try { + let current = cg.symbolTable.queryClass(cleanFrom); + while (current.parentClassName) { + const parentClean = current.parentClassName; + if (parentClean === cleanTo) return true; + current = cg.symbolTable.queryClass(parentClean); + } + } catch (e) { + return false; + } + return false; +} + +function handleImplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator): number { + if (fromType === toType || toType.replace(/^L|;$/g, '') === 'java/lang/String') { + return 0 + } + + if (fromType.startsWith('L') || toType.startsWith('L')) { + if (areClassTypesCompatible(fromType, toType, cg) || fromType === '') { + return 0 + } + throw new Error(`Unsupported class type conversion: ${fromType} -> ${toType}`) + } + + const conversionKey = `${fromType}->${toType}` + if (conversionKey in typeConversionsImplicit) { + cg.code.push(typeConversionsImplicit[conversionKey]) + if (!(fromType in ['J', 'D']) && toType in ['J', 'D']) { + return 1 + } else if (!(toType in ['J', 'D']) && fromType in ['J', 'D']) { + return -1 + } else { + return 0 + } + } else { + throw new Error(`Unsupported implicit type conversion: ${conversionKey}`) + } +} + +function handleExplicitTypeConversion(fromType: string, toType: string, cg: CodeGenerator) { + if (fromType === toType) { + return + } + const conversionKey = `${fromType}->${toType}` + if (conversionKey in typeConversions) { + cg.code.push(typeConversions[conversionKey]) + } else { + throw new Error(`Unsupported explicit type conversion: ${conversionKey}`) + } +} + +function generateStringConversion(valueType: string, cg: CodeGenerator): void { + const stringClass = 'java/lang/String' + + // Map primitive types to `String.valueOf()` method descriptors + const valueOfDescriptors: { [key: string]: string } = { + I: '(I)Ljava/lang/String;', // int + J: '(J)Ljava/lang/String;', // long + F: '(F)Ljava/lang/String;', // float + D: '(D)Ljava/lang/String;', // double + Z: '(Z)Ljava/lang/String;', // boolean + B: '(B)Ljava/lang/String;', // byte + S: '(S)Ljava/lang/String;', // short + C: '(C)Ljava/lang/String;' // char + } + + const descriptor = valueOfDescriptors[valueType] + if (!descriptor) { + throw new Error(`Unsupported primitive type for String conversion: ${valueType}`) + } + + const methodIndex = cg.constantPoolManager.indexMethodrefInfo(stringClass, 'valueOf', descriptor) + + cg.code.push(OPCODE.INVOKESTATIC, 0, methodIndex) +} + +function hashCode(str: string): number { + let hash = 0 + for (let i = 0; i < str.length; i++) { + hash = hash * 31 + str.charCodeAt(i) // Simulate Java's overflow behavior + } + return hash +} + +// function generateBooleanConversion(type: string, cg: CodeGenerator): number { +// let stackChange = 0; // Tracks changes to the stack size +// +// switch (type) { +// case 'I': // int +// case 'B': // byte +// case 'S': // short +// case 'C': // char +// // For integer-like types, compare with zero +// cg.code.push(OPCODE.ICONST_0); // Push 0 +// stackChange += 1; // `ICONST_0` pushes a value onto the stack +// cg.code.push(OPCODE.IF_ICMPNE); // Compare and branch +// stackChange -= 2; // `IF_ICMPNE` consumes two values from the stack +// break; +// +// case 'J': // long +// // For long, compare with zero +// cg.code.push(OPCODE.LCONST_0); // Push 0L +// stackChange += 2; // `LCONST_0` pushes two values onto the stack (long takes 2 slots) +// cg.code.push(OPCODE.LCMP); // Compare top two longs +// stackChange -= 4; // `LCMP` consumes four values (two long operands) and pushes one result +// cg.code.push(OPCODE.IFNE); // If not equal, branch +// stackChange -= 1; // `IFNE` consumes one value (the comparison result) +// break; +// +// case 'F': // float +// // For float, compare with zero +// cg.code.push(OPCODE.FCONST_0); // Push 0.0f +// stackChange += 1; // `FCONST_0` pushes a value onto the stack +// cg.code.push(OPCODE.FCMPL); // Compare top two floats +// stackChange -= 2; // `FCMPL` consumes two values (float operands) and pushes one result +// cg.code.push(OPCODE.IFNE); // If not equal, branch +// stackChange -= 1; // `IFNE` consumes one value (the comparison result) +// break; +// +// case 'D': // double +// // For double, compare with zero +// cg.code.push(OPCODE.DCONST_0); // Push 0.0d +// stackChange += 2; // `DCONST_0` pushes two values onto the stack (double takes 2 slots) +// cg.code.push(OPCODE.DCMPL); // Compare top two doubles +// stackChange -= 4; // `DCMPL` consumes four values (two double operands) and pushes one result +// cg.code.push(OPCODE.IFNE); // If not equal, branch +// stackChange -= 1; // `IFNE` consumes one value (the comparison result) +// break; +// +// case 'Z': // boolean +// // Already a boolean, no conversion needed +// break; +// +// default: +// throw new Error(`Cannot convert type ${type} to boolean.`); +// } +// +// return stackChange; // Return the net change in stack size +// } + +function getExpressionType(node: Node, cg: CodeGenerator): string { + if (!(node.kind in codeGenerators)) { + throw new ConstructNotSupportedError(node.kind) + } + const originalCode = [...cg.code] // Preserve the original code state + const resultType = codeGenerators[node.kind](node, cg).resultType + cg.code = originalCode // Restore the original code state + return resultType +} + +function isSubtype(fromType: string, toType: string, cg: CodeGenerator): boolean { + return ( + fromType === toType || + typeConversionsImplicit[`${fromType}->${toType}`] !== undefined || + areClassTypesCompatible(fromType, toType, cg) + ) +} + const isNullLiteral = (node: Node) => { return node.kind === 'Literal' && node.literalType.kind === 'NullLiteral' } @@ -245,13 +444,20 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi vi.forEach((val, i) => { cg.code.push(OPCODE.DUP) const size1 = compile(createIntLiteralNode(i), cg).stackSize - const size2 = compile(val as Expression, cg).stackSize + const { stackSize: size2, resultType } = compile(val as Expression, cg) + const stackSizeChange = handleImplicitTypeConversion(resultType, arrayElemType, cg) cg.code.push(arrayElemType in arrayStoreOp ? arrayStoreOp[arrayElemType] : OPCODE.AASTORE) - maxStack = Math.max(maxStack, 2 + size1 + size2) + maxStack = Math.max(maxStack, 2 + size1 + size2 + stackSizeChange) }) cg.code.push(OPCODE.ASTORE, curIdx) } else { - maxStack = Math.max(maxStack, compile(vi, cg).stackSize) + const { stackSize: initializerStackSize, resultType: initializerType } = compile(vi, cg) + const stackSizeChange = handleImplicitTypeConversion( + initializerType, + variableInfo.typeDescriptor, + cg + ) + maxStack = Math.max(maxStack, initializerStackSize + stackSizeChange) cg.code.push( variableInfo.typeDescriptor in normalStoreOp ? normalStoreOp[variableInfo.typeDescriptor] @@ -276,7 +482,15 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi }, BreakStatement: (node: Node, cg: CodeGenerator) => { - cg.addBranchInstr(OPCODE.GOTO, cg.loopLabels[cg.loopLabels.length - 1][1]) + if (cg.loopLabels.length > 0) { + // If inside a loop, break jumps to the end of the loop + cg.addBranchInstr(OPCODE.GOTO, cg.loopLabels[cg.loopLabels.length - 1][1]) + } else if (cg.switchLabels.length > 0) { + // If inside a switch, break jumps to the switch's end label + cg.addBranchInstr(OPCODE.GOTO, cg.switchLabels[cg.switchLabels.length - 1]) + } else { + throw new Error('Break statement not inside a loop or switch statement') + } return { stackSize: 0, resultType: EMPTY_TYPE } }, @@ -429,6 +643,11 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi cg.addBranchInstr(OPCODE.GOTO, targetLabel) } return { stackSize: 0, resultType: cg.symbolTable.generateFieldDescriptor('boolean') } + } else { + if (onTrue === (parseInt(value) !== 0)) { + cg.addBranchInstr(OPCODE.GOTO, targetLabel) + } + return { stackSize: 0, resultType: cg.symbolTable.generateFieldDescriptor('boolean') } } } @@ -530,13 +749,12 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi }) const argDescriptor = '(' + argTypes.join('') + ')' - const symbolInfos = cg.symbolTable.queryMethod('') - const methodInfos = symbolInfos[symbolInfos.length - 1] as MethodInfos + const methodInfos = cg.symbolTable.queryMethod('') as MethodInfos for (let i = 0; i < methodInfos.length; i++) { const methodInfo = methodInfos[i] - if (methodInfo.typeDescriptor.includes(argDescriptor)) { + if (methodInfo.typeDescriptor.includes(argDescriptor) && methodInfo.className == id) { const method = cg.constantPoolManager.indexMethodrefInfo( - methodInfo.parentClassName, + methodInfo.className, methodInfo.name, methodInfo.typeDescriptor ) @@ -570,70 +788,118 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const n = node as MethodInvocation let maxStack = 1 let resultType = EMPTY_TYPE + let candidateMethods: MethodInfos = [] + let unqualifiedCall = false + + // --- Handle super. calls --- + if (n.identifier.startsWith('super.')) { + candidateMethods = cg.symbolTable.queryMethod(n.identifier.slice(6)) as MethodInfos + candidateMethods.filter(method => + method.className == cg.symbolTable.queryClass(cg.currentClass).parentClassName) + cg.code.push(OPCODE.ALOAD, 0); + } + // --- Handle qualified calls (e.g. System.out.println or p.show) --- + else if (n.identifier.includes('.')) { + const lastDot = n.identifier.lastIndexOf('.'); + const receiverStr = n.identifier.slice(0, lastDot); + + if (receiverStr === 'this') { + candidateMethods = cg.symbolTable.queryMethod(n.identifier.slice(5)) as MethodInfos + console.debug(candidateMethods) + candidateMethods.filter(method => + method.className == cg.currentClass) + cg.code.push(OPCODE.ALOAD, 0); + } else { + const recvRes = compile({ kind: 'ExpressionName', name: receiverStr }, cg); + maxStack = Math.max(maxStack, recvRes.stackSize); + candidateMethods = cg.symbolTable.queryMethod(n.identifier).pop() as MethodInfos + } + } + // --- Handle unqualified calls --- + else { + candidateMethods = cg.symbolTable.queryMethod(n.identifier) as MethodInfos + unqualifiedCall = true; + } - const symbolInfos = cg.symbolTable.queryMethod(n.identifier) - for (let i = 0; i < symbolInfos.length - 1; i++) { - if (i === 0) { - const varInfo = symbolInfos[i] as VariableInfo - if (varInfo.index !== undefined) { - cg.code.push(OPCODE.ALOAD, varInfo.index) - continue + // Filter candidate methods by matching the argument list. + const argDescs = n.argumentList.map(arg => getExpressionType(arg, cg)) + const methodMatches: MethodInfos = [] + + for (let i = 0; i < candidateMethods.length; i++) { + const m = candidateMethods[i] + const fullDesc = m.typeDescriptor // e.g., "(Ljava/lang/String;C)V" + const paramPart = fullDesc.slice(1, fullDesc.indexOf(')')) + const params = paramPart.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] + if (params.length !== argDescs.length) continue + let match = true + for (let i = 0; i < params.length; i++) { + const argType = argDescs[i] + // Allow B/S to match int. + if ((argType === 'B' || argType === 'S') && params[i] === 'I') continue + if (!isSubtype(argType, params[i], cg)) { + match = false + break } } - const fieldInfo = symbolInfos[i] as FieldInfo - const field = cg.constantPoolManager.indexFieldrefInfo( - fieldInfo.parentClassName, - fieldInfo.name, - fieldInfo.typeDescriptor - ) - cg.code.push( - fieldInfo.accessFlags & FIELD_FLAGS.ACC_STATIC ? OPCODE.GETSTATIC : OPCODE.GETFIELD, - 0, - field - ) + if (match) methodMatches.push(m) + } + if (methodMatches.length === 0) { + throw new NoMethodMatchingSignatureError(n.identifier + argDescs.join(',')) } - const argTypes: Array = [] - n.argumentList.forEach((x, i) => { - const argCompileResult = compile(x, cg) - maxStack = Math.max(maxStack, i + 1 + argCompileResult.stackSize) - argTypes.push(argCompileResult.resultType) - }) - const argDescriptor = '(' + argTypes.join('') + ')' - - let foundMethod = false - const methodInfos = symbolInfos[symbolInfos.length - 1] as MethodInfos - for (let i = 0; i < methodInfos.length; i++) { - const methodInfo = methodInfos[i] - if (methodInfo.typeDescriptor.includes(argDescriptor)) { - const method = cg.constantPoolManager.indexMethodrefInfo( - methodInfo.parentClassName, - methodInfo.name, - methodInfo.typeDescriptor - ) + // Overload resolution (simple: choose first, or refine if needed) + let selectedMethod = methodMatches[0] + if (methodMatches.length > 1) { + for (let i = 1; i < methodMatches.length; i++) { + const currParams = + selectedMethod.typeDescriptor + .slice(1, selectedMethod.typeDescriptor.indexOf(')')) + .match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] + const candParams = + methodMatches[i].typeDescriptor + .slice(1, methodMatches[i].typeDescriptor.indexOf(')')) + .match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] if ( - n.identifier.startsWith('this.') && - !(methodInfo.accessFlags & FIELD_FLAGS.ACC_STATIC) + candParams.map((p, idx) => isSubtype(p, currParams[idx], cg)).reduce((a, b) => a && b, true) ) { - // load "this" - cg.code.push(OPCODE.ALOAD, 0) + selectedMethod = methodMatches[i] + } else if ( + !currParams.map((p, idx) => isSubtype(p, candParams[idx], cg)).reduce((a, b) => a && b, true) + ) { + throw new AmbiguousMethodCallError(n.identifier + argDescs.join(',')) } - cg.code.push( - methodInfo.accessFlags & METHOD_FLAGS.ACC_STATIC - ? OPCODE.INVOKESTATIC - : OPCODE.INVOKEVIRTUAL, - 0, - method - ) - resultType = methodInfo.typeDescriptor.slice(argDescriptor.length) - foundMethod = true - break } } - if (!foundMethod) { - throw new InvalidMethodCallError(n.identifier) + if (unqualifiedCall && !(selectedMethod.accessFlags & FIELD_FLAGS.ACC_STATIC)) { + cg.code.push(OPCODE.ALOAD, 0) } + + // Compile each argument. + const fullDescriptor = selectedMethod.typeDescriptor + const paramPart = fullDescriptor.slice(1, fullDescriptor.indexOf(')')) + const params = paramPart.match(/(\[+[BCDFIJSZ])|(\[+L[^;]+;)|[BCDFIJSZ]|L[^;]+;/g) || [] + n.argumentList.forEach((arg, i) => { + const argRes = compile(arg, cg) + let argType = argRes.resultType + if (argType === 'B' || argType === 'S') argType = 'I' + const conv = handleImplicitTypeConversion(argType, params[i] || '', cg) + maxStack = Math.max(maxStack, i + 1 + argRes.stackSize + conv) + }) + + // Emit the method call. + const methodRef = cg.constantPoolManager.indexMethodrefInfo( + selectedMethod.className, + selectedMethod.name, + selectedMethod.typeDescriptor + ) + if (n.identifier.startsWith('super.')) { + cg.code.push(OPCODE.INVOKESPECIAL, 0, methodRef) + } else { + const isStatic = (selectedMethod.accessFlags & METHOD_FLAGS.ACC_STATIC) !== 0 + cg.code.push(isStatic ? OPCODE.INVOKESTATIC : OPCODE.INVOKEVIRTUAL, 0, methodRef) + } + resultType = selectedMethod.typeDescriptor.slice(selectedMethod.typeDescriptor.indexOf(')') + 1) return { stackSize: maxStack, resultType: resultType } }, @@ -662,15 +928,20 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi if (lhs.kind === 'ArrayAccess') { const { stackSize: size1, resultType: arrayType } = compile(lhs.primary, cg) const size2 = compile(lhs.expression, cg).stackSize - maxStack = size1 + size2 + compile(right, cg).stackSize + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) + const arrayElemType = arrayType.slice(1) + const stackSizeChange = handleImplicitTypeConversion(rhsType, arrayElemType, cg) + maxStack = Math.max(maxStack, size1 + size2 + rhsSize + stackSizeChange) cg.code.push(arrayElemType in arrayStoreOp ? arrayStoreOp[arrayElemType] : OPCODE.AASTORE) } else if ( lhs.kind === 'ExpressionName' && !Array.isArray(cg.symbolTable.queryVariable(lhs.name)) ) { const info = cg.symbolTable.queryVariable(lhs.name) as VariableInfo - maxStack = 1 + compile(right, cg).stackSize + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) + const stackSizeChange = handleImplicitTypeConversion(rhsType, info.typeDescriptor, cg) + maxStack = Math.max(maxStack, 1 + rhsSize + stackSizeChange) cg.code.push( info.typeDescriptor in normalStoreOp ? normalStoreOp[info.typeDescriptor] : OPCODE.ASTORE, info.index @@ -693,7 +964,11 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi cg.code.push(OPCODE.ALOAD, 0) maxStack += 1 } - maxStack += compile(right, cg).stackSize + + const { stackSize: rhsSize, resultType: rhsType } = compile(right, cg) + const stackSizeChange = handleImplicitTypeConversion(rhsType, fieldInfo.typeDescriptor, cg) + + maxStack = Math.max(maxStack, maxStack + rhsSize + stackSizeChange) cg.code.push( fieldInfo.accessFlags & FIELD_FLAGS.ACC_STATIC ? OPCODE.PUTSTATIC : OPCODE.PUTFIELD, 0, @@ -737,10 +1012,94 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } } - const { stackSize: size1, resultType: type } = compile(left, cg) - const { stackSize: size2 } = compile(right, cg) + const { stackSize: size1, resultType: leftType } = compile(left, cg) + const insertConversionIndex = cg.code.length + cg.code.push(OPCODE.NOP) + const { stackSize: size2, resultType: rightType } = compile(right, cg) + + if (op === '+' && (leftType === 'Ljava/lang/String;' || rightType === 'Ljava/lang/String;')) { + if (leftType !== 'Ljava/lang/String;') { + generateStringConversion(leftType, cg) + } + + if (rightType !== 'Ljava/lang/String;') { + generateStringConversion(rightType, cg) + } + + // Invoke `String.concat` for concatenation + const concatMethodIndex = cg.constantPoolManager.indexMethodrefInfo( + 'java/lang/String', + 'concat', + '(Ljava/lang/String;)Ljava/lang/String;' + ) + cg.code.push(OPCODE.INVOKEVIRTUAL, 0, concatMethodIndex) - switch (type) { + return { + stackSize: Math.max(size1 + 1, size2 + 1), // Max stack size plus one for the concatenation + resultType: 'Ljava/lang/String;' + } + } + + let finalType = leftType + + if (leftType !== rightType) { + const conversionKeyLeft = `${leftType}->${rightType}` + const conversionKeyRight = `${rightType}->${leftType}` + + if (['D', 'F'].includes(leftType) || ['D', 'F'].includes(rightType)) { + // Promote both to double if one is double, or to float otherwise + if (leftType !== 'D' && rightType === 'D') { + cg.code.fill( + typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, + insertConversionIndex + 1 + ) + finalType = 'D' + } else if (leftType === 'D' && rightType !== 'D') { + cg.code.push(typeConversionsImplicit[conversionKeyRight]) + finalType = 'D' + } else if (leftType !== 'F' && rightType === 'F') { + // handleImplicitTypeConversion(leftType, 'F', cg); + cg.code.fill( + typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, + insertConversionIndex + 1 + ) + finalType = 'F' + } else if (leftType === 'F' && rightType !== 'F') { + cg.code.push(typeConversionsImplicit[conversionKeyRight]) + finalType = 'F' + } + } else if (['J'].includes(leftType) || ['J'].includes(rightType)) { + // Promote both to long if one is long + if (leftType !== 'J' && rightType === 'J') { + cg.code.fill( + typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, + insertConversionIndex + 1 + ) + } else if (leftType === 'J' && rightType !== 'J') { + cg.code.push(typeConversionsImplicit[conversionKeyRight]) + } + finalType = 'J' + } else { + // Promote both to int as the common type for smaller types like byte, short, char + if (leftType !== 'I') { + cg.code.fill( + typeConversionsImplicit[conversionKeyLeft], + insertConversionIndex, + insertConversionIndex + 1 + ) + } + if (rightType !== 'I') { + cg.code.push(typeConversionsImplicit[conversionKeyRight]) + } + finalType = 'I' + } + } + + // Perform the operation + switch (finalType) { case 'B': cg.code.push(intBinaryOp[op], OPCODE.I2B) break @@ -762,8 +1121,8 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } return { - stackSize: Math.max(size1, 1 + (['D', 'J'].includes(type) ? 1 : 0) + size2), - resultType: type + stackSize: Math.max(size1, 1 + (['D', 'J'].includes(finalType) ? 1 : 0) + size2), + resultType: finalType } }, @@ -799,7 +1158,18 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi const compileResult = compile(expr, cg) if (op === '-') { - cg.code.push(OPCODE.INEG) + const negationOpcodes: { [type: string]: OPCODE } = { + I: OPCODE.INEG, // Integer negation + J: OPCODE.LNEG, // Long negation + F: OPCODE.FNEG, // Float negation + D: OPCODE.DNEG // Double negation + } + + if (compileResult.resultType in negationOpcodes) { + cg.code.push(negationOpcodes[compileResult.resultType]) + } else { + throw new Error(`Unary '-' not supported for type: ${compileResult.resultType}`) + } } else if (op === '~') { cg.code.push(OPCODE.ICONST_M1, OPCODE.IXOR) compileResult.stackSize = Math.max(compileResult.stackSize, 2) @@ -849,7 +1219,12 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } } - const info = cg.symbolTable.queryVariable(name) + let info: VariableInfo | SymbolInfo[] + try { + info = cg.symbolTable.queryVariable(name) + } catch (e) { + return { stackSize: 1, resultType: 'Ljava/lang/Class;' }; + } if (Array.isArray(info)) { const fieldInfos = info for (let i = 0; i < fieldInfos.length; i++) { @@ -948,6 +1323,350 @@ const codeGenerators: { [type: string]: (node: Node, cg: CodeGenerator) => Compi } return { stackSize: 1, resultType: EMPTY_TYPE } + }, + + CastExpression: (node: Node, cg: CodeGenerator) => { + const { expression, type } = node as CastExpression // CastExpression node structure + const { stackSize, resultType } = compile(expression, cg) + + if ((type == 'byte' || type == 'short') && resultType != 'I') { + handleExplicitTypeConversion(resultType, 'I', cg) + handleExplicitTypeConversion('I', cg.symbolTable.generateFieldDescriptor(type), cg) + } else if (resultType == 'C') { + if (type == 'int') { + return { + stackSize, + resultType: cg.symbolTable.generateFieldDescriptor('int') + } + } else { + throw new Error(`Unsupported class type conversion: + ${'C'} -> ${cg.symbolTable.generateFieldDescriptor(type)}`) + } + } else if (type == 'char') { + if (resultType == 'I') { + handleExplicitTypeConversion('I', 'C', cg) + } else { + throw new Error(`Unsupported class type conversion: + ${resultType} -> ${cg.symbolTable.generateFieldDescriptor(type)}`) + } + } else { + handleExplicitTypeConversion(resultType, cg.symbolTable.generateFieldDescriptor(type), cg) + } + + return { + stackSize, + resultType: cg.symbolTable.generateFieldDescriptor(type) + } + }, + + SwitchStatement: (node: Node, cg: CodeGenerator) => { + const { expression, cases } = node as SwitchStatement + + // Compile the switch expression + const { stackSize: exprStackSize, resultType } = compile(expression, cg) + let maxStack = exprStackSize + + const caseLabels: Label[] = cases.map(() => cg.generateNewLabel()) + const defaultLabel = cg.generateNewLabel() + const endLabel = cg.generateNewLabel() + + // Track the switch statement's end label + cg.switchLabels.push(endLabel) + + if (['I', 'B', 'S', 'C'].includes(resultType)) { + const caseValues: number[] = [] + const caseLabelMap: Map = new Map() + let hasDefault = false + const positionOffset = cg.code.length + + cases.forEach((caseGroup, index) => { + caseGroup.labels.forEach(label => { + if (label.kind === 'CaseLabel') { + const value = parseInt((label.expression as Literal).literalType.value) + caseValues.push(value) + caseLabelMap.set(value, caseLabels[index]) + } else if (label.kind === 'DefaultLabel') { + caseLabels[index] = defaultLabel + hasDefault = true + } + }) + }) + + const [minValue, maxValue] = [Math.min(...caseValues), Math.max(...caseValues)] + const useTableSwitch = maxValue - minValue < caseValues.length * 2 + const caseLabelIndex: number[] = [] + let indexTracker = cg.code.length + + if (useTableSwitch) { + cg.code.push(OPCODE.TABLESWITCH) + indexTracker++ + + // Ensure 4-byte alignment for TABLESWITCH + while (cg.code.length % 4 !== 0) { + cg.code.push(0) // Padding bytes (JVM requires alignment) + indexTracker++ + } + + // Add default branch (jump to default label) + cg.code.push(0, 0, 0, defaultLabel.offset) + caseLabelIndex.push(indexTracker + 3) + indexTracker += 4 + + // Push low and high values (min and max case values) + cg.code.push( + minValue >> 24, + (minValue >> 16) & 0xff, + (minValue >> 8) & 0xff, + minValue & 0xff + ) + cg.code.push( + maxValue >> 24, + (maxValue >> 16) & 0xff, + (maxValue >> 8) & 0xff, + maxValue & 0xff + ) + indexTracker += 8 + + // Generate branch table (map each value to a case label) + for (let i = minValue; i <= maxValue; i++) { + const caseIndex = caseValues.indexOf(i) + cg.code.push( + 0, + 0, + 0, + caseIndex !== -1 ? caseLabels[caseIndex].offset : defaultLabel.offset + ) + caseLabelIndex.push(indexTracker + 3) + indexTracker += 4 + } + } else { + cg.code.push(OPCODE.LOOKUPSWITCH) + indexTracker++ + + // Ensure 4-byte alignment for LOOKUPSWITCH + while (cg.code.length % 4 !== 0) { + cg.code.push(0) + indexTracker++ + } + + // Add default branch (jump to default label) + cg.code.push(0, 0, 0, defaultLabel.offset) + caseLabelIndex.push(indexTracker + 3) + indexTracker += 4 + + // Push the number of case-value pairs + cg.code.push( + (caseValues.length >> 24) & 0xff, + (caseValues.length >> 16) & 0xff, + (caseValues.length >> 8) & 0xff, + caseValues.length & 0xff + ) + indexTracker += 4 + + // Generate lookup table (pairs of case values and corresponding labels) + caseValues.forEach((value, index) => { + cg.code.push(value >> 24, (value >> 16) & 0xff, (value >> 8) & 0xff, value & 0xff) + cg.code.push(0, 0, 0, caseLabels[index].offset) + caseLabelIndex.push(indexTracker + 7) + indexTracker += 8 + }) + } + + // **Process case bodies with proper fallthrough handling** + let previousCase: SwitchCase | null = null + + const nonDefaultCases = cases.filter(caseGroup => + caseGroup.labels.some(label => label.kind === 'CaseLabel') + ) + + nonDefaultCases.forEach((caseGroup, index) => { + caseLabels[index].offset = cg.code.length + + // Ensure statements array is always defined + caseGroup.statements = caseGroup.statements || [] + + // If previous case had no statements, merge labels (fallthrough) + if (previousCase && (previousCase.statements?.length ?? 0) === 0) { + previousCase.labels.push(...caseGroup.labels) + } + + // Compile case statements + caseGroup.statements.forEach(statement => { + const { stackSize } = compile(statement, cg) + maxStack = Math.max(maxStack, stackSize) + }) + + previousCase = caseGroup + }) + + // **Process default case** + defaultLabel.offset = cg.code.length + if (hasDefault) { + const defaultCase = cases.find(caseGroup => + caseGroup.labels.some(label => label.kind === 'DefaultLabel') + ) + if (defaultCase) { + defaultCase.statements = defaultCase.statements || [] + defaultCase.statements.forEach(statement => { + const { stackSize } = compile(statement, cg) + maxStack = Math.max(maxStack, stackSize) + }) + } + } + + cg.code[caseLabelIndex[0]] = caseLabels[caseLabels.length - 1].offset - positionOffset + + for (let i = 1; i < caseLabelIndex.length; i++) { + cg.code[caseLabelIndex[i]] = caseLabels[i - 1].offset - positionOffset + } + + endLabel.offset = cg.code.length + } else if (resultType === 'Ljava/lang/String;') { + // **String Switch Handling** + const hashCaseMap: Map = new Map() + + // Compute and store hashCode() + cg.code.push( + OPCODE.INVOKEVIRTUAL, + 0, + cg.constantPoolManager.indexMethodrefInfo('java/lang/String', 'hashCode', '()I') + ) + + // Create lookup table for hashCodes + cases.forEach((caseGroup, index) => { + caseGroup.labels.forEach(label => { + if (label.kind === 'CaseLabel') { + const caseValue = (label.expression as Literal).literalType.value + const hashCodeValue = hashCode(caseValue.slice(1, caseValue.length - 1)) + if (!hashCaseMap.has(hashCodeValue)) { + hashCaseMap.set(hashCodeValue, caseLabels[index]) + } + } else if (label.kind === 'DefaultLabel') { + caseLabels[index] = defaultLabel + } + }) + }) + + const caseLabelIndex: number[] = [] + let indexTracker = cg.code.length + const positionOffset = cg.code.length + + // **LOOKUPSWITCH Implementation** + cg.code.push(OPCODE.LOOKUPSWITCH) + indexTracker++ + + // Ensure 4-byte alignment + while (cg.code.length % 4 !== 0) { + cg.code.push(0) + indexTracker++ + } + + // Default jump target + cg.code.push(0, 0, 0, defaultLabel.offset) + caseLabelIndex.push(indexTracker + 3) + indexTracker += 4 + + // Number of case-value pairs + cg.code.push( + (hashCaseMap.size >> 24) & 0xff, + (hashCaseMap.size >> 16) & 0xff, + (hashCaseMap.size >> 8) & 0xff, + hashCaseMap.size & 0xff + ) + indexTracker += 4 + + // Populate LOOKUPSWITCH + hashCaseMap.forEach((label, hashCode) => { + cg.code.push( + hashCode >> 24, + (hashCode >> 16) & 0xff, + (hashCode >> 8) & 0xff, + hashCode & 0xff + ) + cg.code.push(0, 0, 0, label.offset) + caseLabelIndex.push(indexTracker + 7) + indexTracker += 8 + }) + + // **Case Handling** + let previousCase: SwitchCase | null = null + + cases + .filter(caseGroup => caseGroup.labels.some(label => label.kind === 'CaseLabel')) + .forEach((caseGroup, index) => { + caseLabels[index].offset = cg.code.length + + // Ensure statements exist + caseGroup.statements = caseGroup.statements || [] + + // Handle fallthrough + if (previousCase && (previousCase.statements?.length ?? 0) === 0) { + previousCase.labels.push(...caseGroup.labels) + } + + // **String Comparison for Collisions** + const caseValue = caseGroup.labels.find( + (label): label is CaseLabel => label.kind === 'CaseLabel' + ) + if (caseValue) { + // TODO: check for actual String equality instead of just rely on hashCode equality + // (see the commented out code below) + + // const caseStr = (caseValue.expression as Literal).literalType.value; + // const caseStrIndex = cg.constantPoolManager.indexStringInfo(caseStr); + + // cg.code.push(OPCODE.LDC, caseStrIndex); + // cg.code.push( + // OPCODE.INVOKEVIRTUAL, + // 0, + // cg.constantPoolManager.indexMethodrefInfo("java/lang/String", "equals", "(Ljava/lang/Object;)Z") + // ); + // + const caseEndLabel = cg.generateNewLabel() + // cg.addBranchInstr(OPCODE.IFEQ, caseEndLabel); + + // Compile case statements + caseGroup.statements.forEach(statement => { + const { stackSize } = compile(statement, cg) + maxStack = Math.max(maxStack, stackSize) + }) + + caseEndLabel.offset = cg.code.length + } + + previousCase = caseGroup + }) + + // **Default Case Handling** + defaultLabel.offset = cg.code.length + const defaultCase = cases.find(caseGroup => + caseGroup.labels.some(label => label.kind === 'DefaultLabel') + ) + + if (defaultCase) { + defaultCase.statements = defaultCase.statements || [] + defaultCase.statements.forEach(statement => { + const { stackSize } = compile(statement, cg) + maxStack = Math.max(maxStack, stackSize) + }) + } + + cg.code[caseLabelIndex[0]] = caseLabels[caseLabels.length - 1].offset - positionOffset + + for (let i = 1; i < caseLabelIndex.length; i++) { + cg.code[caseLabelIndex[i]] = caseLabels[i - 1].offset - positionOffset + } + + endLabel.offset = cg.code.length + } else { + throw new Error( + `Switch statements only support byte, short, int, char, or String types. Found: ${resultType}` + ) + } + + cg.switchLabels.pop() + + return { stackSize: maxStack, resultType: EMPTY_TYPE } } } @@ -958,7 +1677,9 @@ class CodeGenerator { stackSize: number = 0 labels: Label[] = [] loopLabels: Label[][] = [] + switchLabels: Label[] = [] code: number[] = [] + currentClass: string constructor(symbolTable: SymbolTable, constantPoolManager: ConstantPoolManager) { this.symbolTable = symbolTable @@ -989,8 +1710,9 @@ class CodeGenerator { } } - generateCode(methodNode: MethodDeclaration) { + generateCode(currentClass: string, methodNode: MethodDeclaration) { this.symbolTable.extend() + this.currentClass = currentClass if (!methodNode.methodModifier.includes('static')) { this.maxLocals++ } @@ -1013,11 +1735,13 @@ class CodeGenerator { if (methodNode.methodHeader.identifier === '') { this.stackSize = Math.max(this.stackSize, 1) + const parentClass = + this.symbolTable.queryClass(currentClass).parentClassName || 'java/lang/Object' this.code.push( OPCODE.ALOAD_0, OPCODE.INVOKESPECIAL, 0, - this.constantPoolManager.indexMethodrefInfo('java/lang/Object', '', '()V') + this.constantPoolManager.indexMethodrefInfo(parentClass, '', '()V') ) } @@ -1058,8 +1782,9 @@ class CodeGenerator { export function generateCode( symbolTable: SymbolTable, constantPoolManager: ConstantPoolManager, + currentClass: string, methodNode: MethodDeclaration ) { const codeGenerator = new CodeGenerator(symbolTable, constantPoolManager) - return codeGenerator.generateCode(methodNode) + return codeGenerator.generateCode(currentClass, methodNode) } diff --git a/src/compiler/compiler.ts b/src/compiler/compiler.ts index a3566173..3e116800 100644 --- a/src/compiler/compiler.ts +++ b/src/compiler/compiler.ts @@ -1,4 +1,4 @@ -import { ClassFile } from '../ClassFile/types' +import { Class, ClassFile } from '../ClassFile/types' import { AST } from '../ast/types/packages-and-modules' import { ClassBodyDeclaration, @@ -31,35 +31,55 @@ export class Compiler { private methods: Array private attributes: Array private className: string + private parentClassName: string constructor() { this.setup() } private setup() { + this.symbolTable = new SymbolTable() + } + + private resetClassFileState() { this.constantPoolManager = new ConstantPoolManager() this.interfaces = [] this.fields = [] this.methods = [] this.attributes = [] - this.symbolTable = new SymbolTable() } compile(ast: AST) { this.setup() this.symbolTable.handleImports(ast.importDeclarations) - const classFiles: Array = [] - ast.topLevelClassOrInterfaceDeclarations.forEach(x => classFiles.push(this.compileClass(x))) - return classFiles[0] + const classFiles: Array = [] + + ast.topLevelClassOrInterfaceDeclarations.forEach(decl => { + const className = decl.typeIdentifier + const parentClassName = decl.sclass ? decl.sclass : 'java/lang/Object' + const accessFlags = generateClassAccessFlags(decl.classModifier) + this.symbolTable.insertClassInfo( + { name: className, accessFlags: accessFlags, parentClassName: parentClassName }) + this.symbolTable.returnToRoot() + }) + + ast.topLevelClassOrInterfaceDeclarations.forEach(decl => { + this.resetClassFileState() + const classFile = this.compileClass(decl) + classFiles.push({classFile: classFile, className: this.className}) + }) + + return classFiles } private compileClass(classNode: ClassDeclaration): ClassFile { - const parentClassName = 'java/lang/Object' this.className = classNode.typeIdentifier + this.parentClassName = classNode.sclass ? classNode.sclass : 'java/lang/Object' const accessFlags = generateClassAccessFlags(classNode.classModifier) + this.symbolTable.extend() this.symbolTable.insertClassInfo({ name: this.className, accessFlags: accessFlags }) - const superClassIndex = this.constantPoolManager.indexClassInfo(parentClassName) + const superClassIndex = this.constantPoolManager.indexClassInfo(this.parentClassName) const thisClassIndex = this.constantPoolManager.indexClassInfo(this.className) this.constantPoolManager.indexUtf8Info('Code') this.handleClassBody(classNode.classBody) @@ -170,8 +190,9 @@ export class Compiler { this.symbolTable.insertMethodInfo({ name: methodName, accessFlags: generateMethodAccessFlags(methodNode.methodModifier), - parentClassName: this.className, - typeDescriptor: descriptor + parentClassName: this.parentClassName, + typeDescriptor: descriptor, + className: this.className }) } @@ -183,8 +204,9 @@ export class Compiler { this.symbolTable.insertMethodInfo({ name: '', accessFlags: generateMethodAccessFlags(constructor.constructorModifier), - parentClassName: this.className, - typeDescriptor: descriptor + parentClassName: this.parentClassName, + typeDescriptor: descriptor, + className: this.className }) } @@ -199,7 +221,9 @@ export class Compiler { const descriptorIndex = this.constantPoolManager.indexUtf8Info(descriptor) const attributes: Array = [] - attributes.push(generateCode(this.symbolTable, this.constantPoolManager, methodNode)) + attributes.push( + generateCode(this.symbolTable, this.constantPoolManager, this.className, methodNode) + ) this.methods.push({ accessFlags: generateMethodAccessFlags(methodNode.methodModifier), diff --git a/src/compiler/error.ts b/src/compiler/error.ts index 333f687b..1044d4fd 100644 --- a/src/compiler/error.ts +++ b/src/compiler/error.ts @@ -18,7 +18,7 @@ export class SymbolRedeclarationError extends CompileError { export class SymbolCannotBeResolvedError extends CompileError { constructor(token: string, fullName: string) { - super('cannot resolve symbol ' + '"' + token + '"' + ' in' + '"' + fullName + '"') + super('cannot resolve symbol ' + '"' + token + '"' + ' in ' + '"' + fullName + '"') } } @@ -33,3 +33,21 @@ export class ConstructNotSupportedError extends CompileError { super('"' + name + '"' + ' is currently not supported by the compiler') } } + +export class NoMethodMatchingSignatureError extends CompileError { + constructor(signature: string) { + super(`No method matching signature ${signature}) found.`) + } +} + +export class AmbiguousMethodCallError extends CompileError { + constructor(signature: string) { + super(`Ambiguous method call: ${signature}`) + } +} + +export class OverrideFinalMethodError extends CompileError { + constructor(name: string) { + super(`Cannot override final method ${name}`) + } +} \ No newline at end of file diff --git a/src/compiler/grammar.pegjs b/src/compiler/grammar.pegjs index 8f5e3ca0..505f648e 100755 --- a/src/compiler/grammar.pegjs +++ b/src/compiler/grammar.pegjs @@ -774,7 +774,42 @@ AssertStatement = assert Expression (colon Expression) semicolon SwitchStatement - = TO_BE_ADDED + = switch lparen expr:Expression rparen lcurly + cases:SwitchBlock? + rcurly { + return addLocInfo({ + kind: "SwitchStatement", + expression: expr, + cases: cases ?? [], + }); + } + +SwitchBlock + = cases:SwitchBlockStatementGroup* { + return cases; + } + +SwitchBlockStatementGroup + = labels:SwitchLabel+ stmts:BlockStatement* { + return { + kind: "SwitchBlockStatementGroup", + labels: labels, + statements: stmts, + }; + } + +SwitchLabel + = case expr:Expression colon { + return { + kind: "CaseLabel", + expression: expr, + }; + } + / default colon { + return { + kind: "DefaultLabel", + }; + } DoStatement = do body:Statement while lparen expr:Expression rparen semicolon { @@ -1079,7 +1114,8 @@ MultiplicativeExpression } UnaryExpression - = PostfixExpression + = / CastExpression + / PostfixExpression / op:PrefixOp expr:UnaryExpression { return addLocInfo({ kind: "PrefixExpression", @@ -1087,7 +1123,6 @@ UnaryExpression expression: expr, }) } - / CastExpression / SwitchExpression PrefixOp @@ -1107,8 +1142,19 @@ PostfixExpression } CastExpression - = lparen PrimitiveType rparen UnaryExpression - / lparen ReferenceType rparen (LambdaExpression / !(PlusMinus) UnaryExpression) + = lparen castType:PrimitiveType rparen expr:UnaryExpression { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + })} + / lparen castType:ReferenceType rparen expr:(LambdaExpression / !(PlusMinus) UnaryExpression) { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + }) + } SwitchExpression = SwitchStatement diff --git a/src/compiler/grammar.ts b/src/compiler/grammar.ts index 4b6f1ee6..57470926 100755 --- a/src/compiler/grammar.ts +++ b/src/compiler/grammar.ts @@ -776,7 +776,42 @@ AssertStatement = assert Expression (colon Expression) semicolon SwitchStatement - = TO_BE_ADDED + = switch lparen expr:Expression rparen lcurly + cases:SwitchBlock? + rcurly { + return addLocInfo({ + kind: "SwitchStatement", + expression: expr, + cases: cases ?? [], + }); + } + +SwitchBlock + = cases:SwitchBlockStatementGroup* { + return cases; + } + +SwitchBlockStatementGroup + = labels:SwitchLabel+ stmts:BlockStatement* { + return { + kind: "SwitchBlockStatementGroup", + labels: labels, + statements: stmts, + }; + } + +SwitchLabel + = case expr:Expression colon { + return { + kind: "CaseLabel", + expression: expr, + }; + } + / default colon { + return { + kind: "DefaultLabel", + }; + } DoStatement = do body:Statement while lparen expr:Expression rparen semicolon { @@ -1081,7 +1116,8 @@ MultiplicativeExpression } UnaryExpression - = PostfixExpression + = CastExpression + / PostfixExpression / op:PrefixOp expr:UnaryExpression { return addLocInfo({ kind: "PrefixExpression", @@ -1089,7 +1125,6 @@ UnaryExpression expression: expr, }) } - / CastExpression / SwitchExpression PrefixOp @@ -1109,8 +1144,19 @@ PostfixExpression } CastExpression - = lparen PrimitiveType rparen UnaryExpression - / lparen ReferenceType rparen (LambdaExpression / !(PlusMinus) UnaryExpression) + = lparen castType:PrimitiveType rparen expr:UnaryExpression { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + })} + / lparen castType:ReferenceType rparen expr:(LambdaExpression / !(PlusMinus) UnaryExpression) { + return addLocInfo({ + kind: "CastExpression", + type: castType, + expression: expr, + }) + } SwitchExpression = SwitchStatement diff --git a/src/compiler/import/lib-info.ts b/src/compiler/import/lib-info.ts index d334e765..8db187d6 100644 --- a/src/compiler/import/lib-info.ts +++ b/src/compiler/import/lib-info.ts @@ -1,61 +1,60 @@ export const rawLibInfo = { - "packages": [ + packages: [ { - "name": "java.lang", - "classes": [ + name: 'java.lang', + classes: [ { - "name": "public final java.lang.String" + name: 'public final java.lang.String' }, { - "name": "public final java.lang.System", - "fields": [ - "public static final java.io.PrintStream out" - ] + name: 'public final java.lang.Object' + }, + { + name: 'public final java.lang.System', + fields: ['public static final java.io.PrintStream out'] }, { - "name": "public final java.lang.Math", - "methods": [ - "public static int max(int,int)", - "public static int min(int,int)", - "public static double log10(double)" + name: 'public final java.lang.Math', + methods: [ + 'public static int max(int,int)', + 'public static int min(int,int)', + 'public static double log10(double)' ] } ] }, { - "name": "java.io", - "classes": [ + name: 'java.io', + classes: [ { - "name": "public java.io.PrintStream", - "methods": [ - "public void println(java.lang.String)", - "public void println(int)", - "public void println(long)", - "public void println(float)", - "public void println(double)", - "public void println(char)", - "public void println(boolean)", - "public void print(java.lang.String)", - "public void print(int)", - "public void print(long)", - "public void print(float)", - "public void print(double)", - "public void print(char)", - "public void print(boolean)" + name: 'public java.io.PrintStream', + methods: [ + 'public void println(java.lang.String)', + 'public void println(int)', + 'public void println(long)', + 'public void println(float)', + 'public void println(double)', + 'public void println(char)', + 'public void println(boolean)', + 'public void print(java.lang.String)', + 'public void print(int)', + 'public void print(long)', + 'public void print(float)', + 'public void print(double)', + 'public void print(char)', + 'public void print(boolean)' ] } ] }, { - "name": "java.util", - "classes": [ + name: 'java.util', + classes: [ { - "name": "public java.util.Arrays", - "methods": [ - "public static java.lang.String toString(int[])" - ] + name: 'public java.util.Arrays', + methods: ['public static java.lang.String toString(int[])'] } ] } ] -} \ No newline at end of file +} diff --git a/src/compiler/index.ts b/src/compiler/index.ts index 52a93725..13202756 100644 --- a/src/compiler/index.ts +++ b/src/compiler/index.ts @@ -1,16 +1,16 @@ import * as peggy from 'peggy' import { AST } from '../ast/types/packages-and-modules' -import { ClassFile } from '../ClassFile/types' +import { Class } from '../ClassFile/types' import { Compiler } from './compiler' import { javaPegGrammar } from './grammar' import { peggyFunctions } from './peggy-functions' -export const compile = (ast: AST): ClassFile => { +export const compile = (ast: AST): Array => { const compiler = new Compiler() return compiler.compile(ast) } -export const compileFromSource = (javaProgram: string): ClassFile => { +export const compileFromSource = (javaProgram: string): Array => { const parser = peggy.generate(peggyFunctions + javaPegGrammar, { allowedStartRules: ['CompilationUnit'], cache: true diff --git a/src/compiler/symbol-table.ts b/src/compiler/symbol-table.ts index 37c0620e..394ffd34 100644 --- a/src/compiler/symbol-table.ts +++ b/src/compiler/symbol-table.ts @@ -6,12 +6,13 @@ import { generateMethodAccessFlags } from './compiler-utils' import { - InvalidMethodCallError, + InvalidMethodCallError, OverrideFinalMethodError, SymbolCannotBeResolvedError, SymbolNotFoundError, SymbolRedeclarationError } from './error' import { libraries } from './import/libs' +import { METHOD_FLAGS } from '../ClassFile/types/methods' export const typeMap = new Map([ ['byte', 'B'], @@ -63,6 +64,7 @@ export interface MethodInfo { accessFlags: number parentClassName: string typeDescriptor: string + className: string } export interface VariableInfo { @@ -99,7 +101,8 @@ export class SymbolTable { private setup() { libraries.forEach(p => { - this.importedPackages.push(p.packageName + '/') + if (this.importedPackages.findIndex(e => e == p.packageName + '/') == -1) + this.importedPackages.push(p.packageName + '/') p.classes.forEach(c => { this.insertClassInfo({ name: c.className, @@ -119,7 +122,8 @@ export class SymbolTable { name: m.methodName, accessFlags: generateMethodAccessFlags(m.accessFlags), parentClassName: c.className, - typeDescriptor: this.generateMethodDescriptor(m.argsTypeName, m.returnTypeName) + typeDescriptor: this.generateMethodDescriptor(m.argsTypeName, m.returnTypeName), + className: c.className }) ) this.returnToRoot() @@ -131,7 +135,7 @@ export class SymbolTable { return new Map() } - private returnToRoot() { + public returnToRoot() { this.tables = [this.tables[0]] this.curTable = this.tables[0] this.curIdx = 0 @@ -189,6 +193,7 @@ export class SymbolTable { insertFieldInfo(info: FieldInfo) { const key = generateSymbol(info.name, SymbolType.FIELD) + this.curTable = this.tables[this.curIdx] if (this.curTable.has(key)) { throw new SymbolRedeclarationError(info.name) } @@ -203,6 +208,22 @@ export class SymbolTable { insertMethodInfo(info: MethodInfo) { const key = generateSymbol(info.name, SymbolType.METHOD) + for (let i = this.curClassIdx - 1; i > 0; i--) { + const parentTable = this.tables[i]; + if (parentTable.has(key)) { + const parentMethods = parentTable.get(key)!.info; + if (Array.isArray(parentMethods)) { + for (const m of parentMethods) { + if (m.typeDescriptor === info.typeDescriptor && (m.accessFlags & METHOD_FLAGS.ACC_FINAL) + && m.className == info.parentClassName) { + throw new OverrideFinalMethodError(info.name); + } + } + } + } + } + + this.curTable = this.tables[this.curIdx] if (!this.curTable.has(key)) { const symbolNode: SymbolNode = { info: [info], @@ -225,7 +246,7 @@ export class SymbolTable { insertVariableInfo(info: VariableInfo) { const key = generateSymbol(info.name, SymbolType.VARIABLE) - for (let i = this.curIdx; i > this.curClassIdx; i--) { + for (let i = this.curIdx; i >= this.curClassIdx; i--) { if (this.tables[i].has(key)) { throw new SymbolRedeclarationError(info.name) } @@ -235,6 +256,7 @@ export class SymbolTable { info: info, children: this.getNewTable() } + this.curTable = this.tables[this.curIdx] this.curTable.set(key, symbolNode) } @@ -266,6 +288,36 @@ export class SymbolTable { throw new SymbolNotFoundError(name) } + private getClassTable(name: string): Table { + let key = generateSymbol(name, SymbolType.CLASS) + for (let i = this.curIdx; i >= 0; i--) { + const table = this.tables[i] + if (table.has(key)) { + return table.get(key)!.children + } + } + + const root = this.tables[0] + if (this.importedClassMap.has(name)) { + const fullName = this.importedClassMap.get(name)! + key = generateSymbol(fullName, SymbolType.CLASS) + if (root.has(key)) { + return root.get(key)!.children + } + } + + let p: string + for (p of this.importedPackages) { + const fullName = p + name + key = generateSymbol(fullName, SymbolType.CLASS) + if (root.has(key)) { + return root.get(key)!.children + } + } + + throw new SymbolNotFoundError(name) + } + private querySymbol(name: string, symbolType: SymbolType): Array { let curTable = this.getNewTable() const symbolInfos: Array = [] @@ -275,7 +327,7 @@ export class SymbolTable { tokens.forEach((token, i) => { if (i === 0) { const key1 = generateSymbol(token, SymbolType.VARIABLE) - for (let i = this.curIdx; i > this.curClassIdx; i--) { + for (let i = this.curIdx; i >= this.curClassIdx; i--) { if (this.tables[i].has(key1)) { const node = this.tables[i].get(key1)! token = (node.info as VariableInfo).typeName @@ -287,8 +339,7 @@ export class SymbolTable { if (token === 'this') { curTable = this.tables[this.curClassIdx] } else { - const key = generateSymbol(this.queryClass(token).name, SymbolType.CLASS) - curTable = this.tables[0].get(key)!.children + curTable = this.getClassTable(token) } } else if (i < len - 1) { const key = generateSymbol(token, SymbolType.FIELD) @@ -299,8 +350,7 @@ export class SymbolTable { symbolInfos.push(node.info) const typeName = (node.info as FieldInfo).typeName - const type = generateSymbol(this.queryClass(typeName).name, SymbolType.CLASS) - curTable = this.tables[0].get(type)!.children + curTable = this.getClassTable(typeName) } else { const key = generateSymbol(token, symbolType) const node = curTable.get(key) @@ -324,16 +374,26 @@ export class SymbolTable { } const key1 = generateSymbol(name, SymbolType.VARIABLE) - for (let i = this.curIdx; i > this.curClassIdx; i--) { + for (let i = this.curIdx; i >= this.curClassIdx; i--) { if (this.tables[i].has(key1)) { throw new InvalidMethodCallError(name) } } + const results: Array = [] const key2 = generateSymbol(name, SymbolType.METHOD) - const table = this.tables[this.curClassIdx] - if (table.has(key2)) { - return [table.get(key2)!.info] + for (let i = this.curIdx; i > 0; i--) { + const table = this.tables[i] + if (table.has(key2)) { + const methodInfos = table.get(key2)!.info as MethodInfos + for (const methodInfo of methodInfos) { + results.push(methodInfo) + } + } + } + + if (results.length > 0) { + return results } throw new InvalidMethodCallError(name) } @@ -346,7 +406,7 @@ export class SymbolTable { const key1 = generateSymbol(name, SymbolType.VARIABLE) const key2 = generateSymbol(name, SymbolType.FIELD) - for (let i = this.curIdx; i >= 0; i--) { + for (let i = this.curIdx; i >= this.curClassIdx; i--) { const table = this.tables[i] if (table.has(key1)) { return (table.get(key1) as SymbolNode).info as VariableInfo diff --git a/src/jvm/utils/index.ts b/src/jvm/utils/index.ts index ea115a78..091dbc44 100644 --- a/src/jvm/utils/index.ts +++ b/src/jvm/utils/index.ts @@ -213,7 +213,7 @@ export function getField(ref: any, fieldName: string, type: JavaType) { } export function asDouble(value: number): number { - return value + return value; } export function asFloat(value: number): number { diff --git a/src/types/checker/index.ts b/src/types/checker/index.ts index 005286e1..c9107e3e 100644 --- a/src/types/checker/index.ts +++ b/src/types/checker/index.ts @@ -1,7 +1,7 @@ import { Array as ArrayType } from '../types/arrays' import { Integer, String, Throwable, Void } from '../types/references' import { CaseConstant, Node } from '../ast/specificationTypes' -import { Type } from '../types/type' +import { PrimitiveType, Type } from '../types/type' import { ArrayRequiredError, BadOperandTypesError, @@ -63,6 +63,33 @@ export const check = (node: Node, frame: Frame = Frame.globalFrame()): Result => return typeCheckBody(node, typeCheckingFrame) } +const isCastCompatible = (fromType: Type, toType: Type): boolean => { + // Handle primitive type compatibility + if (fromType instanceof PrimitiveType && toType instanceof PrimitiveType) { + const fromName = fromType.constructor.name; + const toName = toType.constructor.name; + + console.log(fromName, toName); + + return !(fromName === 'char' && toName !== 'int'); + } + + // Handle class type compatibility + if (fromType instanceof ClassType && toType instanceof ClassType) { + // Allow upcasts (base class to derived class) or downcasts (derived class to base class) + return fromType.canBeAssigned(toType) || toType.canBeAssigned(fromType); + } + + // Handle array type compatibility + if (fromType instanceof ArrayType && toType instanceof ArrayType) { + // Ensure the content types are compatible + return isCastCompatible(fromType.getContentType(), toType.getContentType()); + } + + // Disallow other cases by default + return false; +}; + export const typeCheckBody = (node: Node, frame: Frame = Frame.globalFrame()): Result => { switch (node.kind) { case 'ArrayAccess': { @@ -192,6 +219,55 @@ export const typeCheckBody = (node: Node, frame: Frame = Frame.globalFrame()): R case 'BreakStatement': { return OK_RESULT } + + case 'CastExpression': { + let castType: Type | TypeCheckerError; + let expressionType: Type | null = null; + let expressionResult: Result; + + if ('primitiveType' in node) { + castType = frame.getType(unannTypeToString(node.primitiveType), node.primitiveType.location); + } else { + throw new Error('Invalid CastExpression: Missing type information.'); + } + + if (castType instanceof TypeCheckerError) { + return newResult(null, [castType]); + } + + if ('unaryExpression' in node) { + expressionResult = typeCheckBody(node.unaryExpression, frame); + } else { + throw new Error('Invalid CastExpression: Missing expression.'); + } + + if (expressionResult.hasErrors) { + return expressionResult; + } + + expressionType = expressionResult.currentType; + if (!expressionType) { + throw new Error('Expression in cast should have a type.'); + } + + if ( + (castType instanceof PrimitiveType && expressionType instanceof PrimitiveType) + ) { + if (!isCastCompatible(expressionType, castType)) { + return newResult(null, [ + new IncompatibleTypesError(node.location), + ]); + } + } else { + return newResult(null, [ + new IncompatibleTypesError(node.location), + ]); + } + + // If the cast is valid, return the target type + return newResult(castType); + } + case 'ClassInstanceCreationExpression': { const classIdentifier = node.unqualifiedClassInstanceCreationExpression.classOrInterfaceTypeToInstantiate diff --git a/src/types/types/methods.ts b/src/types/types/methods.ts index 3a4adb6e..d7df067a 100644 --- a/src/types/types/methods.ts +++ b/src/types/types/methods.ts @@ -55,7 +55,7 @@ export class Parameter { return ( object instanceof Parameter && this._name === object._name && - (this._type.canBeAssigned(object._type) || object._type.canBeAssigned(this._type)) && + this._type === object._type && this._isVarargs === object._isVarargs ) } @@ -189,6 +189,7 @@ export class Method implements Type { } public invoke(args: Arguments): Type | TypeCheckerError { + if (this.methodName === 'println') return new Void() const error = this.parameters.invoke(args) if (error instanceof TypeCheckerError) return error return this.returnType