Skip to content

PyTorch: Speed up single animal models#3110

Merged
AlexEMG merged 3 commits intoDeepLabCut:mainfrom
arashsm79:arash/speedup_prediction
Oct 18, 2025
Merged

PyTorch: Speed up single animal models#3110
AlexEMG merged 3 commits intoDeepLabCut:mainfrom
arashsm79:arash/speedup_prediction

Conversation

@arashsm79
Copy link
Copy Markdown
Contributor

@arashsm79 arashsm79 commented Oct 2, 2025

This PR improves the prediction performance of various single animal PyTorch models.
The for loops for extracting values out of the heatmap tensor were replaced with advanced indexing.

Details

Profiling results showed HeatmapPredictor.get_pose_prediction as a major hot spot in single animal inference procedure.

After this change, we see a massive speed-up in the major single animal models as shown below (thanks to @maximpavliv for running the benchmark)

fps_vs_batchsize_128x128 fps_vs_batchsize_256x256 fps_vs_batchsize_512x512

Replace nested loops with vectorized PyTorch operations
to extract heatmap scores and locref offsets in parallel.
@arashsm79 arashsm79 force-pushed the arash/speedup_prediction branch from 8d4decb to e9d1b2c Compare October 9, 2025 12:52
@arashsm79 arashsm79 changed the title [WIP] Speed up PyTorch model predictions Speed up PyTorch single animal models Oct 9, 2025
@arashsm79 arashsm79 changed the title Speed up PyTorch single animal models PyTorch: Speed up single animal models Oct 9, 2025
@arashsm79 arashsm79 marked this pull request as ready for review October 9, 2025 13:01
@maximpavliv maximpavliv self-requested a review October 14, 2025 08:58
Copy link
Copy Markdown
Contributor

@maximpavliv maximpavliv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this improvement — great work! 🙌
The vectorized implementation looks clean and efficient, and it passes all integration tests on my side.

Just one small suggestion: the Example section in the docstring should be updated for consistency with the tensor layout assumed by the code. Specifically, the example

>>> heatmap = torch.rand(32, 17, 64, 64)
>>> locref = torch.rand(32, 17, 64, 64, 2)

should be changed to

>>> heatmap = torch.rand(32, 64, 64, 17)
>>> locref = torch.rand(32, 64, 64, 17, 2)

so it matches the (batch_size, height, width, num_joints) format used in the implementation.

@AlexEMG AlexEMG self-requested a review October 14, 2025 09:29
@AlexEMG AlexEMG merged commit c99fd3b into DeepLabCut:main Oct 18, 2025
6 of 10 checks passed
deruyter92 added a commit to deruyter92/DeepLabCut-live that referenced this pull request Jan 21, 2026
This commit updates the `HeatmapPredictor` in single_predictor.py to follow the implementation in DeepLabCut 3.0.0rc13. See DeepLabCut/DeepLabCut#3110
MMathisLab pushed a commit to DeepLabCut/DeepLabCut-live that referenced this pull request Jan 22, 2026
* DEKRPredictor: add non-maximum suppression (NMS)

This commit Updates the DEKR predictor to follow the DeepLabCut implementation in version 3.0.0rc7, see
DeepLabCut/DeepLabCut#2907

* DEKRPredictor: speed up with vectorized operations

This commit updates the DEKRPredictor to follow the DeepLabCut implementation in version 3.0.0rc13.  see DeepLabCut/DeepLabCut#3121

* PartAffinityFieldPredictor (PAF): Speed up cost computation

This commit updates the PAF predictor to follow the DeepLabCut implementation in version 3.0.0.rc13. See
DeepLabCut/DeepLabCut#3117

* HeatmapPredictor (single animal): speed up with vecorized operations

This commit updates the `HeatmapPredictor` in single_predictor.py to follow the implementation in DeepLabCut 3.0.0rc13. See DeepLabCut/DeepLabCut#3110
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants