Mechanistic interpretability and steering LLMs requires being able to read and modify activations during inference. For instance, to apply steering vectors to control model generation, we need to first collect hidden activations to find a steering direction, then intervene by modifying hidden activations of the model during inference.
To read and patch activations from a LLM, you first need to find the relevant layers that you care about and either add hooks or wrap them. This tends to lead to two approaches, either 1. writing a custom model wrapper for every model you might want to work with (approach taken by Repe, CAA) or 2. leave it to the user to manually specify layer names to patch, and apply the patch using Pytorch hooks (approach taken by Baukit). The first approach is a never-ending battle as new models are released, and the second approach, while very flexible, passes on the complexity to anyone using what you’ve written.
In this post, I’ll discuss a third option, which is to auto-detect the types of layers in a Pytorch LLM and read/patch using Pytorch hooks, and is the approach used by the steering-vectors library. This leverages the fact that all transformer LMs have the same basic structure: a series of layers containing attention and MLP blocks. This post assumes the model is from Huggingface, although this same technique will likely work with any transformer LM that’s sanely constructed. This post will use the terms “transformer LM” and “LLM” interchangeably to refer to a decoder-only generative language model like GPT or LLaMa.
Finding the component parts of any Pytorch module is easy by calling named_modules() on the model. This will return a dictionary containing the name of the submodule, and the submodule itself. This is demonstrated for GPT2-small below:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("gpt2")
print(dict(model.named_modules()).keys())
# transformer
# transformer.wte
# transformer.wpe
# transformer.drop
# transformer.h
# transformer.h.0
# transformer.h.0.ln_1
# transformer.h.0.attn
# transformer.h.0.attn.c_attn
# transformer.h.0.attn.c_proj
# transformer.h.0.attn.attn_dropout
# transformer.h.0.attn.resid_dropout
# transformer.h.0.ln_2
# transformer.h.0.mlp
# transformer.h.0.mlp.c_fc
# transformer.h.0.mlp.c_proj
# transformer.h.0.mlp.act
# transformer.h.0.mlp.dropout
# ...
# transformer.h.11
# transformer.h.11.ln_1
# transformer.h.11.attn
# transformer.h.11.attn.c_attn
# transformer.h.11.attn.c_proj
# transformer.h.11.attn.attn_dropout
# transformer.h.11.attn.resid_dropout
# transformer.h.11.ln_2
# transformer.h.11.mlp
# transformer.h.11.mlp.c_fc
# transformer.h.11.mlp.c_proj
# transformer.h.11.mlp.act
# transformer.h.11.mlp.dropout
# transformer.ln_f
# lm_head
Here, it’s clear that the 12 decoder block layers of the model are of the form transformer.h.{num}, the attention layers are transformer.h.{num}.attn, and the MLP layers are transformer.h.{num}.mlp. It’s similarly easy to see the input and ouput layer norms and dropout.
For LLaMa, the layers are of the form model.layers.{num} for each decoder block, model.layers.{num}.self_attn for attention, and model.layers.{num}.mlp for the MLP layers. For Pythia, the decoder block, attention and MLP layers are of the form gpt_neox.layers.{num}, gpt_neox.layers.{num}.attention, and gpt_neox.layers.{num}.mlp, respectively.
This hints at a simple rule to find relevant layer names in any transformer LM - simply look for the shortest template string of the form *.{num}* which also contains any other terms you might care about. For instance, for attention layers, looking for the shortest template that contains either “attn” or “attention” should cover nearly all LLMs. Likewise, looking for the shortest template with “mlp” should get the MLP layers in nearly all cases. We can generalize this in code below:
import re
from collections import defaultdict
# look for layers of the form "*.{num}"
LAYER_GUESS_RE = r"^([^\d]+)\.([\d]+)(.*)$"
def guess_matcher_from_layers(model, filter = None) -> str | None:
counts_by_guess: dict[str, int] = defaultdict(int)
for layer in dict(model.named_modules()).keys():
if re.match(LAYER_GUESS_RE, layer):
guess = re.sub(LAYER_GUESS_RE, r"\1.{num}\3", layer)
if filter is None or filter(guess):
counts_by_guess[guess] += 1
if len(counts_by_guess) == 0:
return None
# score is higher for guesses that match more often, are and shorter in length
guess_scores = [
(guess, count + 1 / len(guess)) for guess, count in counts_by_guess.items()
]
return max(guess_scores, key=lambda x: x[1])[0]
Then we can find a layer matcher template for the base decoder block, attention, and MLP layers for a model like below:
model = AutoModelForCausalLM.from_pretrained("gpt2")
guess_matcher_from_layers(model)
# "transformer.h.{num}"
guess_matcher_from_layers(model, lambda l: "attn" in l or "attention" in l)
# "transformer.h.{num}.self_attn
guess_matcher_from_layers(model, lambda l: "mlp" in l)
# "transformer.h.{num}.mlp
This code will also successfully guess the corresponding layer templates for LLaMa, Pythia, and any other transformer LM.
Extracting layers using a layer template
Now that we have a layer template string for each of the types of layers we care about, we just need a way to specify a layer number and get back the corresponding submodule to patch. Fortunately, we already have everything we need to do this. The named_modules() method of Pytorch modules gives use everything we need. First, lets start by finding all the numbered layers in the model which match a given template string:
def collect_matching_layers(model, layer_matcher) -> list[str]:
all_layer_names = set(dict(model.named_modules()).keys())
matching_layers = []
for layer_num in range(len(all_layer_names)):
layer_name = layer_matcher.format(num=layer_num)
if layer_name in all_layer_names:
matching_layers.append(layer_name)
else:
break
return matching_layers
If we run this function on GPT2 with the decoder block layer matcher (transformer.h.{num}), we’ll get back an ordered list of all matching layers: transformer.h.0, transformer.h.1, etc…
Once we have this list, it’s trivial to select any layer number from it, and again, use named_modules() to get back the actual Pytorch module corresponding to that layer:
model = AutoModelForCausalLM.from_pretrained("gpt2")
layer_matcher = guess_matcher_from_layers(model) # "transformer.h.{num}"
modules_by_name = dict(model.named_modules())
layer_names = collect_matching_layers(model, layer_matcher)
# layer 2
layer2 = modules_by_name[layer_names[2]]
# layer 7
layer7 = modules_by_name[layer_names[7]]
Add hooks and profit
We now have a way to automatically find and extract all the relevant layers from a Pytorch LLM. The next step is to add Pytorch hooks to read or modify activations.
# add a hook to layer2 and layer7 from above
def do_something_cool(module, args, output):
# save or modify the layer output
...
for layer in [layer2, layer7]:
layer.register_module_forward_hook(do_something_cool)
… and that’s all there is to it! To see this in action, check out layer_matching.py in the steering_vectors library.
]]>nltk.download) which I cannot edit easily. And even if I could, it’s unsettling to disable SSL verification since that opens you up to potentiall man-in-the-middle attacks. The errors would look something like below:
urlopen error [SSL: CERTIFICATE_VERIFY_FAILED]
certificate verify failed:
unable to get local issuer certificate (_ssl.c:1002)
I didn’t have any luck following most of what I found on Stack Overflow to solve this issue, but eventually stumbled on a solution combining ideas from Redhat’s guide to Python cert errors, and a Stack Overlow answer. Specifically, I needed to install certifi certs via pip install certifi, but this was not enough. I then needed to set an ENV var called SSL_CERT_FILE to the location of the certs installed via certifi. I don’t know why Python wasn’t using these certs automatically as it should have been, but this solved the issue for me.
The full steps I took are as follows:
pip install certifi
Next, in Python, find the certifi install location by running
from requests.utils import DEFAULT_CA_BUNDLE_PATH
print(DEFAULT_CA_BUNDLE_PATH)
# /path/to/python/site-packages/certifi/cacert.pem
Note the output of the above cacert.pem file, and add the following to .bashrc (or .bash_profile or .zshrc, etc… depending on your system).
export SSL_CERT_FILE=/path/to/python/site-packages/certifi/cacert.pem
Of course, in the above make sure you use that actual path to cacert.pem on your system.
Next, restart the terminal and hopefully everything should work!
]]>
Cake with melted plastic lego pieces, delicious. Generated by Midjourney
I recently started a PhD in Computer Science after spending the past 10 years working as a software engineer. One of the biggest shocks to me in this transition (aside from how incompetent I am as a researcher) has been the apalling state of code that accompanies published research papers. Usually when I complain about academic code, people think I’m just talking about code quality being poor (which it is), but it’s much deeper than that. The way code is open-sourced in most academic papers is typically completely broken, and shows a deep misunderstanding about what code is for, and what open-source is about.
Imagine you invite some friends to your apartment, and one of them brings a cake they baked. When you try to eat the cake, you find that it has melted plastic lego pieces in it. Shocked, you point this out to your friend, who just replies, “Oh, I’m not good at cooking.” You then realize your friend has a misunderstanding at a fundamental level about what cooking is for, and what food even is. The code that accompanies research papers is like that cake - it fails at the most basic thing code is meant to do, which is to run and be usable.
In this post, we’ll go over what I view as the problem, and share tips for academics on how to do a better job of open-sourcing their code. I’ll be focusing on Python, as that’s mainly what’s used in AI research, but a lot of this will apply to other languages as well. First and foremost, we’ll focus on the big picture of making the code fit for human consumption, and then we’ll go over how to improve the taste, aka code quality.
The current state of academic code is both a travesty and a huge missed opportunity. The few cases I’ve seen where research code is properly packaged and is made easy to use, both the repo and corresponding paper get massive numbers of citations and wide usage. In this article, I hope to show that doing a good job of open-sourcing research code is worth the effort and that it’s not difficult.
Most code I encounter that’s released as part of academic papers is completely broken, as in it’s not possible to run the code as provided at all. This is typically due to things like missing files the researcher forgot to upload, or hardcoded file paths to stuff on the researcher’s own machine, missing documentation, or not pining Python dependency versions. This shows that the researcher never even tried running the code they open-sourced at all, and instead just copied and pasted some files from their local hard-drive into a Github repo, linked the repo in their paper, and high-fived everyone for a job well done. There is no notion that other people are going want to actually try to run that code, and that by uploading broken code and advertising it in a paper you are directly wasting thousands of hours of other people’s time.
Academics are not bad people, and I don’t believe they’re intentionally being malicious. Instead, I think the mindset of most researchers towards open-source code is the following:
The problem with the academic mindset to open-sourcing code above is that it misses the core thing that code is for, which is for actually running and accomplishing a task.
As a researcher, success means having your work widely cited and used by other researchers. One of the most direct ways to accomplish that is for other researchers to use your code in their work. If you make your code easy to use and package it properly, other researchers will use it and then cite your papers. If your code is completely broken or not usable due to not being packaged properly, nobody will use it. Other researchers want to use your code too - it’s a win-win for everyone if research code is open-sourced properly.
It should go without saying, but it’s not OK to publish code that’s completley broken. Your paper is a giant advertisement for your code, and people who read your paper will naturally go to the repo you link and try running what they find there. If the code is broken, you are collectively wasting thousands of hours of other people’s time. Typically, when research code is broken, I find it’s for one of the following reasons:
After ensuring that the code is actually working, the next most important thing is packaging it properly so others can use it in their work. Ideally, your goal should be to release a library which does the thing in your paper, not a pile of random Python files.
Python libraries should be packaged and released on PyPI
If your code is just a bunch of Python scripts in a Github repo, it’s nearly impossible for other people to use that code in their work. What are they supposed to do, copy and paste files from your repo onto their hard drive? Are they supposed to open up the files and copy/paste individual chunks of Python code out? Nobody is going to do that. Fortunately, there’s a well-established way to import code into a Python project which makes it easy for your code to be used by others, and that’s for the code to packaged as a library on PyPI. This lets it be installed with pip install <your-library-name>.
The idea of releasing a library might sound daunting, but it’s really easy once you get used to it. The difference is a code organization question more than anything else, and some basic thought put to “what would someone want to do with this code?”. Once you’ve learned to package code into a library you’ll see that doing a decent job of packaging your code is far easier than learning LaTeX, or writng a paper, or finding a research idea to begin with. We’ll discuss how to make this easy in the section on Poetry later in the article.
The git repo for your code should have the following components:
pip install your-awesome-libraryfrom your_awesome_library import do_awesome_thing
result = do_awesome_thing(input)
… and that’s basically it. If do this, your code should be easy for others to use and you’re already better of 95% of the open-source code released by researchers.
By all means, do include code to reproduce the experiments in your paper, as reproducability of results is important. However, recognize that the majority of users of your code won’t want to reproduce your results, so it shouldn’t be the main focus. It’s fine to include an experiments folder in your git repo for reproducing your results that’s not published to PyPI, or even split apart the experiments into a separate git repo from the reusable library code so the library can evolve separately. If you take the approach of splitting the repos, then the experiments repo can import the library as a normal pip dependency, which also has the bonus of verifying that your library works when installed as a dependency in other projects. This leads naturally to the next point:
If your paper has several distinct components, each of which could be used independently as its own library, there’s nothing wrong with releasing multiple open-source repos or libraries along with your paper. Your goal should be to make your code useful to others, and you might find that it’s more natural to release 2 or even 3 different libraries so the various parts of your paper can be used independently rather than trying to fit everything into 1 library. There’s no rule that says every paper must correspond to 1 and only 1 git repo. If splitting into separate libraries makes it easier for others to use, go for it!
The best examples of researchers releasing their code are also some of the best known projects in the field. I don’t believe this is a coincidence - if you package your code properly and release it on PyPI, then others will use it in their own projects and cite your paper. Two excellent examples that come to mind are the following:
FlashAttention is a beautiful illustration of how you don’t need to overthink this to do a good job. This repo has a simple README with installation of the library via pip install flash-attn and basic instructions on how to use it in Python. There’s a benchmarks folder in the repo to reproduce the results in the paper, but it’s not the main focus of the library. The library itself is simple and focused. A+
Sentence Transformers goes above and beyond, including a documentation website and continues to evolve and add pretrained models to the library. This library corresponds to an original paper by the author on a technique for sentence similarity, but was just packaged well and focused on ease of use, and the author clearly has put a lot of care into this library.
Both of these libraries were created by individual researchers along with their papers, by a PhD student in the case of FlashAttention, and a postdoc in the case of Sentence Transformers. In both these cases, the authors could have just copy/pasted a collection of unusable Python scripts into a Github repo and left it at that, but then likely neither of their papers would have achieved anywhere near the level of success they both have seen. I believe the level of polish that these libraries show is very achievable for all academics, and should be the norm rather than the exception.
Personally, I like using Poetry for managing Python projects. Poetry handles a lot of the complexity of virtual environments for Python, dependency management, and finally, publishing your library to PyPI so it can be installed with pip install <your-library>. Poetry isn’t the only way to do this, but it provides a good foundation.
Let’s assume we’re the authors of a paper about dog image classification using a technique called “DogBert”. We could start by making a new Poetry project:
poetry new dogbert
This will give us the following file structure:
dogbert
├── README.md
├── dogbert
│ └── __init__.py
├── pyproject.toml
└── tests
└── __init__.py
It may seem confusing that there’s 2 nested folders, both named dogbert, but this is a standard setup for Python projects. The inner dogbert folder containing __init__.py is where all our library Python files will go. If you write tests (and you should!), those go in the tests folder.
We cd into the outer dogbert folder, and run poetry install to initialize a new pyenv environment for our project, and make sure any needed dependencies are installed.
We can add any pip dependencies our project needs with poetry add <dependency>. Finally, when we want to publish our library on PyPI so it can be installed with pip install dogbert, we just run the following 2 commands:
poetry build
poetry publish
And that’s it, our library is on PyPI! There’s really not much to it, that’s all it takes to package and publish a library to PyPI.
If you’re used to just writing standalone python scripts in a single file and running them with python my_file.py, Poetry might seem strange at first. If we have the file utils.py at dogbert/utils.py with a function called preprocess(), and we have have another file which wants to import that preprocess function, we can import it like below:
from dogbert.utils import preprocess
Poetry creates its own pyenv enviroment so different projects have independent sets of installed Python modules. This is great, but it means that instead of directly running python, you need to prefix all commands on the CLI with poetry run so the correct pyenv is used. Also, for scripts, it’s best to run them using python’s module flag. Where you may be used to directly running a script file with python path/to/script.py, when using Poetry you’d instead run poetry run python -m path.to.script. If we had a script called train.py at dogbert/scripts/train.py, we could run that with poetry run python -m dogbert.scripts.train.
This may take some getting used-to initially, but it’s a minor workflow change which quickly becomes second-nature.
I’ve had issues in the past with adding Pytorch as a dependency from Poetry since Pytorch has multiple versions with different CUDA requirements, which Poetry doesn’t handle well. I find it’s best to simply leave it to the end-user of your library to install Pytorch, and not try to force it via poetry add torch, since it’s easy to end up with a non-CUDA version of pytorch that way. Oftentimes, if you’re relying on libraries like pytorch-lightning or other popular machine learning libraries, they’ll already handle making sure PyTorch is installed. If you want to include torch as a dependency, I’d recommend adding it as a dev dependency poetry add --group dev torch so you won’t accidentally end up with a CPU-only version of PyTorch being installed for end-users of your library. Hopefully this will be handled better in future versions of Poetry/PyTorch!
Whatever users of your code are most likely to want to do should “just work” out of the box if possible. For instance, in the case of our DogBert image classification paper example, the most likely thing a user would want to do with our code is to classify images with our pretrained model. We should make this use-case as painless as possible. For instance, we can upload our pretrained model to the Huggingface Model Hub and then have our code automatically download and use that pretrained model if the user doesn’t specify a different model to use. If you want to upload your pretrained model somewhere else, that’s fine, just make sure your library can auto-download it by default so your code can “just work”.
from dogbert import DogbertModel
# default case, auto-download our pretrained model from Huggingface
model = DogbertModel()
# allow the user to specify their own model if they want
model = DogbertModel("/path/to/model")
Nothing makes me doubt the results of a paper more than opening up the code the paper links to and seeing unused variables and linting errors strewn throughout the files. Linting errors like this are the coding equivalent of submitting a paper to journal written in crayon. Fortunately, this is easy to remedy by just using a linter during development.
Linters like Flake8 or Pylint can check your code for common code-quality issues like unused variables and report them as errors. All popular code editors have plugins for Python linters which will highlight linting errors directly in your code. You can also customize the errors the linters report if there are types of errors you want to ignore. Linting errors ofter correspond to real bugs in your code, and are an easy way to improve your code quality at almost no cost. There’s really no downside to using a linter.
Related to linters are code formatters like Black. Black will automatically format your code for you so the formatting is consistent, and takes away the need for you to think about formatting entirely. Personally, I think code formatters are great and recommend using Black, but this isn’t a universal opinion in the Python world. I’d recommend experimenting with this and see if you like it. There are plugins for all code editors which will let you automatically run Black whenever you save a Python file, which makes it really seamless.
Type hinting is a new addition to the Python world, but is something I’m a big fan of. Type hints take some getting used to initially, but the payoff is worth it. Adding type hints to your code allows the editor to auto-suggest variable and method names to you, and automatically tell if you if you mistyped some parameter name somewhere rather than crashing at runtime. Furthermore, users of your code will benefit from your type hints since then your library functions and parameter names will autocomplete in their editor too! If you use type hints, you also need to use a type checker to make the sure types are correct as Python will not do this for you. The two most popular are MyPy and PyRight. These work like linters, and can easily be added to your editor to automatically report type errors as you code.
I also recommend moving away from using Python dictionaries to pass structured data around and instead using dataclasses. Dataclasses allow you to specify exactly what fields some data should contain, and will ensure those fields exist. This fits in nicely with type hinting since MyPy and other type checkers can verify you’re using the dataclasses correctly, and you never have to worry about accidentally mistyping a key of some python dict ever again.
Testing is something that I didn’t understand the value of until I started working professionally. There’s a natural aversion to writing tests as it feels like a bunch of extra work you need to do, and everyone always feels like they don’t have time for that. However, as I’ve improved as a software engineer and gotten more comfortable with writing tests, I find the exact opposite: I don’t have time not to write tests.
If you don’t add test cases as you code, you’re probably testing manually. However, this manual testing means that anytime you want to make a change to your existing code, either to refactor or to add new features, you’re always terrified you might accidentally break something you already wrote. Then you need to either go back and manually test everything again, likely forgetting something, or you just hack in your change in whatever way is the least likely to break something, resulting in hacks on top of hacks. This “fear-driven development” leads to horrible messes of code that are almost certainly broken and I believe leads to a lot of the code quality issues endemic throughout academic codebases.
Testing doesn’t have to be difficult. If you can test some piece of code in a Pytest test case rather than manually, do it. Some further tips for testing practically:
approx which lets you write assertions like assert pytest.approx(x) == 3.1, and torch has assert torch.allclose(tensor1, tensor2) to check if tensors are “close enough”.import pdb; pdb.set_trace() into your code, then run a test that runs through that code path so you can interactively experiment as you work.The following are a few things that drive me crazy when reading research code. These are just my personal preferences, so YMMV.
classification_vector rather than just c.If there’s a single thing I want to leave you with, it’s that code published along with research papers should be usable by others. If you can accomplish that, you’re most of the way there. I believe that everything discussed in this article is very achievable for researchers, and is a lot easier than doing research itself, or learning LaTeX, or publishing papers. Researchers are smart people, and none of this is difficult. Once you get used to packaging code into a library that can be installed with pip install, it becomes second-nature, and the benefits to your success as a researcher and to others who want to use your code are immense.
For further reading, I’d recommend this excellent article on Python best practices. It’s a couple years old at this point, but I think the ideas in the article are still very valid today.
]]>
A fun way to spend a Saturday
The specific error for my build was that Sphinx was throwing an ImportError whenever it tried to import the Python code with Rust bindings, despite everything working when I tested locally, and Readthedocs even being able to build the actual Python/Rust package without issue. After several hours of pushing a change to try to fix the build, waiting 5 minutes for the build to finish, see it fail, then repeat, I got frustrated and decided to just see if I can somehow ssh into the live build and debug it directly there.
I remembed the Github Action actions-tmate which provides exactly this functionality in Github using tmate, so figured it might work for Readthedocs too. However, there are several impediments to this working easily in Readthedocs. Specifically:
tmate -F in the build output in order to connect.tmate in the background also doesn’t work, I suspect Readthedocs must kill processes between commands or somethingFortunately, tmate lets you set up webhooks which it calls whenever it starts up a session, and contains all the info needed to connect! Combining this with ngrok makes it possible to get notified via webhook when the session starts so you can ssh in and debug to your heart’s desire.
The full steps to get this working are laid out below.
Grab a copy of ngrok for your local machine from https://ngrok.com or via your package manager of choice. Start it up with ngrok http 5000 (the port doesn’t really matter much), and you should see an ouput like below.
Keep note of the URL, which in the example above is “https://ed59-81-107-232-184.eu.ngrok.io” for the next step.
In the Readthedocs UI for your project, go to “Admin” then “Environment Variables” and add a new environment variable. The name should be “WEBHOOK”, and the value is the ngrok URL from step 1.
You can customize your Readthedocs build using a .readthedocs.yml file in your Git repo. To set up tmate and remote debugging, configure your .readthedocs.yml to look like the example below. This will install tmate, configure it to use your ngrok URL as a webhook, and begin running tmate during the build.
build:
os: ubuntu-22.04
apt_packages:
- tmate
jobs:
post_install:
- echo "set-option -g tmate-webhook-url '${WEBHOOK}'" >> ~/.tmate.conf
- tmate -F
Commit this change to your git repo so that Readthedocs starts building.
Ngrok lets you see all incoming requests via a locally running web interface. Open you your web browser to http://127.0.0.1:4000 and keep an eye out for the webhook from Readthedocs, which should show up as a POST / with a 403 Forbidden response, since we’re not actually returning anything. We just want to see the info that got posted in the JSON
Once a webhook comes in, find the fields ssh_cmd_format and stoken. The ssh_cmd_format field should look something like ssh %[email protected], and the stoken field should look like a random string of characters.
An example of what the webhook data looks like in the local ngrok UI, with the ssh_cmd_format and stoken fields highligted
Just replace the %s in the ssh_cmd_format with the value in stoken, and copy / paste the command into a terminal on your local computer and run it.
You should now have a working tmate terminal into your running Readthedocs build. Hopefully it should be a breeze to debug from there!
In my case, I solved my docs build bug about 5 minutes after getting this working. In case it’s useful to anyone else with the same issue with hybrid Rust/Python apps, the solution was to delete the main module folder containing Python code (tensor_theorem_prover in my case) after running pip install .. I don’t fully understand why this works, but it seems like somehow Python was finding the local folder rather than the compiled wheel with the rust code in it, and deleting the local module folder forced it to find the compiled module instead. ¯\_(ツ)_/¯
Hopefully this technique is helpful if you’re ever stuck debugging Readthedocs builds.
release. Whatever is in the branch will deploy to Netlifyrelease branch at the most recent tagrelease branchWe’ll discuss the rationale behind this and go through how to do this in more detail below.
Sadly, direcly pushing tags to Netlify is tricky for a few reasons. First, Netlify doesn’t support building on tags, which would be the most obvious way to get this to work. Next, you may think, why not have Github actions run the app build and push to Netlify? This would work, but Netlify doesn’t allow creation of scoped API tokens, so any API token you generate will have access to everything in your Netlify account, not just the app you’re trying to deploy. You can get around this by creating a new Netlify account which only has access to the single app you want to deploy, but Netlify charges for every account that can access the project.
In addition, deploying directly on a tag makes it hard to implement hotfix workflows, where you may have code in your main or master branch that isn’t ready to go to production, but you need to release a fix ASAP. It’s possible to get around this by checking out the latest release into its own branch, adding a fix, and then releasing that, but it’s an annoying process. Using a release branch as an intermediary means you can also just treat it as a normal branch and push hotfixes directly to that branch in an emergency.
First, create a branch in your Github repo called release.
Next, create the Github Action which will point the release branch at a tag whenever a new tag is created. This is done by creating a workflow file in your repo in the folder .github/workflows/. In the example below, we name our workflow file release.yml.
# .github/workflows/release.yml
name: release
on:
push:
tags:
- "v*"
jobs:
deploy_releases:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
ref: release
# Point release branch at the latest tag so Netlify can deploy it
- name: Point release branch at tag
run: |
echo "Setting 'release' branch to 'tags/$'"
git checkout release
git reset --hard tags/$
git push -f
This workflow will run on any tag that starts with a v, assuming your releases are named like v1.2.3 for v17. If you want this to run on a different tag, or on all tags, just modify the - "v*" line in the workflow above to the pattern you’d like. Once you push this file to your repo the github side of things is good to go!
Finally, you just need to change the deploy branch for your netlify site to deploy to the release branch instead of the default main or master. You can find this setting in Netlify at “Site settings” → “Build & deploy” → “Branches and deploy contexts” → “Production branch”.
And that’s it! You’re now set up to deploy to production in Netlify on Git release tags.
]]>On the other end of the development cycle, there’s Jupyter, which lets you write Python code in an interactive notebook, mixing text and images in with executable code. Development and experimentation in Jupyter is a joy since you can easily print interactive tables with data to the screen, or draw images, output interactive tensorboards - basically anything that can be displayed in a web browser can be turned into a Jupyter widget. If we can combine Jupyter with Grid Engine we can get the power of Grid Engine with the development ease of Jupyter.
The issue is that usually Grid Engine jobs don’t have ports open to the outside or directly allow ssh access to the running job, so running Jupyter inside of a Grid Engine session is difficult. Fortunately that’s where Ngrok comes in. Ngrok is a tool which can forward a service running on a local machine and give you a web URL where you can access that service from the internet. This is perfect since it solves the problem of letting you easily access a Jupyter notebook that’s running inside of a Grid Engine session.
First, sign up for a free Ngrok account at ngrok.com. After you sign in, find the link to download the ngrok client for Linux and copy the URL of this link.
Next, ssh into Grid Engine and install ngrok in your home directory. This should look something like below:
# paste the linux download URL for ngrok here, it may be different than what's below
wget https://bin.equinox.io/c/bNyj1mQVY4c/ngrok-v3-stable-linux-amd64.tgz
tar -xvzpf ngrok-v3-stable-linux-amd64.tgz
At this point, you should have an executable called ngrok in your home directory.
Next, copy your auth token from the ngrok website (it should be a long pseudo-random string like c8132179a3cE725B4e267_51F32179C3eE725B4E267) and run the following command:
./ngrok config add-authtoken <your token here>
At this point, ngrok should be good to go! Next just make sure you have jupyter installed with:
pip install notebook
Since the notebook will be accessed on the web, you’ll need to modify the notebook config to allow remote connections. First, ensure you have a jupyter config available by running the following:
jupyter notebook --generate-config
This will generate a config file in your home directory, which should be located at ~/.jupyter/jupyter_notebook_config.py. Run the following command to update this config file and allow remote access:
echo "c.NotebookApp.allow_remote_access = True" >> ~/.jupyter/jupyter_notebook_config.py
It’s a good idea to set a password for jupyter since we’re going to make it accessible on the internet, and you don’t want random strangers on the internet to be able to run code in your notebook if they stumble on the URL somehow.
jupyter notebook password
Next, start an interactive session in Grid Engine, something like the following:
qrsh -l tmem=10G,h_rt=2:00:00,gpu=true -now no -verbose
Once your session has started, you need to run both ngrok and Jupyter on the same port. The specific port number doesn’t matter much - you just don’t want to pick a number that someone else on the same machine might also be using. Below I’m using port 7923, but change this to whatever number you prefer (numbers in the 7000-9999 range tend to be good choices).
(trap 'kill 0' SIGINT; jupyter notebook --no-browser --port 7923 &
~/ngrok http 7923 --log=stdout)
The command above just runs jupyter and ngrok in parallel, and kills them both when you exit the shell.
Now, ngrok should display a URL on the screen (something like https://aba4-128-90-27-382.eu.ngrok.io) which you can open up in your browser, and, voila, you should see your jupyter notebook running inside of your Grid Engine interactive shell! And that’s it, you’ve got Jupyter running inside of Grid Engine.
You can of course run the same commands as a standalone SGE job which you submit with qsub. Copy the script below and save it as remote_jupyter.qsub.sh. Tweak the parameters as needed to suit your use case.
#$ -l mem=10G
#$ -l h_rt=24:0:0
#$ -S /bin/bash
#$ -N remote-jupyter
set -e
# pick a port at random between 7001-7999
PORT=`shuf -i 7001-7999 -n 1`
echo "Starting Jupyter and tunnel on port ${PORT}"
# run jupyter in the background and ngrok in the foreground
# connect to the URL that ngrok outputs to the terminal
(trap 'kill 0' SIGINT; jupyter notebook --no-browser --port ${PORT} &
~/ngrok http ${PORT} --log=stdout)
Then, submit the job as usual with qsub remote_jupyter.qsub.sh. When the job runs, you’ll be able to find the URL of the ngrok session in the logs, or you can check the ngrok web interface at ngrok.com and click “tunnels” to find your jupyter notebook.
This post takes a lot of inspiration and ideas from Khuyen Tran’s blog post. Thanks, Khuyen!
If you have any improvements to this technique, let me know!
]]>Below is a sample of how the AMR looks for the sentence The boy must not go:
(o / obligate-01
:ARG2 (g / go-02
:ARG0 (b / boy)
:polarity -))
Here, the “not” from “must not go” is extracted as the :polarity - attribute, which is helpfully machine-readable. obligate-01 is a semantic frame from Propbank, as is go-02. AMR strips away the tense of verbs and ideally results in sentences with the same meaning having the same AMR representation. AMR is also English-only, as the semantic frames correspond to Engish, but there are projects for AMR for other languages as well.
Ideas like AMR are only as useful as the ecosystem around them. The most important parts of that ecosystem are parsers that can take a natural language sentence and return a representation of the sentence in AMR format. Here, there’s state-of-the-art work done by IBM with their Transition AMR Parser. This parser had the best parsing results when I played around with it, and it even returns alignments so you can map each token in the AMR back to the original sentence! However, it’s also extremely difficult to get to work due to it having finicky requirements, not being packaged in PyPi, and requires emailing someone at IBM to get acces to a pretrained checkpoint.
Sapienza University also has an open-source parser called Spring. This is easier to get working since it just uses Huggingface internally which is very standard, and the pretrained model checkpoints are publicly available. However, like IBM’s parser, this is research code which isn’t properly packaged on PyPi, and isn’t really documented. So, using it in your own code requires reading through their source code.
Finally, there’s Amrlib. Amrlib is a true joy to use. It’s simple to set up, and even integrates with Spacy! However, Amrlib seems to be the least accurate of the parsers I experimented with. Still, its ease of use makes it definitely worth it for something that just works. Hopefully the parsers from IBM and Sapienza can learn from Amrlib’s usability.
All of the parsers discussed above suffer from the same core problem: lack of diverse, freely-available AMR training data in large quantity. This is an unfortunate problem which I feel holds back AMR from reaching its true potential. The largest AMR training data set, which all the parsers mentioned above train on, is the offical AMR Corpus. This corpus contains 59,255 AMR-annotated sentences, which is significant but still tiny compared to the amount of data modern NLP systems are trained on. Furthermore, this dataset is not freely available - it requires paying $300 just to access it! This is almost certainly discouraging innovation in the AMR space by setting such a huge financial bar to even experiment with the data.
This wouldn’t be so bad if tools to create more high-quality annotated AMR training data were readily available. However, here, too, the only AMR editor that I could find, the official AMR Editor, is closed-source and outdated. According to the editor page, it’s estimated it takes 10 minutes just to annotate a single sentence with the tool! I could imagine this editor could be improved if it were open-source so the community could contribute, or if another person in the community could create and open-source a high-quaity AMR editor.
I feel like AMR has so much potential. The core idea is simple yet powerful, and it feels like a natural way to parse a sentence for semantic meaning. The number of research papers still being written on AMR shows that there’s a lot of other people who can see the power of AMR as well. It’s just too bad that the closed, paywalled nature of the training data and the closed-source editor hinder AMR from reaching its full potential. Hopefully the community will address these issues in the future!
]]>
A transformer reading a book, generated by pixray-vqgan
TLDR; If you want to skip the details of how this works, the end-project is available here: Frame-Semantic-Transformer. A live demo is available as well.
If you want to get meaningful semantic information from a sentence that can be used by algorithms, you first need parse it. One powerful framework for this is the idea of Frame Semantics. Frame semantics break apart a sentence into concepts called “frames”, where each frame contains attributes called “frame elements” which describe what’s going on in the frame, and has a “trigger” in the sentence which evokes the frame.
For instance, consider the sentence below:
Sergey dodged the flower pot that Larry threw in disgust.
A frame that might be present is the idea of “dodging”:
frame: Dodging
trigger: "dodged"
elements:
Dodger="Sergey"
Bad_entity="the flower pot that Larry threw in disgust"
There can be multiple frames present in a sentence at a time, and frames can relate to and inherit from other frames as well.
The gold standard for Frame Semantics is a project called FrameNet, which contains an open database of thousands of frames and annotated example texts.
I’m certainly not the first person to attempt to build a frame semantic parser (sometimes also called automatic semantic role labeling). The 2 state-of-the-art projects I found are Open-Sesame, and the paper Open-Domain Frame Semantic Parsing Using Transformers.
Open-Sesame is the best performing open-source frame semantic parser, but has a number of problems that make it difficult to work with as an end-user:
The paper Open-Domain Frame Semantic Parsing Using Transformers looks great - it uses Google’s T5 Transformer and claims to achieve even better performance than Open-Sesame. However, it’s not open-source, so there’s no actual code to run or a library to work with.
I decided to combine the best of Open-Sesame and “Open-Domain Frame Semantic Parsing Using Transformers” to build an easy-to-use open-source frame semantic parser on modern technology. I used the data splitting, task definitions, and evaluation criteria from Open-Sesame, while using a T5 transformer as a base model as in the open-domain parsing paper.
My goal is to create a frame-semantic parser which meets the following criteria:
pip install.Semantic parsing of a sentence as performed by Open-Sesame requires 3 steps:
For example, consider the following the sentence:
It was no use trying the lift.
For the first step, trigger identification, we would identify the 2 following locations, indicated by *’s in the sentence below:
It was no use *trying the *lift.
Next, we need to identify which frame corresponds with each trigger location:
It was no use *trying the lift.
-> Attempt_means
It was no use trying the *lift.
-> Connecting_architecture
Finally, for each trigger and frame, we need to find the frame elements in the frame:
It was no use *trying the lift. :: Attempt_means
-> Means="the lift"
It was no use trying the *lift. :: Connecting_architecture
-> Part="lift"
In FrameNet, there are tens of thousands of annotated sentences like this indicating the triggers, frames, and frame elements in the sentence which we can use to train our model.
Transformer architectures have revolutionize the field of language processing (NLP) since their introduction in 2017. The typical idea is to start with a transformer model that’s already pre-trained on a massive quantity of text from the internet, and just “fine-tune” it on the actual task you care about.
In this case, we use the T5 transformer provided by HuggingFace. T5 uses the idea of having a single model perform multiple tasks, with each task simply being indicated by adding a keyword to the input text.
For example, for the sentence we discussed above, we could break apart the tasks as follows:
First, trigger identification
input: "TRIGGER: It was no use trying the lift."
output: "It was no use *trying the *lift."
Next, frame classification
input: "FRAME: It was no use *trying the lift."
output: "Attempt_means"
input: "FRAME: It was no use trying the *lift."
output: "Connecting_architecture"
Finally, argument extraction:
input: "ARGS Attempt_means: It was no use *trying the lift."
output: "Means=the lift"
input: "ARGS Connecting_architecture: It was no use trying the *lift."
output: "Parts=lift"
Notice above how all the tasks follow the same input/output format, where each task takes a string as input and returns a string as output. Furthermore, each task is specified by putting a keyword at the start of the input followed by a :, for example Frame: for frame classification, and Trigger: for trigger identification. For argument extraction, we also put the name of the frame as part of the task definition ARGS <frame_name>:.
I based the T5 training on SimpleT5, which uses Pytorch Lighting and HuggingFace under the hood.
…and that’s all it really takes to get a working frame semantic parser using T5!
That’s not the end of the story, unfortunately. This approach performs well already - it actually beats Open-Sesame at argument extraction even without any extra tweaks! However, it doesn’t perform as well at frame classification, and we can do even better at argument extraction.
The key insight is that for the frame identification and argument extraction tasks, we can give T5 some extra hints to help it choose the best results. For frame classification, FrameNet includes a list of “lexical units” which are likely triggers of each frame. We can use this list to find some candidate frames for each trigger word.
For instance, for the word try, the following lexical units appear in FrameNet:
try.v : Attempttry.v : Try_defendanttry.v : Attempt_meanstry.v : TastingWith the sentence It was no use *trying the lift., we can extract the labeled trigger word trying, stem it with NLTK to get try, and then check the lexical unit list in FrameNet to see a list of reasonable frames to guess. Then, we can pass these into T5 in the task definition, like below:
input: "FRAME Attempt Try_defendant Attempt_means Tasting: It was no use *trying the lift."
output: "Attempt_means"
By checking the lexical units list in FrameNet, we’re able to provide 4 possible frames to T5 in the task definition which makes it must easier for it to simply pick one of those 4 frames rather than needing to guess the frame out of thin air! In the Frame-Semantic-Transformer project, we take this even further and check bigrams of words involving the trigger as well to search for matching lexical units.
For the argument extraction task, we can similarly help T5 by pre-emptively pulling out a list of all the possible frame elements for the frame in question. For instance, the Attempt_means frame has the following possible elements, abbreviated for clarity:
AgentMeansGoalCircumstancesWe can similarly provide this list to T5 as part of the task header:
input: "ARGS Attempt_means | Agent Means Goal Circumstances: It was no use *trying the lift."
output: "Attempt_means"
After first trying this out, it became immediately apparent that the model was overfit to the data as it appears on FrameNet. Specifically, on FrameNet, all sentences end in proper punctuation. If you try asking for the frames of a sentence that doesn’t have a period at the end, the model often freaks out and starts repeating itself over-and-over and outputting nonsense. During training it never encountered an input sentence without a period at the end, it didn’t know what to do!
To alleviate this, Frame-Semantic-Transformer adds some extra data augmentations to the training samples, like occasionally dropping the period at the end of the sentence, or changing “can’t” into “cannot”, or making everything lowercase. These tweaks won’t help the model improve its score on the test data, but it should help it work better on unseen data.
So of course the question is: how does this T5-based approach compare to Open-Sesame? I trained the model on the same breakdown of train/dev/test documents from FrameNet as Open-Sesame, and I tried to use the same metrics as Open-Sesame so the results would be a fair apples-to-apples comparison. I also trained 2 variants of the T5 model - one variant using t5-base which is about 850MB, and another using t5-small, which is about 230MB.
The results are as follows on the Open-Sesame test set:
| Task | Sesame F1 | Small Model F1 | Base Model F1 |
|---|---|---|---|
| Trigger identification | 0.73 | 0.70 | 0.72 |
| Frame classification | 0.87 | 0.81 | 0.87 |
| Argument extraction | 0.61 | 0.70 | 0.72 |
The base model performs pretty similarly to Open-Sesame at task identification and frame classification, but performs significantly better at argument extraction. The small model performs a bit worse than the base model, and under-performs Open-Sesame on trigger identification and frame classification, but is still significantly better than Sesame at argument extraction.
I expect there’s still more improvements that can be made to help Frame-Semantic-Transformer perform even better than it does now:
t5-large and t5-3b which could perform even better.Longer term, it would be great to expand this to bigger / better datasets than FrameNet (ex. multi-lingual framenet) that can be used to train the model. It would be awesome as well to try to generate more frames / lexical units for FrameNet automatically using a technique like what was done to generate Atomic10x in the paper Symbolic Knowledge Distillation - from General Language Models to Commonsense Models
Any contributions to the project or thoughts/feedback is welcome!
]]>Over the past few weeks I’ve been using Vowpal Wabbit (VW) to develop contextual bandit algorithms in Python. Vowpal Wabbit’s core functionality is excellent and it appears to be the industry standard for working with bandits. However, the library is not well documented and has numerous gotchas and partially-working features, especially in the Python bindings. The library overall feels like it was built by academics rather than engineers, so the documentation treats most of the core engineering tasks as trivial and not worth explaining, while frequenly linking off to 50-page long academic research papers as explanations of what the options in the library mean.
As an engineer, there’s a lot I’ve learned that I wish I knew when I first started using this library. This post is a brain-dump of what I’ve learned that’s been useful, important, or surprising for me working with this library. I hope it will be useful for others as well! the core functionality of the library is truly excellent, it just takes a bit of effort to get it into a state where it can really shine.
This post will focus on working with the Python bindings, but a most of this will apply to working with the command-line interface as well, since the Python wrapper is just a thin wrapper around the CLI. I used the --cb_explore_adf setting, which is the most complicated, least documented, and, in my opinion, most useful setting for bandits. This setting allows for picking from a different set of actions at each invocation of the library, and allows actions to have rich sets of features as well. This post will focus on using this setting, but a lot of this post will still be relevant for using other bandit settings in vowpal wabbit as well.
If you see any mistakes or places where there are misunderstandings in this post, please leave a comment and let me know! I’m still learning and will continue to make corrections and improvements to this article as I learn more.
The default VW input format is a string format that looks like the following:
shared | UserAge:15
| elections maine SourceTV
0:3:.3 | Sourcewww topic:4
VW also supports a JSON input format, which would look like the following:
{
"UserAge": 15,
"_multi": [
{ "_text": "elections maine", "Source": "TV" },
{ "Source": "www", "topic": 4, "_label": "0:3:.3" }
]
}
I went with the JSON format since it feels more structured, but this has been hard since this format isn’t super well documented. This JSON format is only valid JSON for individual examples. If you want to use it for more than a single example, you need to concat JSON examples with newlines between them, NOT use a JSON array, as you would probably expect. For example:
Correct:
{
"User": ...
"_multi":[...]
}
{
"User": ...
"_multi":[...]
}
{
"User": ...
"_multi":[...]
}
Incorrect:
[
{
"User": ...
"_multi":[...]
},
{
"User": ...
"_multi":[...]
},
{
"User": ...
"_multi":[...]
}
]
This was a surprise, because the “correct” way to use the JSON format here is to not actually valid JSON! Also, if you use this format you need to pass the --json param to VW.
There’s another json format called --dsjson. This is even less documented than the --json format, so I wasn’t able to figure out how to use it.
If you want to use the JSON format in python, you need to pass the JSON to VW as a JSON-encoded string, not a Python dict. So something like the following:
import vowpalwabbit
import json
vw = vowpalwabbit.Workspace("--cb_explore_adf --json")
example = {
"UserAge":15,
"_multi":[
{"_text":"elections maine", "Source":"TV"},
{"Source":"www", "topic":4, "_label":"0:3:.3"}
]
}
vw.learn(json.dumps(example))
You should put all your features into namespaces rather than on the top level, since this lets you make your model more powerful with the --quadratic and --cubic options as we’ll see later. For instance, below we put shared features into a namespace called “User”, and action features into a namespace called “Action”, although you can have multiple shared and action-level namespaces if you want. In JSON, this looks like the following:
{
"User": { "age": 15 },
"_multi": [
{ "Action": { "_text": "elections maine", "Source": "TV" } },
{ "Action": { "Source": "www", "topic": 4 }, "_label": "0:3:.3" }
]
}
Note: the _label property appears outside of the namespace for the action that was chosen.
I was originally pretty confused by the format of _label for --cb_explore_adf. The label has 3 components, the action number, the cost, and the probability that this action was picked by the policy that generated the data. For --cb_explore_adf, the action number is meaningless, so just write 0 ¯\_(ツ)_/¯.
VW will hash all your features into a large number of buckets (2^18 by default), and learns a weight for each bucket. Then, it just sums the weights of each bucket together to get a score for the action. This is demonstrated in the diagram below.
Basic Vowpal Wabbit model architecture
This is just a simple linear combination of the features passed in, which is very fast to compute, optimize, and understand, but this means it can’t learn a model that’s a combination of input features. For example, if users in Maine who watch TV react well to an action, but not users who watch TV in other places, this model cannot capture that. It can only capture features of users in Maine on the whole, and features of users who watch TV on the whole, but not together.
--quadratic and --cubicThe default model architecture is almost never going to give good results, so you need to tweak the the model architecture to allow it to learn a better estimator. One of the simplest yet still powerful ways to do that is via the --quadratic or -q option. This option allows you to generate new features from every combination of features in namespaces.
The syntax to do this is pretty strange, you need to take the first letter of the name of each namespace and pass a 2-character string after -q to indicate which 2 namespaces to mix together. In our case above where we have an Action namespace at a User namespace, we could mix them with -q UA. We could even do -q UU to mix the User namespace with itself. You can also pass -q multiple times with different combinations of namespaces. You can use : to indicate everything across all namespaces. So -q U: would mix the User namespace with everything across all namespaces.
If you want generate features by mixing 3 namespaces together, you can use --cubic like --cubic UAC or --cubic UUA. If you want to mix more than 3 namespace permutations together, you can use --interactions to specify any number of namespaces to mix together. For example --interactions UAXBY to mix 5 namespaces together.
I think that if there are numeric features, only the value of the last feature in the namespace will be used as the numeric value, so if you have a namespace with a lot of numeric features it should probably go last. (I could be wrong about this!)
There’s a list of more feature enhancement settings available in the VW Wiki.
--nnIn addition to mixing features together, you can a use simple feed-forward neural network as the model instead of just a pure linear model with the --nn param. The depth of the neural network is specified using an int, so --nn 2 would be a 2-layer neural network. There are a number of options available to further tune the neural network architecture in the VW Wiki.
Vowpal Wabbit is extremely fast to train, which is nice because it makes it easy to test out lots of different model settings using offline policy evaluation (OPE). There’s a good tutorial on how to do this on the vowpal wabbit website, so I won’t go into too much detail here, but I found offline policy evaluation essential to figuring out which model params to use to get good results.
One thing that confused me at first was that OPE outputs what it calls “average loss”, but really this means “average cost”. If you use negative cost like I did, then “average loss” will be negative. In all cases, the lower the number for “average loss” the better, even if it’s negative.
Make sure to try out lots of different settings for things like learning rate (-l) and number of passes over the data (--passes) as well. I also found --cover 1 seems to work much better than --cover 3 for some reason.
In Python, I found that you can use the vw.get_sum_loss() method after doing a test run and dividing by the number of test samples to get the “average loss” which is output by the CLI method, if you want to do this in Python rather than using the CLI.
There are a number of strange quirks with the Python wrapper. It doesn’t always seem to accept examples in the same format always. For example, for .learn() and .predict() you can pass an example directly, but for some methods like .audit_example() you need to parse the example into multiple parts using vw.parse() first.
For JSON input, you need to run the Python dict examples through json.dumps first before passing to vowpal wabbit
There are also methods that just print stuff out to stdout instead of returning a value which is obnoxious. For instance it’s not currently possible to get the results from --audit into a string in Python for further processing. If you don’t pass --quiet, the python library will just print stuff to stdout and stderr as it runs. As far as I can tell, there’s no good way to get this data into a more natural Python interface.
I found it’s easier to just write data to temporary files on disk and train via passing in a reference to the training file rather than passing training examples in Python, due to some of the quirks around how the Python library handles example parsing. This of course assumes that the data for learning isn’t so large that it can’t fit into memory or on disk. The code might like something like the following:
import vowpalwabbit
import json
from tempfile import NamedTemporaryFile
def create_and_train_vw(json_examples):
file = NamedTemporaryFile("w")
file.write("\n".join([json.dumps(ex) for ex in json_examples]))
file.flush()
vw = vowpalwabbit.Workspace(f"--cb_explore_adf --json --quiet -d {file.name}")
file.close()
return vw
When you call .predict() on a vw instance, you’ll just get an array of probabilities mapping a probability to every potential action you could take. To use the output from vowpal wabbit for prediction, you’ll need to sample the predict results according to the probabilities it returns, like below:
import random
def sample_prediction(action_probs):
"return the index of the selected action, and the probability of that action"
[selected_index] = random.choices(range(len(action_probs)), weights=action_probs)
return selected_index, action_probs[selected_index]
action_index, probability = sample_prediction(vw.predict(ex))
The documentation for Vowpal Wabbit leaves a lot to be desired, so you’ll likely need to venture outside of the offical website docs while trying to use the library. There’s an official Wiki on Github for VW which has some good info, but it also has a lot of gaps and some of the pages are incomplete. I found it helpful to ask questions in the VW Community Gitter, as there are people there who respond quickly to any questions. There’s also some good info in Stack Overflow as well. As a last resort, I also found posting issues on the VW Github page to also get a lot of in-depth responses from the devs when I thought something looked like a bug.
I’ll keep updating this post as I learn more. If you see anything that’s not correct, please leave a comment to let me know and I’ll update it!
]]>Contextual Bandits are probably one of the most bizarrely named concepts in reinforcement learning. The idea is an extention of “multi-armed bandits”, which come from an old name for slot machines. Imagine you’re at a casino faced with 4 slot machines. Each machine has a different chance of giving a payout, but you don’t know the chance of winning for each machine. How should you play so as to maximize your winnings? You can spend time testing out each slot machine to get a better sense of the reward for each machine, but then you might be missing out on any potential rewards from just sticking with the machine that seems like it’s the best from what you’ve seen so far.
How should you play each machine to maximize your reward?
For multi-armed bandits, an action corresponds to a possible choice you can make. In the example of 4 slot machines above, there are 4 possible actions, each refering to pulling the lever of one of the 4 slot machines. Furthermore, a policy is an algorithm which determines how you play. Typically a policy is probabilistic, so a policy expresses a probability distribution of taking each action. For instance, in the case above you could try a completely random policy and just pull an arm uniformly at random. Or you could try a policy that pulls lever 1 60% of the time and the other 3 levels each 10% of the time. Or the probability can change over type, starting out more random and becoming more deterministic over time.
Contextual bandits are a type of multi-armed bandit problem where you have some extra information that might be useful in determining which action to take. For instance, if you have an online store and you want to recommend an item to a user who visits your website, the item you choose to recommend might depend on the age and location of the user. Contextual bandit problems are very common in digital products where you have a set of items that can be shown to a user and the goal is to choose the best item to show that will optimize some metric, for instance the chance the user will buy your product.
Let’s say you have an online store with several items, and currently you show those items to your users at random. You know this isn’t optimal though, because you have some extra information about each of your users, like what they last purchased, their age, and their location. If someone just bought a pair of shoes from you, it probably doesn’t make sense to try to show them that same pair of shoes immediately after. You have a log of data from your store for each time you showed an item to a user, and whether or not the user ended up purchasing the item.
How can you use this to data to try out different policies for showing items to users? For each user you only what happened after you showed the user 1 item; you don’t have any way to know what would have happened if they had seen a different item instead. Maybe user X didn’t buy when you showed them shoes, but maybe they would have if you had shown them a shirt. Or maybe they wouldn’t have purchased anything regardless. You can try to come up with a new policy for how to pick which item to show to users on your website, but how can you tell how that policy would have performed given the data that’s been logged so far?
It seems like this should be impossible, but IPS offers a way to estimate how well any new policy would have performed given the log for how items were gifted and received in the past, with some caveats as we’ll see below.
In order for IPS to work, the policy used to generate the log data must be probabilistic and have a non-zero probability of generating picking every action that the new policy we want to test can also generate. In general, as long as the policy that generates the data never assigns a 0 probability to any action at all you should be fine. Contextual Bandit libraries like vowpal wabbit do this automatically. Also, in our data log where we record every action taken and the reward generated, we also need to record the probability of taking that action as output from the generating policy.
The idea behind IPS is to replay the log, and weigh each reward that shows up inversely to how likely the generating policy was to pick that action. This helps correct for the fact that actions that the generating policy selects more often will also show up more frequently in the logs than actions that aren’t selected often. The adjusted reward is then either multiplied by 1 if the policy we’re testing out also selected the same action as the generating policy given the context, or set to 0 if it selects a different action. Finally, these are results are averaged to give the expected reward of the policy we’re testing. In python, this would look something like the following:
def ips_estimate_avg_reward(new_policy, data_log):
total_reward = 0
for (reward, action, probability, context) in data_log:
new_action, new_probability = new_policy(context)
if new_action == action:
total_reward += reward / probability
return total_reward / len(data_log)
Not bad for 7 lines of code! Note that the new_probability generated by our new policy isn’t needed for IPS, but if we were to deploy this policy we’d want to record it in the data log so we could continue running ips estimates of policies in the future.
IPS isn’t perfect, however. While it is an unbiased estimator of the expected reward of a new policy, it can have high variance. This is especially if the policy used to generate the data and the policy being tested are very different. IPS gives 0 reward for every case where the test policy and the generating policy select different actions, so if the policies have little overlap it will require a lot of data before IPS gives good estimates. There are other more complicated estimators that handle this better than IPS, such as Doubly Robust, or Importance-Weighted Regression.
All these techniques are implemented in the also bizarrely named but otherwise excellent Vowpal Wabbit library.
]]>