diff --git a/sdk/python/src/flamepy/cache.py b/sdk/python/src/flamepy/cache.py index 79de567..3b7a287 100644 --- a/sdk/python/src/flamepy/cache.py +++ b/sdk/python/src/flamepy/cache.py @@ -45,6 +45,7 @@ async def put_object(session_id: str, data: bytes) -> "DataExpr": return DataExpr(source=DataSource.REMOTE, url=metadata.endpoint, data=data, version=metadata.version) + async def get_object(de: DataExpr) -> "DataExpr": """Get an object from the cache.""" if de.source != DataSource.REMOTE: @@ -61,6 +62,7 @@ async def get_object(de: DataExpr) -> "DataExpr": return de + async def update_object(de: DataExpr) -> "DataExpr": """Update an object in the cache.""" if de.source != DataSource.REMOTE: diff --git a/sdk/python/src/flamepy/client.py b/sdk/python/src/flamepy/client.py index 347edb5..08c3b16 100644 --- a/sdk/python/src/flamepy/client.py +++ b/sdk/python/src/flamepy/client.py @@ -69,10 +69,7 @@ async def connect(addr: str) -> "Connection": return await Connection.connect(addr) -async def create_session(application: str, - common_data: Dict[str, Any] = None, - session_id: str = None, - slots: int = 1) -> "Session": +async def create_session(application: str, common_data: Dict[str, Any] = None, session_id: str = None, slots: int = 1) -> "Session": conn = await ConnectionInstance.instance() @@ -81,19 +78,14 @@ async def create_session(application: str, elif isinstance(common_data, FlameRequest): common_data = common_data.to_json() elif not isinstance(common_data, CommonData): - raise ValueError( - "Invalid common data type, must be a Request or CommonData") + raise ValueError("Invalid common data type, must be a Request or CommonData") session_id = short_name(application) if session_id is None else session_id data_expr = await put_object(session_id, common_data) common_data = CommonData(data_expr.to_json()) - session = await conn.create_session( - SessionAttributes(id=session_id, - application=application, - common_data=common_data, - slots=slots)) + session = await conn.create_session(SessionAttributes(id=session_id, application=application, common_data=common_data, slots=slots)) return session @@ -102,8 +94,7 @@ async def open_session(session_id: SessionID) -> "Session": return await conn.open_session(session_id) -async def register_application( - name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]]) -> None: +async def register_application(name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]]) -> None: conn = await ConnectionInstance.instance() await conn.register_application(name, app_attrs) @@ -201,17 +192,13 @@ async def connect(cls, addr: str) -> "Connection": return cls(addr, channel, frontend) except Exception as e: - raise FlameError( - FlameErrorCode.INVALID_CONFIG, f"failed to connect to {addr}: {str(e)}" - ) + raise FlameError(FlameErrorCode.INVALID_CONFIG, f"failed to connect to {addr}: {str(e)}") async def close(self) -> None: """Close the connection.""" await self._channel.close() - async def register_application( - self, name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]] - ) -> None: + async def register_application(self, name: str, app_attrs: Union[ApplicationAttributes, Dict[str, Any]]) -> None: """Register a new application.""" if isinstance(app_attrs, dict): app_attrs = ApplicationAttributes(**app_attrs) @@ -292,9 +279,7 @@ async def list_applications(self) -> List[Application]: name=app.metadata.name, shim=Shim(app.spec.shim), state=ApplicationState(app.status.state), - creation_time=datetime.fromtimestamp( - app.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(app.status.creation_time / 1000, tz=timezone.utc), image=app.spec.image, command=app.spec.command, arguments=list(app.spec.arguments), @@ -309,9 +294,7 @@ async def list_applications(self) -> List[Application]: return applications except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to list applications: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to list applications: {e.details()}") async def get_application(self, name: str) -> Application: """Get an application by name.""" @@ -337,9 +320,7 @@ async def get_application(self, name: str) -> Application: name=response.metadata.name, shim=Shim(response.spec.shim), state=ApplicationState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), image=response.spec.image, command=response.spec.command, arguments=list(response.spec.arguments), @@ -351,9 +332,7 @@ async def get_application(self, name: str) -> Application: ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to get application: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to get application: {e.details()}") async def create_session(self, attrs: SessionAttributes) -> "Session": """Create a new session.""" @@ -374,26 +353,16 @@ async def create_session(self, attrs: SessionAttributes) -> "Session": application=response.spec.application, slots=response.spec.slots, state=SessionState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), pending=response.status.pending, running=response.status.running, succeed=response.status.succeed, failed=response.status.failed, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), ) return session except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to create session: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to create session: {e.details()}") async def list_sessions(self) -> List["Session"]: """List all sessions.""" @@ -411,29 +380,19 @@ async def list_sessions(self) -> List["Session"]: application=session.spec.application, slots=session.spec.slots, state=SessionState(session.status.state), - creation_time=datetime.fromtimestamp( - session.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(session.status.creation_time / 1000, tz=timezone.utc), pending=session.status.pending, running=session.status.running, succeed=session.status.succeed, failed=session.status.failed, - completion_time=( - datetime.fromtimestamp( - session.status.completion_time / 1000, tz=timezone.utc - ) - if session.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(session.status.completion_time / 1000, tz=timezone.utc) if session.status.HasField("completion_time") else None), ) ) return sessions except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to list sessions: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to list sessions: {e.details()}") async def open_session(self, session_id: SessionID) -> "Session": """Open a session.""" @@ -447,26 +406,16 @@ async def open_session(self, session_id: SessionID) -> "Session": application=response.spec.application, slots=response.spec.slots, state=SessionState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), pending=response.status.pending, running=response.status.running, succeed=response.status.succeed, failed=response.status.failed, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to open session: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to open session: {e.details()}") async def get_session(self, session_id: SessionID) -> "Session": """Get a session by ID.""" @@ -481,26 +430,16 @@ async def get_session(self, session_id: SessionID) -> "Session": application=response.spec.application, slots=response.spec.slots, state=SessionState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), pending=response.status.pending, running=response.status.running, succeed=response.status.succeed, failed=response.status.failed, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to get session: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to get session: {e.details()}") async def close_session(self, session_id: SessionID) -> "Session": """Close a session.""" @@ -515,26 +454,16 @@ async def close_session(self, session_id: SessionID) -> "Session": application=response.spec.application, slots=response.spec.slots, state=SessionState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), pending=response.status.pending, running=response.status.running, succeed=response.status.succeed, failed=response.status.failed, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to close session: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to close session: {e.details()}") class Session: @@ -594,33 +523,21 @@ async def create_task(self, input_data: TaskInput) -> Task: id=response.metadata.id, session_id=self.id, state=TaskState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), input=input_data, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), events=[ Event( code=event.code, message=event.message, - creation_time=datetime.fromtimestamp( - event.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(event.creation_time / 1000, tz=timezone.utc), ) for event in response.status.events ], ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to create task: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to create task: {e.details()}") async def get_task(self, task_id: TaskID) -> Task: """Get a task by ID.""" @@ -633,34 +550,22 @@ async def get_task(self, task_id: TaskID) -> Task: id=response.metadata.id, session_id=self.id, state=TaskState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), input=response.spec.input, output=response.spec.output, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), events=[ Event( code=event.code, message=event.message, - creation_time=datetime.fromtimestamp( - event.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(event.creation_time / 1000, tz=timezone.utc), ) for event in response.status.events ], ) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to get task: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to get task: {e.details()}") async def watch_task(self, task_id: TaskID) -> "TaskWatcher": """Watch a task for updates.""" @@ -671,13 +576,9 @@ async def watch_task(self, task_id: TaskID) -> "TaskWatcher": return TaskWatcher(stream) except grpc.RpcError as e: - raise FlameError( - FlameErrorCode.INTERNAL, f"failed to watch task: {e.details()}" - ) + raise FlameError(FlameErrorCode.INTERNAL, f"failed to watch task: {e.details()}") - async def invoke( - self, input_data, informer: Optional[TaskInformer] = None - ) -> TaskOutput: + async def invoke(self, input_data, informer: Optional[TaskInformer] = None) -> TaskOutput: """Invoke a task with the given input and optional informer.""" if input_data is None: pass @@ -729,25 +630,15 @@ async def __anext__(self) -> Task: id=response.metadata.id, session_id=response.spec.session_id, state=TaskState(response.status.state), - creation_time=datetime.fromtimestamp( - response.status.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(response.status.creation_time / 1000, tz=timezone.utc), input=response.spec.input, output=response.spec.output, - completion_time=( - datetime.fromtimestamp( - response.status.completion_time / 1000, tz=timezone.utc - ) - if response.status.HasField("completion_time") - else None - ), + completion_time=(datetime.fromtimestamp(response.status.completion_time / 1000, tz=timezone.utc) if response.status.HasField("completion_time") else None), events=[ Event( code=event.code, message=event.message, - creation_time=datetime.fromtimestamp( - event.creation_time / 1000, tz=timezone.utc - ), + creation_time=datetime.fromtimestamp(event.creation_time / 1000, tz=timezone.utc), ) for event in response.status.events ], diff --git a/sdk/python/src/flamepy/instance.py b/sdk/python/src/flamepy/instance.py index 8ee05dd..c43aeee 100644 --- a/sdk/python/src/flamepy/instance.py +++ b/sdk/python/src/flamepy/instance.py @@ -45,11 +45,8 @@ def __init__(self): self._entrypoint = None self._parameter = None self._return_type = None - self._input_schema = None - self._output_schema = None self._context = None - self._context_schema = None self._context_parameter = None self._queue = None @@ -60,15 +57,9 @@ def context(self, func): sig = inspect.signature(func) self._context = func - assert ( - len(sig.parameters) == 1 or len(sig.parameters) == 0 - ), "Context must have exactly zero or one parameter" + assert len(sig.parameters) == 1 or len(sig.parameters) == 0, "Context must have exactly zero or one parameter" for param in sig.parameters.values(): - assert ( - param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ), "Parameter must be positional or keyword" - if param.annotation is not inspect._empty: - self._context_schema = param.annotation.model_json_schema() + assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD, "Parameter must be positional or keyword" self._context_parameter = param def entrypoint(self, func): @@ -77,20 +68,13 @@ def entrypoint(self, func): sig = inspect.signature(func) self._entrypoint = func - assert ( - len(sig.parameters) == 1 or len(sig.parameters) == 0 - ), "Entrypoint must have exactly zero or one parameter" + assert len(sig.parameters) == 1 or len(sig.parameters) == 0, "Entrypoint must have exactly zero or one parameter" for param in sig.parameters.values(): - assert ( - param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD - ), "Parameter must be positional or keyword" - if param.annotation is not inspect._empty: - self._input_schema = param.annotation.model_json_schema() + assert param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD, "Parameter must be positional or keyword" self._parameter = param if sig.return_annotation is not inspect._empty: self._return_type = sig.return_annotation - self._output_schema = self._return_type.model_json_schema() async def on_session_enter(self, context: SessionContext): logger = logging.getLogger(__name__) @@ -109,13 +93,7 @@ async def on_session_enter(self, context: SessionContext): else: self._context() else: - obj = ( - self._context_parameter.annotation.model_validate_json( - context.common_data - ) - if context.common_data is not None - else None - ) + obj = self._context_parameter.annotation.model_validate(context.common_data) if context.common_data is not None else None if inspect.iscoroutinefunction(self._context): await self._context(obj) else: @@ -133,11 +111,7 @@ async def on_task_invoke(self, context: TaskContext) -> TaskOutput: self._queue = context._queue if self._parameter is not None: - obj = ( - self._parameter.annotation.model_validate_json(context.input) - if context.input is not None - else None - ) + obj = self._parameter.annotation.model_validate(context.input) if context.input is not None else None if inspect.iscoroutinefunction(self._entrypoint): res = await self._entrypoint(obj) else: @@ -148,7 +122,7 @@ async def on_task_invoke(self, context: TaskContext) -> TaskOutput: else: res = self._entrypoint() - res = self._return_type.model_validate(res).model_dump_json() + res = self._return_type.model_dump(res) logger.debug(f"on_task_invoke: {res}") self.task_id = None @@ -166,9 +140,7 @@ async def record_event(self, code: int, message: Optional[str] = None): if self._queue is not None: await self._queue.put( WatchEventResponseProto( - owner=EventOwnerProto( - session_id=self.session_id, task_id=self.task_id - ), + owner=EventOwnerProto(session_id=self.session_id, task_id=self.task_id), event=EventProto( code=code, message=message, @@ -208,15 +180,11 @@ def run_debug_service(instance: FlameInstance): if instance._context is not None: context_name = instance._context.__name__ - debug_service.add_api_route( - f"/{context_name}", context_local_api, methods=["POST"] - ) + debug_service.add_api_route(f"/{context_name}", context_local_api, methods=["POST"]) if instance._entrypoint is not None: entrypoint_name = instance._entrypoint.__name__ - debug_service.add_api_route( - f"/{entrypoint_name}", entrypoint_local_api, methods=["POST"] - ) + debug_service.add_api_route(f"/{entrypoint_name}", entrypoint_local_api, methods=["POST"]) uvicorn.run(debug_service, host="0.0.0.0", port=5050) diff --git a/sdk/python/src/flamepy/service.py b/sdk/python/src/flamepy/service.py index b7b54f0..432c0f2 100644 --- a/sdk/python/src/flamepy/service.py +++ b/sdk/python/src/flamepy/service.py @@ -37,8 +37,9 @@ FLAME_INSTANCE_ENDPOINT = "FLAME_INSTANCE_ENDPOINT" + class TraceFn: - def __init__(self, name:str): + def __init__(self, name: str): self.name = name logger.debug(f"{name} Enter") @@ -69,9 +70,7 @@ async def record_event(self, code: int, message: Optional[str] = None): """Record an event.""" event = WatchEventResponseProto( owner=EventOwnerProto(session_id=self.session_id, task_id=None), - event=EventProto( - code=code, message=message, creation_time=int(time.time() * 1000) - ), + event=EventProto(code=code, message=message, creation_time=int(time.time() * 1000)), ) await self._queue.put(event) @@ -89,9 +88,7 @@ async def record_event(self, code: int, message: Optional[str] = None): """Record an event.""" event = WatchEventResponseProto( owner=EventOwnerProto(session_id=self.session_id, task_id=self.task_id), - event=EventProto( - code=code, message=message, creation_time=int(time.time() * 1000) - ), + event=EventProto(code=code, message=message, creation_time=int(time.time() * 1000)), ) await self._queue.put(event) @@ -163,16 +160,8 @@ async def OnSessionEnter(self, request, context): app_context = ApplicationContext( name=request.application.name, shim=Shim(request.application.shim), - image=( - request.application.image - if request.application.HasField("image") - else None - ), - command=( - request.application.command - if request.application.HasField("command") - else None - ), + image=(request.application.image if request.application.HasField("image") else None), + command=(request.application.command if request.application.HasField("command") else None), ) logger.debug(f"app_context: {app_context}") @@ -297,13 +286,9 @@ async def start(self): endpoint = os.getenv(FLAME_INSTANCE_ENDPOINT) if endpoint is not None: self._server.add_insecure_port(f"unix://{endpoint}") - logger.debug( - f"Flame Python instance service started on Unix socket: {endpoint}" - ) + logger.debug(f"Flame Python instance service started on Unix socket: {endpoint}") else: - raise FlameError( - FlameErrorCode.INVALID_CONFIG, "FLAME_INSTANCE_ENDPOINT not found" - ) + raise FlameError(FlameErrorCode.INVALID_CONFIG, "FLAME_INSTANCE_ENDPOINT not found") # Start server await self._server.start() diff --git a/sdk/python/src/flamepy/types.py b/sdk/python/src/flamepy/types.py index 5756b71..92a3d87 100644 --- a/sdk/python/src/flamepy/types.py +++ b/sdk/python/src/flamepy/types.py @@ -191,25 +191,25 @@ def on_error(self, error: FlameError) -> None: class Request(BaseModel): @classmethod def from_json(cls, json_data): - return cls.model_validate_json(json_data.decode("utf-8")) + return cls.model_validate(json_data) def to_json(self) -> bytes: - return self.model_dump_json().encode("utf-8") + return self.model_dump() class Response(BaseModel): @classmethod def from_json(cls, json_data): - return cls.model_validate_json(json_data.decode("utf-8")) + return cls.model_validate(json_data) def to_json(self) -> bytes: - return self.model_dump_json().encode("utf-8") + return self.model_dump() def short_name(prefix: str, length: int = 6) -> str: """Generate a short name with a prefix.""" alphabet = string.ascii_letters + string.digits - sn = ''.join(random.SystemRandom().choice(alphabet) for _ in range(length)) + sn = "".join(random.SystemRandom().choice(alphabet) for _ in range(length)) return f"{prefix}-{sn}" @@ -227,16 +227,14 @@ def __init__(self): config = yaml.safe_load(f) cc = config.get("current-cluster") if cc is None: - raise FlameError(FlameErrorCode.INVALID_CONFIG, - "current-cluster is not set") + raise FlameError(FlameErrorCode.INVALID_CONFIG, "current-cluster is not set") for cluster in config.get("clusters", []): if cc == cluster["name"]: self._endpoint = cluster.get("endpoint") self._cache_endpoint = cluster.get("cache") break else: - raise FlameError(FlameErrorCode.INVALID_CONFIG, - f"cluster <{cc}> not found") + raise FlameError(FlameErrorCode.INVALID_CONFIG, f"cluster <{cc}> not found") endpoint = os.getenv("FLAME_ENDPOINT") if endpoint is not None: @@ -274,4 +272,4 @@ def to_json(self) -> bytes: @classmethod def from_json(cls, json_data: bytes) -> "DataExpr": data = bson.loads(json_data) - return cls(**data) \ No newline at end of file + return cls(**data)