![github-actions[bot]](/assets/img/avatar_default.png)
* Create rule S6985 * add implementation details * Address review * Update rule to include details about the wheights_only parameter * Remove unnecessary example --------- Co-authored-by: ghislainpiot <ghislainpiot@users.noreply.github.com> Co-authored-by: Ghislain Piot <ghislain.piot@sonarsource.com> Co-authored-by: Sebastian Zumbrunn <sebastian.zumbrunn@sonarsource.com>
65 lines
1.6 KiB
Plaintext
65 lines
1.6 KiB
Plaintext
This rule raises an issue when `pytorch.load` is used to load a model.
|
|
|
|
== Why is this an issue?
|
|
|
|
In PyTorch, it is common to load serialized models using the `torch.load` function.
|
|
Under the hood, `torch.load` uses the `pickle` library to load the model and the weights.
|
|
If the model comes from an untrusted source, an attacker could inject a malicious payload which would be executed during the deserialization.
|
|
|
|
== How to fix it
|
|
|
|
Use a safer alternative to load the model, such as `safetensors.torch.load_model`. Alternatively, PyTorch can be instructed to only load
|
|
the weights by setting the parameter `weights_only=True`. This avoids the use of the `pickle` library and is therefore safe. Note that the
|
|
use of `weights_only` requires saving only the `state_dict` of a model instead of the whole model.
|
|
|
|
=== Code examples
|
|
|
|
==== Noncompliant code example
|
|
|
|
[source,python,diff-id=1,diff-type=noncompliant]
|
|
----
|
|
import torch
|
|
|
|
model = torch.load('model.pth') # Noncompliant: torch.load is used to load the model
|
|
----
|
|
|
|
==== Compliant solution
|
|
|
|
[source,python,diff-id=1,diff-type=compliant]
|
|
----
|
|
import torch
|
|
import safetensors
|
|
|
|
model = MyModel()
|
|
safetensors.torch.load_model(model, 'model.pth')
|
|
----
|
|
|
|
== Resources
|
|
=== Documentation
|
|
|
|
* Pytorch documentation: https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-entire-model[Save/Load Entire Model]
|
|
|
|
|
|
ifdef::env-github,rspecator-view[]
|
|
|
|
(visible only on this page)
|
|
|
|
== Implementation specification
|
|
|
|
All usages of torch.load
|
|
|
|
=== Message
|
|
|
|
Primary : Replace this call with a safe alternative
|
|
|
|
|
|
=== Issue location
|
|
|
|
Primary : name of the function call
|
|
|
|
=== Quickfix
|
|
|
|
No
|
|
|
|
endif::env-github,rspecator-view[]
|