diff --git a/src/aleph/sdk/client/abstract.py b/src/aleph/sdk/client/abstract.py index fea2db3f..894717a2 100644 --- a/src/aleph/sdk/client/abstract.py +++ b/src/aleph/sdk/client/abstract.py @@ -38,7 +38,7 @@ from typing_extensions import deprecated from aleph.sdk.conf import settings -from aleph.sdk.types import Account +from aleph.sdk.types import Account, Authorization, SecurityAggregateContent from aleph.sdk.utils import extended_json_encoder from ..query.filters import MessageFilter, PostFilter @@ -295,6 +295,30 @@ def get_program_price( """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + async def get_authorizations(self, address: str) -> list[Authorization]: + """ + Retrieves all authorizations for a specific address. + """ + # TODO: update this implementation to use `get_aggregate()` once + # https://github.com/aleph-im/aleph-sdk-python/pull/273 is merged. + # There's currently no way to detect a nonexistent aggregate in generic code just yet. + # fetch_aggregate() throws an implementation-specific ClientResponseError in case of 404. + import aiohttp + + try: + security_aggregate_dict = await self.fetch_aggregate( + address=address, key="security" + ) + except aiohttp.ClientResponseError as e: + if e.status == 404: + return [] + raise + + security_aggregate = SecurityAggregateContent.model_validate( + security_aggregate_dict + ) + return security_aggregate.authorizations + class AuthenticatedAlephClient(AlephClient): account: Account @@ -617,3 +641,35 @@ async def storage_push(self, content: Mapping) -> str: :param content: The dict-like content to upload """ raise NotImplementedError() + + async def update_all_authorizations(self, authorizations: list[Authorization]): + """ + Updates all authorizations for the current account. + Danger! This will replace all authorizations for the account. Use with care. + + :param authorizations: List of authorizations to set. These authorizations will replace the existing ones. + """ + security_aggregate = SecurityAggregateContent(authorizations=authorizations) + await self.create_aggregate( + key="security", content=security_aggregate.model_dump() + ) + + async def add_authorization(self, authorization: Authorization): + """ + Adds a specific authorization for the current account. + """ + authorizations = await self.get_authorizations(self.account.get_address()) + authorizations.append(authorization) + await self.update_all_authorizations(authorizations) + + async def revoke_all_authorizations(self, address: str): + """ + Revokes all authorizations for a specific address. + """ + authorizations = await self.get_authorizations(self.account.get_address()) + filtered_authorizations = [ + authorization + for authorization in authorizations + if authorization.address != address + ] + await self.update_all_authorizations(filtered_authorizations) diff --git a/src/aleph/sdk/types.py b/src/aleph/sdk/types.py index ed2524cb..1bbe66d1 100644 --- a/src/aleph/sdk/types.py +++ b/src/aleph/sdk/types.py @@ -14,7 +14,7 @@ Union, ) -from aleph_message.models import ItemHash +from aleph_message.models import ItemHash, MessageType from pydantic import ( BaseModel, ConfigDict, @@ -24,9 +24,11 @@ TypeAdapter, field_validator, ) -from typing_extensions import runtime_checkable +from typing_extensions import Self, runtime_checkable __all__ = ( + "Authorization", + "AuthorizationBuilder", "StorageEnum", "Account", "AccountFromPrivateKey", @@ -406,3 +408,68 @@ class VmResources(BaseModel): vcpus: PositiveInt memory: PositiveInt disk_mib: PositiveInt + + +class Authorization(BaseModel): + """A single authorization entry for delegated access.""" + + address: str + chain: Optional[Chain] = None + channels: list[str] = [] + types: list[MessageType] = [] + post_types: list[str] = [] + aggregate_keys: list[str] = [] + + +class AuthorizationBuilder: + def __init__(self, address: str): + self._address: str = address + self._chain: Optional[Chain] = None + self._channels: list[str] = [] + self._message_types: list[MessageType] = [] + self._post_types: list[str] = [] + self._aggregate_keys: list[str] = [] + + def chain(self, chain: Chain) -> Self: + self._chain = chain + return self + + def channel(self, channel: str) -> Self: + self._channels.append(channel) + return self + + def message_type(self, message_type: MessageType) -> Self: + self._message_types.append(message_type) + return self + + def post_type(self, post_type: str) -> Self: + if MessageType.post not in self._message_types: + raise ValueError( + "Cannot set post_type without allowing POST message type first" + ) + self._post_types.append(post_type) + return self + + def aggregate_key(self, aggregate_key: str) -> Self: + if MessageType.aggregate not in self._message_types: + raise ValueError( + "Cannot set post_type without allowing AGGREGATE message type first" + ) + self._aggregate_keys.append(aggregate_key) + return self + + def build(self) -> Authorization: + return Authorization( + address=self._address, + chain=self._chain, + channels=self._channels, + types=self._message_types, + post_types=self._post_types, + aggregate_keys=self._aggregate_keys, + ) + + +class SecurityAggregateContent(BaseModel): + """Content schema for the 'security' aggregate.""" + + authorizations: list[Authorization] = [] diff --git a/tests/unit/services/test_authorizations.py b/tests/unit/services/test_authorizations.py new file mode 100644 index 00000000..7ab2b7ee --- /dev/null +++ b/tests/unit/services/test_authorizations.py @@ -0,0 +1,562 @@ +""" +Tests for authorization methods in AlephClient. +""" + +from typing import Any, Dict, Iterable, Optional, Tuple + +import pytest +from aleph_message.models import AggregateMessage, Chain, MessageType +from aleph_message.status import MessageStatus + +from aleph.sdk.client.abstract import AuthenticatedAlephClient +from aleph.sdk.types import ( + Account, + Authorization, + AuthorizationBuilder, + SecurityAggregateContent, +) + + +class FakeAccount: + """Minimal fake account for testing.""" + + CHAIN = "ETH" + CURVE = "secp256k1" + + def __init__(self, address: str = "0xTestAddress1234567890123456789012345678"): + self._address = address + + async def sign_message(self, message: Dict) -> Dict: + message["signature"] = "0x" + "ab" * 65 + return message + + async def sign_raw(self, buffer: bytes) -> bytes: + return b"fake_signature" + + def get_address(self) -> str: + return self._address + + def get_public_key(self) -> str: + return "0x" + "cd" * 33 + + +class MockAlephClient(AuthenticatedAlephClient): + """ + A fake authenticated client that maintains an in-memory aggregate store. + Aggregates are dictionaries that get merged/updated with each create_aggregate call. + """ + + def __init__(self, account: Optional[Account] = None): + self.account = account or FakeAccount() + # Storage: {address: {key: content}} + self._aggregates: Dict[str, Dict[str, Any]] = {} + + async def fetch_aggregate(self, address: str, key: str) -> Dict[str, Any]: + """Fetch a single aggregate by address and key.""" + if address not in self._aggregates: + return {"authorizations": []} + return self._aggregates[address].get(key, {"authorizations": []}) + + async def fetch_aggregates( + self, address: str, keys: Optional[Iterable[str]] = None + ) -> Dict[str, Dict]: + """Fetch multiple aggregates.""" + if address not in self._aggregates: + return {} + if keys is None: + return self._aggregates[address] + return {k: v for k, v in self._aggregates[address].items() if k in keys} + + async def create_aggregate( + self, + key: str, + content: Dict[str, Any], + address: Optional[str] = None, + channel: Optional[str] = None, + inline: bool = True, + sync: bool = False, + ) -> Tuple[AggregateMessage, MessageStatus]: + """ + Create/update an aggregate. Merges content into existing aggregate. + """ + address = address or self.account.get_address() + + if address not in self._aggregates: + self._aggregates[address] = {} + + # Aggregates merge content (like a dict update) + if key in self._aggregates[address]: + self._aggregates[address][key].update(content) + else: + self._aggregates[address][key] = content + + # Return a minimal mock message + mock_message = AggregateMessage.model_validate( + { + "item_hash": "44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a", + "type": "AGGREGATE", + "chain": "ETH", + "sender": address, + "signature": "0x" + "ab" * 65, + "item_type": "inline", + "item_content": "{}", + "content": { + "key": key, + "address": address, + "content": content, + "time": 0, + }, + "time": 0, + "channel": channel or "TEST", + } + ) + return mock_message, MessageStatus.PROCESSED + + # Stub implementations for abstract methods we don't need + async def create_post(self, *args, **kwargs): + raise NotImplementedError + + async def create_store(self, *args, **kwargs): + raise NotImplementedError + + async def create_program(self, *args, **kwargs): + raise NotImplementedError + + async def create_instance(self, *args, **kwargs): + raise NotImplementedError + + async def forget(self, *args, **kwargs): + raise NotImplementedError + + async def submit(self, *args, **kwargs): + raise NotImplementedError + + async def get_posts(self, *args, **kwargs): + raise NotImplementedError + + async def download_file(self, *args, **kwargs): + raise NotImplementedError + + async def download_file_to_path(self, *args, **kwargs): + raise NotImplementedError + + async def get_messages(self, *args, **kwargs): + raise NotImplementedError + + async def get_message(self, *args, **kwargs): + raise NotImplementedError + + def watch_messages(self, *args, **kwargs): + raise NotImplementedError + + def get_estimated_price(self, *args, **kwargs): + raise NotImplementedError + + def get_program_price(self, *args, **kwargs): + raise NotImplementedError + + +# Fixtures +@pytest.fixture +def mock_client() -> MockAlephClient: + """Create a fresh fake client for each test.""" + return MockAlephClient() + + +@pytest.fixture +def mock_client_with_existing_auth() -> MockAlephClient: + """Create a fake client with pre-existing authorizations.""" + client = MockAlephClient() + client._aggregates[client.account.get_address()] = { + "security": { + "authorizations": [ + { + "address": "0xExistingAddress123456789012345678901234", + "chain": "ETH", + "channels": ["existing_channel"], + "types": ["POST"], + "post_types": [], + "aggregate_keys": [], + } + ] + } + } + return client + + +# Tests for get_authorizations +class TestGetAuthorizations: + @pytest.mark.asyncio + async def test_get_authorizations_empty(self, mock_client: MockAlephClient): + """When no authorizations exist, returns empty list.""" + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert authorizations == [] + + @pytest.mark.asyncio + async def test_get_authorizations_returns_existing( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Returns existing authorizations from aggregate store.""" + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + + assert len(authorizations) == 1 + assert authorizations[0].address == "0xExistingAddress123456789012345678901234" + assert authorizations[0].chain == Chain.ETH + assert authorizations[0].channels == ["existing_channel"] + + +# Tests for update_all_authorizations +class TestUpdateAllAuthorizations: + @pytest.mark.asyncio + async def test_update_replaces_all_authorizations( + self, mock_client: MockAlephClient + ): + """update_all_authorizations replaces the entire authorization list.""" + auth1 = Authorization(address="0xAddress1111111111111111111111111111111111") + auth2 = Authorization(address="0xAddress2222222222222222222222222222222222") + + await mock_client.update_all_authorizations([auth1, auth2]) + + # Verify stored content + stored = mock_client._aggregates[mock_client.account.get_address()]["security"] + assert len(stored["authorizations"]) == 2 + + @pytest.mark.asyncio + async def test_update_with_empty_list_clears_authorizations( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Passing an empty list removes all authorizations.""" + await mock_client_with_existing_auth.update_all_authorizations([]) + + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + assert authorizations == [] + + @pytest.mark.asyncio + async def test_update_preserves_authorization_fields( + self, mock_client: MockAlephClient + ): + """All authorization fields are preserved when storing.""" + auth = Authorization( + address="0xFullAuth111111111111111111111111111111111", + chain=Chain.ETH, + channels=["channel1", "channel2"], + types=[MessageType.post, MessageType.aggregate], + post_types=["blog", "comment"], + aggregate_keys=["settings"], + ) + + await mock_client.update_all_authorizations([auth]) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + retrieved = authorizations[0] + assert retrieved.address == auth.address + assert retrieved.chain == Chain.ETH + assert retrieved.channels == ["channel1", "channel2"] + assert MessageType.post in retrieved.types + assert "blog" in retrieved.post_types + + +# Tests for add_authorization +class TestAddAuthorization: + @pytest.mark.asyncio + async def test_add_to_empty(self, mock_client: MockAlephClient): + """Adding authorization when none exist.""" + auth = Authorization(address="0xNewAddress1111111111111111111111111111111") + + await mock_client.add_authorization(auth) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert ( + authorizations[0].address == "0xNewAddress1111111111111111111111111111111" + ) + + @pytest.mark.asyncio + async def test_add_appends_to_existing( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Adding authorization appends to existing list.""" + new_auth = Authorization( + address="0xNewAddress2222222222222222222222222222222", + channels=["new_channel"], + ) + + await mock_client_with_existing_auth.add_authorization(new_auth) + + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + assert len(authorizations) == 2 + addresses = [a.address for a in authorizations] + assert "0xExistingAddress123456789012345678901234" in addresses + assert "0xNewAddress2222222222222222222222222222222" in addresses + + @pytest.mark.asyncio + async def test_add_multiple_authorizations_sequentially( + self, mock_client: MockAlephClient + ): + """Adding multiple authorizations one by one.""" + auth1 = Authorization(address="0xFirst11111111111111111111111111111111111") + auth2 = Authorization(address="0xSecond2222222222222222222222222222222222") + auth3 = Authorization(address="0xThird33333333333333333333333333333333333") + + await mock_client.add_authorization(auth1) + await mock_client.add_authorization(auth2) + await mock_client.add_authorization(auth3) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 3 + + +# Tests for revoke_all_authorizations +class TestRevokeAllAuthorizations: + @pytest.mark.asyncio + async def test_revoke_removes_matching_address( + self, mock_client_with_existing_auth: MockAlephClient + ): + """Revoking removes all authorizations for the specified address.""" + await mock_client_with_existing_auth.revoke_all_authorizations( + "0xExistingAddress123456789012345678901234" + ) + + authorizations = await mock_client_with_existing_auth.get_authorizations( + mock_client_with_existing_auth.account.get_address() + ) + assert len(authorizations) == 0 + + @pytest.mark.asyncio + async def test_revoke_keeps_other_addresses(self, mock_client: MockAlephClient): + """Revoking only removes authorizations for the specified address.""" + auth1 = Authorization(address="0xToRevoke111111111111111111111111111111111") + auth2 = Authorization(address="0xToKeep22222222222222222222222222222222222") + auth3 = Authorization( + address="0xToRevoke111111111111111111111111111111111" + ) # Duplicate + + await mock_client.update_all_authorizations([auth1, auth2, auth3]) + + await mock_client.revoke_all_authorizations( + "0xToRevoke111111111111111111111111111111111" + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert ( + authorizations[0].address == "0xToKeep22222222222222222222222222222222222" + ) + + @pytest.mark.asyncio + async def test_revoke_nonexistent_address_is_noop( + self, mock_client: MockAlephClient + ): + """Revoking an address that doesn't exist does nothing.""" + auth = Authorization(address="0xExisting1111111111111111111111111111111111") + await mock_client.add_authorization(auth) + + await mock_client.revoke_all_authorizations( + "0xNonExistent22222222222222222222222222222" + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + + @pytest.mark.asyncio + async def test_revoke_from_empty_is_noop(self, mock_client: MockAlephClient): + """Revoking when no authorizations exist doesn't error.""" + await mock_client.revoke_all_authorizations( + "0xAnyAddress111111111111111111111111111111111" + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert authorizations == [] + + +# Integration tests - full workflows +class TestAuthorizationWorkflows: + @pytest.mark.asyncio + async def test_full_lifecycle(self, mock_client: MockAlephClient): + """Test complete authorization lifecycle: add, verify, revoke.""" + delegate_address = "0xDelegate111111111111111111111111111111111" + + # Initially empty + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 0 + + # Add authorization + auth = Authorization( + address=delegate_address, + channels=["MY_APP"], + types=[MessageType.post], + ) + await mock_client.add_authorization(auth) + + # Verify it exists + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert authorizations[0].address == delegate_address + assert "MY_APP" in authorizations[0].channels + + # Revoke + await mock_client.revoke_all_authorizations(delegate_address) + + # Verify it's gone + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 0 + + @pytest.mark.asyncio + async def test_multiple_delegates_workflow(self, mock_client: MockAlephClient): + """Test managing authorizations for multiple delegate addresses.""" + delegate1 = "0xDelegate1111111111111111111111111111111111" + delegate2 = "0xDelegate2222222222222222222222222222222222" + + # Add two delegates + await mock_client.add_authorization( + Authorization(address=delegate1, channels=["channel_a"]) + ) + await mock_client.add_authorization( + Authorization(address=delegate2, channels=["channel_b"]) + ) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 2 + + # Revoke first delegate + await mock_client.revoke_all_authorizations(delegate1) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 1 + assert authorizations[0].address == delegate2 + + @pytest.mark.asyncio + async def test_replace_all_authorizations(self, mock_client: MockAlephClient): + """Test replacing all authorizations at once.""" + # Add initial authorizations + await mock_client.add_authorization( + Authorization(address="0xOld111111111111111111111111111111111111111") + ) + await mock_client.add_authorization( + Authorization(address="0xOld222222222222222222222222222222222222222") + ) + + # Replace with new set + new_auths = [ + Authorization(address="0xNew111111111111111111111111111111111111111"), + Authorization(address="0xNew222222222222222222222222222222222222222"), + Authorization(address="0xNew333333333333333333333333333333333333333"), + ] + await mock_client.update_all_authorizations(new_auths) + + authorizations = await mock_client.get_authorizations( + mock_client.account.get_address() + ) + assert len(authorizations) == 3 + addresses = {a.address for a in authorizations} + assert "0xOld111111111111111111111111111111111111111" not in addresses + assert "0xNew111111111111111111111111111111111111111" in addresses + + +# Model tests +class TestAuthorizationModel: + def test_minimal_authorization(self): + """Authorization can be created with just an address.""" + auth = Authorization(address="0x1234567890123456789012345678901234567890") + assert auth.address == "0x1234567890123456789012345678901234567890" + assert auth.chain is None + assert auth.channels == [] + assert auth.types == [] + + def test_full_authorization(self): + """Authorization with all fields set.""" + auth = Authorization( + address="0x1234567890123456789012345678901234567890", + chain=Chain.ETH, + channels=["ch1", "ch2"], + types=[MessageType.post, MessageType.store], + post_types=["blog"], + aggregate_keys=["settings"], + ) + assert auth.chain == Chain.ETH + assert len(auth.channels) == 2 + assert len(auth.types) == 2 + + def test_security_aggregate_serialization(self): + """SecurityAggregateContent serializes correctly.""" + auth = Authorization( + address="0x1234567890123456789012345678901234567890", + channels=["test"], + ) + content = SecurityAggregateContent(authorizations=[auth]) + dumped = content.model_dump() + + assert "authorizations" in dumped + assert len(dumped["authorizations"]) == 1 + assert dumped["authorizations"][0]["address"] == auth.address + + +class TestAuthorizationBuilder: + def test_authorization_builder_only_address(self): + """Test the AuthorizationBuilder.""" + auth = AuthorizationBuilder( + address="0x1234567890123456789012345678901234567890" + ).build() + assert auth.address == "0x1234567890123456789012345678901234567890" + assert auth.chain is None + assert auth.channels == [] + assert auth.types == [] + assert auth.post_types == [] + assert auth.aggregate_keys == [] + + def test_authorization_builder(self): + """Test the AuthorizationBuilder with a detailed configuration.""" + sample_authorization = Authorization( + address="0xFullAuth111111111111111111111111111111111", + chain=Chain.ETH, + channels=["channel1", "channel2"], + types=[MessageType.post, MessageType.aggregate], + post_types=["blog", "comment"], + aggregate_keys=["settings"], + ) + + auth = AuthorizationBuilder(address=sample_authorization.address).chain( + sample_authorization.chain + ) + for channel in sample_authorization.channels: + auth = auth.channel(channel) + for message_type in sample_authorization.types: + auth = auth.message_type(message_type) + for post_type in sample_authorization.post_types: + auth = auth.post_type(post_type) + for aggregate_key in sample_authorization.aggregate_keys: + auth = auth.aggregate_key(aggregate_key) + auth = auth.build() + + assert auth == sample_authorization