|
| 1 | +/** This file contains the declaration of the class TrainCNN using Torch C++ API. |
| 2 | + * The class TrainCNN takes local maps, communication maps, and obstacles maps as input, and |
| 3 | + * predicts the voronoi coverage features. |
| 4 | + * |
| 5 | + **/ |
| 6 | + |
| 7 | +#ifndef COVERAGECONTROL_TRAIN_CNN_H_ |
| 8 | +#define COVERAGECONTROL_TRAIN_CNN_H_ |
| 9 | + |
| 10 | + |
| 11 | +#include <iostream> |
| 12 | +#include <string> |
| 13 | +#include <vector> |
| 14 | +#include <filesystem> |
| 15 | +#include <yaml-cpp/yaml.h> |
| 16 | +#include <torch/torch.h> |
| 17 | + |
| 18 | +#include "cnn_backbone.h" |
| 19 | + |
| 20 | +using namespace torch::indexing; |
| 21 | +namespace F = torch::nn::functional; |
| 22 | + |
| 23 | +namespace CoverageControlTorch { |
| 24 | + |
| 25 | + class TrainCNN { |
| 26 | + private: |
| 27 | + torch::Tensor maps_; |
| 28 | + torch::Tensor features_; |
| 29 | + torch::Device device_ = torch::kCPU; |
| 30 | + YAML::Node config_; |
| 31 | + YAML::Node cnn_config_; |
| 32 | + std::string data_dir_; |
| 33 | + size_t batch_size_ = 64; |
| 34 | + size_t num_epochs_ = 10; |
| 35 | + float learning_rate_ = 0.001; |
| 36 | + float weight_decay_ = 0.0001; |
| 37 | + int image_size_ = 32; |
| 38 | + |
| 39 | + std::shared_ptr<torch::optim::Adam> optimizer_; |
| 40 | + public: |
| 41 | + |
| 42 | + TrainCNN(std::string const &config_file) { |
| 43 | + if (torch::cuda::is_available()) { |
| 44 | + device_ = torch::kCUDA; |
| 45 | + std::cout << "Using CUDA" << std::endl; |
| 46 | + } |
| 47 | + LoadConfigs(config_file); |
| 48 | + } |
| 49 | + |
| 50 | + /** Train CNN model. |
| 51 | + * @param dataset_dir: the directory of the dataset. |
| 52 | + * @param num_layers: the number of convolutional layers. |
| 53 | + * @param num_epochs: the number of epochs. |
| 54 | + * @param learning_rate: the learning rate. |
| 55 | + * @param batch_size: the batch size. |
| 56 | + **/ |
| 57 | + void Train() { |
| 58 | + LoadDataset(); |
| 59 | + |
| 60 | + CoverageControlCNN model( |
| 61 | + cnn_config_["InputDim"].as<int>(), |
| 62 | + cnn_config_["OutputDim"].as<int>(), |
| 63 | + cnn_config_["NumLayers"].as<int>(), |
| 64 | + cnn_config_["LatentSize"].as<int>(), |
| 65 | + cnn_config_["KernelSize"].as<int>(), |
| 66 | + image_size_); |
| 67 | + |
| 68 | + |
| 69 | + model->to(device_); |
| 70 | + |
| 71 | + optimizer_ = std::make_shared<torch::optim::Adam>( |
| 72 | + model->parameters(), |
| 73 | + torch::optim::AdamOptions(learning_rate_).weight_decay(weight_decay_)); |
| 74 | + |
| 75 | + size_t dataset_size = maps_.size(0); |
| 76 | + for (size_t epoch = 1; epoch < num_epochs_ + 1; ++epoch) { |
| 77 | + for (size_t i = 0; i < dataset_size; i += batch_size_) { |
| 78 | + auto loss = TrainOneBatch(model, i); |
| 79 | + std::cout << "Epoch: " << epoch << ", Batch: " << i << ", Loss: " << loss << std::endl; |
| 80 | + } |
| 81 | + } |
| 82 | + maps_ = maps_.to(device_); |
| 83 | + auto pred = model->forward(maps_).to(torch::kCPU); |
| 84 | + features_ = features_.to(torch::kCPU); |
| 85 | + // Compute loss individually for each feature in features |
| 86 | + auto loss = torch::mse_loss(pred, features_); |
| 87 | + std::cout << "Final loss: " << loss.item<float>() << std::endl; |
| 88 | + auto loss_vec = torch::norm(pred - features_, 2, 0).to(torch::kCPU); |
| 89 | + std::cout << "Loss vector: " << loss_vec << std::endl; |
| 90 | + std::cout << "Max of feature 0 true: " << features_.index({Slice(), 0}).max() << std::endl; |
| 91 | + std::cout << "Max of feature 0 pred: " << pred.index({Slice(), 0}).max() << std::endl; |
| 92 | + std::cout << "Max of feature 1 true: " << features_.index({Slice(), 1}).max() << std::endl; |
| 93 | + std::cout << "Max of feature 1 pred: " << pred.index({Slice(), 1}).max() << std::endl; |
| 94 | + |
| 95 | + } |
| 96 | + |
| 97 | + float TrainOneBatch(CoverageControlCNN &model, size_t batch_idx) { |
| 98 | + torch::Tensor batch = maps_.index({Slice(batch_idx, batch_idx + batch_size_)}); |
| 99 | + batch = batch.to(device_); |
| 100 | + auto x = model->forward(batch); |
| 101 | + |
| 102 | + // Backward and optimize |
| 103 | + optimizer_->zero_grad(); |
| 104 | + torch::Tensor batch_features = features_.index({Slice(batch_idx, batch_idx + batch_size_)}).to(device_); |
| 105 | + auto loss = torch::mse_loss(x, batch_features); |
| 106 | + loss.backward(); |
| 107 | + optimizer_->step(); |
| 108 | + |
| 109 | + return loss.item<float>(); |
| 110 | + |
| 111 | + } |
| 112 | + |
| 113 | + |
| 114 | + /** Function to load the dataset from the dataset directory. |
| 115 | + * @param dataset_dir: the directory of the dataset. |
| 116 | + **/ |
| 117 | + void LoadDataset() { |
| 118 | + std::string local_maps_file = data_dir_ + "/local_maps.pt"; |
| 119 | + std::string comm_maps_file = data_dir_ + "/comm_maps.pt"; |
| 120 | + std::string obstacle_maps_file = data_dir_ + "/obstacle_maps.pt"; |
| 121 | + std::string features_file = data_dir_ + "/normalized_coverage_features.pt"; |
| 122 | + |
| 123 | + torch::Tensor local_maps; |
| 124 | + torch::load(local_maps, local_maps_file); |
| 125 | + local_maps = local_maps.unsqueeze(2).view({-1, 1, image_size_, image_size_}); |
| 126 | + torch::Tensor comm_maps; |
| 127 | + torch::load(comm_maps, comm_maps_file); |
| 128 | + comm_maps = comm_maps.to_dense().view({-1, 2, image_size_, image_size_}); |
| 129 | + torch::Tensor obstacle_maps; |
| 130 | + torch::load(obstacle_maps, obstacle_maps_file); |
| 131 | + obstacle_maps = obstacle_maps.to_dense().unsqueeze(2).view({-1, 1, image_size_, image_size_}); |
| 132 | + |
| 133 | + torch::load(features_, features_file); |
| 134 | + features_ = features_.view({-1, features_.size(2)}); |
| 135 | + int output_dim = config_["CNN"]["OutputDim"].as<int>(); |
| 136 | + features_ = features_.index({Slice(), Slice(0, output_dim)}); |
| 137 | + |
| 138 | + maps_ = torch::cat({local_maps, comm_maps, obstacle_maps}, 1); |
| 139 | + std::cout << "maps shape: " << maps_.sizes() << std::endl; |
| 140 | + |
| 141 | + } |
| 142 | + |
| 143 | + void LoadConfigs(std::string const &config_file) { |
| 144 | + std::cout << "Using config file: " << config_file << std::endl; |
| 145 | + // Check if config_file exists |
| 146 | + if(not std::filesystem::exists(config_file)) { |
| 147 | + throw std::runtime_error("Could not open config file: " + config_file); |
| 148 | + } |
| 149 | + config_ = YAML::LoadFile(config_file); |
| 150 | + data_dir_ = config_["pDataDir"].as<std::string>(); |
| 151 | + batch_size_ = config_["BatchSize"].as<size_t>(); |
| 152 | + num_epochs_ = config_["NumEpochs"].as<size_t>(); |
| 153 | + learning_rate_ = config_["LearningRate"].as<float>(); |
| 154 | + weight_decay_ = config_["WeightDecay"].as<float>(); |
| 155 | + |
| 156 | + cnn_config_ = config_["CNN"] ; |
| 157 | + image_size_ = cnn_config_["ImageSize"].as<int>(); |
| 158 | + } |
| 159 | + |
| 160 | + }; |
| 161 | + |
| 162 | +} // namespace CoverageControlTorch |
| 163 | + |
| 164 | +#endif //COVERAGECONTROL_TRAIN_CNN_H_ |
0 commit comments