1717package org .springframework .ai .transformer .splitter ;
1818
1919import java .util .ArrayList ;
20+ import java .util .Arrays ;
21+ import java .util .Collections ;
2022import java .util .List ;
2123
2224import com .knuddels .jtokkit .Encodings ;
@@ -46,6 +48,8 @@ public class TokenTextSplitter extends TextSplitter {
4648
4749 private static final boolean KEEP_SEPARATOR = true ;
4850
51+ private static final List <Character > DEFAULT_PUNCTUATIONS = List .of ('.' , '?' , '!' , '。' , '?' , '!' , '\n' );
52+
4953 private final EncodingRegistry registry = Encodings .newLazyEncodingRegistry ();
5054
5155 private final Encoding encoding = this .registry .getEncoding (EncodingType .CL100K_BASE );
@@ -64,21 +68,24 @@ public class TokenTextSplitter extends TextSplitter {
6468
6569 private final boolean keepSeparator ;
6670
71+ private final List <Character > punctuations ;
72+
6773 public TokenTextSplitter () {
68- this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , KEEP_SEPARATOR );
74+ this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , KEEP_SEPARATOR , DEFAULT_PUNCTUATIONS );
6975 }
7076
7177 public TokenTextSplitter (boolean keepSeparator ) {
72- this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , keepSeparator );
78+ this (DEFAULT_CHUNK_SIZE , MIN_CHUNK_SIZE_CHARS , MIN_CHUNK_LENGTH_TO_EMBED , MAX_NUM_CHUNKS , keepSeparator , DEFAULT_PUNCTUATIONS );
7379 }
7480
7581 public TokenTextSplitter (int chunkSize , int minChunkSizeChars , int minChunkLengthToEmbed , int maxNumChunks ,
76- boolean keepSeparator ) {
82+ boolean keepSeparator , List < Character > punctuations ) {
7783 this .chunkSize = chunkSize ;
7884 this .minChunkSizeChars = minChunkSizeChars ;
7985 this .minChunkLengthToEmbed = minChunkLengthToEmbed ;
8086 this .maxNumChunks = maxNumChunks ;
8187 this .keepSeparator = keepSeparator ;
88+ this .punctuations = punctuations ;
8289 }
8390
8491 public static Builder builder () {
@@ -124,8 +131,10 @@ protected List<String> doSplit(String text, int chunkSize) {
124131 // This prevents unnecessary splitting of small texts
125132 if (tokens .size () > chunkSize ) {
126133 // Find the last period or punctuation mark in the chunk
127- int lastPunctuation = Math .max (chunkText .lastIndexOf ('.' ), Math .max (chunkText .lastIndexOf ('?' ),
128- Math .max (chunkText .lastIndexOf ('!' ), chunkText .lastIndexOf ('\n' ))));
134+ int lastPunctuation = -1 ;
135+ for (char punctuation : punctuations ) {
136+ lastPunctuation = Math .max (lastPunctuation , chunkText .lastIndexOf (punctuation ));
137+ }
129138
130139 if (lastPunctuation != -1 && lastPunctuation > this .minChunkSizeChars ) {
131140 // Truncate the chunk text at the punctuation mark
@@ -180,6 +189,8 @@ public static final class Builder {
180189
181190 private boolean keepSeparator = KEEP_SEPARATOR ;
182191
192+ private List <Character > punctuations = DEFAULT_PUNCTUATIONS ;
193+
183194 private Builder () {
184195 }
185196
@@ -208,9 +219,18 @@ public Builder withKeepSeparator(boolean keepSeparator) {
208219 return this ;
209220 }
210221
222+ public Builder withPunctuations (char ... punctuations ) {
223+ List <Character > list = new ArrayList <>();
224+ for (char punctuation : punctuations ) {
225+ list .add (punctuation );
226+ }
227+ this .punctuations = Collections .unmodifiableList (list );
228+ return this ;
229+ }
230+
211231 public TokenTextSplitter build () {
212232 return new TokenTextSplitter (this .chunkSize , this .minChunkSizeChars , this .minChunkLengthToEmbed ,
213- this .maxNumChunks , this .keepSeparator );
233+ this .maxNumChunks , this .keepSeparator , this . punctuations );
214234 }
215235
216236 }
0 commit comments