Created by Darshan Fofadiya

The Illustrated Guide to Unified Sequence Parallelism

How to fit a 70B model with 1M context on GPUs

By Darshan Fofadiya

Part 1: GPU Memory Part 2: FSDP Part 3: Ulysses Part 4: Ring Attention Part 5: USP

You want to run Llama-70B with a 1 million token context window. Sounds straightforward — just load the model and go, right?

Not quite. Before we can talk about solutions, we need to understand the problem. In this part, we'll do the actual math to see exactly why this is hard:

  1. First, we'll understand the GPU hardware — what an A100 can and cannot do
  2. Then, we'll calculate the model memory — how much space 70B parameters actually need
  3. Next, we'll compute the activation memory — the Q, K, V tensors for 1M tokens
  4. Finally, we'll see the attention explosion — why O(n²) is a killer

By the end, you'll understand exactly why we need parallelism — and which specific bottlenecks each parallelism strategy addresses.


1.1 The Hardware: NVIDIA A100-80GB

The A100 is NVIDIA's data center GPU for AI workloads. Let's understand each spec and why it matters:

1.1.1 HBM2e Memory: 80 GB

This is the GPU's main memory — High Bandwidth Memory (HBM). Everything lives here:

80 GB is our hard constraint. If our data doesn't fit, we can't run.

1.1.2 Memory Bandwidth: 2.0 TB/s

This is how fast we can move data between HBM and the compute units (tensor cores). Let's put this in perspective:

Memory bandwidth: 2.0 TB/s = 2,000 GB/s

To read 80 GB (entire memory):
  Time = 80 GB ÷ 2,000 GB/s = 0.04 seconds = 40 ms

To read 1 GB:
  Time = 1 GB ÷ 2,000 GB/s = 0.5 ms

This sounds fast, but for large models, we're constantly streaming weights from memory to compute. Memory bandwidth often becomes the bottleneck — not compute.

1.1.3 Tensor Cores: 312 TFLOPS (BF16)

TFLOPS = Trillion Floating Point Operations Per Second. The A100 can do 312 trillion BF16 operations per second.

312 TFLOPS = 312 × 10¹² operations/second

For a matrix multiply of [M, K] × [K, N]:
  Operations = 2 × M × K × N (multiply-add)
  
Example: [8192, 8192] × [8192, 8192]
  Operations = 2 × 8192 × 8192 × 8192 = 1.1 × 10¹² = 1.1 TFLOP
  Time (compute-bound) = 1.1 TFLOP ÷ 312 TFLOPS = 3.5 μs

But wait — we also need to read the matrices from memory:

Data to read: 2 matrices × 8192 × 8192 × 2 bytes = 268 MB
Time (memory-bound) = 268 MB ÷ 2,000 GB/s = 134 μs

Compute time: 3.5 μs
Memory time: 134 μs  ← 38× slower!

We're memory-bound, not compute-bound.
Key insight: For transformer inference, we're almost always memory-bandwidth bound. The tensor cores sit idle waiting for data. This is why techniques like FlashAttention (which reduces memory traffic) are so important.

1.1.4 GPU Interconnects: NVLink vs PCIe

When we distribute work across multiple GPUs, they need to communicate. There are two main interconnects:

PCIe Gen4 x16:
  - Bandwidth: ~32 GB/s (bidirectional)
  - Latency: ~1-2 μs
  - Connects: GPU ↔ CPU, GPU ↔ GPU (different nodes)
  - Available on: All GPUs

NVLink (A100):
  - Bandwidth: 600 GB/s (bidirectional, 12 links × 50 GB/s)
  - Latency: ~0.5 μs  
  - Connects: GPU ↔ GPU (same node only)
  - Available on: Data center GPUs (A100, H100)

Let's see what this means for transferring 1 GB of data:

Transfer 1 GB over PCIe:
  Time = 1 GB ÷ 32 GB/s = 31.25 ms

Transfer 1 GB over NVLink:
  Time = 1 GB ÷ 600 GB/s = 1.67 ms

NVLink is ~19× faster than PCIe!

This is why multi-GPU training and inference strongly prefer NVLink-connected GPUs within a single node. Cross-node communication (which must use PCIe/InfiniBand) is much slower.

1.1.5 The Memory Hierarchy

Understanding where data lives and how fast we can access it:

MEMORY HIERARCHY (A100):
────────────────────────
Level           Size        Bandwidth       Latency
─────────────────────────────────────────────────────
Registers       ~20 MB      ~20 TB/s        ~1 cycle
L2 Cache        40 MB       ~5 TB/s         ~30 cycles
HBM (main)      80 GB       2 TB/s          ~300 cycles
NVLink          N/A         600 GB/s        ~500 cycles
PCIe            N/A         32 GB/s         ~1000 cycles
CPU RAM         ~1 TB       ~100 GB/s       ~10000 cycles

The key takeaway: HBM is our working memory, and 80 GB is all we have. If data doesn't fit in HBM, we have to go to much slower storage.


1.2 The Model: Llama-3-70B Architecture

Now let's understand what we're trying to fit into that 80 GB. Llama-3-70B is a decoder-only transformer with these parameters:

ParameterValueWhat it means
d_model8192Each token is represented as a vector of 8192 numbers
n_heads64Query attention heads
n_kv_heads8Key/Value heads (GQA — 8× fewer than query heads)
d_head128Each head works with 128-dimensional vectors
n_layers8080 transformer blocks stacked sequentially
d_ff28672FFN hidden dimension (~3.5× d_model)
vocab_size128,256Number of unique tokens the model knows

1.2.1 Grouped Query Attention (GQA)

Llama-3 uses GQA, not standard multi-head attention. Instead of 64 separate K and V projections (one per head), it uses only 8 — each shared by 8 query heads. This reduces parameters and KV cache size:

Standard MHA: 64 query heads, 64 key heads, 64 value heads
Llama-3 GQA:  64 query heads, 8 key heads, 8 value heads

Each group of 8 query heads shares 1 key head and 1 value head.

1.2.2 Anatomy of One Transformer Layer

Each of the 80 layers contains these weight matrices:

ATTENTION BLOCK (with GQA):
───────────────────────────
Wq (Query projection):  [8192, 8192]     = 67M params  (64 heads × 128)
Wk (Key projection):    [8192, 1024]     = 8.4M params (8 heads × 128)
Wv (Value projection):  [8192, 1024]     = 8.4M params (8 heads × 128)
Wo (Output projection): [8192, 8192]     = 67M params

Attention total: 67 + 8.4 + 8.4 + 67 = 150.8M params per layer

FFN BLOCK (SwiGLU — 3 matrices, not 2):
───────────────────────────────────────
W_gate: [8192, 28672] = 235M params  (gate projection)
W_up:   [8192, 28672] = 235M params  (up projection)
W_down: [28672, 8192] = 235M params  (down projection)

FFN total: 235 × 3 = 705M params per layer

LAYER NORMS: ~16K params (negligible)

TOTAL PER LAYER: 150.8M + 705M ≈ 856M params

Now let's add it all up:

TRANSFORMER LAYERS:
───────────────────
856M params × 80 layers = 68.5B params

EMBEDDINGS:
───────────
Token embeddings: [128256, 8192] = 1.05B params
Output projection: [8192, 128256] = 1.05B params
(These are sometimes tied, but Llama-3 keeps them separate)

TOTAL: 68.5B + 1.05B + 1.05B ≈ 70.6B parameters ✓
Why GQA matters: By using 8 KV heads instead of 64, Llama-3 reduces KV cache memory by 8×. For long-context inference, this is huge — the KV cache is often the biggest memory bottleneck.

1.2.3 Weight Memory Calculation

Now let's calculate the actual memory needed. We'll use BF16 (bfloat16), which is standard for inference:

BF16 = 16 bits = 2 bytes per parameter

Total parameters: 70 × 10⁹
Memory = 70 × 10⁹ params × 2 bytes/param
       = 140 × 10⁹ bytes
       = 140 GB
Problem #1 — Weights Don't Fit: Model weights alone require 140 GB, but the A100 only has 80 GB. We can't even load the model on a single GPU, let alone run inference.

1.2.4 What About Quantization?

You might think: "Just use INT8 or INT4 quantization!" Let's see:

INT8 (8-bit): 70B × 1 byte = 70 GB  ← Fits! But quality degrades
INT4 (4-bit): 70B × 0.5 bytes = 35 GB  ← Fits easily! More quality loss

But wait — we still need memory for:
  - Activations (Q, K, V tensors)
  - KV cache
  - Intermediate buffers

Even with INT4 weights, we'll run out of memory for long sequences.

Quantization helps, but it doesn't solve the fundamental problem for long-context inference. We'll still need parallelism.


1.3 Inference Scenario: 1 Million Token Context

Now let's see what happens when we actually run inference. We have a 1 million token input — think of a massive document, an entire codebase, or a long conversation history. Let's trace through the memory requirements step by step.

1.3.1 The Input Tensor

Our input starts as a sequence of token IDs, which get embedded into vectors:

Input tokens: [1, 1,000,000]  (batch_size=1, seq_len=1M)

After embedding lookup:
X: [batch_size, seq_len, d_model]
X: [1, 1,000,000, 8192]

Memory for X:
  Elements: 1 × 1,000,000 × 8192 = 8.192 × 10⁹
  Bytes (BF16): 8.192 × 10⁹ × 2 = 16.38 GB

Just the embedded input takes 16 GB. But this is just the beginning.

1.3.2 Q, K, V Projections

In each transformer layer, we project the input into Query, Key, and Value tensors. Remember, with GQA, K and V use only 8 heads (not 64):

Q = X @ Wq    →  [1, 1M, 8192] @ [8192, 8192] = [1, 1M, 8192]   (64 query heads × 128)
K = X @ Wk    →  [1, 1M, 8192] @ [8192, 1024] = [1, 1M, 1024]   (8 KV heads × 128, GQA)
V = X @ Wv    →  [1, 1M, 8192] @ [8192, 1024] = [1, 1M, 1024]   (8 KV heads × 128, GQA)

Q is large (64 heads), but K and V are 8× smaller thanks to GQA. Let's calculate the memory:

Q: 1 × 1,000,000 × 8192 × 2 bytes = 16.38 GB   (64 heads)
K: 1 × 1,000,000 × 1024 × 2 bytes = 2.05 GB    (8 KV heads)
V: 1 × 1,000,000 × 1024 × 2 bytes = 2.05 GB    (8 KV heads)
─────────────────────────────────────────────
Total Q+K+V: 20.48 GB (for ONE layer)
Problem #2 — Activations Are Large: Q + K + V for a single layer consumes 20.5 GB — that's 26% of the A100's memory. And this is on top of the 140 GB weights that already don't fit.

1.3.3 Why Don't We Need 20.5 GB × 80 Layers?

You might wonder: if each layer needs 20.5 GB for Q, K, V, don't we need 20.5 × 80 = 1.6 TB total?

No — and here's why. During inference, we process one layer at a time:

Layer 1: Compute Q, K, V → Compute attention → Get output → FREE Q, K, V
Layer 2: Compute Q, K, V → Compute attention → Get output → FREE Q, K, V
...
Layer 80: Compute Q, K, V → Compute attention → Get output → FREE Q, K, V

We reuse the same memory buffer for each layer's activations. So 20.5 GB is the peak activation memory, not cumulative. But 20.5 GB is still a significant chunk of our 80 GB budget — especially when combined with weights and KV cache.

1.3.4 The KV Cache: Memory That Persists

During autoregressive generation (generating tokens one at a time), we don't want to recompute K and V for all previous tokens. Instead, we cache them:

KV Cache structure per layer (with GQA — 8 KV heads, not 64):
  K cache: [batch, seq_len, n_kv_heads, d_head] = [1, 1M, 8, 128]
  V cache: [batch, seq_len, n_kv_heads, d_head] = [1, 1M, 8, 128]

Memory per layer:
  K: 1 × 1,000,000 × 8 × 128 × 2 bytes = 2.05 GB
  V: 1 × 1,000,000 × 8 × 128 × 2 bytes = 2.05 GB
  Total: 4.1 GB per layer

All 80 layers:
  4.1 GB × 80 = 328 GB
GQA saves 8× on KV cache! Without GQA (64 KV heads), this would be 2.62 TB. With GQA (8 KV heads), it's "only" 328 GB — still 4× the A100's capacity, but much more manageable.
Problem #3 — KV Cache Still Doesn't Fit: Even with GQA's 8× reduction, the KV cache for 1M tokens across 80 layers is 328 GB — 4× the A100's capacity.

This is why long-context inference is so challenging. The KV cache grows linearly with sequence length, and we need to keep all of it in memory.


1.4 The Attention Matrix: Quadratic Explosion

We've saved the worst for last. The attention mechanism computes:

Attention(Q, K, V) = softmax(Q @ Kᵀ / √d_head) @ V

The critical operation is Q @ Kᵀ — this produces the attention scores matrix, where every token attends to every other token.

1.4.1 Shape Analysis

Let's trace through the shapes carefully. First, we reshape Q and K for multi-head attention:

Original shapes:
  Q: [batch, seq_len, d_model] = [1, 1M, 8192]
  K: [batch, seq_len, d_model] = [1, 1M, 8192]

Reshape for multi-head (split d_model into n_heads × d_head):
  Q: [batch, seq_len, n_heads, d_head] = [1, 1M, 64, 128]
  K: [batch, seq_len, n_kv_heads, d_head] = [1, 1M, 8, 128]  (GQA)

For attention, each KV head is broadcast to 8 query heads:
  K_expanded: [1, 1M, 64, 128]  (each KV head repeated 8×)

Transpose for batched matmul:
  Q: [batch, n_heads, seq_len, d_head] = [1, 64, 1M, 128]
  K: [batch, n_heads, seq_len, d_head] = [1, 64, 1M, 128]  (after broadcast)
  Kᵀ: [batch, n_heads, d_head, seq_len] = [1, 64, 128, 1M]

Now the attention scores computation:

Attention Scores = Q @ Kᵀ

  [1, 64, 1M, 128] @ [1, 64, 128, 1M]
  = [1, 64, 1M, 1M]

This is a [seq_len × seq_len] matrix for each head!

1.4.2 Memory Calculation

Let's compute the memory for this attention scores matrix:

Shape: [1, 64, 1,000,000, 1,000,000]

Elements: 1 × 64 × 1,000,000 × 1,000,000
        = 64 × 10¹² 
        = 64 trillion elements

Memory (BF16): 64 × 10¹² × 2 bytes 
             = 128 × 10¹² bytes
             = 128 TB
Problem #4 — Quadratic Memory: The attention score matrix requires 128 TB — that's 1,600× the A100's memory capacity. This is the O(n²) problem in action.

1.4.3 Why O(n²) is a Killer

Let's see how attention memory scales with sequence length:

Sequence Length    Attention Memory (64 heads, BF16)
───────────────    ─────────────────────────────────
1,000              128 MB
10,000             12.8 GB
100,000            1.28 TB
1,000,000          128 TB
10,000,000         12.8 PB (petabytes!)

Every 10× increase in sequence length causes a 100× increase in attention memory. This is why long-context models are so challenging.

1.4.4 FlashAttention: Tiled Computation with Online Softmax

FlashAttention is a clever algorithm that computes attention without ever materializing the full attention matrix. The key insight: we don't need to store all attention scores — we can compute them in tiles and immediately use them.

Let's compare standard attention vs FlashAttention:

Standard Attention (naive):
1. Compute full S = Q @ Kᵀ           [1, 64, 1M, 1M] → 128 TB  ← Store this!
2. Apply softmax to S                [1, 64, 1M, 1M] → 128 TB  ← Store this!
3. Compute Output = S @ V            [1, 64, 1M, 128] → 16 GB

Peak memory: 128 TB (the attention matrix)

1.4.5 The Tiling Trick

FlashAttention processes Q in blocks (tiles), and for each Q block, iterates through all K,V blocks:

FlashAttention (tiled):
For each Q_block (e.g., rows 0-1023 of Q):
  Initialize: output_block = 0, running_max = -∞, running_sum = 0
  
  For each K_block, V_block:
    1. Compute tile: S_tile = Q_block @ K_blockᵀ    [1024, 1024] → 8 MB
    2. Compute local softmax statistics
    3. Update running softmax (online algorithm)
    4. Accumulate: output_block += softmax(S_tile) @ V_block
    5. DISCARD S_tile immediately!
    
  Store output_block (final result for these Q rows)

Peak memory: O(tile_size²) ≈ 8 MB per tile

1.4.6 The Online Softmax Problem

But wait — there's a problem. Softmax requires knowing the maximum value across the entire row:

softmax(x_i) = exp(x_i - max(x)) / Σ exp(x_j - max(x))

For row i of the attention matrix:
  - We need max across ALL 1M columns
  - We need sum of exp() across ALL 1M columns
  
But we're only looking at 1024 columns at a time!

FlashAttention solves this with the "online softmax" trick — maintaining running statistics that get corrected as we see more tiles:

Online Softmax Algorithm:
─────────────────────────
After processing K_block 1:
  m₁ = max of tile 1
  l₁ = sum of exp(scores - m₁) for tile 1
  o₁ = softmax(tile 1) @ V_block_1

After processing K_block 2:
  m₂ = max(m₁, max of tile 2)           ← Update global max
  
  Correction factor for old results:
  α = exp(m₁ - m₂)                       ← Scale down old values
  
  l₂ = α × l₁ + sum of exp(tile 2 - m₂) ← Update running sum
  o₂ = α × o₁ + softmax(tile 2) @ V_block_2  ← Correct and accumulate

After ALL K_blocks:
  Final output = o_final / l_final       ← Normalize

1.4.7 Memory Savings

Let's quantify the memory reduction:

Standard Attention:
  Attention matrix: [1, 64, 1M, 1M] × 2 bytes = 128 TB
  
FlashAttention (tile_size = 1024):
  One tile: [1, 64, 1024, 1024] × 2 bytes = 128 MB
  Running statistics per Q_block: ~few KB
  
Memory reduction: 128 TB → 128 MB = 1,000,000× less!

But we still need:
  Q: 16.38 GB      (64 query heads)
  K: 2.05 GB       (8 KV heads, GQA)
  V: 2.05 GB       (8 KV heads, GQA)
  Output: 16.38 GB
  ─────────────────
  Total: ~37 GB (fits in 80 GB for one layer!)
FlashAttention's Impact: By never materializing the full attention matrix, FlashAttention reduces memory from O(n²) to O(n). This makes attention computation feasible even for very long sequences — the 128 TB problem disappears!

However, FlashAttention doesn't solve everything. We still need to store Q, K, V tensors (20.5 GB per layer), and for inference, the KV cache still grows linearly with sequence length. For 1M tokens, that's still 328 GB across all layers (thanks to GQA reducing it 8×).

1.4.8 Why Attention Scores Don't Accumulate Across Layers

Unlike the KV cache, attention scores are computed and immediately discarded:

Layer 1: Compute attention scores → Use them → DISCARD
Layer 2: Compute attention scores → Use them → DISCARD
...
Layer 80: Compute attention scores → Use them → DISCARD

So 128 TB (or with FlashAttention, much less) is the peak memory for one layer's attention computation, not cumulative across all 80 layers. This is a crucial distinction from the KV cache, which must persist.


1.5 Memory Budget Summary

Let's put it all together. Here's what we're trying to fit into an 80 GB GPU:

ComponentSizevs A100 (80 GB)Persists?
Model Weights (BF16)140 GB1.75× capacity ❌Yes — always in memory
Q + K + V (1 layer)20.5 GB26% of capacity ⚠️No — recomputed per layer
Attention Scores (1 layer, naive)128 TB1,600× capacity ❌No — discarded after use
KV Cache (all 80 layers, GQA)328 GB4× capacity ❌Yes — grows with sequence

Even with FlashAttention eliminating the 128 TB attention matrix, we still have:

Minimum memory needed:
  Weights:     140 GB
  Activations: 20.5 GB (peak, one layer at a time)
  KV Cache:    328 GB (with GQA)
  ─────────────────────
  Total:       ~489 GB

Available on one A100: 80 GB

We need at least 7 GPUs just for memory capacity!
The Bottom Line: A single GPU cannot handle this workload. We need to distribute computation across multiple GPUs — but not just any distribution. We need strategies that specifically target each bottleneck.

1.6 Three Bottlenecks → Three Solutions

We've identified three distinct memory bottlenecks. Each requires a different parallelism strategy:

BottleneckSizeSolutionHow It Works
Weights (140 GB) 1.75× GPU Weight Sharding (FSDP) Split weight matrices across GPUs. Each GPU stores 1/N of the weights and gathers them when needed.
Activations (20.5 GB/layer) 26% GPU Sequence Sharding (Ulysses) Split the sequence across GPUs. Each GPU processes 1/N of the tokens, using all-to-all communication for attention.
Attention O(n²) 1,600× GPU Ring Attention Compute attention in chunks, passing KV blocks around a ring of GPUs. Never materialize the full attention matrix.

1.6.1 Why We Need All Three

You might wonder: can't we just use one strategy? Let's see:

FSDP alone (8 GPUs):
  Weights: 140 GB ÷ 8 = 17.5 GB per GPU ✓
  Activations: Still 20.5 GB per GPU ❌
  KV Cache: Still 328 GB total ❌

Sequence Parallelism alone (8 GPUs):
  Weights: Still 140 GB per GPU ❌
  Activations: 20.5 GB ÷ 8 = 2.6 GB per GPU ✓
  KV Cache: 328 GB ÷ 8 = 41 GB per GPU ❌

Ring Attention alone (8 GPUs):
  Weights: Still 140 GB per GPU ❌
  Activations: Distributed ✓
  KV Cache: Distributed ✓
  But: Requires weights to fit on each GPU ❌

No single strategy solves all three problems. We need to combine them.

1.6.2 Unified Sequence Parallelism (USP)

USP combines all three strategies:

USP with 8 GPUs:
  FSDP:              Weights 140 GB ÷ 8 = 17.5 GB per GPU ✓
  Ulysses:           Activations 20.5 GB ÷ 8 = 2.6 GB per GPU ✓
  Ring Attention:    KV Cache distributed across ring ✓
  
Per GPU: 17.5 + 2.6 + KV cache (distributed) — fits in 80 GB!

In the following parts, we'll dive deep into each strategy:


Key Numbers to Remember

A100 HBM Memory80 GB
A100 Memory Bandwidth2.0 TB/s
A100 NVLink Bandwidth600 GB/s
Llama-70B Weights (BF16)140 GB
Q+K+V for 1M tokens (1 layer)20.5 GB
Attention Scores for 1M tokens128 TB
KV Cache for 1M tokens (80 layers, GQA)328 GB

What's Next

Part 2: Weight Sharding (FSDP)How to distribute 140 GB of weights across GPUs using AllGather and ReduceScatter
Part 3: Sequence Sharding (Ulysses)Splitting the 1M token sequence with all-to-all communication
Part 4: Ring AttentionDistributed attention without materializing O(n²)
Part 5: Putting It TogetherThe complete USP picture

The math doesn't lie. We need parallelism — and now we know exactly why.