AllReduce

AllReduce is a communication primitive that performs a reduction operation (typically summation) across all devices and makes the result available on all devices.

Definition

The AllReduce operation removes an “unreduced” suffix from an array, performing a summation across partial results:

Here, {U_X} indicates that the array contains partial sums that need to be combined along the X dimension.

Decomposition

AllReduce can be decomposed into two more primitive operations:

This decomposition is important for understanding the cost and potential optimizations.

Visual Representation

The AllReduce operation can be visualized as a combination of a ReduceScatter (reducing values and scattering results) followed by an AllGather (collecting all results):

AllReduce as ReduceScatter + AllGather Figure: AllReduce shown as part of communication primitives. Here, bidirection communication links are shown (true for TPUs). Source: How to Scale Your Model

Operation Details

  1. Reduction Operation: Typically summation, but can also be min, max, or other operations
  2. Ring Algorithm: Commonly implemented using a ring algorithm
  3. Complete Replication: Every device ends up with identical results

Performance Analysis

For an array of size V bytes:

Time Complexity (Bandwidth-Bound)

Where W_{ICI} is the bidirectional ICI bandwidth.

This is twice the cost of an AllGather or ReduceScatter because it combines both operations.

Use Cases in Distributed Matrix Multiplication

The primary use case is Case 3 matrix multiplication, where both multiplicands have sharded contracting dimensions:

Steps:

  1. A[I, J_X] ·_LOCAL B[J_X, K] → C[I, K] {U_X}
  2. AllReduce_X(C[I, K] {U_X}) → C[I, K]

Implementation in JAX

# Define a sharded array with partial results
x = jnp.ones((128, 64), device=P('Y', None))  # has unreduced values from axis X
 
# Perform AllReduce to sum the partial results
reduced_x = jax.lax.psum(x, axis_name='X')

Alternatives to Full AllReduce

Sometimes a full AllReduce is unnecessary. Alternative patterns include:

  1. ReduceScatter: If the result can remain sharded

  2. Local Reduction: If only a subset of devices need to combine their results