AllGather
AllGather is a fundamental communication primitive in distributed computing that collects sharded tensor fragments from all devices and makes the complete tensor available on each device.
Definition
The AllGather operation removes a sharding subscript from an array, resulting in a replicated dimension:
For example, if we have an array sharded across 4 devices, each with shape [256, 1024], an AllGather would produce a [1024, 1024] array on each device.
Visual Representation
Animation showing how an AllGather operation works across multiple devices. Here, bidirectional communication links are shown (true for TPUs). Source: How to Scale Your Model
Operation Details
- Sending Pattern: Each device sends its local shard to every other device
- Ring Implementation: Often implemented as a ring algorithm where data flows around a logical ring of devices
- Bidirectional Communication: On systems with bidirectional links, data can flow in both directions simultaneously
Performance Analysis
For an array of size V bytes sharded across |X| devices:
Time Complexity (Bandwidth-Bound)
Where is the bidirectional ICI bandwidth.
Time Complexity (Latency-Bound)
For very small arrays:
Where is the minimum hop latency (~1μs for TPUs).
Multi-Axis AllGather
When gathering over multiple axes:
Where is the number of axes being gathered over.
Use Cases in Distributed Matrix Multiplication
The primary use case is Case 2 matrix multiplication, where one multiplicand has a sharded contracting dimension:
Steps:
AllGather_X(A[I, J_X]) → A[I, J]A[I, J] · B[J, K] → C[I, K]
Implementation in JAX
# Define a sharded array
x = jnp.ones((128, 64), device=P('X', None)) # sharded as [I_X, J]
# Perform AllGather to get full array on each device
full_x = jax.lax.all_gather(x, axis_name='X', axis=0) # [I, J]Relationship with Other Primitives
- Inverse of ReduceScatter: In forward/backward propagation, AllGather in the forward pass corresponds to ReduceScatter in the backward pass
- Component of AllReduce: AllReduce can be implemented as ReduceScatter followed by AllGather
Performance Considerations
- Bandwidth vs. Latency: For large arrays, performance is bandwidth-bound; for small arrays, latency-bound
- Ring Topology: Performance improves with wraparound links (torus topology)
- Memory Cost: The operation increases memory usage by a factor of |X| on each device