Skip to content

ddickmann/llm-opt

Repository files navigation

model-opt

CI Python 3.10+ License: Apache 2.0 pip installable

A research framework for LLM compression that treats pruning as combinatorial optimization — combining hierarchical importance analysis, architecture-aware structured pruning, tabu search, and knowledge distillation in a single pipeline.


Important

This is research software. It is designed for researchers and engineers studying LLM compression, not for one-click model optimization. The repository implements measurement, analysis, pruning, search, and recovery as a connected pipeline so compression strategies can be studied end-to-end.

Features

  • Hierarchical importance analysis — score neurons at four levels (neuron, channel, sublayer, block) using Hessian sensitivity, activation magnitude, and SVD-based signal-to-noise
  • Three structured pruning techniques — block removal, MLP channel pruning (SwiGLU-aligned), and N:M structured sparsity with bit-packed masks
  • Tabu search over pruning configurations — escape local optima via adaptive tenure, elite memory, and apply-restore evaluation (no model copies)
  • Knowledge distillation — CE + KL + SquareHead hidden-state matching, with flex mode for cross-architecture teacher-student pairs
  • GPTQ and SmoothQuant — post-training quantization adapted from llmcompressor
  • Single YAML config — one file controls model loading, data, analysis thresholds, pruning budgets, search, and distillation
  • CLI and library — use as model-opt analyze|prune|optimize|distill --config ... or import modules directly

Quick Start

Prerequisites: Python >= 3.10, PyTorch >= 2.0, Transformers >= 4.35

git clone https://github.com/ddickmann/llm-opt.git
cd llm-opt
pip install -e .

Install with all optional dependencies:

pip install -e ".[all,dev]"

Or pick what you need:

pip install -e ".[data]"         # + datasets, safetensors
pip install -e ".[analytics]"    # + scipy, sklearn, tabulate
pip install -e ".[viz]"          # + matplotlib, seaborn
pip install -e ".[distill]"      # + deepspeed

Run the test suite:

make test

Pipeline

The framework implements a five-stage pipeline. Each stage is independently executable and produces serialized artifacts consumed by downstream stages.

 ┌─────────────┐    ┌─────────────┐    ┌─────────────┐    ┌─────────────┐    ┌─────────────┐
 │   COLLECT    │───>│   ANALYZE    │───>│    PRUNE     │───>│  OPTIMIZE   │───>│   DISTILL   │
 │              │    │              │    │              │    │              │    │              │
 │ Hessian      │    │ 4-level      │    │ Block        │    │ Tabu search  │    │ CE + KL +    │
 │ accumulation,│    │ importance   │    │ removal,     │    │ over pruning │    │ SquareHead   │
 │ activation   │    │ scoring,     │    │ MLP channel, │    │ configs,     │    │ hidden-state │
 │ capture,     │    │ neuron       │    │ N:M          │    │ adaptive     │    │ matching     │
 │ SVD-based    │    │ classifi-    │    │ structured   │    │ tenure,      │    │              │
 │ SNR          │    │ cation       │    │ sparsity     │    │ elite memory │    │              │
 └─────────────┘    └─────────────┘    └─────────────┘    └─────────────┘    └─────────────┘

CLI Usage

model-opt analyze  --config configs/example.yaml
model-opt prune    --config configs/example.yaml
model-opt optimize --config configs/example.yaml
model-opt distill  --config configs/example.yaml

Library Usage

from model_opt.analysis.collector import run_collection
from model_opt.analysis.tracker import load_tracker
from model_opt.analytics.importance import analyze_model, MetricsConfig
from model_opt.pruning.block import iterative_block_pruning
from model_opt.pruning.mlp import LayerAlignedMLPPruner
from model_opt.optimization.tabu_search import TabuSearchPruner

Each stage produces serialized outputs (pickle or safetensors) that can be inspected, modified, or fed into custom downstream code.

Configuration

All parameters are managed through nested dataclasses with YAML serialization and type validation. See configs/example.yaml for a complete reference.

model:
  name_or_path: "meta-llama/Llama-3.1-8B-Instruct"
  torch_dtype: "bfloat16"

data:
  dataset_name: "HuggingFaceH4/ultrachat_200k"
  nsamples: 128
  seqlen: 2048

pruning:
  block:
    max_blocks_to_prune: 4
    protected_blocks: [0, 1, 30, 31]
  channel:
    num_groups: 32
  sparsity: { n: 2, m: 4, kl_threshold: 0.01 }

optimization:
  max_iterations: 500
  tabu_tenure: 25
  neighborhood_size: 8

distillation:
  teacher: { name_or_path: "meta-llama/Llama-3.1-8B-Instruct" }
  temperature: 3.0
  hardness_ce: 1.0
  hardness_kldiv: 1.5
  hardness_squarehead: 4.0

Methodology

Stage 1 — Data Collection (analysis/)

Register forward hooks on every linear sublayer. During calibration-set inference, accumulate per-neuron activation statistics, blockwise Hessian matrices (via activation outer products), and SVD-based signal-to-noise ratios. All data is captured to CPU and serialized in a DataTracker for offline analysis.

Key implementation detail: Hessian inversion uses regularized blockwise Cholesky decomposition with 8 fallback strategies of increasing regularization strength. TF32 is disabled during Hessian computation for numerical reproducibility.

Stage 2 — Hierarchical Importance Analysis (analytics/)

Score every neuron at four hierarchical levels, each informing a different pruning strategy:

Level Metrics Downstream use
Neuron Hessian sensitivity, activation magnitude (robust/rank/outlier), SNR N:M sparsity patterns
Channel Group importance, adjacency influence propagation MLP channel pruning positions
Sublayer Aggregated q/k/v/gate/up/down projection scores Per-sublayer pruning budget
Block Full-layer importance + positional weighting Block removal candidates

Importance scoring uses a weighted harmonic mean of per-neuron metrics, coefficient-of-variation consistency, and first-order adjacency propagation for network-level impact estimation.

Neuron classification uses robust z-score outlier detection with adaptive thresholding (auto-adjusts to maintain 10–30% outlier rate regardless of distribution shape):

  • Important — high outliers; preserved during pruning
  • Redundant — low outliers; highest pruning priority
  • Ambiguous — resolved via block-hierarchy labels from the analysis stage

Stage 3 — Structured Pruning (pruning/)

Block pruning — Remove entire transformer layers. Acceptance is gated by perplexity increase and KL divergence against precomputed teacher logits. First and last blocks are protectable via configuration.

MLP channel pruning — Zero-out neuron channels. The LayerAlignedMLPPruner enforces that gate_proj and up_proj share identical pruning positions within each layer (required by SwiGLU-style architectures), then structurally removes the corresponding down_proj input dimensions.

N:M structured sparsity — Hessian-guided weight selection within N:M blocks (e.g. 2:4). Sparsity masks are bit-packed into int64 tensors (~64x compression vs. float masks) and enforced via post-optimizer hooks. Layer acceptance is gated by per-task KL divergence thresholds.

Stage 4 — Combinatorial Optimization (optimization/)

The greedy solution from Stage 3 is a local optimum. The tabu search explores the neighborhood of that solution to find better configurations:

  • Move generators — block swap (add/remove transformer layers), channel redistribution (shift pruning positions between layers)
  • Adaptive tabu tenure — automatically intensifies search near promising regions, diversifies when stagnating
  • LRU state cache — avoids re-evaluating previously seen configurations
  • Elite solution archive — bounded archive of the k-best solutions found during the search
  • Apply-restore evaluation — evaluates candidate states by temporarily modifying weights in-place and restoring them, avoiding O(GB) deepcopy per candidate

Stage 5 — Knowledge Distillation (distillation/)

Post-pruning performance recovery via three complementary loss signals:

Loss Signal When to use
Cross-entropy Standard LM loss on training tokens Always
KL divergence Temperature-scaled logit distribution matching Always (with teacher)
SquareHead Intermediate hidden state matching Architecture-dependent

SquareHead operates in two modes:

  • Standard — 1:1 layer mapping for same-architecture teacher-student pairs (e.g. pruned LLaMA -> original LLaMA)
  • Flex — relative-depth layer mapping with adaptive pooling for cross-architecture distillation (e.g. pruned SmolLM -> Mistral teacher), supporting mismatched hidden dimensions and layer counts

Project Structure

model_opt/
├── config.py                  # Nested dataclass config with YAML serialization
├── analysis/                  # Stage 1: Data collection
│   ├── collector.py           #   Catcher-based forward-hook activation capture
│   ├── scores.py              #   Hessian accumulator + activation scoring
│   └── tracker.py             #   Per-neuron metric storage and serialization
├── analytics/                 # Stage 2: Hierarchical importance analysis
│   ├── importance.py          #   Neuron -> channel -> sublayer -> block scoring
│   └── visualizer.py          #   Heatmaps, distributions, sparsity profiles
├── pruning/                   # Stage 3: Pruning algorithms
│   ├── block.py               #   Greedy block removal with PPL+KL gating
│   ├── channel.py             #   MLP channel pruning (greedy + evaluation)
│   ├── mlp.py                 #   Layer-aligned structural MLP pruner
│   ├── sparsity.py            #   N:M structured sparsity manager
│   ├── masks.py               #   Bit-packed mask compression + optimizer hooks
│   └── prep.py                #   Neuron categorization + position selection
├── optimization/              # Stage 4: Combinatorial search
│   └── tabu_search.py         #   Tabu search with adaptive tenure + elite memory
├── quantization/              # Post-training quantization
│   ├── gptq.py                #   GPTQ: Hessian-weighted weight quantization
│   └── smooth_quant.py        #   SmoothQuant: activation-aware weight smoothing
├── distillation/              # Stage 5: Knowledge distillation
│   ├── loss.py                #   KL + standard/flex SquareHead losses
│   ├── trainer.py             #   HuggingFace Trainer-based distillation loop
│   ├── distributed.py         #   torch.distributed setup + model factories
│   ├── checkpoint.py          #   RNG-preserving checkpoint save/load
│   ├── training_utils.py      #   Convergence detection + ETA estimation
│   ├── data.py                #   Pretrain + SFT dataloaders with BOS/EOS control
│   └── packing.py             #   Sequence packing with block-diagonal attention
├── utils/                     # Shared infrastructure
│   ├── model_loader.py        #   HuggingFace model loading + device mapping
│   ├── data_utils.py          #   Calibration dataset preparation
│   ├── eval_utils.py          #   Perplexity + KL divergence computation
│   ├── model_utils.py         #   Architecture introspection + layer navigation
│   ├── materialize.py         #   Lazy tensor materialization + chunked transfers
│   └── logging.py             #   Tee-logger (terminal + file)
└── cli/                       # Command-line interface
    ├── analyze.py             #   model-opt analyze --config ...
    ├── prune.py               #   model-opt prune --config ...
    ├── optimize.py            #   model-opt optimize --config ...
    └── distill.py             #   model-opt distill --config ...

Positioning

model-opt is not a loose collection of pruning utilities. The core idea is that LLM compression should be treated as an optimization problem over structure, not just as a one-shot thresholding step.

This repository demonstrates three things:

  • Research framing: pruning quality depends on the search procedure, not only on the saliency metric.
  • Algorithm design: hierarchical scoring and combinatorial search can produce stronger pruning decisions than single-pass heuristics alone.
  • Systems engineering: large-model compression research needs reliable infrastructure for artifact flow, architecture-aware transformations, distributed recovery, and reproducible experimentation.

Who This Is For

  • Researchers in pruning and distillation who want an executable framework for studying saliency, structural pruning, search, and recovery in one place.
  • Engineers working on LLM efficiency who need architecture-aware utilities for modern transformer models and want to inspect the full artifact flow.
  • Readers interested in operations research ideas in ML systems who want to see a concrete application of tabu search to a large, structured model optimization problem.

It is less suitable for users looking for a push-button compression library with pre-baked recipes for every architecture.

Design Decisions

Decision Rationale
Platform-agnostic No hard dependency on DeepSpeed, NCCL, or HPC infrastructure. Distributed training uses torch.distributed. DeepSpeed available as optional extra.
Pruning as search Greedy pruning provides an initial feasible solution, but not necessarily a strong one. Tabu search explores the neighborhood to find better configurations at the same compression budget.
Bit-packed masks Sparsity masks compressed to int64 bit-fields (~64x reduction). Enforced via post-optimizer hooks that survive gradient updates.
Apply-restore evaluation Tabu search evaluates candidate pruning states by temporarily modifying weights in-place and restoring them, avoiding O(GB) model copies.
Hierarchical importance Four-level analysis (neuron -> channel -> sublayer -> block) enables each pruning technique to operate on the abstraction level natural to it.
Cross-architecture distillation Flex SquareHead loss uses relative-depth layer mapping + adaptive pooling, enabling teacher-student pairs with different hidden dimensions and layer counts.
Dataclass configuration Type-safe, validated, YAML-serializable. No magic strings or unvalidated dicts.

Scope and Limitations

  • Architecture coverage. Tested on LLaMA, Mistral, Qwen, and SmolLM-family models. Novel architectures may require minor extensions in model_utils.py.
  • Hardware requirements. Hessian accumulation and tabu search require significant GPU memory. For models >= 7B parameters, 40GB+ GPUs are recommended for Stages 1-4. Distillation benefits from multi-GPU setups via torchrun.
  • Reproducibility. TF32 is disabled during Hessian computation. All RNG states are preserved in checkpoints for exact training resumption.

Original Contributions

  • Compression as combinatorial optimization (optimization/tabu_search.py) — searches over pruning configurations using tabu search with apply-restore evaluation, adaptive tenure, and elite-state memory.
  • Hierarchical analysis tied to executable pruning (analytics/, pruning/) — importance modeled across neuron, channel, sublayer, and block levels, translated into architecture-aware pruning actions.
  • Research infrastructure for large-model recovery (distillation/, utils/) — distributed recovery, reproducibility-focused checkpoint state, and memory-aware materialization/loading paths.

Attributions

Dependencies

Package Status Purpose
torch >= 2.0 Required Tensor operations, autograd, distributed
transformers >= 4.35 Required Model loading, tokenizers, Trainer
pandas Required Neuron/layer metric data processing
numpy Required Statistical computations, SVD
pyyaml Required Configuration serialization
tqdm Required Progress bars
accelerate Required Device mapping, model parallelism
datasets Optional ([data]) Streaming dataloaders for distillation
scipy, scikit-learn Optional ([analytics]) Outlier detection, clustering
matplotlib, seaborn Optional ([viz]) Importance heatmaps, distribution plots
deepspeed Optional ([distill]) ZeRO-based memory-efficient distillation

Citation

If you use this work in your research, please cite:

@inproceedings{Dickmann_2025_CUG,
  title={Evaluating AMD Instinct™ MI300A APU: Performance Insights on LLM Training via Knowledge Distillation},
  author={Dickmann, Daniel and Offenhäuser, Philipp and Saxena, Rishabh and Markomanolis, George and Rigazzi, Alessandro and Keller, Patrick and Kayabay, Kerem and Hoppe, Dennis},
  booktitle={Proceedings of the Cray User Group (CUG '25)},
  publisher={ACM},
  year={2025},
  month={may},
  pages={115--126},
  doi={10.1145/3757348.3757361},
  url={http://dx.doi.org/10.1145/3757348.3757361}
}

License

Apache 2.0

About

Research framework for LLM compression: hierarchical analysis, structured pruning, tabu search, and knowledge distillation

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors