Skip to content

Commit bacc3f9

Browse files
author
Saurav Agarwal
committed
minor updates
1 parent 2fe5774 commit bacc3f9

File tree

6 files changed

+131
-95
lines changed

6 files changed

+131
-95
lines changed

cppsrc/core/include/CoverageControl/cuda_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class CudaUtils {
107107
std::cerr << "No CUDA device found" << std::endl;
108108
return false;
109109
}
110-
std::cout << "Initializing CUDA device " << device_id_ << std::endl;
110+
/* std::cout << "Initializing CUDA device " << device_id_ << std::endl; */
111111
if (GPUDeviceInit(device_id_) != device_id_) {
112112
std::cerr << "Failed to initialize CUDA device" << std::endl;
113113
return false;

cppsrc/core/src/cuda/cuda_utils.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ namespace CoverageControl {
9494
}
9595

9696
CheckCudaErrors(cudaSetDevice(dev_id));
97-
std::cout << "GPU Device " << dev_id << " has been set" << std::endl;
98-
std::cout << "CUDA Device [" << dev_id << "]: \""
99-
<< _ConvertSMVer2ArchName(major, minor) << "\"" << std::endl;
97+
/* std::cout << "GPU Device " << dev_id << " has been set" << std::endl; */
98+
/* std::cout << "CUDA Device [" << dev_id << "]: \"" */
99+
/* << _ConvertSMVer2ArchName(major, minor) << "\"" << std::endl; */
100100
is_cuda_initialized_ = true;
101101
device_id_ = dev_id;
102102
return dev_id;
@@ -133,10 +133,10 @@ namespace CoverageControl {
133133
&major, cudaDevAttrComputeCapabilityMajor, current_device));
134134
CheckCudaErrors(cudaDeviceGetAttribute(
135135
&minor, cudaDevAttrComputeCapabilityMinor, current_device));
136-
std::cout << "GPU Device " << current_device << " has been set"
137-
<< std::endl;
138-
std::cout << "CUDA Device [" << current_device << "]: \""
139-
<< _ConvertSMVer2ArchName(major, minor) << "\"" << std::endl;
136+
/* std::cout << "GPU Device " << current_device << " has been set" */
137+
/* << std::endl; */
138+
/* std::cout << "CUDA Device [" << current_device << "]: \"" */
139+
/* << _ConvertSMVer2ArchName(major, minor) << "\"" << std::endl; */
140140
return current_device;
141141
} else {
142142
devices_prohibited++;

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ classifiers = [
2626
dynamic = ["version"]
2727
dependencies = ["numpy", "pyyaml",
2828
'toml; python_version < "3.11"',
29+
"rich"
2930
]
3031

3132
[project.optional-dependencies]

python/coverage_control/algorithms/controllers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(self, config: dict, params: Parameters, env: CoverageSystem):
108108
self.cnn_map_size = self.config["CNNMapSize"]
109109

110110
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111-
print(f"Using device: {self.device}")
111+
# print(f"Using device: {self.device}")
112112

113113
if "ModelFile" in self.config:
114114
self.model_file = IOUtils.sanitize_path(self.config["ModelFile"])
@@ -125,7 +125,7 @@ def __init__(self, config: dict, params: Parameters, env: CoverageSystem):
125125
self.actions_std = self.model.actions_std.to(self.device)
126126
self.model = self.model.to(self.device)
127127
self.model.eval()
128-
self.model = torch.compile(self.model, dynamic=True)
128+
# self.model = torch.compile(self.model, dynamic=True)
129129

130130
def step(self, env):
131131
"""

python/coverage_control/nn/models/lpac.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,20 +80,34 @@ def forward(self, data: torch_geometric.data.Data) -> torch.Tensor:
8080
x = self.output_linear(self.gnn_mlp(self.gnn_backbone(gnn_backbone_in, edge_index)))
8181
return x
8282

83+
def load_compiled_state_dict(self, model_state_dict_path: str) -> None:
84+
# remove _orig_mod from the state dict keys
85+
state_dict = torch.load(model_state_dict_path)
86+
new_state_dict = {}
87+
for key in state_dict.keys():
88+
new_state_dict[key.replace("_orig_mod.", "")] = state_dict[key]
89+
self.load_state_dict(new_state_dict, strict=True)
90+
8391
def load_model(self, model_state_dict_path: str) -> None:
8492
"""
8593
Load the model from the state dict
8694
"""
87-
self.load_state_dict(torch.load(model_state_dict_path), strict=False)
95+
self.load_state_dict(torch.load(model_state_dict_path), strict=True)
96+
97+
def load_model_state_dict(self, model_state_dict: dict) -> None:
98+
"""
99+
Load the model from the state dict
100+
"""
101+
self.load_state_dict(model_state_dict, strict=True)
88102

89103
def load_cnn_backbone(self, model_path: str) -> None:
90104
"""
91105
Load the CNN backbone from the model path
92106
"""
93-
self.load_state_dict(torch.load(model_path).state_dict(), strict=False)
107+
self.load_state_dict(torch.load(model_path).state_dict(), strict=True)
94108

95109
def load_gnn_backbone(self, model_path: str) -> None:
96110
"""
97111
Load the GNN backbone from the model path
98112
"""
99-
self.load_state_dict(torch.load(model_path).state_dict(), strict=False)
113+
self.load_state_dict(torch.load(model_path).state_dict(), strict=True)

python/scripts/evaluators/eval.py

Lines changed: 103 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
# @file eval.py
22
# @brief Evaluates the performance of the controllers on a set of environments
33
import os
4-
import sys
4+
import argparse
55

66
import coverage_control as cc
77
import numpy as np
8+
from rich.progress import (
9+
Progress,
10+
BarColumn,
11+
TextColumn,
12+
TimeRemainingColumn,
13+
TimeElapsedColumn,
14+
TaskProgressColumn,
15+
MofNCompleteColumn,
16+
)
817
from coverage_control import CoverageSystem
918
from coverage_control import IOUtils
1019
from coverage_control import WorldIDF
1120
from coverage_control.algorithms import ControllerCVT
1221
from coverage_control.algorithms import ControllerNN
1322

1423

15-
# @ingroup python_api
1624
class Evaluator:
1725
"""
1826
Evaluates the performance of the controllers on a set of environments
@@ -46,102 +54,115 @@ def __init__(self, in_config):
4654
self.num_steps = self.config["NumSteps"]
4755
os.makedirs(self.env_dir + "/init_maps", exist_ok=True)
4856

57+
self.columns = [
58+
BarColumn(bar_width=None),
59+
TaskProgressColumn(),
60+
TextColumn("[progress.description]{task.description}"),
61+
MofNCompleteColumn(),
62+
TextColumn("Controller: {task.fields[info]}"),
63+
TimeRemainingColumn(),
64+
TimeElapsedColumn()
65+
]
66+
4967
def evaluate(self, save=True):
50-
dataset_count = 0
5168
cost_data = np.zeros((self.num_controllers, self.num_envs, self.num_steps))
5269

53-
while dataset_count < self.num_envs:
54-
print(f"Environment {dataset_count}")
55-
pos_file = self.env_dir + "/" + str(dataset_count) + ".pos"
56-
env_file = self.env_dir + "/" + str(dataset_count) + ".env"
57-
58-
if os.path.isfile(env_file) and os.path.isfile(pos_file):
59-
world_idf = WorldIDF(self.cc_params, env_file)
60-
env_main = CoverageSystem(self.cc_params, world_idf, pos_file)
61-
else:
62-
print(f"Creating new environment {dataset_count}")
63-
env_main = CoverageSystem(self.cc_params)
64-
env_main.WriteEnvironment(pos_file, env_file)
65-
world_idf = env_main.GetWorldIDFObject()
66-
67-
# env_main.PlotInitMap(self.env_dir + "/init_maps", f"{dataset_count}")
68-
robot_init_pos = env_main.GetRobotPositions(force_no_noise=True)
69-
70-
for controller_id in range(self.num_controllers):
71-
step_count = 0
72-
env = CoverageSystem(self.cc_params, world_idf, robot_init_pos)
73-
74-
# map_dir = self.eval_dir + "/" + self.controllers[controller_id]["Name"] + "/plots/"
75-
# env.RecordPlotData()
76-
# env.PlotMapVoronoi(map_dir, step_count)
77-
78-
if self.controllers_configs[controller_id]["Type"] == "Learning":
79-
Controller = ControllerNN
70+
with Progress(*self.columns, expand=True) as progress:
71+
task = progress.add_task(
72+
"[bold blue]Evaluation",
73+
total=self.num_envs,
74+
info="",
75+
auto_refresh=False,
76+
)
77+
78+
for env_count in range(self.num_envs):
79+
pos_file = self.env_dir + "/" + str(env_count) + ".pos"
80+
env_file = self.env_dir + "/" + str(env_count) + ".env"
81+
82+
if os.path.isfile(env_file) and os.path.isfile(pos_file):
83+
world_idf = WorldIDF(self.cc_params, env_file)
84+
env_main = CoverageSystem(self.cc_params, world_idf, pos_file)
8085
else:
81-
Controller = ControllerCVT
82-
controller = Controller(
83-
self.controllers_configs[controller_id], self.cc_params, env
84-
)
85-
initial_objective_value = env.GetObjectiveValue()
86-
cost_data[controller_id, dataset_count, step_count] = (
87-
env.GetObjectiveValue() / initial_objective_value
88-
)
89-
step_count = step_count + 1
90-
91-
while step_count < self.num_steps:
92-
objective_value, converged = controller.step(env)
93-
cost_data[controller_id, dataset_count, step_count] = (
94-
objective_value / initial_objective_value
95-
)
86+
# print(f"Creating new environment {env_count}")
87+
env_main = CoverageSystem(self.cc_params)
88+
env_main.WriteEnvironment(pos_file, env_file)
89+
world_idf = env_main.GetWorldIDFObject()
9690

97-
if converged:
98-
cost_data[controller_id, dataset_count, step_count:] = (
99-
objective_value / initial_objective_value
100-
)
91+
# env_main.PlotInitMap(self.env_dir + "/init_maps", f"{env_count}")
92+
robot_init_pos = env_main.GetRobotPositions(force_no_noise=True)
10193

102-
break
103-
# env.PlotMapVoronoi(map_dir, step_count)
94+
for controller_id in range(self.num_controllers):
95+
step_count = 0
96+
env = CoverageSystem(self.cc_params, world_idf, robot_init_pos)
97+
98+
# map_dir = self.eval_dir + "/" + self.controllers[controller_id]["Name"] + "/plots/"
10499
# env.RecordPlotData()
100+
# env.PlotMapVoronoi(map_dir, step_count)
101+
102+
if self.controllers_configs[controller_id]["Type"] == "Learning":
103+
Controller = ControllerNN
104+
else:
105+
Controller = ControllerCVT
106+
controller = Controller(
107+
self.controllers_configs[controller_id], self.cc_params, env
108+
)
109+
initial_objective_value = env.GetObjectiveValue()
110+
cost_data[controller_id, env_count, step_count] = (
111+
env.GetObjectiveValue() / initial_objective_value
112+
)
105113
step_count = step_count + 1
106114

107-
if step_count % 100 == 0:
108-
val = cost_data[controller_id, dataset_count, step_count - 1]
109-
print(
110-
f"Environment {dataset_count} "
111-
f"{controller.name} "
112-
f"Step {step_count} "
113-
f"Objective Value {val:.3e}"
115+
while step_count < self.num_steps:
116+
objective_value, converged = controller.step(env)
117+
cost_data[controller_id, env_count, step_count] = (
118+
objective_value / initial_objective_value
114119
)
115120

116-
print(
117-
f"Environment {dataset_count} "
118-
f"{controller.name} "
119-
f"Step {step_count} "
120-
f"Objective Value {val:.3e}"
121-
)
122-
if save is True:
123-
self.controller_dir = (
124-
self.eval_dir
125-
+ "/"
126-
+ self.controllers_configs[controller_id]["Name"]
127-
)
128-
controller_data_file = self.controller_dir + "/" + "eval.csv"
129-
np.savetxt(
130-
controller_data_file,
131-
cost_data[controller_id, : dataset_count + 1, :],
132-
delimiter=",",
133-
)
134-
# env.RenderRecordedMap(self.eval_dir + "/" + self.controllers[controller_id]["Name"] + "/", "video.mp4")
135-
del controller
136-
del env
137-
dataset_count = dataset_count + 1
121+
if converged:
122+
cost_data[controller_id, env_count, step_count:] = (
123+
objective_value / initial_objective_value
124+
)
125+
126+
break
127+
# env.PlotMapVoronoi(map_dir, step_count)
128+
# env.RecordPlotData()
129+
130+
if step_count % 10 == 0:
131+
info = (
132+
f"{controller_id}/{self.num_controllers} {controller.name} "
133+
f"Step: {step_count} Obj: {cost_data[controller_id, env_count, step_count]:.2e}"
134+
)
135+
progress.update(task, info=info)
136+
progress.refresh()
137+
138+
step_count = step_count + 1
139+
140+
if save is True:
141+
self.controller_dir = (
142+
self.eval_dir
143+
+ "/"
144+
+ self.controllers_configs[controller_id]["Name"]
145+
)
146+
controller_data_file = self.controller_dir + "/" + "eval.csv"
147+
np.savetxt(
148+
controller_data_file,
149+
cost_data[controller_id, : env_count + 1, :],
150+
delimiter=",",
151+
)
152+
# env.RenderRecordedMap(self.eval_dir + "/" + self.controllers[controller_id]["Name"] + "/", "video.mp4")
153+
del controller
154+
del env
155+
progress.advance(task)
156+
progress.refresh()
138157

139158
return cost_data
140159

141160

142161
if __name__ == "__main__":
143-
config_file = sys.argv[1]
144-
config = IOUtils.load_toml(config_file)
162+
parser = argparse.ArgumentParser()
163+
parser.add_argument("config_file", type=str, help="Path to config file")
164+
args = parser.parse_args()
165+
config = IOUtils.load_toml(args.config_file)
145166

146167
evaluator = Evaluator(config)
147168
evaluator.evaluate()

0 commit comments

Comments
 (0)