X Tutup
The Wayback Machine - https://web.archive.org/web/20230427202546/https://github.com/scikit-learn/scikit-learn/pull/22063
Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX SimpleImputer uses dtype seen in fit for transform #22063

Merged
merged 10 commits into from Jun 1, 2022

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented Dec 22, 2021

Reference Issues/PRs

Fixes #19572

What does this implement/fix? Explain your changes.

This PR adjusts SimpleImputer to remember the dtype it used in fit and uses the same dtype for transform.

CC @glemaitre

Copy link
Contributor

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for checking it closer. I misdiagnose the bug :)

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @thomasjpfan.

- |Fix| :class:`impute.SimpleImputer` now uses the dtype seen in `fit` for
`transform`. :pr:`22063` by `Thomas Fan`_.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be move to doc/whats_new/v1.2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually in 1.1.2. I assume that we will do another bug fix release at some point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is likely too big of a change for 1.1.2. Currently fitting on float64 and transforming a float32 would return float32:

import numpy as np
from sklearn.impute import SimpleImputer
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
X = np.asarray([[7, 2, 3], [4, np.nan, 6], [10, 5, 9]], dtype=np.float64)
imp_mean.fit(X)

X_test = np.asarray([[np.nan, 2, 3], [4, np.nan, 6], [10, np.nan, 9]], dtype=np.float32)
X_trans = imp_mean.transform(X_test)

print(X_trans.dtype)
# float32

Looking at this again, I think it's better to error when fitting on an object dtype, but transforming on a non-object.

Copy link
Contributor

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@glemaitre glemaitre added this to the 1.1.2 milestone May 30, 2022
@@ -278,6 +278,10 @@ def _validate_input(self, X, in_fit):
else:
dtype = FLOAT_DTYPES

if not in_fit and self._fit_dtype.kind == "O":
# Use object dtype if fitted on object dtypes
dtype = self._fit_dtype
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this PR to only use the fit_dtype only if the dype during fit is object.

This is to preserve the current behavior, of "Fit on float64 -> transform on float32 returns float32"

import numpy as np
from sklearn.impute import SimpleImputer
imp_mean = SimpleImputer(missing_values=np.nan, strategy='mean')
X = np.asarray([[7, 2, 3], [4, np.nan, 6], [10, 5, 9]], dtype=np.float64)
imp_mean.fit(X)

X_test = np.asarray([[np.nan, 2, 3], [4, np.nan, 6], [10, np.nan, 9]], dtype=np.float32)
X_trans = imp_mean.transform(X_test)

print(X_trans.dtype)
# float32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. Do you think that we should add a unit test regarding the bitness preservation since we try to have it in other cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the test here: e97a8df (#22063).

@glemaitre glemaitre merged commit 8ea2997 into scikit-learn:main Jun 1, 2022
15 of 25 checks passed
@glemaitre
Copy link
Contributor

Thanks @thomasjpfan

ogrisel pushed a commit to ogrisel/scikit-learn that referenced this pull request Jul 11, 2022
…22063)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
glemaitre added a commit to glemaitre/scikit-learn that referenced this pull request Aug 4, 2022
…22063)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
glemaitre added a commit that referenced this pull request Aug 5, 2022
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
mathijs02 pushed a commit to mathijs02/scikit-learn that referenced this pull request Dec 27, 2022
…22063)

Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants
X Tutup