Skip to content

Commit 2fe5774

Browse files
Saurav AgarwalSaurav Agarwal
authored andcommitted
Update maps and cleanup
1 parent 97748c7 commit 2fe5774

File tree

17 files changed

+419
-488
lines changed

17 files changed

+419
-488
lines changed

cppsrc/core/include/CoverageControl/coverage_system.h

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ class CoverageSystem {
7070
WorldIDF world_idf_; //!< World IDF
7171
size_t num_robots_ = 0; //!< Number of robots
7272
std::vector<RobotModel> robots_; //!< Vector of robots of type RobotModel
73-
std::vector<std::pair<MapType, MapType>>
74-
communication_maps_; //!< Communication maps (2 channels) for each robot
7573
double normalization_factor_ = 0; //!< Normalization factor for the world IDF
7674
Voronoi voronoi_; //!< Voronoi object
7775
std::vector<VoronoiCell> voronoi_cells_; //!< Voronoi cells for each robot
@@ -567,14 +565,17 @@ class CoverageSystem {
567565
return robot_neighbors_pos;
568566
}
569567

570-
std::pair<MapType, MapType> const &GetCommunicationMap(size_t const, size_t);
568+
std::pair<MapType, MapType> GetRobotCommunicationMaps(size_t const, size_t);
571569

572-
const auto &GetCommunicationMaps(size_t map_size) {
573-
#pragma omp parallel for num_threads(num_robots_)
570+
std::vector<MapType> GetCommunicationMaps(size_t map_size) {
571+
std::vector<MapType> communication_maps(2 * num_robots_);
572+
/* #pragma omp parallel for num_threads(num_robots_) */
574573
for (size_t i = 0; i < num_robots_; ++i) {
575-
GetCommunicationMap(i, map_size);
574+
auto comm_map = GetRobotCommunicationMaps(i, map_size);
575+
communication_maps[2 * i] = comm_map.first;
576+
communication_maps[2 * i + 1] = comm_map.second;
576577
}
577-
return communication_maps_;
578+
return communication_maps;
578579
}
579580

580581
auto GetObjectiveValue() {

cppsrc/core/src/coverage_system.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,10 @@ CoverageSystem::CoverageSystem(
145145
InitSetup();
146146
}
147147

148-
std::pair<MapType, MapType> const &CoverageSystem::GetCommunicationMap(
148+
std::pair<MapType, MapType> CoverageSystem::GetRobotCommunicationMaps(
149149
size_t const id, size_t map_size) {
150-
communication_maps_[id] = std::make_pair(MapType::Zero(map_size, map_size),
151-
MapType::Zero(map_size, map_size));
150+
std::pair<MapType, MapType> communication_maps = std::make_pair(
151+
MapType::Zero(map_size, map_size), MapType::Zero(map_size, map_size));
152152
PointVector robot_neighbors_pos = GetRelativePositonsNeighbors(id);
153153
double center = map_size / 2. - params_.pResolution / 2.;
154154
Point2 center_point(center, center);
@@ -157,24 +157,23 @@ std::pair<MapType, MapType> const &CoverageSystem::GetCommunicationMap(
157157
relative_pos * map_size /
158158
(params_.pCommunicationRange * params_.pResolution * 2.) +
159159
center_point;
160-
int scaled_indices_x = scaled_indices_val[0];
161-
int scaled_indices_y = scaled_indices_val[1];
160+
int scaled_indices_x = std::round(scaled_indices_val[0]);
161+
int scaled_indices_y = std::round(scaled_indices_val[1]);
162162
Point2 normalized_relative_pos = relative_pos / params_.pCommunicationRange;
163163

164-
communication_maps_[id].first(scaled_indices_x, scaled_indices_y) +=
164+
communication_maps.first(scaled_indices_x, scaled_indices_y) +=
165165
normalized_relative_pos[0];
166-
communication_maps_[id].second(scaled_indices_x, scaled_indices_y) +=
166+
communication_maps.second(scaled_indices_x, scaled_indices_y) +=
167167
normalized_relative_pos[1];
168168
}
169-
return communication_maps_[id];
169+
return communication_maps;
170170
}
171171

172172
void CoverageSystem::InitSetup() {
173173
num_robots_ = robots_.size();
174174
robot_positions_history_.resize(num_robots_);
175175

176176
voronoi_cells_.resize(num_robots_);
177-
communication_maps_.resize(num_robots_);
178177

179178
robot_global_positions_.resize(num_robots_);
180179
for (size_t iRobot = 0; iRobot < num_robots_; ++iRobot) {
@@ -588,7 +587,7 @@ void CoverageSystem::PlotRobotCommunicationMaps(std::string const &dir_name,
588587
int const &robot_id,
589588
int const &step,
590589
size_t const &map_size) {
591-
auto robot_communication_maps = GetCommunicationMap(robot_id, map_size);
590+
auto robot_communication_maps = GetRobotCommunicationMaps(robot_id, map_size);
592591
Plotter plotter_x(dir_name, map_size * params_.pResolution,
593592
params_.pResolution);
594593
plotter_x.SetPlotName(

cppsrc/python_bindings/core_binds.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ void pyCoverageControl_core(py::module &m) {
8585

8686
py::bind_vector<PointVector>(m, "PointVector");
8787
py::bind_vector<std::vector<Point3>>(m, "Point3Vector");
88+
py::bind_vector<std::vector<MapType>>(m, "MapTypeVector");
8889

8990
py::class_<PolygonFeature>(m, "PolygonFeature")
9091
.def(py::init<>())
@@ -307,8 +308,7 @@ void pyCoverageControl_core_coverage_system(py::module &m) {
307308
py::return_value_policy::reference_internal)
308309
.def("GetRobotSensorView", &CoverageSystem::GetRobotSensorView,
309310
py::return_value_policy::reference_internal)
310-
.def("GetCommunicationMap", &CoverageSystem::GetCommunicationMap,
311-
py::return_value_policy::reference_internal)
311+
.def("GetCommunicationMaps", &CoverageSystem::GetCommunicationMaps)
312312
.def("GetRobotsInCommunication",
313313
&CoverageSystem::GetRobotsInCommunication)
314314
.def("ComputeVoronoiCells", &CoverageSystem::ComputeVoronoiCells,

params/learning_params.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ NumWorkers = 4
77
# Similarly, for the optimizer
88
[LPACModel]
99
Dir = "${CoverageControl_ws}/lpac/models/"
10-
Model = "model.pt"
11-
Optimizer = "optimizer.pt"
1210

1311
[CNNModel]
1412
Dir = "${CoverageControl_ws}/lpac/models/" # Absolute location

python/coverage_control/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
from ._version import version as __version__
1212
from .core import *
1313
from .io_utils import IOUtils
14+
from .coverage_env_utils import CoverageEnvUtils
1415

1516
# from .nn import *
1617

17-
__all__ = ["__version__", "core", "nn", "IOUtils"]
18+
__all__ = ["__version__", "core", "nn", "IOUtils", "CoverageEnvUtils"]
1819

1920
cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), "lib", "cmake")

python/coverage_control/algorithms/controllers.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
from . import DecentralizedCVT
2929
from . import NearOptimalCVT
3030
from .. import IOUtils
31+
from .. import CoverageEnvUtils
3132
from ..core import CoverageSystem
3233
from ..core import Parameters
3334
from ..core import PointVector
34-
from ..nn import CoverageEnvUtils
3535

3636
__all__ = ["ControllerCVT", "ControllerNN"]
3737

@@ -112,11 +112,11 @@ def __init__(self, config: dict, params: Parameters, env: CoverageSystem):
112112

113113
if "ModelFile" in self.config:
114114
self.model_file = IOUtils.sanitize_path(self.config["ModelFile"])
115-
self.model = torch.load(self.model_file)
115+
self.model = torch.load(self.model_file).to(self.device)
116116
else: # Load from ModelStateDict
117117
self.learning_params_file = IOUtils.sanitize_path(
118-
self.config["LearningParams"]
119-
)
118+
self.config["LearningParams"]
119+
)
120120
self.learning_params = IOUtils.load_toml(self.learning_params_file)
121121
self.model = cc_nn.LPAC(self.learning_params).to(self.device)
122122
self.model.load_model(IOUtils.sanitize_path(self.config["ModelStateDict"]))
@@ -125,6 +125,7 @@ def __init__(self, config: dict, params: Parameters, env: CoverageSystem):
125125
self.actions_std = self.model.actions_std.to(self.device)
126126
self.model = self.model.to(self.device)
127127
self.model.eval()
128+
self.model = torch.compile(self.model, dynamic=True)
128129

129130
def step(self, env):
130131
"""
@@ -140,8 +141,8 @@ def step(self, env):
140141
Objective value and convergence flag
141142
"""
142143
pyg_data = CoverageEnvUtils.get_torch_geometric_data(
143-
env, self.params, True, self.use_comm_map, self.cnn_map_size
144-
).to(self.device)
144+
env, self.params, True, self.use_comm_map, self.cnn_map_size
145+
).to(self.device)
145146
with torch.no_grad():
146147
actions = self.model(pyg_data)
147148
actions = actions * self.actions_std + self.actions_mean

python/coverage_control/nn/data_loaders/coverage_env_utils.py renamed to python/coverage_control/coverage_env_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import torchvision
3737
from scipy.spatial import distance_matrix
3838

39-
from ...core import CoverageSystem, DblVector, DblVectorVector, Parameters, PointVector
39+
from .core import CoverageSystem, DblVector, DblVectorVector, Parameters, PointVector
4040

4141

4242
## @ingroup python_api
@@ -257,9 +257,11 @@ def get_maps(
257257
)
258258

259259
if use_comm_map:
260-
comm_maps = CoverageEnvUtils.get_communication_maps(
261-
env, params, resized_map_size
262-
)
260+
comm_maps = env.GetCommunicationMaps(resized_map_size)
261+
comm_maps = torch.tensor(numpy.array(env.GetCommunicationMaps(resized_map_size)), dtype=torch.float32).reshape(num_robots, 2, resized_map_size, resized_map_size)
262+
# comm_maps = CoverageEnvUtils.get_communication_maps(
263+
# env, params, resized_map_size
264+
# )
263265
maps = torch.cat(
264266
[
265267
resized_local_maps.unsqueeze(1),

python/coverage_control/nn/data_loaders/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,12 @@
44

55
from __future__ import annotations
66

7-
from .coverage_env_utils import CoverageEnvUtils
87
from .data_loader_utils import DataLoaderUtils
98
from .loaders import (
109
CNNGNNDataset,
11-
LocalMapCNNDataset,
12-
LocalMapGNNDataset,
13-
VoronoiGNNDataset,
1410
)
1511

1612
__all__ = [
1713
"DataLoaderUtils",
18-
"CoverageEnvUtils",
19-
"LocalMapCNNDataset",
20-
"LocalMapGNNDataset",
2114
"CNNGNNDataset",
22-
"VoronoiGNNDataset",
2315
]

python/coverage_control/nn/data_loaders/data_loader_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,15 @@ def to_torch_geometric_data(
204204
if pos is None:
205205
data = torch_geometric.data.Data(
206206
x=feature,
207-
edge_index=edge_index.clone().detach(),
208-
edge_weight=weights.clone().detach(),
207+
edge_index=edge_index.clone(),
208+
edge_weight=weights.clone(),
209209
)
210210
else:
211211
data = torch_geometric.data.Data(
212212
x=feature,
213-
edge_index=edge_index.clone().detach(),
214-
edge_weight=weights.clone().detach(),
215-
pos=pos.clone().detach(),
213+
edge_index=edge_index.clone(),
214+
edge_weight=weights.clone(),
215+
pos=pos.clone(),
216216
)
217217

218218
return data

0 commit comments

Comments
 (0)