Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 59 additions & 20 deletions crawl4ai/adaptive_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class AdaptiveConfig:
# Embedding strategy parameters
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
embedding_llm_config: Optional[Union[LLMConfig, Dict]] = None # Separate config for embeddings
query_llm_config: Optional[Union[LLMConfig, Dict]] = None # Separate config for embeddings
n_query_variations: int = 10
coverage_threshold: float = 0.85
alpha_shape_alpha: float = 0.5
Expand Down Expand Up @@ -252,7 +253,7 @@ def validate(self):
assert 0 <= self.embedding_min_confidence_threshold <= 1, "embedding_min_confidence_threshold must be between 0 and 1"

@property
def _embedding_llm_config_dict(self) -> Optional[Dict]:
def _llm_config_dict(self) -> Optional[Dict]:
"""Convert LLMConfig to dict format for backward compatibility."""
if self.embedding_llm_config is None:
return None
Expand Down Expand Up @@ -614,12 +615,19 @@ def _get_document_terms(self, crawl_result: CrawlResult) -> List[str]:
return self._tokenize(content.lower())


# strategy = EmbeddingStrategy(
# embedding_model=self.config.embedding_model,
# llm_config=self.config.embedding_llm_config
# )
# -> Forwards the two arguments in AdaptiveConfig
class EmbeddingStrategy(CrawlStrategy):
"""Embedding-based adaptive crawling using semantic space coverage"""

def __init__(self, embedding_model: str = None, llm_config: Union[LLMConfig, Dict] = None):
def __init__(self, embedding_model: str = None, embedding_llm_config: Union[LLMConfig, Dict] = None, query_llm_config: Union[LLMConfig, Dict] = None):
self.embedding_model = embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
self.llm_config = llm_config
self.embedding_llm_config = embedding_llm_config # For embeddings only
self.query_llm_config = query_llm_config # For query generation only

self._embedding_cache = {}
self._link_embedding_cache = {} # Cache for link embeddings
self._validation_passed = False # Track if validation passed
Expand All @@ -632,6 +640,19 @@ def __init__(self, embedding_model: str = None, llm_config: Union[LLMConfig, Dic

def _get_embedding_llm_config_dict(self) -> Dict:
"""Get embedding LLM config as dict with fallback to default."""
# First check if we have a direct embedding_llm_config
if self.embedding_llm_config:
if isinstance(self.embedding_llm_config, dict):
return self.embedding_llm_config
else:
# Convert LLMConfig object to dict
return {
'provider': self.embedding_llm_config.provider,
'api_token': self.embedding_llm_config.api_token,
'base_url': self.embedding_llm_config.base_url
}

# Then check if we have it from AdaptiveConfig
if hasattr(self, 'config') and self.config:
config_dict = self.config._embedding_llm_config_dict
if config_dict:
Expand All @@ -642,11 +663,38 @@ def _get_embedding_llm_config_dict(self) -> Dict:
'provider': 'openai/text-embedding-3-small',
'api_token': os.getenv('OPENAI_API_KEY')
}

def _get_query_llm_config_dict(self) -> Dict:
"""Get query generation LLM config as dict with fallback to default."""
# First check if we have a direct query_llm_config
if self.query_llm_config:
if isinstance(self.query_llm_config, dict):
return self.query_llm_config
else:
# Convert LLMConfig object to dict
return {
'provider': self.query_llm_config.provider,
'api_token': self.query_llm_config.api_token,
'base_url': self.query_llm_config.base_url
}

# Then check if we have it from AdaptiveConfig
if hasattr(self, 'config') and self.config:
config_dict = self.config._query_llm_config_dict
if config_dict:
return config_dict

# Fallback to default if no config provided
return {
'provider': 'openai/gpt-4o-mini',
'api_token': os.getenv('OPENAI_API_KEY')
}

async def _get_embeddings(self, texts: List[str]) -> Any:
"""Get embeddings using configured method"""
from .utils import get_text_embeddings
embedding_llm_config = self._get_embedding_llm_config_dict()

return await get_text_embeddings(
texts,
embedding_llm_config,
Expand Down Expand Up @@ -712,27 +760,17 @@ async def map_query_semantic_space(self, query: str, n_synthetic: int = 10) -> A

Return as a JSON array of strings."""

# Use the LLM for query generation
# Convert LLMConfig to dict if needed
llm_config_dict = None
if self.llm_config:
if isinstance(self.llm_config, dict):
llm_config_dict = self.llm_config
else:
# Convert LLMConfig object to dict
llm_config_dict = {
'provider': self.llm_config.provider,
'api_token': self.llm_config.api_token
}

provider = llm_config_dict.get('provider', 'openai/gpt-4o-mini') if llm_config_dict else 'openai/gpt-4o-mini'
api_token = llm_config_dict.get('api_token') if llm_config_dict else None
query_llm_config_dict = self._get_query_llm_config_dict()
provider = query_llm_config_dict.get('provider', 'openai/gpt-4o-mini')
api_token = query_llm_config_dict.get('api_token')
base_url = query_llm_config_dict.get('base_url')

response = perform_completion_with_backoff(
provider=provider,
prompt_with_variables=prompt,
api_token=api_token,
json_response=True
json_response=True,
base_url=base_url,
)

variations = json.loads(response.choices[0].message.content)
Expand Down Expand Up @@ -1298,7 +1336,8 @@ def _create_strategy(self, strategy_name: str) -> CrawlStrategy:
elif strategy_name == "embedding":
strategy = EmbeddingStrategy(
embedding_model=self.config.embedding_model,
llm_config=self.config.embedding_llm_config
embedding_llm_config=self.config.embedding_llm_config,
query_llm_config=self.config.query_llm_config # Pass both configs
)
strategy.config = self.config # Pass config to strategy
return strategy
Expand Down