Home » Using scikit’s OneHotEncoder only on categorical variables of a data frame

# Using scikit’s OneHotEncoder only on categorical variables of a data frame

I’ve been trying to build a model using machine learning today and I bumped into an error when I wanted to dummify my categorical predictors. It seemed I didn’t really know how Scikit’s OneHotEncoder worked. But I do now. And I want to share it with you.

I had a data frame with some categorical values (of type object) and I needed to preprocess these so I could feed them to a Random Forest Classifier.

from sklearn.preprocessing import OneHotEncoder

model_OHE = OneHotEncoder(categories = ['column_2', 'column_3'])
dummified = model_OHE.fit_transform(X)

The error I ran into was:

Shape mismatch: if categories is an array, it has to be of shape (n_features,).

Apparently, the categories argument isn’t used to tell OneHotEncoder which columns are categorical values. It is to explicitly declare which values should be encoded as a category.

What I should have been doing instead is using the ColumnTransformer. It’s a nice tool to chain transformations on a Pandas data frame. It allows you to select columns to apply the transformation on. Furthermore: these transformed columns are automatically concatenated to the rest of the data frame if you set the remainder argument to “passthrough”.

from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer

model_OHE = ColumnTransformer(
[('OHE', OneHotEncoder(),['column_2','column_3'])],
remainder = 'passthrough'
)
dummified = model_OHE.fit_transform(X)

By the way, your columns will have no names when you OneHotEncode them. Here’s something I put together to set the column names adequately.

def dummify(OHE, x, columns):
transformed_array = OHE.transform(x)
initial_colnames_keep = list(set(x.columns.tolist()) - set(columns))
new_colnames = np.concatenate(model_OHE.named_transformers_['OHE'].categories_).tolist()
all_colnames = new_colnames + initial_colnames_keep
df = pd.DataFrame(transformed_array, index = x.index, columns = all_colnames)
return df

Great success!

### Say thanks, ask questions or give feedback

Technologies get updated, syntax changes and honestly… I make mistakes too. If something is incorrect, incomplete or doesn’t work, let me know in the comments below and help thousands of visitors.