Building Tokenizer from Scratch
Large Language Models ProgrammingIn learning how LLMs work, flashy ideas like attention and transformers get all of the, well, attention. Arguably, tokenization is the unsung hero that is critical to actually training large language models. Without a way of representing text in a format digestible to neural networks, no training or inference can occur. Therefore, I think if I want to truly develop a deep understanding of large language models, I need to be able to understand tokenization at a deeper level than importing tiktoken. So I set out to implement my own tokenizer, leaning heavily on the supplementary material from Sebastian Rashka.
I have a lot of respect for Sebastian Rashka, and I understand that he has no obligation to provide such a wealth of supplementary information that falls outside the scope of the book. While I am grateful for all of these notebooks, I must admit that the notebook training a tokenizer from scratch is a bit rough. First, I find the language confusing. If a “tokenizer” produces a series of tokens from text during encoding, it seems that the word “token” should refer to the integer representation of one or many characters. So I find the use of “token” as a string and then the unnecessary introduction of “token_id” as conceptually cumbersome. Next, the implementation of the tokenizer is hard to follow due to the lack of useful abstractions and reliance on long, complicated functions. Walking through the code line by line to build up an understanding of the logic is challenging, enough so that I suspect most readers just skip over this notebook due to the required mental energy.
This final claim is supported by the fact that I did work through the code line by line and discovered a long-standing bug in the implementation, which I fixed. However, after understanding the code, I knew there was more that could be done than a simple bug fix. I believed that I could introduce helpful abstractions to greatly clean up the implementation and significantly reduce the cognitive load of understanding all of the relevant concepts for training and using a tokenizer.
UML
After a bit of thought, my implementation converged on the following organization:
Module-Level View
Architecture view of the tokenization module.
I created several core domain modules:
tokenization.text_segment: Represents units of texts to be tokenized.tokenization.token_pair: Simple container for a pair of tokens.tokenization.token_merges: Accounting for identifying tokens to merge using byte-pair encoding.tokenization.vocabulary: Encapsulates the token vocabulary and look-up logic.
These are then used by tokenization.tokenizer, which acts as the orchestration layer. Note that the dependency arrows reveal that this structure allows the tokenizer to drie the process while leveraging the well-defined, single-use modules created.
Object-Level View
Class-level view of the tokenization module.
Here we see that Tokenizer is the central coordinator, delegating most of the work to specialized collaborator classes. It exposes the main public API for tokenization:
encode: Converting text into tokens.decode: Reconstructing text from tokens.train: Learning the vocabulary and merge rules from a corpus.
The two substantial domain classes are:
VocabularyTokenMerges
And there are supporting value objects for small, reusable abstractions:
TextSegmentTokenPair
Collectively, this organization of classes provides a narrow, well-defined public interface through Tokenizer. We use delegation over inheritance, in which Tokenizer orchestrates the use of the single-responsibility classes.
Description
For readers less familiar with UML, let me walk through each of the relevant classes here and their relationships at a high level. I will cover the main classes and public functions associated with each class:
- The first class is
TextSegment, which represents an arbitrary section of text from a corpus. This class is designed to be composed exclusively of “standard” characters or “special” characters, and it has a single public method for encoding this text, which is the single resposibility of the class. The module also has a public functionfrom_text_segments, which takes an entire corpus of text and breaks it up into a sequence ofTextSegmentinstances. - The next abstraction comes from
Vocabulary, which contains the mapping from text to tokens and vice versa. It is produced by theinitialize_vocabularyfunction, which takes a sequence of preprocessed characters and a sequence of special characters as inputs. Then, for each each character that exists, the mappings are updated to contain both token to character and character to token. - There is a simle simple dataclass named
TokenPaircontaining some utility functions for a pair of tokens. It has a public method to determine if the tokens match and also a method to form the text representation of this pair of tokens given an existing instance ofVocabulary. - In
identify_byte_pair_merges, a loop identifies the most frequent token and then selects the corresponding token pair to be replaced. This returns an instance ofTokenMerges, which is initialized with a merge dictionary containing the pair of tokens to be merged along with the resulting token. The first public function of this class updates an existing instance ofVocabularygiven these merges, and the second tokenizes a segment of text. - With all of these abstractions, the implementation of the
Tokenizeris very straightforward. For training, the characters in the training text are preprocessed before forming the vocabulary and corresponding tokens. Then, the token merges from byte pair encoding are identified, and the vocabulary is updated to reflect these merges. For encoding a new text string, the text is split up into disjoint text segments, each of which contains the necessary logic to encode itself. Then the resulting tokens are combined into a single list that is return. And for decoding, the vocabulary is used to map each token to text, and this sequence of strings is concatenated into a single string.
Tokenizer Implementation
With this understanding in place, we arrive at a very human-digestible implementation for the tokenizer:
from typing import List, Optional, Sequence
from tokenization import text_segment as text_segment_module
from tokenization import token_merges as token_merges_module
from tokenization import vocabulary as vocabulary_module
class Tokenizer:
def __init__(self):
self._vocabulary: Optional[vocabulary_module.Vocabulary] = None
self._token_merges: Optional[token_merges_module.TokenMerges] = None
def train(
self,
text: str,
vocab_size: int,
special_token_texts: Sequence[str] = ["<|endoftext|>"]) -> None:
preprocessed_chars = (
vocabulary_module.extract_preprocessed_characters(text))
if self._vocabulary is None:
self._vocabulary = vocabulary_module.initialize_vocabulary(
preprocessed_chars, special_token_texts)
tokens = self._vocabulary.tokenize_preprocessed_characters(
preprocessed_chars)
if self._token_merges is None:
self._token_merges = (
token_merges_module.identify_byte_pair_encoding_merges(
tokens, len(self._vocabulary), vocab_size))
self._token_merges.update_vocabulary(self._vocabulary)
def encode(
self, text: str, special_token_texts: Sequence[str]) -> List[int]:
text_segments = text_segment_module.form_text_segments(
text, special_token_texts)
tokens: List[int] = []
for text_segment in text_segments:
tokens += text_segment.encode(self._vocabulary, self._token_merges)
return tokens
def decode(self, tokens: Sequence[int]) -> str:
token_texts = [
self._vocabulary.get_text(token, raw=False) for token in tokens]
return ''.join(token_texts)
While I find this implementation very elegant and the single responsibility of each class satisfying, my hope is that my tokenizer can be a useful learning tool for others. I believe it provides a much better introduction for anyone looking to deeply understand byte pair encoding and tokenization that other available resources.