1616from .const import LOGGER_PATH
1717from .eval import AstEval
1818from .event import Event
19+ from .mqtt import Mqtt
1920from .function import Function
2021from .state import STATE_VIRTUAL_ATTRS , State
2122
@@ -149,13 +150,14 @@ async def wait_until(
149150 state_check_now = True ,
150151 time_trigger = None ,
151152 event_trigger = None ,
153+ mqtt_trigger = None ,
152154 timeout = None ,
153155 state_hold = None ,
154156 state_hold_false = None ,
155157 __test_handshake__ = None ,
156158 ):
157159 """Wait for zero or more triggers, until an optional timeout."""
158- if state_trigger is None and time_trigger is None and event_trigger is None :
160+ if state_trigger is None and time_trigger is None and event_trigger is None and mqtt_trigger is None :
159161 if timeout is not None :
160162 await asyncio .sleep (timeout )
161163 return {"trigger_type" : "timeout" }
@@ -164,6 +166,7 @@ async def wait_until(
164166 state_trig_ident_any = set ()
165167 state_trig_eval = None
166168 event_trig_expr = None
169+ mqtt_trig_expr = None
167170 exc = None
168171 notify_q = asyncio .Queue (0 )
169172
@@ -260,6 +263,23 @@ async def wait_until(
260263 State .notify_del (state_trig_ident , notify_q )
261264 raise exc
262265 Event .notify_add (event_trigger [0 ], notify_q )
266+ if mqtt_trigger is not None :
267+ if isinstance (mqtt_trigger , str ):
268+ mqtt_trigger = [mqtt_trigger ]
269+ if len (mqtt_trigger ) > 1 :
270+ mqtt_trig_expr = AstEval (
271+ f"{ ast_ctx .name } mqtt_trigger" ,
272+ ast_ctx .get_global_ctx (),
273+ logger_name = ast_ctx .get_logger_name (),
274+ )
275+ Function .install_ast_funcs (mqtt_trig_expr )
276+ mqtt_trig_expr .parse (mqtt_trigger [1 ], mode = "eval" )
277+ exc = mqtt_trig_expr .get_exception_obj ()
278+ if exc is not None :
279+ if len (state_trig_ident ) > 0 :
280+ State .notify_del (state_trig_ident , notify_q )
281+ raise exc
282+ await Mqtt .notify_add (mqtt_trigger [0 ], notify_q )
263283 time0 = time .monotonic ()
264284
265285 if __test_handshake__ :
@@ -297,7 +317,7 @@ async def wait_until(
297317 this_timeout = time_left
298318 state_trig_timeout = True
299319 if this_timeout is None :
300- if state_trigger is None and event_trigger is None :
320+ if state_trigger is None and event_trigger is None and mqtt_trigger is None :
301321 _LOGGER .debug (
302322 "trigger %s wait_until no next time - returning with none" , ast_ctx .name ,
303323 )
@@ -403,6 +423,17 @@ async def wait_until(
403423 if event_trig_ok :
404424 ret = notify_info
405425 break
426+ elif notify_type == "mqtt" :
427+ if mqtt_trig_expr is None :
428+ ret = notify_info
429+ break
430+ mqtt_trig_ok = await mqtt_trig_expr .eval (notify_info )
431+ exc = mqtt_trig_expr .get_exception_obj ()
432+ if exc is not None :
433+ break
434+ if mqtt_trig_ok :
435+ ret = notify_info
436+ break
406437 else :
407438 _LOGGER .error (
408439 "trigger %s wait_until got unexpected queue message %s" , ast_ctx .name , notify_type ,
@@ -412,6 +443,8 @@ async def wait_until(
412443 State .notify_del (state_trig_ident , notify_q )
413444 if event_trigger is not None :
414445 Event .notify_del (event_trigger [0 ], notify_q )
446+ if mqtt_trigger is not None :
447+ Mqtt .notify_del (mqtt_trigger [0 ], notify_q )
415448 if exc :
416449 raise exc
417450 return ret
@@ -641,6 +674,7 @@ def __init__(
641674 self .state_check_now = self .state_trigger_kwargs .get ("state_check_now" , False )
642675 self .time_trigger = trig_cfg .get ("time_trigger" , {}).get ("args" , None )
643676 self .event_trigger = trig_cfg .get ("event_trigger" , {}).get ("args" , None )
677+ self .mqtt_trigger = trig_cfg .get ("mqtt_trigger" , {}).get ("args" , None )
644678 self .state_active = trig_cfg .get ("state_active" , {}).get ("args" , None )
645679 self .time_active = trig_cfg .get ("time_active" , {}).get ("args" , None )
646680 self .time_active_hold_off = trig_cfg .get ("time_active" , {}).get ("kwargs" , {}).get ("hold_off" , None )
@@ -656,6 +690,7 @@ def __init__(
656690 self .state_trig_ident = None
657691 self .state_trig_ident_any = set ()
658692 self .event_trig_expr = None
693+ self .mqtt_trig_expr = None
659694 self .have_trigger = False
660695 self .setup_ok = False
661696 self .run_on_startup = False
@@ -726,6 +761,19 @@ def __init__(
726761 return
727762 self .have_trigger = True
728763
764+ if self .mqtt_trigger is not None :
765+ if len (self .mqtt_trigger ) == 2 :
766+ self .mqtt_trig_expr = AstEval (
767+ f"{ self .name } @mqtt_trigger()" , self .global_ctx , logger_name = self .name ,
768+ )
769+ Function .install_ast_funcs (self .mqtt_trig_expr )
770+ self .mqtt_trig_expr .parse (self .mqtt_trigger [1 ], mode = "eval" )
771+ exc = self .mqtt_trig_expr .get_exception_long ()
772+ if exc is not None :
773+ self .mqtt_trig_expr .get_logger ().error (exc )
774+ return
775+ self .have_trigger = True
776+
729777 self .setup_ok = True
730778
731779 def stop (self ):
@@ -736,6 +784,8 @@ def stop(self):
736784 State .notify_del (self .state_trig_ident , self .notify_q )
737785 if self .event_trigger is not None :
738786 Event .notify_del (self .event_trigger [0 ], self .notify_q )
787+ if self .mqtt_trigger is not None :
788+ Mqtt .notify_del (self .mqtt_trigger [0 ], self .notify_q )
739789 if self .task :
740790 Function .task_cancel (self .task )
741791
@@ -765,6 +815,9 @@ async def trigger_watch(self):
765815 if self .event_trigger is not None :
766816 _LOGGER .debug ("trigger %s adding event_trigger %s" , self .name , self .event_trigger [0 ])
767817 Event .notify_add (self .event_trigger [0 ], self .notify_q )
818+ if self .mqtt_trigger is not None :
819+ _LOGGER .debug ("trigger %s adding mqtt_trigger %s" , self .name , self .mqtt_trigger [0 ])
820+ await Mqtt .notify_add (self .mqtt_trigger [0 ], self .notify_q )
768821
769822 last_trig_time = None
770823 last_state_trig_time = None
@@ -924,6 +977,10 @@ async def trigger_watch(self):
924977 func_args = notify_info
925978 if self .event_trig_expr :
926979 trig_ok = await self .event_trig_expr .eval (notify_info )
980+ elif notify_type == "mqtt" :
981+ func_args = notify_info
982+ if self .mqtt_trig_expr :
983+ trig_ok = await self .mqtt_trig_expr .eval (notify_info )
927984
928985 else :
929986 func_args = notify_info
@@ -1038,4 +1095,6 @@ async def do_func_call(func, ast_ctx, task_unique, task_unique_func, hass_contex
10381095 State .notify_del (self .state_trig_ident , self .notify_q )
10391096 if self .event_trigger is not None :
10401097 Event .notify_del (self .event_trigger [0 ], self .notify_q )
1098+ if self .mqtt_trigger is not None :
1099+ Mqtt .notify_del (self .mqtt_trigger [0 ], self .notify_q )
10411100 return
0 commit comments