From a7f00f0623f882251e7836aa76dbbb9e5135719f Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Mon, 12 Jan 2026 23:26:54 +0100 Subject: [PATCH] feature: list, create and revoke authorizations Added methods to `AlephClient` to interact with the authorization system of Aleph Cloud. This enables to delegate specific operations to addresses other than your own, whether it is to enable other people to perform specific actions on your behalf or simply avoid signing messages frequently from your main token-holding wallet. --- src/aleph/sdk/client/abstract.py | 58 ++- src/aleph/sdk/types.py | 71 ++- tests/unit/services/test_authorizations.py | 562 +++++++++++++++++++++ 3 files changed, 688 insertions(+), 3 deletions(-) create mode 100644 tests/unit/services/test_authorizations.py diff --git a/src/aleph/sdk/client/abstract.py b/src/aleph/sdk/client/abstract.py index fea2db3f..894717a2 100644 --- a/src/aleph/sdk/client/abstract.py +++ b/src/aleph/sdk/client/abstract.py @@ -38,7 +38,7 @@ from typing_extensions import deprecated from aleph.sdk.conf import settings -from aleph.sdk.types import Account +from aleph.sdk.types import Account, Authorization, SecurityAggregateContent from aleph.sdk.utils import extended_json_encoder from ..query.filters import MessageFilter, PostFilter @@ -295,6 +295,30 @@ def get_program_price( """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + async def get_authorizations(self, address: str) -> list[Authorization]: + """ + Retrieves all authorizations for a specific address. + """ + # TODO: update this implementation to use `get_aggregate()` once + # https://github.com/aleph-im/aleph-sdk-python/pull/273 is merged. + # There's currently no way to detect a nonexistent aggregate in generic code just yet. + # fetch_aggregate() throws an implementation-specific ClientResponseError in case of 404. + import aiohttp + + try: + security_aggregate_dict = await self.fetch_aggregate( + address=address, key="security" + ) + except aiohttp.ClientResponseError as e: + if e.status == 404: + return [] + raise + + security_aggregate = SecurityAggregateContent.model_validate( + security_aggregate_dict + ) + return security_aggregate.authorizations + class AuthenticatedAlephClient(AlephClient): account: Account @@ -617,3 +641,35 @@ async def storage_push(self, content: Mapping) -> str: :param content: The dict-like content to upload """ raise NotImplementedError() + + async def update_all_authorizations(self, authorizations: list[Authorization]): + """ + Updates all authorizations for the current account. + Danger! This will replace all authorizations for the account. Use with care. + + :param authorizations: List of authorizations to set. These authorizations will replace the existing ones. + """ + security_aggregate = SecurityAggregateContent(authorizations=authorizations) + await self.create_aggregate( + key="security", content=security_aggregate.model_dump() + ) + + async def add_authorization(self, authorization: Authorization): + """ + Adds a specific authorization for the current account. + """ + authorizations = await self.get_authorizations(self.account.get_address()) + authorizations.append(authorization) + await self.update_all_authorizations(authorizations) + + async def revoke_all_authorizations(self, address: str): + """ + Revokes all authorizations for a specific address. + """ + authorizations = await self.get_authorizations(self.account.get_address()) + filtered_authorizations = [ + authorization + for authorization in authorizations + if authorization.address != address + ] + await self.update_all_authorizations(filtered_authorizations) diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index ed2524cb..1bbe66d1 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -14,7 +14,7 @@ Union, ) -from aleph_message.models import ItemHash +from aleph_message.models import ItemHash, MessageType from pydantic import ( BaseModel, ConfigDict, @@ -24,9 +24,11 @@ TypeAdapter, field_validator, ) -from typing_extensions import runtime_checkable +from typing_extensions import Self, runtime_checkable __all__ = ( + "Authorization", + "AuthorizationBuilder", "StorageEnum", "Account", "AccountFromPrivateKey", @@ -406,3 +408,68 @@ class VmResources(BaseModel): vcpus: PositiveInt memory: PositiveInt disk_mib: PositiveInt + + +class Authorization(BaseModel): + """A single authorization entry for delegated access.""" + + address: str + chain: Optional[Chain] = None + channels: list[str] = [] + types: list[MessageType] = [] + post_types: list[str] = [] + aggregate_keys: list[str] = [] + + +class AuthorizationBuilder: + def __init__(self, address: str): + self._address: str = address + self._chain: Optional[Chain] = None + self._channels: list[str] = [] + self._message_types: list[MessageType] = [] + self._post_types: list[str] = [] + self._aggregate_keys: list[str] = [] + + def chain(self, chain: Chain) -> Self: + self._chain = chain + return self + + def channel(self, channel: str) -> Self: + self._channels.append(channel) + return self + + def message_type(self, message_type: MessageType) -> Self: + self._message_types.append(message_type) + return self + + def post_type(self, post_type: str) -> Self: + if MessageType.post not in self._message_types: + raise ValueError( + "Cannot set post_type without allowing POST message type first" + ) + self._post_types.append(post_type) + return self + + def aggregate_key(self, aggregate_key: str) -> Self: + if MessageType.aggregate not in self._message_types: + raise ValueError( + "Cannot set post_type without allowing AGGREGATE message type first" + ) + self._aggregate_keys.append(aggregate_key) + return self + + def build(self) -> Authorization: + return Authorization( + address=self._address, + chain=self._chain, + channels=self._channels, + types=self._message_types, + post_types=self._post_types, + aggregate_keys=self._aggregate_keys, + ) + + +class SecurityAggregateContent(BaseModel): + """Content schema for the 'security' aggregate.""" + + authorizations: list[Authorization] = [] diff --git a/tests/unit/services/test_authorizations.py b/tests/unit/services/test_authorizations.py new file mode 100644 index 00000000..7ab2b7ee --- /dev/null +++ b/tests/unit/services/test_authorizations.py @@ -0,0 +1,562 @@ +""" +Tests for authorization methods in AlephClient. +""" + +from typing import Any, Dict, Iterable, Optional, Tuple + +import pytest +from aleph_message.models import AggregateMessage, Chain, MessageType +from aleph_message.status import MessageStatus + +from aleph.sdk.client.abstract import AuthenticatedAlephClient +from aleph.sdk.types import ( + Account, + Authorization, + AuthorizationBuilder, + SecurityAggregateContent, +) + + +class FakeAccount: + """Minimal fake account for testing.""" + + CHAIN = "ETH" + CURVE = "secp256k1" + + def __init__(self, address: str = "0xTestAddress1234567890123456789012345678"): + self._address = address + + async def sign_message(self, message: Dict) -> Dict: + message["signature"] = "0x" + "ab" * 65 + return message + + async def sign_raw(self, buffer: bytes) -> bytes: + return b"fake_signature" + + def get_address(self) -> str: + return self._address + + def get_public_key(self) -> str: + return "0x" + "cd" * 33 + + +class MockAlephClient(AuthenticatedAlephClient): + """ + A fake authenticated client that maintains an in-memory aggregate store. + Aggregates are dictionaries that get merged/updated with each create_aggregate call. + """ + + def __init__(self, account: Optional[Account] = None): + self.account = account or FakeAccount() + # Storage: {address: {key: content}} + self._aggregates: Dict[str, Dict[str, Any]] = {} + + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Any]: + """Fetch a single aggregate by address and key.""" + if address not in self._aggregates: + return {"authorizations": []} + return self._aggregates[address].get(key, {"authorizations": []}) + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None + ) -> Dict[str, Dict]: + """Fetch multiple aggregates.""" + if address not in self._aggregates: + return {} + if keys is None: + return self._aggregates[address] + return {k: v for k, v in self._aggregates[address].items() if k in keys} + + async def create_aggregate( + self, + key: str, + content: Dict[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Create/update an aggregate. Merges content into existing aggregate. + """ + address = address or self.account.get_address() + + if address not in self._aggregates: + self._aggregates[address] = {} + + # Aggregates merge content (like a dict update) + if key in self._aggregates[address]: + self._aggregates[address][key].update(content) + else: + self._aggregates[address][key] = content + + # Return a minimal mock message + mock_message = AggregateMessage.model_validate( + { + "item_hash": "44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a", + "type": "AGGREGATE", + "chain": "ETH", + "sender": address, + "signature": "0x" + "ab" * 65, + "item_type": "inline", + "item_content": "{}", + "content": { + "key": key, + "address": address, + "content": content, + "time": 0, + }, + "time": 0, + "channel": channel or "TEST", + } + ) + return mock_message, MessageStatus.PROCESSED + + # Stub implementations for abstract methods we don't need + async def create_post(self, *args, **kwargs): + raise NotImplementedError + + async def create_store(self, *args, **kwargs): + raise NotImplementedError + + async def create_program(self, *args, **kwargs): + raise NotImplementedError + + async def create_instance(self, *args, **kwargs): + raise NotImplementedError + + async def forget(self, *args, **kwargs): + raise NotImplementedError + + async def submit(self, *args, **kwargs): + raise NotImplementedError + + async def get_posts(self, *args, **kwargs): + raise NotImplementedError + + async def download_file(self, *args, **kwargs): + raise NotImplementedError + + async def download_file_to_path(self, *args, **kwargs): + raise NotImplementedError + + async def get_messages(self, *args, **kwargs): + raise NotImplementedError + + async def get_message(self, *args, **kwargs): + raise NotImplementedError + + def watch_messages(self, *args, **kwargs): + raise NotImplementedError + + def get_estimated_price(self, *args, **kwargs): + raise NotImplementedError + + def get_program_price(self, *args, **kwargs): + raise NotImplementedError + + +# Fixtures +@pytest.fixture +def mock_client() -> MockAlephClient: + """Create a fresh fake client for each test.""" + return MockAlephClient() + + +@pytest.fixture +def mock_client_with_existing_auth() -> MockAlephClient: + """Create a fake client with pre-existing authorizations.""" + client = MockAlephClient() + client._aggregates[client.account.get_address()] = { + "security": { + "authorizations": [ + { + "address": "0xExistingAddress123456789012345678901234", + "chain": "ETH", + "channels": ["existing_channel"], + "types": ["POST"], + "post_types": [], + "aggregate_keys": [], + } + ] + } + } + return client + + +# Tests for get_authorizations +class TestGetAuthorizations: + @pytest.mark.asyncio + async def test_get_authorizations_empty(self, mock_client: MockAlephClient): + """When no authorizations exist, returns empty list.""" + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert authorizations == [] + + @pytest.mark.asyncio + async def test_get_authorizations_returns_existing( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Returns existing authorizations from aggregate store.""" + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + + assert len(authorizations) == 1 + assert authorizations[0].address == "0xExistingAddress123456789012345678901234" + assert authorizations[0].chain == Chain.ETH + assert authorizations[0].channels == ["existing_channel"] + + +# Tests for update_all_authorizations +class TestUpdateAllAuthorizations: + @pytest.mark.asyncio + async def test_update_replaces_all_authorizations( + self, mock_client: MockAlephClient + ): + """update_all_authorizations replaces the entire authorization list.""" + auth1 = Authorization(address="0xAddress1111111111111111111111111111111111") + auth2 = Authorization(address="0xAddress2222222222222222222222222222222222") + + await mock_client.update_all_authorizations([auth1, auth2]) + + # Verify stored content + stored = mock_client._aggregates[mock_client.account.get_address()]["security"] + assert len(stored["authorizations"]) == 2 + + @pytest.mark.asyncio + async def test_update_with_empty_list_clears_authorizations( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Passing an empty list removes all authorizations.""" + await mock_client_with_existing_auth.update_all_authorizations([]) + + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + assert authorizations == [] + + @pytest.mark.asyncio + async def test_update_preserves_authorization_fields( + self, mock_client: MockAlephClient + ): + """All authorization fields are preserved when storing.""" + auth = Authorization( + address="0xFullAuth111111111111111111111111111111111", + chain=Chain.ETH, + channels=["channel1", "channel2"], + types=[MessageType.post, MessageType.aggregate], + post_types=["blog", "comment"], + aggregate_keys=["settings"], + ) + + await mock_client.update_all_authorizations([auth]) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + retrieved = authorizations[0] + assert retrieved.address == auth.address + assert retrieved.chain == Chain.ETH + assert retrieved.channels == ["channel1", "channel2"] + assert MessageType.post in retrieved.types + assert "blog" in retrieved.post_types + + +# Tests for add_authorization +class TestAddAuthorization: + @pytest.mark.asyncio + async def test_add_to_empty(self, mock_client: MockAlephClient): + """Adding authorization when none exist.""" + auth = Authorization(address="0xNewAddress1111111111111111111111111111111") + + await mock_client.add_authorization(auth) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert ( + authorizations[0].address == "0xNewAddress1111111111111111111111111111111" + ) + + @pytest.mark.asyncio + async def test_add_appends_to_existing( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Adding authorization appends to existing list.""" + new_auth = Authorization( + address="0xNewAddress2222222222222222222222222222222", + channels=["new_channel"], + ) + + await mock_client_with_existing_auth.add_authorization(new_auth) + + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + assert len(authorizations) == 2 + addresses = [a.address for a in authorizations] + assert "0xExistingAddress123456789012345678901234" in addresses + assert "0xNewAddress2222222222222222222222222222222" in addresses + + @pytest.mark.asyncio + async def test_add_multiple_authorizations_sequentially( + self, mock_client: MockAlephClient + ): + """Adding multiple authorizations one by one.""" + auth1 = Authorization(address="0xFirst11111111111111111111111111111111111") + auth2 = Authorization(address="0xSecond2222222222222222222222222222222222") + auth3 = Authorization(address="0xThird33333333333333333333333333333333333") + + await mock_client.add_authorization(auth1) + await mock_client.add_authorization(auth2) + await mock_client.add_authorization(auth3) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 3 + + +# Tests for revoke_all_authorizations +class TestRevokeAllAuthorizations: + @pytest.mark.asyncio + async def test_revoke_removes_matching_address( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Revoking removes all authorizations for the specified address.""" + await mock_client_with_existing_auth.revoke_all_authorizations( + "0xExistingAddress123456789012345678901234" + ) + + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + assert len(authorizations) == 0 + + @pytest.mark.asyncio + async def test_revoke_keeps_other_addresses(self, mock_client: MockAlephClient): + """Revoking only removes authorizations for the specified address.""" + auth1 = Authorization(address="0xToRevoke111111111111111111111111111111111") + auth2 = Authorization(address="0xToKeep22222222222222222222222222222222222") + auth3 = Authorization( + address="0xToRevoke111111111111111111111111111111111" + ) # Duplicate + + await mock_client.update_all_authorizations([auth1, auth2, auth3]) + + await mock_client.revoke_all_authorizations( + "0xToRevoke111111111111111111111111111111111" + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert ( + authorizations[0].address == "0xToKeep22222222222222222222222222222222222" + ) + + @pytest.mark.asyncio + async def test_revoke_nonexistent_address_is_noop( + self, mock_client: MockAlephClient + ): + """Revoking an address that doesn't exist does nothing.""" + auth = Authorization(address="0xExisting1111111111111111111111111111111111") + await mock_client.add_authorization(auth) + + await mock_client.revoke_all_authorizations( + "0xNonExistent22222222222222222222222222222" + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + + @pytest.mark.asyncio + async def test_revoke_from_empty_is_noop(self, mock_client: MockAlephClient): + """Revoking when no authorizations exist doesn't error.""" + await mock_client.revoke_all_authorizations( + "0xAnyAddress111111111111111111111111111111111" + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert authorizations == [] + + +# Integration tests - full workflows +class TestAuthorizationWorkflows: + @pytest.mark.asyncio + async def test_full_lifecycle(self, mock_client: MockAlephClient): + """Test complete authorization lifecycle: add, verify, revoke.""" + delegate_address = "0xDelegate111111111111111111111111111111111" + + # Initially empty + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 0 + + # Add authorization + auth = Authorization( + address=delegate_address, + channels=["MY_APP"], + types=[MessageType.post], + ) + await mock_client.add_authorization(auth) + + # Verify it exists + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert authorizations[0].address == delegate_address + assert "MY_APP" in authorizations[0].channels + + # Revoke + await mock_client.revoke_all_authorizations(delegate_address) + + # Verify it's gone + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 0 + + @pytest.mark.asyncio + async def test_multiple_delegates_workflow(self, mock_client: MockAlephClient): + """Test managing authorizations for multiple delegate addresses.""" + delegate1 = "0xDelegate1111111111111111111111111111111111" + delegate2 = "0xDelegate2222222222222222222222222222222222" + + # Add two delegates + await mock_client.add_authorization( + Authorization(address=delegate1, channels=["channel_a"]) + ) + await mock_client.add_authorization( + Authorization(address=delegate2, channels=["channel_b"]) + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 2 + + # Revoke first delegate + await mock_client.revoke_all_authorizations(delegate1) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert authorizations[0].address == delegate2 + + @pytest.mark.asyncio + async def test_replace_all_authorizations(self, mock_client: MockAlephClient): + """Test replacing all authorizations at once.""" + # Add initial authorizations + await mock_client.add_authorization( + Authorization(address="0xOld111111111111111111111111111111111111111") + ) + await mock_client.add_authorization( + Authorization(address="0xOld222222222222222222222222222222222222222") + ) + + # Replace with new set + new_auths = [ + Authorization(address="0xNew111111111111111111111111111111111111111"), + Authorization(address="0xNew222222222222222222222222222222222222222"), + Authorization(address="0xNew333333333333333333333333333333333333333"), + ] + await mock_client.update_all_authorizations(new_auths) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 3 + addresses = {a.address for a in authorizations} + assert "0xOld111111111111111111111111111111111111111" not in addresses + assert "0xNew111111111111111111111111111111111111111" in addresses + + +# Model tests +class TestAuthorizationModel: + def test_minimal_authorization(self): + """Authorization can be created with just an address.""" + auth = Authorization(address="0x1234567890123456789012345678901234567890") + assert auth.address == "0x1234567890123456789012345678901234567890" + assert auth.chain is None + assert auth.channels == [] + assert auth.types == [] + + def test_full_authorization(self): + """Authorization with all fields set.""" + auth = Authorization( + address="0x1234567890123456789012345678901234567890", + chain=Chain.ETH, + channels=["ch1", "ch2"], + types=[MessageType.post, MessageType.store], + post_types=["blog"], + aggregate_keys=["settings"], + ) + assert auth.chain == Chain.ETH + assert len(auth.channels) == 2 + assert len(auth.types) == 2 + + def test_security_aggregate_serialization(self): + """SecurityAggregateContent serializes correctly.""" + auth = Authorization( + address="0x1234567890123456789012345678901234567890", + channels=["test"], + ) + content = SecurityAggregateContent(authorizations=[auth]) + dumped = content.model_dump() + + assert "authorizations" in dumped + assert len(dumped["authorizations"]) == 1 + assert dumped["authorizations"][0]["address"] == auth.address + + +class TestAuthorizationBuilder: + def test_authorization_builder_only_address(self): + """Test the AuthorizationBuilder.""" + auth = AuthorizationBuilder( + address="0x1234567890123456789012345678901234567890" + ).build() + assert auth.address == "0x1234567890123456789012345678901234567890" + assert auth.chain is None + assert auth.channels == [] + assert auth.types == [] + assert auth.post_types == [] + assert auth.aggregate_keys == [] + + def test_authorization_builder(self): + """Test the AuthorizationBuilder with a detailed configuration.""" + sample_authorization = Authorization( + address="0xFullAuth111111111111111111111111111111111", + chain=Chain.ETH, + channels=["channel1", "channel2"], + types=[MessageType.post, MessageType.aggregate], + post_types=["blog", "comment"], + aggregate_keys=["settings"], + ) + + auth = AuthorizationBuilder(address=sample_authorization.address).chain( + sample_authorization.chain + ) + for channel in sample_authorization.channels: + auth = auth.channel(channel) + for message_type in sample_authorization.types: + auth = auth.message_type(message_type) + for post_type in sample_authorization.post_types: + auth = auth.post_type(post_type) + for aggregate_key in sample_authorization.aggregate_keys: + auth = auth.aggregate_key(aggregate_key) + auth = auth.build() + + assert auth == sample_authorization