ReduceScatter

ReduceScatter is a communication primitive that performs a reduction operation (typically summation) across devices and then scatters different portions of the result to different devices.

Definition

The ReduceScatter operation removes an “unreduced” suffix from an array and introduces a sharding along a specified dimension:

Here, {U_X} indicates partial sums that need to be combined, and the result is sharded along dimension J.

Visual Representation

ReduceScatter Animation Animation showing how a ReduceScatter operation works across multiple devices. Here, bidirection communication links are shown (true for TPUs). Source: How to Scale Your Model

Operation Details

  1. Partial Reduction: Each device holds partial sums that need to be combined
  2. Targeted Distribution: Each device receives only a portion of the final result
  3. Ring Algorithm: Often implemented using a ring algorithm for efficiency
  4. Dimension Selection: The operation needs to specify which dimension to shard in the output

Performance Analysis

For an array of size V bytes:

Time Complexity (Bandwidth-Bound)

Where W_{ICI} is the bidirectional ICI bandwidth.

This is the same cost as an AllGather operation.

Use Cases in Distributed Matrix Multiplication

One key use is in Case 3 matrix multiplication with sharded output:

Steps:

  1. A[I, J_X] ·_LOCAL B[J_X, K] → C[I, K] {U_X}
  2. ReduceScatter_X,K(C[I, K] {U_X}) → C[I, K_X]

Implementation in JAX

# Define a sharded array with partial results
x = jnp.ones((128, 64), device=P(None, None))  # has unreduced values from axis X
 
# Perform ReduceScatter to sum and distribute portions of the result
scattered_x = jax.lax.psum_scatter(x, axis_name='X', scatter_dimension=1)

Relationship with Backpropagation

ReduceScatter is the gradient operation of AllGather:

  • Forward pass: AllGather_X(A[I_X]) → A[I]
  • Backward pass: ReduceScatter_X(A'[I] {U_X}) → A'[I_X]

This relationship is fundamental to understanding communication patterns in neural network training.

It can be combined with AllGather to implement AllReduce when needed