Skip to content

Commit dab85bf

Browse files
committed
Internal clean up.
PiperOrigin-RevId: 654103272
1 parent c0af966 commit dab85bf

File tree

3 files changed

+54
-44
lines changed

3 files changed

+54
-44
lines changed

tfx/orchestration/experimental/core/async_pipeline_task_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ def _generate_tasks_for_node(
477477

478478
for input_and_param in unprocessed_inputs:
479479
if backfill_token:
480+
assert input_and_param.exec_properties is not None
480481
input_and_param.exec_properties[
481482
constants.BACKFILL_TOKEN_CUSTOM_PROPERTY_KEY
482483
] = backfill_token

tfx/orchestration/experimental/core/pipeline_state.py

Lines changed: 51 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import os
2323
import threading
2424
import time
25-
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Set, Tuple
25+
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Set, Tuple, cast
2626
import uuid
2727

2828
from absl import logging
@@ -557,7 +557,8 @@ def __init__(
557557
self._mlmd_execution_atomic_op_context = None
558558
self._execution: Optional[metadata_store_pb2.Execution] = None
559559
self._on_commit_callbacks: List[Callable[[], None]] = []
560-
self._node_states_proxy: Optional[_NodeStatesProxy] = None
560+
# The note state proxy is assumed to be initialized before being used.
561+
self._node_states_proxy: _NodeStatesProxy = cast(_NodeStatesProxy, None)
561562

562563
@classmethod
563564
@telemetry_utils.noop_telemetry(metrics_utils.no_op_metrics)
@@ -916,26 +917,29 @@ def _load_from_context(
916917

917918
@property
918919
def execution(self) -> metadata_store_pb2.Execution:
919-
self._check_context()
920+
if self._execution is None:
921+
raise RuntimeError(
922+
'Operation must be performed within the pipeline state context.'
923+
)
920924
return self._execution
921925

922926
def is_active(self) -> bool:
923927
"""Returns `True` if pipeline is active."""
924-
self._check_context()
925-
return execution_lib.is_execution_active(self._execution)
928+
return execution_lib.is_execution_active(self.execution)
926929

927930
def initiate_stop(self, status: status_lib.Status) -> None:
928931
"""Updates pipeline state to signal stopping pipeline execution."""
929-
self._check_context()
930932
data_types_utils.set_metadata_value(
931-
self._execution.custom_properties[_STOP_INITIATED], 1)
933+
self.execution.custom_properties[_STOP_INITIATED], 1
934+
)
932935
data_types_utils.set_metadata_value(
933-
self._execution.custom_properties[_PIPELINE_STATUS_CODE],
934-
int(status.code))
936+
self.execution.custom_properties[_PIPELINE_STATUS_CODE],
937+
int(status.code),
938+
)
935939
if status.message:
936940
data_types_utils.set_metadata_value(
937-
self._execution.custom_properties[_PIPELINE_STATUS_MSG],
938-
status.message)
941+
self.execution.custom_properties[_PIPELINE_STATUS_MSG], status.message
942+
)
939943

940944
@_synchronized
941945
def initiate_resume(self) -> None:
@@ -994,21 +998,24 @@ def _structure(
994998

995999
env.get_env().prepare_orchestrator_for_pipeline_run(updated_pipeline)
9961000
data_types_utils.set_metadata_value(
997-
self._execution.custom_properties[_UPDATED_PIPELINE_IR],
998-
_PipelineIRCodec.get().encode(updated_pipeline))
1001+
self.execution.custom_properties[_UPDATED_PIPELINE_IR],
1002+
_PipelineIRCodec.get().encode(updated_pipeline),
1003+
)
9991004
data_types_utils.set_metadata_value(
1000-
self._execution.custom_properties[_UPDATE_OPTIONS],
1001-
_base64_encode(update_options))
1005+
self.execution.custom_properties[_UPDATE_OPTIONS],
1006+
_base64_encode(update_options),
1007+
)
10021008

10031009
def is_update_initiated(self) -> bool:
1004-
self._check_context()
1005-
return self.is_active() and self._execution.custom_properties.get(
1006-
_UPDATED_PIPELINE_IR) is not None
1010+
return (
1011+
self.is_active()
1012+
and self.execution.custom_properties.get(_UPDATED_PIPELINE_IR)
1013+
is not None
1014+
)
10071015

10081016
def get_update_options(self) -> pipeline_pb2.UpdateOptions:
10091017
"""Gets pipeline update option that was previously configured."""
1010-
self._check_context()
1011-
update_options = self._execution.custom_properties.get(_UPDATE_OPTIONS)
1018+
update_options = self.execution.custom_properties.get(_UPDATE_OPTIONS)
10121019
if update_options is None:
10131020
logging.warning(
10141021
'pipeline execution missing expected custom property %s, '
@@ -1019,17 +1026,18 @@ def get_update_options(self) -> pipeline_pb2.UpdateOptions:
10191026

10201027
def apply_pipeline_update(self) -> None:
10211028
"""Applies pipeline update that was previously initiated."""
1022-
self._check_context()
10231029
updated_pipeline_ir = _get_metadata_value(
1024-
self._execution.custom_properties.get(_UPDATED_PIPELINE_IR))
1030+
self.execution.custom_properties.get(_UPDATED_PIPELINE_IR)
1031+
)
10251032
if not updated_pipeline_ir:
10261033
raise status_lib.StatusNotOkError(
10271034
code=status_lib.Code.INVALID_ARGUMENT,
10281035
message='No updated pipeline IR to apply')
10291036
data_types_utils.set_metadata_value(
1030-
self._execution.properties[_PIPELINE_IR], updated_pipeline_ir)
1031-
del self._execution.custom_properties[_UPDATED_PIPELINE_IR]
1032-
del self._execution.custom_properties[_UPDATE_OPTIONS]
1037+
self.execution.properties[_PIPELINE_IR], updated_pipeline_ir
1038+
)
1039+
del self.execution.custom_properties[_UPDATED_PIPELINE_IR]
1040+
del self.execution.custom_properties[_UPDATE_OPTIONS]
10331041
self.pipeline = _PipelineIRCodec.get().decode(updated_pipeline_ir)
10341042

10351043
def is_stop_initiated(self) -> bool:
@@ -1038,8 +1046,7 @@ def is_stop_initiated(self) -> bool:
10381046

10391047
def stop_initiated_reason(self) -> Optional[status_lib.Status]:
10401048
"""Returns status object if stop initiated, `None` otherwise."""
1041-
self._check_context()
1042-
custom_properties = self._execution.custom_properties
1049+
custom_properties = self.execution.custom_properties
10431050
if _get_metadata_value(custom_properties.get(_STOP_INITIATED)) == 1:
10441051
code = _get_metadata_value(custom_properties.get(_PIPELINE_STATUS_CODE))
10451052
if code is None:
@@ -1111,45 +1118,44 @@ def get_previous_node_states_dict(self) -> Dict[task_lib.NodeUid, NodeState]:
11111118

11121119
def get_pipeline_execution_state(self) -> metadata_store_pb2.Execution.State:
11131120
"""Returns state of underlying pipeline execution."""
1114-
self._check_context()
1115-
return self._execution.last_known_state
1121+
return self.execution.last_known_state
11161122

11171123
def set_pipeline_execution_state(
11181124
self, state: metadata_store_pb2.Execution.State) -> None:
11191125
"""Sets state of underlying pipeline execution."""
1120-
self._check_context()
1121-
1122-
if self._execution.last_known_state != state:
1126+
if self.execution.last_known_state != state:
11231127
self._on_commit_callbacks.append(
1124-
functools.partial(_log_pipeline_execution_state_change,
1125-
self._execution.last_known_state, state,
1126-
self.pipeline_uid))
1127-
self._execution.last_known_state = state
1128+
functools.partial(
1129+
_log_pipeline_execution_state_change,
1130+
self.execution.last_known_state,
1131+
state,
1132+
self.pipeline_uid,
1133+
)
1134+
)
1135+
self.execution.last_known_state = state
11281136

11291137
def get_property(self, property_key: str) -> Optional[types.Property]:
11301138
"""Returns custom property value from the pipeline execution."""
11311139
return _get_metadata_value(
1132-
self._execution.custom_properties.get(property_key))
1140+
self.execution.custom_properties.get(property_key)
1141+
)
11331142

11341143
def save_property(
11351144
self, property_key: str, property_value: types.Property
11361145
) -> None:
1137-
self._check_context()
11381146
data_types_utils.set_metadata_value(
1139-
self._execution.custom_properties[property_key], property_value
1147+
self.execution.custom_properties[property_key], property_value
11401148
)
11411149

11421150
def remove_property(self, property_key: str) -> None:
11431151
"""Removes a custom property of the pipeline execution if exists."""
1144-
self._check_context()
1145-
if self._execution.custom_properties.get(property_key):
1146-
del self._execution.custom_properties[property_key]
1152+
if self.execution.custom_properties.get(property_key):
1153+
del self.execution.custom_properties[property_key]
11471154

11481155
def pipeline_creation_time_secs_since_epoch(self) -> int:
11491156
"""Returns the pipeline creation time as seconds since epoch."""
1150-
self._check_context()
11511157
# Convert from milliseconds to seconds.
1152-
return self._execution.create_time_since_epoch // 1000
1158+
return self.execution.create_time_since_epoch // 1000
11531159

11541160
def get_orchestration_options(
11551161
self) -> orchestration_options.OrchestrationOptions:
@@ -1197,6 +1203,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
11971203
self._mlmd_execution_atomic_op_context = None
11981204
self._execution = None
11991205
try:
1206+
assert mlmd_execution_atomic_op_context is not None
12001207
mlmd_execution_atomic_op_context.__exit__(exc_type, exc_val, exc_tb)
12011208
finally:
12021209
self._on_commit_callbacks.clear()

tfx/orchestration/experimental/core/sync_pipeline_task_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __call__(self) -> List[task_lib.Task]:
169169
):
170170
successful_node_ids.add(node_id)
171171
elif node_state.is_failure():
172+
assert node_state.status is not None
172173
failed_nodes_dict[node_id] = node_state.status
173174

174175
# Collect nodes that cannot be run because they have a failed ancestor.
@@ -545,6 +546,7 @@ def _generate_tasks_from_resolved_inputs(
545546
# executions. Idempotency is guaranteed by external_id.
546547
updated_external_artifacts = []
547548
for input_and_params in resolved_info.input_and_params:
549+
assert input_and_params.input_artifacts is not None
548550
for artifacts in input_and_params.input_artifacts.values():
549551
updated_external_artifacts.extend(
550552
task_gen_utils.update_external_artifact_type(

0 commit comments

Comments
 (0)