einshard#
- einshard.einshard(arr, expression)[source]#
Shards a
jax.Arrayaccording to the given einshard expression.
- einshard.make_sharding(expression, *, n_dims=None)[source]#
Make sharding from einshard expression.
- Parameters:
expression (str) – The einshard expression.
n_dims (int | None) – The number of dimensions of the array to be sharded. This argument must be provided if ellipsis is used in the einshard expression.
- Returns:
The
jax.sharding.Shardingobject corresponding to the given expression.- Return type: