Skip to content

Commit 86df648

Browse files
committed
Apply the ph.make_proto optimization to execution parameters with use_proto=True.
PiperOrigin-RevId: 653755777
1 parent 12051eb commit 86df648

File tree

5 files changed

+126
-10
lines changed

5 files changed

+126
-10
lines changed

tfx/orchestration/experimental/core/task_gen_utils_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,21 @@ def test_generate_resolved_info_with_dynamic_exec_prop(self):
473473
resolved_info.input_and_params[0].exec_properties['input_str'],
474474
)
475475

476+
def test_generate_resolved_info_with_ph_exec_parameter(self):
477+
otu.fake_example_gen_run(self._mlmd_connection, self._example_gen, 2, 1)
478+
otu.fake_component_output(self._mlmd_connection, self._transform)
479+
resolved_info = task_gen_utils.generate_resolved_info(
480+
self._mlmd_connection_manager,
481+
node_proto_view.get_view(self._trainer),
482+
self._pipeline,
483+
)
484+
self.assertProtoEquals(
485+
"""
486+
splits: "train"
487+
""",
488+
resolved_info.input_and_params[0].exec_properties['train_args'],
489+
)
490+
476491
@parameterized.named_parameters(
477492
dict(
478493
testcase_name='per_execution_idx_latest',

tfx/orchestration/experimental/core/testing/test_async_pipeline.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
from tfx.dsl.component.experimental.decorators import component
2121
from tfx.dsl.control_flow import for_each
2222
from tfx.dsl.input_resolution.canned_resolver_functions import latest_created
23+
from tfx.dsl.placeholder import placeholder as ph
2324
from tfx.orchestration import pipeline as pipeline_lib
25+
from tfx.proto import trainer_pb2
2426
from tfx.proto.orchestration import pipeline_pb2
2527
from tfx.types import standard_artifacts
2628

@@ -82,5 +84,12 @@ def create_pipeline() -> pipeline_pb2.Pipeline:
8284
assert trainer.node_info.id == 'my_trainer'
8385
for value in trainer.inputs.inputs.values():
8486
value.min_count = 1
87+
train_args_proto = trainer_pb2.TrainArgs(splits=['train'])
88+
train_args = ph.make_proto(train_args_proto)
89+
trainer.parameters.parameters['train_args'].CopyFrom(
90+
pipeline_pb2.Value(
91+
placeholder=train_args.encode()
92+
)
93+
)
8594

8695
return compiled_pipeline

tfx/orchestration/portable/inputs_utils_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,69 @@ def test_resolve_dynamic_parameters(self):
385385
dynamic_parameters, placeholder_utils.ResolutionContext()
386386
)
387387

388+
def test_resolve_ph_execution_parameters(self):
389+
execution_parameters = pipeline_pb2.NodeParameters()
390+
text_format.Parse(
391+
r"""
392+
parameters: {
393+
key: "train_args"
394+
value: {
395+
placeholder: {
396+
operator: {
397+
proto_op: {
398+
expression: {
399+
operator: {
400+
make_proto_op: {
401+
base: {
402+
type_url: "type.googleapis.com/tensorflow.service.TrainArgs"
403+
value: "\n\005train"
404+
}
405+
file_descriptors: {
406+
file: {
407+
name: "third_party/tfx/trainer.proto"
408+
package: "tensorflow.service"
409+
message_type {
410+
name: "TrainArgs"
411+
field {
412+
name: "splits"
413+
number: 1
414+
label: LABEL_REPEATED
415+
type: TYPE_STRING
416+
}
417+
}
418+
syntax: "proto3"
419+
}
420+
}
421+
}
422+
}
423+
}
424+
}
425+
}
426+
}
427+
}
428+
}
429+
""",
430+
execution_parameters,
431+
)
432+
test_artifact = types.standard_artifacts.String()
433+
test_artifact.uri = self.create_tempfile().full_path
434+
test_artifact.value = 'testvalue'
435+
input_dict = {'_test_placeholder': [test_artifact]}
436+
exec_params_resolved = inputs_utils.resolve_dynamic_parameters(
437+
execution_parameters,
438+
placeholder_utils.ResolutionContext(
439+
exec_info=data_types.ExecutionInfo(
440+
input_dict=input_dict, pipeline_run_id='testrunid'
441+
)
442+
),
443+
)
444+
self.assertProtoEquals(
445+
"""
446+
splits: "train"
447+
""",
448+
exec_params_resolved['train_args'],
449+
)
450+
388451

389452
if __name__ == '__main__':
390453
tf.test.main()

tfx/types/component_spec.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import copy
1717
import inspect
1818
import itertools
19-
from typing import Any, Dict, List, Mapping, Optional, Type, cast
19+
from typing import Any, cast, Dict, List, Mapping, Optional, Type
2020

2121
from tfx.dsl.component.experimental.json_compat import check_strict_json_compat
2222
from tfx.dsl.placeholder import placeholder
@@ -31,6 +31,21 @@
3131
# Use Any to avoid cyclic import.
3232
_BaseNode = Any
3333

34+
# Execution parameters that have `use_proto=True` but cannot be optimized with
35+
# Placeholder ph.make_proto.
36+
# TODO(b/350820714): Placeholder needs to be supported at runtime so that
37+
# TensorflowTrainerConfig placeholder can be used to create the Trainer and
38+
# Tuner jobs.
39+
# TODO(b/349459258): ExampleDiff executor needs to be updated to support
40+
# placeholder proto fields not being present.
41+
# TODO(b/352623284); DistributionValidator test needs to be updated to
42+
# support placeholder proto.
43+
_MAKE_PROTO_EXEMPT_EXEC_PARAMETERS = [
44+
'tensorflow_trainer',
45+
'example_diff_config',
46+
'distribution_validator_config',
47+
]
48+
3449

3550
def _is_runtime_param(data: Any) -> bool:
3651
return data.__class__.__name__ == 'RuntimeParameter'
@@ -229,11 +244,16 @@ def _parse_parameters(self, raw_args: Mapping[str, Any]):
229244
if (inspect.isclass(arg.type) and issubclass(arg.type, message.Message) # pytype: disable=not-supported-yet
230245
and value and not _is_runtime_param(value)) and not isinstance(
231246
value, placeholder.Placeholder):
247+
# If the parameter is defined with use_proto=True, convert the value to
248+
# proto from dict or json string if necessary before creating the proto
249+
# placeholder.
232250
if arg.use_proto:
233251
if isinstance(value, dict):
234252
value = proto_utils.dict_to_proto(value, arg.type())
235253
elif isinstance(value, str):
236254
value = proto_utils.json_to_proto(value, arg.type())
255+
if arg_name not in _MAKE_PROTO_EXEMPT_EXEC_PARAMETERS:
256+
value = placeholder.make_proto(value)
237257
else:
238258
# Create deterministic json string as it will be stored in metadata
239259
# for cache check.

tfx/types/component_spec_test.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
import unittest
2020

2121
import tensorflow as tf
22+
from tfx.dsl.compiler import placeholder_utils
2223
from tfx.dsl.components.base.testing import test_node
2324
from tfx.dsl.placeholder import placeholder
25+
from tfx.orchestration.portable import data_types
2426
from tfx.proto import example_gen_pb2
2527
from tfx.types import artifact
2628
from tfx.types import channel
@@ -32,7 +34,6 @@
3234
from tfx.utils import proto_utils
3335

3436
from google.protobuf import json_format
35-
from google.protobuf import text_format
3637

3738

3839
class _InputArtifact(artifact.Artifact):
@@ -432,15 +433,23 @@ class SpecWithNonPrimitiveTypes(ComponentSpec):
432433
input=channel.Channel(type=_InputArtifact),
433434
output=channel.Channel(type=_OutputArtifact))
434435

435-
# Verify exec_properties store parsed value when use_proto set to True.
436-
expected_proto = text_format.Parse(
436+
# Verify exec_properties stores the correct placeholder when use_proto set
437+
# to True.
438+
resolved_proto = placeholder_utils.resolve_placeholder_expression(
439+
spec.exec_properties['config_proto'].encode(),
440+
placeholder_utils.ResolutionContext(
441+
exec_info=data_types.ExecutionInfo()
442+
)
443+
)
444+
self.assertProtoEquals(
437445
"""
438-
splits {
439-
name: "name"
440-
pattern: "pattern"
441-
}
442-
""", example_gen_pb2.Input())
443-
self.assertProtoEquals(expected_proto, spec.exec_properties['config_proto'])
446+
splits {
447+
name: "name"
448+
pattern: "pattern"
449+
}
450+
""",
451+
resolved_proto
452+
)
444453
self.assertEqual(True, spec.exec_properties['boolean'])
445454
self.assertIsInstance(spec.exec_properties['list_config_proto'], list)
446455
self.assertEqual(spec.exec_properties['list_boolean'], [False, True])

0 commit comments

Comments
 (0)