Skip to content

Commit fa2ebee

Browse files
authored
improved animal_name handling when analyzing videos (#2884)
1 parent 7162424 commit fa2ebee

3 files changed

Lines changed: 43 additions & 11 deletions

File tree

deeplabcut/pose_estimation_tensorflow/predict_videos.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def analyze_videos(
271271
use_shelve=False,
272272
auto_track=True,
273273
n_tracks=None,
274+
animal_names=None,
274275
calibrate=False,
275276
identity_only=False,
276277
use_openvino="CPU" if is_openvino_available else None,
@@ -397,6 +398,13 @@ def analyze_videos(
397398
animals in the video is different from the number of animals the model was
398399
trained on.
399400
401+
animal_names: list[str], optional
402+
If you want the names given to individuals in the labeled data file, you can
403+
specify those names as a list here. If given and `n_tracks` is None, `n_tracks`
404+
will be set to `len(animal_names)`. If `n_tracks` is not None, then it must be
405+
equal to `len(animal_names)`. If it is not given, then `animal_names` will
406+
be loaded from the `individuals` in the project config.yaml file.
407+
400408
use_openvino: str, optional
401409
Use "CPU" for inference if OpenVINO is available in the Python environment.
402410
@@ -630,6 +638,7 @@ def analyze_videos(
630638
trainingsetindex,
631639
destfolder=destfolder,
632640
n_tracks=n_tracks,
641+
animal_names=animal_names,
633642
modelprefix=modelprefix,
634643
save_as_csv=save_as_csv,
635644
)

deeplabcut/refine_training_dataset/stitch.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#
99
# Licensed under GNU Lesser General Public License v3.0
1010
#
11+
from typing import List, Optional
12+
1113
import matplotlib.pyplot as plt
1214
import networkx as nx
1315
import numpy as np
@@ -891,8 +893,11 @@ def concatenate_data(self):
891893

892894
def format_df(self, animal_names=None):
893895
data = self.concatenate_data()
894-
if not animal_names or len(animal_names) != self.n_tracks:
896+
if not animal_names or len(animal_names) < self.n_tracks:
895897
animal_names = [f"ind{i}" for i in range(1, self.n_tracks + 1)]
898+
elif len(animal_names) > self.n_tracks:
899+
animal_names = animal_names[:self.n_tracks]
900+
896901
coords = ["x", "y", "likelihood"]
897902
n_multi_bpts = data.shape[1] // (len(animal_names) * len(coords))
898903
n_unique_bpts = 0 if self.single is None else self.single.data.shape[1]
@@ -1031,6 +1036,7 @@ def stitch_tracklets(
10311036
shuffle=1,
10321037
trainingsetindex=0,
10331038
n_tracks=None,
1039+
animal_names: Optional[List[str]] = None,
10341040
min_length=10,
10351041
split_tracklets=True,
10361042
prestitch_residuals=True,
@@ -1071,6 +1077,13 @@ def stitch_tracklets(
10711077
passed if the number of animals in the video is different from
10721078
the number of animals the model was trained on.
10731079
1080+
animal_names: list, optional
1081+
If you want the names given to individuals in the labeled data file, you can
1082+
specify those names as a list here. If given and `n_tracks` is None, `n_tracks`
1083+
will be set to `len(animal_names)`. If `n_tracks` is not None, then it must be
1084+
equal to `len(animal_names)`. If it is not given, then `animal_names` will
1085+
be loaded from the `individuals` in the project config.yaml file.
1086+
10741087
min_length : int, optional
10751088
Tracklets less than `min_length` frames of length
10761089
are considered to be residuals; i.e., they do not participate
@@ -1107,8 +1120,8 @@ def stitch_tracklets(
11071120
tracklets should be stitched together, the lower the returned value.
11081121
11091122
destfolder: string, optional
1110-
Specifies the destination folder for analysis data (default is the path of the video). Note that for subsequent analysis this
1111-
folder also needs to be passed.
1123+
Specifies the destination folder for analysis data (default is the path of the
1124+
video). Note that for subsequent analysis this folder also needs to be passed.
11121125
11131126
track_method: string, optional
11141127
Specifies the tracker used to generate the pose estimation data.
@@ -1135,7 +1148,14 @@ def stitch_tracklets(
11351148
cfg = auxiliaryfunctions.read_config(config_path)
11361149
track_method = auxfun_multianimal.get_track_method(cfg, track_method=track_method)
11371150

1138-
animal_names = cfg["individuals"]
1151+
if animal_names is None:
1152+
animal_names = cfg["individuals"]
1153+
elif n_tracks is not None and n_tracks != len(animal_names):
1154+
raise ValueError(
1155+
"When setting both `n_tracks` and `animal_names`, `n_tracks` must be equal "
1156+
f"to len(animal_names)`. Found `n_tracks`={n_tracks} and `animal_names`="
1157+
f"{animal_names} of length {len(animal_names)}.`")
1158+
11391159
if n_tracks is None:
11401160
n_tracks = len(animal_names)
11411161

deeplabcut/utils/make_labeled_video.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def create_labeled_video(
455455
456456
displayedindividuals: list[str] or str, optional, default="all"
457457
Individuals plotted in the video.
458-
By default, all individuals present in the config will be showed.
458+
By default, all individuals present in the config will be shown.
459459
460460
codec: str, optional, default="mp4v"
461461
Codec for labeled video. For available options, see
@@ -613,9 +613,6 @@ def create_labeled_video(
613613
)
614614
)
615615

616-
individuals = auxfun_multianimal.IntersectionofIndividualsandOnesGivenbyUser(
617-
cfg, displayedindividuals
618-
)
619616
if draw_skeleton:
620617
bodyparts2connect = cfg["skeleton"]
621618
if displayedbodyparts != "all":
@@ -644,7 +641,7 @@ def create_labeled_video(
644641
DLCscorerlegacy,
645642
track_method,
646643
cfg,
647-
individuals,
644+
displayedindividuals,
648645
color_by,
649646
bodyparts,
650647
codec,
@@ -757,8 +754,14 @@ def proc_video(
757754
print("Labeled video already created. Skipping...")
758755
return
759756

760-
if all(individuals):
761-
df = df.loc(axis=1)[:, individuals]
757+
if individuals != "all":
758+
if isinstance(individuals, str):
759+
individuals = [individuals]
760+
761+
if all(individuals) and "individuals" in df.columns.names:
762+
mask = df.columns.get_level_values("individuals").isin(individuals)
763+
df = df.loc[:, mask]
764+
762765
cropping = metadata["data"]["cropping"]
763766
[x1, x2, y1, y2] = metadata["data"]["cropping_parameters"]
764767
labeled_bpts = [

0 commit comments

Comments
 (0)