1+ from __future__ import annotations
12import asyncio
23import os
34from abc import ABC , ABCMeta , abstractmethod
45from dataclasses import asdict
56from logging import Logger
67from time import time
7- from typing import TYPE_CHECKING , Any , Optional
8+ from typing import TYPE_CHECKING , Any , Optional , Tuple
89
910from jupyter_ai .config_manager import ConfigManager
1011from jupyterlab_chat .models import Message , NewMessage , User
1112from jupyterlab_chat .ychat import YChat
13+ from litellm import ModelResponseStream , supports_function_calling
14+ from litellm .utils import function_to_dict
1215from pydantic import BaseModel
1316from traitlets import MetaHasTraits
1417from traitlets .config import LoggingConfigurable
1518
1619from .persona_awareness import PersonaAwareness
20+ from ..litellm_utils import ToolCallList , ResolvedToolCall
21+
22+ # Import toolkits
23+ from jupyter_ai_tools .toolkits .file_system import toolkit as fs_toolkit
24+ from jupyter_ai_tools .toolkits .code_execution import toolkit as codeexec_toolkit
25+ from jupyter_ai_tools .toolkits .git import toolkit as git_toolkit
1726
18- # prevents a circular import
19- # types imported under this block have to be surrounded in single quotes on use
2027if TYPE_CHECKING :
2128 from collections .abc import AsyncIterator
22-
23- from litellm import ModelResponseStream
24-
2529 from .persona_manager import PersonaManager
30+ from ..tools import Toolkit
2631
32+ DEFAULT_TOOLKITS : dict [str , Toolkit ] = {
33+ "fs" : fs_toolkit ,
34+ "codeexec" : codeexec_toolkit ,
35+ "git" : git_toolkit ,
36+ }
2737
2838class PersonaDefaults (BaseModel ):
2939 """
@@ -237,7 +247,7 @@ def as_user_dict(self) -> dict[str, Any]:
237247
238248 async def stream_message (
239249 self , reply_stream : "AsyncIterator[ModelResponseStream | str]"
240- ) -> None :
250+ ) -> Tuple [ ResolvedToolCall , ToolCallList ] :
241251 """
242252 Takes an async iterator, dubbed the 'reply stream', and streams it to a
243253 new message by this persona in the YChat. The async iterator may yield
@@ -247,21 +257,36 @@ async def stream_message(
247257 stream, then continuously updates it until the stream is closed.
248258
249259 - Automatically manages its awareness state to show writing status.
260+
261+ Returns a list of `ResolvedToolCall` objects. If this list is not empty,
262+ the persona should run these tools.
250263 """
251264 stream_id : Optional [str ] = None
252265 stream_interrupted = False
253266 try :
254267 self .awareness .set_local_state_field ("isWriting" , True )
255- async for chunk in reply_stream :
256- # Coerce LiteLLM stream chunk to a string delta
257- if not isinstance (chunk , str ):
258- chunk = chunk .choices [0 ].delta .content
268+ toolcall_list = ToolCallList ()
269+ resolved_toolcalls : list [ResolvedToolCall ] = []
259270
260- # LiteLLM streams always terminate with an empty chunk, so we
261- # ignore and continue when this occurs.
262- if not chunk :
271+ async for chunk in reply_stream :
272+ # Compute `content_delta` and `tool_calls_delta` based on the
273+ # type of object yielded by `reply_stream`.
274+ if isinstance (chunk , ModelResponseStream ):
275+ delta = chunk .choices [0 ].delta
276+ content_delta = delta .content
277+ toolcalls_delta = delta .tool_calls
278+ elif isinstance (chunk , str ):
279+ content_delta = chunk
280+ toolcalls_delta = None
281+ else :
282+ raise Exception (f"Unrecognized type in stream_message(): { type (chunk )} " )
283+
284+ # LiteLLM streams always terminate with an empty chunk, so
285+ # continue in this case.
286+ if not (content_delta or toolcalls_delta ):
263287 continue
264288
289+ # Terminate the stream if the user requested it.
265290 if (
266291 stream_id
267292 and stream_id in self .message_interrupted .keys ()
@@ -280,34 +305,46 @@ async def stream_message(
280305 stream_interrupted = True
281306 break
282307
283- if not stream_id :
284- stream_id = self .ychat .add_message (
285- NewMessage (body = "" , sender = self .id )
308+ # Append `content_delta` to the existing message.
309+ if content_delta :
310+ # Start the stream with an empty message on the initial reply.
311+ # Bind the new message ID to `stream_id`.
312+ if not stream_id :
313+ stream_id = self .ychat .add_message (
314+ NewMessage (body = "" , sender = self .id )
315+ )
316+ self .message_interrupted [stream_id ] = asyncio .Event ()
317+ self .awareness .set_local_state_field ("isWriting" , stream_id )
318+ assert stream_id
319+
320+ self .ychat .update_message (
321+ Message (
322+ id = stream_id ,
323+ body = content_delta ,
324+ time = time (),
325+ sender = self .id ,
326+ raw_time = False ,
327+ ),
328+ append = True ,
286329 )
287- self .message_interrupted [stream_id ] = asyncio .Event ()
288- self .awareness .set_local_state_field ("isWriting" , stream_id )
289-
290- assert stream_id
291- self .ychat .update_message (
292- Message (
293- id = stream_id ,
294- body = chunk ,
295- time = time (),
296- sender = self .id ,
297- raw_time = False ,
298- ),
299- append = True ,
300- )
330+ if toolcalls_delta :
331+ toolcall_list += toolcalls_delta
332+
333+ # After the reply stream is complete, resolve the list of tool calls.
334+ resolved_toolcalls = toolcall_list .resolve ()
301335 except Exception as e :
302336 self .log .error (
303337 f"Persona '{ self .name } ' encountered an exception printed below when attempting to stream output."
304338 )
305339 self .log .exception (e )
306340 finally :
341+ # Reset local state
307342 self .awareness .set_local_state_field ("isWriting" , False )
308- if stream_id :
309- # if stream was interrupted, add a tombstone
310- if stream_interrupted :
343+ self .message_interrupted .pop (stream_id , None )
344+
345+ # If stream was interrupted, add a tombstone and return `[]`,
346+ # indicating that no tools should be run afterwards.
347+ if stream_id and stream_interrupted :
311348 stream_tombstone = "\n \n (AI response stopped by user)"
312349 self .ychat .update_message (
313350 Message (
@@ -319,8 +356,15 @@ async def stream_message(
319356 ),
320357 append = True ,
321358 )
322- if stream_id in self .message_interrupted .keys ():
323- del self .message_interrupted [stream_id ]
359+ return None
360+
361+ # Otherwise return the resolved list.
362+ if len (resolved_toolcalls ):
363+ count = len (resolved_toolcalls )
364+ names = sorted ([tc .function .name for tc in resolved_toolcalls ])
365+ self .log .info (f"AI response triggered { count } tool calls: { names } " )
366+ return resolved_toolcalls , toolcall_list
367+
324368
325369 def send_message (self , body : str ) -> None :
326370 """
@@ -361,7 +405,7 @@ def get_mcp_config(self) -> dict[str, Any]:
361405 Returns the MCP config for the current chat.
362406 """
363407 return self .parent .get_mcp_config ()
364-
408+
365409 def process_attachments (self , message : Message ) -> Optional [str ]:
366410 """
367411 Process file attachments in the message and return their content as a string.
@@ -431,6 +475,99 @@ def resolve_attachment_to_path(self, attachment_id: str) -> Optional[str]:
431475 self .log .error (f"Failed to resolve attachment { attachment_id } : { e } " )
432476 return None
433477
478+ def get_tools (self , model_id : str ) -> list [dict ]:
479+ """
480+ Returns the `tools` parameter which should be passed to
481+ `litellm.acompletion()` for a given LiteLLM model ID.
482+
483+ If the model does not support tool-calling, this method returns an empty
484+ list. Otherwise, it returns the list of tools available in the current
485+ environment. These may include:
486+
487+ - The default set of tool functions in Jupyter AI, defined in the
488+ `jupyter_ai_tools` package.
489+
490+ - (TODO) Tools provided by MCP server configuration, if any.
491+
492+ - (TODO) Web search.
493+
494+ - (TODO) File search using vector store IDs.
495+
496+ TODO: cache this
497+
498+ TODO: Implement some permissions system so users can control what tools
499+ are allowable.
500+
501+ NOTE: The returned list is expected by LiteLLM to conform to the `tools`
502+ parameter defintiion defined by the OpenAI API:
503+ https://platform.openai.com/docs/guides/tools#available-tools
504+
505+ NOTE: This API is a WIP and is very likely to change.
506+ """
507+ # Return early if the model does not support tool calling
508+ if not supports_function_calling (model = model_id ):
509+ return []
510+
511+ tool_descriptions = []
512+
513+ # Get all tools from `jupyter_ai_tools` and store their object descriptions
514+ for toolkit_name , toolkit in DEFAULT_TOOLKITS .items ():
515+ # TODO: make these tool permissions configurable.
516+ for tool in toolkit .get_tools ():
517+ # Here, we are using a util function from LiteLLM to coerce
518+ # each `Tool` struct into a tool description dictionary expected
519+ # by LiteLLM.
520+ desc = {
521+ "type" : "function" ,
522+ "function" : function_to_dict (tool .callable ),
523+ }
524+
525+ # Prepend the toolkit name to each function name, hopefully
526+ # ensuring every tool function has a unique name.
527+ # e.g. 'git_add' => 'git__git_add'
528+ #
529+ # TODO: Actually ensure this instead of hoping.
530+ desc ['function' ]['name' ] = f"{ toolkit_name } __{ desc ['function' ]['name' ]} "
531+ tool_descriptions .append (desc )
532+
533+ # Finally, return the tool descriptions
534+ return tool_descriptions
535+
536+
537+ async def run_tools (self , tools : list [ResolvedToolCall ]) -> list [dict ]:
538+ """
539+ Runs the tools specified in the list of tool calls returned by
540+ `self.stream_message()`. Returns a list of dictionaries
541+ `toolcall_outputs: list[dict]`, which should be appended directly to the
542+ message history on the next invocation of the LLM.
543+ """
544+ if not len (tools ):
545+ return []
546+
547+ tool_outputs : list [dict ] = []
548+ for tool_call in tools :
549+ # Get tool definition from the correct toolkit
550+ toolkit_name , tool_name = tool_call .function .name .split ("__" )
551+ assert toolkit_name in DEFAULT_TOOLKITS
552+ tool_defn = DEFAULT_TOOLKITS [toolkit_name ].get_tool_unsafe (tool_name )
553+
554+ # Run tool and store its output
555+ output = await tool_defn .callable (** tool_call .function .arguments )
556+
557+ # Store the tool output in a dictionary accepted by LiteLLM
558+ output_dict = {
559+ "tool_call_id" : tool_call .id ,
560+ "role" : "tool" ,
561+ "name" : tool_call .function .name ,
562+ "content" : output ,
563+ }
564+ tool_outputs .append (output_dict )
565+
566+ self .log .info (f"Ran { len (tools )} tool functions." )
567+ return tool_outputs
568+
569+
570+
434571 def shutdown (self ) -> None :
435572 """
436573 Shuts the persona down. This method should:
0 commit comments