Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ repos:
hooks:
- id: black
args: ["--line-length", "119", "--skip-string-normalization"]


- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
Expand Down
12 changes: 5 additions & 7 deletions lagent/actions/arxiv_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Type

from aioify import aioify
from asyncer import asyncify

from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
from lagent.actions.parser import BaseParser, JsonParser
Expand Down Expand Up @@ -42,12 +42,10 @@ def get_arxiv_article_information(self, query: str) -> dict:

try:
results = arxiv.Search( # type: ignore
query[:self.max_query_len],
max_results=self.top_k_results).results()
query[: self.max_query_len], max_results=self.top_k_results
).results()
except Exception as exc:
return ActionReturn(
errmsg=f'Arxiv exception: {exc}',
state=ActionStatusCode.HTTP_ERROR)
return ActionReturn(errmsg=f'Arxiv exception: {exc}', state=ActionStatusCode.HTTP_ERROR)
docs = [
f'Published: {result.updated.date()}\nTitle: {result.title}\n'
f'Authors: {", ".join(a.name for a in result.authors)}\n'
Expand All @@ -67,7 +65,7 @@ class AsyncArxivSearch(AsyncActionMixin, ArxivSearch):
"""

@tool_api(explode_return=True)
@aioify
@asyncify
def get_arxiv_article_information(self, query: str) -> dict:
"""Run Arxiv search and get the article meta information.

Expand Down
185 changes: 102 additions & 83 deletions lagent/actions/google_scholar_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from typing import Optional, Type

from aioify import aioify
from asyncer import asyncify

from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
from lagent.schema import ActionReturn, ActionStatusCode
Expand Down Expand Up @@ -31,7 +31,8 @@ def __init__(
if api_key is None:
raise ValueError(
'Please set Serper API key either in the environment '
'as SERPER_API_KEY or pass it as `api_key` parameter.')
'as SERPER_API_KEY or pass it as `api_key` parameter.'
)
self.api_key = api_key

@tool_api(explode_return=True)
Expand Down Expand Up @@ -78,6 +79,7 @@ def search_google_scholar(
- pub_info: publication information of selected papers
"""
from serpapi import GoogleSearch

params = {
'q': query,
'engine': 'google_scholar',
Expand All @@ -94,7 +96,7 @@ def search_google_scholar(
'as_sdt': as_sdt,
'safe': safe,
'filter': filter,
'as_vis': as_vis
'as_vis': as_vis,
}
search = GoogleSearch(params)
try:
Expand All @@ -112,27 +114,24 @@ def search_google_scholar(
cited_by.append(citation['total'])
snippets.append(item['snippet'])
organic_id.append(item['result_id'])
return dict(
title=title,
cited_by=cited_by,
organic_id=organic_id,
snippets=snippets)
return dict(title=title, cited_by=cited_by, organic_id=organic_id, snippets=snippets)
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)

@tool_api(explode_return=True)
def get_author_information(self,
author_id: str,
hl: Optional[str] = None,
view_op: Optional[str] = None,
sort: Optional[str] = None,
citation_id: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None) -> dict:
def get_author_information(
self,
author_id: str,
hl: Optional[str] = None,
view_op: Optional[str] = None,
sort: Optional[str] = None,
citation_id: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None,
) -> dict:
"""Search for an author's information by author's id provided by get_author_id.

Args:
Expand All @@ -155,6 +154,7 @@ def get_author_information(self,
* website: the author's homepage url
"""
from serpapi import GoogleSearch

params = {
'engine': 'google_scholar_author',
'author_id': author_id,
Expand All @@ -167,7 +167,7 @@ def get_author_information(self,
'num': num,
'no_cache': no_cache,
'async': async_req,
'output': output
'output': output,
}
try:
search = GoogleSearch(params)
Expand All @@ -178,20 +178,19 @@ def get_author_information(self,
name=author['name'],
affiliations=author.get('affiliations', ''),
website=author.get('website', ''),
articles=[
dict(title=article['title'], authors=article['authors'])
for article in articles[:3]
])
articles=[dict(title=article['title'], authors=article['authors']) for article in articles[:3]],
)
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)

@tool_api(explode_return=True)
def get_citation_format(self,
q: str,
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json') -> dict:
def get_citation_format(
self,
q: str,
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json',
) -> dict:
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.

Args:
Expand All @@ -206,13 +205,14 @@ def get_citation_format(self,
* citation: the citation format of the article
"""
from serpapi import GoogleSearch

params = {
'q': q,
'engine': 'google_scholar_cite',
'api_key': self.api_key,
'no_cache': no_cache,
'async': async_,
'output': output
'output': output,
}
try:
search = GoogleSearch(params)
Expand All @@ -221,18 +221,19 @@ def get_citation_format(self,
citation_info = citation[0]['snippet']
return citation_info
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)

@tool_api(explode_return=True)
def get_author_id(self,
mauthors: str,
hl: Optional[str] = 'en',
after_author: Optional[str] = None,
before_author: Optional[str] = None,
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json') -> dict:
def get_author_id(
self,
mauthors: str,
hl: Optional[str] = 'en',
after_author: Optional[str] = None,
before_author: Optional[str] = None,
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json',
) -> dict:
"""The getAuthorId function is used to get the author's id by his or her name.

Args:
Expand All @@ -249,6 +250,7 @@ def get_author_id(self,
* author_id: the author_id of the author
"""
from serpapi import GoogleSearch

params = {
'mauthors': mauthors,
'engine': 'google_scholar_profiles',
Expand All @@ -258,7 +260,7 @@ def get_author_id(self,
'before_author': before_author,
'no_cache': no_cache,
'async': _async,
'output': output
'output': output,
}
try:
search = GoogleSearch(params)
Expand All @@ -267,8 +269,7 @@ def get_author_id(self,
author_info = dict(author_id=profile[0]['author_id'])
return author_info
except Exception as e:
return ActionReturn(
errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)


class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar):
Expand All @@ -283,7 +284,7 @@ class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar):
"""

@tool_api(explode_return=True)
@aioify
@asyncify
def search_google_scholar(
self,
query: str,
Expand Down Expand Up @@ -326,23 +327,38 @@ def search_google_scholar(
- organic_id: a list of the organic results' ids of the three selected papers
- pub_info: publication information of selected papers
"""
return super().search_google_scholar(query, cites, as_ylo, as_yhi,
scisbd, cluster, hl, lr, start,
num, as_sdt, safe, filter, as_vis)
return super().search_google_scholar(
query,
cites,
as_ylo,
as_yhi,
scisbd,
cluster,
hl,
lr,
start,
num,
as_sdt,
safe,
filter,
as_vis,
)

@tool_api(explode_return=True)
@aioify
def get_author_information(self,
author_id: str,
hl: Optional[str] = None,
view_op: Optional[str] = None,
sort: Optional[str] = None,
citation_id: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None) -> dict:
@asyncify
def get_author_information(
self,
author_id: str,
hl: Optional[str] = None,
view_op: Optional[str] = None,
sort: Optional[str] = None,
citation_id: Optional[str] = None,
start: Optional[int] = None,
num: Optional[int] = None,
no_cache: Optional[bool] = None,
async_req: Optional[bool] = None,
output: Optional[str] = None,
) -> dict:
"""Search for an author's information by author's id provided by get_author_id.

Args:
Expand All @@ -364,17 +380,19 @@ def get_author_information(self,
* articles: at most 3 articles by the author
* website: the author's homepage url
"""
return super().get_author_information(author_id, hl, view_op, sort,
citation_id, start, num,
no_cache, async_req, output)
return super().get_author_information(
author_id, hl, view_op, sort, citation_id, start, num, no_cache, async_req, output
)

@tool_api(explode_return=True)
@aioify
def get_citation_format(self,
q: str,
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json') -> dict:
@asyncify
def get_citation_format(
self,
q: str,
no_cache: Optional[bool] = None,
async_: Optional[bool] = None,
output: Optional[str] = 'json',
) -> dict:
"""Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.

Args:
Expand All @@ -391,15 +409,17 @@ def get_citation_format(self,
return super().get_citation_format(q, no_cache, async_, output)

@tool_api(explode_return=True)
@aioify
def get_author_id(self,
mauthors: str,
hl: Optional[str] = 'en',
after_author: Optional[str] = None,
before_author: Optional[str] = None,
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json') -> dict:
@asyncify
def get_author_id(
self,
mauthors: str,
hl: Optional[str] = 'en',
after_author: Optional[str] = None,
before_author: Optional[str] = None,
no_cache: Optional[bool] = False,
_async: Optional[bool] = False,
output: Optional[str] = 'json',
) -> dict:
"""The getAuthorId function is used to get the author's id by his or her name.

Args:
Expand All @@ -415,5 +435,4 @@ def get_author_id(self,
:class:`dict`: author id
* author_id: the author_id of the author
"""
return super().get_author_id(mauthors, hl, after_author, before_author,
no_cache, _async, output)
return super().get_author_id(mauthors, hl, after_author, before_author, no_cache, _async, output)
Loading