88#
99# Licensed under GNU Lesser General Public License v3.0
1010#
11+ from typing import List , Optional
12+
1113import matplotlib .pyplot as plt
1214import networkx as nx
1315import numpy as np
@@ -891,8 +893,11 @@ def concatenate_data(self):
891893
892894 def format_df (self , animal_names = None ):
893895 data = self .concatenate_data ()
894- if not animal_names or len (animal_names ) != self .n_tracks :
896+ if not animal_names or len (animal_names ) < self .n_tracks :
895897 animal_names = [f"ind{ i } " for i in range (1 , self .n_tracks + 1 )]
898+ elif len (animal_names ) > self .n_tracks :
899+ animal_names = animal_names [:self .n_tracks ]
900+
896901 coords = ["x" , "y" , "likelihood" ]
897902 n_multi_bpts = data .shape [1 ] // (len (animal_names ) * len (coords ))
898903 n_unique_bpts = 0 if self .single is None else self .single .data .shape [1 ]
@@ -1031,6 +1036,7 @@ def stitch_tracklets(
10311036 shuffle = 1 ,
10321037 trainingsetindex = 0 ,
10331038 n_tracks = None ,
1039+ animal_names : Optional [List [str ]] = None ,
10341040 min_length = 10 ,
10351041 split_tracklets = True ,
10361042 prestitch_residuals = True ,
@@ -1071,6 +1077,13 @@ def stitch_tracklets(
10711077 passed if the number of animals in the video is different from
10721078 the number of animals the model was trained on.
10731079
1080+ animal_names: list, optional
1081+ If you want the names given to individuals in the labeled data file, you can
1082+ specify those names as a list here. If given and `n_tracks` is None, `n_tracks`
1083+ will be set to `len(animal_names)`. If `n_tracks` is not None, then it must be
1084+ equal to `len(animal_names)`. If it is not given, then `animal_names` will
1085+ be loaded from the `individuals` in the project config.yaml file.
1086+
10741087 min_length : int, optional
10751088 Tracklets less than `min_length` frames of length
10761089 are considered to be residuals; i.e., they do not participate
@@ -1107,8 +1120,8 @@ def stitch_tracklets(
11071120 tracklets should be stitched together, the lower the returned value.
11081121
11091122 destfolder: string, optional
1110- Specifies the destination folder for analysis data (default is the path of the video). Note that for subsequent analysis this
1111- folder also needs to be passed.
1123+ Specifies the destination folder for analysis data (default is the path of the
1124+ video). Note that for subsequent analysis this folder also needs to be passed.
11121125
11131126 track_method: string, optional
11141127 Specifies the tracker used to generate the pose estimation data.
@@ -1135,7 +1148,14 @@ def stitch_tracklets(
11351148 cfg = auxiliaryfunctions .read_config (config_path )
11361149 track_method = auxfun_multianimal .get_track_method (cfg , track_method = track_method )
11371150
1138- animal_names = cfg ["individuals" ]
1151+ if animal_names is None :
1152+ animal_names = cfg ["individuals" ]
1153+ elif n_tracks is not None and n_tracks != len (animal_names ):
1154+ raise ValueError (
1155+ "When setting both `n_tracks` and `animal_names`, `n_tracks` must be equal "
1156+ f"to len(animal_names)`. Found `n_tracks`={ n_tracks } and `animal_names`="
1157+ f"{ animal_names } of length { len (animal_names )} .`" )
1158+
11391159 if n_tracks is None :
11401160 n_tracks = len (animal_names )
11411161
0 commit comments