Source code for einshard.einshard

from math import prod

import jax
from jax import Array
from jax.experimental import mesh_utils
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P

from .parser import parse_expression

def _partition_at_ellipsis(lst: list) -> tuple[list, list]:
    idx = lst.index(...)
    l = lst[:idx]
    r = lst[idx + 1:]
    return l, r

[docs] def make_sharding(expression: str, *, n_dims: int | None = None) -> NamedSharding: ''' Make sharding from einshard expression. Args: expression (str): The einshard expression. n_dims (int | None): The number of dimensions of the array to be sharded. This argument must be provided if ellipsis is used in the einshard expression. Returns: jax.sharding.NamedSharding: The :class:`jax.sharding.Sharding` object corresponding to the given expression. ''' n_devices = jax.device_count() res = parse_expression(expression, 0) if not res.is_success(): idx, desc = res.error raise ValueError(f'Cannot parse einshard expression "{expression}", expected {desc} at position {idx}.') _, (elements_left, elements_right) = res.value n_left_ellipses = sum(element_left is ... for element_left in elements_left) n_right_ellipses = sum(element_right is ... for element_right in elements_right) assert n_left_ellipses == n_right_ellipses and n_left_ellipses <= 1 if n_left_ellipses == 0: if n_dims is not None: assert n_dims == len(elements_left) else: # == 1 assert n_dims is not None n_dims_elided = n_dims - len(elements_left) + 1 axis_names_for_left_augmented = [f'?{i}' for i in range(n_dims_elided)] axis_names_for_right_augmented = [(identifier, 1, False) for identifier in axis_names_for_left_augmented] # 1: `sharding_number`, False: `is_proportional` elements_left_left, elements_left_right = _partition_at_ellipsis(elements_left) elements_left = [*elements_left_left, *axis_names_for_left_augmented, *elements_left_right] elements_right_left, elements_right_right = _partition_at_ellipsis(elements_right) elements_right = [*elements_right_left, *axis_names_for_right_augmented, *elements_right_right] # print(elements_left) # print(elements_right) sharding_numbers_fixed = [sharding_number for _, sharding_number, is_proportional in elements_right if not is_proportional] sharding_numbers_proportional = [sharding_number for _, sharding_number, is_proportional in elements_right if is_proportional] if not sharding_numbers_proportional: sharding_ratio = 1 # can be of whatever value because it will not be used in this case else: n_devices_needed_for_fixed = prod(sharding_numbers_fixed) n_devices_needed_for_proportional_base = prod(sharding_numbers_proportional) n_sharded_axes_proportional = len(sharding_numbers_proportional) assert n_devices % n_devices_needed_for_fixed == 0 n_devices_available_proportional = n_devices // n_devices_needed_for_fixed assert n_devices_available_proportional % n_devices_needed_for_proportional_base == 0 sharding_ratio_full = n_devices_available_proportional // n_devices_needed_for_proportional_base sharding_ratio = sharding_ratio_full ** (1 / n_sharded_axes_proportional) assert sharding_ratio.is_integer() sharding_ratio = int(sharding_ratio) mesh_shape = [sharding_number * (1 if not is_proportional else sharding_ratio) for _, sharding_number, is_proportional in elements_right] axis_names = tuple(f'a{i}' for i, _ in enumerate(elements_right)) d = {identifier: i for i, (identifier, _, _) in enumerate(elements_right) if identifier is not None} partition_spec = tuple(f'a{d[element_left]}' for element_left in elements_left) # print(mesh_shape) # print(axis_names) # print(partition_spec) devices = mesh_utils.create_device_mesh(mesh_shape) mesh = Mesh(devices, axis_names=axis_names) sharding = NamedSharding(mesh, P(*partition_spec)) return sharding
[docs] def einshard(arr: Array, expression: str) -> Array: ''' Shards a :class:`jax.Array` according to the given einshard expression. Args: arr (jax.Array): The array to be sharded. expression (str): The einshard expression. Returns: jax.Array: The sharded array. ''' sharding = make_sharding(expression, n_dims=len(arr.shape)) arr = jax.make_array_from_callback(arr.shape, sharding, lambda idx: arr[idx]) return arr