90 lines
3.5 KiB
Python

import re
import hashlib
from typing import Dict, List, Tuple, Union
Rank = int
def _byte_pair_merge(ranks: Dict[bytes, Rank], piece: bytes) -> List[Tuple[bytes, Rank]]:
parts = []
min_rank = (float('inf'), float('inf'))
for i in range(len(piece) - 1):
rank = ranks.get(piece[i:i + 2], float('inf'))
if rank < min_rank[0]:
min_rank = (rank, i)
parts.append((piece[i:i + 2], rank))
parts.append((piece[len(piece) - 1:], float('inf')))
parts.append((piece[len(piece):], float('inf')))
while min_rank[0] != float('inf'):
i = min_rank[1]
if i > 0:
parts[i - 1] = (parts[i - 1][0], get_rank_with_ranks(piece, parts, i - 1, ranks))
parts[i] = (parts[i][0], get_rank_with_ranks(piece, parts, i, ranks))
del parts[i + 1]
min_rank = (float('inf'), float('inf'))
for j, (_, rank) in enumerate(parts[:-1]):
if rank < min_rank[0]:
min_rank = (rank, j)
return parts
def get_rank_with_ranks(piece: bytes, parts: List[Tuple[bytes, Rank]], i: int, ranks: Dict[bytes, Rank]) -> Rank:
if (i + 3) < len(parts):
key = piece[parts[i][0].start:parts[i + 3][0].start]
return ranks.get(key, float('inf'))
else:
return float('inf')
def byte_pair_encode(piece: bytes, ranks: Dict[bytes, Rank]) -> List[Rank]:
assert len(piece) > 1
parts = _byte_pair_merge(ranks, piece)
tokens = []
current_token = []
for part in parts[:-1]:
if len(current_token) == 0:
current_token.append(part[0])
elif ranks.get(b''.join(current_token + [part[0]])) is not None:
current_token.append(part[0])
else:
tokens.append(ranks[b''.join(current_token)])
current_token = [part[0]]
tokens.append(ranks[b''.join(current_token)])
return tokens
def byte_pair_split(piece: bytes, ranks: Dict[bytes, Rank]) -> List[bytes]:
assert len(piece) > 1
parts = _byte_pair_merge(ranks, piece)
return [part[0] for part in parts[:-1]]
class CoreBPE:
def __init__(self, encoder: Dict[bytes, Rank], special_tokens_encoder: Dict[str, Rank], pattern: str):
self.encoder = encoder
self.special_tokens_encoder = special_tokens_encoder
self.decoder = {v: k for k, v in encoder.items()}
self.special_tokens_decoder = {v: k.encode('utf-8') for k, v in special_tokens_encoder.items()}
self.regex = re.compile(pattern)
self.special_regex = re.compile('|'.join(map(re.escape, special_tokens_encoder.keys())))
def encode_ordinary(self, text: str) -> List[Rank]:
return [self.encoder[piece.encode("utf-8")] for piece in self.regex.findall(text)]
def encode(self, text: str, allowed_special: set) -> List[Rank]:
tokens = []
start = 0
for match in self.special_regex.finditer(text):
if match.start() > start:
tokens.extend(self.encode_ordinary(text[start:match.start()]))
if match.group(0) in allowed_special:
tokens.append(self.special_tokens_encoder[match.group(0)])
start = match.end()
if start < len(text):
tokens.extend(self.encode_ordinary(text[start:]))
return tokens
def decode_bytes(self, tokens: List[Rank]) -> bytes:
return b''.join(self.decoder.get(token, self.special_tokens_decoder.get(token)) for token in tokens)
def token_byte_values(self) -> List[bytes]:
return self.sorted_token_bytes