einshard

einshard#

einshard.einshard(arr, expression)[source]#

Shards a jax.Array according to the given einshard expression.

Parameters:
  • arr (jax.Array) – The array to be sharded.

  • expression (str) – The einshard expression.

Returns:

The sharded array.

Return type:

jax.Array

einshard.make_sharding(expression, *, n_dims=None)[source]#

Make sharding from einshard expression.

Parameters:
  • expression (str) – The einshard expression.

  • n_dims (int | None) – The number of dimensions of the array to be sharded. This argument must be provided if ellipsis is used in the einshard expression.

Returns:

The jax.sharding.Sharding object corresponding to the given expression.

Return type:

jax.sharding.NamedSharding