Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions src/Microsoft.ML.Tokenizers/Utils/BytePairEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@ public static (int Id, int TokenIndex, int TokenLength)[] BytePairEncode(ReadOnl
return [(ranks[mergingBytes], 0, 1)];
}

// For large inputs, use heap-based algorithm to avoid O(n²) behavior.
// Threshold of 128 chosen empirically: linear scan is cache-friendly for small inputs,
// while heap overhead (O(log n) per operation) becomes worthwhile for larger inputs.
// Based on upstream tiktoken using 100, adjusted upward for C#'s efficient span operations.
if (mergingBytes.Length > 128)
Comment thread
stephentoub marked this conversation as resolved.
{
return BytePairEncodeLarge(mergingBytes, ranks, indexMappingSpan);
}

(int Index, int Rank)[]? arrayPoolArray = null;
int requiredLength = mergingBytes.Length + 1;
Span<(int Index, int Rank)> byteIndicesAndRanks = requiredLength <= 64 ?
Expand Down Expand Up @@ -116,6 +125,166 @@ int GetRank(Span<(int Index, int Rank)> byteIndicesAndRanks, int startIndex, int
return result;
}

private struct State
{
public int Prev;
public int End;
public int NextEnd;
public int NextRank;
public int CurRank;
Comment thread
tarekgh marked this conversation as resolved.
}

private struct MergeEntry : IComparable<MergeEntry>
{
public int Rank;
public int Start;

public int CompareTo(MergeEntry other)
{
int rankComparison = Rank.CompareTo(other.Rank);
if (rankComparison != 0)
{
return rankComparison;
}
return Start.CompareTo(other.Start);
}
}
Comment thread
stephentoub marked this conversation as resolved.

private static (int Id, int TokenIndex, int TokenLength)[] BytePairEncodeLarge(ReadOnlyMemory<byte> mergingBytes, IReadOnlyDictionary<ReadOnlyMemory<byte>, int> ranks, ReadOnlySpan<int> indexMappingSpan)
{
State[]? statePoolArray = null;
int stateLength = mergingBytes.Length;
Span<State> state = stateLength <= 256 ?
stackalloc State[256] :
Comment thread
stephentoub marked this conversation as resolved.
Outdated
(statePoolArray = ArrayPool<State>.Shared.Rent(stateLength));
state = state.Slice(0, stateLength);

state[0] = new State
{
Prev = int.MaxValue,
End = 1,
NextEnd = 2,
NextRank = int.MaxValue,
CurRank = int.MaxValue
};

var heap = new PriorityQueue<MergeEntry>(0);
Comment thread
stephentoub marked this conversation as resolved.
Outdated
Comment thread
tarekgh marked this conversation as resolved.
Outdated

for (int i = 0; i < mergingBytes.Length - 1; i++)
{
var slice = mergingBytes.Slice(i, 2);
if (ranks.TryGetValue(slice, out int rank))
{
heap.Enqueue(new MergeEntry { Start = i, Rank = rank });
state[i].NextRank = rank;
}

state[i + 1] = new State
{
Prev = i,
End = i + 2,
NextEnd = i + 3,
NextRank = int.MaxValue,
CurRank = int.MaxValue
};
}

// Local function to add a potential merge to the heap.
void PotentialMerge(Span<State> stateSpan, PriorityQueue<MergeEntry> heapQueue, int start, int nextEndItem)
{
stateSpan[start].NextEnd = nextEndItem;
stateSpan[start].NextRank = int.MaxValue;

if (nextEndItem <= mergingBytes.Length)
{
var slice = mergingBytes.Slice(start, nextEndItem - start);
if (ranks.TryGetValue(slice, out int rank))
{
heapQueue.Enqueue(new MergeEntry { Start = start, Rank = rank });
stateSpan[start].NextRank = rank;
}
}
}

while (heap.Count > 0)
{
MergeEntry left = heap.Dequeue();

if (left.Rank == int.MaxValue)
{
break;
}

if (left.Rank != state[left.Start].NextRank)
{
continue;
}

int leftStart = left.Start;
int rightStart = state[leftStart].End;
int rightEnd = state[leftStart].NextEnd;
int rightNextEnd = state[rightStart].NextEnd;

state[leftStart].CurRank = state[leftStart].NextRank;
state[leftStart].End = rightEnd;
PotentialMerge(state, heap, leftStart, rightNextEnd);

if (rightEnd < state.Length)
{
state[rightEnd].Prev = leftStart;
}

if (leftStart > 0)
{
int prevStart = state[leftStart].Prev;
PotentialMerge(state, heap, prevStart, rightEnd);
}

state[rightStart].NextRank = int.MaxValue;
}

var resultList = new List<(int Id, int TokenIndex, int TokenLength)>();
Comment thread
tarekgh marked this conversation as resolved.
Outdated
int currentIndex = 0;

while (currentIndex < state.Length)
{
int startIndex = currentIndex;
int endIndex = state[currentIndex].End;

int mappedStartIndex = indexMappingSpan[startIndex];
int mappedEndIndex = indexMappingSpan[endIndex];

int finalEndIndex = endIndex;

// Handle partial characters/elements at token boundaries.
// If the byte at endIndex-1 maps to the same character as endIndex,
// extend the token to include the complete character.
if (finalEndIndex > 0 && indexMappingSpan[finalEndIndex - 1] == mappedEndIndex)
{
finalEndIndex++;
while (finalEndIndex < indexMappingSpan.Length && indexMappingSpan[finalEndIndex] == mappedEndIndex)
{
finalEndIndex++;
}
}
Comment thread
stephentoub marked this conversation as resolved.

int tokenId = state[currentIndex].CurRank != int.MaxValue
? state[currentIndex].CurRank
: ranks[mergingBytes.SliceStartEnd(startIndex, endIndex)];

resultList.Add((tokenId, mappedStartIndex, indexMappingSpan[finalEndIndex] - mappedStartIndex));
Comment thread
stephentoub marked this conversation as resolved.
Outdated

currentIndex = state[currentIndex].End;
}

if (statePoolArray is not null)
{
ArrayPool<State>.Shared.Return(statePoolArray);
}

return resultList.ToArray();
}

private static ReadOnlyMemory<byte> SliceStartEnd(this ReadOnlyMemory<byte> memory, int start, int end) => memory.Slice(start, end - start);
}
}
72 changes: 71 additions & 1 deletion test/Microsoft.ML.Tokenizers.Tests/TiktokenTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.DotNet.RemoteExecutor;
using System;
using System.Buffers;
using System.Collections.Generic;
Expand All @@ -13,6 +12,7 @@
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using Microsoft.DotNet.RemoteExecutor;
using Xunit;

namespace Microsoft.ML.Tokenizers.Tests
Expand Down Expand Up @@ -848,6 +848,76 @@ public void TestOss()

private static IReadOnlyDictionary<string, int>? GetVocabulary(TiktokenTokenizer tiktoken)
=> typeof(TiktokenTokenizer).GetProperty("Vocabulary", BindingFlags.Instance | BindingFlags.NonPublic)?.GetValue(tiktoken) as IReadOnlyDictionary<string, int>;

[Fact]
public void TestLargeInputOptimization()
{
// Verify that large inputs (>128 bytes) and boundary cases round-trip correctly via the public API.
// This exercises the large-input optimization path but does not directly compare it to the small-input path.

// Test with repeated characters - this is the adversarial case that caused O(n^2) behavior
string largeRepeatedInput = new string('a', 1000);
IReadOnlyList<int> ids = GPT4.EncodeToIds(largeRepeatedInput);
string decoded = GPT4.Decode(ids);
Assert.Equal(largeRepeatedInput, decoded);

// Test with a more realistic large input
string largeMixedInput = string.Join(" ", Enumerable.Repeat("Hello World! This is a test.", 50));
IReadOnlyList<int> mixedIds = GPT4.EncodeToIds(largeMixedInput);
string mixedDecoded = GPT4.Decode(mixedIds);
Assert.Equal(largeMixedInput, mixedDecoded);

// Test boundary case - exactly at threshold (128)
string boundaryInput = new string('x', 128);
IReadOnlyList<int> boundaryIds = GPT4.EncodeToIds(boundaryInput);
string boundaryDecoded = GPT4.Decode(boundaryIds);
Assert.Equal(boundaryInput, boundaryDecoded);

// Test just below threshold (127)
string belowThresholdInput = new string('x', 127);
IReadOnlyList<int> belowIds = GPT4.EncodeToIds(belowThresholdInput);
string belowDecoded = GPT4.Decode(belowIds);
Assert.Equal(belowThresholdInput, belowDecoded);

// Test just above threshold (129)
string aboveThresholdInput = new string('x', 129);
IReadOnlyList<int> aboveIds = GPT4.EncodeToIds(aboveThresholdInput);
string aboveDecoded = GPT4.Decode(aboveIds);
Assert.Equal(aboveThresholdInput, aboveDecoded);
}

[Theory]
[InlineData(200)]
[InlineData(500)]
[InlineData(1000)]
[InlineData(2000)]
public void TestLargeInputConsistency(int length)
{
// Verify that large inputs are handled correctly by the public API and round-trip successfully.
// These tests focus on observable behavior (round-tripping and reconstruction), not on comparing internal code paths.

// Test with repeated character
string inputRepeated = new string('z', length);
IReadOnlyList<int> idsRepeated = GPT4.EncodeToIds(inputRepeated);

// Verify round-trip
string decodedRepeated = GPT4.Decode(idsRepeated);
Assert.Equal(inputRepeated, decodedRepeated);

// Test with mixed content (more realistic scenario)
string inputMixed = string.Join(" ", Enumerable.Repeat("Hello World! Test123", length / 20 + 1)).Substring(0, length);
IReadOnlyList<int> idsMixed = GPT4.EncodeToIds(inputMixed);
string decodedMixed = GPT4.Decode(idsMixed);
Assert.Equal(inputMixed, decodedMixed);

// Verify with EncodingToTokens as well
IReadOnlyList<EncodedToken> tokens = GPT4.EncodeToTokens(inputRepeated, out string? normalizedText);
Assert.Null(normalizedText); // No normalization expected

// Reconstruct from tokens
var reconstructed = string.Concat(tokens.Select(t => t.Value));
Assert.Equal(inputRepeated, reconstructed);
}
}
}

Loading