TPU Architecture Practice Questions
Question 1: LLM Latency Bounds
Say you want to sample from a 200B parameter model in bf16 that’s split across 32 TPU v4p. How long would it take to load all the parameters from HBM into the systolic array?
Model size =
Model size per chip=
TPU v4p HBM BW =
Time =
Question 2: TPU Pod Specifications
Consider a full TPU v5e pod. How many total CPU hosts are there? How many TPU TensorCores? What is the total FLOPs/s for the whole pod? What is the total HBM?
Do the same exercise for a TPU v5p pod.
| Chip Type | Pod Size | Host Size | Total Hosts | Total FLOPs/s | Total HBM |
|---|---|---|---|---|---|
| TPU v5e | 16×16 | 4×2 | 8 | ||
| TPU v5p | 16x20x28 | 2x2x1 | 4 |
Question 3: PCIe Operational Intensity
Imagine we’re forced to store a big weight matrix of type , and a batch of activations of type in host DRAM and want to do a matrix multiplication on them. This is running on a single host, and we’re using a single TPU v6e chip attached to it.
You can assume , and (we’ll see in future chapters why these are reasonable assumptions). What is the smallest batch size we need to remain FLOPs bound over PCIe? Assume PCIe bandwidth of 1.5e10 bytes/second.
Intensity(matmul) =
Assuming :
Intensity(matmul) =
For flop-bound,
Correct answer and approach, incorrect thought process
Since the data needs to be first loaded into DRAM, I should not have thought of the data read operation ‘directly’ in these calculations above.
operations time = seconds weight read+write time = seconds
This is assuming we can overload operations with weight loading (either overlap reading/writing data from DRAM while load from host happens, or overlap compute while read/write from DRAM happens)
Question 4: General MatMul Latency
Let’s say we want to multiply a weight matrix int8[16384, 4096] by an activation matrix of size int8[B, 4096] where B is some unknown batch size. Let’s say we’re on 1 TPUv5e to start.
a) How long will this multiplication take as a function of B?
b) What if we wanted to run this operation out of VMEM? How long would it take as a function of B?
Data size:
- W =
- X =
- X*W =
Ops =
For v5e chip:
Time (ops) =
Out of HBM:
Time (comm) =
For FLOP-bound:
Out of VMEM:
Time (comm) = 1/22 *
For FLOP-bound:
Question 5: ICI Bandwidth
Let’s say we have a TPU v5e
4x4slice. Let’s say we want to send an array of typebfloat16[8, 128, 8192]fromTPU{0,0}toTPU{3, 3}. Let’s say the per-hop latency for TPU v5e is .a) How soon will the first byte arrive at its destination?
b) How long will the total transfer take?
For v5e no wrap-around
Can send data down + right
Total bytes =
Bytes per second =
First byte = 6 hops =
Total time =
Question 6: Multi-Component Performance (Hard)
Imagine you have a big matrix A:
int8[128 * 1024, 128 * 1024]sharded evenly across a TPU v5e 4x4 slice but offloaded to host DRAM on each chip. Let’s say you want to copy the entire array to TPU{0, 0} and multiply it by a vectorbf16[8, 128 * 1024]. How long will this take?
Steps:
- Calculate size of array on each slice
- Calculate data transfer time from each slice to T{0, 0}
- Calculate time to read from T{0, 0} DRAM (both matrix and vector)
- Calculate time to perform multiplication
- Calculate time to write output to DRAM