ShortKit-ML — Detect and mitigate shortcuts and biases in machine learning embedding spaces. 20+ detection and mitigation methods with a unified API. Multi-attribute support tests multiple sensitive attributes simultaneously. Model Comparison mode for benchmarking multiple embedding models.
- Overview
- Installation
- Quick Start
- Detection Methods
- Overall Assessment Conditions
- MCP Server
- Paper Benchmarks
- Reproducing Paper Results
- GPU Support
- Interactive Dashboard
- Testing
- Contributing
- Citation
ShortKit-ML provides a comprehensive toolkit for detecting and mitigating shortcuts (unwanted biases) in embedding spaces:
- 20+ detection methods: HBAC, Probe, Statistical, Geometric, Bias Direction PCA, Equalized Odds, Demographic Parity, Intersectional, GroupDRO, GCE, Causal Effect, SSA, SIS, CAV, VAE, Early-Epoch Clustering, and more
- 6 mitigation methods: Shortcut Masking, Background Randomization, Adversarial Debiasing, Explanation Regularization, Last Layer Retraining, Contrastive Debiasing
- 5 pluggable risk conditions: indicator_count, majority_vote, weighted_risk, multi_attribute, meta_classifier
Key Features:
- Unified
ShortcutDetectorAPI for all methods - Interactive Gradio dashboard with real-time analysis
- PDF/HTML/Markdown reports with visualizations
- Embedding-only mode (no model access needed)
- Multi-attribute support: test race, gender, age simultaneously
- Model Comparison mode: compare multiple embedding models side-by-side
Available on PyPI at pypi.org/project/shortkit-ml.
pip install shortkit-mlFor all optional extras (dashboard, reporting, VAE, HuggingFace, MCP, etc.):
pip install "shortkit-ml[all]"git clone https://github.com/criticaldata/ShortKit-ML.git
cd ShortKit-ML
pip install -e ".[all]"Or with uv:
uv venv --python 3.10
source .venv/bin/activate # Windows: .venv\Scripts\activate
uv pip install -e ".[all]"# macOS
brew install pango gdk-pixbuf libffi
# Ubuntu/Debian
sudo apt-get install libpango-1.0-0 libpangocairo-1.0-0 libgdk-pixbuf2.0-0HTML and Markdown reports work without these. PDF export is optional.
from shortcut_detect import ShortcutDetector
import numpy as np
embeddings = np.load("embeddings.npy") # (n_samples, embedding_dim)
labels = np.load("labels.npy") # (n_samples,)
detector = ShortcutDetector(methods=['hbac', 'probe', 'statistical', 'geometric', 'equalized_odds'])
detector.fit(embeddings, labels)
detector.generate_report("report.html", format="html")
print(detector.summary())For closed-source models or systems that only expose embeddings:
from shortcut_detect import ShortcutDetector, HuggingFaceEmbeddingSource
hf_source = HuggingFaceEmbeddingSource(model_name="sentence-transformers/all-MiniLM-L6-v2")
detector = ShortcutDetector(methods=["probe", "statistical"])
detector.fit(embeddings=None, labels=labels, group_labels=groups,
raw_inputs=texts, embedding_source=hf_source)See Embedding-Only Guide for
CallableEmbeddingSourceand caching options.
| Method | Key | What It Detects | Reference |
|---|---|---|---|
| HBAC | hbac |
Clustering by protected attributes | - |
| Probe | probe |
Group info recoverable from embeddings | - |
| Statistical | statistical |
Dimensions with group differences | - |
| Geometric | geometric |
Bias directions & prototype overlap | - |
| Bias Direction PCA | bias_direction_pca |
Projection gap along bias direction | Bolukbasi 2016 |
| Equalized Odds | equalized_odds |
TPR/FPR disparities | Hardt 2016 |
| Demographic Parity | demographic_parity |
Prediction rate disparities | Feldman 2015 |
| Early Epoch Clustering | early_epoch_clustering |
Shortcut reliance in early reps | Yang 2023 |
| GCE | gce |
High-loss minority samples | - |
| Frequency | frequency |
Signal in few dimensions | - |
| GradCAM Mask Overlap | gradcam_mask_overlap |
Attention overlap with shortcut masks | - |
| SpRAy | spray |
Spectral clustering of heatmaps | Lapuschkin 2019 |
| CAV | cav |
Concept-level sensitivity | Kim 2018 |
| Causal Effect | causal_effect |
Spurious attribute influence | - |
| VAE | vae |
Latent disentanglement signatures | - |
| SSA | ssa |
Semi-supervised spectral shift | arXiv:2204.02070 |
| Generative CVAE | generative_cvae |
Counterfactual embedding shifts | - |
| GroupDRO | groupdro |
Worst-group performance gaps | Sagawa 2020 |
| SIS | sis |
Sufficient input subsets (minimal dims for prediction) | Carter 2019 |
| Intersectional | intersectional |
Intersectional fairness gaps (2+ attributes) | Buolamwini 2018 |
| Method | Class | Strategy | Reference |
|---|---|---|---|
| Shortcut Masking | ShortcutMasker |
Zero/randomize/inpaint shortcut regions | - |
| Background Randomization | BackgroundRandomizer |
Swap foreground across backgrounds | - |
| Adversarial Debiasing | AdversarialDebiasing |
Remove group information adversarially | Zhang 2018 |
| Explanation Regularization | ExplanationRegularization |
Penalize attention on shortcuts (RRR) | Ross 2017 |
| Last Layer Retraining | LastLayerRetraining |
Retrain final layer balanced (DFR) | Kirichenko 2023 |
| Contrastive Debiasing | ContrastiveDebiasing |
Contrastive loss to align groups (CNC) | - |
See Detection Methods Overview for per-method usage, interpretation guides, and code examples.
ShortcutDetector supports pluggable risk aggregation conditions that control how method-level results map to the final HIGH/MODERATE/LOW summary.
| Condition | Best For | Description |
|---|---|---|
indicator_count |
General use (default) | Count of risk signals: 2+ = HIGH, 1 = MODERATE, 0 = LOW |
majority_vote |
Conservative screening | Consensus across methods |
weighted_risk |
Nuanced analysis | Evidence strength matters (probe accuracy, effect sizes, etc.) |
multi_attribute |
Multi-demographic | Escalates when multiple attributes flag risk |
meta_classifier |
Trained pipelines | Logistic regression meta-model on detector outputs (bundled model included) |
detector = ShortcutDetector(
methods=["probe", "statistical"],
condition_name="weighted_risk",
condition_kwargs={"high_threshold": 0.6, "moderate_threshold": 0.3},
)Custom conditions can be registered via @register_condition("name"). See Conditions API for details.
ShortKit-ML ships an MCP server so AI assistants (Claude, Cursor, etc.) can call detection tools directly from chat — no Python script required.
pip install -e ".[mcp]"# via entry point (after install)
shortkit-ml-mcp
# or directly
python -m shortcut_detect.mcp_server| Tool | Description |
|---|---|
list_methods |
List all 19 detection methods with descriptions |
generate_synthetic_data |
Generate a synthetic shortcut dataset (linear / nonlinear / none) |
run_detector |
Run selected methods on embeddings — returns verdict, risk level, per-method breakdown |
get_summary |
Human-readable summary from a prior run_detector call |
get_method_detail |
Full raw result dict for a single method |
compare_methods |
Side-by-side comparison table + consensus vote across methods |
Add the following to ~/Library/Application Support/Claude/claude_desktop_config.json (macOS):
{
"mcpServers": {
"shortkit-ml": {
"command": "python",
"args": ["-m", "shortcut_detect.mcp_server"],
"cwd": "/path/to/ShortKit-ML"
}
}
}A ready-to-edit template is included at claude_desktop_config.json.
Configure examples/paper_benchmark_config.json to control effect sizes, sample sizes, imbalance ratios, and embedding dimensionalities. A smoke profile (examples/paper_benchmark_config_smoke.json) is provided for quick sanity checks.
python -m shortcut_detect.benchmark.paper_run --config examples/paper_benchmark_config.jsonOutputs CSVs, figures, and summary markdown into output/paper_benchmark/.
Requires a CheXpert manifest (data/chexpert_manifest.csv) plus model-specific embedding pickles. Supported models: medclip, biomedclip, cxr-foundation.
python3 scripts/run_dataset2_benchmark.py \
--manifest data/chexpert_manifest.csv \
--model medclip \
--root . \
--artifacts-dir output/paper_benchmark/chexpert_embeddings \
--config examples/paper_benchmark_config.jsonSee scripts/reproduce_paper.sh and the Dockerfile for full reproducibility.
All paper results are fully reproducible with fixed seeds (seed=42). Every table and figure in the paper can be regenerated from the scripts and data in this repository.
13 benchmark methods are evaluated across all datasets: hbac, probe, statistical, geometric, frequency, bias_direction_pca, sis, demographic_parity, equalized_odds, intersectional, groupdro, gce, ssa. These span 5 paradigms: embedding-level analysis, representation geometry, fairness evaluation, explainability, and training dynamics.
| Step | Command | Output | Time |
|---|---|---|---|
| 1. Install | pip install -e ".[all]" |
Package + deps | 2 min |
| 2. Synthetic benchmarks | python scripts/generate_all_paper_tables.py |
output/paper_tables/*.tex |
~10 min |
| 3. Paper figures | python scripts/generate_paper_figures.py |
output/paper_figures/*.pdf |
~2 min |
| 4. CheXpert benchmark | python scripts/run_chexpert_benchmark.py |
output/paper_benchmark/chexpert_results/ |
~1 min |
| 5. MIMIC-CXR setup | python scripts/setup_mimic_cxr_data.py |
data/mimic_cxr/*.npy |
~1 min |
| 6. MIMIC-CXR benchmark | python scripts/run_mimic_benchmark.py |
output/paper_benchmark/mimic_cxr_results/ |
~2 min |
| 7. CelebA extraction | python scripts/extract_celeba_embeddings.py |
data/celeba/celeba_real_*.npy |
~5 min (MPS) |
| 8. CelebA benchmark | python scripts/run_celeba_real_benchmark.py |
output/paper_benchmark/celeba_real_results/ |
~1 min |
| 9. Full pipeline (smoke) | ./scripts/reproduce_paper.sh smoke |
All synthetic outputs | ~5 min |
| 10. Full pipeline | ./scripts/reproduce_paper.sh full |
All synthetic outputs | ~2-4 hrs |
docker build -t shortcut-detect .
docker run --rm -v $(pwd)/output:/app/output shortcut-detect fullImportant: The
data/folder in this repository is empty. All embeddings and metadata are hosted on HuggingFace and must be downloaded separately. Raw CheXpert and MIMIC-CXR images and labels are not redistributed — access requires accepting the respective dataset licenses (PhysioNet for MIMIC-CXR, Stanford for CheXpert).
All embeddings are hosted on HuggingFace (gated, PhysioNet-restricted access):
MITCriticalData/ShortKit-ML-data
# Download all embeddings into data/
huggingface-cli download MITCriticalData/ShortKit-ML-data --repo-type dataset --local-dir data/| Dataset | Location | Embedding Models | Dim | Samples |
|---|---|---|---|---|
| Synthetic | Generated at runtime | SyntheticGenerator(seed=42) |
128 | Configurable |
| CheXpert | data/chexpert/ |
MedCLIP, ResNet-50, DenseNet-121, ViT-B/16, ViT-B/32, DINOv2, RAD-DINO, MedSigLIP | 512-2048 | 2,000 each |
| MIMIC-CXR | data/mimic_cxr/ |
RAD-DINO, ViT-B/16, ViT-B/32, MedSigLIP | 768-1152 | ~1,500 each |
| CelebA | data/celeba/ |
ResNet-50 (ImageNet) | 2,048 | 10,000 |
| Paper Table | Script | Data | Seed |
|---|---|---|---|
| Tab 3: Synthetic P/R/F1 | generate_all_paper_tables.py |
SyntheticGenerator |
42 |
| Tab 4: False positive rates | generate_all_paper_tables.py |
SyntheticGenerator (null) |
42 |
| Tab 5: Sensitivity analysis | generate_all_paper_tables.py |
SensitivitySweep |
42 |
| Tab 6: CheXpert results | run_chexpert_benchmark.py |
data/chest_embeddings.npy |
42 |
| Tab 7: MIMIC-CXR cross-val | run_mimic_benchmark.py |
data/mimic_cxr/*.npy |
42 |
| Tab 8: CelebA validation | run_celeba_real_benchmark.py |
data/celeba/celeba_real_embeddings.npy |
42 |
| Tab 9: Risk conditions | generate_all_paper_tables.py |
SyntheticGenerator |
42 |
| Fig 2: Convergence matrix | generate_paper_figures.py |
Synthetic + CheXpert | 42 |
See docs/reproducibility.md for full details.
The library auto-selects the best available device. PyTorch components (probes, VAE, GroupDRO, adversarial debiasing, etc.) use the standard torch.device fallback:
| Platform | Backend | Auto-detected |
|---|---|---|
| Linux/Windows with NVIDIA GPU | CUDA | Yes (torch.cuda.is_available()) |
| macOS Apple Silicon | MPS | Partial -- pass device="mps" explicitly |
| CPU (any platform) | CPU | Yes (default fallback) |
Note: Most detection methods (HBAC, statistical, geometric, etc.) run on CPU via NumPy/scikit-learn and do not require GPU. GPU acceleration benefits the torch-based probe, VAE, GroupDRO, and mitigation methods. MPS support depends on PyTorch operator coverage; if you encounter errors on Apple Silicon, fall back to
device="cpu".
python app.py
# Opens at http://127.0.0.1:7860Features: sample CheXpert data, custom CSV upload, PDF/HTML reports, model comparison tab, multi-attribute analysis.
CSV Format:
embedding_0,embedding_1,...,task_label,group_label,attr_race,attr_gender
0.123,0.456,...,1,group_a,Black,MaleSee Dashboard Guide for detailed usage.
pytest tests/ -v
pytest --cov=shortcut_detect --cov-report=html638 tests passing across all detection and mitigation methods.
pip install -e ".[dev]"
pre-commit install- Black for formatting (line length: 100), Ruff for linting, MyPy for types
- Pre-commit hooks run automatically; CI tests on Python 3.10, 3.11, 3.12
- New detectors must implement
DetectorBase. Seedocs/contributing.mdandshortcut_detect/detector_template.py
shortcut_detect/
├── probes/ # Probe-based detection (sklearn + torch)
├── clustering/ # HBAC detector
├── statistical/ # Statistical testing
├── geometric/ # Geometric & bias direction analysis
├── fairness/ # Equalized Odds, Demographic Parity, Intersectional
├── frequency/ # Frequency shortcut detector
├── causal/ # Causal effect detector
├── gce/ # Generalized cross-entropy detector
├── training/ # Early epoch clustering (SPARE)
├── vae/ # VAE latent disentanglement
├── xai/ # CAV, SpRAy, GradCAM mask overlap, SIS
├── ssa/ # Semi-supervised spectral analysis
├── groupdro/ # GroupDRO worst-group robustness
├── conditions/ # Pluggable risk aggregation conditions
│ ├── base.py, registry.py, indicator_count.py, majority_vote.py
│ ├── weighted_risk.py, multi_attribute.py, meta_classifier.py
│ └── meta_model.joblib # Trained meta-classifier (bundled)
├── benchmark/ # Paper benchmark infrastructure
│ ├── runner.py, paper_runner.py, synthetic_generator.py
│ ├── measurement.py, fp_analysis.py, sensitivity.py
│ ├── convergence_viz.py, baseline_comparison.py, figures.py
├── comparison/ # Model comparison runner
├── mitigation/ # Debiasing & masking methods (M01-M07)
├── reporting/ # HTML/PDF/CSV reports & visualizations
├── unified.py # ShortcutDetector unified API
└── detector_base.py # DetectorBase ABC with results_ schema
docs/ # MkDocs documentation site
examples/ # Notebooks and benchmark configs
app.py # Gradio dashboard
Dockerfile # Reproducible environment
scripts/ # Paper reproduction scripts
tests/ # Test suite (475+ tests)
pip install mkdocs mkdocs-material "mkdocstrings[python]" pymdown-extensions
mkdocs serve # http://127.0.0.1:8000- Getting Started
- Detection Methods -- all 20+ methods with guides
- API Reference
- Contributing
@software{shortkit_ml2025,
title={ShortKit-ML: Tools for Identifying Biases in Embedding Spaces},
author={Sebastian Cajas, Aldo Marzullo, Sahil Kapadia, Qingpeng Kong, Filipe Santos, Alessandro Quarta, Leo Celi},
year={2025},
url={https://github.com/criticaldata/ShortKit-ML}
}MIT License - see LICENSE file
- GitHub: criticaldata/ShortKit-ML
- Issues: GitHub Issues