rspec/rules/S6970/python/rule.adoc
2024-04-15 15:18:53 +02:00

91 lines
2.8 KiB
Plaintext

This rule raises an issue if the Scikit-learn `fit` method is not called prior to a method yielding results.
== Why is this an issue?
When using the Scikit-learn library it is crucial to train the model before
attempting to get results. Failing to do so can lead to incorrect results or runtime errors.
The training is done with the help of the `fit` method and retrieving results can be done for example with the `predict` method.
If the `predict` method is called without a prior call to the `fit` method, a `NotFittedError` will be thrown.
In this case the error is unambiguous but in some other cases the error thrown could be less explicit.
[source,python]
----
from sklearn import svm, datasets
from sklearn.model_selection import GridSearchCV
iris = datasets.load_iris()
parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
svc = svm.SVC()
clf = GridSearchCV(svc, parameters)
results = clf.cv_results_ # raises an AttributeError
----
In the example above failing to train the model on the iris dataset with the
`fit` method results in a more cryptic error where ``++cv_results_++`` is not an
attribute of `GridSearchCV`, this is because this attribute is only set after the method `fit`
is called.
This rule will raise an issue when the following methods are called without a prior call to `fit`:
* `predict`
* `predict_proba`
* `predict_log_proba`
* `score`
* `decision_function`
* `transform`
* `inverse_transform`
== How to fix it
To fix the issue train the model by using the `fit` method.
=== Code examples
==== Noncompliant code example
[source,python,diff-id=1,diff-type=noncompliant]
----
from sklearn import datasets
from sklearn.cluster import KMeans
iris = datasets.load_iris()
X = iris.data
kmeans = KMeans(n_clusters=3, random_state=42)
kmeans.predict(X) # Noncompliant: raises a NotFittedError
----
==== Compliant solution
[source,python,diff-id=1,diff-type=compliant]
----
from sklearn import datasets
from sklearn.cluster import KMeans
iris = datasets.load_iris()
X = iris.data
kmeans = KMeans(n_clusters=3, random_state=42)
kmeans.fit(X)
kmeans.predict(X) # Compliant
----
== Resources
=== Documentation
* Scikit-learn Documentation - https://scikit-learn.org/stable/glossary.html#term-fit[term fit reference]
* Scikit-learn Documentation - https://scikit-learn.org/stable/modules/generated/sklearn.exceptions.NotFittedError.html#sklearn.exceptions.NotFittedError[NotFittedError reference]
ifdef::env-github,rspecator-view[]
Implementation details:
Only if the list of methods above are called, we should check for the `fit` method called on the same object.
Issue location: the name of method (from the list above)
Message: Call the fit method on this object before retrieving the results.
Quickfix: Not applicable (could be too tricky as the parameters of fit and predict could be different)
endif::env-github,rspecator-view[]