JAX study notes[8]

文章目录

  • jax.typing
  • references

jax.typing

  1. the function annotations applied for static type checking maybe become a integral python coding standard.
  2. jax.Array is the base class represented array.
  3. to annotate in python project.
  • Level 1: Annotations as documentation
def f(x: jax.Array) -> jax.Array:  # type annotations are valid for traced and non-traced types.
  return x
  • Level 2: Annotations for intelligent autocomplete
    the many modern IDEs such as vscode make use of the type annotations in intelligent code completion systems.
  • Level 3: Annotations for static type-checking
  1. the package development with JAX must abide by two python type checking facilities including pytype developed by google , and mypy which known as the most popular static type checking tools.And beyond that, JAX will face chanllenges such as array duck-typing,transformations and decorators,array annotation lack of granularity and imprecise APIs inherited from NumPy.
  2. JAX provided that static type annotations and runtime instance checks for duck-typed objects.
  • Static type annotations
from typing import Union
from jax import Array, jit
from jax.core import Tracer
import jax.numpy as jnp

ArrayAnnotation = Union[Array, Tracer]

@jit
def f(x: ArrayAnnotation) -> ArrayAnnotation:
    assert isinstance(x, (Array, Tracer))  # Explicit check
    return x * 2

x = jnp.array([1.0, 2.0, 3.0])
result = f(x)
print(result)  # [2. 4. 6.] (jax.Array)

@jit
def g(x):
    return f(x)  # `x` is a Tracer here!

print(g(x))      # Same output, but internally traced


from jax import grad

df_dx = grad(lambda x: f(x).sum())  # Works with tracers
print(df_dx(x))  # [2. 2. 2.] (gradient of x*2)

f("invalid_input")
f(234)
[2. 4. 6.]
[2. 4. 6.]
[2. 2. 2.]
Traceback (most recent call last):
  File "e:\learn\learnpy\l2.py", line 29, in <module>
    f("invalid_input")
TypeError: Error interpreting argument to <function f at 0x0000018EEFD999E0> as an abstract array. The problematic value is of type <class 'str'> and was passed to the function at path x.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
  • Runtime instance checks
from typing import Union
from jax import Array, jit
from jax.core import Tracer
import jax.numpy as jnp

ArrayInstance = Union[Array, Tracer]

@jit
def f(x):
  return isinstance(x, ArrayInstance)


x = jnp.array([1, 2, 3])
assert f(x)       # x will be an array
assert jit(f)(x)  # x will be a tracer

references

https://docs.jax.dev/

你可能感兴趣的:(计算综合,JAX)