91 lines
2.8 KiB
Plaintext
91 lines
2.8 KiB
Plaintext
This rule raises an issue if the Scikit-learn `fit` method is not called prior to a method yielding results.
|
|
|
|
== Why is this an issue?
|
|
|
|
When using the Scikit-learn library it is crucial to train the model before
|
|
attempting to get results. Failing to do so can lead to incorrect results or runtime errors.
|
|
The training is done with the help of the `fit` method and retrieving results can be done for example with the `predict` method.
|
|
|
|
If the `predict` method is called without a prior call to the `fit` method, a `NotFittedError` will be thrown.
|
|
In this case the error is unambiguous but in some other cases the error thrown could be less explicit.
|
|
|
|
[source,python]
|
|
----
|
|
from sklearn import svm, datasets
|
|
from sklearn.model_selection import GridSearchCV
|
|
|
|
iris = datasets.load_iris()
|
|
parameters = {'kernel':('linear', 'rbf'), 'C':[1, 10]}
|
|
svc = svm.SVC()
|
|
clf = GridSearchCV(svc, parameters)
|
|
|
|
results = clf.cv_results_ # raises an AttributeError
|
|
----
|
|
|
|
In the example above failing to train the model on the iris dataset with the
|
|
`fit` method results in a more cryptic error where ``++cv_results_++`` is not an
|
|
attribute of `GridSearchCV`, this is because this attribute is only set after the method `fit`
|
|
is called.
|
|
|
|
This rule will raise an issue when the following methods are called without a prior call to `fit`:
|
|
|
|
* `predict`
|
|
* `predict_proba`
|
|
* `predict_log_proba`
|
|
* `score`
|
|
* `decision_function`
|
|
* `transform`
|
|
* `inverse_transform`
|
|
|
|
== How to fix it
|
|
|
|
To fix the issue train the model by using the `fit` method.
|
|
|
|
=== Code examples
|
|
|
|
==== Noncompliant code example
|
|
|
|
[source,python,diff-id=1,diff-type=noncompliant]
|
|
----
|
|
from sklearn import datasets
|
|
from sklearn.cluster import KMeans
|
|
|
|
iris = datasets.load_iris()
|
|
X = iris.data
|
|
|
|
kmeans = KMeans(n_clusters=3, random_state=42)
|
|
kmeans.predict(X) # Noncompliant: raises a NotFittedError
|
|
----
|
|
|
|
==== Compliant solution
|
|
|
|
[source,python,diff-id=1,diff-type=compliant]
|
|
----
|
|
from sklearn import datasets
|
|
from sklearn.cluster import KMeans
|
|
|
|
iris = datasets.load_iris()
|
|
X = iris.data
|
|
|
|
kmeans = KMeans(n_clusters=3, random_state=42)
|
|
kmeans.fit(X)
|
|
kmeans.predict(X) # Compliant
|
|
----
|
|
|
|
== Resources
|
|
=== Documentation
|
|
|
|
* Scikit-learn Documentation - https://scikit-learn.org/stable/glossary.html#term-fit[term fit reference]
|
|
* Scikit-learn Documentation - https://scikit-learn.org/stable/modules/generated/sklearn.exceptions.NotFittedError.html#sklearn.exceptions.NotFittedError[NotFittedError reference]
|
|
|
|
ifdef::env-github,rspecator-view[]
|
|
|
|
Implementation details:
|
|
|
|
Only if the list of methods above are called, we should check for the `fit` method called on the same object.
|
|
Issue location: the name of method (from the list above)
|
|
Message: Call the fit method on this object before retrieving the results.
|
|
Quickfix: Not applicable (could be too tricky as the parameters of fit and predict could be different)
|
|
|
|
endif::env-github,rspecator-view[]
|