JAX study notes[16]

文章目录

  • Pytrees
  • references

Pytrees

  1. in essence, JAX function and transform act on arrays,actually most opeartion handling arrays base on the collection of arrays.
  2. JAX use the Pytree which is an abstract object to control a lot of collections with consolidated former instead of make various structures for different cases.
import jax
import jax.numpy as jnp
params = [11,120,1000000000,"abcd",jnp.ones(3),{'n': 5, 'W': jnp.zeros(2)}]
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/learn1.py
PyTreeDef([*, *, *, *, *, {'W': *, 'n': *}])
[11, 120, 1000000000, 'abcd', Array([1., 1., 1.], dtype=float32), Array([0., 0.], dtype=float32), 5]
  1. JAX provide plenty of facilities to work with PyTrees.
  • tree.map
    to make a new pytree through puting the input some arguments formed as pytree into a function.
jax.tree.map(f, tree, *rest, is_leaf=None)
import jax
import jax.numpy as jnp
import math
params1 = [x for x in jnp.arange(1,10,2)]
params2 = [x for x in jnp.arange(10,1,-2)]
print(jax.tree.map(lambda a,b: math.sqrt(a^2+b^2),params1,params2))
  • tree.reduce
    to achieve reduce manipulation and get reduced value.
jax.tree.reduce(function: Callable[[T, Any], T], tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) → T
import jax
import operator

params1 = [1,2,3]
params2 = [4,5]


result=jax.tree.reduce(operator.add, [params1, params2])
print(result)
PS E:\learn\learnpy> & D:/Python312/python.exe e:/learn/learnpy/learn1.py
15

references

https://docs.jax.dev/

你可能感兴趣的:(JAX study notes[16])