Skip to content

Commit 7babd79

Browse files
Saurav AgarwalSaurav Agarwal
authored andcommitted
Update data gen and cnn train
1 parent 8e1c53e commit 7babd79

4 files changed

Lines changed: 115 additions & 41 deletions

File tree

cppsrc/main/data_generation.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
#include <filesystem>
1010

1111
#include <CoverageControlTorch/coverage_system.h>
12+
#include <CoverageControl/algorithms/lloyd_local_voronoi.h>
13+
#include <CoverageControl/algorithms/lloyd_global_online.h>
14+
#include <CoverageControl/algorithms/oracle_global_offline.h>
1215
#include <CoverageControlTorch/generate_dataset.h>
1316

1417
using CoverageControl::Point2;
@@ -33,7 +36,7 @@ int main(int argc, char** argv) {
3336

3437
for(int i = 0; i < num_datasets; i++) {
3538
std::cout << "Generating dataset " << i << std::endl;
36-
CCT::GenerateDataset dataset_generator(argv[1], std::to_string(i));
39+
CCT::GenerateDataset<CC::LloydGlobalOnline> dataset_generator(argv[1], std::to_string(i));
3740
}
3841

3942
return 0;

cppsrc/torch/include/CoverageControlTorch/cnn_backbone.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using namespace torch::indexing;
77

88
namespace CoverageControlTorch {
99

10-
struct CoverageControlCNNImpl : torch::nn::Module {
10+
struct CNNBackboneImpl : torch::nn::Module {
1111
int input_dim_ = 4;
1212
int output_dim_ = 7;
1313
int num_layers_ = 2;
@@ -21,7 +21,7 @@ namespace CoverageControlTorch {
2121
torch::nn::Linear linear_2_;
2222
torch::nn::Linear linear_3_;
2323

24-
CoverageControlCNNImpl(int input_dim, int output_dim, int num_layers, int latent_size, int kernel_size, int image_size) :
24+
CNNBackboneImpl(int input_dim, int output_dim, int num_layers, int latent_size, int kernel_size, int image_size) :
2525
input_dim_(input_dim),
2626
output_dim_(output_dim),
2727
num_layers_(num_layers),
@@ -57,21 +57,21 @@ namespace CoverageControlTorch {
5757
for(size_t i = 0; i < conv_layers_->size(); ++i) {
5858
auto batch_norm = (batch_norm_layers_[i].get())->as<torch::nn::BatchNorm2d>();
5959
auto conv = (conv_layers_[i].get())->as<torch::nn::Conv2d>();
60-
x = torch::tanh(batch_norm->forward(conv->forward(x)));
60+
x = torch::leaky_relu(batch_norm->forward(conv->forward(x)));
6161
/* std::cout << "x size: " << x.sizes() << std::endl; */
6262
}
6363
x = x.flatten(1);
6464
/* std::cout << "x size: " << x.sizes() << std::endl; */
65-
x = torch::tanh(linear_1_->forward(x));
65+
x = torch::leaky_relu(linear_1_->forward(x));
6666
/* std::cout << "x size: " << x.sizes() << std::endl; */
67-
x = torch::tanh(linear_2_->forward(x));
67+
x = torch::leaky_relu(linear_2_->forward(x));
6868
/* std::cout << "x size: " << x.sizes() << std::endl; */
69-
x = linear_3_->forward(x);
69+
/* x = linear_3_->forward(x); */
7070
return x;
7171
}
7272
};
7373

74-
TORCH_MODULE(CoverageControlCNN);
74+
TORCH_MODULE(CNNBackbone);
7575

7676

7777
} // namespace CoverageControlTorch

cppsrc/torch/include/CoverageControlTorch/generate_dataset.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,18 @@
1616

1717
#include <torch/script.h>
1818
#include <torch/torch.h>
19-
#include <CoverageControl/algorithms/lloyd_global_online.h>
20-
#include <CoverageControl/algorithms/oracle_global_offline.h>
21-
2219

2320
using namespace torch::indexing;
2421
typedef long int T_idx_t;
2522
namespace F = torch::nn::functional;
26-
typedef CoverageControl::LloydGlobalOnline CoverageAlgorithm;
23+
/* typedef CoverageControl::LloydGlobalOnline CoverageAlgorithm; */
2724
/* typedef CoverageControl::OracleGlobalOffline CoverageAlgorithm; */
2825

2926
#include "coverage_system.h"
3027

3128
namespace CoverageControlTorch {
3229

30+
template <class CoverageAlgorithm>
3331
class GenerateDataset {
3432

3533
private:

cppsrc/torch/include/CoverageControlTorch/train_cnn.h

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <string>
1313
#include <vector>
1414
#include <filesystem>
15+
#include <thread>
16+
#include <limits>
1517
#include <yaml-cpp/yaml.h>
1618
#include <torch/torch.h>
1719

@@ -22,10 +24,42 @@ namespace F = torch::nn::functional;
2224

2325
namespace CoverageControlTorch {
2426

27+
struct CoverageControlCNNImpl : torch::nn::Module {
28+
int input_dim_ = 4;
29+
int output_dim_ = 7;
30+
int num_layers_ = 2;
31+
int latent_size_ = 8;
32+
int kernel_size_ = 3;
33+
int image_size_ = 32;
34+
35+
CNNBackbone cnn_backbone_;
36+
torch::nn::Linear linear_;
37+
38+
CoverageControlCNNImpl(int input_dim, int output_dim, int num_layers, int latent_size, int kernel_size, int image_size) :
39+
input_dim_(input_dim),
40+
output_dim_(output_dim),
41+
num_layers_(num_layers),
42+
latent_size_(latent_size),
43+
kernel_size_(kernel_size),
44+
image_size_(image_size),
45+
cnn_backbone_(register_module("cnn_backbone", CNNBackbone(input_dim_, output_dim_, num_layers_, latent_size_, kernel_size_, image_size_))),
46+
linear_(register_module("linear", torch::nn::Linear(2 * output_dim_, output_dim))) {
47+
}
48+
49+
torch::Tensor forward(torch::Tensor x) {
50+
x = cnn_backbone_->forward(x);
51+
x = linear_->forward(x);
52+
return x;
53+
}
54+
};
55+
56+
TORCH_MODULE(CoverageControlCNN);
57+
2558
class TrainCNN {
2659
private:
27-
torch::Tensor maps_;
28-
torch::Tensor features_;
60+
torch::Tensor train_maps_, val_maps_;
61+
torch::Tensor train_features_, val_features_;
62+
torch::Tensor features_mean_, features_std_;
2963
torch::Device device_ = torch::kCPU;
3064
YAML::Node config_;
3165
YAML::Node cnn_config_;
@@ -36,7 +70,9 @@ namespace CoverageControlTorch {
3670
float weight_decay_ = 0.0001;
3771
int image_size_ = 32;
3872

39-
std::shared_ptr<torch::optim::Adam> optimizer_;
73+
/* std::shared_ptr<torch::optim::Adam> optimizer_; */
74+
std::shared_ptr<torch::optim::SGD> optimizer_;
75+
4076
public:
4177

4278
TrainCNN(std::string const &config_file) {
@@ -55,6 +91,7 @@ namespace CoverageControlTorch {
5591
* @param batch_size: the batch size.
5692
**/
5793
void Train() {
94+
5895
LoadDataset();
5996

6097
CoverageControlCNN model(
@@ -68,40 +105,63 @@ namespace CoverageControlTorch {
68105

69106
model->to(device_);
70107

71-
optimizer_ = std::make_shared<torch::optim::Adam>(
108+
optimizer_ = std::make_shared<torch::optim::SGD>(
72109
model->parameters(),
73-
torch::optim::AdamOptions(learning_rate_).weight_decay(weight_decay_));
74-
75-
size_t dataset_size = maps_.size(0);
110+
torch::optim::SGDOptions(learning_rate_).weight_decay(weight_decay_).momentum(0.1));
111+
112+
/* optimizer_ = std::make_shared<torch::optim::Adam>( */
113+
/* model->parameters(), */
114+
/* torch::optim::AdamOptions(learning_rate_).weight_decay(weight_decay_)); */
115+
/* const auto num_workers = std::thread::hardware_concurrency()/2; */
116+
/* auto data_loader = torch::data::make_data_loader( */
117+
/* std::move(val_maps_), */
118+
/* torch::data::DataLoaderOptions().batch_size(batch_size_).workers(num_workers).shuffle(true)); */
119+
/* auto data_loader = torch::data::make_data_loader( */
120+
/* torch::data::datasets::TensorDataset(val_maps_, val_features_), */
121+
/* torch::data::DataLoaderOptions().batch_size(batch_size_).workers(num_workers)); */
122+
123+
// Best model parameters
124+
float best_val_loss = std::numeric_limits<float>::max();
125+
std::vector<float> best_params;
126+
127+
size_t dataset_size = train_maps_.size(0);
76128
for (size_t epoch = 1; epoch < num_epochs_ + 1; ++epoch) {
129+
optimizer_->zero_grad();
77130
for (size_t i = 0; i < dataset_size; i += batch_size_) {
78131
auto loss = TrainOneBatch(model, i);
79132
std::cout << "Epoch: " << epoch << ", Batch: " << i << ", Loss: " << loss << std::endl;
80133
}
134+
// Validate
135+
val_maps_ = val_maps_.to(device_);
136+
auto pred = model->forward(val_maps_).to(torch::kCPU);
137+
val_features_ = val_features_.to(torch::kCPU);
138+
// Compute loss individually for each feature in features
139+
/* auto loss_vec = torch::norm(pred - val_features_, 2, 0).to(torch::kCPU); */
140+
auto loss_vec = torch::mse_loss(pred, val_features_, torch::Reduction::None).mean({0}).to(torch::kCPU);
141+
std::cout << "Loss vector: " << loss_vec << std::endl;
142+
auto actual_pred = pred * features_std_ + features_mean_;
143+
auto actual_features = val_features_ * features_std_ + features_mean_;
144+
auto accuracy = 100 * (torch::abs(actual_pred - actual_features).mean({0}))/(actual_features.mean({0}));
145+
std::cout << "Accuracy: " << accuracy << std::endl;
146+
auto loss = torch::mse_loss(pred, val_features_);
147+
std::cout << "Val loss: " << loss << std::endl;
148+
std::cout << "Best Val loss: " << best_val_loss << std::endl;
149+
if (loss.item<float>() < best_val_loss) {
150+
best_val_loss = loss.item<float>();
151+
torch::save(model, data_dir_ + "/model.pt");
152+
torch::save(*optimizer_, data_dir_ + "/optimizer.pt");
153+
}
81154
}
82-
maps_ = maps_.to(device_);
83-
auto pred = model->forward(maps_).to(torch::kCPU);
84-
features_ = features_.to(torch::kCPU);
85-
// Compute loss individually for each feature in features
86-
auto loss = torch::mse_loss(pred, features_);
87-
std::cout << "Final loss: " << loss.item<float>() << std::endl;
88-
auto loss_vec = torch::norm(pred - features_, 2, 0).to(torch::kCPU);
89-
std::cout << "Loss vector: " << loss_vec << std::endl;
90-
std::cout << "Max of feature 0 true: " << features_.index({Slice(), 0}).max() << std::endl;
91-
std::cout << "Max of feature 0 pred: " << pred.index({Slice(), 0}).max() << std::endl;
92-
std::cout << "Max of feature 1 true: " << features_.index({Slice(), 1}).max() << std::endl;
93-
std::cout << "Max of feature 1 pred: " << pred.index({Slice(), 1}).max() << std::endl;
94-
155+
std::cout << "Best validation loss: " << best_val_loss << std::endl;
95156
}
96157

97158
float TrainOneBatch(CoverageControlCNN &model, size_t batch_idx) {
98-
torch::Tensor batch = maps_.index({Slice(batch_idx, batch_idx + batch_size_)});
159+
torch::Tensor batch = train_maps_.index({Slice(batch_idx, batch_idx + batch_size_)});
99160
batch = batch.to(device_);
100161
auto x = model->forward(batch);
101162

102163
// Backward and optimize
103-
optimizer_->zero_grad();
104-
torch::Tensor batch_features = features_.index({Slice(batch_idx, batch_idx + batch_size_)}).to(device_);
164+
torch::Tensor batch_features = train_features_.index({Slice(batch_idx, batch_idx + batch_size_)}).to(device_);
105165
auto loss = torch::mse_loss(x, batch_features);
106166
loss.backward();
107167
optimizer_->step();
@@ -130,13 +190,26 @@ namespace CoverageControlTorch {
130190
torch::load(obstacle_maps, obstacle_maps_file);
131191
obstacle_maps = obstacle_maps.to_dense().unsqueeze(2).view({-1, 1, image_size_, image_size_});
132192

133-
torch::load(features_, features_file);
134-
features_ = features_.view({-1, features_.size(2)});
193+
torch::Tensor features;
194+
torch::load(features, features_file);
195+
features = features.view({-1, features.size(2)});
135196
int output_dim = config_["CNN"]["OutputDim"].as<int>();
136-
features_ = features_.index({Slice(), Slice(0, output_dim)});
137-
138-
maps_ = torch::cat({local_maps, comm_maps, obstacle_maps}, 1);
139-
std::cout << "maps shape: " << maps_.sizes() << std::endl;
197+
features = features.index({Slice(), Slice(0, output_dim)});
198+
199+
torch::Tensor maps = torch::cat({local_maps, comm_maps, obstacle_maps}, 1);
200+
201+
// Split into train and val
202+
size_t num_train = 0.998 * maps.size(0);
203+
size_t num_val = maps.size(0) - num_train;
204+
train_maps_ = maps.index({Slice(0, num_train), Slice()});
205+
train_features_ = features.index({Slice(0, num_train), Slice()});
206+
val_maps_ = maps.index({Slice(num_train, maps.size(0))});
207+
val_features_ = features.index({Slice(num_train, maps.size(0))});
208+
209+
torch::load(features_mean_, data_dir_ + "/coverage_features_mean.pt");
210+
torch::load(features_std_, data_dir_ + "/coverage_features_std.pt");
211+
212+
std::cout << "maps shape: " << maps.sizes() << std::endl;
140213

141214
}
142215

0 commit comments

Comments
 (0)