@@ -155,17 +155,19 @@ def update_idf(self, coefficients, normalize=False):
155155 )
156156
157157 def advance_state (self ):
158- self .controller .step (self .env_main )
158+ obj_val , is_converged = self .controller .step (self .env_main )
159159 self .step_counter = self .step_counter + 1
160160
161161 if self .generate_video and self .step_counter % 1 == 0 :
162162 self .env_main .RecordPlotData ()
163163 # self.env_main.PlotRobotLocalMap("./robot_maps/", 0, self.step_counter)
164164 # self.env_main.PlotRobotSensorView("./robot_maps/", 0, self.step_counter)
165- robot_positions = self .env_main .GetRobotPositions ()
166-
167- for env in self .envs :
168- env .SetGlobalRobotPositions (robot_positions )
165+ if is_converged == False :
166+ robot_positions = self .env_main .GetRobotPositions ()
167+ for env in self .envs :
168+ env .SetGlobalRobotPositions (robot_positions )
169+ is_state_updated = not is_converged
170+ return is_state_updated
169171
170172 def compute_obj_values (self ):
171173 obj_values = np .array (
@@ -188,33 +190,37 @@ def evaluate(self):
188190 K = self .num_steps // self .T_0
189191 self .lambda_duals = self .fun_dual_updater (self .dual_updater , self .lambda_duals )
190192
193+ self .update_idf (self .lambda_duals , normalize = self .normalize )
191194 obj_values = self .compute_obj_values ()
192195 print (
193196 f"{ 0 } Objective values: { obj_values } Lambda duals: { self .lambda_duals } , self alphas: { self .alphas } "
194197 )
195198
196199 for k in range (K ):
197- obj_values = np .zeros (self .num_idfs )
198-
199- self .update_idf (self .lambda_duals , normalize = self .normalize )
200200
201+ is_state_updated = False
201202 for _ in range (self .T_0 ):
202- self .advance_state ()
203+ is_state_updated = is_state_updated or self .advance_state ()
203204 # obj_values += self.compute_obj_values() # This is a vector
204205
205- # obj_values /= self.T_0
206- obj_values = self .compute_obj_values ()
207- obj_max = np .max (obj_values )
208- self .lambda_duals = np .maximum (
209- self .lambda_duals
210- + self .eta_dual * (obj_values - self .alphas ) / obj_max ,
211- 0 ,
212- )
213- if self .dual_updater == "max_one" or self .dual_updater == "malencia" :
214- self .lambda_duals = self .compute_obj_values ()
215- self .lambda_duals = self .fun_dual_updater (
216- self .dual_updater , self .lambda_duals
217- )
206+ if is_state_updated == True :
207+ # obj_values /= self.T_0
208+ obj_values = self .compute_obj_values ()
209+ obj_max = np .max (obj_values )
210+ self .lambda_duals = np .maximum (
211+ self .lambda_duals
212+ + self .eta_dual * (obj_values - self .alphas ) / obj_max ,
213+ 0 ,
214+ )
215+ if self .dual_updater == "max_one" or self .dual_updater == "malencia" :
216+ self .lambda_duals = self .compute_obj_values ()
217+ self .lambda_duals = self .fun_dual_updater (
218+ self .dual_updater , self .lambda_duals
219+ )
220+ self .update_idf (self .lambda_duals , normalize = self .normalize )
221+ else :
222+ obj_values = self .all_obj_values [:, k ]
223+ self .lambda_duals = self .all_lambda_duals [:, k ]
218224
219225 self .all_obj_values [:, k + 1 ] = obj_values
220226 self .all_lambda_duals [:, k + 1 ] = self .lambda_duals
@@ -296,7 +302,7 @@ def fun_dual_updater(self, configs, lambdas):
296302 envs = list (range (100 ))
297303 # T_0s = [25, 50, 75, 100]
298304 # envs = [72]
299- T_0s = [25 ]
305+ T_0s = [1 ]
300306 eta_duals = [1 ]
301307 eval_dir = sys .argv [2 ]
302308
@@ -309,7 +315,7 @@ def fun_dual_updater(self, configs, lambdas):
309315 env_id ,
310316 eta_dual ,
311317 T_0 ,
312- dual_updater = "proj_1 " ,
318+ dual_updater = "malencia " ,
313319 alpha = 0.0 ,
314320 normalize = True ,
315321 obj_normalize_factor = 1e10 ,
0 commit comments