|
32 | 32 |
|
33 | 33 | from gam.data.dataset import CotrainDataset |
34 | 34 | from gam.trainer.trainer_agreement import TrainerAgreement |
| 35 | +from gam.trainer.trainer_agreement import TrainerAgreementAlwaysAgree |
35 | 36 | from gam.trainer.trainer_agreement import TrainerPerfectAgreement |
36 | 37 | from gam.trainer.trainer_base import Trainer |
37 | 38 | from gam.trainer.trainer_classification import TrainerClassification |
@@ -443,38 +444,41 @@ def train(self, data, **kwargs): |
443 | 444 | trainer_agr = TrainerPerfectAgreement(data=data) |
444 | 445 | else: |
445 | 446 | with tf.variable_scope('AgreementModel'): |
446 | | - trainer_agr = TrainerAgreement( |
447 | | - model=self.model_agr, |
448 | | - data=data, |
449 | | - optimizer=self.optimizer, |
450 | | - gradient_clip=self.gradient_clip, |
451 | | - min_num_iter=self.min_num_iter_agr, |
452 | | - max_num_iter=self.max_num_iter_agr, |
453 | | - num_iter_after_best_val=self.num_iter_after_best_val_agr, |
454 | | - max_num_iter_cotrain=self.max_num_iter_cotrain, |
455 | | - num_warm_up_iter=self.num_warm_up_iter_agr, |
456 | | - warm_start=self.warm_start_agr, |
457 | | - batch_size=self.batch_size_agr, |
458 | | - enable_summaries=self.enable_summaries_per_model, |
459 | | - summary_step=self.summary_step_agr, |
460 | | - summary_dir=self.summary_dir, |
461 | | - logging_step=self.logging_step_agr, |
462 | | - eval_step=self.eval_step_agr, |
463 | | - abs_loss_chg_tol=self.abs_loss_chg_tol, |
464 | | - rel_loss_chg_tol=self.rel_loss_chg_tol, |
465 | | - loss_chg_iter_below_tol=self.loss_chg_iter_below_tol, |
466 | | - checkpoints_dir=self.checkpoints_dir, |
467 | | - weight_decay=self.weight_decay_agr, |
468 | | - weight_decay_schedule=self.weight_decay_schedule_agr, |
469 | | - agree_by_default=False, |
470 | | - percent_val=self.ratio_valid_agr, |
471 | | - max_num_samples_val=self.max_samples_valid_agr, |
472 | | - seed=self.seed, |
473 | | - lr_decay_rate=self.lr_decay_rate_agr, |
474 | | - lr_decay_steps=self.lr_decay_steps_agr, |
475 | | - lr_initial=self.learning_rate_agr, |
476 | | - use_graph=self.use_graph, |
477 | | - add_negative_edges=self.add_negative_edges_agr) |
| 447 | + if self.always_agree: |
| 448 | + trainer_agr = TrainerAgreementAlwaysAgree(data=data) |
| 449 | + else: |
| 450 | + trainer_agr = TrainerAgreement( |
| 451 | + model=self.model_agr, |
| 452 | + data=data, |
| 453 | + optimizer=self.optimizer, |
| 454 | + gradient_clip=self.gradient_clip, |
| 455 | + min_num_iter=self.min_num_iter_agr, |
| 456 | + max_num_iter=self.max_num_iter_agr, |
| 457 | + num_iter_after_best_val=self.num_iter_after_best_val_agr, |
| 458 | + max_num_iter_cotrain=self.max_num_iter_cotrain, |
| 459 | + num_warm_up_iter=self.num_warm_up_iter_agr, |
| 460 | + warm_start=self.warm_start_agr, |
| 461 | + batch_size=self.batch_size_agr, |
| 462 | + enable_summaries=self.enable_summaries_per_model, |
| 463 | + summary_step=self.summary_step_agr, |
| 464 | + summary_dir=self.summary_dir, |
| 465 | + logging_step=self.logging_step_agr, |
| 466 | + eval_step=self.eval_step_agr, |
| 467 | + abs_loss_chg_tol=self.abs_loss_chg_tol, |
| 468 | + rel_loss_chg_tol=self.rel_loss_chg_tol, |
| 469 | + loss_chg_iter_below_tol=self.loss_chg_iter_below_tol, |
| 470 | + checkpoints_dir=self.checkpoints_dir, |
| 471 | + weight_decay=self.weight_decay_agr, |
| 472 | + weight_decay_schedule=self.weight_decay_schedule_agr, |
| 473 | + agree_by_default=False, |
| 474 | + percent_val=self.ratio_valid_agr, |
| 475 | + max_num_samples_val=self.max_samples_valid_agr, |
| 476 | + seed=self.seed, |
| 477 | + lr_decay_rate=self.lr_decay_rate_agr, |
| 478 | + lr_decay_steps=self.lr_decay_steps_agr, |
| 479 | + lr_initial=self.learning_rate_agr, |
| 480 | + use_graph=self.use_graph, |
| 481 | + add_negative_edges=self.add_negative_edges_agr) |
478 | 482 |
|
479 | 483 | if self.use_perfect_cls: |
480 | 484 | # A perfect classification model used for debugging purposes. |
|
0 commit comments