1212#include < string>
1313#include < vector>
1414#include < filesystem>
15+ #include < thread>
16+ #include < limits>
1517#include < yaml-cpp/yaml.h>
1618#include < torch/torch.h>
1719
@@ -22,10 +24,42 @@ namespace F = torch::nn::functional;
2224
2325namespace CoverageControlTorch {
2426
27+ struct CoverageControlCNNImpl : torch::nn::Module {
28+ int input_dim_ = 4 ;
29+ int output_dim_ = 7 ;
30+ int num_layers_ = 2 ;
31+ int latent_size_ = 8 ;
32+ int kernel_size_ = 3 ;
33+ int image_size_ = 32 ;
34+
35+ CNNBackbone cnn_backbone_;
36+ torch::nn::Linear linear_;
37+
38+ CoverageControlCNNImpl (int input_dim, int output_dim, int num_layers, int latent_size, int kernel_size, int image_size) :
39+ input_dim_ (input_dim),
40+ output_dim_ (output_dim),
41+ num_layers_ (num_layers),
42+ latent_size_ (latent_size),
43+ kernel_size_ (kernel_size),
44+ image_size_ (image_size),
45+ cnn_backbone_ (register_module(" cnn_backbone" , CNNBackbone(input_dim_, output_dim_, num_layers_, latent_size_, kernel_size_, image_size_))),
46+ linear_ (register_module(" linear" , torch::nn::Linear(2 * output_dim_, output_dim))) {
47+ }
48+
49+ torch::Tensor forward (torch::Tensor x) {
50+ x = cnn_backbone_->forward (x);
51+ x = linear_->forward (x);
52+ return x;
53+ }
54+ };
55+
56+ TORCH_MODULE (CoverageControlCNN);
57+
2558 class TrainCNN {
2659 private:
27- torch::Tensor maps_;
28- torch::Tensor features_;
60+ torch::Tensor train_maps_, val_maps_;
61+ torch::Tensor train_features_, val_features_;
62+ torch::Tensor features_mean_, features_std_;
2963 torch::Device device_ = torch::kCPU ;
3064 YAML::Node config_;
3165 YAML::Node cnn_config_;
@@ -36,7 +70,9 @@ namespace CoverageControlTorch {
3670 float weight_decay_ = 0.0001 ;
3771 int image_size_ = 32 ;
3872
39- std::shared_ptr<torch::optim::Adam> optimizer_;
73+ /* std::shared_ptr<torch::optim::Adam> optimizer_; */
74+ std::shared_ptr<torch::optim::SGD> optimizer_;
75+
4076 public:
4177
4278 TrainCNN (std::string const &config_file) {
@@ -55,6 +91,7 @@ namespace CoverageControlTorch {
5591 * @param batch_size: the batch size.
5692 **/
5793 void Train () {
94+
5895 LoadDataset ();
5996
6097 CoverageControlCNN model (
@@ -68,40 +105,63 @@ namespace CoverageControlTorch {
68105
69106 model->to (device_);
70107
71- optimizer_ = std::make_shared<torch::optim::Adam >(
108+ optimizer_ = std::make_shared<torch::optim::SGD >(
72109 model->parameters (),
73- torch::optim::AdamOptions (learning_rate_).weight_decay (weight_decay_));
74-
75- size_t dataset_size = maps_.size (0 );
110+ torch::optim::SGDOptions (learning_rate_).weight_decay (weight_decay_).momentum (0.1 ));
111+
112+ /* optimizer_ = std::make_shared<torch::optim::Adam>( */
113+ /* model->parameters(), */
114+ /* torch::optim::AdamOptions(learning_rate_).weight_decay(weight_decay_)); */
115+ /* const auto num_workers = std::thread::hardware_concurrency()/2; */
116+ /* auto data_loader = torch::data::make_data_loader( */
117+ /* std::move(val_maps_), */
118+ /* torch::data::DataLoaderOptions().batch_size(batch_size_).workers(num_workers).shuffle(true)); */
119+ /* auto data_loader = torch::data::make_data_loader( */
120+ /* torch::data::datasets::TensorDataset(val_maps_, val_features_), */
121+ /* torch::data::DataLoaderOptions().batch_size(batch_size_).workers(num_workers)); */
122+
123+ // Best model parameters
124+ float best_val_loss = std::numeric_limits<float >::max ();
125+ std::vector<float > best_params;
126+
127+ size_t dataset_size = train_maps_.size (0 );
76128 for (size_t epoch = 1 ; epoch < num_epochs_ + 1 ; ++epoch) {
129+ optimizer_->zero_grad ();
77130 for (size_t i = 0 ; i < dataset_size; i += batch_size_) {
78131 auto loss = TrainOneBatch (model, i);
79132 std::cout << " Epoch: " << epoch << " , Batch: " << i << " , Loss: " << loss << std::endl;
80133 }
134+ // Validate
135+ val_maps_ = val_maps_.to (device_);
136+ auto pred = model->forward (val_maps_).to (torch::kCPU );
137+ val_features_ = val_features_.to (torch::kCPU );
138+ // Compute loss individually for each feature in features
139+ /* auto loss_vec = torch::norm(pred - val_features_, 2, 0).to(torch::kCPU); */
140+ auto loss_vec = torch::mse_loss (pred, val_features_, torch::Reduction::None).mean ({0 }).to (torch::kCPU );
141+ std::cout << " Loss vector: " << loss_vec << std::endl;
142+ auto actual_pred = pred * features_std_ + features_mean_;
143+ auto actual_features = val_features_ * features_std_ + features_mean_;
144+ auto accuracy = 100 * (torch::abs (actual_pred - actual_features).mean ({0 }))/(actual_features.mean ({0 }));
145+ std::cout << " Accuracy: " << accuracy << std::endl;
146+ auto loss = torch::mse_loss (pred, val_features_);
147+ std::cout << " Val loss: " << loss << std::endl;
148+ std::cout << " Best Val loss: " << best_val_loss << std::endl;
149+ if (loss.item <float >() < best_val_loss) {
150+ best_val_loss = loss.item <float >();
151+ torch::save (model, data_dir_ + " /model.pt" );
152+ torch::save (*optimizer_, data_dir_ + " /optimizer.pt" );
153+ }
81154 }
82- maps_ = maps_.to (device_);
83- auto pred = model->forward (maps_).to (torch::kCPU );
84- features_ = features_.to (torch::kCPU );
85- // Compute loss individually for each feature in features
86- auto loss = torch::mse_loss (pred, features_);
87- std::cout << " Final loss: " << loss.item <float >() << std::endl;
88- auto loss_vec = torch::norm (pred - features_, 2 , 0 ).to (torch::kCPU );
89- std::cout << " Loss vector: " << loss_vec << std::endl;
90- std::cout << " Max of feature 0 true: " << features_.index ({Slice (), 0 }).max () << std::endl;
91- std::cout << " Max of feature 0 pred: " << pred.index ({Slice (), 0 }).max () << std::endl;
92- std::cout << " Max of feature 1 true: " << features_.index ({Slice (), 1 }).max () << std::endl;
93- std::cout << " Max of feature 1 pred: " << pred.index ({Slice (), 1 }).max () << std::endl;
94-
155+ std::cout << " Best validation loss: " << best_val_loss << std::endl;
95156 }
96157
97158 float TrainOneBatch (CoverageControlCNN &model, size_t batch_idx) {
98- torch::Tensor batch = maps_ .index ({Slice (batch_idx, batch_idx + batch_size_)});
159+ torch::Tensor batch = train_maps_ .index ({Slice (batch_idx, batch_idx + batch_size_)});
99160 batch = batch.to (device_);
100161 auto x = model->forward (batch);
101162
102163 // Backward and optimize
103- optimizer_->zero_grad ();
104- torch::Tensor batch_features = features_.index ({Slice (batch_idx, batch_idx + batch_size_)}).to (device_);
164+ torch::Tensor batch_features = train_features_.index ({Slice (batch_idx, batch_idx + batch_size_)}).to (device_);
105165 auto loss = torch::mse_loss (x, batch_features);
106166 loss.backward ();
107167 optimizer_->step ();
@@ -130,13 +190,26 @@ namespace CoverageControlTorch {
130190 torch::load (obstacle_maps, obstacle_maps_file);
131191 obstacle_maps = obstacle_maps.to_dense ().unsqueeze (2 ).view ({-1 , 1 , image_size_, image_size_});
132192
133- torch::load (features_, features_file);
134- features_ = features_.view ({-1 , features_.size (2 )});
193+ torch::Tensor features;
194+ torch::load (features, features_file);
195+ features = features.view ({-1 , features.size (2 )});
135196 int output_dim = config_[" CNN" ][" OutputDim" ].as <int >();
136- features_ = features_.index ({Slice (), Slice (0 , output_dim)});
137-
138- maps_ = torch::cat ({local_maps, comm_maps, obstacle_maps}, 1 );
139- std::cout << " maps shape: " << maps_.sizes () << std::endl;
197+ features = features.index ({Slice (), Slice (0 , output_dim)});
198+
199+ torch::Tensor maps = torch::cat ({local_maps, comm_maps, obstacle_maps}, 1 );
200+
201+ // Split into train and val
202+ size_t num_train = 0.998 * maps.size (0 );
203+ size_t num_val = maps.size (0 ) - num_train;
204+ train_maps_ = maps.index ({Slice (0 , num_train), Slice ()});
205+ train_features_ = features.index ({Slice (0 , num_train), Slice ()});
206+ val_maps_ = maps.index ({Slice (num_train, maps.size (0 ))});
207+ val_features_ = features.index ({Slice (num_train, maps.size (0 ))});
208+
209+ torch::load (features_mean_, data_dir_ + " /coverage_features_mean.pt" );
210+ torch::load (features_std_, data_dir_ + " /coverage_features_std.pt" );
211+
212+ std::cout << " maps shape: " << maps.sizes () << std::endl;
140213
141214 }
142215
0 commit comments