github-actions[bot] 2a23d72c8f
Create rule S6982: model.eval() or model.train() should be called after loading a PyTorch model state (#3972)
* Create rule S6982

* Create rule S6982: model.eval() should be called after loading weights
of a PyTorch model

* Added implementation details

* Fix after review

---------

Co-authored-by: joke1196 <joke1196@users.noreply.github.com>
Co-authored-by: David Kunzmann <david.kunzmann@sonarsource.com>
2024-09-19 14:55:20 +02:00

69 lines
2.1 KiB
Plaintext

This rule raises an issue when a PyTorch model state is loaded and `torch.nn.Module.eval()` or `torch.nn.Module.train()` is not called.
== Why is this an issue?
When using PyTorch it is common practice to load and save a model's state from/to a `.pth` file.
Doing so allows, for example, to instantiate an untrained model and load learned parameters coming from another pre-trained model.
Once the learned parameters are loaded to the model it is important, before inferencing,
to clearly state the intention by calling `torch.nn.Module.eval()` method to set the model in evaluation mode
or calling `torch.nn.Module.train()` to indicate the training will resume.
Failing to call `torch.nn.Module.eval()` would leave the model in training mode which may not be the intention.
== How to fix it
Call the `torch.nn.Module.eval()` or `torch.nn.Module.train()` method on the model.
=== Code examples
==== Noncompliant code example
[source,python,diff-id=1,diff-type=noncompliant]
----
import torch
import torchvision.models as models
model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth')) # Noncompliant: model.train() or model.eval() was not called.
----
==== Compliant solution
[source,python,diff-id=1,diff-type=compliant]
----
import torch
import torchvision.models as models
model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
----
== Resources
=== Documentation
* PyTorch Documentation - https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.eval[eval - reference]
* PyTorch Documentation - https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.train[train - reference]
* PyTorch Documentation - https://pytorch.org/docs/stable/notes/autograd.html#evaluation-mode-nn-module-eval[Autograd - Evaluation Mode]
ifdef::env-github,rspecator-view[]
(visible only on this page)
== Implementation specification
=== Message
Primary : Set the module in training or evaluation mode.
=== Issue location
Primary : the call to model.load_state_dict
=== Quickfix
No
endif::env-github,rspecator-view[]