Here's a sampling trick I saw for the first time today. My normal pattern for autoregressive rollouts is something like

logits = model(tokens)
probs = logits.softmax(dim=-1)
next_idx = torch.multinomial(probs, 1)
idx = torch.cat((idx, next_idx), dim=-1)

This is fine, but torch.multinomial incurs a cuda sync, which we'd like to avoid if we're trying to do fast generation. There's a fast trick that lets us avoid the CUDA sync and do categorical sampling (mathematically equivalent to the Gumbel-max trick).

Let's say we have a sequence of $V$ variables $E_i\sim\text{Exp}(\lambda_i)$. We have the following fact:

$$\mathbb{P}(\text{argmin}_i E_i = k) = \frac{\lambda_k}{\sum_{j}\lambda_j}$$

To derive this, note that the CDF of an exponential distribution is just $1 - e^{-\lambda x}$, so $\mathbb{P}(E_i > x) = e^{-\lambda x}$. This immediately gives

$$\begin{aligned} \mathbb{P}(\text{argmin}_i E_i = k) &= \int_0^\infty \lambda_k e^{-\lambda_k x}\prod_{j=1, j\neq k}^V e^{-\lambda_j x}\,dx\\ &= \lambda_k\int_{0}^\infty \exp\left(\left(-\sum_{i=1}^V \lambda_i \right)x\right)\,dx\\ &= \frac{\lambda_k}{\sum_{i=1}^V \lambda_i} \end{aligned}$$

So to sample from a categorical distribution, we just generate a bunch of samples from different exponential distributions and take the argmin. Now, the other neat part is the exponential distribution has the property that if $x\sim\text{Exp}(1)$, $x/\lambda\sim\text{Exp}(\lambda)$. This means it's quite simple to sample from these different distributions. We draw a bunch of samples from Exp(1), then scale them by the logits. Explicitly, we can do something like

w = torch.empty_like(probs).exponential_(1)
idx = torch.argmin(w / probs, dim=-1)

The argmin has a nice interpretation of the "earliest arrival time." The important thing is it avoids the cuda sync point which helps with fast inference. It's also a bit cheaper than the gumbel-max trick, since I think you just need an exponential instead of a log and an exponential. The main reason to prefer the Gumbel-max trick is when you want differentiability; you can't differentiate through the argmin. I noticed this when looking through the gpt-fast repo, full credits to them.