Diagnosing heart disease using a robust and explainable pipeline in tidyverse


Diagnosing heart disease using a robust and explainable pipeline in tidyverse

One application of machine-learning in the domain of health is diagnosing a condition using descriptive information, which can lead to a particular course of treatment or the absence of treatment. In such a context health professionals and patients need to understand the process used by the model to generate a final decision or classification.

Here we use the Heart Disease UCI dataset from Kaggle including 13 features as well as a target condition of having or not having heart disease.

First, load the libraries and data

df <- read_csv("datasets_33180_43520_heart.csv")
## Parsed with column specification:
## cols(
##   age = col_double(),
##   sex = col_double(),
##   cp = col_double(),
##   trestbps = col_double(),
##   chol = col_double(),
##   fbs = col_double(),
##   restecg = col_double(),
##   thalach = col_double(),
##   exang = col_double(),
##   oldpeak = col_double(),
##   slope = col_double(),
##   ca = col_double(),
##   thal = col_double(),
##   target = col_double()
## )

Lets take a look

df = df %>%
  mutate_at(vars(sex,cp, fbs, restecg, exang, slope, ca,  thal, target), as.factor)

df
## # A tibble: 303 x 14
##      age sex   cp    trestbps  chol fbs   restecg thalach exang oldpeak slope
##    <dbl> <fct> <fct>    <dbl> <dbl> <fct> <fct>     <dbl> <fct>   <dbl> <fct>
##  1    63 1     3          145   233 1     0           150 0         2.3 0    
##  2    37 1     2          130   250 0     1           187 0         3.5 0    
##  3    41 0     1          130   204 0     0           172 0         1.4 2    
##  4    56 1     1          120   236 0     1           178 0         0.8 2    
##  5    57 0     0          120   354 0     1           163 1         0.6 2    
##  6    57 1     0          140   192 0     1           148 0         0.4 1    
##  7    56 0     1          140   294 0     0           153 0         1.3 1    
##  8    44 1     1          120   263 0     1           173 0         0   2    
##  9    52 1     2          172   199 1     1           162 0         0.5 2    
## 10    57 1     2          150   168 0     1           174 0         1.6 2    
## # ... with 293 more rows, and 3 more variables: ca <fct>, thal <fct>,
## #   target <fct>

Variables are as follows:

  • age: The person’s age in years
  • sex: The person’s sex (1 = male, 0 = female)
  • cp: The chest pain experienced (Value 1: typical angina, Value 2: atypical angina, Value 3: non-anginal pain, Value 4: asymptomatic)
  • trestbps: The person’s resting blood pressure (mm Hg on admission to the hospital)
  • chol: The person’s cholesterol measurement in mg/dl
  • fbs: The person’s fasting blood sugar (> 120 mg/dl, 1 = true; 0 = false)
  • restecg: Resting electrocardiographic measurement (0 = normal, 1 = having ST-T wave abnormality, 2 = showing probable or definite left ventricular hypertrophy by Estes’ criteria)
  • thalach: The person’s maximum heart rate achieved
  • exang: Exercise induced angina (1 = yes; 0 = no)
  • oldpeak: ST depression induced by exercise relative to rest (‘ST’ relates to positions on the ECG plot. See more here)
  • slope: the slope of the peak exercise ST segment (Value 1: upsloping, Value 2: flat, Value 3: downsloping)
  • ca: The number of major vessels (0-3)
  • thal: A blood disorder called thalassemia (3 = normal; 6 = fixed defect; 7 = reversable defect) target: Heart disease (0 = no, 1 = yes)

As indicated here “The diagnosis of heart disease is done on a combination of clinical signs and test results. The types of tests run will be chosen on the basis of what the physician thinks is going on, ranging from electrocardiograms and cardiac computerized tomography (CT) scans, to blood tests and exercise stress tests.”

Here our pipeline includes seven steps as follows:

  • Step 0

In order to implement a robust pipeline we first create training and test splits with a 10 fold cross validation structure.

df_split <- initial_split(df)
train_data <- training(df_split)
test_data <- testing(df_split)
cv_train <- vfold_cv(train_data, v = 10, repeats = 5, strata = "target")
  • Step 1

Then we proceed with preprocessing using the recipe package including imputation of missing data, standardisation of numeric variables, and dummy coding of the categorical variables.

rec_obj <- recipe(target ~ ., data = df)
imputed <- rec_obj %>%
  step_knnimpute(all_predictors()) 
ind_vars <- imputed %>%
  step_dummy(all_predictors(), -all_numeric()) 
standardized <- ind_vars %>%
  step_center(all_predictors())  %>%
  step_scale(all_predictors()) 
  • Step 2

The model development step will be implemented using the tidymodels package by setting the model, engine, and variable to be tuned in the cross-validation procedure. This starts with setting up the model:

rf_mod <- 
  rand_forest(mtry = tune(), min_n = tune(), trees = 1000) %>% 
  set_engine("ranger", num.threads = cores) %>% 
  set_mode("classification")

rf_mod
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   mtry = tune()
##   trees = 1000
##   min_n = tune()
## 
## Engine-Specific Arguments:
##   num.threads = cores
## 
## Computational engine: ranger
  • Step 3

Development of the workflow and to implement the training step with the grid search function. Setting up the workflow:

rf_workflow <- 
  workflow() %>% 
  add_model(rf_mod) %>% 
  add_recipe(standardized)

rf_workflow
## == Workflow ===============================================================================================================================
## Preprocessor: Recipe
## Model: rand_forest()
## 
## -- Preprocessor ---------------------------------------------------------------------------------------------------------------------------
## 4 Recipe Steps
## 
## * step_knnimpute()
## * step_dummy()
## * step_center()
## * step_scale()
## 
## -- Model ----------------------------------------------------------------------------------------------------------------------------------
## Random Forest Model Specification (classification)
## 
## Main Arguments:
##   mtry = tune()
##   trees = 1000
##   min_n = tune()
## 
## Engine-Specific Arguments:
##   num.threads = cores
## 
## Computational engine: ranger

and running the tune grid function

set.seed(345)
rf_res <- 
  rf_workflow %>% 
  tune_grid(grid = 25,
            control = control_grid(save_pred = TRUE),
            metrics = metric_set(roc_auc), 
            resamples = cv_train)
  • Step 4

Find the set of best performing hyperparameter using the performance metrics (collect metrics) and to select best models to fit to test data.

rf_res %>%
  collect_metrics()
rf_best <- 
  rf_res %>% 
  select_best(metric = "roc_auc")

rf_best

rf_res %>% 
  show_best(metric = "roc_auc")

autoplot(rf_res)
  • Step 5

Fit the best performing model to the test data using the last fit function.

last_rf_mod <- 
  rand_forest(mtry = 1, min_n = 38, trees = 1000) %>% 
  set_engine("ranger", num.threads = cores, importance = "impurity") %>% 
  set_mode("classification")

last_rf_workflow <- 
  rf_workflow %>% 
  update_model(last_rf_mod)


set.seed(345)
last_rf_fit <- 
  last_rf_workflow %>% 
  last_fit(df_split)


last_rf_fit %>% 
  collect_metrics()
## # A tibble: 2 x 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.867
## 2 roc_auc  binary         0.960
  • Step 6

Interpretability analysis: First let’s visualize a simplified tree

heat_tree(test_data, target_lab = 'target')
## Registered S3 method overwritten by 'seriation':
##   method         from 
##   reorder.hclust gclus

The VIP package provides model specific measure such feature importance based on class impurity for random forests

last_rf_fit %>% 
  pluck(".workflow", 1) %>%   
  pull_workflow_fit() %>% 
  vip(num_features = 15)

Another type of measures provided by the VIP package are model agnostic indices from permutation, ICE, or PDP, procedures. Permutation importance an index for interpreting a machine-learning model, which involves shuffling individual variables after a model has been fit and seeing the effect on accuracy.

Let’s take a look,

test_rec <- prep(standardized)%>%bake(test_data)
rfo <- ranger::ranger(target ~ ., mtry = 1, min.node.size = 38, num.trees = 1000, data = test_rec, importance = "permutation")
pfun <- function(object, newdata) predict(object, data = newdata)$predictions
vip(rfo, method = "permute", metric = "auc", pred_wrapper = pfun,
   target = "target", reference_class = 0)