In-Text Questions
Pop Quiz 1: Memory Usage
Let A be an array with shape
int8[128, 2048], sharding , and meshMesh({'X': 2, 'Y': 8, 'Z': 2})(so 32 devices total).How much memory does A use per device?
How much total memory does A use across all devices?
Per device:
[128//(2 * 8), 2048] = [8, 2048] = 16,384 bytes
Total = 32 * 16,384 bytes = 524,288 bytes
Pop Quiz 2: AllGather Time
Using the numbers from Part 2, how long does it take to perform the AllGatherY([EY, F]) → [E, F] on a TPUv5e with a 2D mesh
{'X': 8, 'Y': 4}, , in bfloat16?What about with ?
TPU v5e ICI (bidi) = 9e10 bytes/s
Each shard = {512, 8192}
We need three unidirectional hops since we cannot do fully directional AllGather:
For each shard of {64, 512}
This is latency bound, so 3 hops =
[! Failure] My Original Incorrect Answer
For part 1, I did not consider that we cannot do full bi-directional hop since axis does not have size 16.
For part 2, I did not consider that it will be latency bound
Practice Questions
Question 1: Replicated Sharding Memory Usage
An array is sharded (i.e., only sharded across ), with a mesh
Mesh({'X': 4, 'Y': 8, 'Z': 2}). What is the ratio of the total number of bytes taken up by across all chips to the size of one copy of the array?
each shard = 1/4 * sizeof(A)
total shards = 64
ratio = 64*1/4 = 16 * sizeof(A)
Question 2: AllGather Latency
How long should take on a TPUv4p 4x4x4 slice with mesh
Mesh({'X': 4, 'Y': 4, 'Z': 4})if and in bfloat16?How about ?
How about ?
TPUv4p ICI bw = 9e10 bytes/s (bidi)
- with mesh
Mesh({'X': 4, 'Y': 4, 'Z': 4})if and in bfloat16
Shard size =
Since full-directional transfer possible, along the contracting dimension (X) full array size (1024) matters not the shard size
Time =
- with mesh
Mesh({'X': 4, 'Y': 4, 'Z': 4})if and in bfloat16
Time =
Shard size =
Total transfer time =
Time =
Question 3: Latency-Bound AllGather
Let’s say we’re performing an but is very small (say 128). How long should this take on a TPUv4p 4x4x4 slice with mesh
Mesh({'X': 4, 'Y': 4, 'Z': 4})in bfloat16?Hint: you’re probably latency bound.
Latency bound = per hop
bidirectional possible = 2 hops =
Question 4: Matrix Multiplication Strategies
To perform , the text describes using and multiplying the fully replicated matrices (Strategy 1).
Instead, you could multiply the local shards like , and then (Strategy 2).
How many FLOPs and comms does each of these strategies perform? Which is better and why?
Let’s assume bf16 (2 bytes per element)
Also, for TPU v5e
Case 1: AllGather → MatMul
AllGather = (2DF/W) MatMul FLOPs = 2BDF/C (where C is accelerator FLOPs - each TPU does same FLOPs parallely)
Compute Bound if
If B > 2550, compute bound
Case 2: Sharded MatMul → AllReduce
Shared MatMul FLOPS = 2BDF / |X| = 2BDF/XC FLOPs (each TPU) AllReduce = (4BF/W)
BF/W)$
Compute Bound if
If D 8k (reasonable), X < 2 → which never happens
Thus Case 2 is always comms bound
Which is faster?
-
For B < 2550 (often, but not always) both are comms bound
Thus if D > 2B and 2B < 5100 : strategy 2 is faster - this is true for larger models with reasonable batch sizes
-
For B > 2550, 2 is comms bound 1 is math bound
Thus if D > 2C/W ⇒ D > 5100: strategy 2 is faster
Why don’t we employ Strategy 2?
We do sometimes. but is is unlikely that the contracting dimension of one of the inputs is sharded along an axis where the other input isn’t (for matmuls).
Example, for FSDP we shard both parameters and data along same dimensions
My original incorrect answer
(I was not thinking in terms of total bytes when sharded, only thinking along one dim, also I calculated FLOPs explicitly while we are interested in compute time not the flops per se)
Case 1: AllGather → MatMul
AllGather = (D * sizeof(D_{elem})/W) MatMul FLOPs = 2BDF * |X|
Case 2: Sharded MatMul → AllReduce
Shared MatMul FLOPS = 2B(D/|X|)F * |X| = 2BDF AllReduce = 2*(D * sizeof(D_{elem})/W)
Question 5: Minimum Latency Matrix Multiplication
Let’s say I want to do a matmul on a TPUv5p 4x4x4 with the lowest possible latency. How should my inputs be sharded? What is the total FLOPs and comms time?
- Case 1: B < 2550
- If D > 2B and 2B < 5100,
Todo
Question 6: Complex Sharding Communication
Let’s say we want to perform on TPUv5e 4x4. What communication do we perform? How much time is spent on communication vs. computation?
What about ? This is the most standard setting for training where we combine data, tensor, and zero sharding.
What about ? This is standard for inference, where we do pure tensor parallelism (+data).
Todo
Question 7: Memory-Constrained Transformer Block
A typical Transformer block has two matrices and where . With a batch size B, the whole block is with .
Let’s pick , , and and assume everything is in bfloat16. Assume we’re running on a TPUv5e 2x2 slice but assume each TPU only has 300MB of free memory.
How should B, C, and the output be sharded to stay below the memory limit while minimizing overall time? How much time is spent on comms and FLOPs?
Todo
Question 8: Communication Primitives Benchmarking
Using the JAX code snippet from the text as a template, allocate a sharded array and benchmark each of the 4 main communication primitives (AllGather, AllReduce, ReduceScatter, and AllToAll) using pmap or shard_map.
You will want to use
jax.lax.all_gather,jax.lax.psum,jax.lax.psum_scatter, andjax.lax.all_to_all. Do you understand the semantics of these functions? How long do they take?
Todo
Question 9: Alternative Strategy for Sharded MatMul
The text claims that when only one input to a matmul is sharded along its contracting dimension, we should AllGather the sharded matrix and perform the contracting locally.
Another strategy you might consider is to perform the sharded matmul and then AllReduce the result, i.e., by:
Answer the following:
Explicitly write out this algorithm for matrices and , using indices to show exactly what computation is done on what device. Assume is sharded as across ND devices, and you want your output to be replicated across all devices.
Now suppose you are ok with the final result not being replicated on each device, but instead sharded (across either the N or K dimension). How would the algorithm above change?
Looking purely at the communication cost of the strategy above (in part 2, not 1), how does this communication cost compare to the communication cost of the algorithm in which we first AllGather A and then do the matmul?
Todo
Question 10: Fun with AllToAll
In the text, it was noted that the time to perform an AllToAll is a factor of 4 lower than the time to perform an AllGather or ReduceScatter (in the regime where we are throughput-bound).
In this problem we will see where that factor of 4 comes from, and also see how this factor would change if we only had single-direction ICI links, rather than bidirectional ICI links.
Let’s start with the single-direction case first. Imagine we have D devices in a ring topology. If we are doing either an AllGather or a ReduceScatter on an N x N matrix A which is sharded as (say divides for simplicity). Describe the comms involved in these two collectives, and calculate the total number of scalars (floats or ints) which are transferred across a single ICI link during the entirety of this algorithm.
Now let’s think about an AllToAll, still in the single-directional ICI case. How is the algorithm different in this case than the all-gather case? Calculate the number of scalars that are transferred across a single ICI link in this algorithm.
You should have found that the ratio between your answers to part (a) and part (b) is a nice number. Explain where this factor comes from in simple terms.
Now let’s add bidirectional communication. How does this affect the total time needed in the all-gather case?
How does adding bidirectional communication affect the total time needed in the AllToAll case?
Now simply explain the ratio between AllGather time and AllToAll time in a bidirectional ring.
Todo