Home ยป Dealing with right-censored data in machine learning: Random Survival Forests

Dealing with right-censored data in machine learning: Random Survival Forests

  • by
random_survival_forests
Want to do a random act of kindness? Share this post.

A couple of weeks ago, I started working with survival analysis. It was fairly new to me, so I had to dig into some new methods. There was one method that captured my attention: random survival forests (RSFs). It’s one of many statistical learning techniques designed to work with right-censored survival data. In this blog post I present a condensed primer on RSFs and how you can use them in R.

Although I explain all concepts or link to adequate documentation, this blog post will be more meaningful if you have prior knowledge of survival analysis, probability and tree-based machine learning methods.

The theory behind Random Survival Forests

Random forests introduce two items of randomness into decision tree methods (CART), to deal with trees’ inherent greediness. First, a number of decision trees is built on an equal number of bootstrapped training data samples. Second, a random samples of predictors is chosen as split candidates from the full set of predictors. By working with a subset of predictors, other predictors than the strongest predictor actually get a chance to be introduced into the model and the trees become decorrelated.

Although an off-the-shelf survival analysis is possible within a CART paradigm, Ishwaran et al. developed a Random Survival Forests that takes into account both survival time and censoring status.

These five steps are at the core of random survival forests (RSFs):

An essential element of RSFs is the Cumulative Hazard Function (CHF), or the probability of failure at time t given survival until time t — it is the integral of the hazard function. In RSFs, the CHF for each terminal node h is the Nelson-Aalen estimator. This estimator is built on N bootstrapped samples and evaluated on out-of-bag data. Here’s how it works: A CHF gets constructed for each bagged tree. Then, take the average of all these CHFs. Given the (ensemble) CHF, the (ensemble) mortality is estimated: the estimated value for the CHF summed over time.

To estimate the prediction error of a model, Ishwaran et al. use the C-index. It does not depend on a single fixed time for evaluation and accounts for censoring.

In the following example, we train an algorithm using Random Survival Forests from the ranger package, which is an implementation of Ishwaran et al. ‘s paper from 2008.

Putting Random Survival Forests to work in R

First, we load in all the required packages. We get our veteran dataset from the survival package. We use data.table as general framework. The ranger package will be used to train the RSF model. Finally, caret is used to make a confusion matrix.

rm(list=ls())

library(survival)
library(data.table)
library(ranger)
library(caret)

set.seed(19880303)

data(veteran)
dt <- data.table(veteran)
rm(veteran)

Next, we split the data in a training and test set.

ind <- sample(1:nrow(dt),round(nrow(dt) * 0.7,0))

dt_train <- dt[ind,]
dt_test <- dt[!ind,]

Next, we use the ranger packages to train the model. I also plot the survival curves for two cases (row 20 and row 21).

# Ranger
r_fit <- ranger(Surv(time, status) ~ .,
                data = dt_train,
                mtry = 3,
                verbose = TRUE,
                write.forest=TRUE,
                num.trees= 1000,
                importance = 'permutation')

plot(r_fit$unique.death.times, r_fit$survival[20,], type = 'l', col = 'red')
lines(r_fit$unique.death.times, r_fit$survival[21,], type = 'l', col = 'blue')

In the following chunk of code, I calculate the accuracy of the model when it needs to predict survival after 61 days (this is purely arbitrary for demonstrational purposes).

preds <- predict(r_fit, dt_test, type = 'response')$survival
preds <- data.table(preds)
colnames(preds) <- as.character(r_fit$unique.death.times)

prediction <- preds$`61` > 0.5
real <- dt_test$time >= 61

caret::confusionMatrix(as.factor(prediction), as.factor(real), positive = 'TRUE')

I hope you learned something!

Want to do a random act of kindness? Share this post.