library(tidyverse)
library(tidymodels)
library(stringr)
library(textrecipes)
library(themis)
library(vip)
# set seed for randomization
set.seed(123)
theme_set(theme_minimal(base_size = 13))
Slay: Predicting song artist based on lyrics
Suggested answers
Import data
<- read_csv(file = "data/beyonce-swift-lyrics.csv") |>
lyrics mutate(artist = factor(artist))
lyrics
# A tibble: 309 × 19
album_name track_number track_name artist lyrics danceability energy loudness
<chr> <dbl> <chr> <fct> <chr> <dbl> <dbl> <dbl>
1 RENAISSAN… 1 I'M THAT … Beyon… "Plea… 0.554 0.535 -8.96
2 RENAISSAN… 2 COZY Beyon… "This… 0.556 0.63 -8.15
3 RENAISSAN… 3 ALIEN SUP… Beyon… "Plea… 0.545 0.641 -6.40
4 RENAISSAN… 4 CUFF IT Beyon… "I fe… 0.78 0.689 -5.67
5 RENAISSAN… 5 ENERGY (f… Beyon… "On s… 0.903 0.519 -9.15
6 RENAISSAN… 6 BREAK MY … Beyon… "I'm … 0.693 0.887 -5.04
7 RENAISSAN… 7 CHURCH GI… Beyon… "(Lor… 0.792 0.919 -5.69
8 RENAISSAN… 8 PLASTIC O… Beyon… "Boy,… 0.618 0.712 -8.25
9 RENAISSAN… 9 VIRGO'S G… Beyon… "Baby… 0.683 0.85 -5.04
10 RENAISSAN… 10 MOVE (fea… Beyon… "Move… 0.876 0.628 -6.60
# ℹ 299 more rows
# ℹ 11 more variables: speechiness <dbl>, acousticness <dbl>,
# instrumentalness <dbl>, liveness <dbl>, valence <dbl>, tempo <dbl>,
# time_signature <dbl>, duration_ms <dbl>, explicit <lgl>, key_name <chr>,
# mode_name <chr>
Split the data into analysis/assessment/test sets
Your turn:
- Split the data into training/test sets with 75% allocated for training
- Split the training set into 10 cross-validation folds
# split into training/testing
set.seed(123)
<- initial_split(data = lyrics, strata = artist, prop = 0.75)
lyrics_split
<- training(lyrics_split)
lyrics_train <- testing(lyrics_split)
lyrics_test
# create cross-validation folds
<- vfold_cv(data = lyrics_train, strata = artist) lyrics_folds
Estimate the null model for a baseline comparison
Your turn: Estimate a null model to determine an appropriate baseline for evaluating a model’s performance.
<- null_model() |>
null_spec set_engine("parsnip") |>
set_mode("classification")
|>
null_spec fit_resamples(
~ .,
artist resamples = lyrics_folds
|>
) collect_metrics()
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.662 10 0.00385 Preprocessor1_Model1
2 roc_auc binary 0.5 10 0 Preprocessor1_Model1
Fit a random forest model
Define the feature engineering recipe
Demonstration:
- Define a feature engineering recipe to predict the song’s artist as a function of the lyrics + audio features
- Exclude the ID variables from the recipe
- Tokenize the song lyrics
- Remove stop words
- Only keep the 500 most frequently appearing tokens
- Calculate tf-idf scores for the remaining tokens
This will generate one column for every token. Each column will have the standardized name
tfidf_lyrics_*
where*
is the specific token. Instead we would prefer the column names simply be*
. You can remove thetfidf_lyrics_
prefix using# Simplify these names step_rename_at(starts_with("tfidf_lyrics_"), fn = \(x) str_replace_all( string = x, pattern = "tfidf_lyrics_", replacement = "" ) )
This does cause a conflict between the
energy
audio feature and the tokenenergy
. We will add a prefix to the audio features to avoid this conflict.# Simplify these names step_rename_at( all_predictors(), -starts_with("tfidf_lyrics_"), fn = \(x) str_glue("af_{x}") )
- Downsample the observations so there are an equal number of songs by Beyoncé and Taylor Swift in the analysis set
# define preprocessing recipe
<- recipe(artist ~ ., data = lyrics_train) |>
lyrics_rec # exclude ID variables
update_role(album_name, track_number, track_name, new_role = "id vars") |>
step_tokenize(lyrics) |>
step_stopwords(lyrics) |>
step_tokenfilter(lyrics, max_tokens = 500) |>
step_tfidf(lyrics) |>
# Simplify these names
step_rename_at(
all_predictors(), -starts_with("tfidf_lyrics_"),
fn = \(x) str_glue("af_{x}")
|>
) step_rename_at(starts_with("tfidf_lyrics_"),
fn = \(x) str_replace_all(
string = x,
pattern = "tfidf_lyrics_",
replacement = ""
)|>
) step_downsample(artist)
lyrics_rec
Fit the model
Demonstration:
- Define a random forest model grown with 1000 trees using the
ranger
engine. - Define a workflow using the feature engineering recipe and random forest model specification. Fit the workflow using the cross-validation folds.
- Use
control = control_resamples(save_pred = TRUE)
to save the assessment set predictions. We need these to assess the model’s performance.
- Use
# define the model specification
<- rand_forest(trees = 1000) |>
ranger_spec set_mode("classification") |>
# calculate feature importance metrics using the ranger engine
set_engine("ranger", importance = "permutation")
# define the workflow
<- workflow() |>
ranger_workflow add_recipe(lyrics_rec) |>
add_model(ranger_spec)
# fit the model to each of the cross-validation folds
<- ranger_workflow |>
ranger_cv fit_resamples(
resamples = lyrics_folds,
control = control_resamples(save_pred = TRUE, save_workflow = TRUE)
)
Evaluate model performance
Demonstration:
- Calculate the model’s accuracy and ROC AUC. How did it perform?
- Draw the ROC curve for each validation fold
- Generate the resampled confusion matrix for the model and draw it using a heatmap. How does the model perform predicting Beyoncé songs relative to Taylor Swift songs?
# extract metrics and predictions
<- collect_metrics(ranger_cv)
ranger_cv_metrics <- collect_predictions(ranger_cv)
ranger_cv_predictions
# how well did the model perform?
ranger_cv_metrics
# A tibble: 2 × 6
.metric .estimator mean n std_err .config
<chr> <chr> <dbl> <int> <dbl> <chr>
1 accuracy binary 0.858 10 0.0250 Preprocessor1_Model1
2 roc_auc binary 0.929 10 0.0180 Preprocessor1_Model1
# roc curve
|>
ranger_cv_predictions group_by(id) |>
roc_curve(truth = artist, .pred_Beyoncé) |>
autoplot()
# confusion matrix
conf_mat_resampled(x = ranger_cv, tidy = FALSE) |>
autoplot(type = "heatmap")
Penalized regression
Define the feature engineering recipe
Demonstration:
- Define a feature engineering recipe to predict the song’s artist as a function of the lyrics + audio features
- Exclude the ID variables from the recipe
- Tokenize the song lyrics
- Calculate all possible 1-grams, 2-grams, 3-grams, 4-grams, and 5-grams
- Remove stop words
- Only keep the 2000 most frequently appearing tokens
- Calculate tf-idf scores for the remaining tokens
- Rename audio feature and tf-idf as before
- Apply required steps for penalized regression models
- Convert the
explicit
variable to a factor - Convert nominal predictors to dummy variables
- Get rid of zero-variance predictors
- Normalize all predictors to mean of 0 and variance of 1
- Convert the
- Downsample the observations so there are an equal number of songs by Beyoncé and Taylor Swift in the analysis set
<- recipe(artist ~ ., data = lyrics_train) |>
glmnet_rec # exclude ID variables
update_role(album_name, track_number, track_name, new_role = "id vars") |>
# tokenize and prep lyrics
step_tokenize(lyrics) |>
step_stopwords(lyrics) |>
step_ngram(lyrics, num_tokens = 5L, min_num_tokens = 1L) |>
step_tokenfilter(lyrics, max_tokens = 2000) |>
step_tfidf(lyrics) |>
# Simplify these names
step_rename_at(
all_predictors(), -starts_with("tfidf_lyrics_"),
fn = \(x) str_glue("af_{x}")
|>
) step_rename_at(starts_with("tfidf_lyrics_"),
fn = \(x) str_replace_all(
string = x,
pattern = "tfidf_lyrics_",
replacement = ""
)|>
) # fix explicit variable to factor
step_bin2factor(af_explicit) |>
# normalize for penalized regression
step_dummy(all_nominal_predictors()) |>
step_zv(all_predictors()) |>
step_normalize(all_numeric_predictors()) |>
step_downsample(artist)
glmnet_rec
Tune the penalized regression model
Demonstration:
- Define the penalized regression model specification, including tuning placeholders for
penalty
andmixture
- Create the workflow object
- Define a tuning grid with every combination of:
penalty = 10^seq(-6, -1, length.out = 20)
mixture = c(0, 0.2, 0.4, 0.6, 0.8, 1)
- Tune the model using the cross-validation folds
- Evaluate the tuning procedure and identify the best performing models based on ROC AUC
# define the penalized regression model specification
<- logistic_reg(penalty = tune(), mixture = tune()) |>
glmnet_spec set_mode("classification") |>
set_engine("glmnet")
# define the new workflow
<- workflow() |>
glmnet_workflow add_recipe(glmnet_rec) |>
add_model(glmnet_spec)
# create the tuning grid
<- expand_grid(
glmnet_grid penalty = 10^seq(-6, -1, length.out = 20),
mixture = c(0, 0.2, 0.4, 0.6, 0.8, 1)
)
# tune over the model hyperparameters
<- tune_grid(
glmnet_tune object = glmnet_workflow,
resamples = lyrics_folds,
grid = glmnet_grid,
control = control_grid(save_pred = TRUE, save_workflow = TRUE)
)
# evaluate results
collect_metrics(x = glmnet_tune)
# A tibble: 240 × 8
penalty mixture .metric .estimator mean n std_err .config
<dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0.000001 0 accuracy binary 0.8 10 0.0235 Preprocessor1_Mod…
2 0.000001 0 roc_auc binary 0.834 10 0.0428 Preprocessor1_Mod…
3 0.00000183 0 accuracy binary 0.8 10 0.0235 Preprocessor1_Mod…
4 0.00000183 0 roc_auc binary 0.834 10 0.0428 Preprocessor1_Mod…
5 0.00000336 0 accuracy binary 0.8 10 0.0235 Preprocessor1_Mod…
6 0.00000336 0 roc_auc binary 0.834 10 0.0428 Preprocessor1_Mod…
7 0.00000616 0 accuracy binary 0.8 10 0.0235 Preprocessor1_Mod…
8 0.00000616 0 roc_auc binary 0.834 10 0.0428 Preprocessor1_Mod…
9 0.0000113 0 accuracy binary 0.8 10 0.0235 Preprocessor1_Mod…
10 0.0000113 0 roc_auc binary 0.834 10 0.0428 Preprocessor1_Mod…
# ℹ 230 more rows
autoplot(glmnet_tune)
# identify the five best hyperparameter combinations
show_best(x = glmnet_tune, metric = "roc_auc")
# A tibble: 5 × 8
penalty mixture .metric .estimator mean n std_err .config
<dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
1 0.000001 0 roc_auc binary 0.834 10 0.0428 Preprocessor1_Model…
2 0.00000183 0 roc_auc binary 0.834 10 0.0428 Preprocessor1_Model…
3 0.00000336 0 roc_auc binary 0.834 10 0.0428 Preprocessor1_Model…
4 0.00000616 0 roc_auc binary 0.834 10 0.0428 Preprocessor1_Model…
5 0.0000113 0 roc_auc binary 0.834 10 0.0428 Preprocessor1_Model…
Fit the best model
Your turn:
- Select the model + hyperparameter combinations that achieve the highest ROC AUC
- Fit that model using the best hyperparameters and the full training set. How well does the model perform on the test set?
# select the best model's hyperparameters
<- fit_best(ranger_cv)
rf_best
# test set ROC AUC
bind_cols(
lyrics_test,predict(rf_best, new_data = lyrics_test, type = "prob")
|>
) roc_auc(truth = artist, .pred_Beyoncé)
# A tibble: 1 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 roc_auc binary 0.922
Variable importance
We can examine the results of each model to evaluate which tokens were the most important in generating artist predictions. Here we use vip to calculate importance.
# extract parsnip model fit
<- extract_fit_parsnip(rf_best) |>
rf_imp vi(method = "model")
# clean up the data frame for visualization
|>
rf_imp # extract 20 most important n-grams
slice_max(order_by = Importance, n = 20) |>
mutate(Variable = fct_reorder(.f = Variable, .x = Importance)) |>
ggplot(mapping = aes(
x = Importance,
y = Variable
+
)) geom_col() +
labs(
y = NULL,
title = "Most relevant features for predicting whether\na song is by Beyoncé or Taylor Swift",
subtitle = "Random forest model"
)
# extract parsnip model fit
<- glmnet_tune |>
glmnet_imp fit_best() |>
extract_fit_parsnip() |>
vi(method = "model", lambda = select_best(x = glmnet_tune, metric = "roc_auc")$penalty)
# clean up the data frame for visualization
|>
glmnet_imp mutate(
Sign = case_when(
== "NEG" ~ "More likely from Beyoncé",
Sign == "POS" ~ "More likely from Taylor Swift"
Sign
),Importance = abs(Importance)
|>
) # importance must be greater than 0
filter(Importance > 0) |>
# keep top 20 features for each artist
slice_max(n = 20, order_by = Importance, by = Sign) |>
mutate(Variable = fct_reorder(.f = Variable, .x = Importance)) |>
ggplot(mapping = aes(
x = Importance,
y = Variable,
fill = Sign
+
)) geom_col(show.legend = FALSE) +
scale_fill_brewer(type = "qual") +
facet_wrap(facets = vars(Sign), scales = "free_y") +
labs(
y = NULL,
title = "Most relevant features for predicting whether\na song is by Beyoncé or Taylor Swift",
subtitle = "Penalized regression model"
)
::session_info() sessioninfo
─ Session info ───────────────────────────────────────────────────────────────
setting value
version R version 4.3.1 (2023-06-16)
os macOS Ventura 13.5.2
system aarch64, darwin20
ui X11
language (EN)
collate en_US.UTF-8
ctype en_US.UTF-8
tz America/New_York
date 2023-11-17
pandoc 3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
─ Packages ───────────────────────────────────────────────────────────────────
package * version date (UTC) lib source
backports 1.4.1 2021-12-13 [1] CRAN (R 4.3.0)
bit 4.0.5 2022-11-15 [1] CRAN (R 4.3.0)
bit64 4.0.5 2020-08-30 [1] CRAN (R 4.3.0)
broom * 1.0.5 2023-06-09 [1] CRAN (R 4.3.0)
class 7.3-22 2023-05-03 [1] CRAN (R 4.3.0)
cli 3.6.1 2023-03-23 [1] CRAN (R 4.3.0)
codetools 0.2-19 2023-02-01 [1] CRAN (R 4.3.0)
colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.0)
crayon 1.5.2 2022-09-29 [1] CRAN (R 4.3.0)
data.table 1.14.8 2023-02-17 [1] CRAN (R 4.3.0)
dials * 1.2.0 2023-04-03 [1] CRAN (R 4.3.0)
DiceDesign 1.9 2021-02-13 [1] CRAN (R 4.3.0)
digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.0)
dplyr * 1.1.3 2023-09-03 [1] CRAN (R 4.3.0)
ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.3.0)
evaluate 0.22 2023-09-29 [1] CRAN (R 4.3.1)
fansi 1.0.5 2023-10-08 [1] CRAN (R 4.3.1)
farver 2.1.1 2022-07-06 [1] CRAN (R 4.3.0)
fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.0)
forcats * 1.0.0 2023-01-29 [1] CRAN (R 4.3.0)
foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.0)
furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.0)
future 1.32.0 2023-03-07 [1] CRAN (R 4.3.0)
future.apply 1.11.0 2023-05-21 [1] CRAN (R 4.3.0)
generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.0)
ggplot2 * 3.4.2 2023-04-03 [1] CRAN (R 4.3.0)
glmnet * 4.1-7 2023-03-23 [1] CRAN (R 4.3.0)
globals 0.16.2 2022-11-21 [1] CRAN (R 4.3.0)
glue 1.6.2 2022-02-24 [1] CRAN (R 4.3.0)
gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.0)
GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.0)
gridExtra 2.3 2017-09-09 [1] CRAN (R 4.3.0)
gtable 0.3.3 2023-03-21 [1] CRAN (R 4.3.0)
hardhat 1.3.0 2023-03-30 [1] CRAN (R 4.3.0)
here 1.0.1 2020-12-13 [1] CRAN (R 4.3.0)
hms 1.1.3 2023-03-21 [1] CRAN (R 4.3.0)
htmltools 0.5.6.1 2023-10-06 [1] CRAN (R 4.3.1)
htmlwidgets 1.6.2 2023-03-17 [1] CRAN (R 4.3.0)
infer * 1.0.4 2022-12-02 [1] CRAN (R 4.3.0)
ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.0)
iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.0)
jsonlite 1.8.7 2023-06-29 [1] CRAN (R 4.3.0)
knitr 1.44 2023-09-11 [1] CRAN (R 4.3.0)
labeling 0.4.2 2020-10-20 [1] CRAN (R 4.3.0)
lattice 0.21-8 2023-04-05 [1] CRAN (R 4.3.0)
lava 1.7.2.1 2023-02-27 [1] CRAN (R 4.3.0)
lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.0)
lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.3.0)
listenv 0.9.0 2022-12-16 [1] CRAN (R 4.3.0)
lubridate * 1.9.3 2023-09-27 [1] CRAN (R 4.3.1)
magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0)
MASS 7.3-60 2023-05-04 [1] CRAN (R 4.3.0)
Matrix * 1.5-4.1 2023-05-18 [1] CRAN (R 4.3.0)
modeldata * 1.1.0 2023-01-25 [1] CRAN (R 4.3.0)
modelenv 0.1.1 2023-03-08 [1] CRAN (R 4.3.0)
munsell 0.5.0 2018-06-12 [1] CRAN (R 4.3.0)
nnet 7.3-19 2023-05-03 [1] CRAN (R 4.3.0)
parallelly 1.36.0 2023-05-26 [1] CRAN (R 4.3.0)
parsnip * 1.1.0 2023-04-12 [1] CRAN (R 4.3.0)
pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.0)
pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.0)
prodlim 2023.03.31 2023-04-02 [1] CRAN (R 4.3.0)
purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.3.0)
R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0)
ranger * 0.15.1 2023-04-03 [1] CRAN (R 4.3.0)
RColorBrewer 1.1-3 2022-04-03 [1] CRAN (R 4.3.0)
Rcpp 1.0.10 2023-01-22 [1] CRAN (R 4.3.0)
readr * 2.1.4 2023-02-10 [1] CRAN (R 4.3.0)
recipes * 1.0.6 2023-04-25 [1] CRAN (R 4.3.0)
rlang 1.1.1 2023-04-28 [1] CRAN (R 4.3.0)
rmarkdown 2.25 2023-09-18 [1] CRAN (R 4.3.1)
ROSE 0.0-4 2021-06-14 [1] CRAN (R 4.3.0)
rpart 4.1.19 2022-10-21 [1] CRAN (R 4.3.0)
rprojroot 2.0.3 2022-04-02 [1] CRAN (R 4.3.0)
rsample * 1.1.1 2022-12-07 [1] CRAN (R 4.3.0)
rstudioapi 0.14 2022-08-22 [1] CRAN (R 4.3.0)
scales * 1.2.1 2022-08-20 [1] CRAN (R 4.3.0)
sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.0)
shape 1.4.6 2021-05-19 [1] CRAN (R 4.3.0)
SnowballC 0.7.1 2023-04-25 [1] CRAN (R 4.3.0)
stopwords * 2.3 2021-10-28 [1] CRAN (R 4.3.0)
stringi 1.7.12 2023-01-11 [1] CRAN (R 4.3.0)
stringr * 1.5.0 2022-12-02 [1] CRAN (R 4.3.0)
survival 3.5-5 2023-03-12 [1] CRAN (R 4.3.0)
textrecipes * 1.0.3 2023-04-14 [1] CRAN (R 4.3.0)
themis * 1.0.1 2023-04-14 [1] CRAN (R 4.3.0)
tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.0)
tidymodels * 1.1.0 2023-05-01 [1] CRAN (R 4.3.0)
tidyr * 1.3.0 2023-01-24 [1] CRAN (R 4.3.0)
tidyselect 1.2.0 2022-10-10 [1] CRAN (R 4.3.0)
tidyverse * 2.0.0 2023-02-22 [1] CRAN (R 4.3.0)
timechange 0.2.0 2023-01-11 [1] CRAN (R 4.3.0)
timeDate 4022.108 2023-01-07 [1] CRAN (R 4.3.0)
tokenizers 0.3.0 2022-12-22 [1] CRAN (R 4.3.0)
tune * 1.1.1 2023-04-11 [1] CRAN (R 4.3.0)
tzdb 0.4.0 2023-05-12 [1] CRAN (R 4.3.0)
utf8 1.2.4 2023-10-22 [1] CRAN (R 4.3.1)
vctrs 0.6.4 2023-10-12 [1] CRAN (R 4.3.1)
vip * 0.3.2 2020-12-17 [1] CRAN (R 4.3.0)
vroom 1.6.3 2023-04-28 [1] CRAN (R 4.3.0)
withr 2.5.2 2023-10-30 [1] CRAN (R 4.3.1)
workflows * 1.1.3 2023-02-22 [1] CRAN (R 4.3.0)
workflowsets * 1.0.1 2023-04-06 [1] CRAN (R 4.3.0)
xfun 0.40 2023-08-09 [1] CRAN (R 4.3.0)
yaml 2.3.7 2023-01-23 [1] CRAN (R 4.3.0)
yardstick * 1.2.0 2023-04-21 [1] CRAN (R 4.3.0)
[1] /Library/Frameworks/R.framework/Versions/4.3-arm64/Resources/library
──────────────────────────────────────────────────────────────────────────────