Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions domaintools/base_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,53 @@ def _wait_time(self):

return wait_for

def _get_feeds_results_generator(self, parameters, headers):
with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
status_code = None
while status_code != 200:
resp_data = session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
status_code = resp_data.status_code
self.setStatus(status_code, resp_data)

# Check limit exceeded here
if "response" in resp_data.text and "limit_exceeded" in resp_data.text:
self._limit_exceeded = True
self._limit_exceeded_message = "limit exceeded"

yield resp_data

if self._limit_exceeded:
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))

if not self.kwargs.get("sessionID"):
# we'll only do iterative request for queries that has sessionID.
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
break

def _get_session_params(self):
parameters = deepcopy(self.kwargs)
parameters.pop("output_format", None)
parameters.pop(
"format", None
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
headers = {}
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT

header_api_key = parameters.pop("X-Api-Key", None)
if header_api_key:
headers["X-Api-Key"] = header_api_key

return {"parameters": parameters, "headers": headers}

def _make_request(self):
if self.product in FEEDS_PRODUCTS_LIST:
session_params = self._get_session_params()
parameters = session_params.get("parameters")
headers = session_params.get("headers")

return self._get_feeds_results_generator(parameters=parameters, headers=headers)

with Client(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
if self.product in [
Expand All @@ -92,30 +138,15 @@ def _make_request(self):
patch_data = self.kwargs.copy()
patch_data.update(self.api.extra_request_params)
return session.patch(url=self.url, json=patch_data)
elif self.product in FEEDS_PRODUCTS_LIST:
parameters = deepcopy(self.kwargs)
parameters.pop("output_format", None)
parameters.pop(
"format", None
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
headers = {}
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT

header_api_key = parameters.pop("X-Api-Key", None)
if header_api_key:
headers["X-Api-Key"] = header_api_key

return session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
else:
return session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)

def _get_results(self):
wait_for = self._wait_time()
if self.api.rate_limit and (wait_for is None or self.product == "account-information"):
data = self._make_request()
if data.status_code == 503: # pragma: no cover
status_code = data.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
if status_code == 503: # pragma: no cover
sleeptime = 60
log.info(
"503 encountered for [%s] - sleeping [%s] seconds before retrying request.",
Expand All @@ -135,12 +166,15 @@ def _get_results(self):
def data(self):
if self._data is None:
results = self._get_results()
self.setStatus(results.status_code, results)
status_code = results.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
self.setStatus(status_code, results)
if (
self.kwargs.get("format", "json") == "json"
and self.product not in FEEDS_PRODUCTS_LIST # Special handling of feeds products' data to preserve the result in jsonline format
):
self._data = results.json()
elif self.product in FEEDS_PRODUCTS_LIST:
self._data = results # Uses generator to handle large data results from feeds endpoint
else:
self._data = results.text
limit_exceeded, message = self.check_limit_exceeded()
Expand All @@ -155,6 +189,10 @@ def data(self):
return self._data

def check_limit_exceeded(self):
if self.product in FEEDS_PRODUCTS_LIST:
# bypass here as this is handled in generator already
return False, ""

if self.kwargs.get("format", "json") == "json" and self.product not in FEEDS_PRODUCTS_LIST:
if "response" in self._data and "limit_exceeded" in self._data["response"] and self._data["response"]["limit_exceeded"] is True:
return True, self._data["response"]["message"]
Expand All @@ -172,7 +210,7 @@ def status(self):

def setStatus(self, code, response=None):
self._status = code
if code == 200:
if code == 200 or (self.product in FEEDS_PRODUCTS_LIST and code == 206):
return

reason = None
Expand Down Expand Up @@ -211,7 +249,7 @@ def response(self):
return self._response

def items(self):
return self.response().items()
return self.response().items() if isinstance(self.response(), dict) else self.response()

def emails(self):
"""Find and returns all emails mentioned in the response"""
Expand Down
10 changes: 8 additions & 2 deletions domaintools/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from rich.progress import Progress, SpinnerColumn, TextColumn

from domaintools.api import API
from domaintools.constants import Endpoint, OutputFormat
from domaintools.constants import Endpoint, OutputFormat, FEEDS_PRODUCTS_LIST
from domaintools.cli.utils import get_file_extension
from domaintools.exceptions import ServiceException
from domaintools._version import current as version
Expand Down Expand Up @@ -110,6 +110,9 @@ def args_to_dict(*args) -> Dict:
def _get_formatted_output(cls, cmd_name: str, response, out_format: str = "json"):
if cmd_name in ("available_api_calls",):
return "\n".join(response)
if response.product in FEEDS_PRODUCTS_LIST:
return "\n".join([data.text for data in response])

return str(getattr(response, out_format) if out_format != "list" else response.as_list())

@classmethod
Expand Down Expand Up @@ -227,7 +230,10 @@ def run(cls, name: str, params: Optional[Dict] = {}, **kwargs):

if isinstance(out_file, _io.TextIOWrapper):
# use rich `print` command to prettify the ouput in sys.stdout
print(response)
if response.product in FEEDS_PRODUCTS_LIST:
print(output)
else:
print(response)
else:
# if it's a file then write
out_file.write(output if output.endswith("\n") else output + "\n")
Expand Down
50 changes: 30 additions & 20 deletions domaintools_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@

import asyncio

from copy import deepcopy
from httpx import AsyncClient

from domaintools.base_results import Results
from domaintools.constants import FEEDS_PRODUCTS_LIST, OutputFormat, HEADER_ACCEPT_KEY_CSV_FORMAT
from domaintools.exceptions import ServiceUnavailableException
from domaintools.constants import FEEDS_PRODUCTS_LIST
from domaintools.exceptions import ServiceUnavailableException, ServiceException


class _AIter(object):
Expand Down Expand Up @@ -42,6 +41,26 @@ class AsyncResults(Results):
def __await__(self):
return self.__awaitable__().__await__()

async def _get_feeds_async_results_generator(self, session, parameters, headers):
status_code = None
while status_code != 200:
resp_data = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
status_code = resp_data.status_code
self.setStatus(status_code, resp_data)

# Check limit exceeded here
if "response" in resp_data.text and "limit_exceeded" in resp_data.text:
self._limit_exceeded = True
self._limit_exceeded_message = "limit exceeded"
yield resp_data

if self._limit_exceeded:
raise ServiceException(503, "Limit Exceeded{}".format(self._limit_exceeded_message))
if not self.kwargs.get("sessionID"):
# we'll only do iterative request for queries that has sessionID.
# Otherwise, we will have an infinite request if sessionID was not provided but the required data asked is more than the maximum (1 hour of data)
break

async def _make_async_request(self, session):
if self.product in ["iris-investigate", "iris-enrich", "iris-detect-escalate-domains"]:
post_data = self.kwargs.copy()
Expand All @@ -52,27 +71,19 @@ async def _make_async_request(self, session):
patch_data.update(self.api.extra_request_params)
results = await session.patch(url=self.url, json=patch_data)
elif self.product in FEEDS_PRODUCTS_LIST:
parameters = deepcopy(self.kwargs)
parameters.pop("output_format", None)
parameters.pop(
"format", None
) # For some unknownn reasons, even if "format" is not included in the cli params for feeds endpoint, it is being populated thus we need to remove it. Happens only if using CLI.
headers = {}
if self.kwargs.get("output_format", OutputFormat.JSONL.value) == OutputFormat.CSV.value:
parameters["headers"] = int(bool(self.kwargs.get("headers", False)))
headers["accept"] = HEADER_ACCEPT_KEY_CSV_FORMAT

header_api_key = parameters.pop("X-Api-Key", None)
if header_api_key:
headers["X-Api-Key"] = header_api_key

results = await session.get(url=self.url, params=parameters, headers=headers, **self.api.extra_request_params)
generator_params = self._get_session_params()
parameters = generator_params.get("parameters")
headers = generator_params.get("headers")
results = await self._get_feeds_async_results_generator(session=session, parameters=parameters, headers=headers)
else:
results = await session.get(url=self.url, params=self.kwargs, **self.api.extra_request_params)
if results:
self.setStatus(results.status_code, results)
status_code = results.status_code if self.product not in FEEDS_PRODUCTS_LIST else 200
self.setStatus(status_code, results)
if self.kwargs.get("format", "json") == "json":
self._data = results.json()
elif self.product in FEEDS_PRODUCTS_LIST:
self._data = results # Uses generator to handle large data results from feeds endpoint
else:
self._data = results.text()
limit_exceeded, message = self.check_limit_exceeded()
Expand All @@ -83,7 +94,6 @@ async def _make_async_request(self, session):

async def __awaitable__(self):
if self._data is None:

async with AsyncClient(verify=self.api.verify_ssl, proxy=self.api.proxy_url, timeout=None) as session:
wait_time = self._wait_time()
if wait_time is None and self.api:
Expand Down
Loading
Loading