diff --git a/pyproject.toml b/pyproject.toml index 4a586bc..e8dd6c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "tree_sitter_language_pack>=0.7.0", "tree_sitter_languages>=1.9.1", "vtk>=9.3.1", + "pyyaml>=6.0.0", ] requires-python = ">=3.10" readme = "README.md" @@ -62,7 +63,13 @@ build-backend = "setuptools.build_meta" fallback_version = "0.1.0" [tool.setuptools.package-data] -vtk_prompt = ["prompts/*.txt"] +vtk_prompt = ["prompts/*.yml"] + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-dir] +"" = "src" [tool.black] include = 'src/.*.py$' diff --git a/rag-components b/rag-components index 317b2fc..32c0442 160000 --- a/rag-components +++ b/rag-components @@ -1 +1 @@ -Subproject commit 317b2fcf133d7121171cff0881e10fd0745b6a78 +Subproject commit 32c04421d1f27aa5a1a823344f13974c0518ccbb diff --git a/src/vtk_prompt/prompt.py b/src/vtk_prompt/prompt.py index 5718690..2bb110d 100755 --- a/src/vtk_prompt/prompt.py +++ b/src/vtk_prompt/prompt.py @@ -2,6 +2,7 @@ import ast import os +import re import sys import json import openai @@ -9,11 +10,8 @@ from dataclasses import dataclass from pathlib import Path -from .prompts import ( - get_no_rag_context, - get_rag_context, - get_python_role, -) +# Using YAML system exclusively +from .yaml_prompt_loader import GitHubModelYAMLLoader @dataclass @@ -85,30 +83,32 @@ def run_code(self, code_string): print(code_string) return None - def query( + def query_yaml( self, - message="", - api_key=None, - model="gpt-4o", + message, + api_key, + prompt_source="vtk_python_code_generation", base_url=None, - max_tokens=1000, - temperature=0.1, - top_k=5, rag=False, + top_k=5, retry_attempts=1, + override_model=None, + override_temperature=None, + override_max_tokens=None, ): - """Generate VTK code with optional RAG enhancement and retry logic. + """Generate VTK code using YAML prompt templates. Args: message: The user query api_key: API key for the service - model: Model name to use + prompt_source: Name of the YAML prompt file to use or binary blob of the prompt file content base_url: API base URL - max_tokens: Maximum tokens to generate - temperature: Temperature for generation - top_k: Number of RAG examples to retrieve rag: Whether to use RAG enhancement - retry_attempts: Number of times to retry if AST validation fails + top_k: Number of RAG examples to retrieve + retry_attempts: Number of retry attempts for failed generations + + Returns: + Generated code string or None if failed """ if not api_key: api_key = os.environ.get("OPENAI_API_KEY") @@ -121,13 +121,18 @@ def query( # Create client with current parameters client = openai.OpenAI(api_key=api_key, base_url=base_url) - # Load existing conversation if present - if self.conversation_file and not self.conversation: - self.conversation = self.load_conversation() + # Load YAML prompt configuration + from pathlib import Path + + prompts_dir = Path(__file__).parent / "prompts" + yaml_loader = GitHubModelYAMLLoader(prompts_dir) + model_params = yaml_loader.get_model_parameters(prompt_source) + model = override_model or yaml_loader.get_model_name(prompt_source) - if not message and not self.conversation: - raise ValueError("No prompt or conversation file provided") + # Prepare variables for template substitution + variables = {"request": message} + # Handle RAG if requested if rag: from .rag_chat_wrapper import ( check_rag_components_available, @@ -148,27 +153,33 @@ def query( raise ValueError("Failed to load RAG snippets") context_snippets = "\n\n".join(rag_snippets["code_snippets"]) - context = get_rag_context(message, context_snippets) + variables["context_snippets"] = context_snippets if self.verbose: - print("CONTEXT: " + context) references = rag_snippets.get("references") if references: print("Using examples from:") for ref in references: print(f"- {ref}") - else: - context = get_no_rag_context(message) - if self.verbose: - print("CONTEXT: " + context) - # If no conversation exists, start with system role - if not self.conversation: - self.conversation = [{"role": "system", "content": get_python_role()}] + # Load existing conversation or start fresh + conversation_messages = self.load_conversation() - # Add current user message - if message: - self.conversation.append({"role": "user", "content": context}) + # Build base messages from YAML template + base_messages = yaml_loader.build_messages(prompt_source, variables) + + # If conversation exists, extend it with new user message + if conversation_messages: + # Add the current request as a new user message + conversation_messages.append({"role": "user", "content": message}) + self.conversation = conversation_messages + else: + # Use YAML template as starting point + self.conversation = base_messages + + # Extract parameters with overrides + temperature = override_temperature or model_params.get("temperature", 0.3) + max_tokens = override_max_tokens or model_params.get("max_tokens", 2000) # Retry loop for AST validation for attempt in range(retry_attempts): @@ -197,58 +208,50 @@ def query( f"Output was truncated due to max_tokens limit ({max_tokens}). Please increase max_tokens." ) - generated_code = None - if "import vtk" not in content: - generated_code = "import vtk\n" + content - else: - pos = content.find("import vtk") - if pos != -1: - generated_code = content[pos:] - else: - generated_code = content + generated_explanation = re.findall( + "(.*?)", content, re.DOTALL + )[0] + generated_code = re.findall("(.*?)", content, re.DOTALL)[0] + if "import vtk" not in generated_code: + generated_code = f"import vtk\n{generated_code}" is_valid, error_msg = self.validate_code_syntax(generated_code) if is_valid: - if message: - self.conversation.append( - {"role": "assistant", "content": content} - ) - self.save_conversation() - return generated_code, response.usage + # Save conversation with assistant response + self.conversation.append({"role": "assistant", "content": content}) + self.save_conversation() - elif attempt < retry_attempts - 1: # Don't print on last attempt if self.verbose: - print(f"AST validation failed: {error_msg}. Retrying...") - # Add error feedback to context for retry - self.conversation.append({"role": "assistant", "content": content}) - self.conversation.append( - { - "role": "user", - "content": ( - f"The generated code has a syntax error: {error_msg}. " - "Please fix the syntax and generate valid Python code." - ), - } - ) + print("Code validation successful!") + return generated_code, generated_explanation else: - # Last attempt failed if self.verbose: - print(f"Final attempt failed AST validation: {error_msg}") - - if message: - self.conversation.append( - {"role": "assistant", "content": content} + print( + f"Code validation failed on attempt {attempt + 1}: {error_msg}" ) + print("Generated code:") + print(generated_code) + + if attempt < retry_attempts - 1: + # Add error feedback to messages for retry + error_feedback = ( + f"The previous code had a syntax error: {error_msg}. " + "Please fix the syntax and try again." + ) + self.conversation.append({"role": "user", "content": error_feedback}) + else: + # Save conversation even if final attempt failed + self.conversation.append({"role": "assistant", "content": content}) self.save_conversation() - return ( - generated_code, - response.usage, - ) # Return anyway, let caller handle + print( + f"All {retry_attempts} attempts failed. Final error: {error_msg}" + ) + return generated_code, generated_explanation # Return anyway, let caller handle else: - if attempt == retry_attempts - 1: - return "No response generated", response.usage + print("No response content received") + return None - return "No response generated" + return None @click.command() @@ -259,14 +262,14 @@ def query( default="openai", help="LLM provider to use", ) -@click.option("-m", "--model", default="gpt-4o", help="Model name to use") +@click.option("-m", "--model", default="gpt-4o-mini", help="Model name to use") @click.option( "-k", "--max-tokens", type=int, default=1000, help="Max # of tokens to generate" ) @click.option( "--temperature", type=float, - default=0.7, + default=0.1, help="Temperature for generation (0.0-2.0)", ) @click.option( @@ -310,7 +313,7 @@ def main( retry_attempts, conversation, ): - """Generate and execute VTK code using LLMs. + """Generate and execute VTK code using LLMs with YAML prompts. INPUT_STRING: The code description to generate VTK code for """ @@ -340,22 +343,26 @@ def main( verbose=verbose, conversation_file=conversation, ) - generated_code, usage = client.query( + + # Use YAML system directly + prompt_source = "rag_context" if rag else "no_rag_context" + generated_code = client.query_yaml( input_string, api_key=token, - model=model, + prompt_source=prompt_source, base_url=base_url, - max_tokens=max_tokens, - temperature=temperature, - top_k=top_k, rag=rag, + top_k=top_k, retry_attempts=retry_attempts, + # Override parameters if specified in CLI + override_model=model if model != "gpt-4o-mini" else None, + override_temperature=temperature if temperature != 0.1 else None, + override_max_tokens=max_tokens if max_tokens != 1000 else None, ) - if verbose and usage is not None: - print( - f"Used tokens: input={usage.prompt_tokens} output={usage.completion_tokens}" - ) + # Usage tracking not yet implemented for YAML system + if verbose: + print("Token usage tracking not available in YAML mode") client.run_code(generated_code) diff --git a/src/vtk_prompt/prompts/__init__.py b/src/vtk_prompt/prompts/__init__.py index b63dcb9..e9423c8 100644 --- a/src/vtk_prompt/prompts/__init__.py +++ b/src/vtk_prompt/prompts/__init__.py @@ -2,6 +2,7 @@ from pathlib import Path import vtk +from ..yaml_prompt_loader import GitHubModelYAMLLoader PYTHON_VERSION = ">=3.10" VTK_VERSION = vtk.__version__ @@ -9,70 +10,40 @@ # Path to the prompts directory PROMPTS_DIR = Path(__file__).parent +# Initialize YAML loader for current directory (src/vtk_prompt/prompts) +_yaml_loader = GitHubModelYAMLLoader(PROMPTS_DIR) -def load_template(template_name: str) -> str: - """Load a template file from the prompts directory. - Args: - template_name: Name of the template file (without .txt extension) - - Returns: - The template content as a string - """ - template_path = PROMPTS_DIR / f"{template_name}.txt" - if not template_path.exists(): - raise FileNotFoundError( - f"Template {template_name} not found at {template_path}" - ) - - return template_path.read_text() - - -def get_base_context() -> str: - """Get the base context template with version variables filled in.""" - template = load_template("base_context") - return template.format(VTK_VERSION=VTK_VERSION, PYTHON_VERSION=PYTHON_VERSION) - - -def get_no_rag_context(request: str) -> str: - """Get the no-RAG context template with request filled in.""" - base_context = get_base_context() - template = load_template("no_rag_context") - return template.format(BASE_CONTEXT=base_context, request=request) - - -def get_rag_context(request: str, context_snippets: str) -> str: - """Get the RAG context template with request and snippets filled in.""" - base_context = get_base_context() - template = load_template("rag_context") - return template.format( - BASE_CONTEXT=base_context, request=request, context_snippets=context_snippets +# Legacy functions for backward compatibility with rag_chat_wrapper +def get_rag_chat_context(context: str, query: str) -> str: + """Get the RAG chat context template with context and query filled in.""" + # Use YAML version + messages = _yaml_loader.build_messages( + "rag_chat_context", {"CONTEXT": context, "QUERY": query} ) - - -def get_python_role() -> str: - """Get the Python role template with version filled in.""" - template = load_template("python_role") - return template.format(PYTHON_VERSION=PYTHON_VERSION) + # Return combined system + user content for backward compatibility + system_content = "" + user_content = "" + for msg in messages: + if msg["role"] == "system": + system_content = msg["content"] + elif msg["role"] == "user": + user_content = msg["content"] + return f"{system_content}\n\n{user_content}" def get_vtk_xml_context(description: str) -> str: """Get the VTK XML context template with description filled in.""" - template = load_template("vtk_xml_context") - return template.format(VTK_VERSION=VTK_VERSION, description=description) - - -def get_xml_role() -> str: - """Get the XML role template.""" - return load_template("xml_role") - - -def get_ui_post_prompt() -> str: - """Get the UI post prompt template.""" - return load_template("ui_post_prompt") - - -def get_rag_chat_context(context: str, query: str) -> str: - """Get the RAG chat context template with context and query filled in.""" - template = load_template("rag_chat_context") - return template.format(CONTEXT=context, QUERY=query) + # Use YAML version + messages = _yaml_loader.build_messages( + "vtk_xml_context", {"description": description} + ) + # Return combined system + user content for backward compatibility + system_content = "" + user_content = "" + for msg in messages: + if msg["role"] == "system": + system_content = msg["content"] + elif msg["role"] == "user": + user_content = msg["content"] + return f"{system_content}\n\n{user_content}" diff --git a/src/vtk_prompt/prompts/base_context.txt b/src/vtk_prompt/prompts/base_context.txt deleted file mode 100644 index 62d6add..0000000 --- a/src/vtk_prompt/prompts/base_context.txt +++ /dev/null @@ -1,23 +0,0 @@ -Write only python source code that uses VTK. - - -- DO NOT READ OUTSIDE DATA -- DO NOT DEFINE FUNCTIONS -- NO TEXT, ONLY SOURCE CODE -- ONLY import VTK and numpy if needed -- Only use {VTK_VERSION} python basic components. -- Only use {PYTHON_VERSION} or above. - - - -- Only output verbatin python code. -- Only VTK library -- No explanations -- No ```python marker -- No markdown - - - -input: Only create a vtkShpere -output: sphere = vtk.vtkSphereSource() - diff --git a/src/vtk_prompt/prompts/no_rag_context.prompt.yml b/src/vtk_prompt/prompts/no_rag_context.prompt.yml new file mode 100644 index 0000000..89bf293 --- /dev/null +++ b/src/vtk_prompt/prompts/no_rag_context.prompt.yml @@ -0,0 +1,14 @@ +name: "VTK No RAG Context" +description: "Standard VTK code generation without RAG enhancement" +model: "openai/gpt-4o-mini" +modelParameters: + temperature: 0.1 + max_tokens: 1000 +messages: + - role: system + content: !include ./src/vtk_prompt/prompts/prompt_system_content.yml + - role: assistant + content: !include ./src/vtk_prompt/prompts/prompt_assistant_content.yml + - role: user + content: | + {{request}} diff --git a/src/vtk_prompt/prompts/no_rag_context.txt b/src/vtk_prompt/prompts/no_rag_context.txt deleted file mode 100644 index 6c3d336..0000000 --- a/src/vtk_prompt/prompts/no_rag_context.txt +++ /dev/null @@ -1,4 +0,0 @@ -{BASE_CONTEXT} - -Request: -{request} \ No newline at end of file diff --git a/src/vtk_prompt/prompts/prompt_assistant_content.yml b/src/vtk_prompt/prompts/prompt_assistant_content.yml new file mode 100644 index 0000000..0891d63 --- /dev/null +++ b/src/vtk_prompt/prompts/prompt_assistant_content.yml @@ -0,0 +1,29 @@ +>- + First, provide a **short but complete explanation** written in **full + sentences**. + + The explanation must describe **what the code does** at each step. + + The explanation must describe **why the code does what it does with regards to the VTK library** at each step. + + The explanation must describe **why the code does what it does with regards to the data being visualized** at each step. + + The explanation must always come **before** the code. + + The explanation MUST begin with a "" tag and end with a "" tag. + + The code MUST begin with a "" tag and end with a "" tag. + + Do not summarize, introduce, or conclude outside the explanation or code + itself. + + Output the Python code **exactly as written**, with no additional text + before or after the code. + + **No** markdown markers like ```python or ``` anywhere. + + Do not add phrases like “Here is the source code” or similar. + + The explanation must stay **above the code**. + + You may use inline comments in the code if helpful for clarity. diff --git a/src/vtk_prompt/prompts/prompt_system_content.yml b/src/vtk_prompt/prompts/prompt_system_content.yml new file mode 100644 index 0000000..17d048e --- /dev/null +++ b/src/vtk_prompt/prompts/prompt_system_content.yml @@ -0,0 +1,32 @@ +>- + You are a python >=3.10 source code producing entity, your output will be + fed to a python interpreter. + + Write Python source code with an explanation that uses VTK. + + DO NOT READ OUTSIDE DATA. + + DO NOT DEFINE FUNCTIONS. + + DO NOT USE MARKDOWN. + + ALWAYS PROVIDE SOURCE CODE. + + ALWAYS MAKE SURE IMPORTS ARE CORRECT AND VALID. + + ALWAYS MAKE SURE THE CODE IS VALID AND CAN BE RUN. + + ONLY import VTK and numpy if needed. + + Only use 9.5.0 Python basic components. + + Only use >=3.10 or above. + + ALWAYS consider all prior instructions, corrections, and constraints from this conversation as still in effect. + + NEVER remove or undo fixes made in previous steps unless explicitly told to. + + Before producing new code, mentally check that: + - All earlier errors are still fixed. + - No earlier constraints are violated. + - The output follows the entire conversation’s accumulated rules. diff --git a/src/vtk_prompt/prompts/python_role.txt b/src/vtk_prompt/prompts/python_role.txt deleted file mode 100644 index ade13b9..0000000 --- a/src/vtk_prompt/prompts/python_role.txt +++ /dev/null @@ -1 +0,0 @@ -You are a python {PYTHON_VERSION} source code producing entity, your output will be fed to a python interpreter \ No newline at end of file diff --git a/src/vtk_prompt/prompts/rag_chat_context.prompt.yml b/src/vtk_prompt/prompts/rag_chat_context.prompt.yml new file mode 100644 index 0000000..8f18a3f --- /dev/null +++ b/src/vtk_prompt/prompts/rag_chat_context.prompt.yml @@ -0,0 +1,36 @@ +name: "VTK RAG Chat Assistant" +description: "AI assistant for VTK documentation and support with context" +model: "openai/gpt-4o-mini" +modelParameters: + temperature: 0.3 + max_tokens: 2000 +messages: + - role: system + content: | + You are an AI assistant specializing in VTK (Visualization Toolkit) + documentation. Your primary task is to provide accurate, concise, and helpful + responses to user queries about VTK, including relevant code snippets + + Here is the context information you should use to answer queries: + + {{CONTEXT}} + + + When responding to a user query, follow these guidelines: + + 1. Relevance Check: + + - If the query is not relevant to VTK, respond with "This question is not relevant to VTK." + + 2. Answer Formulation: + + - If you don't know the answer, clearly state that. + - If uncertain, ask the user for clarification. + - Respond in the same language as the user's query. + - Be concise while providing complete information. + - If the answer isn't in the context but you have the knowledge, explain this to the user and provide the answer based on your understanding. + - role: user + content: | + + {{QUERY}} + diff --git a/src/vtk_prompt/prompts/rag_chat_context.txt b/src/vtk_prompt/prompts/rag_chat_context.txt deleted file mode 100644 index 20d262f..0000000 --- a/src/vtk_prompt/prompts/rag_chat_context.txt +++ /dev/null @@ -1,28 +0,0 @@ -You are an AI assistant specializing in VTK (Visualization Toolkit) -documentation. Your primary task is to provide accurate, concise, and helpful -responses to user queries about VTK, including relevant code snippets - -Here is the context information you should use to answer queries: - -{CONTEXT} - - -Here's the user's query: - - -{QUERY} - - -When responding to a user query, follow these guidelines: - -1. Relevance Check: - - - If the query is not relevant to VTK, respond with "This question is not relevant to VTK." - -2. Answer Formulation: - - - If you don't know the answer, clearly state that. - - If uncertain, ask the user for clarification. - - Respond in the same language as the user's query. - - Be concise while providing complete information. - - If the answer isn't in the context but you have the knowledge, explain this to the user and provide the answer based on your understanding. diff --git a/src/vtk_prompt/prompts/rag_context.prompt.yml b/src/vtk_prompt/prompts/rag_context.prompt.yml new file mode 100644 index 0000000..aef3ce1 --- /dev/null +++ b/src/vtk_prompt/prompts/rag_context.prompt.yml @@ -0,0 +1,20 @@ +name: "VTK RAG Context" +description: "VTK code generation with RAG enhancement from examples" +model: "openai/gpt-4o-mini" +modelParameters: + temperature: 0.1 + max_tokens: 1000 +messages: + - role: system + content: !include ./src/vtk_prompt/prompts/prompt_system_content.yml + - role: assistant + content: !include ./src/vtk_prompt/prompts/prompt_assistant_content.yml + - role: user + content: | + {{#if context_snippets}} + + {{context_snippets}} + + + {{/if}} + {{request}} diff --git a/src/vtk_prompt/prompts/rag_context.txt b/src/vtk_prompt/prompts/rag_context.txt deleted file mode 100644 index a38fc0a..0000000 --- a/src/vtk_prompt/prompts/rag_context.txt +++ /dev/null @@ -1,12 +0,0 @@ -{BASE_CONTEXT} - - -- Refer to the below vtk_examples snippets, this is the the main source of thruth - - - -{context_snippets} - - -Request: -{request} \ No newline at end of file diff --git a/src/vtk_prompt/prompts/ui_context.prompt.yml b/src/vtk_prompt/prompts/ui_context.prompt.yml new file mode 100644 index 0000000..0ed709c --- /dev/null +++ b/src/vtk_prompt/prompts/ui_context.prompt.yml @@ -0,0 +1,31 @@ +name: "VTK UI Context" +description: "UI-specific VTK code generation with renderer instructions" +model: "openai/gpt-4o-mini" +modelParameters: + temperature: 0.1 + max_tokens: 1000 +messages: + - role: system + content: !include ./src/vtk_prompt/prompts/prompt_system_content.yml + Do not create a new vtkRenderer + + Use the injected vtkrenderer object named renderer + + Do not manage rendering things + + You must connect the actors to the renderer injected object + + You must render what I ask even if I do not ask to render it + + Do not render if I explicitly ask you not to render it + - role: assistant + content: !include ./src/vtk_prompt/prompts/prompt_assistant_content.yml + - role: user + content: | + {{#if context_snippets}} + + {{context_snippets}} + + + {{/if}} + {{request}} diff --git a/src/vtk_prompt/prompts/ui_post_prompt.txt b/src/vtk_prompt/prompts/ui_post_prompt.txt deleted file mode 100644 index 76edb83..0000000 --- a/src/vtk_prompt/prompts/ui_post_prompt.txt +++ /dev/null @@ -1,8 +0,0 @@ - -- Do not create a new vtkRenderer -- Use the injected vtkrenderer object named renderer -- Do not manager rendering things -- You must connect the actors to the renderer injected object -- You must render what I ask even if I do not ask to render it -- Only avoid rendering if I explictitly ask you not to render it - diff --git a/src/vtk_prompt/prompts/vtk_xml_context.prompt.yml b/src/vtk_prompt/prompts/vtk_xml_context.prompt.yml new file mode 100644 index 0000000..2457972 --- /dev/null +++ b/src/vtk_prompt/prompts/vtk_xml_context.prompt.yml @@ -0,0 +1,21 @@ +name: "VTK XML Context" +description: "Generates VTK XML files with version compatibility" +model: "openai/gpt-4o-mini" +modelParameters: + temperature: 0.1 + max_tokens: 1000 +messages: + - role: system + content: | + You are a XML VTK file generator, the generated file will be read by VTK file reader + + Generate VTK XML files that can be read by ParaView or other VTK-compatible applications. + + Output only valid XML content, no explanations or markdown. + + Use VTK {{VTK_VERSION}} compatible XML format. + - role: user + content: | + Generate a VTK XML file for: {{description}} + + The XML should be compatible with VTK {{VTK_VERSION}} and readable by ParaView. diff --git a/src/vtk_prompt/prompts/vtk_xml_context.txt b/src/vtk_prompt/prompts/vtk_xml_context.txt deleted file mode 100644 index 55f70be..0000000 --- a/src/vtk_prompt/prompts/vtk_xml_context.txt +++ /dev/null @@ -1,82 +0,0 @@ -Write only text that is the content of a XML VTK file. - - -- NO COMMENTS, ONLY CONTENT OF THE FILE -- Only use VTK {VTK_VERSION} basic components. - - - -- Only output verbatim XML content. -- No explanations -- No markup or code blocks - - - -input: A VTP file example of a 4 points with temperature and pressure data -output: - - - - - - - - 0.0 0.0 0.0 - 1.0 0.0 0.0 - 0.0 1.0 0.0 - 1.0 1.0 0.0 - - - - - - - - 25.5 - 26.7 - 24.3 - 27.1 - - - - 101.3 - 101.5 - 101.2 - 101.4 - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -Request: -{description} \ No newline at end of file diff --git a/src/vtk_prompt/prompts/xml_role.prompt.yml b/src/vtk_prompt/prompts/xml_role.prompt.yml new file mode 100644 index 0000000..7e49d9c --- /dev/null +++ b/src/vtk_prompt/prompts/xml_role.prompt.yml @@ -0,0 +1,10 @@ +name: "VTK XML Generator" +description: "Generates VTK XML files for visualization" +model: "openai/gpt-4o-mini" +modelParameters: + temperature: 0.1 + max_tokens: 1000 +messages: + - role: system + content: | + You are a XML VTK file generator, the generated file will be read by VTK file reader diff --git a/src/vtk_prompt/prompts/xml_role.txt b/src/vtk_prompt/prompts/xml_role.txt deleted file mode 100644 index 4d31804..0000000 --- a/src/vtk_prompt/prompts/xml_role.txt +++ /dev/null @@ -1 +0,0 @@ -You are a XML VTK file generator, the generated file will be read by VTK file reader \ No newline at end of file diff --git a/src/vtk_prompt/utils.js b/src/vtk_prompt/utils.js index 39e6514..6e0ff5e 100644 --- a/src/vtk_prompt/utils.js +++ b/src/vtk_prompt/utils.js @@ -5,6 +5,16 @@ window.trame.utils.vtk_prompt = { return "Invalid file type"; } return true; + }, + yaml_file(obj) { + if ( + obj && + (obj.type !== "application/x-yaml" || (!obj.name.endsWith(".yaml") && + !obj.name.endsWith(".yml"))) + ) { + return "Invalid file type"; + } + return true; } } } diff --git a/src/vtk_prompt/vtk_prompt_ui.py b/src/vtk_prompt/vtk_prompt_ui.py index 79d7f6a..8783ab9 100644 --- a/src/vtk_prompt/vtk_prompt_ui.py +++ b/src/vtk_prompt/vtk_prompt_ui.py @@ -2,6 +2,8 @@ import json from pathlib import Path +import re +import yaml # Add VTK and Trame imports from vtkmodules.vtkInteractionStyle import vtkInteractorStyleSwitch # noqa @@ -15,9 +17,10 @@ # Import our prompt functionality from .prompt import VTKPromptClient +from .query_error_handler import QueryErrorHandler -# Import our template system -from .prompts import get_ui_post_prompt +# Legacy prompts removed - using YAML system exclusively +from .yaml_prompt_loader import GitHubModelYAMLLoader EXPLAIN_RENDERER = ( "# renderer is a vtkRenderer injected by this webapp" @@ -59,6 +62,7 @@ def __init__(self, server=None): self.renderer.SetBackground(0.1, 0.1, 0.1) # Add a simple coordinate axes as default content + self.state.config_source = "ui_context" self._add_default_scene() # Initial render @@ -85,12 +89,22 @@ def _add_default_scene(self): # App state variables self.state.query_text = "" self.state.generated_code = "" + self.state.generated_explanation = "" self.state.is_loading = False self.state.use_rag = False self.state.error_message = "" self.state.conversation_object = None self.state.conversation_file = None - self.state.conversation = None + self.state.conversation = [] # Initialize as empty list instead of None + + # YAML prompt configuration - UI always uses ui_context prompt + from pathlib import Path + + prompts_dir = Path(__file__).parent / "prompts" + self.yaml_loader = GitHubModelYAMLLoader(prompts_dir) + + # Get default parameters from YAML ui_context prompt + self.default_params = self.yaml_loader.get_model_parameters(self.state.config_source) # Token usage tracking self.state.input_tokens = 0 @@ -101,8 +115,8 @@ def _add_default_scene(self): self.state.tab_index = 0 # Tab navigation state # Cloud model configuration - self.state.provider = "openai" - self.state.model = "gpt-4o" + self.state.provider = self.yaml_loader.get_model_provider(self.state.config_source) + self.state.model = self.yaml_loader.get_model_name(self.state.config_source) self.state.available_providers = [ "openai", "anthropic", @@ -153,12 +167,17 @@ def _init_prompt_client(self): ) return - self.prompt_client = VTKPromptClient( - collection_name="vtk-examples", - database_path="./db/codesage-codesage-large-v2", - verbose=False, - conversation=self.state.conversation, - ) + # Create the client if it doesn't exist, otherwise update its conversation + if not hasattr(self, 'prompt_client'): + self.prompt_client = VTKPromptClient( + collection_name="vtk-examples", + database_path="./db/codesage-codesage-large-v2", + verbose=False, + conversation=self.state.conversation, + ) + else: + # Update the conversation in the existing client + self.prompt_client.conversation = self.state.conversation except ValueError as e: self.state.error_message = str(e) @@ -274,62 +293,100 @@ def reset_camera(self): except Exception as e: print(f"Error resetting camera: {e}") - def _generate_and_execute_code(self): + def _generate_and_execute_code(self, result=None): """Generate VTK code using Anthropic API and execute it.""" self.state.is_loading = True self.state.error_message = "" + + original_query = self.state.query_text # Store original query for retries try: - # Generate code using prompt functionality - reuse existing methods - enhanced_query = self.state.query_text - if self.state.query_text: - post_prompt = get_ui_post_prompt() - enhanced_query = post_prompt + self.state.query_text - - # Reinitialize client with current settings + # Update the prompt client with current settings + # This will use the existing client or create one if it doesn't exist self._init_prompt_client() if hasattr(self.state, "error_message") and self.state.error_message: return - result = self.prompt_client.query( - enhanced_query, - api_key=self._get_api_key(), - model=self._get_model(), - base_url=self._get_base_url(), - max_tokens=int(self.state.max_tokens), - temperature=float(self.state.temperature), - top_k=int(self.state.top_k), - rag=self.state.use_rag, - retry_attempts=int(self.state.retry_attempts), - ) - # Keep UI in sync with conversation - self.state.conversation = self.prompt_client.conversation - - # Handle both code and usage information - if isinstance(result, tuple) and len(result) == 2: - generated_code, usage = result - if usage: - self.state.input_tokens = usage.prompt_tokens - self.state.output_tokens = usage.completion_tokens - else: - generated_code = result - # Reset token counts if no usage info - self.state.input_tokens = 0 - self.state.output_tokens = 0 - - self.state.generated_code = EXPLAIN_RENDERER + "\n" + generated_code - - # Execute the generated code using the existing run_code method - # But we need to modify it to work with our renderer - self._execute_with_renderer(generated_code) - + retry_attempts = int(self.state.retry_attempts) + current_query = original_query + last_error = None + last_generated_code = None + error_history = [] + + for attempt in range(retry_attempts + 1): # +1 for initial attempt + print(f"Attempt {attempt + 1} of {retry_attempts + 1}") + try: + if result is None: + # Use YAML system exclusively - UI uses ui_context prompt + print(f"Query\n{current_query}\n\n") + result = self.prompt_client.query_yaml( + current_query, + api_key=self._get_api_key(), + prompt_source=self.state.config_source, + base_url=self._get_base_url(), + rag=self.state.use_rag, + top_k=int(self.state.top_k), + retry_attempts=int(self.state.retry_attempts), + # Override parameters from UI settings when different from defaults + override_temperature=( + float(self.state.temperature) + if float(self.state.temperature) + != self.default_params.get("temperature", 0.1) + else None + ), + override_max_tokens=( + int(self.state.max_tokens) + if int(self.state.max_tokens) + != self.default_params.get("max_tokens", 1000) + else None + ), + ) + # Keep UI in sync with conversation + self.state.conversation = self.prompt_client.conversation + + # Handle generated code + generated_code, generated_explanation = result + # Reset token counts for YAML system (no usage info yet) + self.state.input_tokens = 0 + self.state.output_tokens = 0 + + self.state.generated_explanation = generated_explanation.strip() + self.state.generated_code = EXPLAIN_RENDERER + "\n" + generated_code.strip() + # Execute the generated code using the existing run_code method + # But we need to modify it to work with our renderer + self._execute_with_renderer(generated_code) + + # Success, break out of retry loop + break + except Exception as execution_error: + last_error = execution_error + last_generated_code = generated_code if 'generated_code' in locals() else None + + # Add this error to the history + error_history.append(str(execution_error)) + + # If this was the last attempt, re-raise the exception + if attempt >= retry_attempts: + raise execution_error + + current_query = QueryErrorHandler.build_retry_query( + execution_error, + original_query, + last_generated_code, + error_history, + ) + print(current_query) except ValueError as e: if "max_tokens" in str(e): self.state.error_message = f"{str(e)} Current: {self.state.max_tokens}. Try increasing max tokens." else: - self.state.error_message = f"Error generating code: {str(e)}" + self.state.error_message = f"Value error generating code: {str(e)}" except Exception as e: - self.state.error_message = f"Error generating code: {str(e)}" + # If we exhausted retries, provide more detailed error message + if last_error and last_generated_code: + self.state.error_message = f"Code execution failed after {retry_attempts + 1} attempts. Final error: {str(e)}" + else: + self.state.error_message = f"Error generating code: {str(e)}" finally: self.state.is_loading = False @@ -371,21 +428,62 @@ def _execute_with_renderer(self, code_string): self.ctrl.view_update() except Exception as e: - self.state.error_message = f"Error executing code: {str(e)}" + # Don't set error_message here - let the retry logic handle it + # Re-raise the exception so retry logic can catch it + raise e @change("conversation_object") def on_conversation_file_data_change(self, conversation_object, **_): + self.state.conversation = None + self.state.conversation_file = None invalid = ( conversation_object is None or conversation_object["type"] != "application/json" or Path(conversation_object["name"]).suffix != ".json" ) - self.state.conversation = ( - None if invalid else json.loads(conversation_object["content"]) + if invalid: + return + + self.state.conversation_file = conversation_object["name"] + content = conversation_object["content"] + if not content: + return + + self.state.conversation = json.loads(content) + + # Update the conversation in the prompt client if it exists + if hasattr(self, 'prompt_client'): + self.prompt_client.conversation = self.state.conversation + + if not invalid and content and self.state.auto_run_conversation_file: + result = self.state.conversation[-1]["content"] + generated_explanation = re.findall( + "(.*?)", result, re.DOTALL + )[0] + generated_code = re.findall("(.*?)", result, re.DOTALL)[0] + if "import vtk" not in generated_code: + generated_code = f"import vtk\n{generated_code}" + self._generate_and_execute_code([generated_code, generated_explanation]) + + @change("config_object") + def on_config_file_data_change(self, config_object, **_): + invalid = ( + config_object is None + or isinstance(config_object, str) + or config_object["type"] != "application/x-yaml" + and Path(config_object["name"]).suffix != ".yaml" + and Path(config_object["name"]).suffix != ".yml" ) - self.state.conversation_file = None if invalid else conversation_object["name"] - if not invalid and self.state.auto_run_conversation_file: - self.generate_code() + if invalid: + self.state.config_file_name = None + return + + if not config_object["content"]: + return + + self.state.config_source = config_object["content"] + self.state.config_file_name = config_object["name"] + self.clear_scene() @trigger("save_conversation") def save_conversation(self): @@ -393,6 +491,25 @@ def save_conversation(self): return "" return json.dumps(self.prompt_client.conversation, indent=2) + @trigger("save_config") + def save_config(self): + config_data = { + "model": f"{self.state.provider}/{self.state.model}", + "modelParameters": { + "max_completion_tokens": self.state.max_tokens, + "temperature": self.state.temperature, + "rag": self.state.use_rag, + "collection": self.state.collection, + "database": self.state.database, + "top_k": self.state.top_k, + "retry_attempts": self.state.retry_attempts, + "conversation": self.state.conversation_file, + }, + } + default = self.yaml_loader.load_prompt(self.state.config_source) + config_data = {**default, **config_data} + return yaml.safe_dump(config_data) + def _build_ui(self): """Build a simplified Vuetify UI.""" # Initialize drawer state as collapsed @@ -464,38 +581,39 @@ def _build_ui(self): with vuetify.VTabsWindowItem(): with vuetify.VCard(flat=True, style="mt-2"): with vuetify.VCardText(): - # Provider selection - vuetify.VSelect( - label="Provider", - v_model=("provider", "openai"), - items=("available_providers", []), - density="compact", - variant="outlined", - prepend_icon="mdi-cloud", - ) + with vuetify.VForm(): + # Provider selection + vuetify.VSelect( + label="Provider", + v_model=("provider", "openai"), + items=("available_providers", []), + density="compact", + variant="outlined", + prepend_icon="mdi-cloud", + ) - # Model selection - vuetify.VSelect( - label="Model", - v_model=("model", "gpt-4o"), - items=("available_models[provider] || []",), - density="compact", - variant="outlined", - prepend_icon="mdi-brain", - ) + # Model selection + vuetify.VSelect( + label="Model", + v_model=("model", "gpt-4o"), + items=("available_models[provider] || []",), + density="compact", + variant="outlined", + prepend_icon="mdi-brain", + ) - # API Token - vuetify.VTextField( - label="API Token", - v_model=("api_token", ""), - placeholder="Enter your API token", - type="password", - density="compact", - variant="outlined", - prepend_icon="mdi-key", - hint="Required for cloud providers", - persistent_hint=True, - ) + # API Token + vuetify.VTextField( + label="API Token", + v_model=("api_token", ""), + placeholder="Enter your API token", + type="password", + density="compact", + variant="outlined", + prepend_icon="mdi-key", + hint="Required for cloud providers", + persistent_hint=True, + ) # Local Models Tab Content with vuetify.VTabsWindowItem(): @@ -565,7 +683,10 @@ def _build_ui(self): with vuetify.VCardText(): vuetify.VSlider( label="Temperature", - v_model=("temperature", 0.1), + v_model=( + "temperature", + self.default_params.get("temperature", 0.1), + ), min=0.0, max=1.0, step=0.1, @@ -576,7 +697,10 @@ def _build_ui(self): ) vuetify.VTextField( label="Max Tokens", - v_model=("max_tokens", 1000), + v_model=( + "max_tokens", + self.default_params.get("max_tokens", 1000), + ), type="number", density="compact", variant="outlined", @@ -584,7 +708,7 @@ def _build_ui(self): ) vuetify.VTextField( label="Retry Attempts", - v_model=("retry_attempts", 1), + v_model=("retry_attempts", 5), type="number", min=1, max=5, @@ -598,16 +722,60 @@ def _build_ui(self): "⚙️ Files", hide_details=True, density="compact" ) with vuetify.VCardText(): + with html.Div( + classes="d-flex align-center justify-space-between mb-2" + ): + with vuetify.VTooltip( + text=("config_file_name", "No config loaded"), + location="top", + disabled=("!config_source",), + ): + with vuetify.Template(v_slot_activator="{ props }"): + vuetify.VFileInput( + label="Configuration File", + v_model=("config_object", None), + accept=".yaml, .yml", + density="compact", + variant="solo", + prepend_icon="mdi-file-cog-outline", + hide_details="auto", + classes="py-1 pr-1 mr-1 text-truncate", + open_on_focus=False, + clearable=False, + v_bind="props", + rules=[ + "[utils.vtk_prompt.rules.yaml_file]" + ], + ) + with vuetify.VTooltip( + text="Download configuration file", + location="right", + ): + with vuetify.Template(v_slot_activator="{ props }"): + with vuetify.VBtn( + icon=True, + density="comfortable", + color="secondary", + rounded="lg", + v_bind="props", + disabled=("!config_source",), + click="utils.download(" + + "`config_${new Date().toISOString()}.prompt.yaml`," + + "trigger('save_config')," + + "'application/yaml'" + + ")", + ): + vuetify.VIcon("mdi-file-download-outline") vuetify.VCheckbox( label="Run new conversation files", v_model=("auto_run_conversation_file", True), - prepend_icon="mdi-file-refresh-outline", + prepend_icon="mdi-run", density="compact", color="primary", hide_details=True, ) with html.Div( - classes="d-flex align-center justify-space-between" + classes="d-flex align-center justify-space-between mb-2" ): with vuetify.VTooltip( text=("conversation_file", "No file loaded"), @@ -652,26 +820,61 @@ def _build_ui(self): vuetify.VIcon("mdi-file-download-outline") with layout.content: - with vuetify.VContainer(fluid=True, classes="fill-height"): + with vuetify.VContainer( + classes="fluid fill-height", style="min-width: 100%;" + ): with vuetify.VRow(rows=12, classes="fill-height"): # Left column - Generated code view with vuetify.VCol(cols=6, classes="fill-height"): - with vuetify.VCard(classes="mb-2", style="height: 100%;"): - vuetify.VCardTitle("Generated Code") - with vuetify.VCardText( - classes="overflow-auto", + with vuetify.VExpansionPanels( + v_model=("explanation_expanded", [0, 1]), + classes="fill-height", + multiple=True, + ): + with vuetify.VExpansionPanel( + classes="mt-1", + style="height: fit-content; max-height: 30%;", ): - vuetify.VTextarea( - v_model=("generated_code", ""), - readonly=True, - solo=True, - hide_details=True, - no_resize=True, - auto_grow=True, - classes="overflow-y", - style="font-family: monospace;", - placeholder="Generated VTK code will appear here...", + vuetify.VExpansionPanelTitle( + "Explanation", classes="text-h6" ) + with vuetify.VExpansionPanelText( + style="overflow: hidden;" + ): + vuetify.VTextarea( + v_model=("generated_explanation", ""), + readonly=True, + solo=True, + hide_details=True, + no_resize=True, + classes="overflow-y-auto fill-height", + placeholder="Explanation will appear here...", + ) + with vuetify.VExpansionPanel( + classes="mt-1 fill-height", + readonly=True, + style=( + "explanation_expanded.length > 1 ? 'max-height: 75%;' : 'max-height: 95%;'", + ), + ): + vuetify.VExpansionPanelTitle( + "Generated Code", + collapse_icon=False, + classes="text-h6", + ) + with vuetify.VExpansionPanelText( + style="overflow: hidden; height: 90%;" + ): + vuetify.VTextarea( + v_model=("generated_code", ""), + readonly=True, + solo=True, + hide_details=True, + no_resize=True, + classes="overflow-y-auto fill-height", + style="font-family: monospace;", + placeholder="Generated VTK code will appear here...", + ) # Right column - VTK viewer and prompt with vuetify.VCol(cols=6, classes="fill-height"): diff --git a/src/vtk_prompt/yaml_prompt_loader.py b/src/vtk_prompt/yaml_prompt_loader.py new file mode 100644 index 0000000..e7b2203 --- /dev/null +++ b/src/vtk_prompt/yaml_prompt_loader.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 + +import yaml +from pathlib import Path +from typing import Dict, List, Any, Optional, Union +import vtk +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass + +PYTHON_VERSION = ">=3.10" +VTK_VERSION = vtk.__version__ + + +@dataclass +class LoaderConfig: + prompts_dir: Path + default_extension: str = ".prompt.yml" + encoding: str = "utf-8" + + +class PromptSource(ABC): + @abstractmethod + def get_content(self) -> str: + """Get the raw content as string.""" + pass + + +class FilePromptSource(PromptSource): + # CLI provides a filename + def __init__(self, filename: str, config: LoaderConfig): + self.filename = filename + self.config = config + + def get_content(self) -> str: + filename = self.filename + if not filename.endswith(self.config.default_extension): + filename = f"{filename}{self.config.default_extension}" + + prompt_path = self.config.prompts_dir / filename + if not prompt_path.exists(): + raise FileNotFoundError(f"Prompt file not found: {prompt_path}") + + try: + with open(prompt_path, "r", encoding=self.config.encoding) as f: + return f.read() + except IOError as e: + raise IOError(f"Failed to read prompt file {prompt_path}: {e}") + + +class ContentPromptSource(PromptSource): + # UI provides a content blob + def __init__(self, content: Union[str, bytes]): + if isinstance(content, bytes): + self.content = content.decode('utf-8') + else: + self.content = content + + def get_content(self) -> str: + return self.content + + +class GitHubModelYAMLLoader: + """Loader for GitHub Models YAML prompt files.""" + + def __init__(self, prompts_dir: Optional[Path] = None): + """Initialize with prompts directory path.""" + if prompts_dir is None: + # Default to prompts directory in repository root + prompts_dir = Path(__file__).parent.parent.parent / "prompts" + + self.config = LoaderConfig(prompts_dir=Path(prompts_dir)) + + @classmethod + def from_file(cls, filename: str, prompts_dir: Optional[Path] = None) -> 'GitHubModelYAMLLoader': + loader = cls(prompts_dir) + loader._current_source = FilePromptSource(filename, loader.config) + return loader + + @classmethod + def from_content(cls, content: Union[str, bytes]) -> 'GitHubModelYAMLLoader': + loader = cls() + loader._current_source = ContentPromptSource(content) + return loader + + def _get_prompt_source(self, prompt: Union[str, bytes]) -> PromptSource: + if ( + isinstance(prompt, (str, bytes)) and + not isinstance(prompt, str) or + isinstance(prompt, bytes) + ): + return ContentPromptSource(prompt) + else: + return FilePromptSource(prompt, self.config) + + def _parse_yaml_content(self, content: str) -> Dict[str, Any]: + try: + return yaml.load(content, Loader=get_loader()) + except yaml.YAMLError as e: + raise ValueError(f"Failed to parse YAML content: {e}") + + def load_prompt(self, prompt: Union[str, bytes]) -> Dict[str, Any]: + """Load a YAML prompt file. + + Args: + prompt: Name of the prompt file (with or without .prompt.yml extension) or binary blob + of the prompt file content + + Returns: + Parsed YAML content as dictionary + """ + source = self._get_prompt_source(prompt) + content = source.get_content() + return self._parse_yaml_content(content) + + def save_prompt(self, prompt: str) -> None: + # TODO + return + + def substitute_variables(self, content: str, variables: Dict[str, str]) -> str: + """Substitute template variables in content using GitHub Models format. + + Args: + content: Template content with {{variable}} placeholders + variables: Dictionary of variable name -> value mappings + + Returns: + Content with variables substituted + """ + # Add default variables + default_vars = {"VTK_VERSION": VTK_VERSION, "PYTHON_VERSION": PYTHON_VERSION} + variables = {**default_vars, **variables} + + # Handle conditional blocks like {{#if variable}}...{{/if}} + def handle_conditionals(text: str) -> str: + # Simple conditional handling for {{#if variable}}...{{/if}} + conditional_pattern = r"\{\{#if\s+(\w+)\}\}(.*?)\{\{/if\}\}" + + def replace_conditional(match): + var_name = match.group(1) + block_content = match.group(2) + # Include block if variable exists and is truthy + if var_name in variables and variables[var_name]: + return block_content + return "" + + return re.sub( + conditional_pattern, replace_conditional, text, flags=re.DOTALL + ) + + # First handle conditionals + content = handle_conditionals(content) + + # Then substitute regular variables + for var_name, var_value in variables.items(): + placeholder = f"{{{{{var_name}}}}}" + content = content.replace(placeholder, str(var_value)) + + return content + + def build_messages( + self, + prompt: str | bytes, + variables: Dict[str, str] = None, + system_only: bool = False, + ) -> List[Dict[str, str]]: + """Build messages list from YAML prompt with variable substitution. + + Args: + prompt: Name of the prompt file or binary blob of the prompt file content + variables: Variables to substitute in the template + system_only: If True, return only the first system message content as string + + Returns: + List of message dictionaries compatible with OpenAI API, or string if system_only=True + """ + if variables is None: + variables = {} + + prompt_data = self.load_prompt(prompt) + messages = prompt_data.get("messages", []) + + # Substitute variables in each message + processed_messages = [] + for message in messages: + processed_message = { + "role": message["role"], + "content": self.substitute_variables(message["content"], variables), + } + processed_messages.append(processed_message) + + # If system_only is True, return only the first system message content + if system_only: + for message in processed_messages: + if message["role"] == "system": + return message["content"] + return "" # No system message found + + return processed_messages + + def get_model_parameters(self, prompt: str | bytes) -> Dict[str, Any]: + """Get model parameters from YAML prompt. + + Args: + prompt: Name of the prompt file or binary blob of the prompt file content + + Returns: + Dictionary of model parameters + """ + prompt_data = self.load_prompt(prompt) + return prompt_data.get("modelParameters", {}) + + def get_model_name(self, prompt: str | bytes) -> str: + """Get model name from YAML prompt. + + Args: + prompt: Name of the prompt file or binary blob of the prompt file content + + Returns: + Model name string + """ + prompt_data = self.load_prompt(prompt) + return prompt_data.get("model", "openai/gpt-4o-mini").split("/")[1] + + def get_model_provider(self, prompt: str | bytes) -> str: + """Get model provider from YAML prompt. + + Args: + prompt: Name of the prompt file or binary blob of the prompt file content + + Returns: + Model name string + """ + prompt_data = self.load_prompt(prompt) + return prompt_data.get("model", "openai/gpt-4o-mini").split("/")[0] + + def list_available_prompts(self) -> List[str]: + """List all available prompt files. + + Returns: + List of prompt file names (without extension) + """ + if not self.config.prompts_dir.exists(): + return [] + + prompt_files = list(self.config.prompts_dir.glob("*.prompt.yml")) + return [f.stem.replace(".prompt", "") for f in prompt_files] + + +# Convenience functions for backward compatibility +def get_yaml_prompt_messages( + prompt: str | bytes, variables: Dict[str, str] = None +) -> List[Dict[str, str]]: + """Get messages from a YAML prompt file. + + Args: + prompt: Name of the prompt file or binary blob of the prompt file content + variables: Variables to substitute in the template + + Returns: + List of message dictionaries + """ + loader = GitHubModelYAMLLoader() + return loader.build_messages(prompt, variables) + + +def get_yaml_prompt_parameters(prompt: str | bytes) -> Dict[str, Any]: + """Get model parameters from a YAML prompt file. + + Args: + prompt: Name of the prompt file or binary blob of the prompt file content + + Returns: + Dictionary of model parameters + """ + loader = GitHubModelYAMLLoader() + return loader.get_model_parameters(prompt) + + +def include_constructor(loader, node): + data = loader.construct_scalar(node) + + match = re.search(r'\s', data) + if match: + file = data[:match.start()].strip() + other_text = data[match.start():].strip() + else: + file = data.strip() + other_text = "" + + with Path(file).resolve().open() as f: + file_content = yaml.safe_load(f) + + return file_content + "\n" + other_text + +def get_loader(): + loader = yaml.SafeLoader + loader.add_constructor("!include", include_constructor) + return loader