Skip to content

Commit 75801ac

Browse files
authored
Allow custom entity names (#105)
* Allow custom entity names
1 parent a491469 commit 75801ac

File tree

5 files changed

+28
-21
lines changed

5 files changed

+28
-21
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
CHANGED
1111

12+
- Allow entities with custom names
1213
- Add/update type-hinting for various worker methods
1314

1415
## v1.2.0

durabletask/task.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,14 @@ def retry_timeout(self) -> Optional[timedelta]:
615615
return self._retry_timeout
616616

617617

618+
def get_entity_name(fn: Entity) -> str:
619+
if hasattr(fn, "__durable_entity_name__"):
620+
return getattr(fn, "__durable_entity_name__")
621+
if isinstance(fn, type) and issubclass(fn, DurableEntity):
622+
return fn.__name__
623+
return get_name(fn)
624+
625+
618626
def get_name(fn: Callable) -> str:
619627
"""Returns the name of the provided function"""
620628
name = fn.__name__

durabletask/worker.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,16 +188,14 @@ def add_named_activity(self, name: str, fn: task.Activity[TInput, TOutput]) -> N
188188
def get_activity(self, name: str) -> Optional[task.Activity[Any, Any]]:
189189
return self.activities.get(name)
190190

191-
def add_entity(self, fn: task.Entity) -> str:
191+
def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str:
192192
if fn is None:
193193
raise ValueError("An entity function argument is required.")
194194

195-
if isinstance(fn, type) and issubclass(fn, DurableEntity):
196-
name = fn.__name__
197-
self.add_named_entity(name, fn)
198-
else:
199-
name = task.get_name(fn)
200-
self.add_named_entity(name, fn)
195+
if name is None:
196+
name = task.get_entity_name(fn)
197+
198+
self.add_named_entity(name, fn)
201199
return name
202200

203201
def add_named_entity(self, name: str, fn: task.Entity) -> None:
@@ -378,13 +376,13 @@ def add_activity(self, fn: task.Activity) -> str:
378376
)
379377
return self._registry.add_activity(fn)
380378

381-
def add_entity(self, fn: task.Entity) -> str:
379+
def add_entity(self, fn: task.Entity, name: Optional[str] = None) -> str:
382380
"""Registers an entity function with the worker."""
383381
if self._is_running:
384382
raise RuntimeError(
385383
"Entities cannot be added while the worker is running."
386384
)
387-
return self._registry.add_entity(fn)
385+
return self._registry.add_entity(fn, name)
388386

389387
def use_versioning(self, version: VersioningOptions) -> None:
390388
"""Initializes versioning options for sub-orchestrators and activities."""

tests/durabletask-azuremanaged/test_dts_class_based_entities_e2e.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
1818

1919

20-
def test_client_signal_class_entity():
20+
def test_client_signal_class_entity_and_custom_name():
2121
invoked = False
2222

2323
class EmptyEntity(entities.DurableEntity):
@@ -28,12 +28,12 @@ def do_nothing(self, _):
2828
# Start a worker, which will connect to the sidecar in a background thread
2929
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
3030
taskhub=taskhub_name, token_credential=None) as w:
31-
w.add_entity(EmptyEntity)
31+
w.add_entity(EmptyEntity, name="EntityNameCustom")
3232
w.start()
3333

3434
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
3535
taskhub=taskhub_name, token_credential=None)
36-
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
36+
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
3737
c.signal_entity(entity_id, "do_nothing")
3838
time.sleep(2) # wait for the signal to be processed
3939

@@ -70,7 +70,7 @@ def do_nothing(self, _):
7070
assert invoked
7171

7272

73-
def test_orchestration_signal_class_entity():
73+
def test_orchestration_signal_class_entity_and_custom_name():
7474
invoked = False
7575

7676
class EmptyEntity(entities.DurableEntity):
@@ -79,14 +79,14 @@ def do_nothing(self, _):
7979
invoked = True
8080

8181
def empty_orchestrator(ctx: task.OrchestrationContext, _):
82-
entity_id = entities.EntityInstanceId("EmptyEntity", "testEntity")
82+
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
8383
ctx.signal_entity(entity_id, "do_nothing")
8484

8585
# Start a worker, which will connect to the sidecar in a background thread
8686
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
8787
taskhub=taskhub_name, token_credential=None) as w:
8888
w.add_orchestrator(empty_orchestrator)
89-
w.add_entity(EmptyEntity)
89+
w.add_entity(EmptyEntity, name="EntityNameCustom")
9090
w.start()
9191

9292
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,

tests/durabletask-azuremanaged/test_dts_function_based_entities_e2e.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
endpoint = os.getenv("ENDPOINT", "http://localhost:8080")
1818

1919

20-
def test_client_signal_entity():
20+
def test_client_signal_entity_and_custom_name():
2121
invoked = False
2222

2323
def empty_entity(ctx: entities.EntityContext, _):
@@ -28,12 +28,12 @@ def empty_entity(ctx: entities.EntityContext, _):
2828
# Start a worker, which will connect to the sidecar in a background thread
2929
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
3030
taskhub=taskhub_name, token_credential=None) as w:
31-
w.add_entity(empty_entity)
31+
w.add_entity(empty_entity, name="EntityNameCustom")
3232
w.start()
3333

3434
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
3535
taskhub=taskhub_name, token_credential=None)
36-
entity_id = entities.EntityInstanceId("empty_entity", "testEntity")
36+
entity_id = entities.EntityInstanceId("EntityNameCustom", "testEntity")
3737
c.signal_entity(entity_id, "do_nothing")
3838
time.sleep(2) # wait for the signal to be processed
3939

@@ -70,7 +70,7 @@ def empty_entity(ctx: entities.EntityContext, _):
7070
assert invoked
7171

7272

73-
def test_orchestration_signal_entity():
73+
def test_orchestration_signal_entity_and_custom_name():
7474
invoked = False
7575

7676
def empty_entity(ctx: entities.EntityContext, _):
@@ -79,14 +79,14 @@ def empty_entity(ctx: entities.EntityContext, _):
7979
invoked = True
8080

8181
def empty_orchestrator(ctx: task.OrchestrationContext, _):
82-
entity_id = entities.EntityInstanceId("empty_entity", f"{ctx.instance_id}_testEntity")
82+
entity_id = entities.EntityInstanceId("EntityNameCustom", f"{ctx.instance_id}_testEntity")
8383
ctx.signal_entity(entity_id, "do_nothing")
8484

8585
# Start a worker, which will connect to the sidecar in a background thread
8686
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
8787
taskhub=taskhub_name, token_credential=None) as w:
8888
w.add_orchestrator(empty_orchestrator)
89-
w.add_entity(empty_entity)
89+
w.add_entity(empty_entity, name="EntityNameCustom")
9090
w.start()
9191

9292
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,

0 commit comments

Comments
 (0)