How would I find the gradient of a best-fit line for input values on a scatterplot

185 Views Asked by At

Based on this example, the below code plots a scatterplot to which double values can be input via a spinner. I'd like to plot a line of best fit on the scatter plot in order to calculate the gradient and find an average for the values which the value I get for gradient. I'd like it to be output onto a text box when a button is pressed. Help would be greatly appreciated.

package Grava;

        import javafx.application.Application;
        import javafx.geometry.Pos;
        import javafx.scene.Scene;
        import javafx.scene.control.Button;
        import javafx.scene.control.TextField;
        import javafx.scene.control.*;
        import javafx.scene.image.Image;
        import javafx.scene.layout.BorderPane;
        import javafx.scene.layout.HBox;
        import javafx.stage.Stage;
        import org.jfree.chart.ChartFactory;
        import org.jfree.chart.JFreeChart;
        import org.jfree.chart.fx.ChartViewer;
        import org.jfree.chart.plot.XYPlot;
        import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
        import org.jfree.data.statistics.Regression;
        import org.jfree.data.xy.XYSeries;
        import org.jfree.data.xy.XYSeriesCollection;

public class ScatterAdd extends Application {


    private final XYSeries series = new XYSeries("Voltage");
    private final XYSeries trend = new XYSeries("Trend");
    private final XYSeriesCollection dataset = new XYSeriesCollection(series);

    ChoiceBox<String> domainLabels = new ChoiceBox<>();
    ChoiceBox<String> rangeLabels = new ChoiceBox<>();

    private JFreeChart createChart() throws Exception {
        /*
        JFreeChart chart = createChart();
        domainLabels.getSelectionModel().selectedItemProperty().addListener((ov, s0, s1) -> {
            chart.getXYPlot().getDomainAxis().setLabel(s1);
        });
        rangeLabels.getSelectionModel().selectedItemProperty().addListener((ov, s0, s1) -> {
            chart.getXYPlot().getRangeAxis().setLabel(s1);
        });
        XYPlot plot = createChart().getXYPlot();
        XYLineAndShapeRenderer r = (XYLineAndShapeRenderer) plot.getRenderer();
        r.setSeriesLinesVisible(1, Boolean.TRUE);
        r.setSeriesShapesVisible(1, Boolean.FALSE);
        this part gives me an error when i create it in the createChart() method
        */
         
        return ChartFactory.createScatterPlot("VI Characteristics", "Current", "Voltage", dataset);
        
    }

    @Override
    public void start(Stage stage) throws Exception {

        Image image = new Image("Grava.logo.png");
        stage.getIcons().add(image);

        var equation = new TextField();



        series.addChangeListener((event) -> {
            double[] coefficients = Regression.getOLSRegression(dataset, 0);
            double b = coefficients[0]; // intercept
            double m = coefficients[1]; // slope
            double x = series.getDataItem(0).getXValue();
            trend.add(x, m * x + b);
            x = series.getDataItem(series.getItemCount() - 1).getXValue();
            trend.add(x, m * x + b);
            dataset.addSeries(trend);
            equation.setText("y = " + m + " x + " + b);

        });




        domainLabels.getItems().addAll("Current", "Seconds");
        domainLabels.setValue("Current");

        rangeLabels.getItems().addAll("Voltage", "Metres");
        rangeLabels.setValue("Voltage");

        var xSpin = new Spinner<Double>(-10000000.000, 10000000.000, 0, 0.1);
        xSpin.setEditable(true);
        xSpin.setPromptText("Xvalue");

        var ySpin = new Spinner<Double>(-10000000.000, 10000000.000, 0, 0.1);
        ySpin.setEditable(true);
        ySpin.setPromptText("Yvalue");

        var button = new Button("Add");
        button.setOnAction(ae -> series.add(xSpin.getValue(), ySpin.getValue()));


        HBox xBox = new HBox();
        xBox.getChildren().addAll(domainLabels);

        HBox yBox = new HBox();
        yBox.getChildren().addAll(rangeLabels);

        var enter = new ToolBar(xBox, xSpin, yBox, ySpin, button, equation);
        BorderPane.setAlignment(enter, Pos.CENTER);

        BorderPane root = new BorderPane();
        root.setCenter(new ChartViewer(createChart()));
        root.setBottom(enter);

        stage.setTitle("ScatterAdd");
        stage.setScene(new Scene(root, 720, 480));
        stage.show();

    }

    public static void main(String[] args) {
        launch(args);
    }
}
1

There are 1 best solutions below

14
trashgod On

As shown here, call Regression.getOLSRegression() and display the result. You can do this in your Add button handler, as suggested in your question, or in a SeriesChangeListener, as shown below. This fragment simply prints the equation of the line, but you can update the display as desired.

private final XYSeries series = new XYSeries("Voltage");
private final XYSeries trend = new XYSeries("Trend");
private final XYSeriesCollection dataset = new XYSeriesCollection(series);
…
// https://stackoverflow.com/a/61398612/230513
series.addChangeListener((event) -> {
    double[] coefficients = Regression.getOLSRegression(dataset, 0);
    double b = coefficients[0]; // intercept
    double m = coefficients[1]; // slope
    System.out.println("y = " + m + " x + " + b);
});

Note that the series will need at least two values, and it should be added to the collection only once. In particular,

  • Don't recreate the dataset in createChart()

  • Don't recreate the choice boxes in start()

  • Add a TextField to the toolbar and invoke setText() in the change listener; the displayed equation will update as points are added.

  • Create an XYSeries named trend and add into the dataset, then update the series in the change listener.

      var equation = new TextField();
      series.addChangeListener((event) -> {
          …
          equation.setText("y = " + m + "x + " + b);
          double x = series.getDataItem(0).getXValue();
          trend.clear();
          trend.add(x, m * x + b);
          x = series.getDataItem(series.getItemCount() - 1).getXValue();
          trend.add(x, m * x + b);
      });
    

image