Skip to content

Commit b507b89

Browse files
committed
add axtract helper
1 parent a0de0a5 commit b507b89

File tree

1 file changed

+25
-26
lines changed

1 file changed

+25
-26
lines changed

src/axiomatic/client.py

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import base64
22
import dill # type: ignore
33
import json
4-
import requests # type: ignore
4+
import requests # type: ignore
55
import os
66
import time
77
from typing import Dict, Optional, Sequence
@@ -22,6 +22,7 @@ def __init__(self, *args, **kwargs):
2222

2323
self.document_helper = DocumentHelper(self)
2424
self.tools_helper = ToolsHelper(self)
25+
self.axtract_helper = AxtractHelper(self)
2526

2627

2728
class DocumentHelper:
@@ -115,6 +116,7 @@ def load_parsed_pdf(self, path: str) -> ParseResponse:
115116
inline_equations=inline_equations,
116117
)
117118

119+
118120
class AxtractHelper:
119121
_ax_client: Axiomatic
120122

@@ -160,10 +162,10 @@ def analyze_equations(
160162
Examples:
161163
# From local file
162164
client.analyze_equations(file_path="path/to/paper.pdf")
163-
165+
164166
# From URL
165167
client.analyze_equations(url_path="https://arxiv.org/pdf/2203.00001.pdf")
166-
168+
167169
# From parsed paper
168170
client.analyze_equations(parsed_paper=parsed_data)
169171
"""
@@ -172,34 +174,35 @@ def analyze_equations(
172174
parsed_document = self._ax_client.document.parse(file=pdf_file)
173175
print("We are almost there")
174176
response = self._ax_client.document.equation.process(
175-
markdown=parsed_document.markdown,
177+
markdown=parsed_document.markdown,
176178
interline_equations=parsed_document.interline_equations,
177-
inline_equations=parsed_document.inline_equations
178-
)
179-
179+
inline_equations=parsed_document.inline_equations,
180+
)
181+
180182
elif url_path:
181183
if "arxiv" in url_path and "abs" in url_path:
182184
url_path = url_path.replace("abs", "pdf")
183185
url_file = requests.get(url_path)
184186
from io import BytesIO
187+
185188
pdf_stream = BytesIO(url_file.content)
186189
parsed_document = self._ax_client.document.parse(file=url_file.content)
187190
print("We are almost there")
188191
response = self._ax_client.document.equation.process(
189-
markdown=parsed_document.markdown,
192+
markdown=parsed_document.markdown,
190193
interline_equations=parsed_document.interline_equations,
191-
inline_equations=parsed_document.inline_equations
192-
)
193-
194+
inline_equations=parsed_document.inline_equations,
195+
)
196+
194197
elif parsed_paper:
195198
response = EquationProcessingResponse.model_validate(
196199
self._ax_client.document.equation.process(**parsed_paper.model_dump()).model_dump()
197200
)
198-
201+
199202
else:
200203
print("Please provide either a file path or a URL to analyze.")
201204
return None
202-
205+
203206
return response
204207

205208
def validate_equations(
@@ -220,19 +223,15 @@ def validate_equations(
220223
EquationValidationResult containing the validation results
221224
"""
222225
# equations_dict = loaded_equations.model_dump() if hasattr(loaded_equations, 'model_dump') else loaded_equations.dict()
223-
226+
224227
api_response = self._ax_client.document.equation.validate(
225-
variables=requirements,
226-
paper_equations=loaded_equations,
227-
include_internal_model=include_internal_model
228-
)
229-
228+
variables=requirements, paper_equations=loaded_equations, include_internal_model=include_internal_model
229+
)
230+
230231
return api_response
231-
232232

233233
def display_full_results(self, api_response: EquationValidationResult, user_choice):
234234
display_full_results(api_response, user_choice)
235-
236235

237236
def set_numerical_requirements(self, extracted_equations: EquationProcessingResponse):
238237
"""Launch an interactive interface for setting numerical requirements for equations.
@@ -289,13 +288,13 @@ def tool_exec(self, tool: str, code: str, poll_interval: int = 3, debug: bool =
289288
print(f"status: {result.status}")
290289
if result.status == "SUCCEEDED":
291290
output = json.loads(result.output or "{}")
292-
if not output['objects']:
291+
if not output["objects"]:
293292
return result.output
294293
else:
295294
return {
296295
"job_id": job_id,
297-
"messages": output['messages'],
298-
"objects": self._load_objects_from_base64(output['objects'])
296+
"messages": output["messages"],
297+
"objects": self._load_objects_from_base64(output["objects"]),
299298
}
300299
else:
301300
return result.error_trace
@@ -306,10 +305,10 @@ def load(self, job_id: str, obj_key: str):
306305
result = self._ax_client.tools.status(job_id=job_id)
307306
if result.status == "SUCCEEDED":
308307
output = json.loads(result.output or "{}")
309-
if not output['objects']:
308+
if not output["objects"]:
310309
return result.output
311310
else:
312-
return self._load_objects_from_base64(output['objects'])[obj_key]
311+
return self._load_objects_from_base64(output["objects"])[obj_key]
313312
else:
314313
return result.error_trace
315314

0 commit comments

Comments
 (0)