Skip to content
Home ยป Using Keras in R: Training a model

Using Keras in R: Training a model

  • by
  • 5 min read

In this blog post I will introduce you to building and training your own neural network algorithm in R through Keras & TensorFlow. If you haven’t installed Keras for R yet, please follow the instructions explained in part 1.

I have explicitly chosen to work with structured data in this blog post. Because really… who works with (i.e.) pictures in R? The dataset contains properties of cars and the ask price for second hand cars on a Belgian Craigslist-ish website, scraped in 2018.

I could skip explaining the first parts of my code, but I think it’s good to get familiar with data that is not loaded from the datasets API.

Preprocessing

In this first part, I load a dataset that you can find in my GitHub repo. I impute missing values using RF Imputation and dummify all factor data, because TensorFlow/Keras does not accept data frames.

options(scipen=999)

library(keras)
library(data.table)
library(missRanger)
library(dummies)
library(ggplot2)

set.seed(19880303)
dt <- fread('dt.csv', stringsAsFactors=T)
dt[,car_model := NULL]
dt[,car_year := 2019 - car_year]

dt <- missRanger(dt, pmm.k=3, num.trees=100) # RF Imputation
dt <- dummy.data.frame(dt, dummy.classes = 'factor') # Dummify factors

Next, I split the data into train and test data. I convert everything to matrices and finally, do some cleaning of features that only contain one data point.

# Train/test split
ind <- sample(nrow(dt),round(nrow(dt)*0.8,0))
x_train <- dt[ind,]
x_test <- dt[-ind,]

y_train <- x_train[,1]
x_train <- x_train[,-1]
y_test <- x_test[,1]
x_test <- x_test[,-1]
rm(dt,ind)

# To matrices
x_train <- as.matrix(x_train)
y_train <- as.matrix(y_train)
x_test <- as.matrix(x_test)
y_test <- as.matrix(y_test)

# Some removal of columns with only one car
x_train <- x_train[,!(colnames(x_train) %in% c('car_brandhummer','car_brandlamborghini'))]
x_test <- x_test[,!(colnames(x_test) %in% c('car_brandhummer','car_brandlamborghini'))]

Finally, I normalize data. This is not necessary per se, but will definitely be benificial to your training procedure — a lot. You should know that normalizing your data allows for a smoother gradient descent and the loss of your training will converge much faster.

# Normalizing data
x_train <- scale(x_train)
x_test <- scale(x_test, center = attr(x_train, "scaled:center") , scale = attr(x_train, "scaled:scale"))

# Removal of columns that potentially contain NAs
x_test <- x_test[,!(colSums(is.na(x_train)) > 0)]
x_train <- x_train[,!(colSums(is.na(x_train)) > 0)]

Building the neural network

Now we get to the juicy part: actually building the network. In this model I work with a sequential network and three dense layers that are ReLU-activated. There’s two ways to add activation functions: (1) by specifying them in the dense layer or (2) by adding them as a separate layer. The output layer is a 1-node dense layer.

There’s a dropout layer to thin the network and avoid overfitting. The optimizer is an RMSprop, and since this is a regression job, I am optimizing for mean squared error. The metric I want to inspect visually is mean absolute error.

build_model <- function() {
  model <- keras_model_sequential() 
  model %>% 
    layer_dense(units = 64, 
                input_shape = dim(x_train)[2],
                kernel_regularizer = regularizer_l2(l = 0.001)) %>% 
    layer_activation_relu() %>% # (2) Separate ReLU layer
    layer_dense(units = 128,
                activation = 'relu') %>% # (1) Specified in the dense layer
    layer_dropout(0.6) %>%
    layer_dense(units = 64,
                activation = 'relu') %>%
    layer_dense(units = 1)
  
  model %>% compile(
    loss = "mse",
    optimizer = optimizer_rmsprop(),
    metrics = list("mean_absolute_error")
  )
  model
}

model <- build_model()
model %>% summary()

If your RStudio crashes during the building of the model and shows the error “R Session Aborted – R encountered a fatal error. The Session was terminated.”, you have specified parameters that do not exist. Bad news: there no convenient way to debug your network. You’re on your own here. Try modifying several parameters, one by one.

R Session Aborted - R encountered a fatal error. The Session was terminated.
R Session Aborted – R encountered a fatal error. The Session was terminated.

Training the algorithm

In this step I specify two callback functions.

  • For every epoch I want to print a dot, to see it’s actually working.
  • The training needs to stop early if the validation loss does not decrease anymore.

To conclude, I specify that I want to train for 100 epochs. The training starts by invoking the fit() function on the neural network we built. It will store all training information in the history object.

In a final phase, I predict on the test dataset and plot the predicted versus the real values.

print_dot_callback <- callback_lambda(
  on_epoch_end = function(epoch, logs) {
    if (epoch %% 80 == 0) cat("\n")
    cat(".")
  }
)

early_stop <- callback_early_stopping(monitor = "val_loss", patience = 20)

epochs <- 100

# Fit the model and store training stats
history <- model %>% fit(
  x_train,
  y_train,
  epochs = epochs,
  validation_split = 0.2,
  verbose = 1,
  callbacks = list(early_stop, print_dot_callback)
)

y_test_pred <- model %>% predict(x_test)
ggplot(data.table('test' = c(y_test),'pred' = c(y_test_pred)), 
       aes(x=test, y=pred)) + geom_point() + xlim(0,5000) + ylim(0,5000) +
  geom_smooth(method='lm')

In the graph you can see how the loss and mean absolute error decreases for the training and validation set.

By the way, if you’re having trouble understanding some of the code and concepts, I can highly recommend “An Introduction to Statistical Learning: with Applications in R”, which is the must-have data science bible. If you simply need an introduction into R, and less into the Data Science part, I can absolutely recommend this book by Richard Cotton. Hope it helps!

In a next post I will elaborate on hypertuning the model.

Say thanks, ask questions or give feedback

Technologies get updated, syntax changes and honestly… I make mistakes too. If something is incorrect, incomplete or doesn’t work, let me know in the comments below and help thousands of visitors.