@@ -107,6 +107,9 @@ def __init__(
107107 self .lambda_duals = np .array ([1.0 , 1.0 , 0.0 ])
108108 elif dual_updater == "proj_1" :
109109 self .lambda_duals = np .array ([1.0 / self .num_idfs for i in range (self .num_idfs )])
110+ elif dual_updater == "max_one" :
111+ self .lambda_duals = np .array ([1.0 / self .num_idfs for i in range (self .num_idfs )])
112+ print (f"Initial Lambda_dual: { self .lambda_duals } " )
110113
111114 # Set the real values here
112115 # self.alphas = np.array([1 / self.num_idfs for i in range(self.num_idfs)])
@@ -209,6 +212,8 @@ def evaluate(self):
209212 + self .eta_dual * (obj_values - self .alphas ) / obj_max ,
210213 0 ,
211214 )
215+ if self .dual_updater == "max_one" :
216+ self .lambda_duals = self .compute_obj_values ()
212217 self .lambda_duals = self .fun_dual_updater (
213218 self .dual_updater , self .lambda_duals
214219 )
@@ -266,6 +271,12 @@ def fun_dual_updater(self, configs, lambdas):
266271 max_index = np .argmax (lambdas )
267272 lambdas = np .zeros (len (lambdas ))
268273 lambdas [max_index ] = 1
274+ print (f"max_index: { max_index } " )
275+ print (f"lambdas: { lambdas } " )
276+
277+ if self .max_dual_index != max_index :
278+ self .max_dual_index = max_index
279+ self .max_dual_switch_counter += 1
269280
270281 return lambdas
271282
@@ -291,7 +302,7 @@ def fun_dual_updater(self, configs, lambdas):
291302 env_id ,
292303 eta_dual ,
293304 T_0 ,
294- dual_updater = "proj_1 " ,
305+ dual_updater = "max_one " ,
295306 alpha = 0.0 ,
296307 normalize = True ,
297308 obj_normalize_factor = 1e10 ,
0 commit comments