Node link diagram in R using Rpart.plot and rattle

314 Views Asked by At

I am trying to create a node-link diagram (decision tree) by using parsnip and tidymodels. What I am performing is building a decision tree model for the StackOverflow dataset using the tidymodels package and rpart as model engine. The model should predict whether a developer will work remotely (variable remote) based on the number of years of programming experience (years_coded_job), degree of career satisfaction (career_satisfaction), job title "Data Scientist" yes/no (data_scientist), and size of the employing company (company_size_number).

My pipeline

library(tidyverse)
library(tidymodels)
library(rpart.plot)
library(rpart)
library(rattle)

so <- read_rds(here::here("stackoverflow.rds"))

fit <- rpart(remote ~ years_coded_job + career_satisfaction + data_scientist + company_size_number,
             data = so,
             control = rpart.control(minsplit = 20, minbucket = 2))

fancyRpartPlot(fit, sub = "")

The plot I obtain

plot

I want to know whether is this the correct approach for determining the predictors. Since I am not building a model is this the right way?

1

There are 1 best solutions below

2
On BEST ANSWER

If you are going and parsnip to fit your model, it's better to use that actual fitted model for any visualizations like this. You can get the underlying engine object from a parsnip model using $fit.

library(tidyverse)
library(tidymodels)
library(rattle)
#> Loading required package: bitops
#> Rattle: A free graphical interface for data science with R.
#> Version 5.4.0 Copyright (c) 2006-2020 Togaware Pty Ltd.
#> Type 'rattle()' to shake, rattle, and roll your data.
data(kyphosis, package = "rpart")

tree_fit <- decision_tree(min_n = 20) %>%
  set_engine("rpart") %>%
  set_mode("classification") %>%
  fit(Kyphosis ~ Age + Number + Start,
      data = kyphosis)

fancyRpartPlot(tree_fit$fit, sub = "")

Created on 2021-05-25 by the reprex package (v2.0.0)

For some kinds of visualizations, you will need to use repair_call().