|
13 | 13 |
|
14 | 14 | #include "type_conversions.h" |
15 | 15 |
|
| 16 | +typedef long int T_idx_t; |
16 | 17 | using namespace torch::indexing; |
| 18 | +using namespace CoverageControl; |
17 | 19 | namespace CoverageControlTorch { |
18 | 20 |
|
19 | 21 | class CoverageSystem : public CoverageControl::CoverageSystem { |
| 22 | + private: |
| 23 | + float env_resolution_ = 1; |
| 24 | + float comm_range_ = 256; |
| 25 | + |
| 26 | + void init() { |
| 27 | + env_resolution_ = (float) params_.pResolution; |
| 28 | + comm_range_ = (float) params_.pCommunicationRange; |
| 29 | + } |
20 | 30 |
|
21 | 31 | public: |
22 | 32 |
|
23 | | - CoverageSystem(Parameters const ¶ms, size_t const num_features, size_t const num_robots) : CoverageControl::CoverageSystem(params, num_features, num_robots) { } |
| 33 | + CoverageSystem(Parameters const ¶ms, size_t const num_features, size_t const num_robots) : CoverageControl::CoverageSystem(params, num_features, num_robots) { init(); } |
24 | 34 |
|
25 | | - CoverageSystem(Parameters const ¶ms, WorldIDF const &world_idf, std::string const &pos_file_name) : CoverageControl::CoverageSystem(params, world_idf, pos_file_name) { } |
| 35 | + CoverageSystem(Parameters const ¶ms, WorldIDF const &world_idf, std::string const &pos_file_name) : CoverageControl::CoverageSystem(params, world_idf, pos_file_name) { init(); } |
26 | 36 |
|
27 | | - CoverageSystem(Parameters const ¶ms, WorldIDF const &world_idf, std::vector <Point2> const &robot_positions) : CoverageControl::CoverageSystem(params, world_idf, robot_positions) { } |
| 37 | + CoverageSystem(Parameters const ¶ms, WorldIDF const &world_idf, std::vector <Point2> const &robot_positions) : CoverageControl::CoverageSystem(params, world_idf, robot_positions) { init(); } |
28 | 38 |
|
29 | | - CoverageSystem(Parameters const ¶ms, std::vector <BivariateNormalDistribution> const &dists, std::vector <Point2> const &robot_positions) : CoverageControl::CoverageSystem(params, dists, robot_positions) { } |
| 39 | + CoverageSystem(Parameters const ¶ms, std::vector <BivariateNormalDistribution> const &dists, std::vector <Point2> const &robot_positions) : CoverageControl::CoverageSystem(params, dists, robot_positions) { init(); } |
30 | 40 |
|
31 | 41 | torch::Tensor GetAllRobotsLocalMaps() { |
32 | 42 | torch::Tensor maps = torch::zeros({num_robots_, params_.pLocalMapSize, params_.pLocalMapSize}); |
33 | 43 | #pragma omp parallel for |
34 | 44 | for(size_t i = 0; i < num_robots_; i++) { |
35 | | - maps[i] = EigenToLibTorch(robots_[i].GetRobotLocalMap()); |
| 45 | + maps[i] = ToTensor(robots_[i].GetRobotLocalMap()); |
36 | 46 | } |
37 | 47 | return maps; |
38 | 48 | } |
39 | 49 |
|
40 | | - void GetAllRobotsCommunicationMaps(torch::Tensor maps, size_t const &map_size) const { |
41 | | -#pragma omp parallel for |
42 | | - for(int i = 0; i < num_robots_; i++) { |
43 | | - auto robot_neighbors_pos = GetRobotsInCommunication(i); |
44 | | - double comm_scale = (params_.pCommunicationRange * 2.) / map_size; |
45 | | - Point2 map_translation(map_size * comm_scale * params_.pResolution/2., map_size * comm_scale * params_.pResolution/2.); |
46 | | - for(Point2 const& relative_pos:robot_neighbors_pos) { |
47 | | - Point2 map_pos = relative_pos + map_translation; |
48 | | - int pos_idx, pos_idy; |
49 | | - MapUtils::GetClosestGridCoordinate(params_.pResolution * comm_scale, map_pos, pos_idx, pos_idy); |
50 | | - if(pos_idx < map_size and pos_idy < map_size and pos_idx >= 0 and pos_idy >= 0) { |
51 | | - maps.index_put_({i, pos_idx, pos_idy}, 1); |
52 | | - } |
53 | | - } |
| 50 | + void GetAllRobotsCommunicationMaps(torch::Tensor maps, size_t const &map_size) { |
| 51 | + float f_map_size = (float) map_size; |
| 52 | + torch::Tensor robot_positions = ToTensor(GetRobotPositions()); |
| 53 | + torch::Tensor scaled_relative_pos = torch::round((robot_positions.unsqueeze(0) - robot_positions.unsqueeze(1)) * f_map_size/(comm_range_ * env_resolution_ * 2.) + (f_map_size/2. - env_resolution_/2.)).to(torch::kInt64); |
| 54 | + torch::Tensor pairwise_dist_matrices = torch::cdist(robot_positions, robot_positions, 2); |
| 55 | + torch::Tensor diagonal_mask = torch::eye(pairwise_dist_matrices.size(0)).to(torch::kBool); |
| 56 | + pairwise_dist_matrices.masked_fill_(diagonal_mask, comm_range_ + 1); |
| 57 | + for (T_idx_t r_idx = 0; r_idx < num_robots_; ++r_idx) { |
| 58 | + torch::Tensor indices = scaled_relative_pos.index({r_idx, pairwise_dist_matrices[r_idx] < (comm_range_ - 0.001), Slice()}); |
| 59 | + maps.index_put_({r_idx, indices.index({Slice(),0}), indices.index({Slice(), 1})}, 1); |
54 | 60 | } |
55 | 61 | } |
56 | 62 | }; |
|
0 commit comments