Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/python/src/flamepy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ async def put_object(session_id: str, data: bytes) -> "DataExpr":

return DataExpr(source=DataSource.REMOTE, url=metadata.endpoint, data=data, version=metadata.version)


async def get_object(de: DataExpr) -> "DataExpr":
"""Get an object from the cache."""
if de.source != DataSource.REMOTE:
Expand All @@ -61,6 +62,7 @@ async def get_object(de: DataExpr) -> "DataExpr":

return de


async def update_object(de: DataExpr) -> "DataExpr":
"""Update an object in the cache."""
if de.source != DataSource.REMOTE:
Expand Down
185 changes: 38 additions & 147 deletions sdk/python/src/flamepy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,7 @@ async def connect(addr: str) -> "Connection":
return await Connection.connect(addr)


async def create_session(application: str,
common_data: Dict[str, Any] = None,
session_id: str = None,
slots: int = 1) -> "Session":
async def create_session(application: str, common_data: Dict[str, Any] = None, session_id: str = None, slots: int = 1) -> "Session":

conn = await ConnectionInstance.instance()

Expand All @@ -81,19 +78,14 @@ async def create_session(application: str,
elif isinstance(common_data, FlameRequest):
common_data = common_data.to_json()
elif not isinstance(common_data, CommonData):
raise ValueError(
"Invalid common data type, must be a Request or CommonData")
raise ValueError("Invalid common data type, must be a Request or CommonData")

session_id = short_name(application) if session_id is None else session_id

data_expr = await put_object(session_id, common_data)
common_data = CommonData(data_expr.to_json())

session = await conn.create_session(
SessionAttributes(id=session_id,
application=application,
common_data=common_data,
slots=slots))
session = await conn.create_session(SessionAttributes(id=session_id, application=application, common_data=common_data, slots=slots))
return session


Expand All @@ -102,8 +94,7 @@ async def open_session(session_id: SessionID) -> "Session":
return await conn.open_session(session_id)


async def register_application(
name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]]) -> None:
async def register_application(name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]]) -> None:
conn = await ConnectionInstance.instance()
await conn.register_application(name, app_attrs)

Expand Down Expand Up @@ -201,17 +192,13 @@ async def connect(cls, addr: str) -> "Connection":
return cls(addr, channel, frontend)

except Exception as e:
raise FlameError(
FlameErrorCode.INVALID_CONFIG, f"failed to connect to {addr}: {str(e)}"
)
raise FlameError(FlameErrorCode.INVALID_CONFIG, f"failed to connect to {addr}: {str(e)}")

async def close(self) -> None:
"""Close the connection."""
await self._channel.close()

async def register_application(
self, name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]]
) -> None:
async def register_application(self, name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]]) -> None:
"""Register a new application."""
if isinstance(app_attrs, dict):
app_attrs = ApplicationAttributes(**app_attrs)
Expand Down Expand Up @@ -292,9 +279,7 @@ async def list_applications(self) -> List[Application]:
name=app.metadata.name,
shim=Shim(app.spec.shim),
state=ApplicationState(app.status.state),
creation_time=datetime.fromtimestamp(
app.status.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(app.status.creation_time / 1000, tz=timezone.utc),
image=app.spec.image,
command=app.spec.command,
arguments=list(app.spec.arguments),
Expand All @@ -309,9 +294,7 @@ async def list_applications(self) -> List[Application]:
return applications

except grpc.RpcError as e:
raise FlameError(
FlameErrorCode.INTERNAL, f"failed to list applications: {e.details()}"
)
raise FlameError(FlameErrorCode.INTERNAL, f"failed to list applications: {e.details()}")

async def get_application(self, name: str) -> Application:
"""Get an application by name."""
Expand All @@ -337,9 +320,7 @@ async def get_application(self, name: str) -> Application:
name=response.metadata.name,
shim=Shim(response.spec.shim),
state=ApplicationState(response.status.state),
creation_time=datetime.fromtimestamp(
response.status.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc),
image=response.spec.image,
command=response.spec.command,
arguments=list(response.spec.arguments),
Expand All @@ -351,9 +332,7 @@ async def get_application(self, name: str) -> Application:
)

except grpc.RpcError as e:
raise FlameError(
FlameErrorCode.INTERNAL, f"failed to get application: {e.details()}"
)
raise FlameError(FlameErrorCode.INTERNAL, f"failed to get application: {e.details()}")

async def create_session(self, attrs: SessionAttributes) -> "Session":
"""Create a new session."""
Expand All @@ -374,26 +353,16 @@ async def create_session(self, attrs: SessionAttributes) -> "Session":
application=response.spec.application,
slots=response.spec.slots,
state=SessionState(response.status.state),
creation_time=datetime.fromtimestamp(
response.status.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc),
pending=response.status.pending,
running=response.status.running,
succeed=response.status.succeed,
failed=response.status.failed,
completion_time=(
datetime.fromtimestamp(
response.status.completion_time / 1000, tz=timezone.utc
)
if response.status.HasField("completion_time")
else None
),
completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None),
)
return session
except grpc.RpcError as e:
raise FlameError(
FlameErrorCode.INTERNAL, f"failed to create session: {e.details()}"
)
raise FlameError(FlameErrorCode.INTERNAL, f"failed to create session: {e.details()}")

async def list_sessions(self) -> List["Session"]:
"""List all sessions."""
Expand All @@ -411,29 +380,19 @@ async def list_sessions(self) -> List["Session"]:
application=session.spec.application,
slots=session.spec.slots,
state=SessionState(session.status.state),
creation_time=datetime.fromtimestamp(
session.status.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(session.status.creation_time / 1000, tz=timezone.utc),
pending=session.status.pending,
running=session.status.running,
succeed=session.status.succeed,
failed=session.status.failed,
completion_time=(
datetime.fromtimestamp(
session.status.completion_time / 1000, tz=timezone.utc
)
if session.status.HasField("completion_time")
else None
),
completion_time=(datetime.fromtimestamp(session.status.completion_time / 1000, tz=timezone.utc) if session.status.HasField("completion_time") else None),
)
)

return sessions

except grpc.RpcError as e:
raise FlameError(
FlameErrorCode.INTERNAL, f"failed to list sessions: {e.details()}"
)
raise FlameError(FlameErrorCode.INTERNAL, f"failed to list sessions: {e.details()}")

async def open_session(self, session_id: SessionID) -> "Session":
"""Open a session."""
Expand All @@ -447,26 +406,16 @@ async def open_session(self, session_id: SessionID) -> "Session":
application=response.spec.application,
slots=response.spec.slots,
state=SessionState(response.status.state),
creation_time=datetime.fromtimestamp(
response.status.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc),
pending=response.status.pending,
running=response.status.running,
succeed=response.status.succeed,
failed=response.status.failed,
completion_time=(
datetime.fromtimestamp(
response.status.completion_time / 1000, tz=timezone.utc
)
if response.status.HasField("completion_time")
else None
),
completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None),
)

except grpc.RpcError as e:
raise FlameError(
FlameErrorCode.INTERNAL, f"failed to open session: {e.details()}"
)
raise FlameError(FlameErrorCode.INTERNAL, f"failed to open session: {e.details()}")

async def get_session(self, session_id: SessionID) -> "Session":
"""Get a session by ID."""
Expand All @@ -481,26 +430,16 @@ async def get_session(self, session_id: SessionID) -> "Session":
application=response.spec.application,
slots=response.spec.slots,
state=SessionState(response.status.state),
creation_time=datetime.fromtimestamp(
response.status.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc),
pending=response.status.pending,
running=response.status.running,
succeed=response.status.succeed,
failed=response.status.failed,
completion_time=(
datetime.fromtimestamp(
response.status.completion_time / 1000, tz=timezone.utc
)
if response.status.HasField("completion_time")
else None
),
completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None),
)

except grpc.RpcError as e:
raise FlameError(
FlameErrorCode.INTERNAL, f"failed to get session: {e.details()}"
)
raise FlameError(FlameErrorCode.INTERNAL, f"failed to get session: {e.details()}")

async def close_session(self, session_id: SessionID) -> "Session":
"""Close a session."""
Expand All @@ -515,26 +454,16 @@ async def close_session(self, session_id: SessionID) -> "Session":
application=response.spec.application,
slots=response.spec.slots,
state=SessionState(response.status.state),
creation_time=datetime.fromtimestamp(
response.status.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc),
pending=response.status.pending,
running=response.status.running,
succeed=response.status.succeed,
failed=response.status.failed,
completion_time=(
datetime.fromtimestamp(
response.status.completion_time / 1000, tz=timezone.utc
)
if response.status.HasField("completion_time")
else None
),
completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None),
)

except grpc.RpcError as e:
raise FlameError(
FlameErrorCode.INTERNAL, f"failed to close session: {e.details()}"
)
raise FlameError(FlameErrorCode.INTERNAL, f"failed to close session: {e.details()}")


class Session:
Expand Down Expand Up @@ -594,33 +523,21 @@ async def create_task(self, input_data: TaskInput) -> Task:
id=response.metadata.id,
session_id=self.id,
state=TaskState(response.status.state),
creation_time=datetime.fromtimestamp(
response.status.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc),
input=input_data,
completion_time=(
datetime.fromtimestamp(
response.status.completion_time / 1000, tz=timezone.utc
)
if response.status.HasField("completion_time")
else None
),
completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None),
events=[
Event(
code=event.code,
message=event.message,
creation_time=datetime.fromtimestamp(
event.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(event.creation_time / 1000, tz=timezone.utc),
)
for event in response.status.events
],
)

except grpc.RpcError as e:
raise FlameError(
FlameErrorCode.INTERNAL, f"failed to create task: {e.details()}"
)
raise FlameError(FlameErrorCode.INTERNAL, f"failed to create task: {e.details()}")

async def get_task(self, task_id: TaskID) -> Task:
"""Get a task by ID."""
Expand All @@ -633,34 +550,22 @@ async def get_task(self, task_id: TaskID) -> Task:
id=response.metadata.id,
session_id=self.id,
state=TaskState(response.status.state),
creation_time=datetime.fromtimestamp(
response.status.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc),
input=response.spec.input,
output=response.spec.output,
completion_time=(
datetime.fromtimestamp(
response.status.completion_time / 1000, tz=timezone.utc
)
if response.status.HasField("completion_time")
else None
),
completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None),
events=[
Event(
code=event.code,
message=event.message,
creation_time=datetime.fromtimestamp(
event.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(event.creation_time / 1000, tz=timezone.utc),
)
for event in response.status.events
],
)

except grpc.RpcError as e:
raise FlameError(
FlameErrorCode.INTERNAL, f"failed to get task: {e.details()}"
)
raise FlameError(FlameErrorCode.INTERNAL, f"failed to get task: {e.details()}")

async def watch_task(self, task_id: TaskID) -> "TaskWatcher":
"""Watch a task for updates."""
Expand All @@ -671,13 +576,9 @@ async def watch_task(self, task_id: TaskID) -> "TaskWatcher":
return TaskWatcher(stream)

except grpc.RpcError as e:
raise FlameError(
FlameErrorCode.INTERNAL, f"failed to watch task: {e.details()}"
)
raise FlameError(FlameErrorCode.INTERNAL, f"failed to watch task: {e.details()}")

async def invoke(
self, input_data, informer: Optional[TaskInformer] = None
) -> TaskOutput:
async def invoke(self, input_data, informer: Optional[TaskInformer] = None) -> TaskOutput:
"""Invoke a task with the given input and optional informer."""
if input_data is None:
pass
Expand Down Expand Up @@ -729,25 +630,15 @@ async def __anext__(self) -> Task:
id=response.metadata.id,
session_id=response.spec.session_id,
state=TaskState(response.status.state),
creation_time=datetime.fromtimestamp(
response.status.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc),
input=response.spec.input,
output=response.spec.output,
completion_time=(
datetime.fromtimestamp(
response.status.completion_time / 1000, tz=timezone.utc
)
if response.status.HasField("completion_time")
else None
),
completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None),
events=[
Event(
code=event.code,
message=event.message,
creation_time=datetime.fromtimestamp(
event.creation_time / 1000, tz=timezone.utc
),
creation_time=datetime.fromtimestamp(event.creation_time / 1000, tz=timezone.utc),
)
for event in response.status.events
],
Expand Down
Loading
Loading