Skip to content

Commit 4d72a90

Browse files
Merge pull request #17 from otiliastr:improve_cnn
PiperOrigin-RevId: 271188450
2 parents 7bc62bb + 248c1a3 commit 4d72a90

File tree

9 files changed

+339
-227
lines changed

9 files changed

+339
-227
lines changed

neural_structured_learning/keras/layers/layers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
_ERR_TOL = 3e-5 # Tolerance when comparing floats.
3131

3232

33-
# TODO(pp): Update models to use NeighborFeatures
33+
# TODO(ppham27): Update models to use NeighborFeatures
3434
def _make_functional_regularized_model(distance_config):
3535
"""Makes a model with `PairwiseDistance` and the functional API."""
3636

neural_structured_learning/research/gam/data/dataset.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,9 @@ def num_train(self):
145145
def num_val(self):
146146
return self.indices_val.shape[0]
147147

148-
@property
149148
def num_test(self):
150149
return self.indices_test.shape[0]
151150

152-
@property
153151
def num_unlabeled(self):
154152
return self.indices_unlabeled.shape[0]
155153

neural_structured_learning/research/gam/experiments/run_train_mnist.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@
4949
'data_source', 'tensorflow_datasets', 'Data source. Valid options are: '
5050
'`tensorflow_datasets`, `realistic_ssl`')
5151
flags.DEFINE_integer(
52-
'target_num_train_per_class', 20,
52+
'target_num_train_per_class', 400,
5353
'Number of samples per class to use for training.')
5454
flags.DEFINE_integer(
55-
'target_num_val', 10000,
55+
'target_num_val', 1000,
5656
'Number of samples to be used for validation.')
5757
flags.DEFINE_integer(
58-
'seed', 1234,
58+
'seed', 123,
5959
'Seed used by the random number generators.')
6060
flags.DEFINE_bool(
6161
'load_preprocessed', False,
@@ -142,7 +142,7 @@
142142
'Minimum number of iterations to train the agreement model for after '
143143
'the best validation accuracy is improved.')
144144
flags.DEFINE_integer(
145-
'num_samples_to_label', 200,
145+
'num_samples_to_label', 500,
146146
'Number of samples to label after each co-train iteration.')
147147
flags.DEFINE_float(
148148
'min_confidence_new_label', 0.4,
@@ -156,7 +156,7 @@
156156
'Minimum number of co-train iterations the agreement must be trained '
157157
'before it is used in the classifier.')
158158
flags.DEFINE_float(
159-
'ratio_valid_agr', 0.2,
159+
'ratio_valid_agr', 0.1,
160160
'Ratio of edges used for validating the agreement model.')
161161
flags.DEFINE_integer(
162162
'max_samples_valid_agr', 10000,
@@ -190,9 +190,9 @@
190190
'Schedule for decaying the weight decay in the agreement model. Choose '
191191
'between None or linear.')
192192
flags.DEFINE_integer(
193-
'batch_size_agr', 32, 'Batch size for agreement model.')
193+
'batch_size_agr', 512, 'Batch size for agreement model.')
194194
flags.DEFINE_integer(
195-
'batch_size_cls', 32, 'Batch size for classification model.')
195+
'batch_size_cls', 512, 'Batch size for classification model.')
196196
flags.DEFINE_float(
197197
'gradient_clip', None,
198198
'The gradient clipping global norm value. If None, no clipping is done.')
@@ -240,7 +240,7 @@
240240
'reg_weight_uu', 0.05,
241241
'Regularization weight for unlabeled-unlabeled edges.')
242242
flags.DEFINE_integer(
243-
'num_pairs_reg', 512,
243+
'num_pairs_reg', 128,
244244
'Number of pairs of nodes to use in the agreement loss term of the '
245245
'classification model.')
246246
flags.DEFINE_string(
@@ -252,14 +252,14 @@
252252
'penalize_neg_agr', True,
253253
'Whether to encourage differences when agreement is negative.')
254254
flags.DEFINE_bool(
255-
'use_l2_cls', True,
255+
'use_l2_cls', False,
256256
'Whether to use L2 loss for the classifier, not cross entropy.')
257257
flags.DEFINE_bool(
258258
'first_iter_original', True,
259259
'Whether to use the original model in the first iteration, without self '
260260
'labeling or agreement loss.')
261261
flags.DEFINE_bool(
262-
'inductive', False,
262+
'inductive', True,
263263
'Whether to use an inductive or transductive SSL setting.')
264264
flags.DEFINE_string(
265265
'experiment_suffix', '',
@@ -277,6 +277,11 @@
277277
flags.DEFINE_string(
278278
'optimizer', 'adam',
279279
'Which optimizer to use. Valid options are `adam`, `amsgrad`.')
280+
flags.DEFINE_bool(
281+
'load_from_checkpoint', False,
282+
'Whether to load the trained model and the data that has been self-labeled '
283+
'from a previous run, if available. This is useful if a process can get '
284+
'preempted or interrupted.')
280285

281286

282287
def parse_layers_string(layers_string):
@@ -306,11 +311,12 @@ def pick_model(data):
306311
"""Picks the models depending on the provided configuration flags."""
307312
# Create model classification.
308313
if FLAGS.model_cls == 'mlp':
309-
hidden_classif = (parse_layers_string(FLAGS.hidden_cls)
310-
if FLAGS.hidden_cls is not None else [])
314+
hidden_cls = (
315+
parse_layers_string(FLAGS.hidden_cls)
316+
if FLAGS.hidden_cls is not None else [])
311317
model_cls = MLP(
312318
output_dim=data.num_classes,
313-
hidden_sizes=hidden_classif,
319+
hidden_sizes=hidden_cls,
314320
activation=tf.nn.leaky_relu,
315321
name='mlp_cls')
316322
elif FLAGS.model_cls == 'cnn':
@@ -417,22 +423,25 @@ def main(argv):
417423
logging.info('Preprocessed data saved to %s.', path)
418424

419425
# Put together parameters to create a model name.
420-
model_name = FLAGS.model_cls + (('_' + FLAGS.hidden_cls)
421-
if FLAGS.model_cls == 'mlp' else '')
422-
model_name += '-' + FLAGS.model_agr + (('_' + FLAGS.hidden_agr)
423-
if FLAGS.model_agr == 'mlp' else '')
424-
model_name += ('-aggr_' + FLAGS.aggregation_agr_inputs + '_' +
425-
FLAGS.hidden_aggreg)
426+
model_name = FLAGS.model_cls
427+
model_name += ('_' + FLAGS.hidden_cls) if FLAGS.model_cls == 'mlp' else ''
428+
model_name += '-' + FLAGS.model_agr
429+
model_name += ('_' + FLAGS.hidden_agr) if FLAGS.model_agr == 'mlp' else ''
430+
model_name += '-aggr_' + FLAGS.aggregation_agr_inputs
431+
model_name += ('_' + FLAGS.hidden_aggreg) if FLAGS.hidden_aggreg else ''
426432
model_name += ('-add_%d-conf_%.2f-iter_cls_%d-iter_agr_%d-batch_cls_%d' %
427433
(FLAGS.num_samples_to_label, FLAGS.min_confidence_new_label,
428434
FLAGS.max_num_iter_cls, FLAGS.max_num_iter_agr,
429435
FLAGS.batch_size_cls))
430-
model_name += '-perfectAgr' if FLAGS.use_perfect_agreement else ''
431-
model_name += '-perfectCls' if FLAGS.use_perfect_classifier else ''
436+
model_name += '-LL_%s_LU_%s_UU_%s' % (str(
437+
FLAGS.reg_weight_ll), str(FLAGS.reg_weight_lu), str(FLAGS.reg_weight_uu))
438+
model_name += '-perfAgr' if FLAGS.use_perfect_agreement else ''
439+
model_name += '-perfCls' if FLAGS.use_perfect_classifier else ''
432440
model_name += '-keepProp' if FLAGS.keep_label_proportions else ''
433441
model_name += '-PenNegAgr' if FLAGS.penalize_neg_agr else ''
434-
model_name += '-inductive' if FLAGS.inductive else ''
435-
model_name += '-L2Loss' if FLAGS.use_l2_cls else '-CELoss'
442+
model_name += '-transduct' if not FLAGS.inductive else ''
443+
model_name += '-L2' if FLAGS.use_l2_cls else '-CE'
444+
model_name += '-seed_' + str(FLAGS.seed)
436445
model_name += FLAGS.experiment_suffix
437446
logging.info('Model name: %s', model_name)
438447

@@ -451,7 +460,6 @@ def main(argv):
451460
model_cls, model_agr = pick_model(data)
452461

453462
# Train.
454-
optimizer = tf.train.AdamOptimizer
455463
trainer = TrainerCotraining(
456464
model_cls=model_cls,
457465
model_agr=model_agr,
@@ -466,7 +474,7 @@ def main(argv):
466474
min_confidence_new_label=FLAGS.min_confidence_new_label,
467475
keep_label_proportions=FLAGS.keep_label_proportions,
468476
num_warm_up_iter_agr=FLAGS.num_warm_up_iter_agr,
469-
optimizer=optimizer,
477+
optimizer=tf.train.AdamOptimizer,
470478
gradient_clip=FLAGS.gradient_clip,
471479
batch_size_agr=FLAGS.batch_size_agr,
472480
batch_size_cls=FLAGS.batch_size_cls,
@@ -511,7 +519,8 @@ def main(argv):
511519
lr_decay_rate_cls=FLAGS.lr_decay_rate_cls,
512520
lr_decay_steps_cls=FLAGS.lr_decay_steps_cls,
513521
lr_decay_rate_agr=FLAGS.lr_decay_rate_agr,
514-
lr_decay_steps_agr=FLAGS.lr_decay_steps_agr)
522+
lr_decay_steps_agr=FLAGS.lr_decay_steps_agr,
523+
load_from_checkpoint=FLAGS.load_from_checkpoint)
515524

516525
trainer.train(data)
517526

0 commit comments

Comments
 (0)