Every token, everywhere, all at once

smaddrellmander1 pts0 comments

Every token, everywhere, all at once — idlemachines

← essays<br>← essays

Almost any model that ends up producing an embedding has a moment where it has to take a variable-length set of vectors and produce exactly one output. The step that does this collapsing is called pooling.

It's worth a moment to think about why we need this layer.<br>The thing is most outputs we want from a model are either single values (classification classes, regression outputs etc.) or fixed-size vectors (embeddings, features, etc.).<br>This is at odds with much of the data we want to process as sequences or structured images and graphs.<br>We can think about it as a dimensionality problem, our data often has at least one more dimension than the output we want, and the pooling step is how we reduce one to the other.<br>The problem is a familiar story about compression, unless we literally just concatenate all the vectors together we have to sacrifice some information. The question is which bit. Looking at the ways different models make this compression efficient tells us a lot about what they think is important, and what capacity the model has in the learned representations.

Jump in the pool

All pooling methods can generally be arranged in levels of how much information they incorporate from each token.<br>The simplest end of the spectrum is a max pooling, where we take a single token to represent the whole sequence.<br>The other end of the spectrum is a learned attention pooling, where we learn a query vector that scores every token and gives us a weighted average of the whole sequence.<br>In between we have mean pooling, where we give every token equal weight, and GeM pooling, where we give more weight to the larger values but don't throw the smaller ones away entirely.<br>The pooling method you choose to read out from your model is an important design choice, and needs to be done in context of the encoder architecture and the task you're doing. (As well as the strengths or limitations of your model.)

Practice<br>Pooling ToolkitEasyQ386<br>Pooling spectrum, implement the key methods with their masking rules and see how they differ on the same input. The figure is a good reference for what the outputs should look like.

Six pooling methods applied to the same six token vectors<br>Six ways of pooling the same six token vectors. Each token is drawn as a short vector with four entries shaded by value, the bar above it is the weight that method assigns it, and the vector on the right is the pooled result. Mean spreads the weight evenly; max and last-token copy a single token and throw the rest away; GeM leans toward the larger vectors; attention pooling sets the weights from content, and you can see it pulling toward t5 even though t5 is faint; latent attention runs several learned queries and pools their outputs. Reading down the figure is reading along the spectrum, from fixed-and-uniform at the top to learned-and-content-dependent at the bottom.

This essay is about single-vector pooling, where a collection of vectors goes in and one vector comes out. We're going to just talk about tokens in a NLP context, but nothing here is specific to language. These methods apply as well to images and graphs or any other modality you can encode as a set of vectors.

All tokens are equal ...

Practice<br>Sum Pooling and Length SensitivityEasyQ376<br>Sum Pooling — implement the canonical Deep Sets aggregator, and observe in the tests why it isn't length-invariant.

Practice<br>Deep Sets, Variance from SumsMediumQ384<br>Deep Sets Aggregation — implement ρ ⁣(∑iϕ(xi))\rho\!\left(\sum_i \phi(x_i)\right)ρ(∑i​ϕ(xi​)) and recover sample variance as a concrete permutation-invariant set function.

Sum . This is the simplest concept in many ways, take the sequence of all tokens and add them up into one mega vector. This is neatly permutation invariant by construction, we don't care which order the tokens are seen. The problem is in the magnitude. If we have a sequence of length nnn and we sum the token vectors, the result grows in magnitude with nnn. This is a problem if we want to compare sequences of different lengths, or if we want to have a fixed scale for our embeddings. Generally ML models have a relatively narrow window of input magnitudes where things are well behaved.

Practice<br>Permutation Invariance CheckMediumQ383<br>Permutation Invariance Check — measure how much a pooling function drifts when you shuffle its input. Mean drifts by zero; last-token does not.

Mean . To get around the length sensitivity of sum pooling, we can divide by the sequence length and get the mean. This is the default for encoder-only models like BERT, and it's a strong baseline for many tasks. This is everywhere in text embeddings, and plays a huge role in Protein Language Models (PLMs) too.<br>The mean is still permutation invariant, but it has a fixed scale regardless of how many tokens we have.

Practice<br>Masked Mean PoolingEasyQ375<br>Masked Mean Pooling — average over real tokens only. Divide by mask.sum(), not...

pooling token vectors mean length from

Related Articles