diff --git a/crawl4ai/async_configs.py b/crawl4ai/async_configs.py index bfa0d398b..b93745846 100644 --- a/crawl4ai/async_configs.py +++ b/crawl4ai/async_configs.py @@ -1,3 +1,4 @@ +import json import os from typing import Union import warnings @@ -1792,7 +1793,8 @@ def __init__( frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, stop: Optional[List[str]] = None, - n: Optional[int] = None, + n: Optional[int] = None, + **kwargs, ): """Configuaration class for LLM provider and API token.""" self.provider = provider @@ -1803,13 +1805,24 @@ def __init__( else: # Check if given provider starts with any of key in PROVIDER_MODELS_PREFIXES # If not, check if it is in PROVIDER_MODELS + prefixes = PROVIDER_MODELS_PREFIXES.keys() if any(provider.startswith(prefix) for prefix in prefixes): - selected_prefix = next( - (prefix for prefix in prefixes if provider.startswith(prefix)), - None, - ) - self.api_token = PROVIDER_MODELS_PREFIXES.get(selected_prefix) + + if provider.startswith("vertex_ai"): + credential_path = PROVIDER_MODELS_PREFIXES["vertex_ai"] + + with open(credential_path, "r") as file: + vertex_credentials = json.load(file) + # Convert to JSON string + self.vertex_credentials = json.dumps(vertex_credentials) + self.api_token = None + else: + selected_prefix = next( + (prefix for prefix in prefixes if provider.startswith(prefix)), + None, + ) + self.api_token = PROVIDER_MODELS_PREFIXES.get(selected_prefix) else: self.provider = DEFAULT_PROVIDER self.api_token = os.getenv(DEFAULT_PROVIDER_API_KEY) @@ -1834,11 +1847,11 @@ def from_kwargs(kwargs: dict) -> "LLMConfig": frequency_penalty=kwargs.get("frequency_penalty"), presence_penalty=kwargs.get("presence_penalty"), stop=kwargs.get("stop"), - n=kwargs.get("n") + n=kwargs.get("n"), ) def to_dict(self): - return { + result = { "provider": self.provider, "api_token": self.api_token, "base_url": self.base_url, @@ -1848,8 +1861,11 @@ def to_dict(self): "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, "stop": self.stop, - "n": self.n + "n": self.n, } + if self.provider.startswith("vertex_ai"): + result["extra_args"] = {"vertex_credentials": self.vertex_credentials} + return result def clone(self, **kwargs): """Create a copy of this configuration with updated values. @@ -1864,6 +1880,7 @@ def clone(self, **kwargs): config_dict.update(kwargs) return LLMConfig.from_kwargs(config_dict) + class SeedingConfig: """ Configuration class for URL discovery and pre-validation via AsyncUrlSeeder. diff --git a/crawl4ai/config.py b/crawl4ai/config.py index 08f56b832..2e61ffee8 100644 --- a/crawl4ai/config.py +++ b/crawl4ai/config.py @@ -22,6 +22,11 @@ "anthropic/claude-3-opus-20240229": os.getenv("ANTHROPIC_API_KEY"), "anthropic/claude-3-sonnet-20240229": os.getenv("ANTHROPIC_API_KEY"), "anthropic/claude-3-5-sonnet-20240620": os.getenv("ANTHROPIC_API_KEY"), + "vertex_ai/gemini-2.0-flash-lite": os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), + 'vertex_ai/gemini-2.0-flash': os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), + 'vertex_ai/gemini-2.5-flash': os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), + 'vertex_ai/gemini-2.5-pro': os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), + 'vertex_ai/gemini-3-pro-preview': os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), "gemini/gemini-pro": os.getenv("GEMINI_API_KEY"), 'gemini/gemini-1.5-pro': os.getenv("GEMINI_API_KEY"), 'gemini/gemini-2.0-flash': os.getenv("GEMINI_API_KEY"), @@ -35,6 +40,7 @@ "openai": os.getenv("OPENAI_API_KEY"), "anthropic": os.getenv("ANTHROPIC_API_KEY"), "gemini": os.getenv("GEMINI_API_KEY"), + "vertex_ai": os.getenv("GOOGLE_APPLICATION_CREDENTIALS"), "deepseek": os.getenv("DEEPSEEK_API_KEY"), } diff --git a/crawl4ai/extraction_strategy.py b/crawl4ai/extraction_strategy.py index 4a64e5d46..6e6c93907 100644 --- a/crawl4ai/extraction_strategy.py +++ b/crawl4ai/extraction_strategy.py @@ -574,7 +574,10 @@ def __init__( self.overlap_rate = overlap_rate self.word_token_rate = word_token_rate self.apply_chunking = apply_chunking - self.extra_args = kwargs.get("extra_args", {}) + # Merge both extra kwargs + self.extra_args = kwargs.get("extra_args", {}) | self.llm_config.to_dict().get( + "extra_args", {} + ) if not self.apply_chunking: self.chunk_token_threshold = 1e9 self.verbose = verbose diff --git a/deploy/docker/c4ai-code-context.md b/deploy/docker/c4ai-code-context.md index c18fbc784..6acc318c4 100644 --- a/deploy/docker/c4ai-code-context.md +++ b/deploy/docker/c4ai-code-context.md @@ -1269,7 +1269,8 @@ class LLMConfig: frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, stop: Optional[List[str]] = None, - n: Optional[int] = None, + n: Optional[int] = None, + **kwargs, ): """Configuaration class for LLM provider and API token.""" self.provider = provider @@ -1280,13 +1281,25 @@ class LLMConfig: else: # Check if given provider starts with any of key in PROVIDER_MODELS_PREFIXES # If not, check if it is in PROVIDER_MODELS + prefixes = PROVIDER_MODELS_PREFIXES.keys() if any(provider.startswith(prefix) for prefix in prefixes): - selected_prefix = next( - (prefix for prefix in prefixes if provider.startswith(prefix)), - None, - ) - self.api_token = PROVIDER_MODELS_PREFIXES.get(selected_prefix) + + if provider.startswith("vertex_ai"): + credential_path = PROVIDER_MODELS_PREFIXES["vertex_ai"] + + with open(credential_path, "r") as file: + vertex_credentials = json.load(file) + # Convert to JSON string + self.vertex_credentials = json.dumps(vertex_credentials) + + self.api_token = None + else: + selected_prefix = next( + (prefix for prefix in prefixes if provider.startswith(prefix)), + None, + ) + self.api_token = PROVIDER_MODELS_PREFIXES.get(selected_prefix) else: self.provider = DEFAULT_PROVIDER self.api_token = os.getenv(DEFAULT_PROVIDER_API_KEY) @@ -1311,11 +1324,11 @@ class LLMConfig: frequency_penalty=kwargs.get("frequency_penalty"), presence_penalty=kwargs.get("presence_penalty"), stop=kwargs.get("stop"), - n=kwargs.get("n") + n=kwargs.get("n"), ) def to_dict(self): - return { + result = { "provider": self.provider, "api_token": self.api_token, "base_url": self.base_url, @@ -1325,8 +1338,11 @@ class LLMConfig: "frequency_penalty": self.frequency_penalty, "presence_penalty": self.presence_penalty, "stop": self.stop, - "n": self.n + "n": self.n, } + if self.provider.startswith("vertex_ai"): + result["extra_args"] = {"vertex_credentials": self.vertex_credentials} + return result def clone(self, **kwargs): """Create a copy of this configuration with updated values. @@ -4094,7 +4110,9 @@ class LLMExtractionStrategy(ExtractionStrategy): self.overlap_rate = overlap_rate self.word_token_rate = word_token_rate self.apply_chunking = apply_chunking - self.extra_args = kwargs.get("extra_args", {}) + self.extra_args = kwargs.get("extra_args", {}) | self.llm_config.to_dict().get( + "extra_args", {} + ) if not self.apply_chunking: self.chunk_token_threshold = 1e9 self.verbose = verbose