New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX SimpleImputer uses dtype seen in fit for transform #22063
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for checking it closer. I misdiagnose the bug :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, @thomasjpfan.
doc/whats_new/v1.1.rst
Outdated
| - |Fix| :class:`impute.SimpleImputer` now uses the dtype seen in `fit` for | ||
| `transform`. :pr:`22063` by `Thomas Fan`_. | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be move to doc/whats_new/v1.2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually in 1.1.2. I assume that we will do another bug fix release at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is likely too big of a change for 1.1.2. Currently fitting on float64 and transforming a float32 would return float32:
import numpy as np
from sklearn.impute import SimpleImputer
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
X = np.asarray([[7, 2, 3], [4, np.nan, 6], [10, 5, 9]], dtype=np.float64)
imp_mean.fit(X)
X_test = np.asarray([[np.nan, 2, 3], [4, np.nan, 6], [10, np.nan, 9]], dtype=np.float32)
X_trans = imp_mean.transform(X_test)
print(X_trans.dtype)
# float32Looking at this again, I think it's better to error when fitting on an object dtype, but transforming on a non-object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| @@ -278,6 +278,10 @@ def _validate_input(self, X, in_fit): | |||
| else: | |||
| dtype = FLOAT_DTYPES | |||
|
|
|||
| if not in_fit and self._fit_dtype.kind == "O": | |||
| # Use object dtype if fitted on object dtypes | |||
| dtype = self._fit_dtype | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated this PR to only use the fit_dtype only if the dype during fit is object.
This is to preserve the current behavior, of "Fit on float64 -> transform on float32 returns float32"
import numpy as np
from sklearn.impute import SimpleImputer
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
X = np.asarray([[7, 2, 3], [4, np.nan, 6], [10, 5, 9]], dtype=np.float64)
imp_mean.fit(X)
X_test = np.asarray([[np.nan, 2, 3], [4, np.nan, 6], [10, np.nan, 9]], dtype=np.float32)
X_trans = imp_mean.transform(X_test)
print(X_trans.dtype)
# float32There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point. Do you think that we should add a unit test regarding the bitness preservation since we try to have it in other cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the test here: e97a8df (#22063).
|
Thanks @thomasjpfan |
…22063) Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
…22063) Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
…22063) Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>


Reference Issues/PRs
Fixes #19572
What does this implement/fix? Explain your changes.
This PR adjusts
SimpleImputerto remember the dtype it used infitand uses the same dtype fortransform.CC @glemaitre