@@ -69,14 +69,13 @@ class TrainerClassification(Trainer):
6969 summary_step: Integer representing the summary step size.
7070 summary_dir: String representing the path to a directory where to save the
7171 variable summaries.
72- logging_step: Integer representing the number of iterations after which
73- we log the loss of the model.
72+ logging_step: Integer representing the number of iterations after which we
73+ log the loss of the model.
7474 eval_step: Integer representing the number of iterations after which we
7575 evaluate the model.
76- warm_start: Whether the model parameters are initialized at their
77- best value in the previous cotrain iteration. If False, they are
78- reinitialized.
79- gradient_clip=None,
76+ warm_start: Whether the model parameters are initialized at their best value
77+ in the previous cotrain iteration. If False, they are reinitialized.
78+ gradient_clip=None,
8079 abs_loss_chg_tol: A float representing the absolute tolerance for checking
8180 if the training loss has converged. If the difference between the current
8281 loss and previous loss is less than `abs_loss_chg_tol`, we count this
@@ -89,19 +88,19 @@ class TrainerClassification(Trainer):
8988 iterations that pass the convergence criteria before stopping training.
9089 checkpoints_dir: Path to the folder where to store TensorFlow model
9190 checkpoints.
92- weight_decay: Weight for the weight decay term in the classification
93- model loss.
91+ weight_decay: Weight for the weight decay term in the classification model
92+ loss.
9493 weight_decay_schedule: Schedule how to adjust the classification weight
9594 decay weight after every cotrain iteration.
9695 penalize_neg_agr: Whether to not only encourage agreement between samples
9796 that the agreement model believes should have the same label, but also
9897 penalize agreement when two samples agree when the agreement model
9998 predicts they should disagree.
100- use_l2_clssif: Whether to use L2 loss for classification, as opposed to the
101- whichever loss is specified in the provided model_cls.
10299 first_iter_original: A boolean specifying whether the first cotrain
103100 iteration trains the original classification model (with no agreement
104101 term).
102+ use_l2_clssif: Whether to use L2 loss for classification, as opposed to the
103+ whichever loss is specified in the provided model_cls.
105104 seed: Seed used by all the random number generators in this class.
106105 use_graph: Boolean specifying whether the agreement loss is applied to graph
107106 edges, as opposed to random pairs of samples.
@@ -162,8 +161,9 @@ def __init__(self,
162161 self .gradient_clip = gradient_clip
163162 self .logging_step = logging_step
164163 self .eval_step = eval_step
165- self .checkpoint_path = (os .path .join (checkpoints_dir , 'classif_best.ckpt' )
166- if checkpoints_dir is not None else None )
164+ self .checkpoint_path = (
165+ os .path .join (checkpoints_dir , 'classif_best.ckpt' )
166+ if checkpoints_dir is not None else None )
167167 self .weight_decay_initial = weight_decay
168168 self .weight_decay_schedule = weight_decay_schedule
169169 self .num_pairs_reg = num_pairs_reg
@@ -186,11 +186,11 @@ def __init__(self,
186186 # First obtain the features shape from the dataset, and append a batch_size
187187 # dimension to it (i.e., `None` to allow for variable batch size).
188188 features_shape = [None ] + list (data .features_shape )
189- input_features = tf .placeholder (tf . float32 , shape = features_shape ,
190- name = 'input_features' )
189+ input_features = tf .placeholder (
190+ tf . float32 , shape = features_shape , name = 'input_features' )
191191 input_labels = tf .placeholder (tf .int64 , shape = (None ,), name = 'input_labels' )
192- one_hot_labels = tf .one_hot (input_labels , data . num_classes ,
193- name = 'input_labels_one_hot' )
192+ one_hot_labels = tf .one_hot (
193+ input_labels , data . num_classes , name = 'input_labels_one_hot' )
194194 # Create a placeholder specifying if this is train time.
195195 is_train = tf .placeholder_with_default (False , shape = [], name = 'is_train' )
196196
@@ -201,8 +201,8 @@ def __init__(self,
201201 self .variables = variables
202202 self .reg_params = reg_params
203203 predictions , variables , reg_params = (
204- self .model .get_predictions_and_params (encoding = encoding ,
205- is_train = is_train ))
204+ self .model .get_predictions_and_params (
205+ encoding = encoding , is_train = is_train ))
206206 self .variables .update (variables )
207207 self .reg_params .update (reg_params )
208208 normalized_predictions = self .model .normalize_predictions (predictions )
@@ -221,9 +221,10 @@ def __init__(self,
221221 loss_supervised = tf .reduce_sum (loss_supervised , axis = - 1 )
222222 loss_supervised = tf .reduce_mean (loss_supervised )
223223 else :
224- loss_supervised = self .model .get_loss (predictions = predictions ,
225- targets = one_hot_labels ,
226- weight_decay = None )
224+ loss_supervised = self .model .get_loss (
225+ predictions = predictions ,
226+ targets = one_hot_labels ,
227+ weight_decay = None )
227228
228229 # Agreement regularization loss.
229230 loss_agr = self ._get_agreement_reg_loss (data , is_train , features_shape )
@@ -280,8 +281,9 @@ def __init__(self,
280281 gradients , _ = tf .clip_by_global_norm (gradients , self .gradient_clip )
281282 grads_and_vars = tuple (zip (gradients , variab ))
282283 with tf .control_dependencies (
283- tf .get_collection (tf .GraphKeys .UPDATE_OPS ,
284- scope = tf .get_default_graph ().get_name_scope ())):
284+ tf .get_collection (
285+ tf .GraphKeys .UPDATE_OPS ,
286+ scope = tf .get_default_graph ().get_name_scope ())):
285287 train_op = self .optimizer .apply_gradients (
286288 grads_and_vars , global_step = self .global_step )
287289
@@ -332,7 +334,7 @@ def _create_weight_decay_var(self, weight_decay_initial,
332334 if weight_decay_schedule is None :
333335 if weight_decay_initial is not None :
334336 weight_decay_var = tf .constant (
335- weight_decay_initial , dtype = tf .float32 , name = 'weight_decay' )
337+ weight_decay_initial , dtype = tf .float32 , name = 'weight_decay' )
336338 else :
337339 weight_decay_var = None
338340 elif weight_decay_schedule == 'linear' :
@@ -406,32 +408,28 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
406408
407409 with tf .variable_scope ('predictions' , reuse = True ):
408410 encoding , _ , _ = self .model .get_encoding_and_params (
409- inputs = features_ll_right , is_train = is_train ,
410- update_batch_stats = False )
411+ inputs = features_ll_right , is_train = is_train , update_batch_stats = False )
411412 predictions_ll_right , _ , _ = self .model .get_predictions_and_params (
412413 encoding = encoding , is_train = is_train )
413414 predictions_ll_right = self .model .normalize_predictions (
414415 predictions_ll_right )
415416
416417 encoding , _ , _ = self .model .get_encoding_and_params (
417- inputs = features_lu_right , is_train = is_train ,
418- update_batch_stats = False )
418+ inputs = features_lu_right , is_train = is_train , update_batch_stats = False )
419419 predictions_lu_right , _ , _ = self .model .get_predictions_and_params (
420420 encoding = encoding , is_train = is_train )
421421 predictions_lu_right = self .model .normalize_predictions (
422422 predictions_lu_right )
423423
424424 encoding , _ , _ = self .model .get_encoding_and_params (
425- inputs = features_uu_left , is_train = is_train ,
426- update_batch_stats = False )
425+ inputs = features_uu_left , is_train = is_train , update_batch_stats = False )
427426 predictions_uu_left , _ , _ = self .model .get_predictions_and_params (
428427 encoding = encoding , is_train = is_train )
429428 predictions_uu_left = self .model .normalize_predictions (
430429 predictions_uu_left )
431430
432431 encoding , _ , _ = self .model .get_encoding_and_params (
433- inputs = features_uu_right , is_train = is_train ,
434- update_batch_stats = False )
432+ inputs = features_uu_right , is_train = is_train , update_batch_stats = False )
435433 predictions_uu_right , _ , _ = self .model .get_predictions_and_params (
436434 encoding = encoding , is_train = is_train )
437435 predictions_uu_right = self .model .normalize_predictions (
@@ -442,8 +440,8 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
442440 # Stop gradients need to be added
443441 # The case where there are no more uu or lu
444442 # edges at the end of training, so the shapes don't match needs fixing.
445- left = tf .concat (
446- ( labels_ll_left , labels_lu_left , predictions_uu_left ), axis = 0 )
443+ left = tf .concat (( labels_ll_left , labels_lu_left , predictions_uu_left ),
444+ axis = 0 )
447445 right = tf .concat (
448446 (predictions_ll_right , predictions_lu_right , predictions_uu_right ),
449447 axis = 0 )
@@ -455,12 +453,16 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
455453 agreement_ll = tf .cast (
456454 tf .equal (labels_ll_left_idx , labels_ll_right_idx ), dtype = tf .float32 )
457455 _ , agreement_lu , _ , _ = self .trainer_agr .create_agreement_prediction (
458- src_features = features_lu_left , tgt_features = features_lu_right ,
459- is_train = is_train , src_indices = indices_lu_left ,
456+ src_features = features_lu_left ,
457+ tgt_features = features_lu_right ,
458+ is_train = is_train ,
459+ src_indices = indices_lu_left ,
460460 tgt_indices = indices_lu_right )
461461 _ , agreement_uu , _ , _ = self .trainer_agr .create_agreement_prediction (
462- src_features = features_uu_left , tgt_features = features_uu_right ,
463- is_train = is_train , src_indices = indices_uu_left ,
462+ src_features = features_uu_left ,
463+ tgt_features = features_uu_right ,
464+ is_train = is_train ,
465+ src_indices = indices_uu_left ,
464466 tgt_indices = indices_uu_right )
465467 agreement = tf .concat ((agreement_ll , agreement_lu , agreement_uu ), axis = 0 )
466468 if self .penalize_neg_agr :
@@ -476,10 +478,10 @@ def _get_agreement_reg_loss(self, data, is_train, features_shape):
476478 num_ll = tf .shape (predictions_ll_right )[0 ]
477479 num_lu = tf .shape (predictions_lu_right )[0 ]
478480 num_uu = tf .shape (predictions_uu_left )[0 ]
479- weights = tf .concat (( self . reg_weight_ll * tf . ones ( num_ll ,),
480- self .reg_weight_lu * tf .ones (num_lu ,),
481- self .reg_weight_uu * tf .ones (num_uu ,)),
482- axis = 0 )
481+ weights = tf .concat (
482+ ( self .reg_weight_ll * tf .ones (num_ll ,), self . reg_weight_lu *
483+ tf . ones ( num_lu ,), self .reg_weight_uu * tf .ones (num_uu ,)),
484+ axis = 0 )
483485
484486 # Scale each distance by its agreement weight and regularzation weight.
485487 loss = tf .reduce_mean (dists * weights * agreement )
@@ -511,8 +513,9 @@ def _construct_feed_dict(self,
511513 input_indices = next (data_iterator )
512514 # Select the labels. Use the true, correct labels, at test time, and the
513515 # self-labeled ones at train time.
514- labels = (self .data .get_original_labels (input_indices ) if split == 'test'
515- else self .data .get_labels (input_indices ))
516+ labels = (
517+ self .data .get_original_labels (input_indices )
518+ if split == 'test' else self .data .get_labels (input_indices ))
516519 feed_dict = {
517520 self .input_features : self .data .get_features (input_indices ),
518521 self .input_labels : labels ,
@@ -586,8 +589,8 @@ def _select_from_pool(indices):
586589 while True :
587590 indices_src , features_src , labels_src = _select_from_pool (src_indices )
588591 indices_tgt , features_tgt , labels_tgt = _select_from_pool (tgt_indices )
589- yield (indices_src , indices_tgt , features_src , features_tgt ,
590- labels_src , labels_tgt )
592+ yield (indices_src , indices_tgt , features_src , features_tgt , labels_src ,
593+ labels_tgt )
591594
592595 def edge_iterator (self , data , batch_size , labeling ):
593596 """An iterator over graph edges.
@@ -679,6 +682,7 @@ def train(self, data, session=None, **kwargs):
679682 data: A CotrainDataset object.
680683 session: A TensorFlow session or None.
681684 **kwargs: Other keyword arguments.
685+
682686 Returns:
683687 best_test_acc: A float representing the test accuracy at the iteration
684688 where the validation accuracy is maximum.
@@ -742,11 +746,11 @@ def train(self, data, session=None, **kwargs):
742746 checkpoint_saved = False
743747 while not has_converged :
744748 feed_dict = self ._construct_feed_dict (
745- data_iterator = data_iterator_train ,
746- split = 'train' ,
747- pair_ll_iterator = pair_ll_iterator ,
748- pair_lu_iterator = pair_lu_iterator ,
749- pair_uu_iterator = pair_uu_iterator )
749+ data_iterator = data_iterator_train ,
750+ split = 'train' ,
751+ pair_ll_iterator = pair_ll_iterator ,
752+ pair_lu_iterator = pair_lu_iterator ,
753+ pair_uu_iterator = pair_uu_iterator )
750754 if self .enable_summaries and step % self .summary_step == 0 :
751755 loss_val , summary , iter_cls_total , _ = session .run (
752756 [self .loss_op , self .summary_op , self .iter_cls_total , self .train_op ],
@@ -813,8 +817,10 @@ def predict(self, session, indices, is_train):
813817 input_features = self .data .get_features (batch_indices )
814818 batch_predictions = session .run (
815819 self .normalized_predictions ,
816- feed_dict = {self .input_features : input_features ,
817- self .is_train :is_train })
820+ feed_dict = {
821+ self .input_features : input_features ,
822+ self .is_train : is_train
823+ })
818824 predictions .append (batch_predictions )
819825 idx_start = idx_end
820826 if not predictions :
0 commit comments