Home » One-hot encoding with a TensorFlow DataGenerator

# One-hot encoding with a TensorFlow DataGenerator

In this blog post I explain how to create a DataGenerator with a one-hot encoder to encode your labels in the same way for every batch.

Some months ago, I tried training a text generator on a huge corpus of text with an LSTM model. Basically, it’s a model that predicts what the next word should be in a sentence. I had everything figured out, but tokenizing the text and one-hot encoding the many labels was an issue. After tokenizing the predictors and one-hot encoding the labels, the data set became massive, and it couldn’t even be stored in memory.

These are the two errors I kept running into:

tensorflow/core/framework/allocator.cc:107]
Allocation of 18970130000 exceeds 10% of system memory.

tensorflow/core/framework/op_kernel.cc:1502] OP_REQUIRES failed at one_hot_op.cc:97 :
Resource exhausted: OOM when allocating tensor with shape[xxxxxxx,xxx,xx] and type float

Although it as clear to me I should use a generator (like the ImageDataGenerator), my experience with writing custom TensorFlow code was limited. However, this week, I solved the problem. I wrote a DataGenerator class that properly encoded the labels. As a matter of fact, it’s not even that hard.

The class DataGenerator inherits from Sequence, because it’s a very memory-efficient and structured way of generating batches.

class DataGenerator(Sequence):

The constructor accepts predictors, labels, a fitted OneHotEncoder, batch_size, the number of classes, a maximum sequence length and a boolean to shuffle.


def __init__(self, predictors, labels, enc, batch_size=32, n_classes=25, max_seq_len=25, shuffle=True):
self.batch_size = batch_size
self.labels = labels
self.predictors = predictors
self.n_classes = n_classes
self.max_seq_len = max_seq_len
self.shuffle = shuffle
self.on_epoch_end()
self.enc = enc

Where the one-hot encoding happens, is within the __getitem__ method. This is the method that actually generates the batches. As you can see, the labels are transformed, according to the pre-fitted OneHotEncoder model that I passed to the object.

	def __getitem__(self, index):
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

# Find list of IDs
X = self.predictors[indexes]
y = self.labels[indexes]
y = self.enc.transform(y.reshape(-1,1))

return X, y

If we put it all together, this is the class we have created.

class DataGenerator(Sequence):
'Generates data for Keras'
def __init__(self, predictors, labels, enc, batch_size=32, n_classes=25, max_seq_len=25, shuffle=True):
self.batch_size = batch_size
self.labels = labels
self.predictors = predictors
self.n_classes = n_classes
self.max_seq_len = max_seq_len
self.shuffle = shuffle
self.on_epoch_end()
self.enc = enc

def __len__(self):
return int(np.floor(len(self.predictors) / self.batch_size))

def __getitem__(self, index):
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

# Find list of IDs
X = self.predictors[indexes]
y = self.labels[indexes]
y = self.enc.transform(y.reshape(-1,1))

return X, y

def on_epoch_end(self):
self.indexes = np.arange(len(self.predictors))
if self.shuffle == True:
np.random.shuffle(self.indexes)

The only thing that’s left to do is fitting a OneHotEncoder model and spinning up your training job. Again, it is of essence to fit the OneHotEncoder outside of the DataGenerator, otherwise the encoding will be different with every batch.

enc = OneHotEncoder(handle_unknown='ignore', sparse = False)
enc.fit(all_labels.reshape(-1,1))

train_gen = DataGenerator(predictors, labels, \
enc = enc,
batch_size = batch_size,
n_classes = vocab_size,
max_seq_len = MAX_SEQUENCE_LENGTH,
shuffle = True)

val_gen = DataGenerator(predictors_test, labels_test, \
enc = enc,
batch_size = batch_size,
n_classes = vocab_size,
max_seq_len = MAX_SEQUENCE_LENGTH,
shuffle = True)

model.fit_generator(train_gen, \
epochs = n_epochs,
verbose = 1,
validation_data = val_gen,
validation_freq = 1,
callbacks = [early_stopping_callback],
workers = 1)

Great success!

### Say thanks, ask questions or give feedback

Technologies get updated, syntax changes and honestly… I make mistakes too. If something is incorrect, incomplete or doesn’t work, let me know in the comments below and help thousands of visitors.