Modify rule S6709: Add how to fix it for Scikit-learn (#3883)

This commit is contained in:
David Kunzmann 2024-05-07 14:21:34 +02:00 committed by GitHub
parent 621b7ce90e
commit 86d6b7c75b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 23 deletions

View File

@ -91,6 +91,7 @@
* Jinja * Jinja
* lxml * lxml
* MySQL Connector/Python * MySQL Connector/Python
* Numpy
* Paramiko * Paramiko
* pyca * pyca
* PyCrypto * PyCrypto
@ -105,6 +106,7 @@
* PyYAML * PyYAML
* Requests * Requests
* Scrypt * Scrypt
* Scikit-Learn
* SignXML * SignXML
* SQLAlchemy * SQLAlchemy
* ssl * ssl

View File

@ -0,0 +1,27 @@
== How to fix it in Numpy
To fix this issue, provide a predictable seed to the random number generator.
=== Code examples
==== Noncompliant code example
[source,python,diff-id=1,diff-type=noncompliant]
----
import numpy as np
def foo():
generator = np.random.default_rng() # Noncompliant: no seed parameter is provided
x = generator.uniform()
----
==== Compliant solution
[source,python,diff-id=1,diff-type=compliant]
----
import numpy as np
def foo():
generator = np.random.default_rng(42) # Compliant
x = generator.uniform()
----

View File

@ -0,0 +1,29 @@
== How to fix it in Scikit-Learn
To fix this issue, provide a predictable seed to the estimator or the utility function.
=== Code examples
==== Noncompliant code example
[source,python,diff-id=2,diff-type=noncompliant]
----
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
X_train, _, y_train, _ = train_test_split(X, y) # Noncompliant: no seed parameter is provided
----
==== Compliant solution
[source,python,diff-id=2,diff-type=compliant]
----
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
import numpy as np
rng = np.random.default_rng(42)
X, y = load_iris(return_X_y=True)
X_train, _, y_train, _ = train_test_split(X, y, random_state=rng.integers(1)) # Compliant
----

View File

@ -26,38 +26,20 @@ Note that a global seed for `RandomState` can be set using `numpy.random.seed` o
In contexts that are not related to data science and machine learning, having a predictable seed may not be the desired behavior. Therefore, this rule will only raise issues if machine learning and data science libraries are being used. In contexts that are not related to data science and machine learning, having a predictable seed may not be the desired behavior. Therefore, this rule will only raise issues if machine learning and data science libraries are being used.
== How to fix it
To fix this issue, provide a predictable seed to the random number generator. // How to fix it section
=== Code examples include::how-to-fix-it/numpy.adoc[]
==== Noncompliant code example include::how-to-fix-it/sklearn.adoc[]
[source,python,diff-id=1,diff-type=noncompliant]
----
import numpy as np
def foo():
generator = np.random.default_rng() # Noncompliant: no seed parameter is provided
x = generator.uniform()
----
==== Compliant solution
[source,python,diff-id=1,diff-type=compliant]
----
import numpy as np
def foo():
generator = np.random.default_rng(42) # Compliant
x = generator.uniform()
----
== Resources == Resources
=== Documentation === Documentation
* NumPy documentation - https://numpy.org/neps/nep-0019-rng-policy.html[NEP 19 RNG Policy] * NumPy documentation - https://numpy.org/neps/nep-0019-rng-policy.html[NEP 19 RNG Policy]
* Scikit-learn documentation - https://scikit-learn.org/stable/glossary.html#term-random_state[Glossary random_state]
* Scikit-learn documentation - https://scikit-learn.org/stable/common_pitfalls.html#controlling-randomness[Controlling randomness]
=== Standards === Standards
@ -66,3 +48,4 @@ def foo():
=== Related rules === Related rules
* S6711 - `numpy.random.Generator` should be preferred to `numpy.random.RandomState` * S6711 - `numpy.random.Generator` should be preferred to `numpy.random.RandomState`