-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbyte_pair_encoding.cpp
More file actions
79 lines (66 loc) · 2.3 KB
/
byte_pair_encoding.cpp
File metadata and controls
79 lines (66 loc) · 2.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <cstdint>
#include <iostream>
#include <unordered_map>
#include <vector>
#include <fstream>
#include <sstream>
#include <iomanip>
int main(int argc, char* argv[]) {
if (argc < 3) {
std::cerr << "Usage: " << argv[0] << " <filename> <vocab_size>" << std::endl;
return 1;
}
std::string filename = argv[1];
int target_vocab_size = std::stoi(argv[2]);
std::ifstream t(filename);
if (!t.is_open()) {
std::cerr << "Error opening file: " << filename << std::endl;
return 1;
}
std::stringstream buffer;
buffer << t.rdbuf();
std::string text = buffer.str();
std::vector<int> tokens;
for (char c: text) {
tokens.push_back(static_cast<unsigned char>(c));
}
int current_vocab = 256;
std::cout << "Starting token count: " << std::to_string(tokens.size()) << std::endl;
std::cout << "Starting encoding in C++..." << std::endl;
std::cout << "|----------------------|" << std::endl;
do {
std::unordered_map<uint64_t, int> lookup; // Byte Pair -> count
int max_freq = 0;
uint64_t max_freq_token_token = 0;
for (size_t i = 0; i < tokens.size() - 1; i++) {
uint64_t key = (static_cast<uint64_t>(tokens[i]) << 32) | static_cast<uint64_t>(tokens[i+1]);
lookup[key]++;
}
for (const auto& pair: lookup) {
if (pair.second > max_freq) {
max_freq = pair.second;
max_freq_token_token = pair.first;
}
}
if (max_freq < 2) break;
int new_token_id = current_vocab++;
int tokenA = max_freq_token_token >> 32;
int tokenB = static_cast<int>(max_freq_token_token & 0xFFFFFFFF);
std::cout << "vocab: " << std::left << std::setw(5) << current_vocab
<< " / " << target_vocab_size
<< " | merging: (" << tokenA << ", " << tokenB << ") -> "
<< new_token_id << " | count: " << max_freq << "\n";
std::vector<int> new_tokens;
new_tokens.reserve(tokens.size());
for (size_t i = 0; i < tokens.size(); i++) {
if (i < tokens.size() - 1 && tokens[i] == tokenA && tokens[i+1] == tokenB) {
new_tokens.push_back(new_token_id);
i++;
} else new_tokens.push_back(tokens[i]);
}
tokens = new_tokens;
} while (current_vocab < target_vocab_size);
std::cout << "Encoding Complete." << std::endl;
std::cout << "Final token count: " << std::to_string(tokens.size()) << std::endl;
return 0;
}