Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions python/agents/machine-learning-engineering/deployment/deploy.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Deployment script for Machine Learning Engineering Agent"""


import os

import vertexai
from absl import app, flags
from machine_learning_engineering.agent import root_agent
from dotenv import load_dotenv
from machine_learning_engineering.agent import root_agent
from vertexai import agent_engines
from vertexai.preview.reasoning_engines import AdkApp

Expand Down Expand Up @@ -73,18 +72,10 @@ def main(argv: list[str]) -> None:
load_dotenv()

project_id = (
FLAGS.project_id
if FLAGS.project_id
else os.getenv("GOOGLE_CLOUD_PROJECT")
)
location = (
FLAGS.location if FLAGS.location else os.getenv("GOOGLE_CLOUD_LOCATION")
)
bucket = (
FLAGS.bucket
if FLAGS.bucket
else os.getenv("GOOGLE_CLOUD_STORAGE_BUCKET")
FLAGS.project_id if FLAGS.project_id else os.getenv("GOOGLE_CLOUD_PROJECT")
)
location = FLAGS.location if FLAGS.location else os.getenv("GOOGLE_CLOUD_LOCATION")
bucket = FLAGS.bucket if FLAGS.bucket else os.getenv("GOOGLE_CLOUD_STORAGE_BUCKET")

print(f"PROJECT: {project_id}")
print(f"LOCATION: {location}")
Expand All @@ -97,9 +88,7 @@ def main(argv: list[str]) -> None:
print("Missing required environment variable: GOOGLE_CLOUD_LOCATION")
return
elif not bucket:
print(
"Missing required environment variable: GOOGLE_CLOUD_STORAGE_BUCKET"
)
print("Missing required environment variable: GOOGLE_CLOUD_STORAGE_BUCKET")
return

vertexai.init(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,10 @@ def main(argv: list[str]) -> None: # pylint: disable=unused-argument
load_dotenv()

project_id = (
FLAGS.project_id
if FLAGS.project_id
else os.getenv("GOOGLE_CLOUD_PROJECT")
)
location = (
FLAGS.location if FLAGS.location else os.getenv("GOOGLE_CLOUD_LOCATION")
)
bucket = (
FLAGS.bucket
if FLAGS.bucket
else os.getenv("GOOGLE_CLOUD_STORAGE_BUCKET")
FLAGS.project_id if FLAGS.project_id else os.getenv("GOOGLE_CLOUD_PROJECT")
)
location = FLAGS.location if FLAGS.location else os.getenv("GOOGLE_CLOUD_LOCATION")
bucket = FLAGS.bucket if FLAGS.bucket else os.getenv("GOOGLE_CLOUD_STORAGE_BUCKET")

project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
location = os.getenv("GOOGLE_CLOUD_LOCATION")
Expand All @@ -50,9 +42,7 @@ def main(argv: list[str]) -> None: # pylint: disable=unused-argument
print("Missing required environment variable: GOOGLE_CLOUD_LOCATION")
return
elif not bucket:
print(
"Missing required environment variable: GOOGLE_CLOUD_STORAGE_BUCKET"
)
print("Missing required environment variable: GOOGLE_CLOUD_STORAGE_BUCKET")
return

vertexai.init(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
import dotenv
import pytest
from google.adk.evaluation.agent_evaluator import AgentEvaluator

from machine_learning_engineering.shared_libraries import config


pytest_plugins = ("pytest_asyncio",)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import pytest
from google.adk.evaluation.agent_evaluator import AgentEvaluator

from machine_learning_engineering.shared_libraries import config


pytest_plugins = ("pytest_asyncio",)


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Machine Learning Engineer: automate the implementation of ML models."""

from . import agent
from . import agent # noqa: F401
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
"""Demonstration of Machine Learning Engineering Agent using Agent Development Kit"""

import os
import json
import os
from typing import Optional
from google.genai import types
from google.adk.agents import callback_context as callback_context_module

from google.adk import agents
from machine_learning_engineering.sub_agents.initialization import agent as initialization_agent_module
from machine_learning_engineering.sub_agents.refinement import agent as refinement_agent_module
from machine_learning_engineering.sub_agents.ensemble import agent as ensemble_agent_module
from machine_learning_engineering.sub_agents.submission import agent as submission_agent_module

from google.adk.agents import callback_context as callback_context_module
from google.genai import types
from machine_learning_engineering import prompt
from machine_learning_engineering.sub_agents.ensemble import (
agent as ensemble_agent_module,
)
from machine_learning_engineering.sub_agents.initialization import (
agent as initialization_agent_module,
)
from machine_learning_engineering.sub_agents.refinement import (
agent as refinement_agent_module,
)
from machine_learning_engineering.sub_agents.submission import (
agent as submission_agent_module,
)


def save_state(
callback_context: callback_context_module.CallbackContext
callback_context: callback_context_module.CallbackContext,
) -> Optional[types.Content]:
"""Prints the current state of the callback context."""
workspace_dir = callback_context.state.get("workspace_dir", "")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Defines the prompts in the Machine Learning Engineering Agent."""


SYSTEM_INSTRUCTION ="""You are a Machine Learning Engineering Multi Agent System.
SYSTEM_INSTRUCTION = """You are a Machine Learning Engineering Multi Agent System.
"""

FRONTDOOR_INSTRUCTION="""
FRONTDOOR_INSTRUCTION = """
You are a machine learning engineer given a machine learning task for which to engineer a solution.

- If the user asks questions that can be answered directly, answer it directly without calling any additional agents.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
"""Utility functions for leakage check agent."""

from typing import Optional
import json
import functools
import json
from typing import Optional

from google.adk import agents
from google.adk.agents import callback_context as callback_context_module
from google.adk.models import llm_response as llm_response_module
from google.adk.models import llm_request as llm_request_module
from google.adk.models import llm_response as llm_response_module
from google.genai import types

from machine_learning_engineering.shared_libraries import data_leakage_prompt
from machine_learning_engineering.shared_libraries import code_util
from machine_learning_engineering.shared_libraries import common_util
from machine_learning_engineering.shared_libraries import config
from machine_learning_engineering.shared_libraries import (
code_util,
common_util,
config,
data_leakage_prompt,
)


def get_check_leakage_agent_instruction(
Expand Down Expand Up @@ -50,11 +51,11 @@ def get_refine_leakage_agent_instruction(

def parse_leakage_status(text: str) -> tuple[str, str]:
"""Parses the leakage status from the text."""
start_idx, end_idx = text.find("["), text.rfind("]")+1
start_idx, end_idx = text.find("["), text.rfind("]") + 1
text = text[start_idx:end_idx]
result = json.loads(text)[0]
leakage_status = result["leakage_status"]
code_block = result["code_block"].replace(f"```python", "").replace("```", "")
code_block = result["code_block"].replace("```python", "").replace("```", "")
return leakage_status, code_block


Expand All @@ -80,7 +81,7 @@ def update_extract_status(
extract_status = True
else:
extract_status = code_block in code
except:
except Exception:
code_block = ""
extract_status = False
extract_status_key = code_util.get_name_with_prefix_and_suffix(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Code related utility functions."""

from typing import Any
import subprocess
import os
import subprocess
import time
from typing import Any

from google.adk.agents import callback_context as callback_context_module

Expand Down Expand Up @@ -175,7 +175,11 @@ def get_run_code_condition(
if "exit()" not in raw_code:
return True
elif agent_name.startswith("submission"):
if "debug_agent" not in agent_name and "exit()" not in raw_code and "submission.csv" in raw_code:
if (
"debug_agent" not in agent_name
and "exit()" not in raw_code
and "submission.csv" in raw_code
):
return True
if "debug_agent" in agent_name and "exit()" not in raw_code:
return True
Expand Down Expand Up @@ -249,7 +253,7 @@ def evaluate_code(
try:
score = extract_performance_from_text(result_dict.get("stdout", ""))
score = float(score)
except:
except Exception:
score = 1e9 if lower else 0
else:
score = 1e9 if lower else 0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
"""Common utility functions."""

import random
import torch
import os
import random
import shutil
import numpy as np

import numpy as np
import torch
from google.adk.models import llm_response


def get_text_from_response(
response: llm_response.LlmResponse,
) -> str:
"""Extracts text from response."""
final_text = ""
if response.content and response.content.parts:
num_parts = len(response.content.parts)
for i in range(num_parts):
if hasattr(response.content.parts[i], "text"):
final_text += response.content.parts[i].text
return final_text
"""Extracts text from response."""
final_text = ""
if response.content and response.content.parts:
num_parts = len(response.content.parts)
for i in range(num_parts):
if hasattr(response.content.parts[i], "text"):
final_text += response.content.parts[i].text
return final_text


def set_random_seed(seed: int) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,61 @@
@dataclasses.dataclass
class DefaultConfig:
"""Default configuration."""
data_dir: str = "./machine_learning_engineering/tasks/" # the directory path where the machine learning tasks and their data are stored.
task_name: str = "california-housing-prices" # The name of the specific task to be loaded and processed.

data_dir: str = (
"./machine_learning_engineering/tasks/" # the directory path where the machine learning tasks and their data are stored.
)
task_name: str = (
"california-housing-prices" # The name of the specific task to be loaded and processed.
)
task_type: str = "Tabular Regression" # The type of machine learning problem.
lower: bool = True # True if a lower value of the metric is better.
workspace_dir: str = "./machine_learning_engineering/workspace/" # Directory used for saving intermediate outputs, results, logs.
agent_model: str = os.environ.get("ROOT_AGENT_MODEL", "gemini-2.0-flash-001") # Name the LLM model to be used by the agent.
workspace_dir: str = (
"./machine_learning_engineering/workspace/" # Directory used for saving intermediate outputs, results, logs.
)
agent_model: str = os.environ.get(
"ROOT_AGENT_MODEL", "gemini-2.0-flash-001"
) # Name the LLM model to be used by the agent.
task_description: str = "" # The detailed description of the task.
task_summary: str = "" # The concise summary of the task.
start_time: float = 0.0 # Timestamp indicating the start time of the task. Typically represented in seconds since the epoch.
seed: int = 42 # The random seed value used to ensure reproducibility of experiments.
start_time: float = (
0.0 # Timestamp indicating the start time of the task. Typically represented in seconds since the epoch.
)
seed: int = (
42 # The random seed value used to ensure reproducibility of experiments.
)
exec_timeout: int = 600 # The maximum time in seconds allowed to complete the task.
num_solutions: int = 2 # The number of different solutions to generate or attempt for the given task.
num_model_candidates: int = 2 # The number of different model architectures or hyperparameter sets to consider as candidates.
num_solutions: int = (
2 # The number of different solutions to generate or attempt for the given task.
)
num_model_candidates: int = (
2 # The number of different model architectures or hyperparameter sets to consider as candidates.
)
max_retry: int = 10 # The maximum number of times to retry a failed operation.
max_debug_round: int = 5 # The maximum number of iterations or rounds allowed for the debugging step.
max_rollback_round: int = 2 # The maximum number of times the system can rollback to a previous state, in case of errors or poor performance.
inner_loop_round: int = 1 # The number of iterations or rounds to be executed within an inner loop of the system.
outer_loop_round: int = 1 # The number of iterations or rounds to be executed within the outer loop, which might encompass multiple inner loops.
ensemble_loop_round: int = 1 # The number of rounds or iterations dedicated to ensembling, combining multiple models or solutions.
num_top_plans: int = 2 # The number of highest-scoring plans or strategies to select or retain.
use_data_leakage_checker: bool = False # Enable (`True`) or disable (`False`) a check for data leakage in the machine learning pipeline.
use_data_usage_checker: bool = False # Enable (`True`) or disable (`False`) a check for how data is being used, potentially for compliance or best practices.
max_debug_round: int = (
5 # The maximum number of iterations or rounds allowed for the debugging step.
)
max_rollback_round: int = (
2 # The maximum number of times the system can rollback to a previous state, in case of errors or poor performance.
)
inner_loop_round: int = (
1 # The number of iterations or rounds to be executed within an inner loop of the system.
)
outer_loop_round: int = (
1 # The number of iterations or rounds to be executed within the outer loop, which might encompass multiple inner loops.
)
ensemble_loop_round: int = (
1 # The number of rounds or iterations dedicated to ensembling, combining multiple models or solutions.
)
num_top_plans: int = (
2 # The number of highest-scoring plans or strategies to select or retain.
)
use_data_leakage_checker: bool = (
False # Enable (`True`) or disable (`False`) a check for data leakage in the machine learning pipeline.
)
use_data_usage_checker: bool = (
False # Enable (`True`) or disable (`False`) a check for how data is being used, potentially for compliance or best practices.
)


CONFIG = DefaultConfig()
Loading