diff --git a/data/femnist/preprocess/group_by_writer.py b/data/femnist/preprocess/group_by_writer.py index 32b950a4..6180baa8 100644 --- a/data/femnist/preprocess/group_by_writer.py +++ b/data/femnist/preprocess/group_by_writer.py @@ -19,8 +19,10 @@ if w != cw: writers.append((cw, cimages)) cw = w - cimages = [(f, c)] - cimages.append((f, c)) + cimages = [] + cimages.append((f, c)) + else: + cimages.append((f, c)) writers.append((cw, cimages)) ibwd = os.path.join(parent_path, 'data', 'intermediate', 'images_by_writer') diff --git a/data/utils/split_data.py b/data/utils/split_data.py index 2ed04bc4..0b0c58ea 100644 --- a/data/utils/split_data.py +++ b/data/utils/split_data.py @@ -34,34 +34,35 @@ def create_jsons_for(user_files, which_set, max_users, include_hierarchy): num_samples.append(ns) user_data[u] = data['user_data'][u] user_count += 1 - - if (user_count == max_users) or (i == len(user_files) - 1): - - all_data = {} - all_data['users'] = users - all_data['num_samples'] = num_samples - all_data['user_data'] = user_data - - data_i = f.find('data') - num_i = data_i + 5 - num_to_end = f[num_i:] - param_i = num_to_end.find('_') - param_to_end = '.json' - if param_i != -1: - param_to_end = num_to_end[param_i:] - nf = '%s_%d%s' % (f[:(num_i-1)], json_index, param_to_end) - file_name = '%s_%s_%s.json' % ((nf[:-5]), which_set, arg_label) - ouf_dir = os.path.join(dir, which_set, file_name) - - print('writing %s' % file_name) - with open(ouf_dir, 'w') as outfile: - json.dump(all_data, outfile) - - user_count = 0 - json_index += 1 - users = [] - num_samples = [] - user_data = {} + + # ✅ 正确做法:每处理一个 user 都判断一次 + if (user_count == max_users) or (i == len(user_files) - 1): + + all_data = {} + all_data['users'] = users + all_data['num_samples'] = num_samples + all_data['user_data'] = user_data + + data_i = f.find('data') + num_i = data_i + 5 + num_to_end = f[num_i:] + param_i = num_to_end.find('_') + param_to_end = '.json' + if param_i != -1: + param_to_end = num_to_end[param_i:] + nf = '%s_%d%s' % (f[:(num_i-1)], json_index, param_to_end) + file_name = '%s_%s_%s.json' % ((nf[:-5]), which_set, arg_label) + ouf_dir = os.path.join(dir, which_set, file_name) + + print('writing %s' % file_name) + with open(ouf_dir, 'w') as outfile: + json.dump(all_data, outfile) + + user_count = 0 + json_index += 1 + users = [] + num_samples = [] + user_data = {} parser = argparse.ArgumentParser()