2222import os
2323import threading
2424import time
25- from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional , Set , Tuple
25+ from typing import Any , Callable , Dict , Iterator , List , Mapping , Optional , Set , Tuple , cast
2626import uuid
2727
2828from absl import logging
@@ -557,7 +557,8 @@ def __init__(
557557 self ._mlmd_execution_atomic_op_context = None
558558 self ._execution : Optional [metadata_store_pb2 .Execution ] = None
559559 self ._on_commit_callbacks : List [Callable [[], None ]] = []
560- self ._node_states_proxy : Optional [_NodeStatesProxy ] = None
560+ # The note state proxy is assumed to be initialized before being used.
561+ self ._node_states_proxy : _NodeStatesProxy = cast (_NodeStatesProxy , None )
561562
562563 @classmethod
563564 @telemetry_utils .noop_telemetry (metrics_utils .no_op_metrics )
@@ -916,26 +917,29 @@ def _load_from_context(
916917
917918 @property
918919 def execution (self ) -> metadata_store_pb2 .Execution :
919- self ._check_context ()
920+ if self ._execution is None :
921+ raise RuntimeError (
922+ 'Operation must be performed within the pipeline state context.'
923+ )
920924 return self ._execution
921925
922926 def is_active (self ) -> bool :
923927 """Returns `True` if pipeline is active."""
924- self ._check_context ()
925- return execution_lib .is_execution_active (self ._execution )
928+ return execution_lib .is_execution_active (self .execution )
926929
927930 def initiate_stop (self , status : status_lib .Status ) -> None :
928931 """Updates pipeline state to signal stopping pipeline execution."""
929- self ._check_context ()
930932 data_types_utils .set_metadata_value (
931- self ._execution .custom_properties [_STOP_INITIATED ], 1 )
933+ self .execution .custom_properties [_STOP_INITIATED ], 1
934+ )
932935 data_types_utils .set_metadata_value (
933- self ._execution .custom_properties [_PIPELINE_STATUS_CODE ],
934- int (status .code ))
936+ self .execution .custom_properties [_PIPELINE_STATUS_CODE ],
937+ int (status .code ),
938+ )
935939 if status .message :
936940 data_types_utils .set_metadata_value (
937- self ._execution .custom_properties [_PIPELINE_STATUS_MSG ],
938- status . message )
941+ self .execution .custom_properties [_PIPELINE_STATUS_MSG ], status . message
942+ )
939943
940944 @_synchronized
941945 def initiate_resume (self ) -> None :
@@ -994,21 +998,24 @@ def _structure(
994998
995999 env .get_env ().prepare_orchestrator_for_pipeline_run (updated_pipeline )
9961000 data_types_utils .set_metadata_value (
997- self ._execution .custom_properties [_UPDATED_PIPELINE_IR ],
998- _PipelineIRCodec .get ().encode (updated_pipeline ))
1001+ self .execution .custom_properties [_UPDATED_PIPELINE_IR ],
1002+ _PipelineIRCodec .get ().encode (updated_pipeline ),
1003+ )
9991004 data_types_utils .set_metadata_value (
1000- self ._execution .custom_properties [_UPDATE_OPTIONS ],
1001- _base64_encode (update_options ))
1005+ self .execution .custom_properties [_UPDATE_OPTIONS ],
1006+ _base64_encode (update_options ),
1007+ )
10021008
10031009 def is_update_initiated (self ) -> bool :
1004- self ._check_context ()
1005- return self .is_active () and self ._execution .custom_properties .get (
1006- _UPDATED_PIPELINE_IR ) is not None
1010+ return (
1011+ self .is_active ()
1012+ and self .execution .custom_properties .get (_UPDATED_PIPELINE_IR )
1013+ is not None
1014+ )
10071015
10081016 def get_update_options (self ) -> pipeline_pb2 .UpdateOptions :
10091017 """Gets pipeline update option that was previously configured."""
1010- self ._check_context ()
1011- update_options = self ._execution .custom_properties .get (_UPDATE_OPTIONS )
1018+ update_options = self .execution .custom_properties .get (_UPDATE_OPTIONS )
10121019 if update_options is None :
10131020 logging .warning (
10141021 'pipeline execution missing expected custom property %s, '
@@ -1019,17 +1026,18 @@ def get_update_options(self) -> pipeline_pb2.UpdateOptions:
10191026
10201027 def apply_pipeline_update (self ) -> None :
10211028 """Applies pipeline update that was previously initiated."""
1022- self ._check_context ()
10231029 updated_pipeline_ir = _get_metadata_value (
1024- self ._execution .custom_properties .get (_UPDATED_PIPELINE_IR ))
1030+ self .execution .custom_properties .get (_UPDATED_PIPELINE_IR )
1031+ )
10251032 if not updated_pipeline_ir :
10261033 raise status_lib .StatusNotOkError (
10271034 code = status_lib .Code .INVALID_ARGUMENT ,
10281035 message = 'No updated pipeline IR to apply' )
10291036 data_types_utils .set_metadata_value (
1030- self ._execution .properties [_PIPELINE_IR ], updated_pipeline_ir )
1031- del self ._execution .custom_properties [_UPDATED_PIPELINE_IR ]
1032- del self ._execution .custom_properties [_UPDATE_OPTIONS ]
1037+ self .execution .properties [_PIPELINE_IR ], updated_pipeline_ir
1038+ )
1039+ del self .execution .custom_properties [_UPDATED_PIPELINE_IR ]
1040+ del self .execution .custom_properties [_UPDATE_OPTIONS ]
10331041 self .pipeline = _PipelineIRCodec .get ().decode (updated_pipeline_ir )
10341042
10351043 def is_stop_initiated (self ) -> bool :
@@ -1038,8 +1046,7 @@ def is_stop_initiated(self) -> bool:
10381046
10391047 def stop_initiated_reason (self ) -> Optional [status_lib .Status ]:
10401048 """Returns status object if stop initiated, `None` otherwise."""
1041- self ._check_context ()
1042- custom_properties = self ._execution .custom_properties
1049+ custom_properties = self .execution .custom_properties
10431050 if _get_metadata_value (custom_properties .get (_STOP_INITIATED )) == 1 :
10441051 code = _get_metadata_value (custom_properties .get (_PIPELINE_STATUS_CODE ))
10451052 if code is None :
@@ -1111,45 +1118,44 @@ def get_previous_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]:
11111118
11121119 def get_pipeline_execution_state (self ) -> metadata_store_pb2 .Execution .State :
11131120 """Returns state of underlying pipeline execution."""
1114- self ._check_context ()
1115- return self ._execution .last_known_state
1121+ return self .execution .last_known_state
11161122
11171123 def set_pipeline_execution_state (
11181124 self , state : metadata_store_pb2 .Execution .State ) -> None :
11191125 """Sets state of underlying pipeline execution."""
1120- self ._check_context ()
1121-
1122- if self ._execution .last_known_state != state :
1126+ if self .execution .last_known_state != state :
11231127 self ._on_commit_callbacks .append (
1124- functools .partial (_log_pipeline_execution_state_change ,
1125- self ._execution .last_known_state , state ,
1126- self .pipeline_uid ))
1127- self ._execution .last_known_state = state
1128+ functools .partial (
1129+ _log_pipeline_execution_state_change ,
1130+ self .execution .last_known_state ,
1131+ state ,
1132+ self .pipeline_uid ,
1133+ )
1134+ )
1135+ self .execution .last_known_state = state
11281136
11291137 def get_property (self , property_key : str ) -> Optional [types .Property ]:
11301138 """Returns custom property value from the pipeline execution."""
11311139 return _get_metadata_value (
1132- self ._execution .custom_properties .get (property_key ))
1140+ self .execution .custom_properties .get (property_key )
1141+ )
11331142
11341143 def save_property (
11351144 self , property_key : str , property_value : types .Property
11361145 ) -> None :
1137- self ._check_context ()
11381146 data_types_utils .set_metadata_value (
1139- self ._execution .custom_properties [property_key ], property_value
1147+ self .execution .custom_properties [property_key ], property_value
11401148 )
11411149
11421150 def remove_property (self , property_key : str ) -> None :
11431151 """Removes a custom property of the pipeline execution if exists."""
1144- self ._check_context ()
1145- if self ._execution .custom_properties .get (property_key ):
1146- del self ._execution .custom_properties [property_key ]
1152+ if self .execution .custom_properties .get (property_key ):
1153+ del self .execution .custom_properties [property_key ]
11471154
11481155 def pipeline_creation_time_secs_since_epoch (self ) -> int :
11491156 """Returns the pipeline creation time as seconds since epoch."""
1150- self ._check_context ()
11511157 # Convert from milliseconds to seconds.
1152- return self ._execution .create_time_since_epoch // 1000
1158+ return self .execution .create_time_since_epoch // 1000
11531159
11541160 def get_orchestration_options (
11551161 self ) -> orchestration_options .OrchestrationOptions :
@@ -1197,6 +1203,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
11971203 self ._mlmd_execution_atomic_op_context = None
11981204 self ._execution = None
11991205 try :
1206+ assert mlmd_execution_atomic_op_context is not None
12001207 mlmd_execution_atomic_op_context .__exit__ (exc_type , exc_val , exc_tb )
12011208 finally :
12021209 self ._on_commit_callbacks .clear ()
0 commit comments