Skip to content

Commit 375843b

Browse files
author
Saurav Agarwal
committed
add avg updater
1 parent 17be0fc commit 375843b

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

python/scripts/evaluators/constrained_learning.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def __init__(
109109
self.lambda_duals = np.array([1.0 / self.num_idfs for i in range(self.num_idfs)])
110110
elif dual_updater == "max_one":
111111
self.lambda_duals = np.array([1.0 / self.num_idfs for i in range(self.num_idfs)])
112+
elif dual_updater == "avg":
113+
self.lambda_duals = np.array([1.0 / self.num_idfs for i in range(self.num_idfs)])
112114
print(f"Initial Lambda_dual: {self.lambda_duals}")
113115

114116
# Set the real values here
@@ -280,6 +282,10 @@ def fun_dual_updater(self, configs, lambdas):
280282

281283
return lambdas
282284

285+
if dual_updater == "avg":
286+
self.lambda_duals = np.array([1.0 / self.num_idfs for i in range(self.num_idfs)])
287+
return lambdas
288+
283289
raise ValueError("configs not recognized")
284290

285291

0 commit comments

Comments
 (0)