AllGather

AllGather is a fundamental communication primitive in distributed computing that collects sharded tensor fragments from all devices and makes the complete tensor available on each device.

Definition

The AllGather operation removes a sharding subscript from an array, resulting in a replicated dimension:

For example, if we have an array sharded across 4 devices, each with shape [256, 1024], an AllGather would produce a [1024, 1024] array on each device.

Visual Representation

AllGather Animation Animation showing how an AllGather operation works across multiple devices. Here, bidirectional communication links are shown (true for TPUs). Source: How to Scale Your Model

Operation Details

  1. Sending Pattern: Each device sends its local shard to every other device
  2. Ring Implementation: Often implemented as a ring algorithm where data flows around a logical ring of devices
  3. Bidirectional Communication: On systems with bidirectional links, data can flow in both directions simultaneously

Performance Analysis

For an array of size V bytes sharded across |X| devices:

Time Complexity (Bandwidth-Bound)

Where is the bidirectional ICI bandwidth.

Time Complexity (Latency-Bound)

For very small arrays:

Where is the minimum hop latency (~1μs for TPUs).

Multi-Axis AllGather

When gathering over multiple axes:

Where is the number of axes being gathered over.

Use Cases in Distributed Matrix Multiplication

The primary use case is Case 2 matrix multiplication, where one multiplicand has a sharded contracting dimension:

Steps:

  1. AllGather_X(A[I, J_X]) → A[I, J]
  2. A[I, J] · B[J, K] → C[I, K]

Implementation in JAX

# Define a sharded array
x = jnp.ones((128, 64), device=P('X', None))  # sharded as [I_X, J]
 
# Perform AllGather to get full array on each device
full_x = jax.lax.all_gather(x, axis_name='X', axis=0)  # [I, J]

Relationship with Other Primitives

  • Inverse of ReduceScatter: In forward/backward propagation, AllGather in the forward pass corresponds to ReduceScatter in the backward pass
  • Component of AllReduce: AllReduce can be implemented as ReduceScatter followed by AllGather

Performance Considerations

  1. Bandwidth vs. Latency: For large arrays, performance is bandwidth-bound; for small arrays, latency-bound
  2. Ring Topology: Performance improves with wraparound links (torus topology)
  3. Memory Cost: The operation increases memory usage by a factor of |X| on each device