Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public class TokenTextSplitter extends TextSplitter {

private static final boolean KEEP_SEPARATOR = true;

private static final List<Character> DEFAULT_PUNCTUATION_MARKS = List.of('.', '?', '!', '\n');

private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry();

private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE);
Expand All @@ -64,21 +66,27 @@ public class TokenTextSplitter extends TextSplitter {

private final boolean keepSeparator;

private final List<Character> punctuationMarks;

public TokenTextSplitter() {
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR,
DEFAULT_PUNCTUATION_MARKS);
}

public TokenTextSplitter(boolean keepSeparator) {
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator,
DEFAULT_PUNCTUATION_MARKS);
}

public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
boolean keepSeparator) {
boolean keepSeparator, List<Character> punctuationMarks) {
this.chunkSize = chunkSize;
this.minChunkSizeChars = minChunkSizeChars;
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
this.maxNumChunks = maxNumChunks;
this.keepSeparator = keepSeparator;
Assert.notEmpty(punctuationMarks, "punctuationMarks must not be empty");
this.punctuationMarks = punctuationMarks;
}

public static Builder builder() {
Expand Down Expand Up @@ -109,8 +117,7 @@ protected List<String> doSplit(String text, int chunkSize) {
}

// Find the last period or punctuation mark in the chunk
int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'),
Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n'))));
int lastPunctuation = getLastPunctuationIndex(chunkText);

if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
// Truncate the chunk text at the punctuation mark
Expand Down Expand Up @@ -140,6 +147,16 @@ protected List<String> doSplit(String text, int chunkSize) {
return chunks;
}

protected int getLastPunctuationIndex(String chunkText) {
// find the max index of any punctuation mark
int maxLastPunctuation = -1;
for (Character punctuationMark : this.punctuationMarks) {
int lastPunctuation = chunkText.lastIndexOf(punctuationMark);
maxLastPunctuation = Math.max(maxLastPunctuation, lastPunctuation);
}
return maxLastPunctuation;
}

private List<Integer> getEncodedTokens(String text) {
Assert.notNull(text, "Text must not be null");
return this.encoding.encode(text).boxed();
Expand All @@ -164,6 +181,8 @@ public static final class Builder {

private boolean keepSeparator = KEEP_SEPARATOR;

private List<Character> punctuationMarks = DEFAULT_PUNCTUATION_MARKS;

private Builder() {
}

Expand Down Expand Up @@ -192,9 +211,14 @@ public Builder withKeepSeparator(boolean keepSeparator) {
return this;
}

public Builder withPunctuationMarks(List<Character> punctuationMarks) {
this.punctuationMarks = punctuationMarks;
return this;
}

public TokenTextSplitter build() {
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
this.maxNumChunks, this.keepSeparator);
this.maxNumChunks, this.keepSeparator, this.punctuationMarks);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,44 @@ public void testTokenTextSplitterBuilderWithAllFields() {
assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1");
}

@Test
public void testTokenTextSplitterWithCustomPunctuationMarks() {
var contentFormatter1 = DefaultContentFormatter.defaultConfig();
var contentFormatter2 = DefaultContentFormatter.defaultConfig();

assertThat(contentFormatter1).isNotSameAs(contentFormatter2);

var doc1 = new Document("Here, we set custom punctuation marks。?!. We just want to test it works or not?");
doc1.setContentFormatter(contentFormatter1);

var doc2 = new Document("And more, we add protected method getLastPunctuationIndex in TokenTextSplitter class!"
+ "The subclasses can override this method to achieve their own business logic。We just want to test it works or not?");
doc2.setContentFormatter(contentFormatter2);

var tokenTextSplitter = TokenTextSplitter.builder()
.withChunkSize(10)
.withMinChunkSizeChars(5)
.withMinChunkLengthToEmbed(3)
.withMaxNumChunks(50)
.withKeepSeparator(true)
.withPunctuationMarks(List.of('。', '?', '!'))
.build();

var chunks = tokenTextSplitter.apply(List.of(doc1, doc2));

assertThat(chunks.size()).isEqualTo(7);

// Doc 1
assertThat(chunks.get(0).getText()).isEqualTo("Here, we set custom punctuation marks。?!");
assertThat(chunks.get(1).getText()).isEqualTo(". We just want to test it works or not");

// Doc 2
assertThat(chunks.get(2).getText()).isEqualTo("And more, we add protected method getLastPunctuation");
assertThat(chunks.get(3).getText()).isEqualTo("Index in TokenTextSplitter class!");
assertThat(chunks.get(4).getText()).isEqualTo("The subclasses can override this method to achieve their own");
assertThat(chunks.get(5).getText()).isEqualTo("business logic。");
assertThat(chunks.get(6).getText()).isEqualTo("We just want to test it works or not?");

}

}