@@ -149,6 +149,10 @@ def testDo(
149149 artifact_utils .encode_split_names (['train' , 'eval' ]), stats .split_names )
150150 self .assertEqual (
151151 stats .get_string_custom_property (executor .STATS_DASHBOARD_LINK ), '' )
152+ self .assertEqual (
153+ stats .has_custom_property (executor .SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME ),
154+ True ,
155+ )
152156 self .assertEqual (stats .span , _TEST_SPAN_NUMBER )
153157
154158 # Check statistics_gen outputs.
@@ -228,6 +232,101 @@ def testDoWithSchemaAndStatsOptions(self):
228232 self ._validate_stats_output (
229233 os .path .join (stats .uri , 'Split-eval' , 'FeatureStats.pb' ))
230234
235+ @parameterized .named_parameters (
236+ {
237+ 'testcase_name' : 'sample_rate_only' ,
238+ 'sample_rate' : 0.2 ,
239+ 'sample_rate_by_split' : 'null' ,
240+ 'expected_sample_rate_by_split_property' : {'train' : 0.2 , 'eval' : 0.2 },
241+ },
242+ {
243+ 'testcase_name' : 'sample_rate_by_split_only' ,
244+ 'sample_rate' : None ,
245+ 'sample_rate_by_split' : '{"train": 0.4, "eval": 0.6}' ,
246+ 'expected_sample_rate_by_split_property' : {'train' : 0.4 , 'eval' : 0.6 },
247+ },
248+ {
249+ 'testcase_name' : 'sample_rate_for_some_split_only' ,
250+ 'sample_rate' : None ,
251+ 'sample_rate_by_split' : '{"train": 0.4}' ,
252+ 'expected_sample_rate_by_split_property' : {'train' : 0.4 , 'eval' : 1.0 },
253+ },
254+ {
255+ 'testcase_name' : 'sample_rate_by_split_override' ,
256+ 'sample_rate' : 0.2 ,
257+ 'sample_rate_by_split' : '{"train": 0.4}' ,
258+ 'expected_sample_rate_by_split_property' : {'train' : 0.4 , 'eval' : 0.2 },
259+ },
260+ {
261+ 'testcase_name' : 'sample_rate_by_split_invalid' ,
262+ 'sample_rate' : 0.2 ,
263+ 'sample_rate_by_split' : '{"test": 0.4}' ,
264+ 'expected_sample_rate_by_split_property' : {'train' : 0.2 , 'eval' : 0.2 },
265+ },
266+ )
267+ def testDoWithSamplingProperty (
268+ self ,
269+ sample_rate ,
270+ sample_rate_by_split ,
271+ expected_sample_rate_by_split_property
272+ ):
273+ source_data_dir = os .path .join (
274+ os .path .dirname (os .path .dirname (__file__ )), 'testdata'
275+ )
276+ output_data_dir = os .path .join (
277+ os .environ .get ('TEST_UNDECLARED_OUTPUTS_DIR' , self .get_temp_dir ()),
278+ self ._testMethodName ,
279+ )
280+ fileio .makedirs (output_data_dir )
281+
282+ # Create input dict.
283+ examples = standard_artifacts .Examples ()
284+ examples .uri = os .path .join (source_data_dir , 'csv_example_gen' )
285+ examples .split_names = artifact_utils .encode_split_names (['train' , 'eval' ])
286+
287+ schema = standard_artifacts .Schema ()
288+ schema .uri = os .path .join (source_data_dir , 'schema_gen' )
289+
290+ input_dict = {
291+ standard_component_specs .EXAMPLES_KEY : [examples ],
292+ standard_component_specs .SCHEMA_KEY : [schema ],
293+ }
294+
295+ exec_properties = {
296+ standard_component_specs .STATS_OPTIONS_JSON_KEY : tfdv .StatsOptions (
297+ sample_rate = sample_rate
298+ ).to_json (),
299+ standard_component_specs .EXCLUDE_SPLITS_KEY : json_utils .dumps ([]),
300+ standard_component_specs .SAMPLE_RATE_BY_SPLIT_KEY : sample_rate_by_split ,
301+ }
302+
303+ # Create output dict.
304+ stats = standard_artifacts .ExampleStatistics ()
305+ stats .uri = output_data_dir
306+ output_dict = {
307+ standard_component_specs .STATISTICS_KEY : [stats ],
308+ }
309+
310+ # Run executor.
311+ stats_gen_executor = executor .Executor ()
312+ stats_gen_executor .Do (input_dict , output_dict , exec_properties )
313+
314+ # Check statistics artifact sample_rate_by_split property.
315+ self .assertEqual (
316+ json_utils .loads (stats .get_json_value_custom_property (
317+ executor .SAMPLE_RATE_BY_SPLIT_PROPERTY_NAME
318+ )),
319+ expected_sample_rate_by_split_property ,
320+ )
321+
322+ # Check statistics_gen outputs.
323+ self ._validate_stats_output (
324+ os .path .join (stats .uri , 'Split-train' , 'FeatureStats.pb' )
325+ )
326+ self ._validate_stats_output (
327+ os .path .join (stats .uri , 'Split-eval' , 'FeatureStats.pb' )
328+ )
329+
231330 def testDoWithTwoSchemas (self ):
232331 source_data_dir = os .path .join (
233332 os .path .dirname (os .path .dirname (__file__ )), 'testdata' )
0 commit comments