AllReduce
AllReduce is a communication primitive that performs a reduction operation (typically summation) across all devices and makes the result available on all devices.
Definition
The AllReduce operation removes an “unreduced” suffix from an array, performing a summation across partial results:
Here, {U_X} indicates that the array contains partial sums that need to be combined along the X dimension.
Decomposition
AllReduce can be decomposed into two more primitive operations:
This decomposition is important for understanding the cost and potential optimizations.
Visual Representation
The AllReduce operation can be visualized as a combination of a ReduceScatter (reducing values and scattering results) followed by an AllGather (collecting all results):
Figure: AllReduce shown as part of communication primitives. Here, bidirection communication links are shown (true for TPUs). Source: How to Scale Your Model
Operation Details
- Reduction Operation: Typically summation, but can also be min, max, or other operations
- Ring Algorithm: Commonly implemented using a ring algorithm
- Complete Replication: Every device ends up with identical results
Performance Analysis
For an array of size V bytes:
Time Complexity (Bandwidth-Bound)
Where W_{ICI} is the bidirectional ICI bandwidth.
This is twice the cost of an AllGather or ReduceScatter because it combines both operations.
Use Cases in Distributed Matrix Multiplication
The primary use case is Case 3 matrix multiplication, where both multiplicands have sharded contracting dimensions:
Steps:
A[I, J_X] ·_LOCAL B[J_X, K] → C[I, K] {U_X}AllReduce_X(C[I, K] {U_X}) → C[I, K]
Implementation in JAX
# Define a sharded array with partial results
x = jnp.ones((128, 64), device=P('Y', None)) # has unreduced values from axis X
# Perform AllReduce to sum the partial results
reduced_x = jax.lax.psum(x, axis_name='X')Alternatives to Full AllReduce
Sometimes a full AllReduce is unnecessary. Alternative patterns include:
-
ReduceScatter: If the result can remain sharded
-
Local Reduction: If only a subset of devices need to combine their results