@@ -82,32 +82,32 @@ def forward(self, data: torch_geometric.data.Data) -> torch.Tensor:
8282
8383 def load_compiled_state_dict (self , model_state_dict_path : str ) -> None :
8484 # remove _orig_mod from the state dict keys
85- state_dict = torch .load (model_state_dict_path )
85+ state_dict = torch .load (model_state_dict_path , weights_only = True )
8686 new_state_dict = {}
8787 for key in state_dict .keys ():
8888 new_state_dict [key .replace ("_orig_mod." , "" )] = state_dict [key ]
89- self .load_state_dict (new_state_dict , strict = True )
89+ self .load_state_dict (new_state_dict , strict = True , weights_only = True )
9090
9191 def load_model (self , model_state_dict_path : str ) -> None :
9292 """
9393 Load the model from the state dict
9494 """
95- self .load_state_dict (torch .load (model_state_dict_path ), strict = True )
95+ self .load_state_dict (torch .load (model_state_dict_path ), strict = True , weights_only = True )
9696
9797 def load_model_state_dict (self , model_state_dict : dict ) -> None :
9898 """
9999 Load the model from the state dict
100100 """
101- self .load_state_dict (model_state_dict , strict = True )
101+ self .load_state_dict (model_state_dict , strict = True , weights_only = True )
102102
103103 def load_cnn_backbone (self , model_path : str ) -> None :
104104 """
105105 Load the CNN backbone from the model path
106106 """
107- self .load_state_dict (torch .load (model_path ).state_dict (), strict = True )
107+ self .load_state_dict (torch .load (model_path ).state_dict (), strict = True , weights_only = True )
108108
109109 def load_gnn_backbone (self , model_path : str ) -> None :
110110 """
111111 Load the GNN backbone from the model path
112112 """
113- self .load_state_dict (torch .load (model_path ).state_dict (), strict = True )
113+ self .load_state_dict (torch .load (model_path ).state_dict (), strict = True , weights_only = True )
0 commit comments