1. 说明
跟 Andrej 大佬学 GPT Tokenizer
2. Jupyter Notebook
1. 说明¶
整体跟着 Andrej 大佬教程完成:https://www.youtube.com/watch?v=zduSFxRajkE&t=5600s
逐字符直接和 int 映射的方式会导致 token 非常多,想充分利用模型的上下文长度的话最好对 token 进行一定程度打包。
最简单的例子:比如 ‘was’ ‘Harry’ 都是非常常见的词,我们理解这些词的时候会把其看作一个整体,而非 ‘w”a”s”H’ 等等一个个字母,浪费理解力。
所以需要一些更好的 encode decode 方式。
放编码领域直接能想到的是哈夫曼编码,不过哈夫曼是为了压缩编码信息,并不是在打包 ‘Harry’ 这种词。
GPT 采用的是 BPE (Byte Pair Encoding) 方法:https://en.wikipedia.org/wiki/Byte_pair_encoding
是一种理解容易的算法,并且更能处理语言中的语法和语义特性,同时还能保持适度的压缩性。
2. 从字符到数字¶
压缩前先把待处理的字符转化为数字(且要能转化回来)
2.1 Unicode 字符编码¶
Unicode 维基:https://zh.wikipedia.org/wiki/Unicode#%E7%BC%96%E7%A0%81%E6%96%B9%E5%BC%8F Unicode 给计算机能用的大量文字和符号都进行了唯一编码,总共大概15万个符号
# 查 Unicode 编码
ord("我")
25105
2.2 Unicode 字符串编码¶
由于每个字符对应的 Unicode 编码转化为 2 进制后长度不一致。所以实际编码为 bytes 时又有不同的方式。
- UTF-8 : 能 1 个字节搞定的大部分 1 个字节搞定,不够再扩展到 2 个字节,至多 4 个字节。
- UTF-16 : 至少用 2 个字节对应一个字符,不需要 2 字节的则用 0 填充,不够再扩展到 3 个字节,至多 4 个字节。
- UTF-32 : 每个字符都用 4 个字节编码。
可见 UTF-8 最灵活省空间,虽然解析麻烦一些。UTF-32 解析简单,但浪费空间。UTF-16 折衷,但个人感觉两边不讨好。
GPT 采用的是 UTF-8
# UTF-8 编码
list("我a,3".encode("utf-8"))
[230, 136, 145, 97, 44, 51]
# UTF-16 编码(会有一些0填充)
list("我a,3".encode("utf-16"))
[255, 254, 17, 98, 97, 0, 44, 0, 51, 0]
# UTF-32 编码(会有更多0填充)
list("我a,3".encode("utf-32"))
[255, 254, 0, 0, 17, 98, 0, 0, 97, 0, 0, 0, 44, 0, 0, 0, 51, 0, 0, 0]
3. BPE算法¶
BPE维基:https://en.wikipedia.org/wiki/Byte_pair_encoding
算法思想简述就是把串中的出现频率最高的二元组替换为一个新的符号。替换到满意为止。
3.1 尝试¶
# 用BPE维基的一段文本作为演示
text = 'Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling.[4] Its modification is notable as the large language model tokenizer with an ability to combine both tokens that encode single characters (including single digits or single punctuation marks) and those that encode whole words (even the longest compound words).[5][6][7] This modification, in the first step, assumes all unique characters to be an initial set of 1-character long n-grams (i.e. initial "tokens"). Then, successively, the most frequent pair of adjacent characters is merged into a new, 2-character long n-gram and all instances of the pair are replaced by this new token. This is repeated until a vocabulary of prescribed size is obtained. Note that new words can always be constructed from final vocabulary tokens and initial-set characters.[8] This algorithmic approach has been extended from spoken language to sign language in recent years.[9]\nAll the unique tokens found in a corpus are listed in a token vocabulary, the size of which, in the case of GPT-3.5 and GPT-4, is 100256.\nThe difference between the modified and the original algorithm is that the original algorithm does not merge the most frequent pair of bytes of data, but replaces them with a new byte that was not contained in the initial dataset. A lookup table of the replacements is required to rebuild the initial dataset. The algorithm is effective for tokenization because it has low computational overhead and remains consistent and reliable.'
# 先把text转化为utf-8编码(一堆0~255的int)
text_codes = text.encode("utf-8")
# 计算各个 pair 的出现次数
# 记录在 pair_counts 中。key 是二元组, value 是次数,形如{(123,333),1}
pair_counts = {}
for pair in zip(text_codes, text_codes[1:]): # 遍历所有pair
pair_counts[pair] = pair_counts.get(pair, 0) + 1 # 统计出现次数
from itertools import islice
# 查看 pair_counts
# 打印前20个键值对
for key, value in islice(pair_counts.items(), 20):
print(f"{key}: {value}")
(66, 121): 1 (121, 116): 3 (116, 101): 17 (101, 32): 49 (32, 112): 6 (112, 97): 4 (97, 105): 7 (105, 114): 7 (114, 32): 12 (32, 101): 6 (101, 110): 25 (110, 99): 8 (99, 111): 12 (111, 100): 10 (100, 105): 10 (105, 110): 33 (110, 103): 15 (103, 91): 1 (91, 49): 1 (49, 93): 1
# 使用 max 函数,指定 key 的比较方式为根据字典的 value 来获取最大 key
max_key = max(pair_counts, key=pair_counts.get)
max_key, pair_counts[max_key]
((101, 32), 49)
4. 正式 BPE¶
4.1 准备将用到的基本方法¶
def get_stats(ids, counts=None):
"""
获取ids中各个pair的出现次数
Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
"""
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): # iterate consecutive elements
counts[pair] = counts.get(pair, 0) + 1
return counts
# 替换函数
def merge(ids, pair, idx):
"""
将 ids 序列中的 pair 替换为 idx
Example: ids=[1, 2, 3, 1, 2], pair=(1, 2), idx=4 -> [4, 3, 4]
"""
newids = []
i = 0
while i < len(ids):
# if not at the very last position AND the pair matches, replace it
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i + 1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
# 由原始文本到token ids
def encode(orginal_text,merges):
# 经由原始字符转化为token id
text_bytes = orginal_text.encode("utf-8") # raw bytes
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
# 从前往后替换 (和训练时保持一致的先后顺序)
stats = get_stats(ids)
pair = min(stats, key=lambda p: merges.get(p, float("inf")))
# 如果没有则返回inf(最不优先替换)
# 但如果全部都是inf的话 min 会返回第一个需要判断下
if pair not in merges:
break # 替换完成
# 替换最早的那个pair
idx = merges[pair]
ids = merge(ids, pair, idx)
return ids
# 由token ids 到原始文本
def decode(ids,vocab):
# given ids (list of integers), return Python string
text_bytes = b"".join(vocab[idx] for idx in ids)
# 如果decode有问题(如utf-8编码下第一个byte不可能是128)则用一个特殊字符代替
text = text_bytes.decode("utf-8", errors="replace")
return text
4.2 BPE 试训练¶
# 超参
vocab_size = 265
# 训练文本准备z
train_text = text
text_bytes = train_text.encode("utf-8") # raw bytes
# 训练对象(词汇表、转化表)
# # 基本的ASCII码要保留,如果有对于其他字符集如汉字,则也需要保留所有单汉字。避免未来 GPT 处理输入不认识
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
merges = {} # (int, int) -> int
# 序列准备
ids = list(text_bytes) # list of integers in range 0..255
# 目标 merge 的次数
num_merges = vocab_size - 256
for i in range(num_merges):
# 获取当前 pair 信息
stats = get_stats(ids)
# 找出现次数最多的 pair
pair = max(stats, key=stats.get)
# 给其一个新 id
idx = 256 + i
# 替换
ids = merge(ids, pair, idx)
# 将替换信息加入 merges
merges[pair] = idx
# 记录新的 idx 对应的原始串
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
merge 1/9: (101, 32) -> 256 (b'e ') had 49 occurrences merge 2/9: (115, 32) -> 257 (b's ') had 39 occurrences merge 3/9: (105, 110) -> 258 (b'in') had 33 occurrences merge 4/9: (116, 104) -> 259 (b'th') had 31 occurrences merge 5/9: (101, 110) -> 260 (b'en') had 25 occurrences merge 6/9: (32, 97) -> 261 (b' a') had 23 occurrences merge 7/9: (116, 32) -> 262 (b't ') had 21 occurrences merge 8/9: (100, 32) -> 263 (b'd ') had 19 occurrences merge 9/9: (111, 114) -> 264 (b'or') had 16 occurrences
4.3 检查试训练结果¶
# 尝试用其 encode 训练文本
test_ids = encode(text,merges)
len(test_ids)
1369
# 对比原始 utf-8 encode 的长度
len(text.encode('utf-8'))
1625
# 通过9次merge减少到了84%
len(test_ids)/len(text.encode('utf-8'))
0.8424615384615385
# 尝试解码
decode(test_ids,vocab)
'Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling.[4] Its modification is notable as the large language model tokenizer with an ability to combine both tokens that encode single characters (including single digits or single punctuation marks) and those that encode whole words (even the longest compound words).[5][6][7] This modification, in the first step, assumes all unique characters to be an initial set of 1-character long n-grams (i.e. initial "tokens"). Then, successively, the most frequent pair of adjacent characters is merged into a new, 2-character long n-gram and all instances of the pair are replaced by this new token. This is repeated until a vocabulary of prescribed size is obtained. Note that new words can always be constructed from final vocabulary tokens and initial-set characters.[8] This algorithmic approach has been extended from spoken language to sign language in recent years.[9]\nAll the unique tokens found in a corpus are listed in a token vocabulary, the size of which, in the case of GPT-3.5 and GPT-4, is 100256.\nThe difference between the modified and the original algorithm is that the original algorithm does not merge the most frequent pair of bytes of data, but replaces them with a new byte that was not contained in the initial dataset. A lookup table of the replacements is required to rebuild the initial dataset. The algorithm is effective for tokenization because it has low computational overhead and remains consistent and reliable.'
# 和原文比对看是否一致
decode(test_ids,vocab) == text
True
# 再换一段非训练文本试试
text2 = "Harry Potter and the Sorcerer's Stone\n\nCHAPTER ONE\n\nTHE BOY WHO LIVED\n\nMr. and Mrs. Dursley, of number four, Privet Drive, were proud to saythat they were perfectly normal, thank you very much. They were the lastpeople you'd expect to be involved in anything strange or mysterious,because they just didn't hold with such nonsense.\n\nMr. Dursley was the director of a firm called Grunnings, which madedrills. He was a big, beefy man with hardly any neck, although he didhave a very large mustache. Mrs. Dursley was thin and blonde and hadnearly twice the usual amount of neck, which came in very useful as shespent so much of her time craning over garden fences, spying on theneighbors. The Dursleys had a small son called Dudley and in theiropinion there was no finer boy anywhere.\n\nThe Dursleys had everything they wanted, but they also had a secret, andtheir greatest fear was that somebody would discover it. They didn'tthink they could bear it if anyone found out about the Potters. Mrs.Potter was"
test_ids2 = encode(text2,merges)
# 长度也降到了 87%样子
len(test_ids2)/len(text2.encode('utf-8'))
0.876
# 且解码和能和原文一致
decode(test_ids2,vocab) == text2
True
# 加载语料
text_harry = ""
the_file_path = 'Harry Potter 1-7.txt'
with open(the_file_path, "r", encoding="ansi") as f:
text_harry = f.read()
# 检查发现文本的换行不统一,有些地方仅一次回车有些地方又两次,修改一下统一为两次回车t
placeholder = "##DOUBLE_NEWLINE##"
text_harry = text_harry.replace("\n\n", placeholder)
# 将所有单个的'\n'换为'\n\n'
text_harry = text_harry.replace("\n", "\n\n")
# 恢复之前placeholder的成对'\n\n'
text_harry = text_harry.replace(placeholder, "\n\n")
# 再删除文中的一些全角空格
text_harry = text_harry.replace("\u3000", "")
# 超参
vocab_size = 360
# 训练文本准备z
train_text = text_harry
text_bytes = train_text.encode("utf-8") # raw bytes
# 训练对象(词汇表、转化表)
# # 基本的ASCII码要保留,如果有对于其他字符集如汉字,则也需要保留所有单汉字。避免未来 GPT 处理输入不认识
vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes
merges = {} # (int, int) -> int
# 序列准备
ids = list(text_bytes) # list of integers in range 0..255
# 目标 merge 的次数
num_merges = vocab_size - 256
for i in range(num_merges):
# 获取当前 pair 信息
stats = get_stats(ids)
# 找出现次数最多的 pair
pair = max(stats, key=stats.get)
# 给其一个新 id
idx = 256 + i
# 替换
ids = merge(ids, pair, idx)
# 将替换信息加入 merges
merges[pair] = idx
# 记录新的 idx 对应的原始串
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
merge 1/104: (101, 32) -> 256 (b'e ') had 172289 occurrences merge 2/104: (100, 32) -> 257 (b'd ') had 129879 occurrences merge 3/104: (116, 104) -> 258 (b'th') had 112753 occurrences merge 4/104: (116, 32) -> 259 (b't ') had 100260 occurrences merge 5/104: (105, 110) -> 260 (b'in') had 93785 occurrences merge 6/104: (115, 32) -> 261 (b's ') had 93641 occurrences merge 7/104: (101, 114) -> 262 (b'er') had 74124 occurrences merge 8/104: (44, 32) -> 263 (b', ') had 73924 occurrences merge 9/104: (97, 110) -> 264 (b'an') had 59080 occurrences merge 10/104: (121, 32) -> 265 (b'y ') had 57034 occurrences merge 11/104: (111, 117) -> 266 (b'ou') had 55077 occurrences merge 12/104: (97, 114) -> 267 (b'ar') had 54143 occurrences merge 13/104: (46, 32) -> 268 (b'. ') had 53776 occurrences merge 14/104: (10, 10) -> 269 (b'\n\n') had 51693 occurrences merge 15/104: (111, 110) -> 270 (b'on') had 50938 occurrences merge 16/104: (258, 256) -> 271 (b'the ') had 47764 occurrences merge 17/104: (101, 257) -> 272 (b'ed ') had 46897 occurrences merge 18/104: (260, 103) -> 273 (b'ing') had 46588 occurrences merge 19/104: (111, 32) -> 274 (b'o ') had 43601 occurrences merge 20/104: (111, 114) -> 275 (b'or') had 39805 occurrences merge 21/104: (273, 32) -> 276 (b'ing ') had 38149 occurrences merge 22/104: (101, 110) -> 277 (b'en') had 37481 occurrences merge 23/104: (116, 274) -> 278 (b'to ') had 30433 occurrences merge 24/104: (32, 115) -> 279 (b' s') had 29826 occurrences merge 25/104: (108, 108) -> 280 (b'll') had 29670 occurrences merge 26/104: (104, 105) -> 281 (b'hi') had 28651 occurrences merge 27/104: (104, 97) -> 282 (b'ha') had 28559 occurrences merge 28/104: (264, 257) -> 283 (b'and ') had 27650 occurrences merge 29/104: (32, 32) -> 284 (b' ') had 27496 occurrences merge 30/104: (102, 32) -> 285 (b'f ') had 27142 occurrences merge 31/104: (101, 97) -> 286 (b'ea') had 24001 occurrences merge 32/104: (119, 97) -> 287 (b'wa') had 23833 occurrences merge 33/104: (104, 256) -> 288 (b'he ') had 23331 occurrences merge 34/104: (262, 32) -> 289 (b'er ') had 22135 occurrences merge 35/104: (115, 116) -> 290 (b'st') had 22011 occurrences merge 36/104: (111, 119) -> 291 (b'ow') had 20687 occurrences merge 37/104: (111, 285) -> 292 (b'of ') had 20524 occurrences merge 38/104: (267, 114) -> 293 (b'arr') had 19524 occurrences merge 39/104: (97, 105) -> 294 (b'ai') had 18887 occurrences merge 40/104: (108, 101) -> 295 (b'le') had 18623 occurrences merge 41/104: (97, 32) -> 296 (b'a ') had 18547 occurrences merge 42/104: (72, 293) -> 297 (b'Harr') had 18187 occurrences merge 43/104: (97, 259) -> 298 (b'at ') had 18099 occurrences merge 44/104: (111, 111) -> 299 (b'oo') had 17295 occurrences merge 45/104: (114, 101) -> 300 (b're') had 16406 occurrences merge 46/104: (103, 104) -> 301 (b'gh') had 15845 occurrences merge 47/104: (114, 105) -> 302 (b'ri') had 15580 occurrences merge 48/104: (46, 269) -> 303 (b'.\n\n') had 15409 occurrences merge 49/104: (109, 32) -> 304 (b'm ') had 15266 occurrences merge 50/104: (99, 104) -> 305 (b'ch') had 15214 occurrences merge 51/104: (287, 261) -> 306 (b'was ') had 15173 occurrences merge 52/104: (260, 32) -> 307 (b'in ') had 14738 occurrences merge 53/104: (270, 32) -> 308 (b'on ') had 14412 occurrences merge 54/104: (121, 266) -> 309 (b'you') had 14262 occurrences merge 55/104: (281, 261) -> 310 (b'his ') had 14105 occurrences merge 56/104: (105, 116) -> 311 (b'it') had 13814 occurrences merge 57/104: (226, 128) -> 312 (b'\xe2\x80') had 13606 occurrences merge 58/104: (114, 111) -> 313 (b'ro') had 13413 occurrences merge 59/104: (107, 32) -> 314 (b'k ') had 13191 occurrences merge 60/104: (98, 101) -> 315 (b'be') had 13177 occurrences merge 61/104: (280, 32) -> 316 (b'll ') had 13148 occurrences merge 62/104: (97, 99) -> 317 (b'ac') had 13019 occurrences merge 63/104: (294, 257) -> 318 (b'aid ') had 12910 occurrences merge 64/104: (258, 101) -> 319 (b'the') had 12802 occurrences merge 65/104: (277, 32) -> 320 (b'en ') had 12798 occurrences merge 66/104: (39, 261) -> 321 (b"'s ") had 12083 occurrences merge 67/104: (115, 101) -> 322 (b'se') had 11964 occurrences merge 68/104: (97, 116) -> 323 (b'at') had 11842 occurrences merge 69/104: (297, 265) -> 324 (b'Harry ') had 11574 occurrences merge 70/104: (119, 105) -> 325 (b'wi') had 11573 occurrences merge 71/104: (108, 105) -> 326 (b'li') had 11313 occurrences merge 72/104: (279, 318) -> 327 (b' said ') had 11025 occurrences merge 73/104: (108, 265) -> 328 (b'ly ') had 10906 occurrences merge 74/104: (275, 32) -> 329 (b'or ') had 10876 occurrences merge 75/104: (97, 103) -> 330 (b'ag') had 10740 occurrences merge 76/104: (115, 259) -> 331 (b'st ') had 10047 occurrences merge 77/104: (118, 256) -> 332 (b've ') had 10013 occurrences merge 78/104: (282, 257) -> 333 (b'had ') had 9955 occurrences merge 79/104: (117, 114) -> 334 (b'ur') had 9806 occurrences merge 80/104: (101, 100) -> 335 (b'ed') had 9791 occurrences merge 81/104: (111, 109) -> 336 (b'om') had 9655 occurrences merge 82/104: (109, 105) -> 337 (b'mi') had 9418 occurrences merge 83/104: (73, 32) -> 338 (b'I ') had 9285 occurrences merge 84/104: (121, 263) -> 339 (b'y, ') had 9256 occurrences merge 85/104: (101, 115) -> 340 (b'es') had 9075 occurrences merge 86/104: (97, 108) -> 341 (b'al') had 9044 occurrences merge 87/104: (262, 256) -> 342 (b'ere ') had 9015 occurrences merge 88/104: (117, 110) -> 343 (b'un') had 8980 occurrences merge 89/104: (258, 32) -> 344 (b'th ') had 8936 occurrences merge 90/104: (284, 284) -> 345 (b' ') had 8797 occurrences merge 91/104: (258, 298) -> 346 (b'that ') had 8765 occurrences merge 92/104: (101, 263) -> 347 (b'e, ') had 8662 occurrences merge 93/104: (110, 111) -> 348 (b'no') had 8611 occurrences merge 94/104: (105, 259) -> 349 (b'it ') had 8568 occurrences merge 95/104: (108, 257) -> 350 (b'ld ') had 8493 occurrences merge 96/104: (105, 99) -> 351 (b'ic') had 8461 occurrences merge 97/104: (104, 101) -> 352 (b'he') had 8361 occurrences merge 98/104: (108, 256) -> 353 (b'le ') had 8322 occurrences merge 99/104: (117, 112) -> 354 (b'up') had 8042 occurrences merge 100/104: (105, 114) -> 355 (b'ir') had 7778 occurrences merge 101/104: (309, 32) -> 356 (b'you ') had 7693 occurrences merge 102/104: (119, 104) -> 357 (b'wh') had 7640 occurrences merge 103/104: (270, 256) -> 358 (b'one ') had 7625 occurrences merge 104/104: (34, 32) -> 359 (b'" ') had 7618 occurrences
4.5 检查压缩情况¶
test_ids_harry = encode(text_harry,merges)
# 检查一下压缩率
len(test_ids_harry)/len(text_harry.encode('utf-8'))
0.5762402764853447
注:仅通过100次替换,就压缩到了57%,效果还是挺不错的
len(vocab)
360
4.6 保存训练结果¶
json形式写入文件
import pickle
bpe_data = {'merges':merges,'vocab':vocab}
# 使用 pickle 保存字典
with open("bpe_data.pkl", "wb") as file:
pickle.dump(bpe_data, file)