Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions e2e/src/e2e/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
limitations under the License.
"""

import flamepy
from typing import Optional
from dataclasses import dataclass

class TestContext(flamepy.Request):
@dataclass
class TestContext:
common_data: Optional[str] = None

class TestRequest(flamepy.Request):
@dataclass
class TestRequest:
input: Optional[str] = None

class TestResponse(flamepy.Response):
@dataclass
class TestResponse:
output: Optional[str] = None
common_data: Optional[str] = None
15 changes: 5 additions & 10 deletions e2e/src/e2e/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,16 @@

import flamepy

from e2e import TestRequest, TestResponse, TestContext
from e2e.api import TestRequest, TestResponse

instance = flamepy.FlameInstance()

sys_context = None

@instance.entrypoint
def e2e_service_entrypoint(req: TestRequest) -> TestResponse:
return TestResponse(output=req.input, common_data=sys_context)
cxt = instance.context()
data = cxt.common_data if cxt is not None else None

@instance.context
def e2e_service_context(ctx: TestContext = None):
global sys_context
if ctx is not None:
sys_context = ctx.common_data
return TestResponse(output=req.input, common_data=data)

if __name__ == "__main__":
instance.run()
instance.run()
9 changes: 3 additions & 6 deletions e2e/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def __init__(self, expected_output):
def on_update(self, task):
self.latest_state = task.state
if task.state == flamepy.TaskState.SUCCEED:
resp = TestResponse.from_json(task.output)
assert resp.output == self.expected_output
assert task.output.output == self.expected_output, f"Task output: {task.output.output}, Expected: {self.expected_output}"
elif task.state == flamepy.TaskState.FAILED:
for event in task.events:
if event.code == flamepy.TaskState.FAILED:
Expand Down Expand Up @@ -99,8 +98,7 @@ async def test_invoke_task_without_common_data():

input = random_string()

resp = await session.invoke(TestRequest(input=input))
output = TestResponse.from_json(resp)
output = await session.invoke(TestRequest(input=input))
assert output.output == input
assert output.common_data is None

Expand All @@ -122,8 +120,7 @@ async def test_invoke_task_with_common_data():
assert ssn_list[0].application == FLM_TEST_APP
assert ssn_list[0].state == SessionState.OPEN

resp = await session.invoke(TestRequest(input=input))
output = TestResponse.from_json(resp)
output = await session.invoke(TestRequest(input=input))
assert output.output == input
assert output.common_data == sys_context

Expand Down
7 changes: 0 additions & 7 deletions sdk/python/example/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ class Summary(flamepy.Response):
You are a helpful assistant.
"""


@ins.context
def sys_context(sp: SysPrompt):
global sys_prompt
sys_prompt = sp.prompt


@ins.entrypoint
def summarize_blog(bl: Blog) -> Summary:
global sys_prompt
Expand Down
4 changes: 0 additions & 4 deletions sdk/python/src/flamepy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
Application,
FlameContext,
TaskInformer,
Request,
Response,
)

from .client import (
Expand Down Expand Up @@ -113,8 +111,6 @@
"Application",
"FlameContext",
"TaskInformer",
"Request",
"Response",
# Client classes
"Connection",
"connect",
Expand Down
9 changes: 8 additions & 1 deletion sdk/python/src/flamepy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class Object(BaseModel):
"""Object."""

version: int
data: list
data: list

class ObjectMetadata(BaseModel):
"""Object metadata."""
Expand All @@ -45,8 +45,12 @@ async def put_object(session_id: str, data: bytes) -> "DataExpr":

return DataExpr(source=DataSource.REMOTE, url=metadata.endpoint, data=data, version=metadata.version)


async def get_object(de: DataExpr) -> "DataExpr":
"""Get an object from the cache."""
if de is None:
return None

if de.source != DataSource.REMOTE:
return de

Expand All @@ -63,6 +67,9 @@ async def get_object(de: DataExpr) -> "DataExpr":

async def update_object(de: DataExpr) -> "DataExpr":
"""Update an object in the cache."""
if de is None:
return None

if de.source != DataSource.REMOTE:
return de

Expand Down
Loading
Loading