Skip to content

Commit 7bde4cc

Browse files
deruyter92C-AchardMMathisLab
authored
Fix SuperAnimal / pretrained load for RTMPose: implement convert_weights on RTMCCHead (#3270)
* Fix SuperAnimal / pretrained load for RTMPose: implement convert_weights on RTMCCHead * update RTMCC convert_weights: add flag for optional gau.w omission; ensure deterministic init. --------- Co-authored-by: Cyril Achard <[email protected]> Co-authored-by: Mackenzie Mathis <[email protected]>
1 parent 80b229c commit 7bde4cc

1 file changed

Lines changed: 60 additions & 1 deletion

File tree

deeplabcut/pose_estimation_pytorch/models/heads/rtmcc_head.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from deeplabcut.pose_estimation_pytorch.models.heads.base import (
2727
HEADS,
2828
BaseHead,
29+
WeightConversionMixin,
2930
)
3031
from deeplabcut.pose_estimation_pytorch.models.modules import (
3132
GatedAttentionUnit,
@@ -37,7 +38,7 @@
3738

3839

3940
@HEADS.register_module
40-
class RTMCCHead(BaseHead):
41+
class RTMCCHead(WeightConversionMixin, BaseHead):
4142
"""RTMPose Coordinate Classification head.
4243
4344
The RTMCC head is itself adapted from the SimCC head. For more information, see
@@ -136,6 +137,64 @@ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
136137
x, y = self.cls_x(feats), self.cls_y(feats)
137138
return dict(x=x, y=y)
138139

140+
@staticmethod
141+
def convert_weights(
142+
state_dict: dict[str, torch.Tensor],
143+
module_prefix: str,
144+
conversion: torch.Tensor,
145+
*,
146+
omit_gau_w: bool = False,
147+
) -> dict[str, torch.Tensor]:
148+
"""Re-order / subset bodypart (token) channels for transfer from SuperAnimal.
149+
150+
Args:
151+
state_dict: State dict for this head.
152+
module_prefix: Prefix for state-dict keys.
153+
conversion: Mapping from new bodyparts to source bodyparts.
154+
omit_gau_w: If True, remove ``gau.w`` from the returned dict instead of
155+
constructing a remapped replacement. This requires loading with
156+
``strict=False`` to avoid missing-key errors.
157+
Prefer omitting when source/target keypoint ordering semantics differ.
158+
"""
159+
conv = conversion.long()
160+
k_new = int(conv.shape[0])
161+
162+
# Remap final layer weights and biases if they exist.
163+
fl_w = f"{module_prefix}final_layer.weight"
164+
fl_b = f"{module_prefix}final_layer.bias"
165+
if fl_w in state_dict:
166+
state_dict[fl_w] = state_dict[fl_w][conv]
167+
if fl_b in state_dict:
168+
state_dict[fl_b] = state_dict[fl_b][conv]
169+
170+
# Remap or re-init gau.w if it exists (only if omit_gau_w is False)
171+
w_key = f"{module_prefix}gau.w"
172+
if w_key in state_dict:
173+
if omit_gau_w:
174+
state_dict.pop(w_key, None)
175+
return state_dict
176+
177+
w_old = state_dict[w_key]
178+
k_old = (w_old.shape[0] + 1) // 2
179+
old_center = k_old - 1
180+
new_center = k_new - 1
181+
182+
# Deterministic default for unmapped offsets (mean of original weights).
183+
default_val = w_old.mean()
184+
w_new = torch.empty(2 * k_new - 1, dtype=w_old.dtype, device=w_old.device)
185+
for idx_new, d in enumerate(range(-new_center, new_center + 1)):
186+
old_vals = []
187+
for i in range(k_new):
188+
j = i - d
189+
if not (0 <= j < k_new):
190+
continue
191+
old_idx = int(conv[i] - conv[j]) + old_center
192+
if 0 <= old_idx < w_old.shape[0]:
193+
old_vals.append(w_old[old_idx])
194+
w_new[idx_new] = torch.stack(old_vals).mean() if old_vals else default_val
195+
state_dict[w_key] = w_new
196+
return state_dict
197+
139198
@staticmethod
140199
def update_input_size(model_cfg: dict, input_size: tuple[int, int]) -> None:
141200
"""Updates an RTMPose model configuration file for a new image input size.

0 commit comments

Comments
 (0)