-
-
Notifications
You must be signed in to change notification settings - Fork 26k
FIX Draw indices using sample_weight in Forest #31529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FIX Draw indices using sample_weight in Forest #31529
Conversation
Relative (float) |
The |
if sample_weight is None: | ||
sample_indices = random_instance.randint(0, n_samples, n_samples_bootstrap) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There two options for the random draw of indices when sample_weight=None
- Convert to all ones
if sample_weight is None:
sample_weight = np.ones(n_samples)
normalized_sample_weight = sample_weight / np.sum(sample_weight)
sample_indices = random_instance.choice(
n_samples, n_samples_bootstrap, replace=True, p=normalized_sample_weight
)
- Use the old code path when
sample_weight=None
if sample_weight is None:
sample_indices = random_instance.randint(0, n_samples, n_samples_bootstrap)
else:
normalized_sample_weight = sample_weight / np.sum(sample_weight)
sample_indices = random_instance.choice(
n_samples,
n_samples_bootstrap,
replace=True,
p=normalized_sample_weight,
)
The two options use different rng functions: choice
with uniform p
for 1 and randint
for 2. They are statistically the same but they don't give the same deterministic output for a given random state.
The benefit of 2. is that the code is backward compatible when sample_weight=None
. A fit without sample_weight
reproduce the same fit as main for a given random_state
.
The benefit of 1. is that sample_weight=None
and sample_weight=np.ones(n_samples)
give the same fit for a given random_state
.
Here we chose 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't the time to finish my review today, but this looks great: I tried running the notebook of github.com/snath-xoc/sample-weight-audit-nondet/ against this branch and I confirm the statistical tests pass for RandomForestClassifier/Regressor
and ExtraTreesClassifier/Regressor
.
@@ -324,13 +325,13 @@ def test_parallel_fit(global_random_seed): | |||
def test_sample_weight(global_random_seed): | |||
"""Tests sample_weight parameter of VotingClassifier""" | |||
clf1 = LogisticRegression(random_state=global_random_seed) | |||
clf2 = RandomForestClassifier(n_estimators=10, random_state=global_random_seed) | |||
clf2 = GradientBoostingClassifier(n_estimators=10, random_state=global_random_seed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change?
@@ -1167,7 +1167,7 @@ def test_class_weights(name): | |||
|
|||
# Iris is balanced, so no effect expected for using 'balanced' weights | |||
clf1 = ForestClassifier(random_state=0) | |||
clf1.fit(iris.data, iris.target) | |||
clf1.fit(iris.data, iris.target, sample_weight=np.ones_like(iris.target)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add an inline comment to explain why leaving sample_weight=None
does not work for this test.
Part of #16298. Similar to #31414 (Bagging estimators) but for Forest estimators.
What does this implement/fix? Explain your changes.
When subsampling is activated (
bootstrap=True
),sample_weight
are now used as probabilities to draw the indices. Forest estimators then pass the statistical repeated/weighted equivalence test.Comments
This PR does not fix Forest estimators when
bootstrap=False
(no subsampling).sample_weight
are still passed to the decision trees. Forest estimators then fail the statistical repeated/weighted equivalence test because the individual treesalso fail this test (probably because of tied splits in decision trees #23728).
TODO
sample_weight=None
casemax_samples
as done in FIX Draw indices using sample_weight in Bagging #31414class_weight = "balanced"
as done in Fix linear svc handling sample weights under class_weight="balanced" #30057class_weight = "balanced_subsample"