Skip to content

Commit f33c66e

Browse files
author
Saurav Agarwal
committed
Add basic CNN training
1 parent a6d70bc commit f33c66e

7 files changed

Lines changed: 352 additions & 1 deletion

File tree

cppsrc/main/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ target_compile_options(compiler_flags INTERFACE
3030
"$<${gcc_like_cxx}:$<BUILD_INTERFACE:-Wall;-Wextra;-Wshadow;-Wformat=2;-Wunused;-pedantic>>"
3131
"$<${msvc_cxx}:$<BUILD_INTERFACE:-W3>>"
3232
)
33+
3334
add_executable(data_generation data_generation.cpp)
3435
target_link_libraries(data_generation PRIVATE compiler_flags CoverageControlCore CoverageControlTorch)
3536
install(TARGETS data_generation DESTINATION ${CMAKE_INSTALL_BINDIR})
37+
38+
add_executable(train_cnn train_cnn.cpp)
39+
target_link_libraries(train_cnn PRIVATE compiler_flags CoverageControlCore CoverageControlTorch)
40+
install(TARGETS train_cnn DESTINATION ${CMAKE_INSTALL_BINDIR})

cppsrc/main/train_cnn.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
/** Main program for training a CNN for image classification.
2+
*
3+
* @file: train_cnn.cpp
4+
*/
5+
6+
#include <iostream>
7+
#include <CoverageControlTorch/train_cnn.h>
8+
9+
int main(int argc, char* argv[]) {
10+
11+
if (argc < 2) {
12+
std::cout << "Usage: ./train_cnn <dataset_dir>" << std::endl;
13+
return 1;
14+
}
15+
16+
std::string config_file = std::string(argv[1]);
17+
CoverageControlTorch::TrainCNN train_cnn(config_file);
18+
train_cnn.Train();
19+
20+
return 0;
21+
}

cppsrc/setup.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ done
1515
BUILD_DIR=${COVERAGECONTROL_WS}/build
1616
INSTALL_DIR=${COVERAGECONTROL_WS}/install
1717

18-
CMAKE_END_FLAGS="-DCMAKE_BUILD_TYPE=Release -G Ninja"
18+
CMAKE_END_FLAGS="-DCMAKE_BUILD_TYPE=RelWithDebInfo -G Ninja"
1919

2020
CleanBuild () {
2121
rm -rf ${BUILD_DIR}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#ifndef COVERAGECONTROL_CNN_BACKBONE_H_
2+
#define COVERAGECONTROL_CNN_BACKBONE_H_
3+
4+
#include <torch/torch.h>
5+
6+
using namespace torch::indexing;
7+
8+
namespace CoverageControlTorch {
9+
10+
struct CoverageControlCNNImpl : torch::nn::Module {
11+
int input_dim_ = 4;
12+
int output_dim_ = 7;
13+
int num_layers_ = 2;
14+
int latent_size_ = 8;
15+
int kernel_size_ = 3;
16+
int image_size_ = 32;
17+
18+
torch::nn::ModuleList conv_layers_;
19+
torch::nn::ModuleList batch_norm_layers_;
20+
torch::nn::Linear linear_1_;
21+
torch::nn::Linear linear_2_;
22+
23+
CoverageControlCNNImpl(int input_dim, int output_dim, int num_layers, int latent_size, int kernel_size, int image_size) :
24+
input_dim_(input_dim),
25+
output_dim_(output_dim),
26+
num_layers_(num_layers),
27+
latent_size_(latent_size),
28+
kernel_size_(kernel_size),
29+
image_size_(image_size),
30+
conv_layers_(torch::nn::ModuleList()),
31+
batch_norm_layers_(torch::nn::ModuleList()),
32+
linear_1_(nullptr),
33+
linear_2_(nullptr) {
34+
35+
std::vector <int> layers_;
36+
layers_.push_back(input_dim_);
37+
for(int i = 0; i < num_layers_; ++i) {
38+
layers_.push_back(latent_size_);
39+
}
40+
41+
for(int i = 0; i < num_layers_; ++i) {
42+
conv_layers_->push_back(register_module("conv" + std::to_string(i),
43+
torch::nn::Conv2d(torch::nn::Conv2dOptions(layers_[i], layers_[i + 1], 3))));
44+
batch_norm_layers_->push_back(register_module("batch_norm" + std::to_string(i),
45+
torch::nn::BatchNorm2d(layers_[i+1])));
46+
}
47+
48+
size_t flatten_size = latent_size_ * (image_size_ - num_layers_ * (kernel_size_ - 1)) * (image_size_ - num_layers_ * (kernel_size_ - 1));
49+
linear_1_ = register_module("linear_1", torch::nn::Linear(flatten_size, latent_size_));
50+
linear_2_ = register_module("linear_2", torch::nn::Linear(latent_size_, output_dim_));
51+
}
52+
53+
torch::Tensor forward(torch::Tensor x) {
54+
for(size_t i = 0; i < conv_layers_->size(); ++i) {
55+
auto batch_norm = (batch_norm_layers_[i].get())->as<torch::nn::BatchNorm2d>();
56+
auto conv = (conv_layers_[i].get())->as<torch::nn::Conv2d>();
57+
x = torch::tanh(batch_norm->forward(conv->forward(x)));
58+
/* std::cout << "x size: " << x.sizes() << std::endl; */
59+
}
60+
x = x.flatten(1);
61+
/* std::cout << "x size: " << x.sizes() << std::endl; */
62+
x = torch::tanh(linear_1_->forward(x));
63+
/* std::cout << "x size: " << x.sizes() << std::endl; */
64+
x = torch::tanh(linear_2_->forward(x));
65+
/* std::cout << "x size: " << x.sizes() << std::endl; */
66+
x = torch::tanh(x);
67+
return x;
68+
}
69+
};
70+
71+
TORCH_MODULE(CoverageControlCNN);
72+
73+
74+
} // namespace CoverageControlTorch
75+
76+
#endif // COVERAGECONTROL_CNN_BACKBONE_H_
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/** Generator for communication maps **/
2+
3+
#ifndef COVERAGECONTROLTORCH_COMMUNICATION_MAP_GENERATOR_H_
4+
#define COVERAGECONTROLTORCH_COMMUNICATION_MAP_GENERATOR_H_
5+
6+
#include <torch/torch.h>
7+
#include <math.h>
8+
namespace F = torch::nn::functional;
9+
using namespace torch::indexing;
10+
namespace CoverageControlTorch {
11+
// Define a new Module.
12+
struct EdgeWtsCommMapGenerator : torch::nn::Module {
13+
EdgeWtsCommMapGenerator (int size, double communication_range, double resolution) : size_(size), communication_range_(communication_range), resolution_(resolution) {
14+
}
15+
16+
auto eval(torch::Tensor robot_positions) {
17+
torch::Tensor edge_weights;
18+
auto num_robots = robot_positions.size(-1);
19+
auto pairwise_dist_matrices= torch::cdist(robot_positions, robot_positions, 2);
20+
/* auto neg_adjacency = pairwise_dist_matrices > communication_range_; */
21+
/* std::cout << "neg_adjacency: " << neg_adjacency.sizes() << std::endl; */
22+
/* std::cout << "neg_adjacency_type: " << neg_adjacency.dtype() << std::endl; */
23+
/* neg_adjacency.fill_diagonal_(true); */
24+
edge_weights = torch::exp(-(pairwise_dist_matrices.square())/(communication_range_*communication_range_));
25+
F::threshold(edge_weights, F::ThresholdFuncOptions(expf(-1), 0).inplace(true));
26+
/* edge_weights[neg_adjacency == true] = 0; */
27+
28+
torch::Tensor comm_map = torch::empty({num_robots, size_, size_});
29+
auto relative_pos = robot_positions.unsqueeze(2) - robot_positions.unsqueeze(1);
30+
std::cout << "relative_pos: " << relative_pos.sizes() << std::endl;
31+
32+
double comm_scale = (communication_range_ * 2.) / size_;
33+
torch::Tensor map_translation = torch::empty({2});
34+
map_translation.index_put_({0}, size_ * comm_scale * resolution_/2.);
35+
map_translation.index_put_({1}, size_ * comm_scale * resolution_/2.);
36+
for(int i = 0; i < num_robots; ++i) {
37+
for(int j = 0; j < num_robots; ++j) {
38+
if(i == j) { continue; }
39+
auto neighbor_pos = relative_pos.index({Slice(), i, j, Slice()}).to(torch::kCUDA);
40+
std::cout << "neighbor_pos: " << neighbor_pos.sizes() << std::endl;
41+
auto map_pos = neighbor_pos + map_translation;
42+
auto indices = torch::round(map_pos / (resolution_ * comm_scale));
43+
comm_map.index_put_({i, indices}, 1);
44+
}
45+
}
46+
std::vector <torch::Tensor> edge_wts_comm_map{edge_weights, comm_map};
47+
// Return edge weights and communication maps
48+
return edge_wts_comm_map;
49+
}
50+
51+
int size_;
52+
double communication_range_, resolution_;
53+
};
54+
}
55+
56+
/* auto neighbor_indices = (pairwise_dist_matrices[i] <= communication_range_).nonzero()[0]; */
57+
58+
/* auto neighbor_pos = relative_pos.index({Slice(), pairwise_dist_matrices.index({Slice(), i, Slice(), Slice()}) <= communication_range_}); */
59+
/* auto neighbor_indices = (pairwise_dist_matrices[i] <= communication_range_).nonzero()[0]; */
60+
61+
/* auto neighbor_pos = relative_pos.index({Slice(), pairwise_dist_matrices.index({Slice(), i, Slice(), Slice()}) <= communication_range_}); */
62+
/* double comm_scale = (communication_range_ * 2.) / size_; */
63+
/* torch::Tensor map_translation = torch::empty({2}); */
64+
/* map_translation.index_put_({0}, size_ * comm_scale * resolution_/2.); */
65+
/* map_translation.index_put_({1}, size_ * comm_scale * resolution_/2.); */
66+
/* auto map_pos = neighbor_pos + map_translation; */
67+
/* auto indices = torch::round(map_pos / (resolution_ * comm_scale)); */
68+
/* comm_map.index_put_({i, indices}, 1); */
69+
#endif // COVERAGECONTROLTORCH_COMMUNICATION_MAP_GENERATOR_H_
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
/** This file contains the declaration of the class TrainCNN using Torch C++ API.
2+
* The class TrainCNN takes local maps, communication maps, and obstacles maps as input, and
3+
* predicts the voronoi coverage features.
4+
*
5+
**/
6+
7+
#ifndef COVERAGECONTROL_TRAIN_CNN_H_
8+
#define COVERAGECONTROL_TRAIN_CNN_H_
9+
10+
11+
#include <iostream>
12+
#include <string>
13+
#include <vector>
14+
#include <filesystem>
15+
#include <yaml-cpp/yaml.h>
16+
#include <torch/torch.h>
17+
18+
#include "cnn_backbone.h"
19+
20+
using namespace torch::indexing;
21+
namespace F = torch::nn::functional;
22+
23+
namespace CoverageControlTorch {
24+
25+
class TrainCNN {
26+
private:
27+
torch::Tensor maps_;
28+
torch::Tensor features_;
29+
torch::Device device_ = torch::kCPU;
30+
YAML::Node config_;
31+
YAML::Node cnn_config_;
32+
std::string data_dir_;
33+
size_t batch_size_ = 64;
34+
size_t num_epochs_ = 10;
35+
float learning_rate_ = 0.001;
36+
float weight_decay_ = 0.0001;
37+
int image_size_ = 32;
38+
39+
std::shared_ptr<torch::optim::Adam> optimizer_;
40+
public:
41+
42+
TrainCNN(std::string const &config_file) {
43+
if (torch::cuda::is_available()) {
44+
device_ = torch::kCUDA;
45+
std::cout << "Using CUDA" << std::endl;
46+
}
47+
LoadConfigs(config_file);
48+
}
49+
50+
/** Train CNN model.
51+
* @param dataset_dir: the directory of the dataset.
52+
* @param num_layers: the number of convolutional layers.
53+
* @param num_epochs: the number of epochs.
54+
* @param learning_rate: the learning rate.
55+
* @param batch_size: the batch size.
56+
**/
57+
void Train() {
58+
LoadDataset();
59+
60+
CoverageControlCNN model(
61+
cnn_config_["InputDim"].as<int>(),
62+
cnn_config_["OutputDim"].as<int>(),
63+
cnn_config_["NumLayers"].as<int>(),
64+
cnn_config_["LatentSize"].as<int>(),
65+
cnn_config_["KernelSize"].as<int>(),
66+
image_size_);
67+
68+
69+
model->to(device_);
70+
71+
optimizer_ = std::make_shared<torch::optim::Adam>(
72+
model->parameters(),
73+
torch::optim::AdamOptions(learning_rate_).weight_decay(weight_decay_));
74+
75+
size_t dataset_size = maps_.size(0);
76+
for (size_t epoch = 1; epoch < num_epochs_ + 1; ++epoch) {
77+
for (size_t i = 0; i < dataset_size; i += batch_size_) {
78+
auto loss = TrainOneBatch(model, i);
79+
std::cout << "Epoch: " << epoch << ", Batch: " << i << ", Loss: " << loss << std::endl;
80+
}
81+
}
82+
maps_ = maps_.to(device_);
83+
auto pred = model->forward(maps_).to(torch::kCPU);
84+
features_ = features_.to(torch::kCPU);
85+
// Compute loss individually for each feature in features
86+
auto loss = torch::mse_loss(pred, features_);
87+
std::cout << "Final loss: " << loss.item<float>() << std::endl;
88+
auto loss_vec = torch::norm(pred - features_, 2, 0).to(torch::kCPU);
89+
std::cout << "Loss vector: " << loss_vec << std::endl;
90+
std::cout << "Max of feature 0 true: " << features_.index({Slice(), 0}).max() << std::endl;
91+
std::cout << "Max of feature 0 pred: " << pred.index({Slice(), 0}).max() << std::endl;
92+
std::cout << "Max of feature 1 true: " << features_.index({Slice(), 1}).max() << std::endl;
93+
std::cout << "Max of feature 1 pred: " << pred.index({Slice(), 1}).max() << std::endl;
94+
95+
}
96+
97+
float TrainOneBatch(CoverageControlCNN &model, size_t batch_idx) {
98+
torch::Tensor batch = maps_.index({Slice(batch_idx, batch_idx + batch_size_)});
99+
batch = batch.to(device_);
100+
auto x = model->forward(batch);
101+
102+
// Backward and optimize
103+
optimizer_->zero_grad();
104+
torch::Tensor batch_features = features_.index({Slice(batch_idx, batch_idx + batch_size_)}).to(device_);
105+
auto loss = torch::mse_loss(x, batch_features);
106+
loss.backward();
107+
optimizer_->step();
108+
109+
return loss.item<float>();
110+
111+
}
112+
113+
114+
/** Function to load the dataset from the dataset directory.
115+
* @param dataset_dir: the directory of the dataset.
116+
**/
117+
void LoadDataset() {
118+
std::string local_maps_file = data_dir_ + "/local_maps.pt";
119+
std::string comm_maps_file = data_dir_ + "/comm_maps.pt";
120+
std::string obstacle_maps_file = data_dir_ + "/obstacle_maps.pt";
121+
std::string features_file = data_dir_ + "/normalized_coverage_features.pt";
122+
123+
torch::Tensor local_maps;
124+
torch::load(local_maps, local_maps_file);
125+
local_maps = local_maps.unsqueeze(2).view({-1, 1, image_size_, image_size_});
126+
torch::Tensor comm_maps;
127+
torch::load(comm_maps, comm_maps_file);
128+
comm_maps = comm_maps.to_dense().view({-1, 2, image_size_, image_size_});
129+
torch::Tensor obstacle_maps;
130+
torch::load(obstacle_maps, obstacle_maps_file);
131+
obstacle_maps = obstacle_maps.to_dense().unsqueeze(2).view({-1, 1, image_size_, image_size_});
132+
133+
torch::load(features_, features_file);
134+
features_ = features_.view({-1, features_.size(2)});
135+
int output_dim = config_["CNN"]["OutputDim"].as<int>();
136+
features_ = features_.index({Slice(), Slice(0, output_dim)});
137+
138+
maps_ = torch::cat({local_maps, comm_maps, obstacle_maps}, 1);
139+
std::cout << "maps shape: " << maps_.sizes() << std::endl;
140+
141+
}
142+
143+
void LoadConfigs(std::string const &config_file) {
144+
std::cout << "Using config file: " << config_file << std::endl;
145+
// Check if config_file exists
146+
if(not std::filesystem::exists(config_file)) {
147+
throw std::runtime_error("Could not open config file: " + config_file);
148+
}
149+
config_ = YAML::LoadFile(config_file);
150+
data_dir_ = config_["pDataDir"].as<std::string>();
151+
batch_size_ = config_["BatchSize"].as<size_t>();
152+
num_epochs_ = config_["NumEpochs"].as<size_t>();
153+
learning_rate_ = config_["LearningRate"].as<float>();
154+
weight_decay_ = config_["WeightDecay"].as<float>();
155+
156+
cnn_config_ = config_["CNN"] ;
157+
image_size_ = cnn_config_["ImageSize"].as<int>();
158+
}
159+
160+
};
161+
162+
} // namespace CoverageControlTorch
163+
164+
#endif //COVERAGECONTROL_TRAIN_CNN_H_

params/learning_params.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
pDataDir: "/root/CoverageControl_ws/data/pure_coverage/" # Absolute location
2+
3+
GPUs: [4, 5]
4+
5+
LearningRate: 0.001
6+
WeightDecay: 0.0001
7+
BatchSize: 10
8+
NumEpochs: 50
9+
10+
CNN:
11+
InputDim: 4
12+
OutputDim: 7
13+
NumLayers: 2
14+
LatentSize: 8
15+
KernelSize: 3
16+
ImageSize: 32

0 commit comments

Comments
 (0)