@@ -74,13 +74,13 @@ namespace CoverageControlTorch {
7474 data_dir_append_ = data_dir_append;
7575
7676 LoadConfigs (config_file);
77- dataset_size_ = config_[" pNumDataset " ].as <size_t >();
77+ dataset_size_ = config_[" NumDataset " ].as <size_t >();
7878 num_robots_ = env_params_.pNumRobots ;
7979 comm_range_ = (float )env_params_.pCommunicationRange ;
8080 env_resolution_ = (float )env_params_.pResolution ;
81- map_size_ = config_[" pCNNMapSize " ].as <int >();
82- every_num_step_ = config_[" pEveryNumSteps " ].as <size_t >();
83- trigger_size_ = config_[" pTriggerPostProcessing " ].as <size_t >();
81+ map_size_ = config_[" CNNMapSize " ].as <int >();
82+ every_num_step_ = config_[" EveryNumSteps " ].as <size_t >();
83+ trigger_size_ = config_[" TriggerPostProcessing " ].as <size_t >();
8484 if (trigger_size_ == 0 or trigger_size_ > dataset_size_ ) {
8585 trigger_size_ = dataset_size_;
8686 }
@@ -208,7 +208,7 @@ namespace CoverageControlTorch {
208208 torch::Tensor diagonal_mask = torch::eye (edge_weights.size (1 )).repeat ({edge_weights.size (0 ), 1 , 1 }).to (torch::kBool );
209209 edge_weights.masked_fill_ (diagonal_mask, 0 );
210210 edge_weights.to (torch::kCPU );
211- if (config_[" pSaveAsSparseQ " ].as <bool >()){
211+ if (config_[" SaveAsSparseQ " ].as <bool >()){
212212 torch::save (edge_weights.to_sparse (), data_folder_ + " edge_weights.pt" );
213213 } else {
214214 torch::save (edge_weights, data_folder_ + " edge_weights.pt" );
@@ -259,15 +259,15 @@ namespace CoverageControlTorch {
259259 torch::save (actions_, data_folder_ + " /actions.pt" );
260260 torch::save (coverage_features_, data_folder_ + " /coverage_features.pt" );
261261
262- if (config_[" pSaveAsSparseQ " ].as <bool >()) {
262+ if (config_[" SaveAsSparseQ " ].as <bool >()) {
263263 torch::save (comm_maps_.to_sparse (), data_folder_ + " /comm_maps.pt" );
264264 torch::save (obstacle_maps_.to_sparse (), data_folder_ + " /obstacle_maps.pt" );
265265 }
266266 else {
267267 torch::save (comm_maps_, data_folder_ + " /comm_maps.pt" );
268268 torch::save (obstacle_maps_, data_folder_ + " /obstacle_maps.pt" );
269269 }
270- if (config_[" pNormalizeQ " ].as <bool >()) {
270+ if (config_[" NormalizeQ " ].as <bool >()) {
271271 torch::Tensor actions_mean = at::mean (actions_.view ({-1 ,2 }), 0 );
272272 torch::Tensor actions_std = at::std (actions_.view ({-1 ,2 }), 0 );
273273 torch::Tensor normalized_actions = (actions_ - actions_mean)/actions_std;
@@ -310,27 +310,26 @@ namespace CoverageControlTorch {
310310 }
311311
312312 config_ = YAML::LoadFile (config_file);
313- data_dir_ = config_[" pDataDir " ].as <std::string>();
313+ data_dir_ = config_[" DataDir " ].as <std::string>();
314314 data_folder_ = data_dir_ + " /data/" + data_dir_append_ + " /" ;
315315
316- // Check if config_["pDataDir"] directory exists
317- std::string data_dir = config_[" pDataDir" ].as <std::string>();
318- if (not std::filesystem::exists (data_dir)) {
319- throw std::runtime_error (" Could not find data directory: " + data_dir);
316+ // Check if config_["DataDir"] directory exists
317+ if (not std::filesystem::exists (data_dir_)) {
318+ throw std::runtime_error (" Could not find data directory: " + data_dir_);
320319 }
321320
322321 if (!std::filesystem::exists (data_folder_)) {
323322 std::filesystem::create_directories (data_folder_);
324323 }
325- // Check if config_["pEnvironmentConfig "] file exists
326- std::string env_config_file = data_dir + " /" + config_[" pEnvironmentConfig " ].as <std::string>();
324+ // Check if config_["EnvironmentConfig "] file exists
325+ std::string env_config_file = data_dir_ + " /" + config_[" EnvironmentConfig " ].as <std::string>();
327326 if (not std::filesystem::exists (env_config_file)) {
328327 throw std::runtime_error (" Could not find environment config file: " + env_config_file);
329328 }
330329 env_params_ = CoverageControl::Parameters (env_config_file);
331330
332- std::string resizer_model_path = data_dir + " /" + config_[" pTorchVisionTransformJIT " ].as <std::string>();
333- // Check if config_["pResizerModel "] file exists
331+ std::string resizer_model_path = data_dir_ + " /" + config_[" TorchVisionTransformJIT " ].as <std::string>();
332+ // Check if config_["ResizerModel "] file exists
334333 if (not std::filesystem::exists (resizer_model_path)) {
335334 throw std::runtime_error (" Could not find resizer model file: " + resizer_model_path);
336335 }
0 commit comments