From 632df5f9f7c60d760301a477f2cd28b8076ac758 Mon Sep 17 00:00:00 2001 From: oneby-wang Date: Sun, 23 Nov 2025 21:45:44 +0800 Subject: [PATCH] feat: Support custom punctuation marks in TokenTextSplitter Signed-off-by: oneby-wang --- .../splitter/TokenTextSplitter.java | 36 ++++++++++++++--- .../splitter/TokenTextSplitterTest.java | 40 +++++++++++++++++++ 2 files changed, 70 insertions(+), 6 deletions(-) diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java b/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java index a202aac426c..764323600df 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java @@ -46,6 +46,8 @@ public class TokenTextSplitter extends TextSplitter { private static final boolean KEEP_SEPARATOR = true; + private static final List DEFAULT_PUNCTUATION_MARKS = List.of('.', '?', '!', '\n'); + private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry(); private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE); @@ -64,21 +66,27 @@ public class TokenTextSplitter extends TextSplitter { private final boolean keepSeparator; + private final List punctuationMarks; + public TokenTextSplitter() { - this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR); + this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR, + DEFAULT_PUNCTUATION_MARKS); } public TokenTextSplitter(boolean keepSeparator) { - this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator); + this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator, + DEFAULT_PUNCTUATION_MARKS); } public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks, - boolean keepSeparator) { + boolean keepSeparator, List punctuationMarks) { this.chunkSize = chunkSize; this.minChunkSizeChars = minChunkSizeChars; this.minChunkLengthToEmbed = minChunkLengthToEmbed; this.maxNumChunks = maxNumChunks; this.keepSeparator = keepSeparator; + Assert.notEmpty(punctuationMarks, "punctuationMarks must not be empty"); + this.punctuationMarks = punctuationMarks; } public static Builder builder() { @@ -109,8 +117,7 @@ protected List doSplit(String text, int chunkSize) { } // Find the last period or punctuation mark in the chunk - int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'), - Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n')))); + int lastPunctuation = getLastPunctuationIndex(chunkText); if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) { // Truncate the chunk text at the punctuation mark @@ -140,6 +147,16 @@ protected List doSplit(String text, int chunkSize) { return chunks; } + protected int getLastPunctuationIndex(String chunkText) { + // find the max index of any punctuation mark + int maxLastPunctuation = -1; + for (Character punctuationMark : this.punctuationMarks) { + int lastPunctuation = chunkText.lastIndexOf(punctuationMark); + maxLastPunctuation = Math.max(maxLastPunctuation, lastPunctuation); + } + return maxLastPunctuation; + } + private List getEncodedTokens(String text) { Assert.notNull(text, "Text must not be null"); return this.encoding.encode(text).boxed(); @@ -164,6 +181,8 @@ public static final class Builder { private boolean keepSeparator = KEEP_SEPARATOR; + private List punctuationMarks = DEFAULT_PUNCTUATION_MARKS; + private Builder() { } @@ -192,9 +211,14 @@ public Builder withKeepSeparator(boolean keepSeparator) { return this; } + public Builder withPunctuationMarks(List punctuationMarks) { + this.punctuationMarks = punctuationMarks; + return this; + } + public TokenTextSplitter build() { return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed, - this.maxNumChunks, this.keepSeparator); + this.maxNumChunks, this.keepSeparator, this.punctuationMarks); } } diff --git a/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java b/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java index 96c58f3fa9a..05f4f01b3b2 100644 --- a/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java +++ b/spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java @@ -125,4 +125,44 @@ public void testTokenTextSplitterBuilderWithAllFields() { assertThat(chunks.get(2).getMetadata()).containsKeys("key2", "key3").doesNotContainKeys("key1"); } + @Test + public void testTokenTextSplitterWithCustomPunctuationMarks() { + var contentFormatter1 = DefaultContentFormatter.defaultConfig(); + var contentFormatter2 = DefaultContentFormatter.defaultConfig(); + + assertThat(contentFormatter1).isNotSameAs(contentFormatter2); + + var doc1 = new Document("Here, we set custom punctuation marks。?!. We just want to test it works or not?"); + doc1.setContentFormatter(contentFormatter1); + + var doc2 = new Document("And more, we add protected method getLastPunctuationIndex in TokenTextSplitter class!" + + "The subclasses can override this method to achieve their own business logic。We just want to test it works or not?"); + doc2.setContentFormatter(contentFormatter2); + + var tokenTextSplitter = TokenTextSplitter.builder() + .withChunkSize(10) + .withMinChunkSizeChars(5) + .withMinChunkLengthToEmbed(3) + .withMaxNumChunks(50) + .withKeepSeparator(true) + .withPunctuationMarks(List.of('。', '?', '!')) + .build(); + + var chunks = tokenTextSplitter.apply(List.of(doc1, doc2)); + + assertThat(chunks.size()).isEqualTo(7); + + // Doc 1 + assertThat(chunks.get(0).getText()).isEqualTo("Here, we set custom punctuation marks。?!"); + assertThat(chunks.get(1).getText()).isEqualTo(". We just want to test it works or not"); + + // Doc 2 + assertThat(chunks.get(2).getText()).isEqualTo("And more, we add protected method getLastPunctuation"); + assertThat(chunks.get(3).getText()).isEqualTo("Index in TokenTextSplitter class!"); + assertThat(chunks.get(4).getText()).isEqualTo("The subclasses can override this method to achieve their own"); + assertThat(chunks.get(5).getText()).isEqualTo("business logic。"); + assertThat(chunks.get(6).getText()).isEqualTo("We just want to test it works or not?"); + + } + }