Skip to content

Commit 12051eb

Browse files
committed
no-up
PiperOrigin-RevId: 653377052
1 parent ef54551 commit 12051eb

File tree

2 files changed

+118
-7
lines changed

2 files changed

+118
-7
lines changed

tfx/components/statistics_gen/executor.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
_TELEMETRY_DESCRIPTORS = ['StatisticsGen']
3737
STATS_DASHBOARD_LINK = 'stats_dashboard_link'
38+
SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME = 'sample_rate_by_split'
3839

3940

4041
class Executor(base_beam_executor.BaseBeamExecutor):
@@ -132,13 +133,6 @@ def Do(
132133

133134
split_names = [split for split in splits if split not in exclude_splits]
134135

135-
# Check if sample_rate_by_split contains invalid split names
136-
for split in sample_rate_by_split:
137-
if split not in split_names:
138-
logging.error(
139-
'Split %s provided in sample_rate_by_split is not valid.', split
140-
)
141-
142136
statistics_artifact = artifact_utils.get_single_instance(
143137
output_dict[standard_component_specs.STATISTICS_KEY]
144138
)
@@ -169,6 +163,24 @@ def Do(
169163
# json_utils
170164
stats_options = options.StatsOptions.from_json(stats_options_json)
171165

166+
sample_rate_by_split_property = {
167+
split: stats_options.sample_rate or 1.0 for split in split_names
168+
}
169+
for split in sample_rate_by_split:
170+
# Check if sample_rate_by_split contains invalid split names
171+
if split not in split_names:
172+
logging.error(
173+
'Split %s provided in sample_rate_by_split is not valid.', split
174+
)
175+
continue
176+
sample_rate_by_split_property[split] = sample_rate_by_split[split]
177+
178+
# Add sample_rate_by_split property to statistics artifact
179+
statistics_artifact.set_json_value_custom_property(
180+
SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME,
181+
json_utils.dumps(sample_rate_by_split_property),
182+
)
183+
172184
write_sharded_output = exec_properties.get(
173185
standard_component_specs.SHARDED_STATS_OUTPUT_KEY, False
174186
)

tfx/components/statistics_gen/executor_test.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ def testDo(
149149
artifact_utils.encode_split_names(['train', 'eval']), stats.split_names)
150150
self.assertEqual(
151151
stats.get_string_custom_property(executor.STATS_DASHBOARD_LINK), '')
152+
self.assertEqual(
153+
stats.has_custom_property(executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME),
154+
True,
155+
)
152156
self.assertEqual(stats.span, _TEST_SPAN_NUMBER)
153157

154158
# Check statistics_gen outputs.
@@ -228,6 +232,101 @@ def testDoWithSchemaAndStatsOptions(self):
228232
self._validate_stats_output(
229233
os.path.join(stats.uri, 'Split-eval', 'FeatureStats.pb'))
230234

235+
@parameterized.named_parameters(
236+
{
237+
'testcase_name': 'sample_rate_only',
238+
'sample_rate': 0.2,
239+
'sample_rate_by_split': 'null',
240+
'expected_sample_rate_by_split_property': {'train': 0.2, 'eval': 0.2},
241+
},
242+
{
243+
'testcase_name': 'sample_rate_by_split_only',
244+
'sample_rate': None,
245+
'sample_rate_by_split': '{"train": 0.4, "eval": 0.6}',
246+
'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 0.6},
247+
},
248+
{
249+
'testcase_name': 'sample_rate_for_some_split_only',
250+
'sample_rate': None,
251+
'sample_rate_by_split': '{"train": 0.4}',
252+
'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 1.0},
253+
},
254+
{
255+
'testcase_name': 'sample_rate_by_split_override',
256+
'sample_rate': 0.2,
257+
'sample_rate_by_split': '{"train": 0.4}',
258+
'expected_sample_rate_by_split_property': {'train': 0.4, 'eval': 0.2},
259+
},
260+
{
261+
'testcase_name': 'sample_rate_by_split_invalid',
262+
'sample_rate': 0.2,
263+
'sample_rate_by_split': '{"test": 0.4}',
264+
'expected_sample_rate_by_split_property': {'train': 0.2, 'eval': 0.2},
265+
},
266+
)
267+
def testDoWithSamplingProperty(
268+
self,
269+
sample_rate,
270+
sample_rate_by_split,
271+
expected_sample_rate_by_split_property
272+
):
273+
source_data_dir = os.path.join(
274+
os.path.dirname(os.path.dirname(__file__)), 'testdata'
275+
)
276+
output_data_dir = os.path.join(
277+
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
278+
self._testMethodName,
279+
)
280+
fileio.makedirs(output_data_dir)
281+
282+
# Create input dict.
283+
examples = standard_artifacts.Examples()
284+
examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
285+
examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])
286+
287+
schema = standard_artifacts.Schema()
288+
schema.uri = os.path.join(source_data_dir, 'schema_gen')
289+
290+
input_dict = {
291+
standard_component_specs.EXAMPLES_KEY: [examples],
292+
standard_component_specs.SCHEMA_KEY: [schema],
293+
}
294+
295+
exec_properties = {
296+
standard_component_specs.STATS_OPTIONS_JSON_KEY: tfdv.StatsOptions(
297+
sample_rate=sample_rate
298+
).to_json(),
299+
standard_component_specs.EXCLUDE_SPLITS_KEY: json_utils.dumps([]),
300+
standard_component_specs.SAMPLE_RATE_BY_SPLIT_KEY: sample_rate_by_split,
301+
}
302+
303+
# Create output dict.
304+
stats = standard_artifacts.ExampleStatistics()
305+
stats.uri = output_data_dir
306+
output_dict = {
307+
standard_component_specs.STATISTICS_KEY: [stats],
308+
}
309+
310+
# Run executor.
311+
stats_gen_executor = executor.Executor()
312+
stats_gen_executor.Do(input_dict, output_dict, exec_properties)
313+
314+
# Check statistics artifact sample_rate_by_split property.
315+
self.assertEqual(
316+
json_utils.loads(stats.get_json_value_custom_property(
317+
executor.SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME
318+
)),
319+
expected_sample_rate_by_split_property,
320+
)
321+
322+
# Check statistics_gen outputs.
323+
self._validate_stats_output(
324+
os.path.join(stats.uri, 'Split-train', 'FeatureStats.pb')
325+
)
326+
self._validate_stats_output(
327+
os.path.join(stats.uri, 'Split-eval', 'FeatureStats.pb')
328+
)
329+
231330
def testDoWithTwoSchemas(self):
232331
source_data_dir = os.path.join(
233332
os.path.dirname(os.path.dirname(__file__)), 'testdata')

0 commit comments

Comments
 (0)