Skip to content

Unified additive step-time formula: fix bandwidth double-counting, 11.7% E2E MAPE (from 28.7%) #4

@sriumcp

Description

@sriumcp

Summary

The current 7-term trained-roofline step-time formula double-counts HBM bandwidth by placing weight loading (T_weight) and KV cache access (T_dc_kv) in separate additive terms. At realistic batch sizes (80-128 requests), both terms are significant and they share the same HBM bus — adding them overestimates by ~2x, causing false saturation at high concurrency.

A unified 5-term additive formula that merges all memory traffic into a single bandwidth term achieves 11.7% E2E MAPE (from 28.7%) and 22.5% TTFT MAPE (from 86.3%) across 21 sub-saturation experiments on 4 models. The formula is fully linear in β — NNLS-compatible, convex, fits in seconds.


Part 1: The Formula

Current formula (7-term, broken)

step = β₁·max(T_pf_compute, T_pf_kv)
     + β₂·max(T_dc_compute, T_dc_kv)     ← KV bandwidth here
     + β₃·T_weight                         ← weight bandwidth here (SEPARATE term)
     + β₄·T_tp + β₅·L + β₆·B + β₇

T_weight (~4ms for 8B model) and T_dc_kv (~9ms at batch=80) are charged as independent costs. In vLLM's single forward pass, weight reads and KV cache reads flow through the same HBM channel layer-by-layer.

Proposed formula (unified-5)

step = β₁·T_pf_compute + β₂·T_dc_compute + β₃·(T_weight + T_pf_kv + T_dc_kv) + β₄·L + β₅·B

Key properties:

  • Linear in β → NNLS-compatible, convex optimization, fits in seconds
  • 5 parameters (down from 7) → smaller, more robust, same loss surface
  • T_weight inside T_all_memory → correctly models shared HBM bus (the fix)

Basis functions (all from first principles)

# Feature Computation Source
f1 T_pf_compute Prefill FLOPs / peak_flops QKV proj + causal attn t_i·(s_i+t_i/2) + 3-matrix SwiGLU MLP
f2 T_dc_compute Decode FLOPs / peak_flops Context-dependent attention + proj + MLP (tiny for memory-bound decode)
f3 T_all_memory (T_weight + T_pf_kv + T_dc_kv) / peak_bw Unified HBM: weights (MoE-aware nEff) + KV reads/writes
f4 L numLayers Per-layer overhead (kernel dispatch, sync)
f5 B batchSize Per-request overhead (scheduling, sampling, state update)

Fitted coefficients (BLIS-in-the-loop, 21 ground truth experiments)

beta = [0.393, 0.093, 0.910, 68.3, 12.9]   # β₁-β₅
alpha = [19615, 1850, 1.71]                   # α₀, α₁, α₂

Physical interpretation:

  • β₁=0.39: prefill compute ~39% of peak FLOPs (MFU, accounts for compute-memory overlap in additive model)
  • β₂=0.09: decode compute negligible (memory-bound regime)
  • β₃=0.91: HBM bandwidth ~91% of peak utilization
  • β₄=68.3 µs/layer: per-layer overhead (averaged over CUDA-graph and eager mode steps)
  • β₅=12.9 µs/request: per-request scheduling overhead

Part 2: Evidence

Full comparison (21 sub-saturation experiments, 4 models × 3-4 profiles × 2-3 rates)

Backend E2E MAPE TTFT MAPE ITL MAPE TTFT worst
CrossModel 24.5% 25.5% 91.2% 100%
Roofline 32.9% 36.1% 32.3% 188%
TR-original-7 28.7% 86.3% 126.3% 892%
Unified-5 (proposed) 11.7% 22.5% 82.7% 100%

Per-model E2E MAPE

Model TP CrossModel TR-original-7 Unified-5
Llama-2-7B 1 31.7% 11.5% 15.1%
Llama-2-70B 4 15.2% 19.4% 3.3%
Mixtral-8x7B 2 5.9% 23.5% 2.0%
CodeLlama-34B 2 41.9% 55.2% 27.0%

Optimizer robustness (3 methods, all converge to same basin)

Optimizer E2E MAPE Loss Evaluations Time
Nelder-Mead 11.7% 18.4 243 5.9 min
Differential Evolution 12.2% 19.2 768 18.9 min
Dual Annealing 12.5% 19.2 250 6.1 min

An 8-parameter variant adding CUDA-graph-aware features (L_eager, T_pf_tokens, constant) was also tested. It improves ITL (74% vs 83%) but doesn't improve E2E and creates a harder optimization landscape. The 5-parameter formula is recommended.


Part 3: Known Limitations

3a. Expert Parallelism (EP) not modeled

The formula currently assumes standard tensor parallelism for MoE:

bytesFfn = L * nEff * 3 * d * dFF * bytesPerElement / tp

This is correct when enable_expert_parallel=False (the training experiments' configuration). But when EP is enabled (--enable-expert-parallel), vLLM repurposes the TP dimension for MoE layers (FusedMoEParallelConfig.make(), config.py:1019-1037):

# With EP enabled:
ep_size = tp_size      # EP takes over TP for MoE
return FusedMoEParallelConfig(tp_size=1, ep_size=ep_size)

This means:

  • Attention: still TP-sharded (each GPU has d/TP head dimensions, kvHeads/TP KV heads)
  • MoE FFN: EP-distributed (each GPU holds N/EP complete experts, not tensor-sharded)
  • Communication: all-to-all dispatch/gather per MoE layer (not TP all-reduce)

The formula needs an EP-aware path:

if ep_size > 1:
    localExperts = numExperts / ep_size
    nEffLocal = min(localExperts, max(kEff, B*kEff))
    bytesFfn = L * nEffLocal * 3 * d * dFF * bytesPerElement  // NO /tp
    T_all2all = L * B * kEff * d * 4 * bytesPerElement / inter_gpu_bw  // new term
else:
    bytesFfn = L * nEff * 3 * d * dFF * bytesPerElement / tp
    T_all2all = 0

This requires: (a) an ep_size parameter in the model config, (b) NVLink/inter-GPU bandwidth in hardware specs, and (c) training data with EP enabled.

3b. CUDA graph vs eager mode averaging

vLLM uses CUDA graph replay for pure decode batches (cudagraph_utils.py:216), eliminating per-layer kernel launch overhead. Mixed batches (any prefill) fall back to eager mode with ~50-100µs/layer of kernel dispatch. The current β₄ averages over both regimes. An 8-parameter variant with L_eager = L × hasPrefill showed ITL improvement (74% vs 83%) but requires more training data to fit reliably.

3c. α₀ is a compensatory parameter

The optimized α₀=19615µs (vs original 9315µs fitted from server-side QUEUED.ts − ARRIVED.ts) absorbs client-visible overhead not captured by the step-time model. It should be jointly optimized rather than independently fitted.

3d. CodeLlama-34B is the hardest model

27% E2E MAPE vs 2-3% for Mixtral/Llama-70B. This model's TP=2 configuration with 48 layers creates different overhead characteristics. More data at varied rates and workloads would help.


Part 4: Training Pipeline Design

The training pipeline has two complementary stages, each producing progressively better coefficients.

Stage 1: Analytical Regression (NNLS on teacher-forced data)

What: Fit β₁-β₅ via scipy.optimize.nnls on per-request processing time, using teacher-forced batch compositions from vLLM OTEL traces.

Why: Fast (seconds), convex (guaranteed global optimum), provides physics-grounded initialization.

Changes to current pipeline (fit_coefficients.py):

Replace the 7-column feature matrix:

# OLD: 7 features per step
[max(T_pf_compute, T_pf_kv), max(T_dc_compute, T_dc_kv), T_weight, T_tp, L, B, 1]

With 5 unified features:

# NEW: 5 features per step
[T_pf_compute, T_dc_compute, T_weight + T_pf_kv + T_dc_kv, L, B]

The stacked prefill/decode split still applies. The build_stacked_feature_matrix() function changes column construction, not structure.

Output: β₁-β₅ with ~7% per-step MAPE (teacher-forced). These are good initial coefficients but don't account for BLIS scheduling dynamics.

Stage 2: BLIS-in-the-Loop Refinement

What: Optimize β₁-β₅ + α₀ by running the full BLIS simulator against each training experiment and comparing predicted E2E/TTFT against real stage_N_lifecycle_metrics.json.

Why: Captures system-level effects (batch formation, queueing, scheduling) that teacher-forced evaluation misses. This is where the E2E MAPE drops from ~15% (NNLS alone) to 11.7%.

Process:

  1. Initialize from Stage 1 β coefficients
  2. Loss = MAPE(E2E) + 0.3 × MAPE(TTFT) across sub-saturation experiments
  3. Optimizer: Nelder-Mead (5.9 min, 243 evals) — all 3 tested optimizers (NM, DE, DA) converge to the same basin
  4. BLIS runs use the exact experiment parameters: model, TP, rate, max_num_seqs, max_num_batched_tokens, token distributions

Saturation boundary: Experiments where fail_rate > 10% are excluded from the metric loss. Their role is validation-only: verify that BLIS correctly predicts queueing explosion for overloaded configurations. The model should NOT try to predict accurate E2E above saturation — just get the cliff location right.

Cliff detection: For each model/TP/workload, binary-search the rate where BLIS TTFT exceeds 3× the low-rate baseline. Compare against real saturation evidence (failure rate onset in training data). If the predicted cliff is >20% off from reality, add a saturation penalty to the loss.

Combined Pipeline

┌──────────────────────────────────────────────────────────┐
│ Phase 1: Data Collection (per GPU × model × workload)    │
│   → vLLM + inference-perf experiments at 3-5 rates       │
│   → OTEL traces + lifecycle metrics + KV events           │
└──────────────────────────┬───────────────────────────────┘
                           ▼
┌──────────────────────────────────────────────────────────┐
│ Phase 2: Teacher-Forced Regression (NNLS, seconds)       │
│   → reconstruct_steps.py → basis_functions.py            │
│   → 5-feature unified matrix → nnls(X, y)               │
│   → Output: β_init (per-step optimized)                  │
└──────────────────────────┬───────────────────────────────┘
                           ▼
┌──────────────────────────────────────────────────────────┐
│ Phase 3: BLIS-in-the-Loop Refinement (NM, ~6 min)       │
│   → Run BLIS with β_init against all experiments         │
│   → Compare E2E/TTFT vs real lifecycle metrics           │
│   → Nelder-Mead on (β₁-β₅, α₀)                         │
│   → Output: β_final + α₀ (system-level optimized)       │
└──────────────────────────┬───────────────────────────────┘
                           ▼
┌──────────────────────────────────────────────────────────┐
│ Phase 4: Saturation Validation                           │
│   → Binary-search saturation rate per model/TP/workload  │
│   → Compare against overload experiments                 │
│   → Report capacity ceiling alongside coefficients       │
└──────────────────────────────────────────────────────────┘

Part 5: Data Collection Plan

Current coverage

  • GPU: H100 SXM only
  • Models: 4 (Llama-2-7B TP=1, Llama-2-70B TP=4, Mixtral-8x7B TP=2, CodeLlama-34B TP=2)
  • Parallelism: TP only (no EP, no DP)
  • Workloads: 4 profiles × 1-2 rates = 13 active + 3 overload experiments
  • Precision: FP16/BF16 only

Required expansion for generalization

5a. Hardware diversity (3 GPU families)

GPU Peak FLOPs (FP16) HBM BW Memory Priority
H100 SXM 989.5 TFLOP/s 3.35 TB/s 80 GB Have data
A100 SXM 312 TFLOP/s 2.04 TB/s 80 GB High — different compute/BW ratio
L40S 362 TFLOP/s 864 GB/s 48 GB High — PCIe, no NVLink, different memory ceiling

Each GPU has a different compute-to-bandwidth ratio (arithmetic intensity crossover), which determines where the roofline switches between compute-bound and memory-bound. A100 is more bandwidth-constrained than H100; L40S is severely bandwidth-limited. The β₃ (bandwidth utilization) coefficient may differ across GPUs.

Minimum per GPU: same 4 models × 4 profiles × 2-3 rates = ~40 experiments.

5b. Model architecture diversity

Architecture Examples Key Feature Current Coverage
Dense MHA Llama-2-7B, Llama-2-70B Full attention heads
Dense GQA Llama-3.1-8B, Llama-3.3-70B, Qwen-2.5 Grouped-query attention ❌ Need
Sparse MoE Mixtral-8x7B 8 experts, top-2 ✅ (TP only)
Sparse MoE + GQA DeepSeek-V2, Qwen3-MoE MoE + grouped KV ❌ Need
Large MoE Mixtral-8x22B, DeepSeek-V3 Many experts, EP required ❌ Need (EP path)

GQA models (Llama-3.x) are especially important — they have 4-8× fewer KV heads than MHA, changing the KV bandwidth cost dramatically.

5c. Parallelism configurations

Config Description Current Needed
TP=1 Single GPU ✅ (7B) More models
TP=2 2-GPU tensor parallel ✅ (Mixtral, CodeLlama) EP variants
TP=4 4-GPU tensor parallel ✅ (70B) More models
TP=2 + EP 2-GPU expert parallel Required for MoE generalization
TP=4 + EP 4-GPU expert parallel Required for large MoE
TP=2 + DP=2 Hybrid TP+DP Optional (same-GPU physics)

EP experiments require --enable-expert-parallel flag and collect the all-to-all communication overhead that the current formula doesn't model.

5d. Workload profiles

Profile Input tokens Output tokens Characteristics
General (chat) ~575 ~210 Moderate I/O, main use case
Codegen ~590 ~195 Similar to general, code-specific
Roleplay ~785 ~250 Longer input (system prompts)
Reasoning ~1080 ~1100-1450 Very long I/O, saturates quickly
Long-context RAG 4096+ ~200 Tests KV bandwidth scaling
Short-burst API ~50 ~20 Tests minimum overhead floor

5e. Rate sweep per experiment

Each model/GPU/TP/workload combination needs 3-5 rates spanning sub-saturation to near-saturation:

rates = [low, medium, high, near_saturation, at_saturation]

Where:

  • low: ~25% of estimated capacity (pure step-time validation, no queueing)
  • medium: ~50% capacity (moderate batching)
  • high: ~75% capacity (realistic production load)
  • near_saturation: ~90% capacity (queueing starts)
  • at_saturation: ~110% capacity (intentional overload — failure rate > 10%)

The at_saturation rate provides the cliff detection ground truth. Capacity can be estimated from low-rate step times before running the full sweep.

5f. Per-experiment data requirements

Each experiment must produce:

  1. traces.json: OTEL journey + step traces (for teacher-forced reconstruction)
  2. stage_N_lifecycle_metrics.json: aggregate E2E/TTFT/ITL per rate stage (for BLIS-in-the-loop)
  3. per_request_lifecycle_metrics.json: per-request timings (for distribution analysis)
  4. exp-config.yaml: exact vLLM server config (max_num_seqs, max_num_batched_tokens, etc.)
  5. kv_events.jsonl: KV cache block events (for KV cache model validation)

The lifecycle metrics are the most critical — they provide the ground truth for BLIS-in-the-loop optimization. The traces are needed for Stage 1 (teacher-forced NNLS).

5g. Train/validation/test split strategy

Split unit: experiments (not individual requests), to test generalization.

Split Purpose Composition
Train Fit β₁-β₅ + α₀ ~70% of experiments: all rates for most model/GPU/workload combos
Validate Tune hyperparams (loss weights, optimizer settings) ~15%: held-out rates or workload profiles
Test Final MAPE reporting ~15%: held-out model/GPU combos (e.g., entire A100 suite or a new model)

Key: test split should include at least one model NOT seen in training, to validate cross-model generalization. The current crossmodel backend (11.8% MAPE) achieves this via architecture features — the unified formula inherits this capability.


Part 6: Implementation Changes

6a. basis_functions.py — New 5-feature construction

Replace compute_step_basis() output from 7 fields to 5:

@dataclass
class StepBasisValues:
    t_pf_compute: float     # f1: prefill FLOPs / peak (µs)
    t_dc_compute: float     # f2: decode FLOPs / peak (µs)
    t_all_memory: float     # f3: T_weight + T_pf_kv + T_dc_kv (µs)
    num_layers: float       # f4: L
    batch_size: float       # f5: B

Individual basis computations (T_pf_compute, T_weight, T_dc_kv, etc.) stay identical. Only the grouping changes.

6b. fit_coefficients.py — Unified stacked matrix

build_stacked_feature_matrix() produces a 5-column matrix instead of 7. The stacked prefill/decode split applies as before — f1 accumulated from prefill entries, f2 from decode entries, f3/f4/f5 from both.

6c. fit_coefficients.py — Joint α₀ optimization (Stage 2)

Add a new function refine_with_blis() that:

  1. Takes NNLS-fitted β and independently-fitted α₀ as initialization
  2. Runs BLIS against each experiment's real parameters
  3. Optimizes (β₁-β₅, α₀) with Nelder-Mead against E2E + TTFT MAPE
  4. Returns system-optimized coefficients

6d. evaluate.py — System-level evaluation

Add evaluate_system_level() that runs BLIS for each experiment and reports:

  • Per-experiment E2E/TTFT/ITL MAPE
  • Per-model aggregate MAPE
  • Saturation boundary comparison (predicted vs real cliff rate)
  • Worst-case TTFT (the metric that caught the original 892% failure)

6e. Go runtime (sim/latency/trained_roofline.go)

Prototype already exists in the investigate-trained-roofline worktree. The 5-coefficient path activates when len(betaCoeffs) == 5. Basis function code is unchanged.

6f. Future: EP-aware path

When ep_size is added to ModelHardwareConfig:

if m.epSize > 1 {
    localExperts := m.numExperts / m.epSize
    nEffLocal := math.Min(float64(localExperts), math.Max(kEff, B*kEff))
    bytesFfn = L * float64(nEffLocal) * 3 * d * dFF * bytesPerElement  // no /tp
    tAll2All = L * batchSize * kEff * d * 4 * bytesPerElement / m.nvlinkBwUs
} else {
    bytesFfn = L * nEff * 3 * d * dFF * bytesPerElement / tp
    tAll2All = 0
}

Part 7: Validation Checklist

Formula correctness

  • NNLS on teacher-forced data produces β ≥ 0 for all 5 coefficients
  • Per-step MAPE (teacher-forced) comparable to 7-term (~7%)
  • Coefficients are physically interpretable (β₃ ∈ [0.7, 1.1], β₁ ∈ [0.2, 0.6])

System-level accuracy

  • BLIS-in-the-loop E2E MAPE < 15% across all training experiments
  • TTFT worst case < 100% (no false saturation)
  • CodeLlama-34B at 20 req/s with max_num_seqs=128 completes without queueing explosion
  • Saturation cliff within 20% of real cliff for each model/TP/workload

Generalization

  • Cross-GPU: H100-trained coefficients produce < 20% E2E MAPE on A100 (with correct hardware specs)
  • Cross-model: coefficients trained on 4 models produce < 20% on a held-out 5th model
  • Cross-workload: general/codegen-trained coefficients produce < 25% on reasoning workloads

References

  • Discussion #522 (inference-sim/inference-sim): sim-to-real validation showing original trained-roofline failure
  • investigate-trained-roofline worktree in inference-sim: full experimental artifacts
  • vLLM V1 scheduler: vllm/v1/core/sched/scheduler.py (token budget, FCFS scheduling)
  • vLLM CUDA graphs: vllm/v1/worker/gpu/cudagraph_utils.py:216 (mixed batch → eager mode)
  • vLLM model runner: vllm/v1/worker/gpu/model_runner.py:880 (single forward pass)
  • vLLM MoE EP config: vllm/model_executor/layers/fused_moe/config.py:1019-1037 (EP repurposes TP)
  • vLLM Mixtral model: vllm/model_executor/models/mixtral.py:75 ("shards each expert across all ranks")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions