Skip to content

Commit ad149d6

Browse files
Saurav AgarwalSaurav Agarwal
authored andcommitted
Modularize cnn
1 parent c5372a4 commit ad149d6

File tree

4 files changed

+144
-41
lines changed

4 files changed

+144
-41
lines changed

cppsrc/main/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,7 @@ install(TARGETS data_generation DESTINATION ${CMAKE_INSTALL_BINDIR})
3838
add_executable(train_cnn train_cnn.cpp)
3939
target_link_libraries(train_cnn PRIVATE compiler_flags CoverageControlCore CoverageControlTorch)
4040
install(TARGETS train_cnn DESTINATION ${CMAKE_INSTALL_BINDIR})
41+
42+
add_executable(test_cnn test_cnn.cpp)
43+
target_link_libraries(test_cnn PRIVATE compiler_flags CoverageControlCore CoverageControlTorch)
44+
install(TARGETS test_cnn DESTINATION ${CMAKE_INSTALL_BINDIR})

cppsrc/main/test_cnn.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/** Main program for testing a CNN
2+
*
3+
* @file: train_cnn.cpp
4+
*/
5+
6+
#include <string>
7+
#include <iostream>
8+
#include <CoverageControlTorch/cnn_module.h>
9+
10+
using namespace CoverageControlTorch;
11+
void LoadDataset(std::string const &data_dir, int const image_size, int const output_dim, torch::Tensor &maps, torch::Tensor &features, torch::Tensor &features_mean, torch::Tensor &features_std) {
12+
std::string local_maps_file = data_dir + "/local_maps.pt";
13+
std::string comm_maps_file = data_dir + "/comm_maps.pt";
14+
std::string obstacle_maps_file = data_dir + "/obstacle_maps.pt";
15+
std::string features_file = data_dir + "/normalized_coverage_features.pt";
16+
17+
torch::Tensor local_maps;
18+
torch::load(local_maps, local_maps_file);
19+
local_maps = local_maps.unsqueeze(2).view({-1, 1, image_size, image_size});
20+
torch::Tensor comm_maps;
21+
torch::load(comm_maps, comm_maps_file);
22+
comm_maps = comm_maps.to_dense().view({-1, 2, image_size, image_size});
23+
torch::Tensor obstacle_maps;
24+
torch::load(obstacle_maps, obstacle_maps_file);
25+
obstacle_maps = obstacle_maps.to_dense().unsqueeze(2).view({-1, 1, image_size, image_size});
26+
27+
torch::load(features, features_file);
28+
features = features.view({-1, features.size(2)});
29+
features = features.index({Slice(), Slice(0, output_dim)}).to(torch::kCPU);
30+
31+
maps = torch::cat({local_maps, comm_maps, obstacle_maps}, 1).to(torch::kCPU);
32+
33+
torch::load(features_mean, data_dir + "/coverage_features_mean.pt");
34+
torch::load(features_std, data_dir + "/coverage_features_std.pt");
35+
36+
std::cout << "maps shape: " << maps.sizes() << std::endl;
37+
38+
}
39+
40+
int main(int argc, char* argv[]) {
41+
42+
if (argc < 2) {
43+
std::cout << "Usage: ./test_cnn <yaml>" << std::endl;
44+
return 1;
45+
}
46+
std::string config_file = std::string(argv[1]);
47+
48+
torch::Device device(torch::kCPU);
49+
if (torch::cuda::is_available()) {
50+
device = torch::kCUDA;
51+
std::cout << "Using CUDA" << std::endl;
52+
}
53+
54+
YAML::Node config = YAML::LoadFile(config_file);
55+
auto cnn_config = config["CNN"];
56+
57+
std::string data_dir = config["pDataDir"].as<std::string>();
58+
59+
CoverageControlCNN model(cnn_config);
60+
torch::load(model, data_dir + config["CNNTraining"]["ModelCkpt"].as<std::string>());
61+
model->to(device);
62+
63+
int image_size = cnn_config["ImageSize"].as<int>();
64+
int output_dim = cnn_config["OutputDim"].as<int>();
65+
66+
torch::Tensor maps, features, features_mean, features_std;
67+
LoadDataset(data_dir, image_size, output_dim, maps, features, features_mean, features_std);
68+
std::cout << "Loaded dataset" << std::endl;
69+
70+
maps = maps.to(torch::kCPU);
71+
torch::Tensor sub_maps = maps.index({Slice(0, 1000)});
72+
sub_maps = sub_maps.to(device);
73+
torch::Tensor sub_features = features.index({Slice(0, 1000)});
74+
std::cout << "submaps created" << std::endl;
75+
auto pred = model->forward(sub_maps).to(torch::kCPU);
76+
std::cout << "Pred done" << std::endl;
77+
auto loss = torch::mse_loss(pred, sub_features);
78+
std::cout << "Val loss: " << loss << std::endl;
79+
return 0;
80+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
2+
#ifndef COVERAGECONTROL_CNN_MODULE_H_
3+
#define COVERAGECONTROL_CNN_MODULE_H_
4+
5+
#include <thread>
6+
#include <string>
7+
#include <yaml-cpp/yaml.h>
8+
#include <torch/torch.h>
9+
#include "cnn_backbone.h"
10+
11+
using namespace torch::indexing;
12+
namespace F = torch::nn::functional;
13+
14+
namespace CoverageControlTorch {
15+
16+
17+
struct CoverageControlCNNImpl : torch::nn::Module {
18+
int input_dim_ = 4;
19+
int output_dim_ = 7;
20+
int num_layers_ = 2;
21+
int latent_size_ = 8;
22+
int kernel_size_ = 3;
23+
int image_size_ = 32;
24+
25+
CNNBackbone cnn_backbone_;
26+
torch::nn::Linear linear_;
27+
28+
CoverageControlCNNImpl(YAML::Node config) : CoverageControlCNNImpl(
29+
config["InputDim"].as<int>(),
30+
config["OutputDim"].as<int>(),
31+
config["NumLayers"].as<int>(),
32+
config["LatentSize"].as<int>(),
33+
config["KernelSize"].as<int>(),
34+
config["ImageSize"].as<int>()) {}
35+
36+
CoverageControlCNNImpl(int input_dim, int output_dim, int num_layers, int latent_size, int kernel_size, int image_size) :
37+
input_dim_(input_dim),
38+
output_dim_(output_dim),
39+
num_layers_(num_layers),
40+
latent_size_(latent_size),
41+
kernel_size_(kernel_size),
42+
image_size_(image_size),
43+
cnn_backbone_(register_module("cnn_backbone", CNNBackbone(input_dim_, output_dim_, num_layers_, latent_size_, kernel_size_, image_size_))),
44+
linear_(register_module("linear", torch::nn::Linear(2 * output_dim_, output_dim))) {
45+
}
46+
47+
torch::Tensor forward(torch::Tensor x) {
48+
x = cnn_backbone_->forward(x);
49+
x = linear_->forward(x);
50+
return x;
51+
}
52+
};
53+
54+
TORCH_MODULE(CoverageControlCNN);
55+
56+
} // CoverageControlTorch
57+
#endif // COVERAGECONTROL_CNN_MODULE_H_

cppsrc/torch/include/CoverageControlTorch/train_cnn.h

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,44 +17,13 @@
1717
#include <yaml-cpp/yaml.h>
1818
#include <torch/torch.h>
1919

20-
#include "cnn_backbone.h"
20+
#include "cnn_module.h"
2121

2222
using namespace torch::indexing;
2323
namespace F = torch::nn::functional;
2424

2525
namespace CoverageControlTorch {
2626

27-
struct CoverageControlCNNImpl : torch::nn::Module {
28-
int input_dim_ = 4;
29-
int output_dim_ = 7;
30-
int num_layers_ = 2;
31-
int latent_size_ = 8;
32-
int kernel_size_ = 3;
33-
int image_size_ = 32;
34-
35-
CNNBackbone cnn_backbone_;
36-
torch::nn::Linear linear_;
37-
38-
CoverageControlCNNImpl(int input_dim, int output_dim, int num_layers, int latent_size, int kernel_size, int image_size) :
39-
input_dim_(input_dim),
40-
output_dim_(output_dim),
41-
num_layers_(num_layers),
42-
latent_size_(latent_size),
43-
kernel_size_(kernel_size),
44-
image_size_(image_size),
45-
cnn_backbone_(register_module("cnn_backbone", CNNBackbone(input_dim_, output_dim_, num_layers_, latent_size_, kernel_size_, image_size_))),
46-
linear_(register_module("linear", torch::nn::Linear(2 * output_dim_, output_dim))) {
47-
}
48-
49-
torch::Tensor forward(torch::Tensor x) {
50-
x = cnn_backbone_->forward(x);
51-
x = linear_->forward(x);
52-
return x;
53-
}
54-
};
55-
56-
TORCH_MODULE(CoverageControlCNN);
57-
5827
class TrainCNN {
5928
private:
6029
torch::Tensor train_maps_, val_maps_;
@@ -94,14 +63,7 @@ namespace CoverageControlTorch {
9463

9564
LoadDataset();
9665

97-
CoverageControlCNN model(
98-
cnn_config_["InputDim"].as<int>(),
99-
cnn_config_["OutputDim"].as<int>(),
100-
cnn_config_["NumLayers"].as<int>(),
101-
cnn_config_["LatentSize"].as<int>(),
102-
cnn_config_["KernelSize"].as<int>(),
103-
image_size_);
104-
66+
CoverageControlCNN model(cnn_config_);
10567

10668
model->to(device_);
10769

@@ -226,7 +188,7 @@ namespace CoverageControlTorch {
226188
learning_rate_ = config_["CNNTraining"]["LearningRate"].as<float>();
227189
weight_decay_ = config_["CNNTraining"]["WeightDecay"].as<float>();
228190

229-
cnn_config_ = config_["CNN"] ;
191+
cnn_config_ = config_["CNN"];
230192
image_size_ = cnn_config_["ImageSize"].as<int>();
231193
}
232194

0 commit comments

Comments
 (0)