diff --git a/e2e/src/e2e/api.py b/e2e/src/e2e/api.py index 51fd0c84..a5711d94 100644 --- a/e2e/src/e2e/api.py +++ b/e2e/src/e2e/api.py @@ -11,15 +11,18 @@ limitations under the License. """ -import flamepy from typing import Optional +from dataclasses import dataclass -class TestContext(flamepy.Request): +@dataclass +class TestContext: common_data: Optional[str] = None -class TestRequest(flamepy.Request): +@dataclass +class TestRequest: input: Optional[str] = None -class TestResponse(flamepy.Response): +@dataclass +class TestResponse: output: Optional[str] = None common_data: Optional[str] = None diff --git a/e2e/src/e2e/service.py b/e2e/src/e2e/service.py index 1fc4506d..a81f7a37 100644 --- a/e2e/src/e2e/service.py +++ b/e2e/src/e2e/service.py @@ -13,21 +13,16 @@ import flamepy -from e2e import TestRequest, TestResponse, TestContext +from e2e.api import TestRequest, TestResponse instance = flamepy.FlameInstance() -sys_context = None - @instance.entrypoint def e2e_service_entrypoint(req: TestRequest) -> TestResponse: - return TestResponse(output=req.input, common_data=sys_context) + cxt = instance.context() + data = cxt.common_data if cxt is not None else None -@instance.context -def e2e_service_context(ctx: TestContext = None): - global sys_context - if ctx is not None: - sys_context = ctx.common_data + return TestResponse(output=req.input, common_data=data) if __name__ == "__main__": - instance.run() \ No newline at end of file + instance.run() diff --git a/e2e/tests/test_session.py b/e2e/tests/test_session.py index fec8e0d3..ba560497 100644 --- a/e2e/tests/test_session.py +++ b/e2e/tests/test_session.py @@ -36,8 +36,7 @@ def __init__(self, expected_output): def on_update(self, task): self.latest_state = task.state if task.state == flamepy.TaskState.SUCCEED: - resp = TestResponse.from_json(task.output) - assert resp.output == self.expected_output + assert task.output.output == self.expected_output, f"Task output: {task.output.output}, Expected: {self.expected_output}" elif task.state == flamepy.TaskState.FAILED: for event in task.events: if event.code == flamepy.TaskState.FAILED: @@ -99,8 +98,7 @@ async def test_invoke_task_without_common_data(): input = random_string() - resp = await session.invoke(TestRequest(input=input)) - output = TestResponse.from_json(resp) + output = await session.invoke(TestRequest(input=input)) assert output.output == input assert output.common_data is None @@ -122,8 +120,7 @@ async def test_invoke_task_with_common_data(): assert ssn_list[0].application == FLM_TEST_APP assert ssn_list[0].state == SessionState.OPEN - resp = await session.invoke(TestRequest(input=input)) - output = TestResponse.from_json(resp) + output = await session.invoke(TestRequest(input=input)) assert output.output == input assert output.common_data == sys_context diff --git a/sdk/python/example/instance.py b/sdk/python/example/instance.py index c2bd57a5..f32edd2d 100644 --- a/sdk/python/example/instance.py +++ b/sdk/python/example/instance.py @@ -32,13 +32,6 @@ class Summary(flamepy.Response): You are a helpful assistant. """ - -@ins.context -def sys_context(sp: SysPrompt): - global sys_prompt - sys_prompt = sp.prompt - - @ins.entrypoint def summarize_blog(bl: Blog) -> Summary: global sys_prompt diff --git a/sdk/python/src/flamepy/__init__.py b/sdk/python/src/flamepy/__init__.py index 1b70d9e1..afe4c2ca 100644 --- a/sdk/python/src/flamepy/__init__.py +++ b/sdk/python/src/flamepy/__init__.py @@ -52,8 +52,6 @@ Application, FlameContext, TaskInformer, - Request, - Response, ) from .client import ( @@ -113,8 +111,6 @@ "Application", "FlameContext", "TaskInformer", - "Request", - "Response", # Client classes "Connection", "connect", diff --git a/sdk/python/src/flamepy/cache.py b/sdk/python/src/flamepy/cache.py index 79de5671..7185dc06 100644 --- a/sdk/python/src/flamepy/cache.py +++ b/sdk/python/src/flamepy/cache.py @@ -20,7 +20,7 @@ class Object(BaseModel): """Object.""" version: int - data: list + data: list class ObjectMetadata(BaseModel): """Object metadata.""" @@ -45,8 +45,12 @@ async def put_object(session_id: str, data: bytes) -> "DataExpr": return DataExpr(source=DataSource.REMOTE, url=metadata.endpoint, data=data, version=metadata.version) + async def get_object(de: DataExpr) -> "DataExpr": """Get an object from the cache.""" + if de is None: + return None + if de.source != DataSource.REMOTE: return de @@ -63,6 +67,9 @@ async def get_object(de: DataExpr) -> "DataExpr": async def update_object(de: DataExpr) -> "DataExpr": """Update an object in the cache.""" + if de is None: + return None + if de.source != DataSource.REMOTE: return de diff --git a/sdk/python/src/flamepy/client.py b/sdk/python/src/flamepy/client.py index 347edb50..99527639 100644 --- a/sdk/python/src/flamepy/client.py +++ b/sdk/python/src/flamepy/client.py @@ -17,6 +17,7 @@ import grpc import grpc.aio import asyncio +import pickle from datetime import datetime, timezone from .cache import put_object @@ -39,8 +40,6 @@ FlameError, FlameErrorCode, TaskInformer, - Request as FlameRequest, - Response as FlameResponse, FlameContext, ApplicationSchema, short_name, @@ -69,32 +68,9 @@ async def connect(addr: str) -> "Connection": return await Connection.connect(addr) -async def create_session(application: str, - common_data: Dict[str, Any] = None, - session_id: str = None, - slots: int = 1) -> "Session": - +async def create_session(application: str, common_data: Any = None, session_id: Optional[str] = None, slots: int = 1) -> "Session": conn = await ConnectionInstance.instance() - - if common_data is None: - pass - elif isinstance(common_data, FlameRequest): - common_data = common_data.to_json() - elif not isinstance(common_data, CommonData): - raise ValueError( - "Invalid common data type, must be a Request or CommonData") - - session_id = short_name(application) if session_id is None else session_id - - data_expr = await put_object(session_id, common_data) - common_data = CommonData(data_expr.to_json()) - - session = await conn.create_session( - SessionAttributes(id=session_id, - application=application, - common_data=common_data, - slots=slots)) - return session + return await conn.create_session(SessionAttributes(id=session_id, application=application, common_data=common_data, slots=slots)) async def open_session(session_id: SessionID) -> "Session": @@ -102,8 +78,7 @@ async def open_session(session_id: SessionID) -> "Session": return await conn.open_session(session_id) -async def register_application( - name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]]) -> None: +async def register_application(name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]]) -> None: conn = await ConnectionInstance.instance() await conn.register_application(name, app_attrs) @@ -201,17 +176,13 @@ async def connect(cls, addr: str) -> "Connection": return cls(addr, channel, frontend) except Exception as e: - raise FlameError( - FlameErrorCode.INVALID_CONFIG, f"failed to connect to {addr}: {str(e)}" - ) + raise FlameError(FlameErrorCode.INVALID_CONFIG, f"failed to connect to {addr}: {str(e)}") async def close(self) -> None: """Close the connection.""" await self._channel.close() - async def register_application( - self, name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]] - ) -> None: + async def register_application(self, name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]]) -> None: """Register a new application.""" if isinstance(app_attrs, dict): app_attrs = ApplicationAttributes(**app_attrs) @@ -292,9 +263,7 @@ async def list_applications(self) -> List[Application]: name=app.metadata.name, shim=Shim(app.spec.shim), state=ApplicationState(app.status.state), - creation_time=datetime.fromtimestamp( - app.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(app.status.creation_time / 1000, tz=timezone.utc), image=app.spec.image, command=app.spec.command, arguments=list(app.spec.arguments), @@ -309,9 +278,7 @@ async def list_applications(self) -> List[Application]: return applications except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to list applications: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to list applications: {e.details()}") async def get_application(self, name: str) -> Application: """Get an application by name.""" @@ -337,9 +304,7 @@ async def get_application(self, name: str) -> Application: name=response.metadata.name, shim=Shim(response.spec.shim), state=ApplicationState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), image=response.spec.image, command=response.spec.command, arguments=list(response.spec.arguments), @@ -351,19 +316,24 @@ async def get_application(self, name: str) -> Application: ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to get application: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to get application: {e.details()}") async def create_session(self, attrs: SessionAttributes) -> "Session": """Create a new session.""" + + session_id = short_name(attrs.application) if attrs.id is None else attrs.id + + common_data_bin = pickle.dumps(attrs.common_data, protocol=pickle.HIGHEST_PROTOCOL) + + data_expr = await put_object(session_id, common_data_bin) + session_spec = SessionSpec( application=attrs.application, slots=attrs.slots, - common_data=attrs.common_data, + common_data=data_expr.encode(), ) - request = CreateSessionRequest(session_id=attrs.id, session=session_spec) + request = CreateSessionRequest(session_id=session_id, session=session_spec) try: response = await self._frontend.CreateSession(request) @@ -374,26 +344,16 @@ async def create_session(self, attrs: SessionAttributes) -> "Session": application=response.spec.application, slots=response.spec.slots, state=SessionState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), pending=response.status.pending, running=response.status.running, succeed=response.status.succeed, failed=response.status.failed, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), ) return session except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to create session: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to create session: {e.details()}") async def list_sessions(self) -> List["Session"]: """List all sessions.""" @@ -411,29 +371,19 @@ async def list_sessions(self) -> List["Session"]: application=session.spec.application, slots=session.spec.slots, state=SessionState(session.status.state), - creation_time=datetime.fromtimestamp( - session.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(session.status.creation_time / 1000, tz=timezone.utc), pending=session.status.pending, running=session.status.running, succeed=session.status.succeed, failed=session.status.failed, - completion_time=( - datetime.fromtimestamp( - session.status.completion_time / 1000, tz=timezone.utc - ) - if session.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(session.status.completion_time / 1000, tz=timezone.utc) if session.status.HasField("completion_time") else None), ) ) return sessions except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to list sessions: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to list sessions: {e.details()}") async def open_session(self, session_id: SessionID) -> "Session": """Open a session.""" @@ -447,26 +397,16 @@ async def open_session(self, session_id: SessionID) -> "Session": application=response.spec.application, slots=response.spec.slots, state=SessionState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), pending=response.status.pending, running=response.status.running, succeed=response.status.succeed, failed=response.status.failed, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to open session: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to open session: {e.details()}") async def get_session(self, session_id: SessionID) -> "Session": """Get a session by ID.""" @@ -481,26 +421,16 @@ async def get_session(self, session_id: SessionID) -> "Session": application=response.spec.application, slots=response.spec.slots, state=SessionState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), pending=response.status.pending, running=response.status.running, succeed=response.status.succeed, failed=response.status.failed, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to get session: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to get session: {e.details()}") async def close_session(self, session_id: SessionID) -> "Session": """Close a session.""" @@ -515,31 +445,20 @@ async def close_session(self, session_id: SessionID) -> "Session": application=response.spec.application, slots=response.spec.slots, state=SessionState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), pending=response.status.pending, running=response.status.running, succeed=response.status.succeed, failed=response.status.failed, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to close session: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to close session: {e.details()}") class Session: connection: Connection - """Represents a computing session.""" id: SessionID application: str @@ -551,7 +470,6 @@ class Session: succeed: int = 0 failed: int = 0 completion_time: Optional[datetime] = None - """Client for session-specific operations.""" def __init__( @@ -581,9 +499,11 @@ def __init__( self.completion_time = completion_time self.mutex = threading.Lock() - async def create_task(self, input_data: TaskInput) -> Task: + async def create_task(self, input_data: Any) -> Task: """Create a new task in the session.""" - task_spec = TaskSpec(session_id=self.id, input=input_data) + input_bin = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL) + + task_spec = TaskSpec(session_id=self.id, input=input_bin) request = CreateTaskRequest(task=task_spec) @@ -594,33 +514,21 @@ async def create_task(self, input_data: TaskInput) -> Task: id=response.metadata.id, session_id=self.id, state=TaskState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), input=input_data, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), events=[ Event( code=event.code, message=event.message, - creation_time=datetime.fromtimestamp( - event.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(event.creation_time / 1000, tz=timezone.utc), ) for event in response.status.events ], ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to create task: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to create task: {e.details()}") async def get_task(self, task_id: TaskID) -> Task: """Get a task by ID.""" @@ -633,34 +541,22 @@ async def get_task(self, task_id: TaskID) -> Task: id=response.metadata.id, session_id=self.id, state=TaskState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), - input=response.spec.input, - output=response.spec.output, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), + input=pickle.loads(response.spec.input) if response.spec.input is not None else None, + output=pickle.loads(response.spec.output) if response.spec.output is not None else None, + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), events=[ Event( code=event.code, message=event.message, - creation_time=datetime.fromtimestamp( - event.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(event.creation_time / 1000, tz=timezone.utc), ) for event in response.status.events ], ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to get task: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to get task: {e.details()}") async def watch_task(self, task_id: TaskID) -> "TaskWatcher": """Watch a task for updates.""" @@ -671,21 +567,10 @@ async def watch_task(self, task_id: TaskID) -> "TaskWatcher": return TaskWatcher(stream) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to watch task: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to watch task: {e.details()}") - async def invoke( - self, input_data, informer: Optional[TaskInformer] = None - ) -> TaskOutput: + async def invoke(self, input_data: Any, informer: Optional[TaskInformer] = None) -> Any: """Invoke a task with the given input and optional informer.""" - if input_data is None: - pass - if isinstance(input_data, FlameRequest): - input_data = input_data.to_json() - elif not isinstance(input_data, TaskInput): - raise ValueError("Invalid input data type, must be a Request or TaskInput") - task = await self.create_task(input_data) watcher = await self.watch_task(task.id) @@ -729,25 +614,15 @@ async def __anext__(self) -> Task: id=response.metadata.id, session_id=response.spec.session_id, state=TaskState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), - input=response.spec.input, - output=response.spec.output, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), + input=pickle.loads(response.spec.input) if response.spec.HasField("input") else None, + output=pickle.loads(response.spec.output) if response.spec.HasField("output") else None, + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), events=[ Event( code=event.code, message=event.message, - creation_time=datetime.fromtimestamp( - event.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(event.creation_time / 1000, tz=timezone.utc), ) for event in response.status.events ], diff --git a/sdk/python/src/flamepy/instance.py b/sdk/python/src/flamepy/instance.py index 8ee05dda..2873ec6c 100644 --- a/sdk/python/src/flamepy/instance.py +++ b/sdk/python/src/flamepy/instance.py @@ -16,7 +16,6 @@ import uvicorn import os import time -from pydantic import BaseModel from typing import Optional, Dict, Any, Union from fastapi import FastAPI, Request as FastAPIRequest, Response as FastAPIResponse @@ -38,38 +37,22 @@ class FlameInstance(FlameService): + def __init__(self): self.session_id = None self.task_id = None self._entrypoint = None self._parameter = None - self._return_type = None - self._input_schema = None - self._output_schema = None - self._context = None - self._context_schema = None - self._context_parameter = None + self._context: Any = None self._queue = None - def context(self, func): + def context(self) -> Any: logger = logging.getLogger(__name__) - logger.debug(f"context: {func.__name__}") - - sig = inspect.signature(func) - self._context = func - assert ( - len(sig.parameters) == 1 or len(sig.parameters) == 0 - ), "Context must have exactly zero or one parameter" - for param in sig.parameters.values(): - assert ( - param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ), "Parameter must be positional or keyword" - if param.annotation is not inspect._empty: - self._context_schema = param.annotation.model_json_schema() - self._context_parameter = param + logger.debug(f"context: {self._context}") + return self._context def entrypoint(self, func): logger = logging.getLogger(__name__) @@ -77,49 +60,20 @@ def entrypoint(self, func): sig = inspect.signature(func) self._entrypoint = func - assert ( - len(sig.parameters) == 1 or len(sig.parameters) == 0 - ), "Entrypoint must have exactly zero or one parameter" + assert len(sig.parameters) == 1 or len(sig.parameters) == 0, "Entrypoint must have exactly zero or one parameter" for param in sig.parameters.values(): - assert ( - param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ), "Parameter must be positional or keyword" - if param.annotation is not inspect._empty: - self._input_schema = param.annotation.model_json_schema() + assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD, "Parameter must be positional or keyword" self._parameter = param - if sig.return_annotation is not inspect._empty: - self._return_type = sig.return_annotation - self._output_schema = self._return_type.model_json_schema() - async def on_session_enter(self, context: SessionContext): logger = logging.getLogger(__name__) logger.debug("on_session_enter") - if self._context is None: - logger.warning("No context function defined") - return self.session_id = context.session_id if self._queue is None: self._queue = context._queue - if self._context_parameter is None: - if inspect.iscoroutinefunction(self._context): - await self._context() - else: - self._context() - else: - obj = ( - self._context_parameter.annotation.model_validate_json( - context.common_data - ) - if context.common_data is not None - else None - ) - if inspect.iscoroutinefunction(self._context): - await self._context(obj) - else: - self._context(obj) + self._context = context.common_data async def on_task_invoke(self, context: TaskContext) -> TaskOutput: logger = logging.getLogger(__name__) @@ -129,31 +83,26 @@ async def on_task_invoke(self, context: TaskContext) -> TaskOutput: return self.task_id = context.task_id + if self._queue is None: self._queue = context._queue if self._parameter is not None: - obj = ( - self._parameter.annotation.model_validate_json(context.input) - if context.input is not None - else None - ) if inspect.iscoroutinefunction(self._entrypoint): - res = await self._entrypoint(obj) + res = await self._entrypoint(context.input) else: - res = self._entrypoint(obj) + res = self._entrypoint(context.input) else: if inspect.iscoroutinefunction(self._entrypoint): res = await self._entrypoint() else: res = self._entrypoint() - res = self._return_type.model_validate(res).model_dump_json() logger.debug(f"on_task_invoke: {res}") self.task_id = None - return TaskOutput(data=res.encode("utf-8")) + return TaskOutput(data=res) async def on_session_leave(self): logger = logging.getLogger(__name__) @@ -166,9 +115,7 @@ async def record_event(self, code: int, message: Optional[str] = None): if self._queue is not None: await self._queue.put( WatchEventResponseProto( - owner=EventOwnerProto( - session_id=self.session_id, task_id=self.task_id - ), + owner=EventOwnerProto(session_id=self.session_id, task_id=self.task_id), event=EventProto( code=code, message=message, @@ -206,40 +153,13 @@ def run_debug_service(instance: FlameInstance): debug_service = FastAPI() debug_service.state.instance = instance - if instance._context is not None: - context_name = instance._context.__name__ - debug_service.add_api_route( - f"/{context_name}", context_local_api, methods=["POST"] - ) - if instance._entrypoint is not None: entrypoint_name = instance._entrypoint.__name__ - debug_service.add_api_route( - f"/{entrypoint_name}", entrypoint_local_api, methods=["POST"] - ) + debug_service.add_api_route(f"/{entrypoint_name}", entrypoint_local_api, methods=["POST"]) uvicorn.run(debug_service, host="0.0.0.0", port=5050) -async def context_local_api(s: FastAPIRequest): - instance = s.app.state.instance - body_str = await s.body() - - await instance.on_session_enter( - SessionContext( - session_id=s.query_params.get("session_id") or "0", - application=ApplicationContext( - name="test", - shim=Shim.Host, - image=None, - command=None, - ), - common_data=body_str, - ) - ) - return FastAPIResponse(status_code=200, content="OK") - - async def entrypoint_local_api(s: FastAPIRequest): instance = s.app.state.instance body_str = await s.body() diff --git a/sdk/python/src/flamepy/service.py b/sdk/python/src/flamepy/service.py index b7b54f0b..642a98ea 100644 --- a/sdk/python/src/flamepy/service.py +++ b/sdk/python/src/flamepy/service.py @@ -15,6 +15,7 @@ import os import time import grpc +import pickle from abc import ABC, abstractmethod from typing import Optional, Dict, Any, Union from dataclasses import dataclass @@ -37,8 +38,9 @@ FLAME_INSTANCE_ENDPOINT = "FLAME_INSTANCE_ENDPOINT" + class TraceFn: - def __init__(self, name:str): + def __init__(self, name: str): self.name = name logger.debug(f"{name} Enter") @@ -63,15 +65,13 @@ class SessionContext: _queue: asyncio.Queue session_id: str application: ApplicationContext - common_data: Optional[bytes] = None + common_data: Any async def record_event(self, code: int, message: Optional[str] = None): """Record an event.""" event = WatchEventResponseProto( owner=EventOwnerProto(session_id=self.session_id, task_id=None), - event=EventProto( - code=code, message=message, creation_time=int(time.time() * 1000) - ), + event=EventProto(code=code, message=message, creation_time=int(time.time() * 1000)), ) await self._queue.put(event) @@ -83,15 +83,13 @@ class TaskContext: _queue: asyncio.Queue task_id: str session_id: str - input: Optional[bytes] = None + input: Any async def record_event(self, code: int, message: Optional[str] = None): """Record an event.""" event = WatchEventResponseProto( owner=EventOwnerProto(session_id=self.session_id, task_id=self.task_id), - event=EventProto( - code=code, message=message, creation_time=int(time.time() * 1000) - ), + event=EventProto(code=code, message=message, creation_time=int(time.time() * 1000)), ) await self._queue.put(event) @@ -100,7 +98,7 @@ async def record_event(self, code: int, message: Optional[str] = None): class TaskOutput: """Output from a task.""" - data: Optional[bytes] = None + data: Any class FlameService: @@ -163,30 +161,22 @@ async def OnSessionEnter(self, request, context): app_context = ApplicationContext( name=request.application.name, shim=Shim(request.application.shim), - image=( - request.application.image - if request.application.HasField("image") - else None - ), - command=( - request.application.command - if request.application.HasField("command") - else None - ), + image=(request.application.image if request.application.HasField("image") else None), + command=(request.application.command if request.application.HasField("command") else None), ) logger.debug(f"app_context: {app_context}") - self._common_data_expr = DataExpr.from_json(request.common_data) if request.HasField("common_data") else None + self._common_data_expr = DataExpr.decode(request.common_data) if request.HasField("common_data") else None + self._common_data_expr = await get_object(self._common_data_expr) - if self._common_data_expr is not None: - self._common_data_expr = await get_object(self._common_data_expr) + common_data = pickle.loads(self._common_data_expr.data) if self._common_data_expr is not None else None session_context = SessionContext( _queue=self._queue, session_id=request.session_id, application=app_context, - common_data=self._common_data_expr.data if self._common_data_expr is not None else None, + common_data=common_data, ) logger.debug(f"session_context: {session_context}") @@ -214,7 +204,7 @@ async def OnTaskInvoke(self, request, context): _queue=self._queue, task_id=request.task_id, session_id=request.session_id, - input=request.input if request.HasField("input") else None, + input=pickle.loads(request.input) if request.HasField("input") else None, ) logger.debug(f"task_context: {task_context}") @@ -223,8 +213,12 @@ async def OnTaskInvoke(self, request, context): output = await self._service.on_task_invoke(task_context) logger.debug("on_task_invoke completed successfully") + output_data = None + if output is not None and output.data is not None: + output_data = pickle.dumps(output.data, protocol=pickle.HIGHEST_PROTOCOL) + # Return task output - return TaskResultProto(return_code=0, output=output.data, message=None) + return TaskResultProto(return_code=0, output=output_data, message=None) except Exception as e: logger.error(f"Error in OnTaskInvoke: {e}") @@ -297,13 +291,9 @@ async def start(self): endpoint = os.getenv(FLAME_INSTANCE_ENDPOINT) if endpoint is not None: self._server.add_insecure_port(f"unix://{endpoint}") - logger.debug( - f"Flame Python instance service started on Unix socket: {endpoint}" - ) + logger.debug(f"Flame Python instance service started on Unix socket: {endpoint}") else: - raise FlameError( - FlameErrorCode.INVALID_CONFIG, "FLAME_INSTANCE_ENDPOINT not found" - ) + raise FlameError(FlameErrorCode.INVALID_CONFIG, "FLAME_INSTANCE_ENDPOINT not found") # Start server await self._server.start() diff --git a/sdk/python/src/flamepy/types.py b/sdk/python/src/flamepy/types.py index 5756b712..9c134f4e 100644 --- a/sdk/python/src/flamepy/types.py +++ b/sdk/python/src/flamepy/types.py @@ -101,10 +101,10 @@ class Event: class SessionAttributes: """Attributes for creating a session.""" - id: str application: str slots: int - common_data: Optional[bytes] = None + id: Optional[str] = None + common_data: Any = None @dataclass @@ -141,8 +141,8 @@ class Task: session_id: SessionID state: TaskState creation_time: datetime - input: Optional[bytes] = None - output: Optional[bytes] = None + input: Any = None + output: Any = None completion_time: Optional[datetime] = None events: Optional[List[Event]] = None @@ -188,24 +188,6 @@ def on_error(self, error: FlameError) -> None: pass -class Request(BaseModel): - @classmethod - def from_json(cls, json_data): - return cls.model_validate_json(json_data.decode("utf-8")) - - def to_json(self) -> bytes: - return self.model_dump_json().encode("utf-8") - - -class Response(BaseModel): - @classmethod - def from_json(cls, json_data): - return cls.model_validate_json(json_data.decode("utf-8")) - - def to_json(self) -> bytes: - return self.model_dump_json().encode("utf-8") - - def short_name(prefix: str, length: int = 6) -> str: """Generate a short name with a prefix.""" alphabet = string.ascii_letters + string.digits @@ -263,7 +245,7 @@ class DataExpr: version: int = 0 data: Optional[bytes] = None - def to_json(self) -> bytes: + def encode(self) -> bytes: data = asdict(self) # For remote data, the data is not included in the JSON if self.source == DataSource.REMOTE: @@ -272,6 +254,6 @@ def to_json(self) -> bytes: return bson.dumps(data) @classmethod - def from_json(cls, json_data: bytes) -> "DataExpr": + def decode(cls, json_data: bytes) -> "DataExpr": data = bson.loads(json_data) return cls(**data) \ No newline at end of file