Trying to make the AST of my compiler using ANTLR and getting return value null

78 Views Asked by At

I have a project where i have to make a compiler using ANTLR and java to make a calculator like program that does the following:Addition Subtraction Multiplication Division Power Sine Cosine Tangent Cotangent mod NOT AND OR XOR NAND Logarithm Factorial Pi Root e. While trying to make the AST using a custom visitor ( i already have generate the neseccery files using antlr and my grammar(g4 file),i keep getting the following error:An error occurred: Cannot invoke "java.lang.Double.doubleValue()" because the return value of "EvaluateExpressionVisitor.visit(ExpressionNode)" is null.

I have being able to identify was the return value is null so any help would be appreaciated. Below is my main file, Evaluate Expression Visitor and ASTNodes.

import org.antlr.v4.runtime.ANTLRFileStream;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.tree.ParseTree;

import java.io.IOException;

public class Main {
    public static void main(String[] args) {
        try {
            MTmathLexer lexer = new MTmathLexer(new ANTLRFileStream("input.txt"));
            MTmathParser parser = new MTmathParser(new CommonTokenStream(lexer));
            ParseTree cst = parser.r();
            ExpressionNode ast = new BuildAstVisitor().visit(cst);
            double value = new EvaluateExpressionVisitor().visit(ast);

            System.out.println("= " + value);
        } catch (IOException e) {
            System.err.println("Error reading input file: " + e.getMessage());
        } catch (Exception e) {
            System.err.println("An error occurred: " + e.getMessage());
        }
    }
}
 abstract class ExpressionNode {}

 abstract class BinaryExpressionNode extends ExpressionNode {
    protected ExpressionNode left;
    protected ExpressionNode right;

    public BinaryExpressionNode(ExpressionNode left, ExpressionNode right) {
        this.left = left;
        this.right = right;
    }

    public ExpressionNode getLeft() {
        return left;
    }

    public ExpressionNode getRight() {
        return right;
    }
}

 class AdditionNode extends BinaryExpressionNode {
    public AdditionNode(ExpressionNode left, ExpressionNode right) {
        super(left, right);
    }
}

 class SubtractionNode extends BinaryExpressionNode {
    public SubtractionNode(ExpressionNode left, ExpressionNode right) {
        super(left, right);
    }
}

  class MultiplicationNode extends BinaryExpressionNode {
        public MultiplicationNode(ExpressionNode left, ExpressionNode right) {
            super(left, right);
        }
    }

     class DivisionNode extends BinaryExpressionNode {
        public DivisionNode(ExpressionNode left, ExpressionNode right) {
            super(left, right);
        }
    }

     class PowerNode extends BinaryExpressionNode {
        public PowerNode(ExpressionNode left, ExpressionNode right) {
            super(left, right);
        }
    }

     class ModNode extends BinaryExpressionNode {
        public ModNode(ExpressionNode left, ExpressionNode right) {
            super(left, right);
        }
    }

     class NotNode extends ExpressionNode {
        private ExpressionNode expression;

        public NotNode(ExpressionNode expression) {
            this.expression = expression;
        }

        public ExpressionNode getExpression() {
            return expression;
        }
    }

     class AndNode extends BinaryExpressionNode {
        public AndNode(ExpressionNode left, ExpressionNode right) {
            super(left, right);
        }
    }

     class OrNode extends BinaryExpressionNode {
        public OrNode(ExpressionNode left, ExpressionNode right) {
            super(left, right);
        }
    }

     class XorNode extends BinaryExpressionNode {
        public XorNode(ExpressionNode left, ExpressionNode right) {
            super(left, right);
        }
    }

     class NandNode extends BinaryExpressionNode {
        public NandNode(ExpressionNode left, ExpressionNode right) {
            super(left, right);
        }
    }

     class LogNode extends ExpressionNode {
        private ExpressionNode expression;

        public LogNode(ExpressionNode expression) {
            this.expression = expression;
        }

        public ExpressionNode getExpression() {
            return expression;
        }
    }

     class FactorialNode extends ExpressionNode {
        private ExpressionNode expression;

        public FactorialNode(ExpressionNode expression) {
            this.expression = expression;
        }

        public ExpressionNode getExpression() {
            return expression;
        }
    }

     class PiNode extends ExpressionNode {}

     class SquareRootNode extends ExpressionNode {
        private ExpressionNode expression;

        public SquareRootNode(ExpressionNode expression) {
            this.expression = expression;
        }

        public ExpressionNode getExpression() {
            return expression;
        }
    }

     class ENode extends ExpressionNode {}

     class SinNode extends ExpressionNode {
        private ExpressionNode expression;

        public SinNode(ExpressionNode expression) {
            this.expression = expression;
        }

        public ExpressionNode getExpression() {
            return expression;
        }
    }

     class CosNode extends ExpressionNode {
        private ExpressionNode expression;

        public CosNode(ExpressionNode expression) {
            this.expression = expression;
        }

        public ExpressionNode getExpression() {
            return expression;
        }
    }

     class TanNode extends ExpressionNode {
        private ExpressionNode expression;

        public TanNode(ExpressionNode expression) {
            this.expression = expression;
        }

        public ExpressionNode getExpression() {
            return expression;
        }
    }

     class CotNode extends ExpressionNode {
        private ExpressionNode expression;

        public CotNode(ExpressionNode expression) {
            this.expression = expression;
        }

        public ExpressionNode getExpression() {
            return expression;
        }
    }

 class NumberNode extends ExpressionNode {
    private int value;

    public NumberNode(int value) {
        this.value = value;
    }

    public int getValue() {
        return value;
    }
}
public class EvaluateExpressionVisitor extends ASTVisitor<Double> {
    @Override
    public Double visit(AdditionNode node) {
        return visit(node.getLeft()) + visit(node.getRight());
    }

    @Override
    public Double visit(SubtractionNode node) {
        return visit(node.getLeft()) - visit(node.getRight());
    }


    @Override
    public Double visit(NumberNode node) {
        return (double) node.getValue();
    }

    @Override
    public Double visit(PiNode node) {
        return Math.PI;
    }

    @Override
    public Double visit(SquareRootNode node) {
        double argument = visit(node.getExpression());
        return Math.sqrt(argument);
    }

    @Override
    public Double visit(ENode node) {
        return Math.E;
    }

    @Override
    public Double visit(MultiplicationNode node) {
        double leftValue = visit(node.getLeft());
        double rightValue = visit(node.getRight());
        return leftValue * rightValue;
    }


    @Override
    public Double visit(DivisionNode node) {
        double leftValue = visit(node.getLeft());
        double rightValue = visit(node.getRight());
        if (rightValue == 0) {
            throw new ArithmeticException("Division by zero");
        }
        return leftValue / rightValue;
    }

    @Override
    public Double visit(PowerNode node) {
        double baseValue = visit(node.getLeft());
        double exponentValue = visit(node.getRight());
        return Math.pow(baseValue, exponentValue);
    }

    @Override
    public Double visit(ModNode node) {
        double leftValue = visit(node.getLeft());
        double rightValue = visit(node.getRight());
        return leftValue % rightValue;
    }

    @Override
    public Double visit(NotNode node) {
        double value = visit(node.getExpression());
        if (value == 0) {
            return 1.0; // Return 1 if the value is 0 (false)
        } else if (value == 1) {
            return 0.0; // Return 0 if the value is 1 (true)
        } else {
            throw new IllegalArgumentException("Invalid value for 'not' operation: " + value);
        }
    }


    @Override
    public Double visit(AndNode node) {
        double leftValue = visit(node.getLeft());
        double rightValue = visit(node.getRight());
        if (leftValue == 1 && rightValue == 1) {
            return 1.0; // Return 1 if both values are true
        } else {
            return 0.0; // Return 0 otherwise
        }
    }

    @Override
    public Double visit(OrNode node) {
        double leftValue = visit(node.getLeft());
        double rightValue = visit(node.getRight());
        if (leftValue == 1 || rightValue == 1) {
            return 1.0; // Return 1 if either value is true
        } else {
            return 0.0; // Return 0 otherwise
        }
    }

    @Override
    public Double visit(XorNode node) {
        double leftValue = visit(node.getLeft());
        double rightValue = visit(node.getRight());
        if ((leftValue == 1 && rightValue == 0) || (leftValue == 0 && rightValue == 1)) {
            return 1.0; // Return 1 if one value is true and the other is false
        } else {
            return 0.0; // Return 0 otherwise
        }
    }

    @Override
    public Double visit(NandNode node) {
        double leftValue = visit(node.getLeft());
        double rightValue = visit(node.getRight());
        if (leftValue == 1 && rightValue == 1) {
            return 0.0; // Return 0 if both values are true
        } else {
            return 1.0; // Return 1 otherwise
        }
    }

    @Override
    public Double visit(LogNode node) {
        double value = visit(node.getExpression());
        if (value <= 0) {
            throw new IllegalArgumentException("Logarithm can only be computed for positive numbers");
        }
        return Math.log(value);
    }

    @Override
    public Double visit(FactorialNode node) {
        int value = visit(node.getExpression()).intValue();
        if (value < 0) {
            throw new IllegalArgumentException("Factorial can only be computed for non-negative integers");
        }
        double result = 1;
        for (int i = 2; i <= value; i++) {
            result *= i;
        }
        return result;
    }

    @Override
    public Double visit(CosNode node) {
        double radians = Math.toRadians(visit(node.getExpression()));
        return Math.cos(radians);
    }

    @Override
    public Double visit(TanNode node) {
        double radians = Math.toRadians(visit(node.getExpression()));
        return Math.tan(radians);
    }

    @Override
    public Double visit(CotNode node) {
        double radians = Math.toRadians(visit(node.getExpression()));
        return 1.0 / Math.tan(radians);
    }

    @Override
    public Double visit(SinNode node) {
        double radians = Math.toRadians(visit(node.getExpression()));
        return Math.sin(radians);
    }


}
public class BuildAstVisitor extends MTmathBaseVisitor<ExpressionNode> {
    @Override
    public ExpressionNode visitPlus_op(MTmathParser.Plus_opContext ctx) {
        ExpressionNode left = visit(ctx.NUM(0));
        ExpressionNode right = visit(ctx.NUM(1));
        return new AdditionNode(left, right);
    }

    @Override
    public ExpressionNode visitMinus_op(MTmathParser.Minus_opContext ctx) {
        ExpressionNode left = visit(ctx.NUM(0));
        ExpressionNode right = visit(ctx.NUM(1));
        return new SubtractionNode(left, right);
    }

    // Implement visit methods for other binary operations similarly

    //@Override
    //public ExpressionNode visitNUM(MTmathParser.NUMContext ctx) {
      //  int value = Integer.parseInt(ctx.getText());
       // return new NumberNode(value);
    //}
    @Override
    public ExpressionNode visitMultiply_op(MTmathParser.Multiply_opContext ctx) {
        ExpressionNode left = visit(ctx.NUM(0));
        ExpressionNode right = visit(ctx.NUM(1));
        return new MultiplicationNode(left, right);
    }

    @Override
    public ExpressionNode visitDivide_op(MTmathParser.Divide_opContext ctx) {
        ExpressionNode left = visit(ctx.NUM(0));
        ExpressionNode right = visit(ctx.NUM(1));
        return new DivisionNode(left, right);
    }

    @Override
    public ExpressionNode visitPower_op(MTmathParser.Power_opContext ctx) {
        ExpressionNode left = visit(ctx.NUM(0));
        ExpressionNode right = visit(ctx.NUM(1));
        return new PowerNode(left, right);
    }

    @Override
    public ExpressionNode visitMod_op(MTmathParser.Mod_opContext ctx) {
        ExpressionNode left = visit(ctx.NUM(0));
        ExpressionNode right = visit(ctx.NUM(1));
        return new ModNode(left, right);
    }

    @Override
    public ExpressionNode visitNot_op(MTmathParser.Not_opContext ctx) {
        ExpressionNode expression = visit(ctx.NUM());
        return new NotNode(expression);
    }

    @Override
    public ExpressionNode visitAnd_op(MTmathParser.And_opContext ctx) {
        ExpressionNode left = visit(ctx.NUM(0));
        ExpressionNode right = visit(ctx.NUM(1));
        return new AndNode(left, right);
    }

    @Override
    public ExpressionNode visitOr_op(MTmathParser.Or_opContext ctx) {
        ExpressionNode left = visit(ctx.NUM(0));
        ExpressionNode right = visit(ctx.NUM(1));
        return new OrNode(left, right);
    }

    @Override
    public ExpressionNode visitXor_op(MTmathParser.Xor_opContext ctx) {
        ExpressionNode left = visit(ctx.NUM(0));
        ExpressionNode right = visit(ctx.NUM(1));
        return new XorNode(left, right);
    }

    @Override
    public ExpressionNode visitNand_op(MTmathParser.Nand_opContext ctx) {
        ExpressionNode left = visit(ctx.NUM(0));
        ExpressionNode right = visit(ctx.NUM(1));
        return new NandNode(left, right);
    }

    @Override
    public ExpressionNode visitLog_op(MTmathParser.Log_opContext ctx) {
        ExpressionNode expression = visit(ctx.NUM());
        return new LogNode(expression);
    }

    @Override
    public ExpressionNode visitFactorial_op(MTmathParser.Factorial_opContext ctx) {
        ExpressionNode expression = visit(ctx.NUM());
        return new FactorialNode(expression);
    }
    @Override
    public ExpressionNode visitPi_op(MTmathParser.Pi_opContext ctx) {
        return new PiNode();
    }

    @Override
    public ExpressionNode visitRoot_op(MTmathParser.Root_opContext ctx) {
        ExpressionNode argument = visit(ctx.NUM());
        return new SquareRootNode(argument);
    }

    @Override
    public ExpressionNode visitE_op(MTmathParser.E_opContext ctx) {
        return new ENode();
    }
    @Override
    public ExpressionNode visitSine_op(MTmathParser.Sine_opContext ctx) {
        ExpressionNode expression = visit(ctx.NUM());
        return new SinNode(expression);
    }

    @Override
    public ExpressionNode visitCosine_op(MTmathParser.Cosine_opContext ctx) {
        ExpressionNode expression = visit(ctx.NUM());
        return new CosNode(expression);
    }

    @Override
    public ExpressionNode visitTangent_op(MTmathParser.Tangent_opContext ctx) {
        ExpressionNode expression = visit(ctx.NUM());
        return new TanNode(expression);
    }

    @Override
    public ExpressionNode visitCotangent_op(MTmathParser.Cotangent_opContext ctx) {
        ExpressionNode expression = visit(ctx.NUM());
        return new CotNode(expression);
    }
}
grammar MTmath;

r : calculations + EOF ; // expression syntax

NUM : [0-9]+;
WS : [ \t\r\n]+ -> skip ; // skip spaces, tabs, newlines

ADD : '+';
MINUS : '-';
MULTIPLY : '*';
DIVIDE : '/';
POWER : '^';
MOD : '%';
NOT : '!';
AND : '&&';
OR : '||';
XOR : '^&'; // kanonika o symbolismos toy xor einai to "^" ALLA epidh 8a exei ton idio symbolismo methn dyname (power) 8a to orisoyme me "^&"
NAND : '!&'; // kanonika o symbolismos toy nand einai to "!" ALLA epidh 8a exei ton idio symbolismo methn not 8a to orisoyme me "!&"
SEMICOLON : ';';
LEFTPAR : '(';
RIGHTPAR : ')';
PI : 'pi';
E : 'e';
FACTORIAL : 'f!';

plus_op : NUM ADD NUM SEMICOLON;
minus_op : NUM MINUS NUM SEMICOLON;
multiply_op : NUM MULTIPLY NUM SEMICOLON;
divide_op : NUM DIVIDE NUM SEMICOLON;
power_op : NUM POWER NUM SEMICOLON;
mod_op : NUM MOD NUM SEMICOLON;
not_op : NOT NUM SEMICOLON;
and_op : NUM AND NUM SEMICOLON;
or_op : NUM OR NUM SEMICOLON;
xor_op : NUM XOR NUM SEMICOLON;
nand_op : NUM NAND NUM SEMICOLON;
log_op : 'log' LEFTPAR NUM RIGHTPAR SEMICOLON;
factorial_op : NUM FACTORIAL SEMICOLON;
pi_op : PI SEMICOLON;
root_op : 'sqrt' LEFTPAR NUM RIGHTPAR SEMICOLON;
e_op : E SEMICOLON;

trig_op : sine_op | cosine_op | tangent_op | cotangent_op;
sine_op : 'sin' LEFTPAR NUM RIGHTPAR SEMICOLON;
cosine_op : 'cos' LEFTPAR NUM RIGHTPAR SEMICOLON;
tangent_op : 'tan' LEFTPAR NUM RIGHTPAR SEMICOLON;
cotangent_op : 'cot' LEFTPAR NUM RIGHTPAR SEMICOLON;

calculations : plus_op | minus_op | multiply_op | divide_op | power_op | mod_op | not_op | and_op | or_op | xor_op | nand_op | log_op | factorial_op | pi_op | root_op | e_op | trig_op;
public abstract class ASTVisitor<T> {
    public abstract T visit(AdditionNode node);
    public abstract T visit(SubtractionNode node);
    public abstract T visit(MultiplicationNode node);
    public abstract T visit(DivisionNode node);
    public abstract T visit(PowerNode node);
    public abstract T visit(ModNode node);
    public abstract T visit(NotNode node);
    public abstract T visit(AndNode node);
    public abstract T visit(OrNode node);
    public abstract T visit(XorNode node);
    public abstract T visit(NandNode node);
    public abstract T visit(LogNode node);
    public abstract T visit(FactorialNode node);
    public abstract T visit(PiNode node);
    public abstract T visit(ENode node);
    public abstract T visit(CosNode node);
    public abstract T visit(TanNode node);
    public abstract T visit(CotNode node);
    public abstract T visit(SinNode node);
    public abstract T visit(SquareRootNode node);
    public abstract T visit(NumberNode node);

    public T visit(ExpressionNode node) {
        if (node instanceof AdditionNode) {
            return visit((AdditionNode) node);
        } else if (node instanceof SubtractionNode) {
            return visit((SubtractionNode) node);
        } else if (node instanceof DivisionNode) {
            return visit((DivisionNode) node);
        } else if (node instanceof PowerNode) {
            return visit((PowerNode) node);
        } else if (node instanceof ModNode) {
            return visit((ModNode) node);
        } else if (node instanceof NotNode) {
            return visit((NotNode) node);
        } else if (node instanceof AndNode) {
            return visit((AndNode) node);
        } else if (node instanceof OrNode) {
            return visit((OrNode) node);
        } else if (node instanceof XorNode) {
            return visit((XorNode) node);
        } else if (node instanceof NandNode) {
            return visit((NandNode) node);
        } else if (node instanceof LogNode) {
            return visit((LogNode) node);
        } else if (node instanceof FactorialNode) {
            return visit((FactorialNode) node);
        } else if (node instanceof PiNode) {
            return visit((PiNode) node);
        } else if (node instanceof CosNode) {
            return visit((CosNode) node);
        } else if (node instanceof TanNode) {
            return visit((TanNode) node);
        } else if (node instanceof CotNode) {
            return visit((CotNode) node);
        } else if (node instanceof SinNode) {
            return visit((SinNode) node);
        } else if (node instanceof SquareRootNode) {
            return visit((SquareRootNode) node); 
    } else if (node instanceof NumberNode) {
        return visit((NumberNode) node);
    } else  {
            return null;
        }
}
}
1

There are 1 best solutions below

0
Bart Kiers On

As mentioned in the comments: there is still code missing for others to reproduce the error you mention.

At the leafs of your parse tree, you should stop visiting the tree, and return your "atom" nodes (a NumberNode in your case). After all, when you encounter a NUM token, there is nothing left to visit.

Instead of doing:

@Override
public ExpressionNode visitPlus_op(MTmathParser.Plus_opContext ctx) {
  ExpressionNode left = visit(ctx.NUM(0));
  ExpressionNode right = visit(ctx.NUM(1));
  return new AdditionNode(left, right);
}

try something like this:

@Override
public ExpressionNode visitPlus_op(MTmathParser.Plus_opContext ctx) {
  ExpressionNode left = new NumberNode(Integer.valueOf(ctx.NUM(0).getText()));
  ExpressionNode right = new NumberNode(Integer.valueOf(ctx.NUM(1).getText()));
  return new AdditionNode(left, right);
}

Also override visitR:

@Override
public ExpressionNode visitR(MTmathParser.RContext ctx) {
  // Return the first expression, ignore others
  return visit(ctx.calculations(0));
}

And you want your expression node to evaluate to a value (Double or Object to also be able to return a Boolean):

public abstract class ExpressionNode {
  abstract Double eval();
}

and then your AdditionNode node could look like this:

class AdditionNode extends BinaryExpressionNode {
    public AdditionNode(ExpressionNode left, ExpressionNode right) {
        super(left, right);
    }

    @Override
    public Double eval() { return super.left.eval() + super.right.eval(); }
}

When I then run this code:

String expression = "1+2;";
MTmathLexer lexer = new MTmathLexer(CharStreams.fromString(expression));
MTmathParser parser = new MTmathParser(new CommonTokenStream(lexer));
ParseTree cst = parser.r();
ExpressionNode ast = new BuildAstVisitor().visit(cst);
double value = new EvaluateExpressionVisitor().visit(ast);
System.out.println(expression + " --> " + value);

this is printed:

1+2; --> 3.0