A comprehensive framework for training high-level policies and language-conditioned Action Chunking with Transformers (ACT) for autonomous surgical suturing.
This repository provides tools for training both low-level and high-level policies for the SutureBot system, enabling autonomous surgical suturing through imitation learning and language conditioning.
- Quick Start
- Training Pipeline
- Low-Level Policy Training
- High-Level Policy Training
- Notes
- Contributing
- Python 3.8.10
- CUDA-compatible GPU (recommended: RTX 4090 with 21GB+ memory)
- Ubuntu/Linux environment
-
Clone the repository
-
Set up Python environment
conda create -n suturebot python=3.8.10 conda activate suturebot pip install -r requirements.txt
-
Optional: Audio processing setup
# For Whisper integration sudo apt update && sudo apt install ffmpeg # For audio recording capabilities sudo apt install portaudio19-dev python3-pyaudio
Add these environment variables to your ~/.bashrc:
export PATH_TO_SUTUREBOT=/path/to/your/srth/directory
export PATH_TO_DATASET=/path/to/your/dataset/folders
export YOUR_CKPT_PATH="$PATH_TO_SUTUREBOT/model_ckpts"Then reload your shell:
source ~/.bashrcβββ src/
β βββ act/ # Low-level policy implementation
β β βββ dvrk_scripts/
β β β βββ constants_dvrk.py # Task configuration settings
β β βββ generic_dataset.py # Dataset handling
β β βββ auto_training_suturing.py # Training orchestration
β β βββ imitate_episodes.py # Core training logic
β β βββ img_aug.py # Image augmentation utilities
β βββ instructor/ # High-level policy implementation
βββ script/ # Utility scripts
β βββ calculate_std_mean.py # Dataset normalization statistics
β βββ suture_point_labeling.py # Annotation tools
β βββ encode_instruction.py # Text embedding generation
Organize your training data according to this hierarchy:
$PATH_TO_DATASET/
βββ [DATASET_NAME]/ # Dataset root directory
β βββ tissue_1/ # Tissue sample 1
β β βββ 1_[task_name]/ # Task directory
β β βββ [episode_timestamp]/ # Episode (timestamped)
β β βββ left_img_dir/ # Left endoscope images
β β β βββ frame000000_left.jpg
β β βββ right_img_dir/ # Right endoscope images
β β β βββ frame000000_right.jpg
β β βββ endo_psm1/ # Right wrist camera
β β β βββ frame000000_psm1.jpg
β β βββ endo_psm2/ # Left wrist camera
β β β βββ frame000000_psm2.jpg
β β βββ ee_csv.csv # Kinematics data
β βββ tissue_2/ # Additional tissue samples
β βββ ... # Same structure as above
Follow these steps to train your low-level policy:
python encode_instruction.py \
--dataset_dir $PATH_TO_DATASET/[DATASET_NAME] \
--encoder distilbert \
--from_countThis creates candidate_embeddings_distilbert.json with task name and direction correction embeddings.
python script/calculate_std_mean.pyNote: Configure tissue IDs and data directory in the script before running. Results are saved to
script/chole/std_mean.txt.
Edit src/act/dvrk_scripts/constants_dvrk.py with your task configuration:
| Parameter | Description | Example Values |
|---|---|---|
dataset_dir |
Path to your dataset | $PATH_TO_DATASET/my_dataset |
num_episodes |
Total episodes (from step 2) | 150 |
tissue_samples_ids |
Tissue IDs for training | [1, 2, 3] |
camera_names |
Cameras to use | ["left", "right", "left_wrist", "right_wrist"] |
action_mode |
Action representation | "hybrid" (recommended) |
norm_scheme |
Normalization method | "std" (recommended) |
goal_condition_style |
goal condition mode ("dot", "map", or "mask") | "dot" |
Configure auto_training_suturing.py with these key parameters:
| Parameter | Description | Recommended Value |
|---|---|---|
task_name |
Task config name | From constants_dvrk.py |
policy_class |
Policy type | "ACT" |
batch_size |
Training batch size | 16 (requires ~21GB GPU) |
num_epochs |
Training epochs | 1000 |
language_encoder |
Text encoder | "distilbert" |
image_encoder |
Vision backbone | "efficientnet_b3film" |
policy_level |
Policy type | "low" |
python auto_training_suturing.pyTip: The training script automatically handles interruptions and resumes from the last checkpoint.
-
Configure Dataset Path
- Set
DATA_DIRinsrc/instructor/constants_daVinci.py - Define camera folder names and validation/test tissue splits
- Labels are automatically extracted from directory names
- Set
-
Dataset Statistics
python script/chole/dataset_rgb_mean_std.py
Alternative: Use ImageNet statistics if dataset-specific stats aren't available.
- Check if all recordings look good:
- Create a video with all demonstrations concatenated for one tissue and check if all demonstrations look good and are in the correct task directory script/chole/concatenate_all_tissue_demos.py (sometimes it happens that a demonstration is saved in the previous task folder (or vice versa)
- Specifically check that the demonstrations are complete. If the demonstrations are incomplete (started to late or ended to early), then concatenating the task recordings is erroneous
- If a task recording started too early or is too long, you can add a βindices_curated.jsonβ in the demonstration directory with the keys βstartβ and/or βendβ giving them the frame index of the curated start/end
| File | Purpose |
|---|---|
model_daVinci.py |
Temporal models (Transformer, etc.) |
backbone_models_daVinci.py |
Vision backbones (ResNet, SwinT) |
dataset_daVinci.py |
Dataset loading and augmentation |
train_daVinci.py |
Training orchestration |
instructor_pipeline.py |
Inference pipeline |
python train_daVinci.py \
--dataset_names [dataset_name]\
--ckpt_dir ./model_ckpts/hl/suturing_hl_3\
--gpu 0 \
--recovery_probability 0.6 \
--batch_size 16 \
--num_epochs 2000 \
--lr 4e-4 \
--min_lr 1e-5 \
--lr_cycle 25 \
--warmup_epochs 5 \
--weight_decay 0.05 \
--validation_interval 10 \
--prediction_offset 15 \
--history_len 4 \
--save_ckpt_interval 5 \
--history_step_size 30 \
--one_hot_flag \
--early_stopping_interval 300 \
--seed 5 \
--plot_val_images_flag \
--max_num_images 5 \
--cameras_to_use left_img_dir \
--backbone_model swin-t \
--model_init_weights imagenet \
--image_dim 224 224 \
--freeze_backbone_until none \
--multitask_loss_weight 0.6 \
--uniform_sampling_flag \
--extra_repeated_phase_last_frame_sampling_flag \
--extra_repeated_phase_last_frame_sampling_probability 0.15 \
--add_center_crop_view_flag \
--global_pool_image_features_flag \
--dataset_mean_std_file_names "dataset_mean_std_camera_type='left_img_dir'_image_step_size=1.json" \
--val_split_number 0 \
--use_complexer_multitask_mlp_head_flag \
--selected_multitasks dominant_moving_direction
- Experimental Features: The codebase contains experimental features that can be removed for simplification
- Deprecated Files:
future_frame_predictor_model.pyhl_correction_publisher_ui_w_whisper.pytemporal_models.py(contains only TCN, not the final Transformer architecture)
When contributing to this project, please ensure:
- Code follows existing style conventions
- New features are documented
- Training configurations are tested before submission
For questions or issues, please contact: [email protected]
@misc{suturebot2025,
title = {SutureBot: A Precision Framework and Benchmark for Autonomous End-to-End Suturing},
author = {Jesse Haworth, Juo-Tung Chen, Nigel Nelson, Ji Woong Kim, Masoud Moghani, Chelsea Finn, Axel Krieger},
year = {2025},
note = {Under review at NeurIPS 2025 Datasets and Benchmarks Track},
howpublished = {\url{https://huggingface.co/datasets/jchen396/suturebot}}
}