diff --git a/taskiq_nats/broker.py b/taskiq_nats/broker.py index db22c61..879d444 100644 --- a/taskiq_nats/broker.py +++ b/taskiq_nats/broker.py @@ -7,16 +7,15 @@ from nats.errors import TimeoutError as NatsTimeoutError from nats.js import JetStreamContext from nats.js.api import ConsumerConfig, StreamConfig +from nats.js.errors import NotFoundError from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage _T = typing.TypeVar("_T") # (Too short) - JetStreamConsumerType = typing.TypeVar( "JetStreamConsumerType", ) - logger = getLogger("taskiq_nats") @@ -138,6 +137,23 @@ def __init__( self.consumer: JetStreamConsumerType + async def _ensure_stream_exists(self) -> None: + """Ensure stream exists, create if it doesn't.""" + if self.stream_config.name is None: + self.stream_config.name = self.stream_name + if not self.stream_config.subjects: + self.stream_config.subjects = [self.subject] + + try: + # Check if stream already exists + await self.js.stream_info(self.stream_config.name) + logger.debug("Stream %s already exists", self.stream_config.name) + except NotFoundError: + logger.debug("stream %s does not exist", self.stream_config.name) + # Stream doesn't exist, create it + await self.js.add_stream(config=self.stream_config) + logger.info("Created stream %s", self.stream_config.name) + async def startup(self) -> None: """ Startup event handler. @@ -148,11 +164,9 @@ async def startup(self) -> None: await super().startup() await self.client.connect(self.servers, **self.connection_kwargs) self.js = self.client.jetstream() - if self.stream_config.name is None: - self.stream_config.name = self.stream_name - if not self.stream_config.subjects: - self.stream_config.subjects = [self.subject] - await self.js.add_stream(config=self.stream_config) + + # Ensure stream exists (won't recreate if it exists) + await self._ensure_stream_exists() await self._startup_consumer() async def shutdown(self) -> None: