3737import org .springframework .ai .chat .model .ChatModel ;
3838import org .springframework .ai .chat .model .ChatResponse ;
3939import org .springframework .ai .chat .model .Generation ;
40- import org .springframework .ai .chat .model .MessageAggregator ;
4140import org .springframework .ai .chat .prompt .Prompt ;
4241
4342import static org .assertj .core .api .Assertions .assertThat ;
4443import static org .mockito .BDDMockito .given ;
4544
4645/**
46+ * Tests for the ChatClient with a focus on verifying the handling of conversation memory
47+ * and the integration of PromptChatMemoryAdvisor to ensure accurate responses based on
48+ * previous interactions.
49+ *
4750 * @author Christian Tzolov
4851 * @author Alexandros Pappas
4952 */
@@ -63,32 +66,33 @@ private String join(Flux<String> fluxContent) {
6366 @ Test
6467 public void promptChatMemory () {
6568
66- var builder = ChatResponseMetadata .builder ()
67- .id ("124" )
68- .usage (new MessageAggregator .DefaultUsage (1 , 2 , 3 ))
69- .model ("gpt4o" )
70- .keyValue ("created" , 0L )
71- .keyValue ("system-fingerprint" , "john doe" );
72- ChatResponseMetadata chatResponseMetadata = builder .build ();
69+ // Create a ChatResponseMetadata instance with default values
70+ ChatResponseMetadata chatResponseMetadata = ChatResponseMetadata .builder ().build ();
7371
72+ // Mock the chatModel to return predefined ChatResponse objects when called
7473 given (this .chatModel .call (this .promptCaptor .capture ()))
7574 .willReturn (
7675 new ChatResponse (List .of (new Generation (new AssistantMessage ("Hello John" ))), chatResponseMetadata ))
7776 .willReturn (new ChatResponse (List .of (new Generation (new AssistantMessage ("Your name is John" ))),
7877 chatResponseMetadata ));
7978
79+ // Initialize an in-memory chat memory to store conversation history
8080 ChatMemory chatMemory = new InMemoryChatMemory ();
8181
82+ // Build a ChatClient with default system text and a memory advisor
8283 var chatClient = ChatClient .builder (this .chatModel )
8384 .defaultSystem ("Default system text." )
8485 .defaultAdvisors (new PromptChatMemoryAdvisor (chatMemory ))
8586 .build ();
8687
88+ // Simulate a user prompt and verify the response
8789 ChatResponse chatResponse = chatClient .prompt ().user ("my name is John" ).call ().chatResponse ();
8890
91+ // Assert that the response content matches the expected output
8992 String content = chatResponse .getResult ().getOutput ().getText ();
9093 assertThat (content ).isEqualTo ("Hello John" );
9194
95+ // Capture and verify the system message instructions
9296 Message systemMessage = this .promptCaptor .getValue ().getInstructions ().get (0 );
9397 assertThat (systemMessage .getText ()).isEqualToIgnoringWhitespace ("""
9498 Default system text.
@@ -101,13 +105,17 @@ public void promptChatMemory() {
101105 """ );
102106 assertThat (systemMessage .getMessageType ()).isEqualTo (MessageType .SYSTEM );
103107
108+ // Capture and verify the user message instructions
104109 Message userMessage = this .promptCaptor .getValue ().getInstructions ().get (1 );
105110 assertThat (userMessage .getText ()).isEqualToIgnoringWhitespace ("my name is John" );
106111
112+ // Simulate another user prompt and verify the response
107113 content = chatClient .prompt ().user ("What is my name?" ).call ().content ();
108114
115+ // Assert that the response content matches the expected output
109116 assertThat (content ).isEqualTo ("Your name is John" );
110117
118+ // Capture and verify the updated system message instructions
111119 systemMessage = this .promptCaptor .getValue ().getInstructions ().get (0 );
112120 assertThat (systemMessage .getText ()).isEqualToIgnoringWhitespace ("""
113121 Default system text.
@@ -122,13 +130,15 @@ public void promptChatMemory() {
122130 """ );
123131 assertThat (systemMessage .getMessageType ()).isEqualTo (MessageType .SYSTEM );
124132
133+ // Capture and verify the updated user message instructions
125134 userMessage = this .promptCaptor .getValue ().getInstructions ().get (1 );
126135 assertThat (userMessage .getText ()).isEqualToIgnoringWhitespace ("What is my name?" );
127136 }
128137
129138 @ Test
130139 public void streamingPromptChatMemory () {
131140
141+ // Mock the chatModel to stream predefined ChatResponse objects
132142 given (this .chatModel .stream (this .promptCaptor .capture ())).willReturn (Flux .generate (
133143 () -> new ChatResponse (List .of (new Generation (new AssistantMessage ("Hello John" )))), (state , sink ) -> {
134144 sink .next (state );
@@ -143,17 +153,22 @@ public void streamingPromptChatMemory() {
143153 return state ;
144154 }));
145155
156+ // Initialize an in-memory chat memory to store conversation history
146157 ChatMemory chatMemory = new InMemoryChatMemory ();
147158
159+ // Build a ChatClient with default system text and a memory advisor
148160 var chatClient = ChatClient .builder (this .chatModel )
149161 .defaultSystem ("Default system text." )
150162 .defaultAdvisors (new PromptChatMemoryAdvisor (chatMemory ))
151163 .build ();
152164
165+ // Simulate a streaming user prompt and verify the response
153166 var content = join (chatClient .prompt ().user ("my name is John" ).stream ().content ());
154167
168+ // Assert that the streamed content matches the expected output
155169 assertThat (content ).isEqualTo ("Hello John" );
156170
171+ // Capture and verify the system message instructions
157172 Message systemMessage = this .promptCaptor .getValue ().getInstructions ().get (0 );
158173 assertThat (systemMessage .getText ()).isEqualToIgnoringWhitespace ("""
159174 Default system text.
@@ -166,13 +181,17 @@ public void streamingPromptChatMemory() {
166181 """ );
167182 assertThat (systemMessage .getMessageType ()).isEqualTo (MessageType .SYSTEM );
168183
184+ // Capture and verify the user message instructions
169185 Message userMessage = this .promptCaptor .getValue ().getInstructions ().get (1 );
170186 assertThat (userMessage .getText ()).isEqualToIgnoringWhitespace ("my name is John" );
171187
188+ // Simulate another streaming user prompt and verify the response
172189 content = join (chatClient .prompt ().user ("What is my name?" ).stream ().content ());
173190
191+ // Assert that the streamed content matches the expected output
174192 assertThat (content ).isEqualTo ("Your name is John" );
175193
194+ // Capture and verify the updated system message instructions
176195 systemMessage = this .promptCaptor .getValue ().getInstructions ().get (0 );
177196 assertThat (systemMessage .getText ()).isEqualToIgnoringWhitespace ("""
178197 Default system text.
@@ -187,6 +206,7 @@ public void streamingPromptChatMemory() {
187206 """ );
188207 assertThat (systemMessage .getMessageType ()).isEqualTo (MessageType .SYSTEM );
189208
209+ // Capture and verify the updated user message instructions
190210 userMessage = this .promptCaptor .getValue ().getInstructions ().get (1 );
191211 assertThat (userMessage .getText ()).isEqualToIgnoringWhitespace ("What is my name?" );
192212 }
0 commit comments