ericjwang.com
(Work in progress.)
OpenAI recently released a blog post detailing some of its correspondence with Elon Musk. Although large sections of this correspondence were redacted, observant readers realized that the lengths of the redacted words were leaked in the HTML of the censor bars.
This raises an obvious research question. Is it possible to recover a redacted text sequence given:
The text preceding and succeeding the redaction, which provides semantic and grammatical clues, and
The word lengths of the redacted text (more precisely, the total length of the redaction and the positions of the spaces within it)?
Many smart people think the answer is obvious. But they are split on whether it is obviously yes or obviously no.
On the “yes” side, we note that certain features of the obscured text (short words, long words, runs of either) greatly reduce the set of possibilities for certain segments of the text. Moreover, we know that the text concerns the financial costs of DeepMind or one of its projects; that it does not concern TensorFlow, TPUs, Brain, Research, or Cloud; and that it ends in a period. If we could account for all these “clues” simultaneously, as well as the prior of grammatical correctness, we may be able to determine the original text uniquely.
On the “no” side, we note that the first word of a sentence is often difficult to predict, as are proper names, etc. Moreover, the very fact that the text was redacted indicates that its content is worth hiding, i.e. surprising, i.e. difficult to predict.
What would it mean for the text to be recoverable? Certainly not that the text is uniquely determined by the lengths. We can go word by word and construct syntactically correct sentences satisfying the constraint, but the semantics will be highly unlikely — to say nothing of the pitfalls in numerical figures and typographical errors.
We might say that the text is recoverable if there exists a completion that seems significantly more plausible than any other, or if all “most plausible” completions are semantically similar up to details like capitalization or specific numerical figures. Then:
The latter seems like a higher bar. Regardless, both require us to produce plausible completions, which will involve a painstaking, backtracking-like approach that is well-suited to a computer search.
Any algorithmic resolution to the problem is going to involve the use of a reasonably powerful open-source language model to construct a measure of “plausibility” that incorporates the semantic and grammatical information. Specifically:
For a given length \(N\), an autoregressive LLM defines a distribution \(p_\theta(x_{\leq N}) = \prod_{i=1}^N p_\theta(x_i \mid x_{<i})\) over token sequences \(x_{\leq N}\).
Let \(C_N\) be the event that the generated text \(x_{\leq N}\) satisfies the word-length constraint; formally, that the characters at the relevant positions are spaces.
Then we are concerned with the distribution \(p_\theta(x_{\leq N} \mid C_N)\).
What do we want to do with this distribution? Intuitively we want to find token sequences \(x_{\leq N}\) that maximize it. If we can produce two such sequences comparable in machine-assessed or human-assessed quality, the problem is underdetermined. However, this distribution is nasty:
For \(x_{\leq N} \in X_{C_N}\) the set of sequences that satisfy \(C\), we know that
\[\begin{align} p_\theta(x_{\leq N} | C_N) &= \frac{p_\theta(x_{\leq N}, C_N)}{p_\theta(C_N)} \\ &= \frac{p_\theta(x_{\leq N})}{\sum_{x \in X_{C_N}}p_\theta(x)} \\ &\propto p_\theta(x_{\leq N}). \end{align}\]The denominator \(\sum_{x \in X_{C_N}}p_\theta(x)\) is intractable, so it is intractible to evaluate any sequence’s exact probability.
More concerningly, sampling from this distribution is difficult, because unlike \(p_\theta\) the unconditional distribution we can’t factor \(p_\theta(\cdot \mid C_N)\) autoregressively. In particular, we would need to compute
\[\begin{align} p_\theta(x_i \mid x_{<i}, C_N) &= \frac{p_\theta(x_i, x_{<i}, C_N)}{p_\theta(x_{<i}, C_N)} \\ &= \frac{p_\theta(C_N \mid x_i, x_{<i})}{p_\theta(C_N \mid x_{<i})} \frac{p_\theta(x_i, x_{<i})}{p_\theta(x_{<i})} \\ &= \frac{p_\theta(C_N \mid x_i, x_{<i})}{p_\theta(C_N \mid x_{<i})} p_\theta(x_i|x_{<i}), \end{align}\]but \(p_\theta(C_N \mid x_{\leq i})\) and \(p_\theta(C_N \mid x_{\leq i})\) and their ratios are again intractable (and telescoping them in the product doesn’t help either). Moreover, it’s infeasible to do rejection sampling here because \(p_\theta(C_N)\) is vanishingly small.
This intractibility comes from an autoregressive LLM’s inability to effectively take future constraints into consideration. But if this is the case, how do constrained sampling frameworks like Guidance or Outlines work?
One good overview can be found here; but generally, these frameworks compute only the likelihood conditional on the next token noting violate the constraint; that is, if we let \(C_i\) be the event that \(x_{\leq i}\) does not violate the constraint, they sample from the distribution
\[q_\theta(x_{\leq N}) \triangleq \prod_{i = 1}^N q_\theta(x_i \mid x_{< i}) \triangleq \prod_{i = 1}^N p_\theta(x_i \mid x_{< i}, C_{i}).\]The advantage of \(q_\theta\) is that it’s a distribution over \(X_{C_N}\) that we can actually sample from. The disadvantage is that it’s also a terrible estimate of \(p_\theta(x_i \mid x_{<i}, C_N)\) in many cases. If we are looking for an eight-letter word to follow the quick brown fox jumped over the lazy
, the token _dog
will satisfy the constraint — but the five-letter continuations that follow will all have higher net perplexity than _kangaroo
.
Let’s reframe the problem slightly. What we want is to find \(x_{\leq N}\) that minimizes the following “surprisal” function:
\[f_\theta(x_{\leq N}) = \begin{cases} - \log p_\theta(x_{\leq N}) & \text{if $x_{\leq N} \in X_{C_N}$} \\ +\infty & \text{otherwise.} \end{cases}\]But \(p_\theta(x_{\leq N})\) still factors:
\[\log p_\theta(x_{\leq N}) = \sum_{i=1}^N \log p_\theta(x_i|x_{<i})\]So getting relatively low-surprisal samples is a tree search problem:
Each node on the tree corresponds to a sequence \(x_{<i}\) that has not violated the constraint \(C_N\). (You can also imagine each node with the same number of children and some infinite weights.)
Each node has up to \(n_{vocab}\approx 10^5\) children, corresponding to continuations that also do not violate \(C_N\). This is an enormous branching factor.
Each edge has a nonnegative weight, i.e. the conditional negative log-probability \(- \log p_\theta(x_i \mid x_{<i})\) or “partial surprisal.” (The edge weights are NOT the greedily-conditional negative log-probability \(- \log p_\theta(x_i \mid x_{<i}, C_i)\) from above because \(q_\theta \not \propto p_\theta\).)
Our goal is to find an approximate (time-bounded) algorithm that discovers paths from the root node to the goal nodes with a low total cost/distance/surprisal
\[\sum_{i=1}^N - \log p_\theta(x_i \mid x_{<i}) = -\log p_\theta(x_{\leq n}) \propto^{-1} p_\theta(x_{\leq n} \mid C_N).\]Time is clearly the constraint here because each forward evaluation of \(p_\t˚heta\) on my RTX 4090 takes a few hundred milliseconds.
How does this relate to the existing literature for language model decoding strategies? Lilian Weng has a good overview of that literature on her blog post on Controllable Neural Text Generation. One takeaway is that a good amount of work on decoding strategies pre-GPT-3.5 was concerned with the unreliability of \(p_\theta\), which seems less relevant today now that LLMs are less prone to mode collapse or strange outliers. For instance, Meister et al. (2020) proposes smoothing “overfit” logits with a regularization term, but we’ve already been trusting our \(p_\theta\) enough to use it as a \(p\) this whole time.
Speaking of which — we have made no assumptions about \(p_\theta\) so far, but the choice of search algorithm is going to come down to the properties of this distribution.
For instance, if edge weights were drawn i.i.d. from a distribution with mean \(\mu_f\), and we expected \(\ell\) characters per token, a reasonable A* heuristic would be \(r\mu_f/\ell\), where \(r\) is the number of characters remaining in the constraint.
We could also take subword properties into account—for instance, if the surprisal of a word tends to cluster in its first token, we could use speculative or Medusa decoding as a shortcut to avoid costly exploration of subword tokens. Weaker heuristic refinements can also come from prior estimates of the distribution of token splits within words of specific lengths.
We could account for nodes whose descendants consistently display a low \(p(C_i \mid x_{<i})\) with some kind of MCTS-style backpropagation.
At this point we have the option of searching over a tree of tokens or of words. I prefer searching over words because we can decompose the task into two separate search problems for which we can define different heuristics.
What this means is that we implement a next-word subroutine \(NextWord(prompt, \ell)\) which outputs a non-exhaustive list of words \(w_i\) of length \(\ell\), each often consisting of multiple tokens, alongside the surprisals \(s_i = -\log p_\theta(w_i \mid prompt)\). These words and surprisals define the children and edge distances of \(prompt\) in the search tree.
We should take care when constructing this subroutine to carefully define the probability of a word. The LLaMA and Mistral tokenizers include spaces at the start of a token, which means that the probability the first word of our completion is “deepmind's
” is actually something like
Thus, the subroutine can be implemented as a breadth-first search over token sequences with the following adjustments:
We can see the results of one run of this subroutine below:
The final entry here is instructive; the model, thought an ArXiv link was a plausible continuation, and indeed https://ar
is ten characters long. But the likelihood of a space after these ten characters is so vanishingly small that the word has over twice the surprisal of any of the others returned.
Our top-level search problem is as follows.
What function do we use to assign priorities to word sequences? Using the running surprisal is the same as breadth-first search, which is good enough for discovering words but certainly not up to the task of exploring a tree with depth 48 and a moderate branching factor.
Incorporating the heuristics above (A* search) seems like a good way to balance exploitation and exploration here. However, tuning these heuristics is difficult in practice — my experiments show that wordcount-based, character-count-based, and position-based heuristics often become too greedy later in the sequence even when they seem almost breadth-first near the start.
We run our experiments on my RTX 4090 with Mistral-7B on HuggingFace Transformers (mostly for convenience). Limiting the search space via old-fashioned beam search gets us closer to a coherent response but also becomes quite greedy due to the branching factor: with a beam size of 100, we quickly end up discarding all possibilities aside from something like —
working at the cutting edge of ai is unfortunately expensive. for example, developing something state-of-the-art in deep nets is mostly about the very large data and compute demands, which means [sic] running large neural net models and collecting very large datasets. for these purposes, large data centers and large google compute engine bills are required, with the hardware and data costs
(This was sampled from an earlier version of the code that didn’t constrain the last character to be a period.)
A completion like this may seem superficially plausible, although it still has some glaring errors: repetitive wording, uninformative content that doesn’t need to be redacted, no mention of DeepMind, and an anomalous reference to GCE.
Regardless, my thinking is that to make further progress we need to manage the exploitation-exploration tradeoff directly. We may need to enforce a certain amount of diversity in our exploration, in terms of explored sequence length (multiple queues for each length, with a scheduler that determines which queue to explore from) and/or set of words chosen (contrastive sampling).
Another approach is to fine-tune a model to accept conditional generation tokens:
<l1>
, <l2>
, …, <l20>
to the tokenizer representing word lengths 1-20. Place these tokens before the beginning of each word.Finally, a good number of performance optimizations are available to make the model run faster. We can squeeze out some small performance gains by increasing the model’s batch size. Moreover, because the tree structure of our token sequences is well understood, reusing the “common ancestor” KV cache between evaluations and across batches is another obvious optimization.
Work continues in fits and starts. As of this writing, I’m running another beam search with larger beam size and smaller branching factor.