Why scikit learn's fit transform is probably not for you

stephantul1 pts0 comments

Scikit-learn’s fit transform paradigm is probably not for you | Stéphan Tulkens

Scikit-learn's fit transform paradigm is probably not for you

Scikit-learn's fit transform paradigm is probably not for you

python

May 17, 2026

If you’ve ever used code from scikit-learn, you will have seen the following pattern:

import numpy as np

from sklearn.preprocessing import StandardScaler

X = np.random.randn((100, 32))

scaler = StandardScaler()<br>scaler.fit(X)<br>X_transformed = scaler.transform(X)

# Or equivalently<br>X_transformed = scaler.fit_transform(X)

For all scikit-learn transformers (1), the fit call sets the internal state of the object, while the transform call uses the set internal state to transform some data into something else. (2) This paradigm is really useful because it allows for zero-cost chaining: any sequence of transformations can be fit_transformed by simply calling fit_transform on all transformations in sequence.

Conflation between construction and usage

The main point I’ll be making in this article is that scikit-learn’s fit transform paradigm mixes up the factory pattern, that is, an object that instantiates other objects, with the actual objects. This is used really well by scikit-learn, but probably doesn’t fit your codebase.

To illustrate, let’s reimplement the StandardScaler using numpy: (3)

from __future__ import annotations

import numpy as np

class StandardScaler:

def __init__(self, with_mean: bool = True, with_std: bool = True) -> None:<br>self.mean: None | np.ndarray = None<br>self.std: None | np.ndarray = None<br>self.with_mean = with_mean<br>self.with_std = with_std

def fit(self, X: np.ndarray) -> StandardScaler:<br>if self.with_mean:<br>self.mean = X.mean(0)<br>if self.with_std:<br>self.std = X.std(0)

return self

@property<br>def _is_fit(self) -> bool:<br>if self.with_mean and self.mean is None:<br>return False<br>if self.with_std and self.std is None:<br>return False<br>return True

def transform(self, X: np.ndarray) -> np.ndarray:<br>if not self._is_fit:<br>raise ValueError("Standardscaler has not been fit")<br>if self.with_mean:<br>X = X - self.mean<br>if self.with_std:<br>X = X / self.std<br>return X

def fit_transform(self, X: np.ndarray) -> np.ndarray:<br>self.fit(X)<br>return self.transform(X)

Let’s first talk about the initializer. In a scikit-learn initializer, you are only supposed to set the so-called hyperparameters of a transformer or estimator.That is, you should only set attribues that do not depend on the data you will use to fit the model. So, in this case, the parameters of the initializer determine what the behavior of the instantiated StandardScaler will be. So, in our case, with_mean and with_std determine what the behavior is of the StandardScaler that is produced by fitting the StandardScaler on some data; if we set with_mean to False, we actually get a different object than we would get if we set it to True.

Second, note that the fit function is destructive. It erases the original state, and introduces a completely new state. From a python perspective, however, the same object is returned, its only the internal state that is reset.

Third, note that there is no need to store the hyperparameters once you’ve fit the transformer. 4

Fourth, for a given StandardScaler, it is impossible to know whether it has been fit or not. So, whenever you work with scikit-learn’s internals, you’ll have to continuously check whether the estimators and transformers you work with actually have their internal state set.

Fifth, when you write your own transformers and estimators, it is very easy to incorrectly implement this state. (5)

Splitting out the factory

So, now on to my main thesis: this whole problem can be avoided by conceding that StandardScaler is both a factory and the object that is constructed by the factory. As such, if we split this up into two separate classes, we’ll see that we’ll end up with much cleaner code.

from __future__ import annotations

import numpy as np

class StandardScaler:

def __init__(self, mean: np.ndarray | None, std: np.ndarray | None) -> None:<br>self.mean: np.ndarray | None = mean<br>self.std: np.ndarray | None = std

def transform(self, X: np.ndarray) -> np.ndarray:<br>if self.mean is not None:<br>X = X - self.mean<br>if self.std is not None:<br>X = X / self.std<br>return X

class StandardScalerFactory:

def __init__(self, with_mean: bool = True, with_std: bool = True) -> None:<br>self.with_mean = with_mean<br>self.with_std = with_std

def fit(self, X: np.ndarray) -> StandardScaler:<br>mean, std = None, None<br>if self.with_mean:<br>mean = X.mean(0)<br>if self.with_std:<br>std = X.std(0)

return StandardScaler(mean, std)

def fit_transform(self, X: np.ndarray) -> tuple[StandardScaler, np.ndarray]:<br>scaler = self.fit(X)<br>return scaler, scaler.transform(X)

As you can see, we’ve changed the structure considerably. fit now returns an object which implements transform, and only implements transform. fit_transform returns a tuple, the first item of which is the fit object, the second of which is the transformed data. This still allows...

self none ndarray standardscaler transform mean

Related Articles