This post investigates the following questions: Where should we steer a model? And how expressive can steering actually be?
By comparing steering and finetuning through a first-order lens, we find that often the best place to steer is after the skip connection, where attention and MLP outputs meet. Steering here is more expressive than steering individual submodules, and it ends up looking a lot closer to what finetuning does. Using this insight, we build lightweight post-block adapters that train only a fraction of the model’s parameters and achieve remarkably close performance to SFT.
Note: This post is aimed at readers comfortable with transformers and some linear algebra. We’ll keep the math light but precise.
Activation steering is an alternative to parameter-efficient finetuning (PEFT). Instead of updating weights, it directly edits a model’s hidden activations at inference time, cutting the number of trainable parameters by an order of magnitude. ReFT [1], for example, reaches LoRA-level performance while using 15×–65× fewer parameters. If we can reliably match finetuning with so few parameters, that changes what’s feasible: we can adapt larger models, run more expensive training objectives, or simply get more done under the same compute budget.
Existing steering methods mainly differ in where they apply these interventions: ReFT modifies MLP outputs (post-MLP), LoFIT [2] steers at attention heads (pre-MLP), and JoLA [3] jointly learns both the steering vectors and the intervention locations.
Despite empirical success, we still lack a clear understanding of why steering works and how to reason about where is the best steering location. Two key questions remain open:
This work investigates both questions.
First: Where should we steer? We begin with a simple, linearized analysis of a GLU’s output changes under weight updates compared to activation steering. This shows us that steering at some locations can easily match the behavior of certain weight updates, but not others. However, we notice that linear steering at these locations cannot completely capture the full behavior of weight updates.
Then: How expressive can steering really be? We then experiment with oracle steering, which, while not a practical method, provides a principled way to test which locations are best to steer at. With this tool, one pattern stands out: the most expressive intervention point is the block output, after the skip connection. Steering here can draw on both the skip-connection input and the transformed MLP output, instead of relying solely on either the MLP or attention pathway.
Motivated by this, we introduce a new activation adapter placed at each block output. It retains a LoRA-like low-rank structure but incorporates a non-linearity after the down-projection. This allows it to capture some of the nonlinear effects of SFT, giving activation steering a more expressive update space.
And finally: A bit of theory. No matter what the steering adapter is, if the adapter is able enough to match the fine-tuned model at each layer, the steered model will be able to match the fine-tuned model. So, how accurate must we be to match a fine-tuned model closely?
We also show that, at least in some settings relating to the geometry of these hidden states and the residuals at each module, post-block steering can replicate post-MLP steering. Also, we show that under some (very specific) parameter settings, post-MLP steering cannot learn anything, while post-block steering still can. This can be evolved into an approximation showing that, in broader, more applicable settings, post-MLP steering can get quite close to pre-MLP steering.
Throughout this article, we will be looking at a number of different places to steer, along with different ways that we can steer. Even if some of these choices don’t make sense at the moment, don’t worry! A lot of this will be explained much more throughout this article. Use this section as a reference for any unclear notation/names as you read.
We will use $\delta\cdot$ to represent small induced changes in our analysis, and $\Delta\cdot$ will represent changes to parameters, such as the finetuning updates to matrices $\Delta W$ or steering vector $\Delta h$.
First, a Transformer model is built from transformer blocks. Each block contains an Attention module and an MLP module. Unless otherwise specified, the MLP modules will be specifically GLU layers, a popular variant of standard 1-layer MLPs. The inputs to each submodule of each layer will pass through a LayerNorm. Each layer will involve two skip-connections, one around each submodule. This all can be seen in the picture below.
As for steering, there are 3 main variants we consider. (1) pre-MLP steering involves steering attention outputs; most commonly done by steering the output of individual attention heads, before skip-connection and normalization; (2) post-MLP steering involves steering the output of the MLP/GLU layer before it goes through the skip-connection; (3) post-block steering involves steering the output of each block, which can be seen as equivalently steering the output of the MLP/GLU layer after it goes through the skip connection. In our notation, a GLU is represented as
\[y_{\mathrm{GLU}}(h) = W_d(\sigma(W_g h) \odot W_u h).\]The matrices $W_d, W_g, W_u$ are called the down-projection, gated, and up-projection matrices respectively. When convenient, we will also write
\[y(h)=W_d m(h), \quad m(h) = \sigma(a_g) \odot a_u, \quad a_g = W_g h, \quad a_u = W_u h.\]For mathematical notation, the hidden state will be represented as a vector $h$ and steering will be represented as $\Delta h$. So, steering works by replacing $h$ with $h + \Delta h$.
Note: $\Delta h$, the steering vector, can depend on the input. Sometimes this is written explicitly as $\Delta h(h)$, but other times it is omitted.
With notation in place, we are now ready to begin our analysis!
Let’s start with the main question that drives the rest of our analysis:
At which points in the network can steering match the effect of updating the weights in that same module?
We will examine pre-MLP (steering attention outputs, like LoFIT) and post-MLP (steering the MLP output before the skip connection, like ReFT) steering as they are the most common choice in the literature. These spots nicely sandwich the MLP, so the most immediate module affected by steering at these points is the MLP itself. Our first step is simple: compare the output changes caused by steering at these locations to the output changes caused by tuning the MLP weights.
Note: Before we start, it will feel like there is a lot of math here, but we promise, everything in this section is linear algebra.
Let the MLP output be
\[y(h)=W_d m(h), \quad m(h) = \sigma(a_g) \odot a_u, \quad a_g = W_g h, \quad a_u = W_u h.\]finetuning the weights gives us
\[W_g \mapsto W_g + \Delta W_g, \quad W_u \mapsto W_u + \Delta W_u, \quad W_d \mapsto W_d + \Delta W_d.\]The updates $\Delta W_g$ and $\Delta W_u$ induce the following changes in their immediate outputs:
\[\delta a_g = (\Delta W_g) h, \quad \delta a_u = (\Delta W_u) h.\]A first order Taylor expansion of $m = \sigma(a_g) \odot a_u$ gives
\[\delta m = (\sigma'(a_g) \odot a_u) \odot \delta a_g + \sigma(a_g) \odot \delta a_u + \text{(higher order terms)}.\]Plugging in $\delta m$ into finetuning output gives us
\[y_{\mathrm{FT}} (h) = (W_d + \Delta W_d)(m+\delta m) \approx W_d m + \Delta W_d m + W_d \delta m.\]This yields the first-order shift caused by finetuning:
\[\boxed{ \begin{aligned} \delta y_{\mathrm{FT}} &\equiv y_{\mathrm{FT}}(h) - y(h) \\ &\approx (\Delta W_d) m + W_d [ (\sigma'(a_g) \odot a_u) \odot ((\Delta W_g) h) + \sigma(a_g) \odot ((\Delta W_u) h) ]. \end{aligned} }\]Plugging $\delta m$ into $y = W_d m$ gives us
\[\boxed{ \delta y_{\mathrm{pre}} \approx W_d [(\sigma'(a_g) \odot a_u) \odot (W_g \Delta h) + \sigma (a_g) \odot (W_u \Delta h)]. }\]Notice that the 2nd term in $\delta y_{\mathrm{FT}}$ is structurally similar to $\delta y_{\mathrm{pre}}$. What does this imply?
In principle, pre-MLP steering can match the shift caused by the updates $\Delta W_u$ and $\Delta W_g$, if there exist a $\Delta h$ such that $W_g \Delta h \approx (\Delta W_g) h$ and $W_u \Delta h \approx (\Delta W_u) h$.
Wait, what about the first term $(\Delta W_d) m$? For a $\Delta h$ to match this term, $(\Delta W_d) m$ must lie in a space reachable by pre-MLP steering. Let’s factor out $\Delta h$ to see what that space looks like:
\[\delta y_{\mathrm{pre}} \approx W_d [(\sigma'(a_g) \odot a_u) W_g + \sigma(a_g) W_u] \Delta h.\]Define $J(h) = (\sigma’(a_g) \odot a_u) \odot W_g + \sigma(a_g) \odot W_u \quad$ and $\quad A(h) = W_d J(h),$ where elementwise-product between a vector and a norm is performed row-wise, i.e. $a \odot W = \mathbf{1}a^\top \odot W$.
We can rewrite \(\delta y_{\mathrm{pre}} = A(h) \Delta h.\)
For pre-MLP steering to match finetuning MLP shift, we must have: \((\Delta W_d) m \in \text{col}(A(h)).\) What does this mean?
Pre-MLP steering can partially imitate MLP finetuning (the $\Delta W_g$ and $\Delta W_u$ effects), but matching the full MLP update is generally very hard, if not impossible.
Post-MLP steering directly modifies the output of the MLP $y$.
Because it acts after all the non-linearities inside the MLP, a sufficiently expressive parameterization could, in principle, reproduce any change made by finetuning the MLP weights. For example, if we were allowed a fully flexible oracle vector,
then adding this vector would give us the exact fine-tuned model output.
This already puts post-MLP steering in a much better position than pre-MLP steering when it comes to matching MLP weight updates. So are we all set with post-MLP steering as the way to go? Not quite.
Let’s look back at the structure of a Transformer block:
Post-MLP steering only modifies the MLP term.
But the block output is the sum of:
So even if post-MLP steering perfectly matches the MLP shift, it will not modify the skip-connection term, which would be needed to mimic a fine-tuned model. This is not to say that post-MLP cannot learn some more complex steering from linear layers to mimic the same effect, but that is, in a way, less natural. This ‘naturality’ shows itself in the experiments below.
Here’s a look at the relative scale of the outputs of the MLP and the Attention models at different layers in an LLM. Across layers, post-MLP steering covers at most ~70% of the block output that finetuning changes, and in some layers, as little as ~40%.
The rest remains untouched, meaning post-MLP steering on a block-specific level cannot fully replicate the effect of finetuning at that block.
Now that we’ve sorted out “where to steer” part of the story, the next piece of the puzzle is how much steering can actually do. In other words: how expressive can activation steering be? Since ReFT is the most widely used steering method today, and represents the strongest linear steering baseline, it’s the right place to begin.
ReFT does a post-MLP steering parameterized by
\[\delta y_{\mathrm{ReFT}} = \textbf{R}^\top (\textbf{W} y + \textbf{b} - \textbf{R}y).\]We use $\delta y$ to indicate this is happening after the MLP rather than before. The parameters $\textbf{R}, \textbf{W}, \textbf{b}$ are parameters, where $\textbf{R} \in \mathbb{R}^{r \times d_{\mathrm{model}}}$ has a rank $r$ and orthonormal rows, $\textbf{W} \in \mathbb{R}^{r \times d_{\mathrm{model}}}$, and $\mathbf{b} \in \mathbb{R}^r$.
Recall that the output of the MLP layer is $y = W_d m$, so we can write
\[\begin{align*} \delta y_{\mathrm{ReFT}} &= y_{\mathrm{ReFT}} - y = (\textbf{R}^\top \textbf{W} - \textbf{R}^\top\textbf{R})W_d m + \textbf{R}^\top b \\ &= \Delta W_d ^{\mathrm{eff}} m + \textbf{R}^\top \textbf{b}, \quad \quad \Delta W_d ^{\mathrm{eff}} = (\textbf{R}^\top \textbf{W} - \textbf{R}^\top\textbf{R})W_d \end{align*}\]Now, let’s compare $\delta y_{\mathrm{ReFT}}$ with the $\delta y_{\mathrm{FT}}$ we have from before. ReFT can induce a $\Delta W_d$-like update, but only within the subspace spanned by $\textbf{R}$. So its ability to mimic full finetuning depends on the nature of $\Delta W_d$ update, whether it is low-rank enough to fit inside that subspace.
The second term ($\textbf{R}^\top\textbf{b}$) can only reproduce $\delta y_{\mathrm{FT}}$’s $\Delta W_u$ and $\Delta W_g$ induced shift if it is approximately a linear function of the post-MLP output. This depends on how locally linear the mapping $h \mapsto y$ is. When these conditions hold, ReFT can approximate the effects of MLP weight updates reasonably well. However, as we show in our experiments (Table 1 in the next section), there are many situations where this does not hold, with ReFT performing significantly below the SFT model.
Now that we’ve seen that steering after the skip-connection provides us with the largest expressivity for steering, let’s see how good it can really be! In fact, our goal will be to match SFT, so let’s see how far we get.
The simplest (and strongest) thing when given the SFT model would be to let the steering vectors be the oracle from above. Just to recall,
\[\Delta h_{\mathrm{oracle}} = h_{\mathrm{FT}} - h_{\mathrm{base}}\]so, quite literally,
\[h_{\mathrm{steer}} = h_{\mathrm{base}} + \Delta h_{\mathrm{oracle}} = h_{\mathrm{FT}}\]
Okay, that’s a bit much. This is simply overwriting the hidden state of the base model with the hidden state of the fine-tuned model. However, this still provided a lot of insight into where to steer, and the properties of a desired steering method.
Taking all of these different oracle steering vectors, their properties can be looked at for patterns to exploit. We quickly found that these vectors were close to low-rank (their covariance had a concentrated spectrum). But be careful! Just because the oracle steering vectors almost exist in some low-dimension subspace, it does not mean the transformation from hidden states to steering vectors is linear! Sometimes it might be, sometimes it won’t.
In fact, if we try and replace the oracle with the best linear approximation of the map between hidden states and steering vectors, we find that the oracle goes from perfect matching to similar-to-or-worse-than ReFT!
Here, we are taking generations on some prompts and comparing the average KL divergence between the fully fine-tuned model’s output probabilities and those of ReFT’s and the linearized oracle’s. The oracle often deviates further from the fine-tuned model than ReFT does, although not by much.
The best way around this would be to learn a low-rank, non-linear function as the map between the hidden states and the steering vectors. What better than a small autoencoder! Its output space would be constrained by the column space of the up-projection, so it will still be low-rank. This isn’t a new idea to steer post-skip-connection ([4] [5] for example), but the systematic justification presented here is: when steering block-by-block, post-block steering will be the most expressive.
At this point, it seemed like we understood what would likely work as a steering vector, so we moved away from using the oracle as the gold-standard to match and now worked on training these adapters end-to-end.
Now we train these steering vectors without a guide. We add low-rank steering modules at the end of each block (similar in spirit to LoRA/PEFT), but applied to activations rather than weights. Based on the discussion above, we test two variants:
The nonlinear version is motivated by our earlier observation (and by [4] and [5]) that the map from hidden states to steering vectors may itself be nonlinear, while still being largely confined to a low-rank subspace.
Here are the results we’re currently seeing for 1B-parameter models:
For a fair comparison, we match the parameter counts of our adapters to the baselines.
Across both Llama-1B and Gemma-1B, the trend is consistent: simply moving the steering location to post-block leads to a substantial boost in performance. Under identical parameter budgets, our linear post-block steering outperforms ReFT, and our fixed-vector and rank-1 variants outperform LoFIT and JoLA respectively.
In a few cases, linear steering even outperforms LoRA (learning an adapter on every Linear module) with the same rank, and occasionally out-performs SFT using full rank! This shows there are situations where steering is a better choice than finetuning.
What is surprising about these results is that linear steering is performing better than non-linear steering. Non-linear steering is not necessarily more expressive than linear steering (with the same rank), and for these tasks, it seems to be that the steering really is linear-like, validating the choice in ReFT. In this situation, it would be better to let the steering have more rank as a pure linear model rather than with some non-linearity messing with this structure.
Now compare this to some larger, 4B-parameter models:
The behavior is different! Now, non-linear steering typically performs better than linear steering (except notably in Winogrande). This shift is likely due to optimization effects of the different scales of model tested here. At the larger scale, the loss function typically ends up being smoother/flatter than smaller scale models. This could mean the larger models can learn the more-expressive non-linear adapter over the easier-to-learn linear adapter.
Additionally, this shift could mean something more fundamental to the best adapters as well. If the adapter is well-suited as a linear adapter with the correct rank, the non-linear adapter would have to work around its non-linearity with something like large scale parameters to match the linear adapter. So, it’s possible that the smaller models need a small-rank linear adapter while the larger models need a large-rank non-linear adapter, but this is left to future work.
Now that we’ve seen how there is a clear difference between post-MLP and post-block steering, we now should ask why is there such a difference. What makes post-block steering better than post-MLP?
This difference, already highlighted, is that a post-block steer can depend on outputs of the attention layer without passing through the GLU. To gain some simple intuition, let’s take a 1-layer model, made of one Attention layer and one GLU:
\[y(x) = x + \mathrm{Attn}(x) + \mathrm{GLU}(x + \mathrm{Attn}(x))\]where $x$ is the input of the model.
The two steerings we will be comparing are a post-MLP steer, that only depends on the outputs of the GLU, and a post-block steer, which steers the model after the skip-connection is added back into the model (which for this model, will just be another steer at the very end that depends on the original model outputs). Let’s remind ourselves of the structure of the GLU/MLP:
\[y_{\mathrm{GLU}}(h) = W_d(\sigma(W_g h) \odot W_u h)\]Consider the (fairly extreme) case where $W_d = 0$. In this situation, the GLU will be identically 0, so any input-dependent steering $\Delta y_{MLP}$ will be a fixed vector. Compare this to the post-block steering which can still depend on $h + \mathrm{Attn}(h)$. This also will extend to cases where $W_d$ is not full-rank, where the output of the steering cannot depend on the directions in the null-space, since the only contributions of a pre-MLP steer in the directions perpendicular to $\text{col}(W_d)$ will be a fixed-vector, while post-MLP steering does not have this restriction.
This tells us:
There are some situations that post-MLP steering does not perform well while post-block steering does.
Is it the case that post-block steering always can do what post-MLP steering can do? Unfortunately, no, not always. After adding back the skip-connection, it isn’t necessarily true that this is invertible so that the skip-connection and GLU terms can be distinguished; their subspaces might overlap with each other. In the situation when this isn’t true, we can match post-MLP steering perfectly.
To gain some intuition, let’s restrict ourselves to the linear case, where post-MLP steering is a steering in the style of ReFT. Now, to assure invertibility, assume that the subspace spanned by the skip-connection and the subspace spanned by the MLP outputs have trivial intersection. That is to say that there is a projection map $P$ such that
\[P(h + \mathrm{Attn}(h) + \mathrm{GLU}(h + \mathrm{Attn}(h))) = h + \mathrm{Attn}(h)\]At this point, the equivalence is easy to see. If $A$ is the rank-$r$ linear projector for post-MLP steering, then $\Delta h(h) = A P h$ as a post-block steer will match the post-MLP steering perfectly. It will also be rank-$r$ since $AP$ is a rank-$r$ (or less) matrix.
Of course, this setting where these two vector spaces are completely separate is a bit extreme. Steering might need to be done in the same direction as the steered vector. At this point, some probability needs to be involved since any projection will no longer be perfect. However, we save this for the paper.
One last thing to mention is that, since we have an oracle at each layer, it’s now possible to measure how similar we are to that oracle. This is, of course, assuming that we have access to the fine-tuned model we want to mimic. Soon we will remove this and move towards a more oracle-free understanding, but for the moment, assume we have access to this oracle. We also now have fine-grained control about errors and how they grow through the model! At each layer, it’s possible to develop an approximate steering vector $\Delta h’$ to match the oracle $\Delta h_{\mathrm{oracle}}$ closely enough that no layer has an error between them of more than some $\epsilon$.
Importantly, this gives us knowledge of how complex our adapters need to be. For example, if our update is nearly linear, we can take enough rank $r$ so that the approximation is within $\epsilon$ to ensure that that error does not grow out of control through the model.
The exact form of this required control is beyond the scope of this blog. However, there are a few key insights.
We plan to further improve these bounds with nicer assumptions based on the behavior of real models soon!
Note:: The model we analyze in this theory section is a drastically simplified stand-in for a real block, but the same geometry: post-block can see both Attn and GLU, post-MLP only sees GLU, persists in deeper models.
If you’ve made it this far—kudos! That’s pretty much all we have to say (for now, wink). To conclude, here are some highlights of everything we unpacked:
- Pre-MLP vs. Post-MLP: They behave very differently, and post-MLP generally does a better job matching MLP weight updates.
- But Post-Block is generally better than Post-MLP. Steering the residual stream, not individual module outputs, is the real sweet spot.
- With only 0.04% trainable parameters (compared to LoRA’s 0.45% using the same rank), our method at this post-block location gets remarkably close to SFT, which updates all parameters.
Lowering the number of trainable parameters is more than an efficiency win. It buys us compute headroom for adapting larger models, running more expensive algorithms like RL, and we can do all of this within realistic compute budgets.
Because our framework reasons about steering and weight updates in their most general form (arbitrary $\Delta h$ and arbitrary $\Delta W$), there’s nothing stopping those updates from being learned by far more expensive methods (e.g., GRPO or other RL-style objectives). So the natural next step is clear: test whether our steering can plug into these stronger algorithms and deliver the same performance with a fraction of the trainable parameters.
And now that we know it’s possible to learn nearly as well in activation space as we can in weight space in even more setting than ReFT originally could, an even more ambitious idea opens up:
What happens if we optimize in both spaces together?
Perhaps, jointly optimizing them might help us break past the limitations or local minima that each space hits on its own?
Stay tuned to find out!
We’ve packaged everything — activation adapters, post-block steering layers, training scripts — into a clean, lightweight library.
👉 **Try our code on GitHub** 🐙
[1] Wu, Z., Arora, A., Wang, Z., Geiger, A., Jurafsky, D., Manning, C., & Potts, C. (2024). ReFT: Representation Finetuning for Language Models.
[2] Fangcong Yin, Xi Ye, & Greg Durrett (2024). LoFiT: Localized finetuning on LLM Representations. In The Thirty-eighth Annual Conference on Neural Information Processing Systems.
[3] Lai, W., Fraser, A., & Titov, I. (2025). Joint Localization and Activation Editing for Low-Resource finetuning. arXiv preprint arXiv:2502.01179.
[4] Houlsby, N., Giurgiu, A., Jastrzebski, S., Morrone, B., De Laroussilhe, Q., Gesmundo, A., … & Gelly, S. (2019). Parameter-efficient transfer learning for NLP. In International conference on machine learning.
[5] Tomanek, K., Zayats, V., Padfield, D., Vaillancourt, K., & Biadsy, F. (2021). Residual adapters for parameter-efficient ASR adaptation to atypical and accented speech. In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing.
Implementation details:
For all methods, we sweep across 5 learning rates and keep other hyperparameters constant (batch size, scheduler, warmup ratio, weight decay). We keep the same learning rate sweep space for ours, ReFT, JoLA, and LoFIT, and slightly shift to smaller values for SFT and LoRA (since they train substantially more parameters).
The ReFT paper also treats the tokens to steer as a hyperparameter. After selecting the best learning rate, we sweep over two locations: the last prompt token (the default value) and (prefix+7, suffix+7) (the best configuration reported for GSM8K). Our method does not require this sweep, it intervenes at all token positions (both prompt and generated).
Over the past decade, natural language processing (NLP) has evolved significantly. Beginning with word2vec in 2013. Word2vec’s word embeddings can capture nuanced lexical relationships through vector arithmetic. The famous equation “king – man + woman ≈ queen.” is a prime example of how word co-occurrences were used to embed semantic meaning. This approach emphasis on maintaining synonym and hypernym relationships stands in contrast to CLIP’s text encoder, which is primarily optimized for aligning images with text rather than preserving these detailed linguistic structures. As NLP developed, GloVe in 2014 improved word representations using global co-occurrence statistics. In 2018, BERT introduced a transformer-based approach, providing context-dependent word representations and setting new performance benchmarks. Following this, large language models (LLMs) like GPT-2 and GPT-3 emerged and further refined language understanding. While these models shifted focus toward richer contextual embeddings, the legacy of explicitly encoding relational properties gradually diminished, making modern VLMs like CLIP sensitive to linguistic variations.
We will start our story with an observation as following:
One major issue is that CLIP’s latent space struggles to capture semantic differences in their pretraining stage. Consider these observations:
These challenges underscore why traditional methods—such as manual prompt engineering or static fine-tuning—often fall short. They are typically time-consuming, narrowly focused, and lack the ability to generalize across varied datasets or scenarios.
Our approach focuses exclusively on fine-tuning the text encoder while leaving the image encoder unchanged. The key idea is to explicitly incorporating lexical relationships and adjust the text embedding space such that:
Inspired by network embedding methods that are explicitly designed to preserve structure, we seek a metric that captures these inherent relationships. This motivation leads us to the Wu-Palmer Similarity, which is a measure used to quantify the semantic similarity between two concepts within a taxonomy, like WordNet. It relies on the depth of the concepts and their Least Common Subsumer (LCS), which is the most specific ancestor common to both concepts.
The similarity between two concepts ( c_1 ) and ( c_2 ) is given by:
\[\text{sim}_{wup}(c_1, c_2) = 2 \times \frac{\text{depth}(LCS(c_1, c_2))}{\text{depth}(c_1) + \text{depth}(c_2)}\]where:
Our training objective is to fine-tune the text encoder using a composite loss that consists of two parts: a distance loss and regularization loss. Given two tokenized word vectors $w_i$ and $w_j$ generated by the model $M$, the losses are defined as follows:
This component enforces that the cosine similarity between the embeddings $M(w_i)$ and $M(w_j)$ approaches a target similarity derived from the Wu-Palmer metric. Mathematically, it is given by:
\[\mathcal{L}_{\text{distance}}(w_i, w_j) = \left(c\times \Bigl( s_{\text{WP}}(w_i, w_j) - \cos\theta \bigl(M(w_i), M(w_j)\bigr) \Bigr) \right)^2\]where:
To prevent the fine-tuning process from deviating too far from the original embeddings, we introduce a regularization term. This is computed as the Euclidean Distance between the current embedding $M(w)$ and the precomputed original embedding $M_0(w)$ for each word $w$:
\[\mathcal{L}_{\text{reg}}(w) = \text{Euclidean}\Bigl( M(w), \, M_0(w) \Bigr)\]This term is scaled by a regularization strength multiplier $\lambda$
The combined loss function that is minimized during training is:
\[\mathcal{L} = \mathcal{L}_{\text{distance}} + \lambda \, \mathcal{L}_{\text{reg}}\]This loss encourages the model to adjust its embeddings so that:
The training process involves iteratively updating the text encoder based on the computed loss over all word pairs. The high-level algorithm is as follows:

By applying our method, we can
This innovative method take a step toward overcoming the linguistic rigidity of CLIP, potentially paving the way for more robust and versatile vision-language applications.
Stay tuned as we delve deeper into the experimental insights and future directions of this research.
In this experiment, we focus exclusively on subsets of class labels for which synonyms or hypernyms are available in WordNet. To minimize potential disturbances, we ensure that all words within each subset are unique during the subset creation process.
For the hypernym setting, we limit our use to direct (level-1) hypernyms in the WordNet hierarchy, as higher-level hypernyms tend to be overly abstract and less applicable in practical real-world scenarios.
To evaluate the effectiveness of our method, we compare the zero-shot classification accuracy with the original pre-trained model1 and our fine-tuned version on two sets of class labels: the original class names and the class names replaced by their corresponding synonyms or hypernyms. We select distinct subset and combination to reduce variations
Our method improves classification performance in both synonym setting and hypernym setting across different datasets. In ImageNet, we also conduct experiment on Mixing synonyms and hypernyms
As a starting point, we test our method’s capability on Fer2013 dataset. The tuned model demonstrates a notable increase in accuracy for both the original class labels and those replaced by synonyms whenever we use sentence template in the classification


Our ImageNet experiments show that fine-tuning not only boosts accuracy when we swap label names with synonyms or hypernyms, but it also improves zero-shot performance using the default labels. This means our method helps the base model get a better handle on semantics and generalize more effectively.
We conduct our experiment in all of three setting below.



To show adaptability of our approach, we also craft a subset of OpenImage to conduct similar experiment as we did imagenet. We observe a similar pattern of improvement


Now, let’s see how well our method generalizes. We’ll show that a model fine-tuned on different ImageNet subsets can also boost the classification accuracy on the OpenImage subset.


In this post, We explored a simple way to fine-tune CLIP’s text encoder without heavy computation—by aligning synonyms and hypernyms in the text embedding space. This tweak improves zero-shot classification accuracy across various datasets without even needing image content. Looking ahead, we’ll refine this approach and test its real-world applications to better connect language and vision.
Our exploration of fine-tuning CLIP’s text encoder has revealed several critical challenges and exciting directions for future research:
Scalability/Polysemy Challenges:
In WordNet, 31989 out of 148730 words have polysemy. This inherent ambiguity could compromise the integrity of the underlying data structure as we scale up, necessitating advanced techniques to manage multiple word meanings effectively. In the experiment part, we also observe a decreasing marginal gain on our proposed method when we increase the number of classes. While 53811 out of 117659 words have synonym in wordnet, scaling is another concern. Both of which underscoring the need for scalable and robust solutions.
Adapting to Image-Caption Datasets:
For broader applicability, we need to adjust the current methodology to work with image-caption datasets like LAION and Conceptual Captions. This adaptation could pave the way for more versatile and comprehensive vision-language models.
Limitations with Propositional Words:
Frameworks like CLIP struggle with propositional terms such as not, is a, or comparative expressions like more/less than. These limitations hinder the model’s ability to fully grasp complex semantic relationships.
We hope you enjoyed our post! Our code is also released in Github.
]]>
The ability to generate high-quality synthetic tabular data has far-reaching applications across numerous domains. One of the most compelling use cases is medical research, where strict privacy regulations prevent real patient datasets from being widely shared. These restrictions, while essential for protecting sensitive information, can slow down collaboration and innovation. If researchers could instead generate and share synthetic patient datasets that mimic real data while ensuring privacy, medical discoveries could be accelerated on a global scale—enabling scientists to uncover insights without compromising patient confidentiality.
Table synthesis also plays crucial roles in data augmentation and missing value imputation. In many real-world scenarios, such as observations of rare weather events, collecting large, high-quality datasets is expensive or impractical (feel free to check out my current work at NASA on this exact topic, for atypically powerful and destructive hurricanes). Additionally, many datasets suffer from incomplete records, and advanced synthesis techniques can be used to intelligently fill in missing values, preserving the integrity and usability of the data. For example, missing data is a common challenge in air traffic control when analyzing flight patterns and schedules. Reliable tabular missing value imputation methods would allow aviation authorities to reconstruct missing flight paths, estimate delay probabilities, and optimize air traffic flow—even when real-time data is incomplete.
Tabby introduces a set of architectural modifications that can be applied to any transformer-based language model (LM), enabling it to generate high-fidelity synthetic tabular data. At its core, Tabby incorporates Gated Mixture-of-Experts (MoE) layers with column-specific parameter sets, allowing the model to better represent relationships between different table columns. These modifications introduce the necessary inductive biases that help the LM model structured tabular data rather than free-form text.
The figure below compares a Tabby model to a standard, non-Tabby “base” LLM, when intended for use with a dataset that has V columns:
Despite these significant architectural changes, Tabby’s fine-tuning process remains straightforward—closely mirroring the pre-existing approaches for adapting LMs to tabular data. Moreover, Tabby is designed to retain and leverage the knowledge gained during the large-scale text pre-training phase, allowing for faster and more efficient adaptation to structured datasets.
Beyond Tabby itself, we also introduce Plain—a lightweight yet powerful training method for fine-tuning LMs (both Tabby and non-Tabby) on tabular data. Plain consistently improves the quality of synthetic data generation, regardless of the underlying LM. If you’re curious about how it works, check out our paper for the full details!
To assess Tabby’s performance, we follow standard benchmarks for tabular data synthesis, using a diverse set of datasets and the widely-accepted Machine Learning Efficacy (MLE) metric to measure data quality.
Datasets: We train Tabby models on six datasets spanning various domains and sizes:
| Name | # Rows | # Columns | Domain |
|---|---|---|---|
| Diabetes | 576 | 9 | Medical |
| Travel | 715 | 7 | Business |
| Adult | 36631 | 15 | Census |
| Abalone | 3132 | 9 | Biology |
| Rainfall | 12566 | 4 | Weather |
| House | 15480 | 9 | Geographical |
Metrics: MLE measures how well synthetic data preserves real-world patterns by comparing the performance of machine learning models trained on synthetic vs. real data. The closer the synthetic data’s MLE score is to the real data’s, the higher the fidelity of the synthetic dataset.
In the table below, the first row represents the MLE score of real (non-synthetic) data. Any synthetic method that matches or surpasses this score is considered to have reached parity with real data.
Tabby achieves parity with real data on three out of six datasets (Diabetes, Travel, and Adult). Additionally, Tabby outperforms the prior best LLM-based tabular synthesis method on all six datasets.
| Name | Diabetes | Travel | Adult | Abalone | Rainfall | House |
|---|---|---|---|---|---|---|
| Upper Bound | 0.73 | 0.87 | 0.85 | 0.45 | 0.54 | 0.61 |
| Our Best Tabby Model | 0.74 ✅ | 0.88 ✅ | 0.85 ✅ | 0.43 | 0.49 | 0.60 |
| Prior best LM approach | 0.72 | 0.87 | 0.83 | 0.40 | 0.05 | 0.55 |
✅: Parity with real data!
One of the most exciting discoveries in our work is that Tabby is not limited to tabular data. Unlike previous tabular synthesis models, Tabby successfully adapts to other structured data formats too, such as nested JSON records.
Why does this matter?
Most prior tabular synthesis methods struggle when faced with hierarchical or nested structures, which are common in web data, API responses, and metadata-rich datasets. Tabby’s architecture enables it to capture these structures effectively, opening up new possibilities for structured data generation.
We’re excited to explore this direction further and believe Tabby could be a foundation for generating a much broader class of structured synthetic data beyond just tables.
Tabby is a highly promising and easy-to-use approach for generating realistic synthetic tabular data. By leveraging Mixture-of-Experts (MoE) layers and the Plain training process, Tabby achieves parity with real data while outperforming previous tabular synthesis methods.
In our paper, we dive deeper into:
✅ The technical details behind Tabby’s architecture
✅ The Plain training process, which boosts data quality across different LLMs
✅ Extensive evaluations and comparisons against prior methods
For a deeper look, check out our paper or code. Feel free to reach out at [email protected]—I’d love to hear your thoughts!
]]>For instance, labeling a dataset of 7,569 points with GPT-4 could cost over $1,200. Even worse, the resulting labels are static, making them difficult to tweak or audit.
Instead of directly prompting LLMs for labels, we ask LLMs to generate programs that act as annotators. These synthesized programs encode the LLM’s labeling logic and can either label data directly or label a training dataset used to train a distilled specialist model.
We built Alchemist, a system that implements this approach. Empirically, Alchemist improves labeling performance on five out of eight datasets, with an average accuracy boost of 12.9%, while reducing costs by approximately 500×.
Step 1: Generate Labeling Programs We start with an unlabeled dataset—-such as YouTube comments or medical abstracts—-and provide an LLM with a simple prompt, instructing it to generate a labeling program that labels the data.
Step 1. Prompt the LLM for Programs We start with an unlabeled dataset (e.g., YouTube comments or medical abstracts). Write a simple prompt, an instruction, telling the LLM what you want—like a function to label spam (1) or ham (0). These prompts can integrate relevant information and may vary in their design, allowing for the synthesis of multiple programs.
Example prompt:
[Task Description] Write a bug-free Python function to label YouTube comments as spam or ham.
[Labeling Instruction] Return 1 for spam, 0 for ham, -1 if unsure.
[Function Signature] def label_spam(text_comment):
And the generated program:
def label_spam(text_comment):
"""
Classifies YouTube comments as spam (1), ham (0), or unsure (-1).
"""
if not isinstance(text_comment, str) or not text_comment.strip():
return -1
text = text_comment.lower()
# Key spam indicators
spam_phrases = ["sub4sub", "subscribe to my", "check out my channel", "follow me",
"make money", "click here", "free gift", "www.", "http", ".com"]
# Check for spam indicators
if any(phrase in text for phrase in spam_phrases):
return 1
# Check for suspicious patterns
suspicious = (
text.count('!') > 3 or
text.count('?') > 3 or
(len(text) > 10 and text.isupper()) or
any(char * 3 in text for char in "!?.,$*") or
any(segment.isdigit() and len(segment) >= 10 for segment in text.split())
)
return -1 if suspicious else 0
A single program might not capture all aspects of the labeling logic. To improve robustness, Alchemist generates multiple programs with diverse heuristics—some using keyword matching, others leveraging more complex patterns.
Step 2: Aggregate Labels with Weak Supervision The generated programs may be noisy or inconsistent. To address this, Alchemist uses weak supervision framework (such as Snorkel) to aggregate their outputs into a single, high-quality set of pseudolabels.
Step 3: Train a Local Model We can either use the pseudolabels directly or train a small, specialized model (e.g., a fine-tuned BERT model) to generalize the labeling logic. This allows completely local execution—-no further API calls required.
Alchemist isn’t limited to text data. For non-text modalities like images, we introduce an intermediate step:
Concept Extraction: We first prompt the LLM to identify key concepts relevant to the classification task. For example, in a waterbird vs. landbird categorization task, the model may identify “wing shape,” “beak shape,” or “foot type” as distinguishing characteristics.
Feature Representation: Feature Representation: A local multimodal model (e.g., CLIP) extracts features corresponding to these concepts. This step generates low-dimensional feature vectors that can be effectively used by the labeling programs.
Program Synthesis: Using the extracted features and their similarity scores, we prompt the LLM to generate labeling programs that automate the annotation process.
We use eight text domain datasets to evaluate Alchemist. We use GPT-3.5 to generate 10 labeling programs for each dataset and compare labeling performance to LLM zero-shot prompting. Here are the results.
Next, we validate the extension of Alchemist to richer modalities. We extract features for the key recognized concepts by employing CLIP as our local feature extractor. Then, we converts these feature vectors to produce a set of similarity scores. Armed with these scores, we describe scores associated with their concepts in prompts and ask GPT4o and Claude 3 for 10 programs. Results show that Alchemist achieves comparable performance on average accuracy while improving robustness to spurious correlations. This is a key strength of Alchemist: targeting salient concepts to be used as features may help move models away from spurious shortcuts found in the data. This validates Alchemist’s ability to handle complex modalities while improving robustness.
We propose an alternative approach to costly annotation procedures that require repeated API requests for labels. Our solution introduces a simple notion of prompting programs to serve as annotators. We developed an automated labeling system called Alchemist to embody this idea. Empirically, our results indicate that Alchemist demonstrates comparable or even superior performance compared to language model-based annotation, improving five out of eight datasets with an average enhancement of 12.9%.
🙋🏻 Still, prompting ChatGPT for your labels repeatedly? Try to generate your program code to save the project expenses!
Imagine you can steer a language model’s behavior on the fly- no extra training, no rounds of fine-tuning, just on-demand alignment. In our paper, “Alignment, Simplified: Steering LLMs with Self-Generated Preferences”, we show that this isn’t just possible—it’s practical, even in complex scenarios like pluralistic alignment and personalization.
Traditional LLM alignment requires two critical components: (1) collecting large volumes of preference data, and (2) using this data to further optimize pretrained model weights to better follow these preferences. As models continue to scale, these requirements become increasingly prohibitive—creating a bottleneck in the deployment pipeline.
This problem intensifies when facing the growing need to align LLMs to multiple, often conflicting preferences simultaneously (Sorensen et al., 2024), alongside mounting demands for rapid, fine-grained individual user preference adaptation (Salemi et al., 2023).
Our research challenges this status quo: Must we always rely on expensive data collection and lengthy training cycles to achieve effective alignment?
The evidence suggests we don’t. When time and resources are limited—making it impractical to collect large annotated datasets—traditional methods like DPO struggle significantly with few training samples. Our more cost-effective approach, however, consistently outperforms these conventional techniques across multiple benchmarks, as demonstrated below:
These results reveal a clear path forward: alignment doesn’t have to be a resource-intensive bottleneck in your LLM deployment pipeline. Enter AlignEZ—our novel approach that reimagines how models can adapt to preferences without the traditional overhead.
At its core, AlignEZ enables the (non-trivial) combination of two most cost-efficient choice of data and algorithm–using self-generated preference data and cut down the compute cost by replacing fine-tuning with embedding editing. This combination is non-trivial for several reasons:
Now that hopefully have convinced you why this is the way to go, let’s break down how AlignEZ works, in plain English:
Instead of collecting human-labeled preference data, AlignEZ lets the model create its own diverse preference pairs. Diversity is key to capturing a broad range of alignment signals, ensuring we capture as much alignment signal as possible. We achieve this through a two-step prompting strategy:
By applying this process across our dataset, we develop a rich preference dataset where each query is paired with multiple responses that reflect various dimensions of “helpful” and “unhelpful” behavior.
Importantly, we recognize that the initial batch of generated data may contain significant noise—often resulting from the model failing to properly follow the conditioned characteristic. As a critical first filtering step, we eliminate samples that are too similar in the embedding space, a characteristic that research by (Razin et al., 2024) has shown to increase the likelihood of dispreferred responses.
With our self-generated preference data in hand, we next identify the alignment subspace within the LLM’s latent representation. Our approach adapts classic techniques from embedding debiasing literature (Bolukbasi et al., 2016) that were originally developed to identify subspaces representing specific word groups.
Formally, let $\Phi_l$ denote the function mapping an input sentence to the embedding space at layer $l$, and each preference pair as $(p_i^{help}, p_i^{harm})$. Firt, we construct embedding matrices for helpful and harmful preferences:
\[\begin{equation} \textbf{H}_{l}^{help} := \begin{bmatrix} \Phi_{l}(p_1^{help}) \\ \vdots \\ \Phi_{l}(p_K^{help}) \end{bmatrix}^T, \quad \textbf{H}_{l}^{harm} := \begin{bmatrix} \Phi_{l}(p_1^{harm}) \\ \vdots \\ \Phi_{l}(p_K^{harm}) \end{bmatrix}^T, \end{equation}\]where $K$ is the total number of preference pairs. Next, alignment subspace is identified by computing the difference between the helpful and harmful embeddings:
\[\begin{equation} \textbf{H}_{l}^{align} := \textbf{H}_{l}^{help} - \textbf{H}_{l}^{harm}. \end{equation}\]We then perform SVD on $\textbf{H}_{l}^{align}$:
\[\begin{equation} \textbf{H}_{l}^{align} = \textbf{U}\Sigma\textbf{V} \\ \Theta_l^{align} := \textbf{V}^T, \end{equation}\]An important trick we add here is to remove subspace directions that are already well-represented in the original LLM embedding. Formally for a query $q$:
\[\begin{equation} \Theta_{l,help}^{align}(q) := \left\{\,\theta \in \Theta_l^{align} \,\middle|\, \cos\left(\Phi_l(q),\theta\right) \leq 0 \right\}, \end{equation}\]This prevents any single direction from dominating the editing process and ensures we only add necessary new directions to the embedding space.
Finally, during inference when generating a new response, we modify the model’s hidden representations by projecting them in the direction of the alignment subspace $\Theta_l^{align}$. Our editing process is as follow:
\[\begin{aligned} \hat{x}_l &\leftarrow x_l,\\ \text{for each } \theta_l \in \Theta_l^{align}:\quad \hat{x}_l &\leftarrow \hat{x}_l + \alpha\,\sigma\!\bigl(\langle \hat{x}_l, \theta_l \rangle\bigr)\,\theta_l, \end{aligned}\]where $\sigma(\cdot)$ is an activation function and $\langle \cdot,\cdot \rangle$ denotes inner product. We iteratively adjust $\hat{x}_l$ by moving it toward or away from each direction $\theta_l$ in $\Theta_l$. We set $\sigma(\cdot)=\tanh(\cdot)$ with $\alpha = 1$, enabling smooth bidirectional scaling bounded by $[-1,1]$.
The core insight behind AlignEZ is that alignment information already exists within the pre-trained model - we just need to find it and amplify it.
Think of it like a radio signal. The alignment “station” is already broadcasting inside the model, but it’s mixed with static. Traditional methods try to boost the signal by retraining the entire radio (expensive!). AlignEZ instead acts like a targeted equalizer that simply turns up the volume on the channels where the alignment signal is strongest.
For more detailed explanation of our method with (more) proper mathematical notations, check out our paper!
Our experiments reveal that AlignEZ achieves strong alignment gains with a fraction of the computational resources traditionally, significantly simplifies multi-objective/pluralistic alignment process, is compatible with and expedites more expensive alignment algorithms, and
We use the standard alignment automatic evaluation, GPT as a judge evaluation (Zheng et al., 2023), and measure $\Delta\%$, defined as Win Rate (W%) subtracted by lose rate (L%) against the base models.
AlignEZ delivers consistent improvements, achieving positive gains in 87.5% of cases with an average ∆% of 7.2%– showing more reliable performance than the test-time alignment baselines 75% for ITI and 56.3% CAA. Perhaps the most significant advantage? AlignEZ accomplishes all this without requiring ground-truth preference data—a limitation of both ITI and CAA.
Next, we test AlignEZ’s capacity for steering LLMs to multiple preferences at once. We test for two key abilities: (1) fine-grained control across dual preference axes (demonstrating precise regulation of each axis’s influence), and (2) ability to align to 3 preferences simultaneously. Following the setup from (Yang et al., 2024), we evaluate on three preference traits: helpfulness, harmlesness, humorous.
On fine-grained control, we we modulate the steering between two preference axes by applying weight pairs ($\alpha$, 1 − $\alpha$), where $\alpha$ ranges from 0.1 to 0.9 in increments of 0.1.
The result above shows that for uncorrelated preferences, (helpful, harmless) and (harmless, humor), AlignEZ successfully grant fine-grained control, as shown by the rewards that closely tracks the weight pairs ($\alpha$ and (1-$\alpha$)), showing precise tuning capabilities. Steering between correlated preference pair (helpful, harmless), however, shows limited effect. When we attempt to increase one while decreasing the other, their effects tend to counteract each other, resulting in minimal net change in model behavior.
On steering across three preference axes at once, we can see that AlignEZ can simultaneously increase the desired preferences–even outperforming RLHF-ed model prompted to generate these characteristics on the harmless and helpful axes.
Wait, there’s more? yes! We also show that AlignEZ is compatible with classic, more expensive alignment techniques–even giving them significant boost. We show above that AlignEZ is able to lift the performance of a model trained with only 1% of the data to reach the performance of that trained on 25% of the data.
With this set of results showing AlignEZ’s efficacy on alignment tasks, given its cost efficient and practical nature, we are excited about the possibility to extend it to more challenging tasks– ones that requires specialized knowledge such as mathematical reasoning and code intelligence. As a first step in this direction, we perform a proof of concept experiment, applying AlignEZ on multi-step mathematical reasoning benchmarks.
Surprisingly, even when starting from a strong reasoning model, vanilla AlignEZ provides improvements! We attribute these gains to the identified subspace, which appears to strengthen the model’s tendency toward step-by-step reasoning while suppressing shortcuts to direct answers.
AlignEZ represents a paradigm shift in how we approach LLM alignment. By leveraging self-generated preference data and targeted embedding editing, we’ve demonstrated that effective alignment doesn’t require massive datasets or expensive fine-tuning cycles. Our approach offers several key advantages:
Resource Efficiency: AlignEZ works at inference time with minimal computational overhead, making it accessible to researchers and developers with limited resources.
Versatility: From single-objective alignment to complex multi-preference scenarios, AlignEZ provides flexible control without sacrificing performance.
Compatibility: As shown in our DPO experiments, AlignEZ can complement existing alignment techniques, accelerating their effectiveness even with limited training data.
No Preference Bottleneck: By generating its own preference data, AlignEZ removes one of the most significant bottlenecks in the alignment pipeline.
Our promising results open several exciting avenues for future research. Most obvious next direction to explore is domain-specific alignment– extending our proof-of-concept reasoning experiments, we plan to investigate how AlignEZ can enhance performance in specialized domains like medical advice, legal reasoning, and scientific research.
📜🔥: Check out our paper! https://arxiv.org/abs/2406.03642
💻 : Code coming soon! Stay tuned for our GitHub repository.
]]>
Weak-to-strong generalization describes a scenario in which a strong model, trained on pseudolabels or outputs provided by a weak model, achieves superior performance compared to the weak model itself. In such settings, the weak model is capable of making predictions on a broad range of data, while the strong model leverages these predictions as a foundation to learn additional, more complex aspects of the data. This phenomenon is critical in contexts such as data-efficient learning and has implications for building more advanced, robust systems.
A key insight from our work is that the potential for weak-to-strong generalization is driven by the overlap density in the data. Overlap density quantifies the proportion of data points that contain two complementary types of informative patterns: one that the weak model can readily capture and another that requires the capacity of a strong model. More formally, consider a dataset where each input ( x ) can be decomposed as
[ x = [x_{\text{easy}}, x_{\text{hard}}], ]
with ( x_{\text{easy}} ) representing features that are easily learned by the weak model and ( x_{\text{hard}} ) representing the more challenging features. Based on this decomposition, data points can be partitioned into three regions:
Overlap density is defined as the proportion of overlap points in the dataset. These overlap points are critical because they serve as a bridge: the weak model can accurately label them using the easy features, while the strong model can leverage these labels to learn the challenging features. This mechanism lays the foundation for a strong model to effectively leverage supervision signals from a weak model.
However, overlap density is not directly observable in real-world data since easy and hard features are not easily defined. To overcome this challenge, we developed an overlap detection algorithm that operates in two main steps:
Confidence-Based Separation:
The algorithm first uses the weak model’s confidence scores to separate data points. Typically, points where the weak model is less confident are likely to lack the patterns it can easily capture. By thresholding these confidence scores (using methods like change-point detection), we identify a subset of points that are likely to be dominated by the challenging patterns.
Overlap Scoring: After the first step, we identify two groups: hard-only points (low confidence) and non-hard-only points (high confidence). Our next goal is to pinpoint the overlap points within the non-hard-only group. We achieve this by defining overlap scores based on the distance to the hard-only points. The intuition is that overlap points are closer to hard-only points because they contain some of the challenging features that easy-only points lack. Thus, among the non-hard-only points, those with small distances to the hard-only points are classified as overlap points.
This procedure yields an estimate of the overlap density in a dataset, which is crucial for understanding and enhancing weak-to-strong generalization.
Beyond estimating overlap density within a single dataset, our approach extends to selecting the best data sources from multiple candidates. The idea is to prioritize sources that exhibit a high overlap density, as these are more likely to provide supervisory signals that enable the strong model to learn the challenging aspects of the data.
Our UCB-based (Upper Confidence Bound) data selection algorithm works as follows:
Our experiments on datasets such as Amazon Polarity and DREAM demonstrate that this UCB-based strategy consistently identifies data sources with higher overlap density. When training the strong model with data selected through our algorithm, we observe improved generalization performance compared to using randomly sampled data.
Our exploration of weak-to-strong generalization through the lens of overlap density paves the way for continuous, iterative enhancements in model performance. By identifying and leveraging data points that exhibit both easily captured and more challenging patterns, we can significantly boost the effectiveness of weak supervision. This data-focused strategy—prioritizing optimal data sources above all else—may prove crucial for developing truly advanced systems.
Weak-to-strong generalization may hold the key to achieving superintelligence. As AI evolves, the pool of human experts capable of providing meaningful supervision is shrinking. For instance, as mathematicians are increasingly tasked with annotating complex math problems, the scarcity of such expertise becomes a significant bottleneck. In the long run, efficient supervisory signals will be essential, making a deep understanding of weak-to-strong generalization vital.
Moreover, this process mirrors human academia—learning from imperfect past knowledge, generalizing it, and continuously pushing the boundaries forward. We believe our study on weak-to-strong generalization offers a principled pathway for scientific discovery.
To delve further into our theoretical analysis and experimental findings, please refer to our paper at https://arxiv.org/abs/2412.03881 and explore our implementation on GitHub at https://github.com/SprocketLab/datacentric_w2s.
]]>
Zero-shot models are impressive—they can classify images or texts they’ve never seen before. However, these models often inherit biases from their massive pretraining datasets. If a model is predominantly exposed to certain labels during training, it may overpredict those labels when deployed in new tasks. OTTER (Optimal TransporT adaptER) addresses this challenge by correcting label bias at inference time without requiring extra training data.
In our recent work, we introduce OTTER, a lightweight method that rebalances the predictions of a pretrained model to better align with the label distribution of the downstream task. The key insight is to leverage optimal transport—a mathematical framework for matching probability distributions—to adjust the model’s output.
OTTER reinterprets classification as the problem of transporting probability mass from the input space to the label space. In a traditional zero-shot classifier, given a set of \(n\) data points \(\{x_1, x_2, \dots, x_n\}\), the model outputs scores \(s_\theta(x_i, j)\) for each class \(j \in \{1, \dots, K\}\). Typically, we assign each data point the label corresponding to the highest score (i.e. \(\hat{y}_i = \arg\max_{j} s_\theta(x_i, j)\)).
OTTER, however, views these scores as indicating how much “mass” should ideally be transported from each data point \(x_i\) to a class \(j\). We first represent the empirical distribution of inputs as
\[\mu = \frac{1}{n} \sum_{i=1}^{n} \delta_{x_i},\]and we prescribe a target label distribution
\[\nu = (p_1, p_2, \dots, p_K),\]with \(\sum_{j=1}^{K} p_j = 1\). The goal is to reassign the mass from the input points to the classes so that the overall distribution of predicted labels matches \(\nu\).
This is achieved by formulating an optimal transport problem. We define a cost matrix \(C\) where each element is given by
\[C_{ij} = -\log s_\theta(x_i, j).\]This cost function naturally penalizes lower prediction scores, so moving mass to classes with higher scores incurs a lower cost. Then, OTTER solves for a transport plan \(\pi\) via
\[\pi = \arg\min_{\gamma \in \Pi(\mu, \nu)} \langle \gamma, C \rangle,\]where \(\Pi(\mu, \nu)\) denotes the set of all joint distributions whose marginals are \(\mu\) and \(\nu\). In other words, the plan \(\pi\) determines how to reassign the input mass such that exactly \(n \cdot p_j\) points are assigned to class \(j\).
By computing the optimal \(\pi\) and then taking
\(\hat{y}_i = \arg\max_{j} \pi_{ij},\) OTTER produces modified predictions that not only reflect the model’s confidence (through the cost structure) but also enforce the desired label distribution. When the target distribution \(\nu\) is chosen to match the true downstream distribution, this procedure effectively corrects for the bias introduced during pretraining.
The theoretical results in the paper show that if the cost matrix were derived from the true posterior (i.e. \(-\log P(Y = j \mid x_i)\)), then the optimal transport solution would recover the Bayes-optimal classifier. Since the true target distribution is typically unknown, OTTER uses an estimated downstream label distribution to rebalance the predictions accordingly.
A key theoretical insight is that under mild conditions, OTTER recovers the Bayes-optimal classifier. Specifically, if the true target probabilities are \(P_t(Y = j \mid X = x_i)\), then OTTER’s predictions:
\(\hat{y}_i = \arg\max_{j \in [K]} \pi_{ij},\) will match the Bayes-optimal decisions:
\[f_t(x_i) = \arg\max_{j \in [K]} P_t(Y = j \mid X = x_i).\]Moreover, our analysis provides error bounds using perturbation theory—bounding the sensitivity of the transport plan with respect to deviations in both the cost matrix and the target distribution. This ensures that OTTER is robust in practical settings, even when the label distribution estimate is slightly noisy.
We evaluated OTTER on a diverse set of image and text classification tasks, and our findings reveal several key benefits:
OTTER consistently boosts zero-shot classification accuracy, achieving an average improvement of about 4.8% on image tasks and up to 15.9% on text tasks across a variety of datasets.
OTTER requires a potentially large batch size during prediction to function effectively. Our online variant, R-OTTER, overcomes this challenge by learning reweighting parameters from the model’s own pseudo-labels on a validation set, enabling real-time adjustments in dynamic environments without relying on additional labeled data.
Selection bias in LLMs refers to their tendency to favor certain answer choices in multiple-choice questions (MCQs). OTTER effectively mitigates this bias by providing a simple yet effective mechanism to ensure a more balanced and representative distribution of model outputs.
For practitioners deploying zero-shot models, OTTER offers:
OTTER offers a practical approach to mitigating label bias in zero-shot models, enhancing their reliability and adaptability in real-world applications. Check out our paper: https://arxiv.org/abs/2404.08461 and code on GitHub: https://github.com/SprocketLab/OTTER.
Thank you for reading!
]]>Robustness against these spurious correlations have been widely studied in the literature. Sagawa et al., 2019 brings into attention the discrepancy between model performance on data slices that conform to training data biases vs. those that break the correlations. A plethora of methods similar to Arjovsky et al., 2019 tackle this during training stage, which render them less practical for large pre-trained architectures. Zhang & Ré, 2022 demonstrated the performance discrepancy between different data slices is also apparent in pre-trained models, and propose a lightweight adapter training approach. Yang et al., 2023 introduces a spurious correlation aware fine-tuning approach. However, some might argue that fine-tuning breaks some promise of large pre-trained models – their capacity to be used out of the box.
In this post we describe Roboshot: an approach to robustify pre-trained models and steering them away from these biases/correlations. What’s more? RoboShot does this without additional data and fine-tuning! The core of our idea is inspired by embedding debiasing literature, which seeks to remove subspaces that contain predefined harmful or unwanted concepts. However, here we do not seek to produce fully-invariant embeddings; our goal is simply to improve pre-trained model robustness at low or zero cost.
Before diving in, let’s first discuss and formulate the zero-shot inference setup. Similar to Dalvi et al., 2022 we think of pre-trained models embedding space as spanning unknown concepts ${z_1, z_2, \ldots, z_k}$, and a pre-trained embedding $x$ is a mixture of concepts $\Sigma_i \gamma_i z_i$, where $\gamma_i \geq 0$ are weights.
Now, given $x$, $c^0=\sum_i \beta_{i,0} z_i$ (embedding of the first class), and $c^1=\sum_i \beta_{i,1} z_i$ (embedding of the second class) , its zero-shot prediction is made by
Prediction is made by taking the class with higher inner product with the datapoint’s embedding. The above equation describes binary classification, but it is straightforward to extend it to multi-class settings.
RoboShot assumes that input embedding mixture can be partitioned into three concept groups: harmful, helpful, and benign
\[x = \sum_{s=1}^S \alpha_s^{\text{harmful}} z_s + \sum_{r=S+1}^{S+R} \alpha_r^{\text{helpful}} z_r + \sum_{b=S+R+1}^{S+R+B} \alpha_b^{\text{benign}} z_b.\]For better illustration, we will start with a working example of a benchmark dataset: Watebirds. The task is to distinguish $y \in {\texttt{waterbird}, \texttt{landbird}}$. The training data contains unwanted correlations between waterbird and water background, and landbird with land background. For the sake of illustration, let’s assume that in the embedding space, $z_{water} = -z_{land}$ and $z_{waterbird} = -z_{landbird}$.
Let’s say we have a test image that does not follow the training correlations (e.g., landbird over water). In the embedding space, this might be $x=0.7z_{\texttt{water}}+ 0.3 z_{\texttt{landbird}}$
Our class embeddings might be $c^{\texttt{waterbird}}=0.4z_{\texttt{water}}+0.6z_{\texttt{waterbird}} \text{ and }c^{\texttt{landbird}}=0.4z_{\texttt{land}}+0.6z_{\texttt{landbird}}$
Our zero-shot prediction is then $x^T c^{\texttt{waterbird}}= 0.1 > x^T c^{\texttt{landbird}}= -0.1$
which gives us waterbird prediction and is incorrect.
In this example, we see how harmful components contained in $x$ cause wrong predictions. Now, let’s see how RoboShot avoids this by reducing harmful components in embeddings and boosting the helpful ones.
Suppose in $x$ we have ground truth harmful component $v^{\texttt{harmful}}$ and ground truth beneficial component $v^{\texttt{helpful}}$. Note that in reality, we do not have access to the $v^{\texttt{harmful}}$ and $v^{\texttt{helpful}}$ (in the next part of this blogpost, we will describe the proxy for this ground truth component). RoboShot reduces $v^{\texttt{harmful}}$’s effect on $x$ by classical vector rejection:
Intuitively, this procedure subtracts $v^{\texttt{harmful}}$’s component on $x$. Similarly, to increase $v_{\texttt{helpful}}$’s influence, we can add $v^{\texttt{helpful}}$’s component along $x$, such that:
Let’s try this on our example.
Suppose that we have a single harmful and helpful insight:
$v^{\texttt{harmful}}=0.9z_{\texttt{water}}+0.1z_{\texttt{landbird}} \quad \quad v^{\text{helpful}}=0.1z_{\texttt{water}}+0.9z_{\texttt{landbird}} $
First let’s reduce $v^{\texttt{harmful}}$’s effect by plugging it into equation 2, which results in $\hat{x} = -0.0244z_{\texttt{water}}+0.2195z_{\texttt{landbird}}$
Making zero shot prediction with $\hat{x}$, we have $x^T c^{\texttt{waterbird}}= -0.1415 < x^T c^{\texttt{landbird}}= 0.1415$
Which gives us the correct prediction: landbird
We have seen that removing a single component neutralizes the harmful component and now we have the correct prediction! Next, let’s see the effect of increasing $v^{\texttt{helpful}}$’s effect by plugging it into equation 3. This results in
$\hat{x} = -0.0006z_{\texttt{water}}+0.4337z_{\texttt{landbird}}$
This further increase the classification margin!.

Algorithm 1 details the RoboShot algorithm. In real scenarios, we often have multiple helpful and harmful concepts (e.g., shape of beak, wing size, etc). We can simply do the vector rejection and addition iteratively (lines 2-5 and 6-8, respectively).
In real scenarios, how do we get access to $v^{\texttt{harmful}}$ and $v^{\texttt{helpful}}$? especially since in latent space, features are entangled with one another.
First let’s take a step back and think of $v^{\texttt{harmful}}$ and $v^{\texttt{helpful}}$ in the context of the task. For instance, in the Waterbirds dataset, the task is to distinguish between landbird and waterbird. Ideally, our predictions should be dependent only on the bird features (e.g., beak shape, wings size), and independent of confounding factors like backgrounds (i.e., land or water). If we can somehow isolate the background and bird components in the embedding space, set them as $v_{\texttt{harmful}}$ and $v_{\texttt{helpful}}$, and plug equations 2 and 3, we are golden. Wait, this is analogous to setting $v^{\texttt{harmful}}$ as the background features, and $v_{\texttt{helpful}}$ as the bird features! This way, we can think of $v^{\texttt{harmful}}$ and $v^{\texttt{helpful}}$ as a priors inherent to the task. Now, we have two remaining pieces to tie it all together: (i) how to obtain these insights without training, and (ii) how to translate them in latent space.
In RoboShot, we get the textual descriptions of harmful and helpful concepts by querying language models (LM) using only the task description. For example, in the Waterbirds dataset, we use the prompt “What are the biased/spurious differences between waterbirds and landbirds?”.
We translate the answers we get to $v^{\texttt{harmful}}$ , by using their embeddings. Let $s^1, s^2$ be the text insights obtained from the answer (e.g., {‘water background’ ‘land background’}). We obtain a spurious insight representation by taking the difference of their embedding:
where $g$ is the text encoder of our model.
Similarly, to obtain proxy to $v^{\texttt{helpful}}$, we ask LMs “What are the true characteristics of waterbirds and landbirds?” and obtain e.g., {‘short beak’, ‘long beak’}. The remainder of the procedure is identical to the case of harmful components.
Thats it! With this cheap and finetuning free approach, now we can robustify our zero-shot models against unwanted correlations from training data. In the table below, we measure baseline and our performance in terms of average accuracy across all groups (AVG), worst-group accuracy (WG), and the gap between then (Gap). A model that is less influenced by unwanted correlations have high AVG and WG, and low Gap. We can see that Roboshot improves Vision-language model predictions across multiple spurious correlation and distribution shift benchmarks.
On language tasks, RoboShot also lifts weaker/older LMs performance to a level comparable to modern LLMs, and surpass direct prompting to BART-MNLI and ChatGPT on several datasets.
Below, we illustrate the effect of rejecting $v^{\texttt{harmful}}$, increasing $v^{\texttt{helpful}}$, and doing both in the following image. Rejecting $v^{\texttt{harmful}}$ reduces variance in one direction, while increasing $v^{\texttt{helpful}}$ amplifies variance in the orthogonal direction. When both projections are applied, they create a balanced mixture.
In this post, we have described RoboShot: our approach to robustify pre-trained models from unwanted correlations without any fine-tuning. RoboShot is almost zero-cost: we obtain insights from cheap (or even free) available knowledge resources and use them to improve pre-trained models – defying the usual need to collect extra labels for fine-tuning. RoboShot works on multiple modalities, and opens way to use textual embeddings to debias image embeddings.
Thank you for reading! 😊 Kindly check our 👩💻 GitHub repo and 📜 paper!
]]>In this post we discuss a simple way to do this based on one of our NeurIPS ‘22 papers. The core principle is a (very general) form of the weak supervision algorithms that we’ve been playing with for several years. For binary outputs, this idea has already been successfully used in our Ask Me Anything prompting strategy. Here, we focus on lifting this to the richer structures needed for CoT and other techniques.

Warning: our discussion will get a bit technical—but we promise it will be fun! In fact we’ll get to connect to a ton of different fields, including graphical models, unsupervised learning, embeddings and non-Euclidean geometry, tensor algorithms, and more!
First, a roadmap for this post. We will
Let’s dive in!
Let’s take the example in the figure above. We are performing a basic email classification task, where we want to categorize each message as spam or not spam. We repeatedly query the model by varying the prompt, obtaining a number of observations for each email.
We’ll refer to each prompting approach as an object source (OS). These sources are just estimates of the ground truth answer for whatever task we’re interested in. What can we do with these? First, let’s collect the outputs. These are arranged in a matrix as shown in figure below. The instances (examples) are the emails. Of course, the column for the ground truth label $Y$ is just a placeholder since we don’t get to see it.

After observing the outputs of the sources, the goal of aggregation is to estimate the ground truth object—and hopefully more accurately than any single source by itself! A naive but reasonable first-cut way to aggregate is to take the majority vote of the outputs for each point. This approach will work well when the OSs are independent and have similar qualities. However, some OSs could be more accurate and some more noisy. They might also be correlated. This can make majority vote less effective. Imagine, for example, that one source is right 95% of the time, while the others are right only 51% of the time. Clearly aggregation will help, but we’d like to dramatically upweight the accurate source.
How can we model these possibilities? Weak supervision approaches often model the distribution of the unobserved ground truth $Y$ and source outputs $\lambda_1, \ldots \lambda_m$ as a probabilistic graphical model with parameters $\theta$, for example the Ising model:
\[P_{\theta}(\lambda_1,\lambda_2,\ldots \lambda_m,Y) = \frac{1}{Z}\exp \Big( \theta_Y Y + \sum_{i=1}^m \theta_i \lambda_i Y + \sum_{(i,j)\in E} \theta_{ij}\lambda_i \lambda_j \Big)\]What does this do for us? First, we can now think of learning the accuracies and correlations described above as learning the parameters of this model. These are the $\theta$’s, also known as canonical parameters in the PGM literature. Note that unlike conventional learning for graphical models, we have a latent variable problem, as we do not observe $Y$. If we have learned these parameters, we can rely on the estimated model to perform aggregations. The resulting pipeline looks like:

The $\theta$ parameters above encode how accurate each of the OSes are, with a large $\theta_i$ indicating that the $i$th noisy estimate frequently agrees with $Y$, the ground truth. How do we estimate these? We’ll need a few technical pieces from the graphical model literature. It turns out that we need only estimate the mean parameters—terms like $\mathbb{E}[\lambda_i Y]$ and $\mathbb{E}[\lambda_i \lambda_j]$! Note that the correlation terms $\mathbb{E}[\lambda_i \lambda_j]$ do not involve $Y$ — so that as long as we know the structure (the edge set E), the rest is easy, since these terms are observed.
How about the accuracy parameters i.e., the correlations between $\lambda_i$ and $Y$ ? This is challenging as we don’t get to see any ground truth! There are classical methods like EM (Expectation-Maximization) and variants such as Dawid-Skene that could be applied. However, these approaches are prone to converging to local optima and sometimes perform poorly. A simple and elegant approach, Flying Squid, based on the Method of Moments, to the rescue! The key idea is based on the observation that for any three conditionally independent sources, $\lambda_1,\lambda_2,\lambda_3$ the second order moments with binary labels can be written as,
\[\mathbb{E}[\lambda_1\lambda_2] = \mathbb{E}[\lambda_1 Y]\mathbb{E}[\lambda_2 Y]\] \[\mathbb{E}[\lambda_2\lambda_3] = \mathbb{E}[\lambda_2 Y]\mathbb{E}[\lambda_3 Y]\] \[\mathbb{E}[\lambda_3\lambda_1] = \mathbb{E}[\lambda_3 Y]\mathbb{E}[\lambda_1 Y]\]This system of three equations can be solved directly for $\mathbb{E}[\lambda_i Y]$ without observing $Y$, as follows. \(|\mathbb{E}[\lambda_1 Y]| = \sqrt{\frac{\mathbb{E}[\lambda_1\lambda_2] \mathbb{E}[\lambda_3\lambda_1] }{\mathbb{E}[\lambda_2\lambda_3]}}, |\mathbb{E}[\lambda_2 Y] |= \sqrt{\frac{\mathbb{E}[\lambda_1\lambda_2] \mathbb{E}[\lambda_2\lambda_3] }{\mathbb{E}[\lambda_3\lambda_1]}}, |\mathbb{E}[\lambda_3 Y]| = \sqrt{\frac{\mathbb{E}[\lambda_2\lambda_3] \mathbb{E}[\lambda_3\lambda_1] }{\mathbb{E}[\lambda_1\lambda_2]}}\) This analytical solution is easy to obtain for the binary classification setting. All that is left is to figure out the signs of the above, in order to break symmetry. As long as our sources are better than random on average, this can be done.
What does knowing these accuracies buy us? It turns out that we can use them to do weighted aggregation, or, more concretely given our model, to compute a posterior probability \(P_{\hat{\theta}}(Y \vert \lambda_1, \ldots, \lambda_m)\).
This basic idea can also be extended easily to multi-class settings by solving multiple one vs. rest binary classification problems. This method has nice theoretical guarantees and works well for classification settings especially when the number of classes is small—and when the model has special kinds of symmetry. More details about FlyingSquid can be found in the blog post and paper. Try it!
As we saw, the main challenge in WS is to estimate the accuracies $\theta_i$ of the object sources without having access to the ground truth object. While approaches like FlyingSquid are simple and efficient, they make some strong assumptions. If we want to handle outputs that have high-cardinality or special structure (e.g. parse trees, rankings, math expressions etc.), we may need a more powerful tool. Tensor decompositions are a great candidate for this—having already been used for learning many kinds of mixtures. Before we proceed, let’s see how we can adapt this class of algorithms to our aggregation setting.
We’ll start with some quick background on classical multi-view mixture model learning. Our first task is to understand if it is suitable for aggregating more complicated foundation model objects. As a first step, we ask if it works on par with existing methods for simple settings like binary classification? If so, does it directly scale up to more challenging objects, such as those that take on many possible values?
We’ll see that tensor methods are competitive for simple cases, but that this approach doesn’t scale well when the objects live in higher-cardinality spaces with structure. To make it possible to use tensor decomposition approaches for such scenarios, we’ll have to make some careful adjustments.
We can think of source outputs as observations from a multi-view mixture model i.e., each source $\lambda_a$ is a view of the true object $Y$. In a multi-view mixture model, multiple views \(\{\lambda_{a}\}_{a=1}^m\) of a latent variable $Y$ are observed. These views are independent when conditioned on $Y$.
i.e. $\lambda_{a}\vert Y=y$ is conditionally independent of $\lambda_{b}\vert Y=y$ for all $a,b$. This mixture model is depicted as a graphical model in the below figure. 
Now, suppose we have a cardinality $k$ problem (the true object $Y$ takes $k$ values). We use one-hot vector representations of the objects ( denoted in bold-face ). Let \(\mathbb{E}[{\boldsymbol{\lambda}}_a|Y=y] = {\boldsymbol{\mu}}_{ay}\) denote the mean of \(\boldsymbol{\lambda}_a\) conditioned on the true object $y$ (for all $a$ and $y$). Then it is easy to see the following for the tensor product (third order moment) of any three conditionally independent ${\boldsymbol{\lambda}}_a,{\boldsymbol{\lambda}}_b,{\boldsymbol{\lambda}}_c$,
\[{\bf{T}} = \mathbb{E}_{\lambda_a,\lambda_b,\lambda_c,y}[{\boldsymbol{\lambda}}_a \otimes {\boldsymbol{\lambda}}_b \otimes {\boldsymbol{\lambda}}_c] = \sum_{y\in[k]} w_y {\boldsymbol{\mu}}_{a,y} \otimes {\boldsymbol{\mu}}_{b,y} \otimes {\boldsymbol{\mu}}_{c,y}\]i.e. $\bf{T}$ can be written as a sum of $k$ rank-1 tensors. Here $w_y \in [0,1]$ are the prior probabilities of label $Y=y$. Note that we do not know the true distribution of $\lambda,y$. Instead we have $n$ i.i.d. observations \(\{ {\boldsymbol{\lambda}}_{a,i}\}_{a\in[m],i\in[n]}\). Using these we can produce an empirical estimate of $\bf{T}$:
\[\hat{\bf{T}} =\hat{\mathbb{E}}[{\boldsymbol{\lambda}}_a \otimes {\boldsymbol{\lambda}}_b \otimes {\boldsymbol{\lambda}}_c] = \frac{1}{n}\sum_{i\in[n]} {\boldsymbol{\lambda}}_{a,i} \otimes {\boldsymbol{\lambda}}_{b,i} \otimes {\boldsymbol{\lambda}}_{c,i}\]Suppose \(\tilde{\bf{T}} = \sum_{y\in[k]} \hat{w}_y \hat{\boldsymbol{\mu}}_{a,y}\otimes \hat{\boldsymbol{\mu}}_{b,y} \otimes\hat{\boldsymbol{\mu}}_{c,y}\) is a rank-k factorization of the empirical tensor $\hat{\bf{T}}$. If $\hat{\bf{T}}$ is a good approximation of the true tensor ${\bf{T}}$ and $\tilde{\bf{T}}$ is a good approximation of $\hat{\bf{T}}$ then we have that \(\hat{\boldsymbol{\mu}}_{a,y}\) is good approximation of the true mean parameters ${\boldsymbol{\mu}}_{a,y}$. This idea is developed in the fantastic Anandkumar et al. 2012, 2014 and lots of follow-up work.
Using the estimates $\hat{\boldsymbol{\mu}}_{a,y}$ we obtain estimates of our canonical $\theta$ parameters, and so we’ll have the accuracies, just as with FlyingSquid or other weak supervision methods. We’ll call this procedure the tensor aggregation model.
The big question—how well does this work? We run a simple experiment on simulated sources to show that this method is competitive. For this we simulate three object sources outputting multiclass values with $\theta=[4,0.5,0.5]$. We run tensor aggregation on the 1-hot encodings of the outputs and compare the accuracy of the aggregated object against FlyingSquid and majority vote baselines. The results are shown in figure below (averaged over 100 trials). Tensor aggregation offers competitive performance but due to the use of 1-hot encodings—leading to high dimensionality—its performance also degrades when we increase the cardinality of the object space.

Note that we used the simplest one-versus-all approach to multiclass FlyingSquid. There are much more powerful variants that would likely out-compete (as is the case for binary)—but for simplicity, we won’t include all of these.
Overall, the tensor method is encouraging and we’re motivated to apply it beyond simple classification settings. How do we scale up to such settings?
As we alluded to, many foundation models will require aggregating items more diverse than just a multiclass label. Even more generally, we’ll often want to aggregate a huge range of object types. We’ve thought about how to do this with semantic dependency parse trees, classes of objects having hierarchal structure, continuous or manifold-valued objects for regressions, and more. We can often think of the spaces these objects live in as metric spaces—since they have natural distance functions. Here we’ll discuss the finite metric space case, but we have lots of ideas about how to extend it to infinite cardinality spaces. Our approach consists of two high level steps:
As we shall see, both of these steps are critical. We show a full pipeline below.

Now that our objects of interest live in metric spaces, our new goal is to aggregate them into something close to the ground truth. For example, suppose the distance metric is $d$. We’d like to again aggregate $\lambda_a, \lambda_b, \lambda_c$. Ideally we’d like to get an output $\hat{y}$ so that $\mathbb{E}[d(\hat{y}, y)]$ is small. Once again, we’d need to account for accuracies—which are now average distances like $\mathbb{E}[d(\lambda_a, y)]$.
Working directly with discrete metric spaces is challenging—we can’t use our favorite off-the-shelf optimization approaches. To make life easy we’ll do the usual: work with low-dimensional vector space representations. If we can do this, we’ll be set: we’ll get away with using tensor aggregation without needing to scale it up to high dimensions, where we could get hurt by noise, as we saw earlier.
The key is to have these low-dimensional representations preserve distances, since otherwise we can’t hope to perform a reasonable aggregation. That is, if our embeddings of objects distort these distances, our aggregation might end up with a distant output $\mathbb{E}[d(\lambda_a, y)]$. Learning faithful embeddings has been a very active area of research over several decades. Here we are particularly interested in learning isometric—perfectly distance-preserving—embeddings.
In general, such isometric embeddings might not exist in the conventional case of vector space embeddings. Instead, we use Pseudo-Euclidean Embeddings (PSE). These are a generalization of classical Multi-Dimensional Scaling(MDS). The main benefit of PSE over MDS is that it can isometrically embed metric spaces that cannot be isometrically embeddable in Euclidean space. The main drawback, as we shall see, is that pseudo-spaces are weird!
We’ll discuss PSE more technically below, but first let’s understand its utility. As an example, take our metric spaces to be graphs, where the distance is the smallest number of hops between nodes. Two examples of graphs are shown below. We learn their node embeddings using MDS and PSE. MDS gives low dimensional representations but cannot produce isometric embeddings for general metric spaces. Note that MDS (blue line) never reaches zero—but with just three-dimensional embeddings, PSE does! For a more complex graph, the tree to the right, we see the same effect. Adding dimensions helps MDS a bit, but fails to produce isometric embeddings, while PSE succeeds again (red line drops to $10^{-14}$).

How do these pseudo-Euclidean spaces work? Basically, their metrics are no longer induced by p.s.d. inner-products, so that we can have distinct points still have distance 0. This is behavior that is often challenging to deal with geometrically, but for our purposes, works fine.
Let’s see some technical details: a vector ${\bf{u}}$ in a pseudo-Euclidean space $\mathbb{R}^{d^+,d^-}$ has two parts: ${\bf{u}}^+ \in \mathbb{R}^{d^+}$ and ${\bf{u}}^- \in \mathbb{R}^{d^-}$. The dot product and the squared distance between any two vectors ${\bf{u}},{\bf{v}}$ are $\langle {\bf{u}}, {\bf{v}}\rangle_{\phi} = \langle {\bf{u}}^+,{\bf{v}}^+ \rangle - \langle {\bf{u}}^-,{\bf{v}}^- \rangle$ and $d^2_{\phi}({\bf{u}},{\bf{v}}) = \lVert{\bf{u}}^{+}-{\bf{v}}^{+}\rVert_2^2 - \lVert {\bf{u}}^{-}-{\bf{v}}^{-}\rVert_2^2$. These properties enable isometric embeddings: the distance can be decomposed into two components that are individually induced from p.s.d. inner products—and can thus be embedded via MDS. Indeed, pseudo-Euclidean embeddings effectively run MDS for each component. To recover the original distance, we obtain $\lVert {\bf{u}}^{+}-{\bf{v}}^{+}\rVert_2^2$ and $\lVert{\bf{u}}^{-}-{\bf{v}}^{-}\rVert_2^2$ and subtract. More details on these embeddings can be found in a classic treatise.
Armed with the powerful PSE technique, we first obtain isometric representations of the objects in a PSE space and solve the parameter estimation problem using tensor decomposition. The original tensor decomposition algorithm was designed for Euclidean vectors so we cannot apply it off-the-shelf for PSE points. We overcome this issue by using the fact that the two parts of any vector in PSE are individually in Euclidean spaces $\mathbb{R}^{d^+},\mathbb{R}^{d-}$ . This allows us to treat the positive and negative components \({\boldsymbol{\lambda}}_{a}^+ \in \mathbb{R}^{d^+}\) and \({\boldsymbol{\lambda}}_{a}^{-} \in \mathbb{R}^{d^-}\) of our pseudo-Euclidean embedding as separate multi-view mixtures. We apply tensor decomposition on them separately, which gives us mean parameters \(\hat{\boldsymbol{\mu}}^+_{a,y}\) and \(\hat{\boldsymbol{\mu}}^-_{a,y}\) for each $a,y$. Using these we obtain our estimates of the canonical parameters \(\hat{\bf{\theta}}\).
With this adaptation, we retain the nice theoretical guarantees of tensor decomposition for parameter recovery while working with any finite metric space. We can also see the benefit of our approach on a simple synthetic data experiment on the tree metric we saw earlier. In this experiment, we simulate three sources on the tree metric with three branches with $b$ number of nodes in each branch. We use $\theta=[4,0.5,0.5]$ i.e. first source is highly accurate and the other two are somewhat noisy. We run two variations of our method one with PSE embeddings and the other with 1-hot embeddings of the labels. We keep the number of samples $n=1000$ fixed and vary the number nodes $b$ to increase the cardinality of the label space. The results can be seen in figure below.
As expected, using PSE embeddings we can achieve much better accuracy of the aggregated objects and unlike other methods this performance does not degrade with higher cardinality, as this metric space is isometrically embeddable in 3-dimensional PSE space.
This aggregation approach is quite general and can be applied in any setting where we can obtain multiple noisy observations of a ground truth object living in a discrete metric space.
We’ll show off its potential in a toy CoT example. We consider in-context learning for language models. The in-context examples typically consist of paired input and output data, which effectively guide LLMs in comprehending the task at hand and generating accurate predictions. Recent advancements in this area have shed light on the effectiveness of prompts that incorporate explicit steps known as Chain of Thoughts (CoT). These step-by-step instructions facilitate LLMs in making precise predictions while providing detailed reasoning steps. Building upon this concept, more nuanced variations such as Tree of Thought (ToT) and Graph of Thought (GoT) have emerged. These expanded frameworks have demonstrated impressive efficacy when tackling complex reasoning problems with LLMs.

While highly effective, they require access to high quality explanations which can be a bottleneck in broad applicability of these methods. Nevertheless, one can always come up with many self-obtained, or heuristic, or otherwise inexpensive sources that can provide potentially noisy reasoning steps. How can we use these to construct accurate chains or trees or graphs of thoughts?

Indeed, we can use our aggregation approach. As an illustration, we consider the Game of 24, a complex reasoning puzzle with 4 numbers from 1 to 13 as input. The expected output is an expression using the given numbers and basic arithmetic operations (+,-,x,/) so that the expression evaluates to 24. Note that this task can be easily solved by enumerating all possible expressions and selecting the ones that evaluate to 24. However, we are interested in solving this task using LLMs by providing it some in-context examples. Here the CoT steps could be an expression broken down into multiple steps. We use the same 1362 puzzles as in Tree of Thought paper and simulate 3 sources with different noise levels ( $\theta= [5,0.6,0.5]$ ) that can provide noisy expressions (CoTs). We then apply our aggregation procedure (i.e., PSE + tensor decomposition() to recover the true expressions and evaluate the recovered expressions for the correctness. We run this procedure 10 times with different random seeds and report the mean accuracies in the above bar chart. We can clearly see that our method based on tensor decompositions output performs majority vote.
Although on a small-scale toy problem, these findings are quite exciting and demonstrate the potential of weak supervision for aggregating foundation model objects, such as in CoT, ToT, GoT or other forms of reasoning.
We hope you enjoyed our post! Please check out our paper, and our code!
]]>It’s no secret that large-scale supervised machine learning is expensive and that one of the biggest challenges is in obtaining the labeled data required for training machine learning models. Weak Supervision (WS) a popular and quite successful technique for reducing this need for labeled data. WS relies on access to noisy, heuristic functions that produce reasonable label guesses–these are called labeling functions, or LFs for short. Given a handful of these LFs, WS attempts to learn the relationships between the LFs and the true but unobserved label–the component that does this is called the Label Model. WS is fairly easy to apply to text data, it’s harder to apply to data with more complex features. Automated Weak Supervision (AutoWS) solves this problem by instead learning the LFs using a small amount of labeled data. The beauty of all of this is that WS and AutoWS can be combined with other ways of dealing with a lack of labeled data, like zero-shot learning with foundation models, or self-supervised learning. In this blog post, we will shed some light on AutoWS and explain the motivation behind our AutoWS-Bench-101 benchmark, the first-ever benchmark for AutoWS!

Let’s step through a quick example of WS on movie review data… Here, the goal is to classify Rotten Tomatoes reviews as either “Fresh (+)” or “Rotten (-)” Suppose we start off with three LFs, and for simplicity, we will use majority voting as our Label Model:
Now let’s apply these LFs to the following review of the 2019 movie Cats:
"At best, it’s an ambitious misfire. At worst, it’s straight-up nightmare fuel that will haunt generations. Enter into the world of the Jellicles at your own peril."
Most people would probably assign this review the label “Rotten,” though since we’re doing WS, let’s check to see if our LFs agree… LF1 doesn’t vote because the word “amazing” does not appear in the text, LF2 votes “Rotten,” and for the sake of argument, suppose that LF3 also votes “Rotten.” Since we’re aggregating these LF outputs using majority vote, WS correctly labels this review as “Rotten.”
The purpose of this example was twofold: first, if you aren’t familiar with WS, this example was hopefully illuminating. And second, that it’s easy to write LFs for text data! The “features” that come with text (i.e., words) are more intuitive for humans to reason about, which makes it easier to come up with fairly general rules for text tasks.
But what about data with more complex features, such as images? To our knowledge, traditional WS hasn’t even been applied to MNIST, because writing LFs from scratch for raw pixel data is simply not practical.

Fortunately, dealing with more complex features only requires a few extra steps… The general idea is to learn the LFs using a small set of labeled examples instead of writing the LFs by hand. This technique is called Automated Weak Supervision, or AutoWS, and the pipeline illustrating this process is shown in the above diagram. What makes traditional weak supervision difficult for things like images is that they are typically represented as tensors of pixel values, and it is challenging for a human to write explicit LFs on the pixel level to classify these data. Of course, this is true of other data types as well–including things like PDEs, which are often used for physics simulations, medical data, and featurized tabular data.
The first step in most AutoWS techniques is to obtain a more useful representation of the complex data. This is typically done by using some form of dimensionality reduction, by either using a classical technique such as PCA or by using an embedding obtained from a modern foundation model.
Next, AutoWS techniques often use some small set of existing labeled examples to train simple models in this feature representation–these are called weak learners, and these will be used in place of the hand-designed LFs used in traditional WS.
Finally, we proceed with the rest of the original WS pipeline by learning the parameters of the LM and we generate training data! Except now, we can generate training data for much more complex and diverse domains, including a large variety of previously challenging scientific domains.
Armed with the two key takeaways from our deep dive into AutoWS methods,
With AutoWS-Bench-101, we benchmark AutoWS methods using only 100 initial labeled examples, which gives our benchmark its name, as our goal is to generate the 101st label onward! We do so by applying the two previously-mentioned takeaways–we evaluate the cross product of a set of feature representation methods with a set of AutoWS methods, and we do so on a diverse set of applications.
In particular, we tried a wide range of feature representation techniques, of varying complexity–simply using raw features, PCA, an ImageNet-trained ResNet-18, and features from CLIP–a modern foundation model. We plug each of these into a handful of AutoWS techniques, including Snuba, Interactive Weak Supervision, and GOGGLES.
Our benchmark comprises three main categories of datases:
In the first category, we include MNIST, CIFAR-10, a spherically-projected version of MNIST, and MNIST with permuted pixels. Next, for backward compatibility with WRENCH, a benchmark for WS, we include three of the NLP datasets from their benchmark: YouTube, Yelp, and IMDb. Finally, we include three datasets from diverse application domains, where we think that AutoWS is quite promising: electrocardiograms (ECG), classifying the turbulence of a PDE (Navier-Stokes), and malware detection (EMBER).
The standard of evaluation for AutoWS-Bench-101 relies on performance profile curves, which are a holistic way to evaluate different methods across various settings or “environments.”
We won’t go into too many details of how these are computed here, and instead, I’ll refer you to a nice blog post by Ben Recht on the topic.
The key idea of performance profiles is that the higher curves are better for most tasks, or at least close to the best method for a given task, and the curves themselves are able to express situations in which a method is actually dramatically worse than the best method.
Using performance profiles, we were able to see several interesting trends across our three categories of datases:
Using our benchmark, we also came away with these other key findings:
For more details about these findings, and our ablation studies of the various AutoWS methods that we tried, check out our paper! 😃
And if you arrived at this page by scanning our QR code at NeurIPS, and you made it all the way here, here’s a cookie. 🍪
We’re excited to add more functionality and methods to AutoWS-Bench-101! But beyond this benchmark and WRENCH, what is left to do? I mentioned before that WS and AutoWS can be combined with zero-shot learning with foundation models and self-supervised learning… But how do we find out which methods, or combination of methods, are actually the most useful for different types of tasks?
I am personally excited about the idea leveraging community involvement to answer these big questions. As an organizer of the AutoML Decathlon competition at NeurIPS 2022, one idea that I’m excited about is to run a Weak Supervision coopetition–a cooperative competition. The idea behind this is to solcit LFs for a set of diverse tasks with mostly unobserved labeles from the community–and cooperatively solve the challenge of programmatically-labeleing large datasets via a Kaggle-like interface. I like this idea because it is goal-driven: the community must find a way to label these datasets by any means necessary, and everyone can help one another out by contributing to a shared GitHub repository. I think that this could be similar to something like Google Big-Bench, with the promise of publishing a paper with many authors, and the eternal glory of having contributed to a large-scale (possibly registered + peer-reviewed) supervision experiment.
Whether the next steps for WS benchmarking end up being related to this idea of a coopetition or something entirely different, I’m super excited to see where we go next with reducing the need for labeled data!
Nicholas Roberts [email protected]
]]>