ReduceScatter
ReduceScatter is a communication primitive that performs a reduction operation (typically summation) across devices and then scatters different portions of the result to different devices.
Definition
The ReduceScatter operation removes an “unreduced” suffix from an array and introduces a sharding along a specified dimension:
Here, {U_X} indicates partial sums that need to be combined, and the result is sharded along dimension J.
Visual Representation
Animation showing how a ReduceScatter operation works across multiple devices. Here, bidirection communication links are shown (true for TPUs). Source: How to Scale Your Model
Operation Details
- Partial Reduction: Each device holds partial sums that need to be combined
- Targeted Distribution: Each device receives only a portion of the final result
- Ring Algorithm: Often implemented using a ring algorithm for efficiency
- Dimension Selection: The operation needs to specify which dimension to shard in the output
Performance Analysis
For an array of size V bytes:
Time Complexity (Bandwidth-Bound)
Where W_{ICI} is the bidirectional ICI bandwidth.
This is the same cost as an AllGather operation.
Use Cases in Distributed Matrix Multiplication
One key use is in Case 3 matrix multiplication with sharded output:
Steps:
A[I, J_X] ·_LOCAL B[J_X, K] → C[I, K] {U_X}ReduceScatter_X,K(C[I, K] {U_X}) → C[I, K_X]
Implementation in JAX
# Define a sharded array with partial results
x = jnp.ones((128, 64), device=P(None, None)) # has unreduced values from axis X
# Perform ReduceScatter to sum and distribute portions of the result
scattered_x = jax.lax.psum_scatter(x, axis_name='X', scatter_dimension=1)Relationship with Backpropagation
ReduceScatter is the gradient operation of AllGather:
- Forward pass:
AllGather_X(A[I_X]) → A[I] - Backward pass:
ReduceScatter_X(A'[I] {U_X}) → A'[I_X]
This relationship is fundamental to understanding communication patterns in neural network training.
It can be combined with AllGather to implement AllReduce when needed