<- decision_tree(engine = "rpart") |>
tree_mod set_mode("classification")
<- workflow() |>
tree_wf add_formula(children ~ .) |>
add_model(tree_mod)
Tune better models to predict children in hotel bookings
Suggested answers
Your Turn 1
Fill in the blanks to return the accuracy and ROC AUC for this model using 10-fold cross-validation.
Fill in the blanks to return the accuracy and ROC AUC for this model using 10-fold cross-validation.
set.seed(100)
|>
______ ______(resamples = hotels_folds) |>
______
Answer:
set.seed(100)
|>
tree_wf fit_resamples(resamples = hotels_folds) |>
collect_metrics()
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.773 10 0.00567 Preprocessor1_Model1
2 roc_auc binary 0.832 10 0.00672 Preprocessor1_Model1
Your Turn 2
Create a new parsnip model called rf_mod
, which will learn an ensemble of classification trees from our training data using the ranger package. Update your tree_wf
with this new model.
Fit your workflow with 10-fold cross-validation and compare the ROC AUC of the random forest to your single decision tree model — which predicts the test set better?
Hint: you’ll need https://www.tidymodels.org/find/parsnip/
# model
<- _____ |>
rf_mod _____("ranger") |>
_____("classification")
# workflow
<- tree_wf |>
rf_wf update_model(_____)
# fit with cross-validation
set.seed(100)
|>
_____ fit_resamples(resamples = hotels_folds) |>
collect_metrics()
Answer:
# model
<- rand_forest(engine = "ranger") |>
rf_mod set_mode("classification")
# workflow
<- tree_wf |>
rf_wf update_model(rf_mod)
# fit with cross-validation
set.seed(100)
|>
rf_wf fit_resamples(resamples = hotels_folds) |>
collect_metrics()
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.829 10 0.00332 Preprocessor1_Model1
2 roc_auc binary 0.912 10 0.00320 Preprocessor1_Model1
Your Turn 3
Challenge: Fit 3 more random forest models, each using 3, 5, and 8 variables at each split. Update your rf_wf
with each new model. Which value maximizes the area under the ROC curve?
<- rf_mod |>
rf3_mod set_args(mtry = 3)
<- rf_mod |>
rf5_mod set_args(mtry = 5)
<- rf_mod |>
rf8_mod set_args(mtry = 8)
Do this for each model above:
<- rf_wf |>
_____ update_model(_____)
set.seed(100)
|>
_____ fit_resamples(resamples = hotels_folds) |>
collect_metrics()
Answer:
# 3
<- rf_wf |>
rf3_wf update_model(rf3_mod)
set.seed(100)
|>
rf3_wf fit_resamples(resamples = hotels_folds) |>
collect_metrics()
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.831 10 0.00307 Preprocessor1_Model1
2 roc_auc binary 0.910 10 0.00307 Preprocessor1_Model1
# 5
<- rf_wf |>
rf5_wf update_model(rf5_mod)
set.seed(100)
|>
rf5_wf fit_resamples(resamples = hotels_folds) |>
collect_metrics()
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.829 10 0.00376 Preprocessor1_Model1
2 roc_auc binary 0.912 10 0.00305 Preprocessor1_Model1
# 8
<- rf_wf |>
rf8_wf update_model(rf8_mod)
set.seed(100)
|>
rf8_wf fit_resamples(resamples = hotels_folds) |>
collect_metrics()
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.828 10 0.00378 Preprocessor1_Model1
2 roc_auc binary 0.909 10 0.00362 Preprocessor1_Model1
Your Turn 4
Edit the random forest model to tune the mtry
and min_n
hyper-parameters; call the new model spec rf_tuner
.
Update your workflow to use the tuned model.
Then use tune_grid()
to find the best combination of hyper-parameters to maximize roc_auc
; let tune set up the grid for you.
How does it compare to the average ROC AUC across folds from fit_resamples()
?
<- rand_forest(engine = "ranger") |>
rf_mod set_mode("classification")
<- workflow() |>
rf_wf add_formula(children ~ .) |>
add_model(rf_mod)
set.seed(100) # Important!
<- rf_wf |>
rf_results fit_resamples(resamples = hotels_folds,
metrics = metric_set(roc_auc),
# change me to control_grid() with tune_grid
control = control_resamples(save_workflow = TRUE))
|>
rf_results collect_metrics()
# A tibble: 1 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 roc_auc binary 0.912 10 0.00320 Preprocessor1_Model1
Answer:
<- rand_forest(
rf_tuner engine = "ranger",
mtry = tune(),
min_n = tune()
|>
) set_mode("classification")
<- rf_wf |>
rf_wf update_model(rf_tuner)
set.seed(100) # Important!
<- rf_wf |>
rf_results tune_grid(resamples = hotels_folds,
control = control_grid(save_workflow = TRUE))
i Creating pre-processing data to finalize unknown parameter: mtry
Your Turn 5
Use fit_best()
to take the best combination of hyper-parameters from rf_results
and use them to predict the test set.
How does our actual test ROC AUC compare to our cross-validated estimate?
<- fit_best(rf_results)
hotels_best
# cross validated ROC AUC
|>
rf_results show_best(metric = "roc_auc", n = 5)
# A tibble: 5 × 8
mtry min_n .metric .estimator mean n std_err .config
<int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
1 3 15 roc_auc binary 0.910 10 0.00283 Preprocessor1_Model07
2 8 20 roc_auc binary 0.909 10 0.00376 Preprocessor1_Model10
3 7 36 roc_auc binary 0.908 10 0.00372 Preprocessor1_Model02
4 9 28 roc_auc binary 0.907 10 0.00381 Preprocessor1_Model01
5 12 21 roc_auc binary 0.907 10 0.00430 Preprocessor1_Model03
# test set ROC AUC
bind_cols(
hotels_test,predict(hotels_best, new_data = hotels_test, type = "prob")
|>
) roc_auc(truth = children, .pred_children)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 roc_auc binary 0.913
# test set ROC curve
bind_cols(
hotels_test,predict(hotels_best, new_data = hotels_test, type = "prob")
|>
) roc_curve(truth = children, .pred_children) |>
autoplot()
Acknowledgments
- Materials derived from Tidymodels, Virtually: An Introduction to Machine Learning with Tidymodels by Allison Hill.
- Dataset and some modeling steps derived from A predictive modeling case study and licensed under a Creative Commons Attribution-ShareAlike 4.0 International (CC BY-SA) License.
::session_info() sessioninfo
─ Session info ───────────────────────────────────────────────────────────────
setting value
version R version 4.3.1 (2023-06-16)
os macOS Ventura 13.5.2
system aarch64, darwin20
ui X11
language (EN)
collate en_US.UTF-8
ctype en_US.UTF-8
tz America/New_York
date 2023-11-10
pandoc 3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
─ Packages ───────────────────────────────────────────────────────────────────
package * version date (UTC) lib source
backports 1.4.1 2021-12-13 [1] CRAN (R 4.3.0)
bit 4.0.5 2022-11-15 [1] CRAN (R 4.3.0)
bit64 4.0.5 2020-08-30 [1] CRAN (R 4.3.0)
broom * 1.0.5 2023-06-09 [1] CRAN (R 4.3.0)
class 7.3-22 2023-05-03 [1] CRAN (R 4.3.0)
cli 3.6.1 2023-03-23 [1] CRAN (R 4.3.0)
codetools 0.2-19 2023-02-01 [1] CRAN (R 4.3.0)
colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.0)
crayon 1.5.2 2022-09-29 [1] CRAN (R 4.3.0)
data.table 1.14.8 2023-02-17 [1] CRAN (R 4.3.0)
dials * 1.2.0 2023-04-03 [1] CRAN (R 4.3.0)
DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.3.0)
digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.0)
dplyr * 1.1.3 2023-09-03 [1] CRAN (R 4.3.0)
evaluate 0.22 2023-09-29 [1] CRAN (R 4.3.1)
fansi 1.0.5 2023-10-08 [1] CRAN (R 4.3.1)
farver 2.1.1 2022-07-06 [1] CRAN (R 4.3.0)
fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.0)
forcats * 1.0.0 2023-01-29 [1] CRAN (R 4.3.0)
foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.0)
furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.0)
future 1.32.0 2023-03-07 [1] CRAN (R 4.3.0)
future.apply 1.11.0 2023-05-21 [1] CRAN (R 4.3.0)
generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.0)
ggplot2 * 3.4.2 2023-04-03 [1] CRAN (R 4.3.0)
globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.0)
glue 1.6.2 2022-02-24 [1] CRAN (R 4.3.0)
gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.0)
GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.0)
gtable 0.3.3 2023-03-21 [1] CRAN (R 4.3.0)
hardhat 1.3.0 2023-03-30 [1] CRAN (R 4.3.0)
here 1.0.1 2020-12-13 [1] CRAN (R 4.3.0)
hms 1.1.3 2023-03-21 [1] CRAN (R 4.3.0)
htmltools 0.5.6.1 2023-10-06 [1] CRAN (R 4.3.1)
htmlwidgets 1.6.2 2023-03-17 [1] CRAN (R 4.3.0)
infer * 1.0.4 2022-12-02 [1] CRAN (R 4.3.0)
ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.0)
iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.0)
jsonlite 1.8.7 2023-06-29 [1] CRAN (R 4.3.0)
knitr 1.44 2023-09-11 [1] CRAN (R 4.3.0)
labeling 0.4.2 2020-10-20 [1] CRAN (R 4.3.0)
lattice 0.21-8 2023-04-05 [1] CRAN (R 4.3.0)
lava 1.7.2.1 2023-02-27 [1] CRAN (R 4.3.0)
lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.0)
lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.3.0)
listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.0)
lubridate * 1.9.3 2023-09-27 [1] CRAN (R 4.3.1)
magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0)
MASS 7.3-60 2023-05-04 [1] CRAN (R 4.3.0)
Matrix 1.5-4.1 2023-05-18 [1] CRAN (R 4.3.0)
modeldata * 1.1.0 2023-01-25 [1] CRAN (R 4.3.0)
modelenv 0.1.1 2023-03-08 [1] CRAN (R 4.3.0)
munsell 0.5.0 2018-06-12 [1] CRAN (R 4.3.0)
nnet 7.3-19 2023-05-03 [1] CRAN (R 4.3.0)
parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.0)
parsnip * 1.1.0 2023-04-12 [1] CRAN (R 4.3.0)
pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.0)
pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.0)
prodlim 2023.03.31 2023-04-02 [1] CRAN (R 4.3.0)
purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.3.0)
R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0)
ranger * 0.15.1 2023-04-03 [1] CRAN (R 4.3.0)
Rcpp 1.0.10 2023-01-22 [1] CRAN (R 4.3.0)
readr * 2.1.4 2023-02-10 [1] CRAN (R 4.3.0)
recipes * 1.0.6 2023-04-25 [1] CRAN (R 4.3.0)
rlang 1.1.1 2023-04-28 [1] CRAN (R 4.3.0)
rmarkdown 2.25 2023-09-18 [1] CRAN (R 4.3.1)
rpart * 4.1.19 2022-10-21 [1] CRAN (R 4.3.0)
rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.3.0)
rsample * 1.1.1 2022-12-07 [1] CRAN (R 4.3.0)
rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.3.0)
scales * 1.2.1 2022-08-20 [1] CRAN (R 4.3.0)
sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.0)
stringi 1.7.12 2023-01-11 [1] CRAN (R 4.3.0)
stringr * 1.5.0 2022-12-02 [1] CRAN (R 4.3.0)
survival 3.5-5 2023-03-12 [1] CRAN (R 4.3.0)
tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.0)
tidymodels * 1.1.0 2023-05-01 [1] CRAN (R 4.3.0)
tidyr * 1.3.0 2023-01-24 [1] CRAN (R 4.3.0)
tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.3.0)
tidyverse * 2.0.0 2023-02-22 [1] CRAN (R 4.3.0)
timechange 0.2.0 2023-01-11 [1] CRAN (R 4.3.0)
timeDate 4022.108 2023-01-07 [1] CRAN (R 4.3.0)
tune * 1.1.1 2023-04-11 [1] CRAN (R 4.3.0)
tzdb 0.4.0 2023-05-12 [1] CRAN (R 4.3.0)
utf8 1.2.4 2023-10-22 [1] CRAN (R 4.3.1)
vctrs 0.6.4 2023-10-12 [1] CRAN (R 4.3.1)
vroom 1.6.3 2023-04-28 [1] CRAN (R 4.3.0)
withr 2.5.2 2023-10-30 [1] CRAN (R 4.3.1)
workflows * 1.1.3 2023-02-22 [1] CRAN (R 4.3.0)
workflowsets * 1.0.1 2023-04-06 [1] CRAN (R 4.3.0)
xfun 0.40 2023-08-09 [1] CRAN (R 4.3.0)
yaml 2.3.7 2023-01-23 [1] CRAN (R 4.3.0)
yardstick * 1.2.0 2023-04-21 [1] CRAN (R 4.3.0)
[1] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library
──────────────────────────────────────────────────────────────────────────────