In this ICML 2025 paper, we showed that the behaviour of learning-rate schedules in deep learning can be explained - perhaps surprisingly - with convex optimization theory. This can be exploited to design better schedules: for example, we show that a schedule which we designed based on the theory (hence, for free) gives significant improvements for continual LLM training (~7% of steps saved for 124M and 210M Llama-type models, see Section 5.1 in the paper for details).
In the paper, we specifically designed a schedule for a two-stage continual learning scenario. Our proposed schedule is piecewise constant in each stage (combined with warmup and linear cooldown in the end).
In this blog post, we ask ourselves the question: what happens with more than two stages? What is the limiting behaviour?
Disclaimer: in contrast to the paper, where we tested these schedules on actual LLM pretraining tasks, here we will only look at their theoretical performance in terms of a worst-case suboptimality bound (see Figure 2 for the concrete bound).
Let us consider the iterates of SGD given by
\[x_{t+1} = x_t - \gamma \eta_t g_t,\]where \( (\eta_t)_{t\in \mathbb{N}}\) is a sequence of positive reals, \(\gamma > 0 \) and \(g_t\) is the (batch) gradient of our loss function.
Note that we decompose the step size into a base learning-rate \(\gamma\) and a schedule \((\eta_t)_t\). In a Pytorch training script, these two objects resemble lr (passed to the optimizer) and the torch.optim.lr_scheduler.
We give a quick summary of our experiment in the paper, which will be the starting point of this blog post. For more background, see Section 5.1 of the paper.
Consider the following scenario: we have previously trained a model for \(T_1\) steps. Now, we want to continue training up to \(T_2>T_1\) steps, but reuse the checkpoint after \(T_1\) steps instead of starting from scratch.
The problem: with increasing training horizon, the optimal base learning-rate \(\gamma\) decreases. This is true in theory and practice. So, if we had tuned our base learning-rate \(\gamma\) for the short run (which we assume here), it will be suboptimal for the long run.
Our solution is to adapt the schedule \((\eta_t)_t\) between \(T_1\) and \(T_2\), so that we can keep the base learning rate the same. Specifically, we construct a piecewise constant schedule which decreases in the second stage of training (from \(T_1\) to \(T_2\)). The decreasing factor is computed such that the optimal base LR \(\gamma\) (tuned in the first stage) remains the same, according to the bound from convex optimization theory. This computation can be done effectively for free.
This piecewise constant schedule is combined with a WSD-like approach (warmup-stable-decay), which uses a short warmup and linear cooldown to zero over the final 20% of steps. In our experiments we set \(T_1=50k\) and tried both \(T_2\in \{100k,200k\}\). Figure 1 below shows that this adapted schedule leads to an improvement in LLM pretraining, compared to keeping the schedule unmodified for the longer horizon.1
Figure 1 (from Schaipp et al., 2025): When extending the training horizon T from 50k to 100k or 200k steps, the piecewise constant schedule (left, green) constructed from theory achieves loss improvement for LLM pretraining compared to baseline (grey).
This brings us to the main question of this blog post: what happens if we go beyond two stages of training? What is the limiting schedule if we do the same sequentially for training horizons \(T_1 < T_2<\cdots<T_n\) and \(n\to +\infty\)?
Let \(n\in \mathbb{N}\) and \(\Delta t\in \mathbb{N}\). The starting point is indexed with \(t=1\).2 Assume that we train in total for \(n\) stages, defined by the checkpoints \(T_k = 1+ k \cdot \Delta t\) with \(k=1,\ldots,n\).
We focus on three different schedules (depicted in the animation further above):
1) Piecewise constant. This schedule is constructed sequentially: at each checkpoint \(T_k\), we decide on the next stage up to \(T_{k+1}\). During each stage, the schedule is constant (but at decreasing values). As described above and in the paper, we compute the decreasing factor stage-by-stage based on the theory, in order to keep the optimal base LR as close to theoretically optimal as possible.
2) Sqrt. Schedule set as \(\eta_t = 1/\sqrt{t}\). In the convex, Lipschitz-continuous setting, it achieves a last-iterate bound of \(\frac{DG}{2\sqrt{T}} \mathcal{O}(\log(T)^{1/2})\) (Orabona 2020, see further details below in the Appendix).
3) Linear-decay. For each \(k=1,\ldots,n\), we construct a schedule that decreases linearly from 1 to 0 over the interval \([1, T_k]\). Thus, for each checkpoint/horizon \(T_k\), the resulting schedule is different, which in practice would require retraining from scratch.
Notice that for linear-decay we construct a different schedule for each checkpoint/training horizon \(T_k\) (\(k=1,\ldots,n\)); in contrast to this, we construct only one single piecewise constant and sqrt schedule that covers \([1,T_n]\). One could call these “anytime schedules”, or “infinite schedules” (Zhai et al., 2022) as \(T_n\) can be arbitrarily large.
The notion of look-ahead. We should also point out that these three schedules are different in the way that they have knowledge of the training horizon: the sqrt is fully independent of the horizon; linear-decay is fully dependent on the horizon; the piecewise constant schedule lies in between, as it is constructed sequentially and only needs to know the next checkpoint in advance.
In other words, the three schedule differ in how far they can look-ahead the training horizon. On the one extreme, the sqrt schedule can only plan ahead until the next step, on the other extreme the linear-decay schedule can plan ahead for the entire training horizon (see Table 1). Knowledge on the training horizon can lead to better performance (or bounds).
| Schedule | Look-ahead (number of steps) |
|---|---|
| Sqrt | 1 |
| Piecewise constant | \(\Delta t\) |
| Linear-decay | \(T_k\) |
We will evaluate the theoretical bound from our paper (see Theorem 3.1 therein, or Figure 2 below) for each of the three schedule candidates. The central questions are:
(i) how much improvement (in terms of the bound) do we achieve by looking ahead further? This can be indicative for real-world performance (based on the conclusions from our paper).
(ii) What happens in the limit when \(n \to \infty\)? What are the asymptotics of the piecewise constant schedule?
Figure 2 (from Schaipp et al., 2025): Theoretical bound for SGD with arbitrary schedule in convex, nonsmooth optimization.
Tuning the base-learning rate. Before going to results, it remains open how we choose the base learning-rate \(\gamma\) for each of the schedules:
For the piecewise constant schedule, we pick the \(\gamma\) that minimizes the bound at the first checkpoint \(T_1\) when setting \(D=G=1\). (Here, \(G\) denotes an upper bound of the gradient norms \(\|g_t\| \).) We then keep this value of \(\gamma\) for the rest of training until \(T_n\).
For the sqrt-schedule, we do the same as for piecewise constant in order to allow a fair comparison; note that this tuning of \(\gamma\) implicitly introduces some look-ahead for this schedule, however only during the first stage of training up to \(T_1\). In contrast, the piecewise constant schedule has a look-ahead advantage for the entire rest \([T_1, T_n]\) as well.
For linear-decay remember that we construct a schedule for each checkpoint/horizon \(T_k\) with \(k\in[n]\). Accordingly, we pick \(\gamma_k\) for each \(k=1,\ldots,n\) such that it minimizes the bound for the respective \(T_k\) (again setting \(D=G=1\)).
One subtlety here is that the optimal value of \(\gamma\) requires knowledge of \(D\) and \(G\); through the above tuning procedure we allow all three schedule implicit knowledge of these values. (I made sure manually that the specific values of \(D\) and \(G\) do not affect the results below qualitatively.)
Warmup and cooldown. To keep things simple, we did not include neither warmup nor cooldown for the schedules here.
We first set \(\Delta t =100\) and \(n\) such that \(T_n=10\,000\). First, let us compare the piecewise constant schedule to the linear-decay schedule(s).
Performance gap to linear-decay. Figure 3 reveals that the piecewise constant schedule (blue) performs worse than linear-decay (greys) for each single horizon \(T_k\). This is not surprising: linear-decay is known to achieve the best worst-case bound (Defazio et al., 2023; Zamani & Glineur, 2023), and as it can plan for the entire horizon, it should also yield better performance.
Figure 3: Piecewise constant schedule (blue) performs worse than linear-decay (greys).
So how big is the gap in terms of the convergence rate? Here, it seems appropriate to fit a “scaling law”: let \(y_T\) be the ratio of the final bound \(\Omega_{T}\) for \(k=1,\ldots,n\) of the two schedules. Then, we can fit \(y_T\) as a function of \(T\) based on our observations at \(T_k,~k=1,\ldots,n\). We fit two laws:
\[ y_T = a + b \cdot T^c, \quad y_T = a + b\cdot \log(T)^c \]
Here, \(a,b,c\) are fittable parameters, and we constrain \(a,b\) to be positive; for the second law, we constrain \(c\in [-\infty.1]\). The powerlaw wrt \(\log(T)\) is motivated by the fact that log-factors often appear for slightly suboptimal schedules (for example, a cooldown also brings \(\sqrt{\log(T)}\) of improvement).
Figure 4: Scaling laws for the performance loss compared to linear-decay. Number in brackets indicates the mean absolute deviation over the data points (black).
Figure 4 shows that both laws fit the data well, with a slight advantage in favor of the standard powerlaw. As it seems intractable to derive the bound in analytical form for the piecewise schedule, we cannot go down any further this road at this point.
Asymptotics of the piecewise constant schedule. Next, we look at the resulting shape of the piecewise constant schedule (again for \(\Delta t=100\)). Figure 5 plots the obtained schedule (left) and its bound (right) compared to the sqrt-schedule. We can see that the schedules look quite similar, and that the piecewise schedule yields a slightly better bound.
Figure 5: Learning rate (left) and bound (right) for piecewise constant schedule (blue) and sqrt schedule (green).
Again, this suggests that bigger look-ahead (here for piecewise schedule) achieves better bounds. We take a closer look: what happens when the look-ahead advantage vanishes? For this we look at three different values \(\Delta t \in \{20, 100, 500\}\). First, we can see (Figure 6, left) that the sequentially constructed piecewise constant schedule indeed quickly arrives at \(1/\sqrt{t}\).
Second, the advantage of piecewise constant vs. sqrt after a sufficiently large burn-in time also becomes smaller when \(\Delta t\) is smaller (Figure 6, right). However, even with only 20 look-ahead steps we can still obtain a 5% smaller bound.
Figure 6: Left: Piecewise constant schedule asymptotic shape. Right: relative improvement for different look-ahead steps.
This leaves us with the following insight: being able to plan ahead for \(\Delta t\) iterations allows to design schedules with better bounds. The sqrt-schedule, which is a textbook example in convex, nonsmooth optimization, comes out naturally as the limiting schedule when \(\Delta t \to 1\).
This connection led me to re-visit how the sqrt-schedule is usually motivated: classical textbooks (e.g. the ones by Nesterov and Polyak) introduce the sqrt-schedule for the subgradient method as the anytime-proxy for the optimal constant step size \(\propto 1/\sqrt{T}\). This structure originates from balancing two terms in the suboptimality bound, where one relates to the sum of step sizes, and the other to the sum of squared stepsizes. I could not find a geometric motivation for the rule \(1/\sqrt{t}\) (let me know, if you know one); it seems that our construction of a piecewise constant schedule refines the classical derivation (from constant \(1/\sqrt{T}\) to anytime \(1/\sqrt{t}\)).
Knowing the training horizon in advance (“lookahead”) allows to design better learning-rate schedules. We can use convex, nonsmooth optimization bounds to design an piecewise constant schedule, that plans ahead from one checkpoint to the next.
Compared to the ideal linear-decay schedule, the drawback of our piecewise constant strategy grows with training length. Compared to the no-lookahead sqrt-schedule, even rather short lookahead still yields a noticeable (theoretical) advantage.
Maybe the most interesting takeaway: the sqrt-schedule appears as the limit of the piecewise constant strategy, when the lookahead vanishes.
Acknowledgements. I’d like to thank Francis Bach for his feedback on this post.
Defazio, A., Cutkosky, A., Mehta, H., and Mishchenko, K. Optimal linear decay learning rate schedules and further refinements. arXiv:2310.07831, 2023. [link]
Orabona, F. Last iterate of SGD converges (even in unbounded domains), 2020. [link]
Schaipp, F., Hägele, A., Taylor, A., Simsekli, U., Bach, F. The Surprising Agreement Between Convex Optimization Theory and Learning-Rate Scheduling for Large Model Training. ICML, 2025. [link]
Shamir, O. and Zhang, T. Stochastic gradient descent for non-smooth optimization: Convergence results and optimal averaging schemes. ICML, 2013. [link]
Zamani, M. and Glineur, F. Exact convergence rate of the last iterate in subgradient methods. arXiv:2307.11134, 2023. [link]
Zhai, X., Kolesnikov, A., Houlsby, N., and Beyer, L. Scaling vision transformers. CVPR, 2022. [link]
Orabona (2020) analyses the schedule \(\eta_t = 1/\sqrt{t}\) in the convex, Lipschitz setting. With a slight modification of their analysis (see Cor. 3 therein), if \(G\) denotes the Lipschitz constant of the objective \(f\), then we get
\[\mathbb{E}[f(x_T) - f_\star] \leq \frac{1}{2\sqrt{T}} \big[\frac{D^2}{\gamma} + \gamma G^2(1+\log(T)) + 3\gamma G^2(1+\log(T-1)) \big].\]Minimizing this bound with respect to \(\gamma\), we obtain \(\gamma^\star = \frac{D}{G\sqrt{1+\log(T)+ 3(1+\log(T-1))}}\). Plugging this back into the bound, we obtain
\[\mathbb{E}[f(x_T) - f_\star] \leq \frac{DG}{2\sqrt{T}}\mathcal{O}(\sqrt{\log(T)}).\]We also point towards Shamir and Zhang (2013), who derive the bound \(\frac{DG(2+\log{T})}{\sqrt{T}} \) if \(\gamma=D/G \), however assuming a bounded domain.
]]>To simplify this, I created a single bib file with all published papers for NeurIPS, ICML and ICLR. The files are available on Github.
Some remarks and caveats:
Having a database of papers published at the major ML conferences, we can also do some simple data analysis.
For this, I used the library bibtexparser to create a csv file with title, authors and year of each paper.
(Note that the ICML bib even contains the abstracts, which would allow for a more detailled analysis.)
First, let’s plot the number of papers per venue each year. Unsurprisingly, the ML paper factory is growing exponentially fast.
Fig. 1: Number of accepted papers per year and conference.
Not only do we have more papers, but also the average number of authors per paper increased from (approximately) three to five within 2010-2024.
Fig. 2: Number of authors per paper (computed across all conferences).
To wrap up, a lazy approach to finding historical trends in ML topics is to count papers that have specific keywords in their title.
Maybe because I started to do research in the machine learning field rather late, I sometimes find it quite hard to understand the historical context of certain topics; below is a selection of well-known keywords over time, that might serve as a proxy.
Fig. 3: Historical timelines: percentage of papers (per conference and year) with the paper title containing certain keywords.
Consider the training problem
\[\nonumber \min_{w\in \mathbb{R}^d} \ell(w),\]where \(\ell: \mathbb{R}^d\to \mathbb{R}\) is a loss function, and \(w\) are learnable parameters. Assume that the loss is given as \(\ell(w) = \mathbb{E}_{x\sim \mathcal{P}} [\ell(w,x)]\), where \(x\) is a batch of data, sampled from the training data distribution \(\mathcal{P}\).
Suppose we want to solve this problem with stochastic gradient methods. Let us introduce some notation: we denote
The learning rate in iteration \(t\) will be given by \(\alpha_t := \alpha \eta_t\). We will often refer to \(\alpha\) as learning rate parameter, which is slightly inprecise, but for most of the contents the schedule \(\eta_t\) will be constant anyway.
The arguably most widely used method for training large-scale machine learning models is AdamW. It has been proposed by Loshchilov and Hutter and its main feature is that it handles weight decay separate from the loss \(\ell\) (as opposed to the original Adam [3]). For readers not familiar with AdamW, we refer to [1] and briefly explain the AdamW update formula below.
A short outline of this post:
Disclaimer: I have previously written a blog post for the ICLR 2023 blog post track [2], that discusses the weight decay mechanism of AdamW, and how it can be seen as a proximal version of Adam (the blog post is based on the paper by Zhuang et al [4]). This post will re-use some of the figures and contents. In fact, I stumbled upon the central question of this blog post during writing back then.
The quantities involed in AdamW are mostly the same as in the original version of Adam: let \(g_t=\nabla \ell(w_t,x_t)\) be the stochastic gradient in iteration \(t\), then for \(\beta_1,\beta_2\in[0,1)\) we compute
\[m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t, \nonumber \\ v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t \odot g_t. \nonumber \\\]The bias-corrected quantities are then given by
\[\hat m_t = m_{t}/(1-\beta_1^{t}), \nonumber \\ \hat v_t = v_{t}/(1-\beta_2^{t}) . \nonumber \\\]Let us denote the Adam preconditioner by \(D_t = \mathrm{diag}(\epsilon + \sqrt{\hat v_t})\). The way that AdamW was proposed originally in the paper by Loshchilov and Hutter [1] is
\[w_{t+1} = (1-\lambda\eta_t)w_t - \alpha_t D_t^{-1}\hat m_t. \tag{AdamW-LH}\]In Pytorch, the method is implemented slightly differently:
\[w_{t+1} = (1-\lambda\alpha_t)w_t - \alpha_t D_t^{-1}\hat m_t. \tag{AdamW-PT}\]Note that the only difference consists in the coefficient \(1-\lambda\alpha_t= 1-\lambda\alpha\eta_t\) instead of \(1-\lambda\eta_t\). While this seems like trivia at first sight (one could easily reparametrize \(\lambda\) by \(\lambda \alpha)\), we will show that it has an important practical implication on tuning.
Remark: The implementation of AdamW is the same in Optax as in Pytorch. Hence, what follows applies similarly to tuning AdamW in Optax.
So what do we mean when we say that learning rate \(\alpha\) and weight decay \(\lambda\) are decoupled? We will work with the following (approximate) definition: we say that \(\alpha\) and \(\lambda\) are decoupled, if the optimal choice for \(\lambda\) does not depend on the choice of \(\alpha\). Here, we mean optimal with respect to some metric of interest - for the rest of the blog post, this metric will be the loss \(\ell\) computed over a validation dataset.
The graphic below illustrates this phenomenon: imagine, we draw a heatmap of the validation loss over a \((\alpha,\lambda)\) grid. Bright values indicate a better model performance. Then, in a coupled scenario (left) the bright valley could have a diagonal shape, while for the decoupled scenario (right) the valley is more rectangular.
Fig. 1: Model performance (bright = good) as a function of learning rate and weight decay parameters. Illustration taken from [2].
Note that in practice this can make a huge difference: in general, we need to tune over the 2D-space of \((\alpha,\lambda)\), assuming that all other hyperparameters are already set. The naive way to do this is a grid search. However, if we know that \(\alpha\) and \(\lambda\) are decoupled, then it would be sufficient to do two separate line searches for \(\alpha\) and \(\lambda\), followed by combining the best values from each line search. For example, if for each parameter we have \(N\) candidate values, this reduces the tuning effort from \(N^2\) (naive grid search) to \(2N\) (two line searches).
This motivates why the decoupling property is important for practical use.
One of the main contributions of the AdamW paper [1] was that it showed how to treat weight decay separately from the loss. This is declared in the paper as follows:
The main contribution of this paper is to improve regularization in Adam by decoupling the weight decay from the gradient-based update.
The authors also claim that their method decouples the weight decay parameter \(\lambda\) and the learning rate \(\alpha\) (which goes beyond decoupling weight decay and loss).
We provide empirical evidence that our proposed modification decouples the optimal choice of weight decay factor from the setting of the learning rate for […] Adam.
While this claim is supported by experiments in the paper, we will show next an example where there is no decoupling when using AdamW from Pytorch. The reason for this is, as we will show, the implementation subtlety we described in the previous section.
The experiment is as follows: we solve a ridge regression problem for some synthetic data \(A \in \mathbb{R}^{n \times d},~b\in \mathbb{R}^{n}\) with \(n=200,~d=1000\). Hence \(\ell\) is the squared loss, given by \(\ell(w) = \Vert Aw-b \Vert^2\).
We run both AdamW-LH and AdamW-PT, for a grid of learning-rate values \(\alpha\) and weight-decay values \(\lambda\). For now, we set the scheduler to be constant, that is $\eta_t=1$. We run everything for 50 epochs, with batch size 20, and average all results over five seeds.
Below is the final validation-set loss, plotted as heatmap over \(\alpha\) and \(\lambda\). Again, brighter values indicate lower loss values.
Fig. 2: Final validation loss (bright = low) as a function of learning rate \(\alpha\) and weight decay parameter \(\lambda\).
This matches the previous illustrative picture in Figure 1 pretty well (it’s not a perfect rectangle for AdamW-LH, but I guess it proves the point)!
Conclusion 1: Using the Pytorch implementation AdamW-PT, the parameters choices for \(\alpha\) and \(\lambda\) are not decoupled in general. However, the originally proposed method AdamW-LH indeed shows decoupling for the above example.
Based on this insight, the obvious question is: what is the best (joint) tuning strategy when using the Pytorch version AdamW-PT? We answer this next.
Assume that we have already found a good candidate value \(\bar \lambda\) for the weight-decay parameter; for example, we obtained \(\bar \lambda\) by tuning for a fixed (initial) learning rate \(\bar \alpha\). Now we also want to tune the (initial) learning-rate value \(\alpha\).
Assume that our tuning budget only allows for one line search/sweep. We will present two options for tuning:
(S1) Keep \(\lambda = \bar \lambda\) fixed, and simply sweep over a range of values for \(\alpha\).
If \(\alpha\) and \(\lambda\) are decoupled, then (S1) should work fine. However, as we saw before, the Pytorch version of AdamW, called AdamW-PT, seems not to be decoupled. Instead, the decay coefficient in each iteration is given by \(1 - \alpha \lambda \eta_t\). Thus, it seems intuitive to keep the quantity \(\alpha \lambda\) fixed, which is implemented by the following strategy:
(S2) When sweeping over \(\alpha\), adapt \(\lambda\) accordingly such that the product \(\alpha \lambda\) stays fixed. For example, if \(\alpha = 2\bar \alpha\), then set \(\lambda=\frac12 \bar \lambda\).
Strategy (S1) is slightly easier to code; my conjecture is that (S1) is also employed more often than (S2) in practice.
However, and this is the main argument, from the way AdamW-PT is implemented, (S2) seems to be more reasonable. We verify this next.
We plot again the heatmaps as before, but now highlighting the points that we would actually observe by the tuning strategy (S1) or (S2). We show below
Here, we set \((\bar \lambda, \bar \alpha) =\) (1e-2,3.2e-1) for AdamW-PT, and \((\bar \lambda, \bar \alpha) =\) (1e-2,3.2e-3) for AdamW-LH.
In the below plots, the circle-shaped markers highlight the sweep that corresponds to the tuning strategy (S1) or (S2). The bottom plot shows the validation loss as a curve over the highlighted markers.
Fig. 3: Three different tuning strategies: (S1) for AdamW-PT (left), (S2) for AdamW-PT (middle) and (S1) for AdamW-LH (right). Top: Heatmap of final validation loss where the highlighted points show the results of the respective sweep. Bottom: A curve of the final validation loss at each of the highlighted points (learning rate increases from left to right on x-axis).
Note that the bottom plot displays the final validation-loss values that a practitioner would observe for the sweep of each respective tuning strategy. What is important is the width of the valley of this curve, as it reflects how dense the sweep would need to be to obtain a low final loss. The main insight here is: for the middle and right ones, it would be much easier to obtain a low final loss, as for the left one. This is important when the sweep has only few trials due to high computational costs for a single run, or other practical constraints.
Conclusion 2: when using the Pytorch version of AdamW (i.e. AdamW-PT), tuning strategy (S2) should be used. That is, when doubling the learning rate, the weight decay should be halved.
In fact, Figure 3 also shows that tuning strategy (S2) for AdamW-PT is essentially the same as strategy (S1) for AdamW-LH.
Summary and final remarks:
Implementation details can have an effect on hyperparameter tuning strategies. We showed this phenomenon for AdamW, where the tuning strategy should be a diagonal line search if the Pytorch implementation is used.
In the appendix, we show that the results are similar when using a square-root decaying scheduler for \(\eta_t\) instead.
This blog post only covers a ridge regression problem, and one might argue that the results could be different for other tasks. However, the exercise certainly shows there is no decoupling for AdamW-PT for one of the simplest possible problems, ridge regression. I also observed good performance of the (S2) strategy for AdamW-PT when training a vision transfomer on Imagenet (with the timm library).
If you want to cite this post, please use
@misc{adamw-decoupling-blog,
title = {How to jointly tune learning rate and weight decay for {A}dam{W}},
author = {Schaipp, Fabian},
howpublished = {\url{https://fabian-sp.github.io/posts/2024/02/decoupling/}},
year = {2024}
}

[1] Loshchilov, I. and Hutter, F., Decoupled Weight Decay Regularization, ICLR 2019.
[2] Schaipp F., Decay No More, ICLR Blog Post Track 2023.
[3] Kingma, D. and Ba, J., Adam: A Method for Stochastic Optimization, ICLR 2015.
[4] Zhuang Z., Liu M., Cutkosky A., Orabona F., Understanding AdamW through Proximal Methods and Scale-Freeness, TMLR 2022.
To validate that the effects are similar for non-constant learning rates, we run the same experiment but now with a square-root decaying learning rate schedule. That is \(\eta_t = 1/\sqrt{\text{epoch of iteration } t}\). We sweep again over the initial learning rate \(\alpha\) and weight decay parameter \(\lambda\). The results are plotted below:
Fig. 4: Same as Figure 3, but with a square-root decaying learning-rate schedule.
For completeness, this is the code we used for AdamW-LH. It is adapted from here.
class AdamLH(Optimizer):
""" AdamW with fully decoupled weight decay.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self._init_lr = lr
super(AdamLH, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamLH, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""
Performs a single optimization step.
Parameters
----------
closure : LossClosure, optional
A callable that evaluates the model (possibly with backprop) and returns the loss,
by default None.
loss : torch.tensor, optional
The loss tensor. Use this when the backward step has already been performed.
By default None.
Returns
-------
(Stochastic) Loss function value.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group['lr']
lmbda = group['weight_decay']
eps = group['eps']
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
# decay
p.mul_(1 - lmbda*lr/self._init_lr)
grad = p.grad
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
state['step'] += 1
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1-beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
update = -step_size * exp_avg / denom
p.add_(update)
return loss
The post is part of his (quite awesome) blog post series Optimization Nuggets and can be found here.


It is very hard to reach all goals at the same time. Typically, compiled languages such as C++ offer speed but lack the other two (at least for my taste). Some concrete examples of this tradeoff:
scikit-learn many classical solvers are written in C++ or cython, for example the logistic regression solvers.ProximalOperators.jl and ProximalAlgorithms.jl packages.However, the goal of this article is to present one approach of reaching all three goals in Python.
Consider problems of the form
\[\min_x f(x) + r(x)\]where we assume \(f\) to be continously differentiable and \(r\) is a (closed, convex) regularizer. For a step size \(\alpha>0\), the proximal gradient algorithm for such problems is given by the iterates
\[x^{k+1} = \mathrm{prox}_{\alpha r}(x^k- \alpha \nabla f(x^k)),\]where \(\mathrm{prox}\) is the proximal operator of a closed, convex function.
If we implement an algorithm for problems of the above type, it would be favourable to have code that works for any functions f and r fulfilling the respective assumptions. Moreover, as we have a composite objective, we would prefer to have a solver which we can call for any combination of f and r we would like - without adapting the code of the solver.
An obvious approach to achieve this, is handling both f and r as instances of classes, having the following methods:
f needs the method grad which computes a gradient at a specific point,r needs the method prox which computes the proximal operator of \(\alpha\cdot r\) at a specific point.Let us show the implementation for f being a quadratic function and r being the 1-norm.
class Quadratic:
def __init__(self, A, b):
self.A = A
self.b = b
def grad(self, x):
g = self.A @ x + self.b
return g
The below formula for the proximal operator is well-known but for the understanding it is not so important here.
class L1Norm:
def __init__(self, l):
self.l = l
def prox(self, x, alpha):
return np.sign(x) * np.maximum(np.abs(x) - alpha*self.l, 0.)
Now, proximal gradient descent can be implemented generally with the following simple function:
def prox_gd(f, r, x0, alpha=0.1, max_iter=50):
x = x0.copy()
for i in range(max_iter):
y = x-alpha*f.grad(x)
x = r.prox(y,alpha)
return x
This is general and very simple to read. If one would like to apply the algorithm to a different objective, he would only need to write the respective f and/or r. With this, a library of functions can be built and used modularly.
However, all of the above is pure Python code and will therefore be pretty slow. Our goal is to use Numba in order to accelerate the implementation while keeping generality and readability.
Numba is a package for just-in-time (JIT) compilation. It is designed to speed up pure Python code using the decorator @njit.
Numba support many functions built in numpy. A detailled list is here.
The speedup comes typically from for-loops - which naturally appear in optimization algorithms. Thus, we want to write our solver as a JIT-compiled numba function. The problem: everything that happens inside a JIT-compiled function must itself be compiled. Thus, if we want to make use of class methods inside the solver, the class must be such that every method is JIT-compiled. Luckily, numba offers this possibility using @jitclass.
When using @jitclass, it is important to specify the type of every attribute of the class. See the example below or the docs for all details. Our quadratic function class can be implemented as follows:
from numba.experimental import jitclass
from numba import float64, njit
spec = [
('b', float64[:]),
('A', float64[:,:])
]
@jitclass(spec)
class Quadratic:
def __init__(self, A, b):
self.A = A
self.b = b
def grad(self, x):
g = self.A @ x + self.b
return g
Same with the 1-norm:
spec_l1 = [('l', float64)]
@jitclass(spec_l1)
class L1Norm:
def __init__(self, l):
self.l = l
def prox(self, x, alpha):
return np.sign(x) * np.maximum(np.abs(x) - alpha*self.l, 0.)
Remark: @jitclass alone does not necessarily speed up the code. The main speedup will come from for-loops, typically appearing in the solver.
After implementing Quadratic and L1Norm as specific examples for f and r, we can now implement a numba-version of proximal gradient descent. We can pretty muchy copy the code and simply add the @njit decorator.
@njit()
def fast_prox_gd(f, r, x0, alpha=0.1, max_iter=50):
x = x0.copy()
for i in range(max_iter):
y = x-alpha*f.grad(x)
x = r.prox(y,alpha)
return x
Some remarks on the @njit decorator (mainly a reminder to myself):
float instead of int), numba will recompile the function (which takes longer).I implemented the pure Python and the Numba version of proximal gradient descent in this notebook.
For a simple 50-dimensional example with f being quadratic function and r the 1-norm, we get the following result:
# Python version
%timeit prox_gd(f, r, x0, alpha=0.001, max_iter=20000)
10 loops, best of 5: 164 ms per loop
# Numba version
%timeit fast_prox_gd(f, r, x0, alpha=0.001, max_iter=20000)
10 loops, best of 5: 54.2 ms per loop
Even for this simple example, we already get a speedup factor over 3. Of course, how much speedup is possible depends on how much of the computation is due to the loop or rather due to numerical heaviness (e.g. matrix vector multiplitication in high dimensions).
If the gradient or prox computation involves for-loops (e.g. Condat’s algorithm for total variation regularization), using numba will result in significant speedups in my experience.
The outlined approach can also be applied to stochastic algorithms where the number of iterations and thus the speedup is typically large. You can find some standard algorithms such as SGD, SAGA or SVRG in this repository.
Thanks for reading!

numba and cython: http://gouthamanbalaraman.com/blog/optimizing-python-numba-vs-cython.htmlThis article serves as a short checklist - mainly as a reminder to myself - for converting your research code into an open-source, distributable and well documented package. Some, but not all, steps might only apply to Python projects. Most of the individual steps are very well documented, so you can see this as a collection of websites/tutorials that helped me for my own projects.
Many of the following steps are much simplified if your code is already a Github repository.
If you aim to make your package available to others, it should have a license. While there are many standard open-source licenses around, be aware that your choice can make a difference in how others can use or redistribute your package. You can add a license directly over the Github page of your repository (link to docu).
A great introductory article on the legal background of open-source licenses is here.
When your project grows, at some point you might need to use some of your functions across multiple other scripts. In order to import from your module, you only need one additional file, a setup.py file, and install the module locally as a package in your (virtual) environment. Fortunately, a setup file is basically all you need in order to make your package distributable - for example with pip or conda.
A useful guide on how to create a setup file and make your package distributable with pip is here.
Other great resources with many details are this packaging guide and this introduction from the Python packaging authority.
If you ever had to become familiar with a code repository you did not write yourself, you will understand the importance of a proper documentation. Apart from the standard advice of using docstrings and comments where needed, you can also create and publish a documentation for your package as a whole. Typically, this could be included in the README of your repository. However, if your package becomes more complex and needs more explaining, you might consider creating a documentation on Read the docs. I will list the steps on how to achieve this (obviously other tools could be used, but I will describe the ones I used myself).
Create a documentation using Sphinx. This mainly involves writing .md or .rst files where you explain everything which is needed. Here is a guide on how to get started.
One of the great features of Sphinx is, that it can parse the docstrings of your functions into nicely-looking and readable websites (as you might know it from the docs of numpy). Moreover, you can include math formulas, cross-references or links into the docstrings. Like this, if you change the source code you will only need to update the docstrings and the documentation will be up-to-date automatically (see section Autodocs below).
Build the documentation locally (see below how to do that).
If you created the documentation files within the subfolder docs, the commands for this are as simple as
cd docs/
make html
.readthedocs.yaml on the top level of your repository. A minimal file example is here.main or master). Readthedocs will install all dependencies from requirements.txt (whereas Pypi uses the ones from setup.py).The Readthedocs documentation provides you with all the details about importing and building your documentation.
If your package gets more involved, having some text with pictures and automatically generated class and function documentation might not satisfy you. Fortunately, there are numerous ways to bring your documentation to the next level:
Explaining by example is often much more effective. Thus, show what your package can do by simply setting up small example scripts. Sphinx Gallery offers an easy way to include showcase examples in your documentation and beautifully embed plots, visualization and code snippets.
The main idea is very simple: in a subfolder of your repository (e.g. called examples) every Python script will be parsed. Every script with a filename starting with plot_ will be executed and all plots are shown.
As the gallery is an extension of Sphinx, it can be easily integrated into your configuration from step 3. All essential infos for getting started can be found in the documentation. Advanced configuration options, e.g. ignoring some of the files in the examples directory, are here.
As an alternative to an example gallery, you can also use Jupyter notebooks to create tutorials on how to use your package, or showcase some of its features.
Like on Github, where Jupyter notebooks are rendered directly in your browser when you open them, the nbsphinx package allows you to do the same on your readthedocs page. Simply create your tutorial notebooks, save them to your docs subfolder, and add the notebooks to your documentation index as described here.
From there, the possibilities are almost endless. You can even do things like linking to an interactive version of the notebook on Google colab or similar.
For projects with many submodules, it can be tedious to manually write a file that has links to all the documentation pages of different classes, functions etc. Fortunately, the autodocs extension for Sphinx can automatically generate them for you.
It can automatically generate documentation pages from the docstrings in an entire source file, class or function in your package.
Note: As you might have guessed, getting all docstrings rendered properly can be a little tricky at times, so try this out locally first before including it in your readthedocs page.
Even if you do not aim to open-source your code, you should include (unit) tests. This means writing test functions which ensure that your code is a) running without errors and b) giving the correct result.
Remark: even though it is the last point in this checklist, writing tests should be at best done while developing the package.
For example, if you wrote a function my_sqrt which should always return a non-negative result, you could add a test like this:
def test_my_sqrt():
a = my_sqrt(b)
assert a >= 0
return
Often, you want to assert that two numbers or arrays are equal up to some numerical inaccuracy. For this, numpy provides useful functionalities.
Using pytest, all files with filename starting or ending with test will be scanned for functions which start with the prefix test (as in the example above).
If you want to compute a coverage report (i.e. how many lines of code are included in one of the tests), use
pytest --cov=my_package my_package/
Make it easy for others to cite your software. A citation snippet can be added directly over Github.
Automate testing and building using Github actions. For example, you can create automatic coverage reports with Codecov.
Thank you for reading! Many thanks to Johannes Ostner for giving feedback and adding some ressources.
