AllToAll

AllToAll is a communication primitive that redistributes data across devices by effectively moving sharding from one dimension to another.

Definition

The AllToAll operation moves a sharding subscript from one array dimension to another:

This operation represents a sharded transposition, changing how the data is distributed without changing the total amount of data on each device.

Visual Representation

AllToAll Animation Animation showing how an AllToAll operation redistributes data across devices. Here, bidirection communication links are shown (true for TPUs). Source: How to Scale Your Model

Operation Details

  1. Targeted Exchange: Each device sends specific portions of its data to each other device
  2. No Replication: Unlike AllGather, the total data volume per device remains constant
  3. Dimension Swapping: Effectively swaps which dimension is sharded

Performance Analysis

For an array of size V bytes on a bidirectional ring:

Time Complexity (Bandwidth-Bound)

This is 1/4 the cost of an AllGather operation on a bidirectional ring (or 1/2 on a unidirectional ring).

Mathematical Explanation

The reduced cost comes from:

  1. No data replication (unlike AllGather where each shard is sent to all devices)
  2. Efficient ring distribution pattern where each device only sends specific portions to specific targets

Use Cases

  1. Mixture of Experts Models: Redistributing data between expert and batch dimensions
  2. Layout Rearrangement: Converting between different sharding strategies
  3. Attention Mechanisms: Redistributing attention patterns
  4. Transformations Between Incompatible Shardings: When operations require different sharding patterns

Implementation in JAX

# Define a sharded array
x = jnp.ones((128, 64), device=P('X', None))  # sharded as [I_X, J]
 
# Perform AllToAll to change sharding
transposed_x = jax.lax.all_to_all(
    x, 
    axis_name='X',
    split_axis=0,  # dimension to un-shard
    concat_axis=1  # dimension to shard
)  # results in [I, J_X]