Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyTigerGraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
from pyTigerGraph.pytgasync.pyTigerGraph import AsyncTigerGraphConnection
from pyTigerGraph.common.exception import TigerGraphException

__version__ = "1.8.6"
__version__ = "1.8.7"

__license__ = "Apache 2"
49 changes: 39 additions & 10 deletions pyTigerGraph/ai/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, conn: TigerGraphConnection) -> None:
"""
self.conn = conn
self.nlqs_host = None
self.aiserver = "supportai"
if conn.tgCloud:
# split scheme and host
scheme, host = conn.host.split("://")
Expand All @@ -70,6 +71,16 @@ def configureCoPilotHost(self, hostname: str):
"""
self.nlqs_host = hostname

def configureServerHost(self, hostname: str, aiserver: str):
""" Configure the hostname of the AI service.
Not necessary if using TigerGraph AI on TigerGraph Cloud.
Args:
hostname (str):
The hostname (and port number) of the CoPilot serivce.
"""
self.nlqs_host = hostname
self.aiserver = aiserver

def registerCustomQuery(self, query_name: str, description: str = None, docstring: str = None, param_types: dict = None):
""" Register a custom query with the InquiryAI service.
Args:
Expand Down Expand Up @@ -227,7 +238,22 @@ def initializeSupportAI(self):
Returns:
JSON response from the SupportAI service.
"""
url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/initialize"
return self.initializeAIServer("supportai")

def initializeGraphAI(self):
""" Initialize the GraphAI service.
Returns:
JSON response from the GraphAI service.
"""
return self.initializeAIServer("graphai")

def initializeAIServer(self, server="supportai"):
""" Initialize the given service.
Returns:
JSON response from the given service.
"""
self.aiserver = server
url = f"{self.nlqs_host}/{self.conn.graphname}/{self.aiserver}/initialize"
return self.conn._req("POST", url, authMode="pwd", resKey=None)

def createDocumentIngest(self, data_source, data_source_config, loader_config, file_format):
Expand All @@ -251,10 +277,10 @@ def createDocumentIngest(self, data_source, data_source_config, loader_config, f
"file_format": file_format
}

url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/create_ingest"
url = f"{self.nlqs_host}/{self.conn.graphname}/{self.aiserver}/create_ingest"
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)

def runDocumentIngest(self, load_job_id, data_source_id, data_path):
def runDocumentIngest(self, load_job_id, data_source_id, data_path, data_source="remote"):
""" Run a document ingest.
Args:
load_job_id (str):
Expand All @@ -266,13 +292,16 @@ def runDocumentIngest(self, load_job_id, data_source_id, data_path):
Returns:
JSON response from the document ingest.
"""
data = {
"load_job_id": load_job_id,
"data_source_id": data_source_id,
"file_path": data_path
}
url = self.nlqs_host+"/"+self.conn.graphname+"/supportai/ingest"
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)
if data_source.lower() == "local" or data_path.startswith(("/", ".", "~")) :
return self.conn.runLoadingJobWithFile(data_path, data_source_id, load_job_id)
else:
data = {
"load_job_id": load_job_id,
"data_source_id": data_source_id,
"file_path": data_path
}
url = f"{self.nlqs_host}/{self.conn.graphname}/{self.aiserver}/ingest"
return self.conn._req("POST", url, authMode="pwd", data=data, jsonData=True, resKey=None)

def searchDocuments(self, query, method="hnswoverlap", method_parameters: dict = {"indices": ["Document", "DocumentChunk", "Entity", "Relationship"], "top_k": 2, "num_hops": 2, "num_seen_min": 2}):
""" Search documents.
Expand Down
11 changes: 8 additions & 3 deletions pyTigerGraph/pyTigerGraphBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,17 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N
_headers, _data, verify = self._prep_req(
authMode, headers, url, method, data)

if "GSQL-TIMEOUT" in _headers:
http_timeout = (10, int(int(_headers["GSQL-TIMEOUT"])/1000) + 10)
else:
http_timeout = None

if jsonData:
res = requests.request(
method, url, headers=_headers, json=_data, params=params, verify=verify)
method, url, headers=_headers, json=_data, params=params, verify=verify, timeout=http_timeout)
else:
res = requests.request(
method, url, headers=_headers, data=_data, params=params, verify=verify)
method, url, headers=_headers, data=_data, params=params, verify=verify, timeout=http_timeout)

try:
if not skipCheck and not (200 <= res.status_code < 300):
Expand Down Expand Up @@ -563,4 +568,4 @@ def _version_greater_than_4_0(self) -> bool:
version = self.getVer().split('.')
if version[0] >= "4" and version[1] > "0":
return True
return False
return False
9 changes: 7 additions & 2 deletions pyTigerGraph/pytgasync/pyTigerGraphBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,16 @@ async def _req(self, method: str, url: str, authMode: str = "token", headers: di
_headers, _data, verify = self._prep_req(
authMode, headers, url, method, data)

if "GSQL-TIMEOUT" in _headers:
http_timeout = (10, int(_headers["GSQL-TIMEOUT"]/1000) + 10)
else:
http_timeout = None

async with httpx.AsyncClient(timeout=None) as client:
if jsonData:
res = await client.request(method, url, headers=_headers, json=_data, params=params)
res = await client.request(method, url, headers=_headers, json=_data, params=params, timeout=http_timeout)
else:
res = await client.request(method, url, headers=_headers, data=_data, params=params)
res = await client.request(method, url, headers=_headers, data=_data, params=params, timeout=http_timeout)

try:
if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404:
Expand Down