Demystifying Noise Contrastive Estimation – Jack Morris
Home
Research
Code
Blog
Misc
Demystifying Noise Contrastive Estimation
January 21, 2022
Introduction
This document describes two machine learning methods, Noise Contrastive Estimation (NCE) [5] and its follow-up InfoNCE [15]. We discuss two variants of NCE as well as InfoNCE and a related technique called partition function estimation. NCE-based methods are used for estimating the parameters of a statistical distribution by differentiating between “real data” and “noise”.
We will try to keep our terminology consistent. The original NCE method is sometimes referred to as Local NCE [3,13] or Binary NCE . We’ll call it Local NCE. The follow-up InfoNCE is very similar to an NCE variant called both Global NCE [13] and Ranking NCE [8], which we’ll call Global NCE.
Applications. Local NCE and Global NCE are both computationally inexpensive methods for learning a conditional likelihood $p_\theta(x \mid c)$. NCE is useful in general when the number of possible $x$’s is very large, like in language modeling [6,8], where $|X|$ is the number of words in a vocabulary. InfoNCE is a method for maximizing the mutual information between two variables, which forms the foundation of joint text-to-image learning methods like CLIP [11], where $x$ is an image and $c$ the text of its caption, as well as general contrastive learning-based techniques like SimCLR [1], where $x$ and $c$ are two different ‘views’ of the same image.
Overview. We begin by discussing how NCE works to approximate $p(x \mid c)$. We then consider how we could approximate $p(x \mid c)$ via importance sampling, a method sometimes known as partition function estimation. Then we connect partition function estimation to Global NCE and InfoNCE. We end with discussion of the differences between the methods and suggestions for which to use in practice.
What are x and c?
Local NCE and Global NCE focus on learning $p(x \mid c)$, the likelihood of some $x$ given some ‘context’ $c$. InfoNCE maximizes $\frac{p(x \mid c)}{p(x)}$, a proxy for mutual information between $x$ and $c$. But what are $x$ and $c$ anyway? Here are examples of x’s and c’s from the literature:
In NLP: Local NCE and Global NCE are both used for language modeling, where $x$ is a word and $c$ is the word’s context (the other words in a window around $x$). Here, we want to learn $p(x \mid c)$, the probability of the next word given context. [6]
In speech recognition: Local NCE is used for learning $p(x \mid c)$ where $x$ is a word predicted from audio $c$ in which the word was spoken. [18]
Also in speech recognition: InfoNCE is used to maximize the mutual information between representations of the same word in different contexts. Here, $x$ is a word and $c$ is the audio context. [19]
In reinforcement learning: InfoNCE is used as a regularizer, where $x$ and $c$ are representations of the same game state at different times. Here, we want to maximize the mutual information between $x$ and $c$, which is proportional to $\frac{p(x \mid c)}{p(x)}$. [16]
In computer vision: InfoNCE is used for contrastive learning, where $x$ and $c$ are different views (different crops with random augmentations like color filters and image-stretching applied) of the same image, and we again want to maximize the mutual information between $x$ and $c$. [1]
In generative adversarial networks (GANs): Local NCE is similar to the ‘discriminator’ of a GAN, except that with a GAN, the ‘generator’ $q(x)$ is learned, and with Local NCE the generator $q(x)$ is fixed throughout training. [17]
Local NCE
How can we estimate $p_{\theta}(x \mid c)$ where $c$ is some context and $x$ is one of a large number of classes? We can rewrite the objective like this:
\[p_{\theta}(x \mid c) = \dfrac{f_\theta (x, c)}{\sum_{x'} f_\theta (x', c)} = \dfrac{f_\theta (x, c)}{Z_\theta (c)}\]
where $f_\theta(x, c)$ assigns a score to $x$ in context and $Z_\theta (c)$ is the normalizing constant or partition function . $Z$ is difficult to compute in this case of many possible classes, because we have to sum over every possible $x’$.
Local NCE reduces the problem of learning $p_\theta(x \mid c)$ to learning a binary classifier $p(d \mid x, c)$, where $d$ tells us whether data point $c$ is “real”, i.e. sampled from $p(x \mid c)$, rather than “noise”, sampled from some noise distribution $q(x)$. Technically, $d$ is an indicator variable where $[d=0]$ implies that $x \sim q(x)$.
In Local NCE, we sample $k$ real samples (or negative samples) from $q(x)$ and assign them the label $D=0$. We sample one true data point (or positive sample) from $p(x \mid c)$. We can write the conditional probability of $d$ having observed $x$ and $c$:
\[\begin{align*}<br>p(D = 0 \mid x, c) &= \dfrac{k \cdot q(x)}{p(x \mid c) + k \cdot q(x)} \\<br>p(D = 1 \mid x, c) &= \dfrac{p(x \mid c)}{p(x \mid c) + k \cdot q(x)}<br>\end{align*}\]
Now we want to write these probabilities in terms of the scoring function...