Skip to content

Commit af89c0f

Browse files
author
Saurav Agarwal
committed
Add convergence check
1 parent 2b6e9d7 commit af89c0f

File tree

2 files changed

+33
-24
lines changed

2 files changed

+33
-24
lines changed

python/coverage_control/algorithms/controllers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,4 +150,7 @@ def step(self, env):
150150
point_vector_actions = PointVector(actions.cpu().numpy())
151151
env.StepActions(point_vector_actions)
152152

153+
# Check if actions are all zeros (1e-12)
154+
if torch.allclose(actions, torch.zeros_like(actions), atol=1e-5):
155+
return env.GetObjectiveValue(), True
153156
return env.GetObjectiveValue(), False

python/scripts/evaluators/constrained_learning.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,19 @@ def update_idf(self, coefficients, normalize=False):
155155
)
156156

157157
def advance_state(self):
158-
self.controller.step(self.env_main)
158+
obj_val, is_converged = self.controller.step(self.env_main)
159159
self.step_counter = self.step_counter + 1
160160

161161
if self.generate_video and self.step_counter % 1 == 0:
162162
self.env_main.RecordPlotData()
163163
# self.env_main.PlotRobotLocalMap("./robot_maps/", 0, self.step_counter)
164164
# self.env_main.PlotRobotSensorView("./robot_maps/", 0, self.step_counter)
165-
robot_positions = self.env_main.GetRobotPositions()
166-
167-
for env in self.envs:
168-
env.SetGlobalRobotPositions(robot_positions)
165+
if is_converged == False:
166+
robot_positions = self.env_main.GetRobotPositions()
167+
for env in self.envs:
168+
env.SetGlobalRobotPositions(robot_positions)
169+
is_state_updated = not is_converged
170+
return is_state_updated
169171

170172
def compute_obj_values(self):
171173
obj_values = np.array(
@@ -188,33 +190,37 @@ def evaluate(self):
188190
K = self.num_steps // self.T_0
189191
self.lambda_duals = self.fun_dual_updater(self.dual_updater, self.lambda_duals)
190192

193+
self.update_idf(self.lambda_duals, normalize=self.normalize)
191194
obj_values = self.compute_obj_values()
192195
print(
193196
f"{0} Objective values: {obj_values} Lambda duals: {self.lambda_duals}, self alphas: {self.alphas}"
194197
)
195198

196199
for k in range(K):
197-
obj_values = np.zeros(self.num_idfs)
198-
199-
self.update_idf(self.lambda_duals, normalize=self.normalize)
200200

201+
is_state_updated = False
201202
for _ in range(self.T_0):
202-
self.advance_state()
203+
is_state_updated = is_state_updated or self.advance_state()
203204
# obj_values += self.compute_obj_values() # This is a vector
204205

205-
# obj_values /= self.T_0
206-
obj_values = self.compute_obj_values()
207-
obj_max = np.max(obj_values)
208-
self.lambda_duals = np.maximum(
209-
self.lambda_duals
210-
+ self.eta_dual * (obj_values - self.alphas) / obj_max,
211-
0,
212-
)
213-
if self.dual_updater == "max_one" or self.dual_updater == "malencia":
214-
self.lambda_duals = self.compute_obj_values()
215-
self.lambda_duals = self.fun_dual_updater(
216-
self.dual_updater, self.lambda_duals
217-
)
206+
if is_state_updated == True:
207+
# obj_values /= self.T_0
208+
obj_values = self.compute_obj_values()
209+
obj_max = np.max(obj_values)
210+
self.lambda_duals = np.maximum(
211+
self.lambda_duals
212+
+ self.eta_dual * (obj_values - self.alphas) / obj_max,
213+
0,
214+
)
215+
if self.dual_updater == "max_one" or self.dual_updater == "malencia":
216+
self.lambda_duals = self.compute_obj_values()
217+
self.lambda_duals = self.fun_dual_updater(
218+
self.dual_updater, self.lambda_duals
219+
)
220+
self.update_idf(self.lambda_duals, normalize=self.normalize)
221+
else:
222+
obj_values = self.all_obj_values[:, k]
223+
self.lambda_duals = self.all_lambda_duals[:, k]
218224

219225
self.all_obj_values[:, k + 1] = obj_values
220226
self.all_lambda_duals[:, k + 1] = self.lambda_duals
@@ -296,7 +302,7 @@ def fun_dual_updater(self, configs, lambdas):
296302
envs = list(range(100))
297303
# T_0s = [25, 50, 75, 100]
298304
# envs = [72]
299-
T_0s = [25]
305+
T_0s = [1]
300306
eta_duals = [1]
301307
eval_dir = sys.argv[2]
302308

@@ -309,7 +315,7 @@ def fun_dual_updater(self, configs, lambdas):
309315
env_id,
310316
eta_dual,
311317
T_0,
312-
dual_updater="proj_1",
318+
dual_updater="malencia",
313319
alpha=0.0,
314320
normalize=True,
315321
obj_normalize_factor=1e10,

0 commit comments

Comments
 (0)