@@ -92,7 +92,8 @@ def test_get_raw_local_maps():
9292 )
9393 assert local_maps .dtype == torch .float32
9494 saved_local_maps = torch .load (
95- os .path .join (script_dir , "data/coverage_env_utils/local_maps.pt" )
95+ os .path .join (script_dir , "data/coverage_env_utils/local_maps.pt" ),
96+ weights_only = True
9697 )
9798 is_all_close = torch .allclose (local_maps , saved_local_maps )
9899 assert is_all_close
@@ -112,7 +113,8 @@ def test_get_raw_obstacle_maps():
112113 )
113114 assert obstacle_maps .dtype == torch .float32
114115 saved_obstacle_maps = torch .load (
115- os .path .join (script_dir , "data/coverage_env_utils/obstacle_maps.pt" )
116+ os .path .join (script_dir , "data/coverage_env_utils/obstacle_maps.pt" ),
117+ weights_only = True
116118 )
117119 is_all_close = torch .allclose (obstacle_maps , saved_obstacle_maps )
118120 assert is_all_close
@@ -128,7 +130,8 @@ def test_get_communication_maps():
128130 assert comm_maps .shape == (params .pNumRobots , 2 , 32 , 32 )
129131 assert comm_maps .dtype == torch .float32
130132 saved_comm_maps = torch .load (
131- os .path .join (script_dir , "data/coverage_env_utils/comm_maps.pt" )
133+ os .path .join (script_dir , "data/coverage_env_utils/comm_maps.pt" ),
134+ weights_only = True
132135 )
133136 is_all_close = torch .allclose (comm_maps , saved_comm_maps )
134137 max_error = torch .max (torch .abs (comm_maps - saved_comm_maps ))
@@ -147,7 +150,8 @@ def test_resize_maps():
147150 assert resized_local_maps .shape == (params .pNumRobots , 32 , 32 )
148151 assert resized_local_maps .dtype == torch .float32
149152 saved_resized_local_maps = torch .load (
150- os .path .join (script_dir , "data/coverage_env_utils/resized_local_maps.pt" )
153+ os .path .join (script_dir , "data/coverage_env_utils/resized_local_maps.pt" ),
154+ weights_only = True
151155 )
152156 is_all_close = torch .allclose (resized_local_maps , saved_resized_local_maps )
153157 assert is_all_close
@@ -162,7 +166,7 @@ def test_get_maps():
162166 assert isinstance (maps , torch .Tensor )
163167 assert maps .shape == (params .pNumRobots , 4 , 32 , 32 )
164168 assert maps .dtype == torch .float32
165- saved_maps = torch .load (os .path .join (script_dir , "data/coverage_env_utils/maps.pt" ))
169+ saved_maps = torch .load (os .path .join (script_dir , "data/coverage_env_utils/maps.pt" ), weights_only = True )
166170 is_all_close = torch .allclose (maps , saved_maps )
167171 assert is_all_close
168172 is_all_equal = torch .equal (maps , saved_maps )
@@ -178,7 +182,8 @@ def test_get_voronoi_features():
178182 assert voronoi_features .shape == (params .pNumRobots , feature_len )
179183 assert voronoi_features .dtype == torch .float32
180184 saved_voronoi_features = torch .load (
181- os .path .join (script_dir , "data/coverage_env_utils/voronoi_features.pt" )
185+ os .path .join (script_dir , "data/coverage_env_utils/voronoi_features.pt" ),
186+ weights_only = True
182187 )
183188 is_all_close = torch .allclose (voronoi_features , saved_voronoi_features )
184189 assert is_all_close
@@ -194,7 +199,8 @@ def test_get_robot_positions():
194199 assert robot_positions .shape == (params .pNumRobots , 2 )
195200 assert robot_positions .dtype == torch .float32
196201 saved_robot_positions = torch .load (
197- os .path .join (script_dir , "data/coverage_env_utils/robot_positions.pt" )
202+ os .path .join (script_dir , "data/coverage_env_utils/robot_positions.pt" ),
203+ weights_only = True
198204 )
199205 is_all_close = torch .allclose (robot_positions , saved_robot_positions )
200206 assert is_all_close
@@ -210,7 +216,8 @@ def test_get_weights():
210216 assert weights .shape == (params .pNumRobots , params .pNumRobots )
211217 assert weights .dtype == torch .float32
212218 saved_weights = torch .load (
213- os .path .join (script_dir , "data/coverage_env_utils/weights.pt" )
219+ os .path .join (script_dir , "data/coverage_env_utils/weights.pt" ),
220+ weights_only = True
214221 )
215222 is_all_close = torch .allclose (weights , saved_weights )
216223 assert is_all_close
@@ -229,9 +236,10 @@ def test_get_torch_geometric_data():
229236 assert data .x .dtype == torch .float32
230237 assert data .edge_index .shape == (2 , 16 )
231238 assert data .edge_index .dtype == torch .long
232- saved_data = torch .load (
233- os .path .join (script_dir , "data/coverage_env_utils/torch_geometric_data.pt" )
234- )
239+ saved_data = torch_geometric .data .data .Data .from_dict (torch .load (
240+ os .path .join (script_dir , "data/coverage_env_utils/torch_geometric_data.pt" ),
241+ weights_only = True
242+ ))
235243 is_all_close = torch .allclose (data .x , saved_data .x )
236244 assert is_all_close
237245 is_all_equal = torch .equal (data .x , saved_data .x )
0 commit comments