|
| 1 | +/** Main program for testing a CNN |
| 2 | + * |
| 3 | + * @file: train_cnn.cpp |
| 4 | + */ |
| 5 | + |
| 6 | +#include <string> |
| 7 | +#include <iostream> |
| 8 | +#include <CoverageControlTorch/cnn_module.h> |
| 9 | + |
| 10 | +using namespace CoverageControlTorch; |
| 11 | +void LoadDataset(std::string const &data_dir, int const image_size, int const output_dim, torch::Tensor &maps, torch::Tensor &features, torch::Tensor &features_mean, torch::Tensor &features_std) { |
| 12 | + std::string local_maps_file = data_dir + "/local_maps.pt"; |
| 13 | + std::string comm_maps_file = data_dir + "/comm_maps.pt"; |
| 14 | + std::string obstacle_maps_file = data_dir + "/obstacle_maps.pt"; |
| 15 | + std::string features_file = data_dir + "/normalized_coverage_features.pt"; |
| 16 | + |
| 17 | + torch::Tensor local_maps; |
| 18 | + torch::load(local_maps, local_maps_file); |
| 19 | + local_maps = local_maps.unsqueeze(2).view({-1, 1, image_size, image_size}); |
| 20 | + torch::Tensor comm_maps; |
| 21 | + torch::load(comm_maps, comm_maps_file); |
| 22 | + comm_maps = comm_maps.to_dense().view({-1, 2, image_size, image_size}); |
| 23 | + torch::Tensor obstacle_maps; |
| 24 | + torch::load(obstacle_maps, obstacle_maps_file); |
| 25 | + obstacle_maps = obstacle_maps.to_dense().unsqueeze(2).view({-1, 1, image_size, image_size}); |
| 26 | + |
| 27 | + torch::load(features, features_file); |
| 28 | + features = features.view({-1, features.size(2)}); |
| 29 | + features = features.index({Slice(), Slice(0, output_dim)}).to(torch::kCPU); |
| 30 | + |
| 31 | + maps = torch::cat({local_maps, comm_maps, obstacle_maps}, 1).to(torch::kCPU); |
| 32 | + |
| 33 | + torch::load(features_mean, data_dir + "/coverage_features_mean.pt"); |
| 34 | + torch::load(features_std, data_dir + "/coverage_features_std.pt"); |
| 35 | + |
| 36 | + std::cout << "maps shape: " << maps.sizes() << std::endl; |
| 37 | + |
| 38 | +} |
| 39 | + |
| 40 | +int main(int argc, char* argv[]) { |
| 41 | + |
| 42 | + if (argc < 2) { |
| 43 | + std::cout << "Usage: ./test_cnn <yaml>" << std::endl; |
| 44 | + return 1; |
| 45 | + } |
| 46 | + std::string config_file = std::string(argv[1]); |
| 47 | + |
| 48 | + torch::Device device(torch::kCPU); |
| 49 | + if (torch::cuda::is_available()) { |
| 50 | + device = torch::kCUDA; |
| 51 | + std::cout << "Using CUDA" << std::endl; |
| 52 | + } |
| 53 | + |
| 54 | + YAML::Node config = YAML::LoadFile(config_file); |
| 55 | + auto cnn_config = config["CNN"]; |
| 56 | + |
| 57 | + std::string data_dir = config["pDataDir"].as<std::string>(); |
| 58 | + |
| 59 | + CoverageControlCNN model(cnn_config); |
| 60 | + torch::load(model, data_dir + config["CNNTraining"]["ModelCkpt"].as<std::string>()); |
| 61 | + model->to(device); |
| 62 | + |
| 63 | + int image_size = cnn_config["ImageSize"].as<int>(); |
| 64 | + int output_dim = cnn_config["OutputDim"].as<int>(); |
| 65 | + |
| 66 | + torch::Tensor maps, features, features_mean, features_std; |
| 67 | + LoadDataset(data_dir, image_size, output_dim, maps, features, features_mean, features_std); |
| 68 | + std::cout << "Loaded dataset" << std::endl; |
| 69 | + |
| 70 | + maps = maps.to(torch::kCPU); |
| 71 | + torch::Tensor sub_maps = maps.index({Slice(0, 1000)}); |
| 72 | + sub_maps = sub_maps.to(device); |
| 73 | + torch::Tensor sub_features = features.index({Slice(0, 1000)}); |
| 74 | + std::cout << "submaps created" << std::endl; |
| 75 | + auto pred = model->forward(sub_maps).to(torch::kCPU); |
| 76 | + std::cout << "Pred done" << std::endl; |
| 77 | + auto loss = torch::mse_loss(pred, sub_features); |
| 78 | + std::cout << "Val loss: " << loss << std::endl; |
| 79 | + return 0; |
| 80 | +} |
0 commit comments