In-Text Questions

Pop Quiz 1: Memory Usage

Let A be an array with shape int8[128, 2048], sharding , and mesh Mesh({'X': 2, 'Y': 8, 'Z': 2}) (so 32 devices total).

How much memory does A use per device?

How much total memory does A use across all devices?

Per device:

[128//(2 * 8), 2048] = [8, 2048] = 16,384 bytes

Total = 32 * 16,384 bytes = 524,288 bytes

Pop Quiz 2: AllGather Time

Using the numbers from Part 2, how long does it take to perform the AllGatherY([EY, F]) → [E, F] on a TPUv5e with a 2D mesh {'X': 8, 'Y': 4}, , in bfloat16?

What about with ?

TPU v5e ICI (bidi) = 9e10 bytes/s

Each shard = {512, 8192}

We need three unidirectional hops since we cannot do fully directional AllGather:

For each shard of {64, 512}

This is latency bound, so 3 hops =

[! Failure] My Original Incorrect Answer

For part 1, I did not consider that we cannot do full bi-directional hop since axis does not have size 16.

For part 2, I did not consider that it will be latency bound


Practice Questions

Question 1: Replicated Sharding Memory Usage

An array is sharded (i.e., only sharded across ), with a mesh Mesh({'X': 4, 'Y': 8, 'Z': 2}). What is the ratio of the total number of bytes taken up by across all chips to the size of one copy of the array?

each shard = 1/4 * sizeof(A)

total shards = 64

ratio = 64*1/4 = 16 * sizeof(A)

Question 2: AllGather Latency

How long should take on a TPUv4p 4x4x4 slice with mesh Mesh({'X': 4, 'Y': 4, 'Z': 4}) if and in bfloat16?

How about ?

How about ?

TPUv4p ICI bw = 9e10 bytes/s (bidi)

  1. with mesh Mesh({'X': 4, 'Y': 4, 'Z': 4}) if and in bfloat16

Shard size =

Since full-directional transfer possible, along the contracting dimension (X) full array size (1024) matters not the shard size

Time =

  1. with mesh Mesh({'X': 4, 'Y': 4, 'Z': 4}) if and in bfloat16

Time =

Shard size =

Total transfer time =

Time =

Question 3: Latency-Bound AllGather

Let’s say we’re performing an but is very small (say 128). How long should this take on a TPUv4p 4x4x4 slice with mesh Mesh({'X': 4, 'Y': 4, 'Z': 4}) in bfloat16?

Hint: you’re probably latency bound.

Latency bound = per hop

bidirectional possible = 2 hops =

Question 4: Matrix Multiplication Strategies

To perform , the text describes using and multiplying the fully replicated matrices (Strategy 1).

Instead, you could multiply the local shards like , and then (Strategy 2).

How many FLOPs and comms does each of these strategies perform? Which is better and why?

Let’s assume bf16 (2 bytes per element)

Also, for TPU v5e

Case 1: AllGather MatMul

AllGather = (2DF/W) MatMul FLOPs = 2BDF/C (where C is accelerator FLOPs - each TPU does same FLOPs parallely)

Compute Bound if

If B > 2550, compute bound

Case 2: Sharded MatMul AllReduce

Shared MatMul FLOPS = 2BDF / |X| = 2BDF/XC FLOPs (each TPU) AllReduce = (4BF/W)

BF/W)$

Compute Bound if

If D 8k (reasonable), X < 2 which never happens

Thus Case 2 is always comms bound

Which is faster?

  • For B < 2550 (often, but not always) both are comms bound

    Thus if D > 2B and 2B < 5100 : strategy 2 is faster - this is true for larger models with reasonable batch sizes

  • For B > 2550, 2 is comms bound 1 is math bound

    Thus if D > 2C/W D > 5100: strategy 2 is faster

Why don’t we employ Strategy 2?

We do sometimes. but is is unlikely that the contracting dimension of one of the inputs is sharded along an axis where the other input isn’t (for matmuls).

Example, for FSDP we shard both parameters and data along same dimensions

My original incorrect answer

(I was not thinking in terms of total bytes when sharded, only thinking along one dim, also I calculated FLOPs explicitly while we are interested in compute time not the flops per se)

Case 1: AllGather MatMul

AllGather = (D * sizeof(D_{elem})/W) MatMul FLOPs = 2BDF * |X|

Case 2: Sharded MatMul AllReduce

Shared MatMul FLOPS = 2B(D/|X|)F * |X| = 2BDF AllReduce = 2*(D * sizeof(D_{elem})/W)

Question 5: Minimum Latency Matrix Multiplication

Let’s say I want to do a matmul on a TPUv5p 4x4x4 with the lowest possible latency. How should my inputs be sharded? What is the total FLOPs and comms time?

  • Case 1: B < 2550
    • If D > 2B and 2B < 5100,

Todo

Question 6: Complex Sharding Communication

Let’s say we want to perform on TPUv5e 4x4. What communication do we perform? How much time is spent on communication vs. computation?

What about ? This is the most standard setting for training where we combine data, tensor, and zero sharding.

What about ? This is standard for inference, where we do pure tensor parallelism (+data).

Todo

Question 7: Memory-Constrained Transformer Block

A typical Transformer block has two matrices and where . With a batch size B, the whole block is with .

Let’s pick , , and and assume everything is in bfloat16. Assume we’re running on a TPUv5e 2x2 slice but assume each TPU only has 300MB of free memory.

How should B, C, and the output be sharded to stay below the memory limit while minimizing overall time? How much time is spent on comms and FLOPs?

Todo

Question 8: Communication Primitives Benchmarking

Using the JAX code snippet from the text as a template, allocate a sharded array and benchmark each of the 4 main communication primitives (AllGather, AllReduce, ReduceScatter, and AllToAll) using pmap or shard_map.

You will want to use jax.lax.all_gather, jax.lax.psum, jax.lax.psum_scatter, and jax.lax.all_to_all. Do you understand the semantics of these functions? How long do they take?

Todo

Question 9: Alternative Strategy for Sharded MatMul

The text claims that when only one input to a matmul is sharded along its contracting dimension, we should AllGather the sharded matrix and perform the contracting locally.

Another strategy you might consider is to perform the sharded matmul and then AllReduce the result, i.e., by:

Answer the following:

  1. Explicitly write out this algorithm for matrices and , using indices to show exactly what computation is done on what device. Assume is sharded as across ND devices, and you want your output to be replicated across all devices.

  2. Now suppose you are ok with the final result not being replicated on each device, but instead sharded (across either the N or K dimension). How would the algorithm above change?

  3. Looking purely at the communication cost of the strategy above (in part 2, not 1), how does this communication cost compare to the communication cost of the algorithm in which we first AllGather A and then do the matmul?

Todo

Question 10: Fun with AllToAll

In the text, it was noted that the time to perform an AllToAll is a factor of 4 lower than the time to perform an AllGather or ReduceScatter (in the regime where we are throughput-bound).

In this problem we will see where that factor of 4 comes from, and also see how this factor would change if we only had single-direction ICI links, rather than bidirectional ICI links.

  1. Let’s start with the single-direction case first. Imagine we have D devices in a ring topology. If we are doing either an AllGather or a ReduceScatter on an N x N matrix A which is sharded as (say divides for simplicity). Describe the comms involved in these two collectives, and calculate the total number of scalars (floats or ints) which are transferred across a single ICI link during the entirety of this algorithm.

  2. Now let’s think about an AllToAll, still in the single-directional ICI case. How is the algorithm different in this case than the all-gather case? Calculate the number of scalars that are transferred across a single ICI link in this algorithm.

  3. You should have found that the ratio between your answers to part (a) and part (b) is a nice number. Explain where this factor comes from in simple terms.

  4. Now let’s add bidirectional communication. How does this affect the total time needed in the all-gather case?

  5. How does adding bidirectional communication affect the total time needed in the AllToAll case?

  6. Now simply explain the ratio between AllGather time and AllToAll time in a bidirectional ring.

Todo