98 lines
2.6 KiB
Plaintext
98 lines
2.6 KiB
Plaintext
This rule raises an issue when a `tensorflow.function` depends on a global or free Python variable.
|
|
|
|
== Why is this an issue?
|
|
|
|
When calling a `tensorflow.function` behind the scenes a `ConcreteFunction` is created everytime a new value is passed as argument.
|
|
This is not the case with Python global variables, closure or nonlocal variables.
|
|
|
|
This means the state and the result of the `tensorflow.function` may not be what is expected.
|
|
|
|
[source,python]
|
|
----
|
|
import tensorflow as tf
|
|
|
|
@tf.function
|
|
def addition():
|
|
return 1 + foo
|
|
|
|
foo = 4
|
|
addition() # tf.Tensor(5, shape=(), dtype=int32): on this first step we obtain the expected result
|
|
|
|
foo = 10
|
|
addition() # tf.Tensor(5, shape=(), dtype=int32): unexpected result of 5 instead of 11
|
|
----
|
|
|
|
As we can see in the example above the second time `addition` is called, we obtain the same result as the first call.
|
|
This is due to the fact that between the 2 calls of `addition` the value of the argument passed to the function did not change.
|
|
This result in the creation of a single `ConcreteFunction` during the first call of `addition`, with the value of foo set to 4.
|
|
|
|
This is why it is a good practice to not use and mutate global variables or nonlocal variables inside of a `tensorflow.function`.
|
|
|
|
=== Exceptions
|
|
|
|
This rule will not raise an issue if the global or nonlocal variable is a `tensorflow.Variable`.
|
|
|
|
[source,python]
|
|
----
|
|
import tensorflow as tf
|
|
|
|
@tf.function
|
|
def addition():
|
|
return 1 + foo
|
|
|
|
foo = tf.Variable(4)
|
|
addition()
|
|
|
|
foo.assign(10)
|
|
addition()
|
|
----
|
|
|
|
In this case the `ConcreteFunction` will be created properly each call if the value of the variable changes.
|
|
|
|
== How to fix it
|
|
|
|
To fix this issue refactor the Python global or nonlocal variable to be an argument of the `tensorflow.function`.
|
|
|
|
=== Code examples
|
|
|
|
==== Noncompliant code example
|
|
|
|
[source,python,diff-id=1,diff-type=noncompliant]
|
|
----
|
|
import tensorflow as tf
|
|
|
|
@tf.function
|
|
def addition():
|
|
return 1 + foo # Noncompliant the usage of the nonlocal variable may not behave as expected.
|
|
|
|
foo = 4
|
|
addition()
|
|
|
|
foo = 10
|
|
addition()
|
|
----
|
|
|
|
==== Compliant solution
|
|
|
|
[source,python,diff-id=1,diff-type=compliant]
|
|
----
|
|
import tensorflow as tf
|
|
|
|
@tf.function
|
|
def addition(foo):
|
|
return 1 + foo # Compliant
|
|
|
|
foo = 4
|
|
addition(foo)
|
|
|
|
foo = 10
|
|
addition(foo)
|
|
----
|
|
|
|
== Resources
|
|
=== Documentation
|
|
|
|
* TensorFlow Documentation - https://www.tensorflow.org/guide/function#depending_on_python_global_and_free_variables[Depending on Python global and free variables]
|
|
* TensorFlow Documentation - https://www.tensorflow.org/jvm/api_docs/java/org/tensorflow/ConcreteFunction?hl=en[ConcreteFunction reference]
|
|
|