AllToAll
AllToAll is a communication primitive that redistributes data across devices by effectively moving sharding from one dimension to another.
Definition
The AllToAll operation moves a sharding subscript from one array dimension to another:
This operation represents a sharded transposition, changing how the data is distributed without changing the total amount of data on each device.
Visual Representation
Animation showing how an AllToAll operation redistributes data across devices. Here, bidirection communication links are shown (true for TPUs). Source: How to Scale Your Model
Operation Details
- Targeted Exchange: Each device sends specific portions of its data to each other device
- No Replication: Unlike AllGather, the total data volume per device remains constant
- Dimension Swapping: Effectively swaps which dimension is sharded
Performance Analysis
For an array of size V bytes on a bidirectional ring:
Time Complexity (Bandwidth-Bound)
This is 1/4 the cost of an AllGather operation on a bidirectional ring (or 1/2 on a unidirectional ring).
Mathematical Explanation
The reduced cost comes from:
- No data replication (unlike AllGather where each shard is sent to all devices)
- Efficient ring distribution pattern where each device only sends specific portions to specific targets
Use Cases
- Mixture of Experts Models: Redistributing data between expert and batch dimensions
- Layout Rearrangement: Converting between different sharding strategies
- Attention Mechanisms: Redistributing attention patterns
- Transformations Between Incompatible Shardings: When operations require different sharding patterns
Implementation in JAX
# Define a sharded array
x = jnp.ones((128, 64), device=P('X', None)) # sharded as [I_X, J]
# Perform AllToAll to change sharding
transposed_x = jax.lax.all_to_all(
x,
axis_name='X',
split_axis=0, # dimension to un-shard
concat_axis=1 # dimension to shard
) # results in [I, J_X]