11import base64
22import dill # type: ignore
33import json
4- import requests # type: ignore
4+ import requests # type: ignore
55import os
66import time
77from typing import Dict , Optional , Sequence
@@ -22,6 +22,7 @@ def __init__(self, *args, **kwargs):
2222
2323 self .document_helper = DocumentHelper (self )
2424 self .tools_helper = ToolsHelper (self )
25+ self .axtract_helper = AxtractHelper (self )
2526
2627
2728class DocumentHelper :
@@ -115,6 +116,7 @@ def load_parsed_pdf(self, path: str) -> ParseResponse:
115116 inline_equations = inline_equations ,
116117 )
117118
119+
118120class AxtractHelper :
119121 _ax_client : Axiomatic
120122
@@ -160,10 +162,10 @@ def analyze_equations(
160162 Examples:
161163 # From local file
162164 client.analyze_equations(file_path="path/to/paper.pdf")
163-
165+
164166 # From URL
165167 client.analyze_equations(url_path="https://arxiv.org/pdf/2203.00001.pdf")
166-
168+
167169 # From parsed paper
168170 client.analyze_equations(parsed_paper=parsed_data)
169171 """
@@ -172,34 +174,35 @@ def analyze_equations(
172174 parsed_document = self ._ax_client .document .parse (file = pdf_file )
173175 print ("We are almost there" )
174176 response = self ._ax_client .document .equation .process (
175- markdown = parsed_document .markdown ,
177+ markdown = parsed_document .markdown ,
176178 interline_equations = parsed_document .interline_equations ,
177- inline_equations = parsed_document .inline_equations
178- )
179-
179+ inline_equations = parsed_document .inline_equations ,
180+ )
181+
180182 elif url_path :
181183 if "arxiv" in url_path and "abs" in url_path :
182184 url_path = url_path .replace ("abs" , "pdf" )
183185 url_file = requests .get (url_path )
184186 from io import BytesIO
187+
185188 pdf_stream = BytesIO (url_file .content )
186189 parsed_document = self ._ax_client .document .parse (file = url_file .content )
187190 print ("We are almost there" )
188191 response = self ._ax_client .document .equation .process (
189- markdown = parsed_document .markdown ,
192+ markdown = parsed_document .markdown ,
190193 interline_equations = parsed_document .interline_equations ,
191- inline_equations = parsed_document .inline_equations
192- )
193-
194+ inline_equations = parsed_document .inline_equations ,
195+ )
196+
194197 elif parsed_paper :
195198 response = EquationProcessingResponse .model_validate (
196199 self ._ax_client .document .equation .process (** parsed_paper .model_dump ()).model_dump ()
197200 )
198-
201+
199202 else :
200203 print ("Please provide either a file path or a URL to analyze." )
201204 return None
202-
205+
203206 return response
204207
205208 def validate_equations (
@@ -220,19 +223,15 @@ def validate_equations(
220223 EquationValidationResult containing the validation results
221224 """
222225 # equations_dict = loaded_equations.model_dump() if hasattr(loaded_equations, 'model_dump') else loaded_equations.dict()
223-
226+
224227 api_response = self ._ax_client .document .equation .validate (
225- variables = requirements ,
226- paper_equations = loaded_equations ,
227- include_internal_model = include_internal_model
228- )
229-
228+ variables = requirements , paper_equations = loaded_equations , include_internal_model = include_internal_model
229+ )
230+
230231 return api_response
231-
232232
233233 def display_full_results (self , api_response : EquationValidationResult , user_choice ):
234234 display_full_results (api_response , user_choice )
235-
236235
237236 def set_numerical_requirements (self , extracted_equations : EquationProcessingResponse ):
238237 """Launch an interactive interface for setting numerical requirements for equations.
@@ -289,13 +288,13 @@ def tool_exec(self, tool: str, code: str, poll_interval: int = 3, debug: bool =
289288 print (f"status: { result .status } " )
290289 if result .status == "SUCCEEDED" :
291290 output = json .loads (result .output or "{}" )
292- if not output [' objects' ]:
291+ if not output [" objects" ]:
293292 return result .output
294293 else :
295294 return {
296295 "job_id" : job_id ,
297- "messages" : output [' messages' ],
298- "objects" : self ._load_objects_from_base64 (output [' objects' ])
296+ "messages" : output [" messages" ],
297+ "objects" : self ._load_objects_from_base64 (output [" objects" ]),
299298 }
300299 else :
301300 return result .error_trace
@@ -306,10 +305,10 @@ def load(self, job_id: str, obj_key: str):
306305 result = self ._ax_client .tools .status (job_id = job_id )
307306 if result .status == "SUCCEEDED" :
308307 output = json .loads (result .output or "{}" )
309- if not output [' objects' ]:
308+ if not output [" objects" ]:
310309 return result .output
311310 else :
312- return self ._load_objects_from_base64 (output [' objects' ])[obj_key ]
311+ return self ._load_objects_from_base64 (output [" objects" ])[obj_key ]
313312 else :
314313 return result .error_trace
315314
0 commit comments