Skip to content

Commit 67e3ed8

Browse files
Saurav AgarwalSaurav Agarwal
authored andcommitted
Update parameter convention
1 parent 4663d24 commit 67e3ed8

File tree

5 files changed

+18
-22
lines changed

5 files changed

+18
-22
lines changed

cppsrc/main/test_cnn.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ int main(int argc, char* argv[]) {
5454
YAML::Node config = YAML::LoadFile(config_file);
5555
auto cnn_config = config["CNN"];
5656

57-
std::string data_dir = config["pDataDir"].as<std::string>();
57+
std::string data_dir = config["DataDir"].as<std::string>();
5858

5959
CoverageControlCNN model(cnn_config);
6060
torch::load(model, config["CNNModel"]["Dir"].as<std::string>() + config["CNNModel"]["Model"].as<std::string>());

cppsrc/torch/include/CoverageControlTorch/cnn_backbone.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,8 @@ namespace CoverageControlTorch {
6060
/* std::cout << "x size: " << x.sizes() << std::endl; */
6161
}
6262
x = x.flatten(1);
63-
/* std::cout << "x size: " << x.sizes() << std::endl; */
6463
x = torch::leaky_relu(linear_1_->forward(x));
65-
/* std::cout << "x size: " << x.sizes() << std::endl; */
6664
x = torch::leaky_relu(linear_2_->forward(x));
67-
/* std::cout << "x size: " << x.sizes() << std::endl; */
6865
/* x = linear_3_->forward(x); */
6966
return x;
7067
}

cppsrc/torch/include/CoverageControlTorch/generate_dataset.h

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ namespace CoverageControlTorch {
7474
data_dir_append_ = data_dir_append;
7575

7676
LoadConfigs(config_file);
77-
dataset_size_ = config_["pNumDataset"].as<size_t>();
77+
dataset_size_ = config_["NumDataset"].as<size_t>();
7878
num_robots_ = env_params_.pNumRobots;
7979
comm_range_ = (float)env_params_.pCommunicationRange;
8080
env_resolution_ = (float)env_params_.pResolution;
81-
map_size_ = config_["pCNNMapSize"].as<int>();
82-
every_num_step_ = config_["pEveryNumSteps"].as<size_t>();
83-
trigger_size_ = config_["pTriggerPostProcessing"].as<size_t>();
81+
map_size_ = config_["CNNMapSize"].as<int>();
82+
every_num_step_ = config_["EveryNumSteps"].as<size_t>();
83+
trigger_size_ = config_["TriggerPostProcessing"].as<size_t>();
8484
if(trigger_size_ == 0 or trigger_size_ > dataset_size_ ) {
8585
trigger_size_ = dataset_size_;
8686
}
@@ -208,7 +208,7 @@ namespace CoverageControlTorch {
208208
torch::Tensor diagonal_mask = torch::eye(edge_weights.size(1)).repeat({edge_weights.size(0), 1, 1}).to(torch::kBool);
209209
edge_weights.masked_fill_(diagonal_mask, 0);
210210
edge_weights.to(torch::kCPU);
211-
if(config_["pSaveAsSparseQ"].as<bool>()){
211+
if(config_["SaveAsSparseQ"].as<bool>()){
212212
torch::save(edge_weights.to_sparse(), data_folder_ + "edge_weights.pt");
213213
} else {
214214
torch::save(edge_weights, data_folder_ + "edge_weights.pt");
@@ -259,15 +259,15 @@ namespace CoverageControlTorch {
259259
torch::save(actions_, data_folder_ + "/actions.pt");
260260
torch::save(coverage_features_, data_folder_ + "/coverage_features.pt");
261261

262-
if(config_["pSaveAsSparseQ"].as<bool>()) {
262+
if(config_["SaveAsSparseQ"].as<bool>()) {
263263
torch::save(comm_maps_.to_sparse(), data_folder_ + "/comm_maps.pt");
264264
torch::save(obstacle_maps_.to_sparse(), data_folder_ + "/obstacle_maps.pt");
265265
}
266266
else {
267267
torch::save(comm_maps_, data_folder_ + "/comm_maps.pt");
268268
torch::save(obstacle_maps_, data_folder_ + "/obstacle_maps.pt");
269269
}
270-
if(config_["pNormalizeQ"].as<bool>()) {
270+
if(config_["NormalizeQ"].as<bool>()) {
271271
torch::Tensor actions_mean = at::mean(actions_.view({-1,2}), 0);
272272
torch::Tensor actions_std = at::std(actions_.view({-1,2}), 0);
273273
torch::Tensor normalized_actions = (actions_ - actions_mean)/actions_std;
@@ -310,27 +310,26 @@ namespace CoverageControlTorch {
310310
}
311311

312312
config_ = YAML::LoadFile(config_file);
313-
data_dir_ = config_["pDataDir"].as<std::string>();
313+
data_dir_ = config_["DataDir"].as<std::string>();
314314
data_folder_ = data_dir_ + "/data/" + data_dir_append_ + "/";
315315

316-
// Check if config_["pDataDir"] directory exists
317-
std::string data_dir = config_["pDataDir"].as<std::string>();
318-
if(not std::filesystem::exists(data_dir)) {
319-
throw std::runtime_error("Could not find data directory: " + data_dir);
316+
// Check if config_["DataDir"] directory exists
317+
if(not std::filesystem::exists(data_dir_)) {
318+
throw std::runtime_error("Could not find data directory: " + data_dir_);
320319
}
321320

322321
if(!std::filesystem::exists(data_folder_)) {
323322
std::filesystem::create_directories(data_folder_);
324323
}
325-
// Check if config_["pEnvironmentConfig"] file exists
326-
std::string env_config_file = data_dir + "/" + config_["pEnvironmentConfig"].as<std::string>();
324+
// Check if config_["EnvironmentConfig"] file exists
325+
std::string env_config_file = data_dir_ + "/" + config_["EnvironmentConfig"].as<std::string>();
327326
if(not std::filesystem::exists(env_config_file)) {
328327
throw std::runtime_error("Could not find environment config file: " + env_config_file);
329328
}
330329
env_params_ = CoverageControl::Parameters(env_config_file);
331330

332-
std::string resizer_model_path = data_dir + "/" + config_["pTorchVisionTransformJIT"].as<std::string>();
333-
// Check if config_["pResizerModel"] file exists
331+
std::string resizer_model_path = data_dir_ + "/" + config_["TorchVisionTransformJIT"].as<std::string>();
332+
// Check if config_["ResizerModel"] file exists
334333
if(not std::filesystem::exists(resizer_model_path)) {
335334
throw std::runtime_error("Could not find resizer model file: " + resizer_model_path);
336335
}

cppsrc/torch/include/CoverageControlTorch/train_cnn.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ namespace CoverageControlTorch {
171171
throw std::runtime_error("Could not open config file: " + config_file);
172172
}
173173
config_ = YAML::LoadFile(config_file);
174-
data_dir_ = config_["pDataDir"].as<std::string>();
174+
data_dir_ = config_["DataDir"].as<std::string>();
175175
batch_size_ = config_["CNNTraining"]["BatchSize"].as<size_t>();
176176
num_epochs_ = config_["CNNTraining"]["NumEpochs"].as<size_t>();
177177
learning_rate_ = config_["CNNTraining"]["LearningRate"].as<float>();

cppsrc/torch/include/CoverageControlTorch/train_gnn.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ namespace CoverageControlTorch {
192192
throw std::runtime_error("Could not open config file: " + config_file);
193193
}
194194
config_ = YAML::LoadFile(config_file);
195-
data_dir_ = config_["pDataDir"].as<std::string>();
195+
data_dir_ = config_["DataDir"].as<std::string>();
196196
batch_size_ = config_["GNNTraining"]["BatchSize"].as<size_t>();
197197
num_epochs_ = config_["GNNTraining"]["NumEpochs"].as<size_t>();
198198
learning_rate_ = config_["GNNTraining"]["LearningRate"].as<float>();

0 commit comments

Comments
 (0)