![github-actions[bot]](/assets/img/avatar_default.png)
* Create rule S6982 * Create rule S6982: model.eval() should be called after loading weights of a PyTorch model * Added implementation details * Fix after review --------- Co-authored-by: joke1196 <joke1196@users.noreply.github.com> Co-authored-by: David Kunzmann <david.kunzmann@sonarsource.com>
69 lines
2.1 KiB
Plaintext
69 lines
2.1 KiB
Plaintext
This rule raises an issue when a PyTorch model state is loaded and `torch.nn.Module.eval()` or `torch.nn.Module.train()` is not called.
|
|
|
|
== Why is this an issue?
|
|
|
|
When using PyTorch it is common practice to load and save a model's state from/to a `.pth` file.
|
|
Doing so allows, for example, to instantiate an untrained model and load learned parameters coming from another pre-trained model.
|
|
Once the learned parameters are loaded to the model it is important, before inferencing,
|
|
to clearly state the intention by calling `torch.nn.Module.eval()` method to set the model in evaluation mode
|
|
or calling `torch.nn.Module.train()` to indicate the training will resume.
|
|
Failing to call `torch.nn.Module.eval()` would leave the model in training mode which may not be the intention.
|
|
|
|
== How to fix it
|
|
|
|
Call the `torch.nn.Module.eval()` or `torch.nn.Module.train()` method on the model.
|
|
|
|
=== Code examples
|
|
|
|
==== Noncompliant code example
|
|
|
|
[source,python,diff-id=1,diff-type=noncompliant]
|
|
----
|
|
import torch
|
|
import torchvision.models as models
|
|
|
|
model = models.vgg16()
|
|
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant: model.train() or model.eval() was not called.
|
|
----
|
|
|
|
==== Compliant solution
|
|
|
|
[source,python,diff-id=1,diff-type=compliant]
|
|
----
|
|
import torch
|
|
import torchvision.models as models
|
|
|
|
model = models.vgg16()
|
|
model.load_state_dict(torch.load('model_weights.pth'))
|
|
model.eval()
|
|
----
|
|
== Resources
|
|
=== Documentation
|
|
|
|
* PyTorch Documentation - https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.eval[eval - reference]
|
|
* PyTorch Documentation - https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.train[train - reference]
|
|
* PyTorch Documentation - https://pytorch.org/docs/stable/notes/autograd.html#evaluation-mode-nn-module-eval[Autograd - Evaluation Mode]
|
|
|
|
ifdef::env-github,rspecator-view[]
|
|
|
|
(visible only on this page)
|
|
|
|
== Implementation specification
|
|
|
|
|
|
=== Message
|
|
|
|
Primary : Set the module in training or evaluation mode.
|
|
|
|
|
|
=== Issue location
|
|
|
|
Primary : the call to model.load_state_dict
|
|
|
|
|
|
=== Quickfix
|
|
|
|
No
|
|
|
|
endif::env-github,rspecator-view[]
|