|
26 | 26 | from deeplabcut.pose_estimation_pytorch.models.heads.base import ( |
27 | 27 | HEADS, |
28 | 28 | BaseHead, |
| 29 | + WeightConversionMixin, |
29 | 30 | ) |
30 | 31 | from deeplabcut.pose_estimation_pytorch.models.modules import ( |
31 | 32 | GatedAttentionUnit, |
|
37 | 38 |
|
38 | 39 |
|
39 | 40 | @HEADS.register_module |
40 | | -class RTMCCHead(BaseHead): |
| 41 | +class RTMCCHead(WeightConversionMixin, BaseHead): |
41 | 42 | """RTMPose Coordinate Classification head. |
42 | 43 |
|
43 | 44 | The RTMCC head is itself adapted from the SimCC head. For more information, see |
@@ -136,6 +137,64 @@ def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: |
136 | 137 | x, y = self.cls_x(feats), self.cls_y(feats) |
137 | 138 | return dict(x=x, y=y) |
138 | 139 |
|
| 140 | + @staticmethod |
| 141 | + def convert_weights( |
| 142 | + state_dict: dict[str, torch.Tensor], |
| 143 | + module_prefix: str, |
| 144 | + conversion: torch.Tensor, |
| 145 | + *, |
| 146 | + omit_gau_w: bool = False, |
| 147 | + ) -> dict[str, torch.Tensor]: |
| 148 | + """Re-order / subset bodypart (token) channels for transfer from SuperAnimal. |
| 149 | +
|
| 150 | + Args: |
| 151 | + state_dict: State dict for this head. |
| 152 | + module_prefix: Prefix for state-dict keys. |
| 153 | + conversion: Mapping from new bodyparts to source bodyparts. |
| 154 | + omit_gau_w: If True, remove ``gau.w`` from the returned dict instead of |
| 155 | + constructing a remapped replacement. This requires loading with |
| 156 | + ``strict=False`` to avoid missing-key errors. |
| 157 | + Prefer omitting when source/target keypoint ordering semantics differ. |
| 158 | + """ |
| 159 | + conv = conversion.long() |
| 160 | + k_new = int(conv.shape[0]) |
| 161 | + |
| 162 | + # Remap final layer weights and biases if they exist. |
| 163 | + fl_w = f"{module_prefix}final_layer.weight" |
| 164 | + fl_b = f"{module_prefix}final_layer.bias" |
| 165 | + if fl_w in state_dict: |
| 166 | + state_dict[fl_w] = state_dict[fl_w][conv] |
| 167 | + if fl_b in state_dict: |
| 168 | + state_dict[fl_b] = state_dict[fl_b][conv] |
| 169 | + |
| 170 | + # Remap or re-init gau.w if it exists (only if omit_gau_w is False) |
| 171 | + w_key = f"{module_prefix}gau.w" |
| 172 | + if w_key in state_dict: |
| 173 | + if omit_gau_w: |
| 174 | + state_dict.pop(w_key, None) |
| 175 | + return state_dict |
| 176 | + |
| 177 | + w_old = state_dict[w_key] |
| 178 | + k_old = (w_old.shape[0] + 1) // 2 |
| 179 | + old_center = k_old - 1 |
| 180 | + new_center = k_new - 1 |
| 181 | + |
| 182 | + # Deterministic default for unmapped offsets (mean of original weights). |
| 183 | + default_val = w_old.mean() |
| 184 | + w_new = torch.empty(2 * k_new - 1, dtype=w_old.dtype, device=w_old.device) |
| 185 | + for idx_new, d in enumerate(range(-new_center, new_center + 1)): |
| 186 | + old_vals = [] |
| 187 | + for i in range(k_new): |
| 188 | + j = i - d |
| 189 | + if not (0 <= j < k_new): |
| 190 | + continue |
| 191 | + old_idx = int(conv[i] - conv[j]) + old_center |
| 192 | + if 0 <= old_idx < w_old.shape[0]: |
| 193 | + old_vals.append(w_old[old_idx]) |
| 194 | + w_new[idx_new] = torch.stack(old_vals).mean() if old_vals else default_val |
| 195 | + state_dict[w_key] = w_new |
| 196 | + return state_dict |
| 197 | + |
139 | 198 | @staticmethod |
140 | 199 | def update_input_size(model_cfg: dict, input_size: tuple[int, int]) -> None: |
141 | 200 | """Updates an RTMPose model configuration file for a new image input size. |
|
0 commit comments