Skip to content

Commit 07abb2f

Browse files
author
Saurav Agarwal
committed
update max_one
1 parent 84ed430 commit 07abb2f

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

python/scripts/evaluators/constrained_learning.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def __init__(
107107
self.lambda_duals = np.array([1.0, 1.0, 0.0])
108108
elif dual_updater == "proj_1":
109109
self.lambda_duals = np.array([1.0 / self.num_idfs for i in range(self.num_idfs)])
110+
elif dual_updater == "max_one":
111+
self.lambda_duals = np.array([1.0 / self.num_idfs for i in range(self.num_idfs)])
112+
print(f"Initial Lambda_dual: {self.lambda_duals}")
110113

111114
# Set the real values here
112115
# self.alphas = np.array([1 / self.num_idfs for i in range(self.num_idfs)])
@@ -209,6 +212,8 @@ def evaluate(self):
209212
+ self.eta_dual * (obj_values - self.alphas) / obj_max,
210213
0,
211214
)
215+
if self.dual_updater == "max_one":
216+
self.lambda_duals = self.compute_obj_values()
212217
self.lambda_duals = self.fun_dual_updater(
213218
self.dual_updater, self.lambda_duals
214219
)
@@ -266,6 +271,12 @@ def fun_dual_updater(self, configs, lambdas):
266271
max_index = np.argmax(lambdas)
267272
lambdas = np.zeros(len(lambdas))
268273
lambdas[max_index] = 1
274+
print(f"max_index: {max_index}")
275+
print(f"lambdas: {lambdas}")
276+
277+
if self.max_dual_index != max_index:
278+
self.max_dual_index = max_index
279+
self.max_dual_switch_counter += 1
269280

270281
return lambdas
271282

@@ -291,7 +302,7 @@ def fun_dual_updater(self, configs, lambdas):
291302
env_id,
292303
eta_dual,
293304
T_0,
294-
dual_updater="proj_1",
305+
dual_updater="max_one",
295306
alpha=0.0,
296307
normalize=True,
297308
obj_normalize_factor=1e10,

0 commit comments

Comments
 (0)