A research framework for LLM compression that treats pruning as combinatorial optimization — combining hierarchical importance analysis, architecture-aware structured pruning, tabu search, and knowledge distillation in a single pipeline.
Important
This is research software. It is designed for researchers and engineers studying LLM compression, not for one-click model optimization. The repository implements measurement, analysis, pruning, search, and recovery as a connected pipeline so compression strategies can be studied end-to-end.
- Hierarchical importance analysis — score neurons at four levels (neuron, channel, sublayer, block) using Hessian sensitivity, activation magnitude, and SVD-based signal-to-noise
- Three structured pruning techniques — block removal, MLP channel pruning (SwiGLU-aligned), and N:M structured sparsity with bit-packed masks
- Tabu search over pruning configurations — escape local optima via adaptive tenure, elite memory, and apply-restore evaluation (no model copies)
- Knowledge distillation — CE + KL + SquareHead hidden-state matching, with flex mode for cross-architecture teacher-student pairs
- GPTQ and SmoothQuant — post-training quantization adapted from llmcompressor
- Single YAML config — one file controls model loading, data, analysis thresholds, pruning budgets, search, and distillation
- CLI and library — use as
model-opt analyze|prune|optimize|distill --config ...or import modules directly
Prerequisites: Python >= 3.10, PyTorch >= 2.0, Transformers >= 4.35
git clone https://github.com/ddickmann/llm-opt.git
cd llm-opt
pip install -e .Install with all optional dependencies:
pip install -e ".[all,dev]"Or pick what you need:
pip install -e ".[data]" # + datasets, safetensors
pip install -e ".[analytics]" # + scipy, sklearn, tabulate
pip install -e ".[viz]" # + matplotlib, seaborn
pip install -e ".[distill]" # + deepspeedRun the test suite:
make testThe framework implements a five-stage pipeline. Each stage is independently executable and produces serialized artifacts consumed by downstream stages.
┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ COLLECT │───>│ ANALYZE │───>│ PRUNE │───>│ OPTIMIZE │───>│ DISTILL │
│ │ │ │ │ │ │ │ │ │
│ Hessian │ │ 4-level │ │ Block │ │ Tabu search │ │ CE + KL + │
│ accumulation,│ │ importance │ │ removal, │ │ over pruning │ │ SquareHead │
│ activation │ │ scoring, │ │ MLP channel, │ │ configs, │ │ hidden-state │
│ capture, │ │ neuron │ │ N:M │ │ adaptive │ │ matching │
│ SVD-based │ │ classifi- │ │ structured │ │ tenure, │ │ │
│ SNR │ │ cation │ │ sparsity │ │ elite memory │ │ │
└─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘ └─────────────┘
model-opt analyze --config configs/example.yaml
model-opt prune --config configs/example.yaml
model-opt optimize --config configs/example.yaml
model-opt distill --config configs/example.yamlfrom model_opt.analysis.collector import run_collection
from model_opt.analysis.tracker import load_tracker
from model_opt.analytics.importance import analyze_model, MetricsConfig
from model_opt.pruning.block import iterative_block_pruning
from model_opt.pruning.mlp import LayerAlignedMLPPruner
from model_opt.optimization.tabu_search import TabuSearchPrunerEach stage produces serialized outputs (pickle or safetensors) that can be inspected, modified, or fed into custom downstream code.
All parameters are managed through nested dataclasses with YAML serialization and type validation. See configs/example.yaml for a complete reference.
model:
name_or_path: "meta-llama/Llama-3.1-8B-Instruct"
torch_dtype: "bfloat16"
data:
dataset_name: "HuggingFaceH4/ultrachat_200k"
nsamples: 128
seqlen: 2048
pruning:
block:
max_blocks_to_prune: 4
protected_blocks: [0, 1, 30, 31]
channel:
num_groups: 32
sparsity: { n: 2, m: 4, kl_threshold: 0.01 }
optimization:
max_iterations: 500
tabu_tenure: 25
neighborhood_size: 8
distillation:
teacher: { name_or_path: "meta-llama/Llama-3.1-8B-Instruct" }
temperature: 3.0
hardness_ce: 1.0
hardness_kldiv: 1.5
hardness_squarehead: 4.0Register forward hooks on every linear sublayer. During calibration-set inference, accumulate per-neuron activation statistics, blockwise Hessian matrices (via activation outer products), and SVD-based signal-to-noise ratios. All data is captured to CPU and serialized in a DataTracker for offline analysis.
Key implementation detail: Hessian inversion uses regularized blockwise Cholesky decomposition with 8 fallback strategies of increasing regularization strength. TF32 is disabled during Hessian computation for numerical reproducibility.
Score every neuron at four hierarchical levels, each informing a different pruning strategy:
| Level | Metrics | Downstream use |
|---|---|---|
| Neuron | Hessian sensitivity, activation magnitude (robust/rank/outlier), SNR | N:M sparsity patterns |
| Channel | Group importance, adjacency influence propagation | MLP channel pruning positions |
| Sublayer | Aggregated q/k/v/gate/up/down projection scores | Per-sublayer pruning budget |
| Block | Full-layer importance + positional weighting | Block removal candidates |
Importance scoring uses a weighted harmonic mean of per-neuron metrics, coefficient-of-variation consistency, and first-order adjacency propagation for network-level impact estimation.
Neuron classification uses robust z-score outlier detection with adaptive thresholding (auto-adjusts to maintain 10–30% outlier rate regardless of distribution shape):
- Important — high outliers; preserved during pruning
- Redundant — low outliers; highest pruning priority
- Ambiguous — resolved via block-hierarchy labels from the analysis stage
Block pruning — Remove entire transformer layers. Acceptance is gated by perplexity increase and KL divergence against precomputed teacher logits. First and last blocks are protectable via configuration.
MLP channel pruning — Zero-out neuron channels. The LayerAlignedMLPPruner enforces that gate_proj and up_proj share identical pruning positions within each layer (required by SwiGLU-style architectures), then structurally removes the corresponding down_proj input dimensions.
N:M structured sparsity — Hessian-guided weight selection within N:M blocks (e.g. 2:4). Sparsity masks are bit-packed into int64 tensors (~64x compression vs. float masks) and enforced via post-optimizer hooks. Layer acceptance is gated by per-task KL divergence thresholds.
The greedy solution from Stage 3 is a local optimum. The tabu search explores the neighborhood of that solution to find better configurations:
- Move generators — block swap (add/remove transformer layers), channel redistribution (shift pruning positions between layers)
- Adaptive tabu tenure — automatically intensifies search near promising regions, diversifies when stagnating
- LRU state cache — avoids re-evaluating previously seen configurations
- Elite solution archive — bounded archive of the k-best solutions found during the search
- Apply-restore evaluation — evaluates candidate states by temporarily modifying weights in-place and restoring them, avoiding O(GB)
deepcopyper candidate
Post-pruning performance recovery via three complementary loss signals:
| Loss | Signal | When to use |
|---|---|---|
| Cross-entropy | Standard LM loss on training tokens | Always |
| KL divergence | Temperature-scaled logit distribution matching | Always (with teacher) |
| SquareHead | Intermediate hidden state matching | Architecture-dependent |
SquareHead operates in two modes:
- Standard — 1:1 layer mapping for same-architecture teacher-student pairs (e.g. pruned LLaMA -> original LLaMA)
- Flex — relative-depth layer mapping with adaptive pooling for cross-architecture distillation (e.g. pruned SmolLM -> Mistral teacher), supporting mismatched hidden dimensions and layer counts
model_opt/
├── config.py # Nested dataclass config with YAML serialization
├── analysis/ # Stage 1: Data collection
│ ├── collector.py # Catcher-based forward-hook activation capture
│ ├── scores.py # Hessian accumulator + activation scoring
│ └── tracker.py # Per-neuron metric storage and serialization
├── analytics/ # Stage 2: Hierarchical importance analysis
│ ├── importance.py # Neuron -> channel -> sublayer -> block scoring
│ └── visualizer.py # Heatmaps, distributions, sparsity profiles
├── pruning/ # Stage 3: Pruning algorithms
│ ├── block.py # Greedy block removal with PPL+KL gating
│ ├── channel.py # MLP channel pruning (greedy + evaluation)
│ ├── mlp.py # Layer-aligned structural MLP pruner
│ ├── sparsity.py # N:M structured sparsity manager
│ ├── masks.py # Bit-packed mask compression + optimizer hooks
│ └── prep.py # Neuron categorization + position selection
├── optimization/ # Stage 4: Combinatorial search
│ └── tabu_search.py # Tabu search with adaptive tenure + elite memory
├── quantization/ # Post-training quantization
│ ├── gptq.py # GPTQ: Hessian-weighted weight quantization
│ └── smooth_quant.py # SmoothQuant: activation-aware weight smoothing
├── distillation/ # Stage 5: Knowledge distillation
│ ├── loss.py # KL + standard/flex SquareHead losses
│ ├── trainer.py # HuggingFace Trainer-based distillation loop
│ ├── distributed.py # torch.distributed setup + model factories
│ ├── checkpoint.py # RNG-preserving checkpoint save/load
│ ├── training_utils.py # Convergence detection + ETA estimation
│ ├── data.py # Pretrain + SFT dataloaders with BOS/EOS control
│ └── packing.py # Sequence packing with block-diagonal attention
├── utils/ # Shared infrastructure
│ ├── model_loader.py # HuggingFace model loading + device mapping
│ ├── data_utils.py # Calibration dataset preparation
│ ├── eval_utils.py # Perplexity + KL divergence computation
│ ├── model_utils.py # Architecture introspection + layer navigation
│ ├── materialize.py # Lazy tensor materialization + chunked transfers
│ └── logging.py # Tee-logger (terminal + file)
└── cli/ # Command-line interface
├── analyze.py # model-opt analyze --config ...
├── prune.py # model-opt prune --config ...
├── optimize.py # model-opt optimize --config ...
└── distill.py # model-opt distill --config ...
model-opt is not a loose collection of pruning utilities. The core idea is that LLM compression should be treated as an optimization problem over structure, not just as a one-shot thresholding step.
This repository demonstrates three things:
- Research framing: pruning quality depends on the search procedure, not only on the saliency metric.
- Algorithm design: hierarchical scoring and combinatorial search can produce stronger pruning decisions than single-pass heuristics alone.
- Systems engineering: large-model compression research needs reliable infrastructure for artifact flow, architecture-aware transformations, distributed recovery, and reproducible experimentation.
- Researchers in pruning and distillation who want an executable framework for studying saliency, structural pruning, search, and recovery in one place.
- Engineers working on LLM efficiency who need architecture-aware utilities for modern transformer models and want to inspect the full artifact flow.
- Readers interested in operations research ideas in ML systems who want to see a concrete application of tabu search to a large, structured model optimization problem.
It is less suitable for users looking for a push-button compression library with pre-baked recipes for every architecture.
| Decision | Rationale |
|---|---|
| Platform-agnostic | No hard dependency on DeepSpeed, NCCL, or HPC infrastructure. Distributed training uses torch.distributed. DeepSpeed available as optional extra. |
| Pruning as search | Greedy pruning provides an initial feasible solution, but not necessarily a strong one. Tabu search explores the neighborhood to find better configurations at the same compression budget. |
| Bit-packed masks | Sparsity masks compressed to int64 bit-fields (~64x reduction). Enforced via post-optimizer hooks that survive gradient updates. |
| Apply-restore evaluation | Tabu search evaluates candidate pruning states by temporarily modifying weights in-place and restoring them, avoiding O(GB) model copies. |
| Hierarchical importance | Four-level analysis (neuron -> channel -> sublayer -> block) enables each pruning technique to operate on the abstraction level natural to it. |
| Cross-architecture distillation | Flex SquareHead loss uses relative-depth layer mapping + adaptive pooling, enabling teacher-student pairs with different hidden dimensions and layer counts. |
| Dataclass configuration | Type-safe, validated, YAML-serializable. No magic strings or unvalidated dicts. |
- Architecture coverage. Tested on LLaMA, Mistral, Qwen, and SmolLM-family models. Novel architectures may require minor extensions in
model_utils.py. - Hardware requirements. Hessian accumulation and tabu search require significant GPU memory. For models >= 7B parameters, 40GB+ GPUs are recommended for Stages 1-4. Distillation benefits from multi-GPU setups via
torchrun. - Reproducibility. TF32 is disabled during Hessian computation. All RNG states are preserved in checkpoints for exact training resumption.
- Compression as combinatorial optimization (
optimization/tabu_search.py) — searches over pruning configurations using tabu search with apply-restore evaluation, adaptive tenure, and elite-state memory. - Hierarchical analysis tied to executable pruning (
analytics/,pruning/) — importance modeled across neuron, channel, sublayer, and block levels, translated into architecture-aware pruning actions. - Research infrastructure for large-model recovery (
distillation/,utils/) — distributed recovery, reproducibility-focused checkpoint state, and memory-aware materialization/loading paths.
- SquareHead distillation loss — adapted from IST-DASLab/SparseFinetuning, extended with Flex SquareHead for cross-architecture teacher-student pairs.
- GPTQ, SmoothQuant, and sequence packing — adapted from vllm-project/llm-compressor.
| Package | Status | Purpose |
|---|---|---|
torch >= 2.0 |
Required | Tensor operations, autograd, distributed |
transformers >= 4.35 |
Required | Model loading, tokenizers, Trainer |
pandas |
Required | Neuron/layer metric data processing |
numpy |
Required | Statistical computations, SVD |
pyyaml |
Required | Configuration serialization |
tqdm |
Required | Progress bars |
accelerate |
Required | Device mapping, model parallelism |
datasets |
Optional ([data]) |
Streaming dataloaders for distillation |
scipy, scikit-learn |
Optional ([analytics]) |
Outlier detection, clustering |
matplotlib, seaborn |
Optional ([viz]) |
Importance heatmaps, distribution plots |
deepspeed |
Optional ([distill]) |
ZeRO-based memory-efficient distillation |
If you use this work in your research, please cite:
@inproceedings{Dickmann_2025_CUG,
title={Evaluating AMD Instinct™ MI300A APU: Performance Insights on LLM Training via Knowledge Distillation},
author={Dickmann, Daniel and Offenhäuser, Philipp and Saxena, Rishabh and Markomanolis, George and Rigazzi, Alessandro and Keller, Patrick and Kayabay, Kerem and Hoppe, Dennis},
booktitle={Proceedings of the Cray User Group (CUG '25)},
publisher={ACM},
year={2025},
month={may},
pages={115--126},
doi={10.1145/3757348.3757361},
url={http://dx.doi.org/10.1145/3757348.3757361}
}