Skip to content

Commit c0af966

Browse files
kmontetfx-copybara
authored andcommitted
Update _PipelineIRCodec to use base dir encoded into pipeline IR
PiperOrigin-RevId: 654068608
1 parent e3ebdca commit c0af966

File tree

7 files changed

+104
-59
lines changed

7 files changed

+104
-59
lines changed

tfx/orchestration/experimental/core/pipeline_ops_test.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -564,11 +564,12 @@ def test_revive_pipeline_run_active_pipeline_run_concurrent_runs_disabled(
564564

565565
def test_revive_pipeline_run_with_subpipelines(self):
566566
with self._mlmd_connection as m:
567-
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline()
567+
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline(
568+
temp_dir=self.create_tempdir().full_path
569+
)
568570
runtime_parameter_utils.substitute_runtime_parameter(
569571
pipeline,
570572
{
571-
constants.PIPELINE_ROOT_PARAMETER_NAME: '/path/to/root',
572573
constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run0',
573574
},
574575
)
@@ -820,11 +821,12 @@ def test_initiate_pipeline_start_with_partial_run_and_subpipeline(
820821
self, mock_snapshot, run_subpipeline
821822
):
822823
with self._mlmd_connection as m:
823-
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline()
824+
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline(
825+
temp_dir=self.create_tempdir().full_path
826+
)
824827
runtime_parameter_utils.substitute_runtime_parameter(
825828
pipeline,
826829
{
827-
constants.PIPELINE_ROOT_PARAMETER_NAME: '/my/pipeline/root',
828830
constants.PIPELINE_RUN_ID_PARAMETER_NAME: 'run-0123',
829831
},
830832
)
@@ -1519,7 +1521,9 @@ def test_record_orchestration_time(self, pipeline, expected_run_id):
15191521
def test_record_orchestration_time_subpipeline(self):
15201522
with self._mlmd_cm as mlmd_connection_manager:
15211523
m = mlmd_connection_manager.primary_mlmd_handle
1522-
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline()
1524+
pipeline = test_sync_pipeline.create_pipeline_with_subpipeline(
1525+
temp_dir=self.create_tempdir().full_path
1526+
)
15231527
runtime_parameter_utils.substitute_runtime_parameter(
15241528
pipeline,
15251529
{
@@ -2653,22 +2657,25 @@ def test_executor_node_stop_then_start_flow(
26532657
self.assertEqual(pstate.NodeState.STARTED, node_state.state)
26542658

26552659
@parameterized.named_parameters(
2656-
dict(
2657-
testcase_name='async', pipeline=test_async_pipeline.create_pipeline()
2658-
),
2659-
dict(
2660-
testcase_name='sync',
2661-
pipeline=test_sync_pipeline.create_pipeline(),
2662-
),
2660+
dict(testcase_name='async', mode='async'),
2661+
dict(testcase_name='sync', mode='sync'),
26632662
)
26642663
@mock.patch.object(sync_pipeline_task_gen, 'SyncPipelineTaskGenerator')
26652664
@mock.patch.object(async_pipeline_task_gen, 'AsyncPipelineTaskGenerator')
26662665
def test_pure_service_node_stop_then_start_flow(
26672666
self,
26682667
mock_async_task_gen,
26692668
mock_sync_task_gen,
2670-
pipeline,
2669+
mode,
26712670
):
2671+
if mode == 'async':
2672+
pipeline = test_async_pipeline.create_pipeline(
2673+
temp_dir=self.create_tempdir().full_path
2674+
)
2675+
else:
2676+
pipeline = test_sync_pipeline.create_pipeline(
2677+
temp_dir=self.create_tempdir().full_path
2678+
)
26722679
runtime_parameter_utils.substitute_runtime_parameter(
26732680
pipeline,
26742681
{
@@ -2862,7 +2869,9 @@ def test_wait_for_predicate_timeout_secs_None(self, mock_sleep):
28622869
self.assertEqual(mock_sleep.call_count, 2)
28632870

28642871
def test_resume_manual_node(self):
2865-
pipeline = test_manual_node.create_pipeline()
2872+
pipeline = test_manual_node.create_pipeline(
2873+
temp_dir=self.create_tempdir().full_path
2874+
)
28662875
runtime_parameter_utils.substitute_runtime_parameter(
28672876
pipeline,
28682877
{
@@ -3516,7 +3525,9 @@ def health_status(self) -> status_lib.Status:
35163525
)
35173526

35183527
def test_delete_pipeline_run(self):
3519-
pipeline = test_sync_pipeline.create_pipeline()
3528+
pipeline = test_sync_pipeline.create_pipeline(
3529+
temp_dir=self.create_tempdir().full_path
3530+
)
35203531
runtime_parameter_utils.substitute_runtime_parameter(
35213532
pipeline,
35223533
{

tfx/orchestration/experimental/core/pipeline_state.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -424,29 +424,38 @@ def testonly_reset(cls) -> None:
424424
with cls._lock:
425425
cls._obj = None
426426

427-
def __init__(self):
428-
self.base_dir = env.get_env().get_base_dir()
429-
if self.base_dir:
430-
self.pipeline_irs_dir = os.path.join(self.base_dir,
431-
self._ORCHESTRATOR_METADATA_DIR,
432-
self._PIPELINE_IRS_DIR)
433-
fileio.makedirs(self.pipeline_irs_dir)
434-
else:
435-
self.pipeline_irs_dir = None
436-
437427
def encode(self, pipeline: pipeline_pb2.Pipeline) -> str:
438428
"""Encodes pipeline IR."""
439429
# Attempt to store as a base64 encoded string. If base_dir is provided
440430
# and the length is too large, store the IR on disk and retain the URL.
441431
# TODO(b/248786921): Always store pipeline IR to base_dir once the
442432
# accessibility issue is resolved.
433+
434+
# Note that this setup means that every *subpipeline* will have its own
435+
# "irs" dir. This is fine, though ideally we would put all pipeline IRs
436+
# under the root pipeline dir, which would require us to *also* store the
437+
# root pipeline dir in the IR.
438+
439+
base_dir = pipeline.runtime_spec.pipeline_root.field_value.string_value
440+
if base_dir:
441+
pipeline_ir_dir = os.path.join(
442+
base_dir, self._ORCHESTRATOR_METADATA_DIR, self._PIPELINE_IRS_DIR
443+
)
444+
fileio.makedirs(pipeline_ir_dir)
445+
else:
446+
pipeline_ir_dir = None
443447
pipeline_encoded = _base64_encode(pipeline)
444448
max_mlmd_str_value_len = env.get_env().max_mlmd_str_value_length()
445-
if self.base_dir and max_mlmd_str_value_len is not None and len(
446-
pipeline_encoded) > max_mlmd_str_value_len:
449+
if (
450+
base_dir
451+
and pipeline_ir_dir
452+
and max_mlmd_str_value_len is not None
453+
and len(pipeline_encoded) > max_mlmd_str_value_len
454+
):
447455
pipeline_id = task_lib.PipelineUid.from_pipeline(pipeline).pipeline_id
448-
pipeline_url = os.path.join(self.pipeline_irs_dir,
449-
f'{pipeline_id}_{uuid.uuid4()}.pb')
456+
pipeline_url = os.path.join(
457+
pipeline_ir_dir, f'{pipeline_id}_{uuid.uuid4()}.pb'
458+
)
450459
with fileio.open(pipeline_url, 'wb') as file:
451460
file.write(pipeline.SerializeToString())
452461
pipeline_encoded = json.dumps({self._PIPELINE_IR_URL_KEY: pipeline_url})

tfx/orchestration/experimental/core/pipeline_state_test.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def _test_pipeline(
4949
param=1,
5050
pipeline_nodes: List[str] = None,
5151
pipeline_run_id: str = 'run0',
52+
pipeline_root: str = '',
5253
):
5354
pipeline = pipeline_pb2.Pipeline()
5455
pipeline.pipeline_info.id = pipeline_id
@@ -63,6 +64,7 @@ def _test_pipeline(
6364
pipeline.runtime_spec.pipeline_run_id.field_value.string_value = (
6465
pipeline_run_id
6566
)
67+
pipeline.runtime_spec.pipeline_root.field_value.string_value = pipeline_root
6668
return pipeline
6769

6870

@@ -202,7 +204,11 @@ def test_encode_decode_with_base_dir(self):
202204

203205
def test_encode_decode_exceeds_max_len(self):
204206
with TestEnv(self._pipeline_root, 0):
205-
pipeline = _test_pipeline('pipeline1', pipeline_nodes=['Trainer'])
207+
pipeline = _test_pipeline(
208+
'pipeline1',
209+
pipeline_nodes=['Trainer'],
210+
pipeline_root=self.create_tempdir().full_path,
211+
)
206212
pipeline_encoded = pstate._PipelineIRCodec.get().encode(pipeline)
207213
self.assertEqual(
208214
pipeline, pstate._PipelineIRCodec.get().decode(pipeline_encoded)

tfx/orchestration/experimental/core/sample_mlmd_creator.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ def _get_mlmd_connection(path: str) -> metadata.Metadata:
5252
return metadata.Metadata(connection_config=connection_config)
5353

5454

55-
def _test_pipeline(ir_path: str, pipeline_id: str, run_id: str,
56-
deployment_config: Optional[message.Message]):
55+
def _test_pipeline(
56+
ir_path: str,
57+
pipeline_id: str,
58+
run_id: str,
59+
deployment_config: Optional[message.Message],
60+
):
5761
"""Creates test pipeline with pipeline_id and run_id."""
5862
pipeline = pipeline_pb2.Pipeline()
5963
io_utils.parse_pbtxt_file(ir_path, pipeline)
@@ -85,25 +89,30 @@ def _execute_nodes(handle: metadata.Metadata, pipeline: pipeline_pb2.Pipeline,
8589
)
8690

8791

88-
def _get_ir_path(external_ir_file: str):
92+
def _get_ir_path(external_ir_file: str, temp_dir: str = ''):
8993
if external_ir_file:
9094
return external_ir_file
9195
ir_file_path = tempfile.mktemp(suffix='.pbtxt')
92-
io_utils.write_pbtxt_file(ir_file_path, test_sync_pipeline.create_pipeline())
96+
io_utils.write_pbtxt_file(
97+
ir_file_path, test_sync_pipeline.create_pipeline(temp_dir=temp_dir)
98+
)
9399
return ir_file_path
94100

95101

96-
def create_sample_pipeline(m: metadata.Metadata,
97-
pipeline_id: str,
98-
run_num: int,
99-
export_ir_path: str = '',
100-
external_ir_file: str = '',
101-
deployment_config: Optional[message.Message] = None,
102-
execute_nodes_func: Callable[
103-
[metadata.Metadata, pipeline_pb2.Pipeline, int],
104-
None] = _execute_nodes):
102+
def create_sample_pipeline(
103+
m: metadata.Metadata,
104+
pipeline_id: str,
105+
run_num: int,
106+
export_ir_path: str = '',
107+
external_ir_file: str = '',
108+
deployment_config: Optional[message.Message] = None,
109+
execute_nodes_func: Callable[
110+
[metadata.Metadata, pipeline_pb2.Pipeline, int], None
111+
] = _execute_nodes,
112+
temp_dir: str = '',
113+
):
105114
"""Creates a list of pipeline and node execution."""
106-
ir_path = _get_ir_path(external_ir_file)
115+
ir_path = _get_ir_path(external_ir_file, temp_dir=temp_dir)
107116
for i in range(run_num):
108117
run_id = 'run%02d' % i
109118
pipeline = _test_pipeline(ir_path, pipeline_id, run_id, deployment_config)

tfx/orchestration/experimental/core/testing/test_async_pipeline.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Async pipeline for testing."""
15+
import os
1516

1617
from tfx.dsl.compiler import compiler
1718
from tfx.dsl.component.experimental.annotations import InputArtifact
@@ -51,7 +52,7 @@ def _trainer(examples: InputArtifact[standard_artifacts.Examples],
5152
del examples, transform_graph, model
5253

5354

54-
def create_pipeline() -> pipeline_pb2.Pipeline:
55+
def create_pipeline(temp_dir: str = '/') -> pipeline_pb2.Pipeline:
5556
"""Creates an async pipeline for testing."""
5657
# pylint: disable=no-value-for-parameter
5758
example_gen = _example_gen().with_id('my_example_gen')
@@ -68,13 +69,14 @@ def create_pipeline() -> pipeline_pb2.Pipeline:
6869

6970
pipeline = pipeline_lib.Pipeline(
7071
pipeline_name='my_pipeline',
71-
pipeline_root='/path/to/root',
72+
pipeline_root=os.path.join(temp_dir, 'path/to/root'),
7273
components=[
7374
example_gen,
7475
transform,
7576
trainer,
7677
],
77-
execution_mode=pipeline_lib.ExecutionMode.ASYNC)
78+
execution_mode=pipeline_lib.ExecutionMode.ASYNC,
79+
)
7880
dsl_compiler = compiler.Compiler(use_input_v2=True)
7981
compiled_pipeline: pipeline_pb2.Pipeline = dsl_compiler.compile(pipeline)
8082

tfx/orchestration/experimental/core/testing/test_manual_node.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,23 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Test pipeline with only manual node."""
15+
import os
1516

1617
from tfx.dsl.compiler import compiler
1718
from tfx.dsl.components.common import manual_node
1819
from tfx.orchestration import pipeline as pipeline_lib
1920
from tfx.proto.orchestration import pipeline_pb2
2021

2122

22-
def create_pipeline() -> pipeline_pb2.Pipeline:
23+
def create_pipeline(temp_dir: str = '/') -> pipeline_pb2.Pipeline:
2324
"""Builds a test pipeline with only manual node."""
2425
manual = manual_node.ManualNode(description='Do something.')
2526

2627
pipeline = pipeline_lib.Pipeline(
2728
pipeline_name='my_pipeline',
28-
pipeline_root='/path/to/root',
29-
components=[
30-
manual
31-
],
32-
enable_cache=True)
29+
pipeline_root=os.path.join(temp_dir, 'path/to/root'),
30+
components=[manual],
31+
enable_cache=True,
32+
)
3333
dsl_compiler = compiler.Compiler()
3434
return dsl_compiler.compile(pipeline)

tfx/orchestration/experimental/core/testing/test_sync_pipeline.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Sync pipeline for testing."""
15+
import os
1516

1617
from tfx.dsl.compiler import compiler
1718
from tfx.dsl.component.experimental.annotations import InputArtifact
@@ -82,7 +83,7 @@ def _chore():
8283
pass
8384

8485

85-
def create_pipeline() -> pipeline_pb2.Pipeline:
86+
def create_pipeline(temp_dir: str = '/') -> pipeline_pb2.Pipeline:
8687
"""Builds a test pipeline.
8788
8889
┌───────────┐
@@ -107,6 +108,10 @@ def create_pipeline() -> pipeline_pb2.Pipeline:
107108
│chore_b │
108109
└────────┘
109110
111+
Args:
112+
temp_dir: If provieded, a temporary test directory to use as prefix to the
113+
pipeline root.
114+
110115
Returns:
111116
A pipeline proto for the above DAG
112117
"""
@@ -142,7 +147,7 @@ def create_pipeline() -> pipeline_pb2.Pipeline:
142147

143148
pipeline = pipeline_lib.Pipeline(
144149
pipeline_name='my_pipeline',
145-
pipeline_root='/path/to/root',
150+
pipeline_root=os.path.join(temp_dir, 'path/to/root'),
146151
components=[
147152
example_gen,
148153
stats_gen,
@@ -154,7 +159,8 @@ def create_pipeline() -> pipeline_pb2.Pipeline:
154159
chore_a,
155160
chore_b,
156161
],
157-
enable_cache=True)
162+
enable_cache=True,
163+
)
158164
dsl_compiler = compiler.Compiler()
159165
return dsl_compiler.compile(pipeline)
160166

@@ -300,7 +306,9 @@ def create_resource_lifetime_pipeline() -> pipeline_pb2.Pipeline:
300306
return dsl_compiler.compile(pipeline)
301307

302308

303-
def create_pipeline_with_subpipeline() -> pipeline_pb2.Pipeline:
309+
def create_pipeline_with_subpipeline(
310+
temp_dir: str = '/',
311+
) -> pipeline_pb2.Pipeline:
304312
"""Creates a pipeline with a subpipeline."""
305313
# pylint: disable=no-value-for-parameter
306314
example_gen = _example_gen().with_id('my_example_gen')
@@ -318,7 +326,7 @@ def create_pipeline_with_subpipeline() -> pipeline_pb2.Pipeline:
318326

319327
componsable_pipeline = pipeline_lib.Pipeline(
320328
pipeline_name='sub-pipeline',
321-
pipeline_root='/path/to/root/sub',
329+
pipeline_root=os.path.join(temp_dir, 'path/to/root/sub'),
322330
components=[stats_gen, schema_gen],
323331
enable_cache=True,
324332
inputs=p_in,
@@ -332,7 +340,7 @@ def create_pipeline_with_subpipeline() -> pipeline_pb2.Pipeline:
332340

333341
pipeline = pipeline_lib.Pipeline(
334342
pipeline_name='my_pipeline',
335-
pipeline_root='/path/to/root',
343+
pipeline_root=os.path.join(temp_dir, 'path/to/root'),
336344
components=[
337345
example_gen,
338346
componsable_pipeline,

0 commit comments

Comments
 (0)