Skip to content

Commit 9d9b954

Browse files
committed
Improve TokenTextSplitter for punctuations handling
Make punctuations configurable and support Chinese punctuations. Signed-off-by: Harold Li <power0721@gmail.com>
1 parent 556391b commit 9d9b954

File tree

2 files changed

+48
-6
lines changed

2 files changed

+48
-6
lines changed

spring-ai-commons/src/main/java/org/springframework/ai/transformer/splitter/TokenTextSplitter.java

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package org.springframework.ai.transformer.splitter;
1818

1919
import java.util.ArrayList;
20+
import java.util.Arrays;
21+
import java.util.Collections;
2022
import java.util.List;
2123

2224
import com.knuddels.jtokkit.Encodings;
@@ -46,6 +48,8 @@ public class TokenTextSplitter extends TextSplitter {
4648

4749
private static final boolean KEEP_SEPARATOR = true;
4850

51+
private static final List<Character> DEFAULT_PUNCTUATIONS = List.of('.', '?', '!', '。', '?', '!', '\n');
52+
4953
private final EncodingRegistry registry = Encodings.newLazyEncodingRegistry();
5054

5155
private final Encoding encoding = this.registry.getEncoding(EncodingType.CL100K_BASE);
@@ -64,21 +68,24 @@ public class TokenTextSplitter extends TextSplitter {
6468

6569
private final boolean keepSeparator;
6670

71+
private final List<Character> punctuations;
72+
6773
public TokenTextSplitter() {
68-
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR);
74+
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, KEEP_SEPARATOR, DEFAULT_PUNCTUATIONS);
6975
}
7076

7177
public TokenTextSplitter(boolean keepSeparator) {
72-
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator);
78+
this(DEFAULT_CHUNK_SIZE, MIN_CHUNK_SIZE_CHARS, MIN_CHUNK_LENGTH_TO_EMBED, MAX_NUM_CHUNKS, keepSeparator, DEFAULT_PUNCTUATIONS);
7379
}
7480

7581
public TokenTextSplitter(int chunkSize, int minChunkSizeChars, int minChunkLengthToEmbed, int maxNumChunks,
76-
boolean keepSeparator) {
82+
boolean keepSeparator, List<Character> punctuations) {
7783
this.chunkSize = chunkSize;
7884
this.minChunkSizeChars = minChunkSizeChars;
7985
this.minChunkLengthToEmbed = minChunkLengthToEmbed;
8086
this.maxNumChunks = maxNumChunks;
8187
this.keepSeparator = keepSeparator;
88+
this.punctuations = punctuations;
8289
}
8390

8491
public static Builder builder() {
@@ -124,8 +131,10 @@ protected List<String> doSplit(String text, int chunkSize) {
124131
// This prevents unnecessary splitting of small texts
125132
if (tokens.size() > chunkSize) {
126133
// Find the last period or punctuation mark in the chunk
127-
int lastPunctuation = Math.max(chunkText.lastIndexOf('.'), Math.max(chunkText.lastIndexOf('?'),
128-
Math.max(chunkText.lastIndexOf('!'), chunkText.lastIndexOf('\n'))));
134+
int lastPunctuation = -1;
135+
for (char punctuation : punctuations) {
136+
lastPunctuation = Math.max(lastPunctuation, chunkText.lastIndexOf(punctuation));
137+
}
129138

130139
if (lastPunctuation != -1 && lastPunctuation > this.minChunkSizeChars) {
131140
// Truncate the chunk text at the punctuation mark
@@ -180,6 +189,8 @@ public static final class Builder {
180189

181190
private boolean keepSeparator = KEEP_SEPARATOR;
182191

192+
private List<Character> punctuations = DEFAULT_PUNCTUATIONS;
193+
183194
private Builder() {
184195
}
185196

@@ -208,9 +219,18 @@ public Builder withKeepSeparator(boolean keepSeparator) {
208219
return this;
209220
}
210221

222+
public Builder withPunctuations(char... punctuations) {
223+
List<Character> list = new ArrayList<>();
224+
for (char punctuation : punctuations) {
225+
list.add(punctuation);
226+
}
227+
this.punctuations = Collections.unmodifiableList(list);
228+
return this;
229+
}
230+
211231
public TokenTextSplitter build() {
212232
return new TokenTextSplitter(this.chunkSize, this.minChunkSizeChars, this.minChunkLengthToEmbed,
213-
this.maxNumChunks, this.keepSeparator);
233+
this.maxNumChunks, this.keepSeparator, this.punctuations);
214234
}
215235

216236
}

spring-ai-commons/src/test/java/org/springframework/ai/transformer/splitter/TokenTextSplitterTest.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,26 @@ public void testLargeTextStillSplitsAtPunctuation() {
165165
assertThat(splitted.get(0).getText()).endsWith(".");
166166
}
167167

168+
@Test
169+
public void testLargeTextStillSplitsAtChinesePunctuation() {
170+
// Verify that punctuation-based splitting still works when text exceeds chunk
171+
// size
172+
TokenTextSplitter splitter = TokenTextSplitter.builder()
173+
.withKeepSeparator(true)
174+
.withChunkSize(15)
175+
.withMinChunkSizeChars(10)
176+
.build();
177+
178+
// This text has multiple sentences and will exceed 15 tokens
179+
Document testDoc = new Document(
180+
"This is the first sentence with enough words。 This is the second sentence。 And this is the third sentence。");
181+
List<Document> splitted = splitter.split(testDoc);
182+
183+
// Should split into multiple chunks at punctuation marks
184+
assertThat(splitted.size()).isGreaterThan(1);
185+
186+
// Verify first chunk ends with punctuation
187+
assertThat(splitted.get(0).getText()).endsWith("。");
188+
}
189+
168190
}

0 commit comments

Comments
 (0)