Skip to content

Commit 1fa5376

Browse files
committed
Added NGM agreement.
1 parent dbfdcf5 commit 1fa5376

File tree

2 files changed

+104
-32
lines changed

2 files changed

+104
-32
lines changed

neural_structured_learning/research/gam/trainer/trainer_agreement.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,3 +1119,71 @@ def predict_label_by_agreement(self, indices, num_neighbors=100,
11191119
acc /= len(indices)
11201120
logging.info('Majority vote accuracy: %.2f.', acc)
11211121
return acc
1122+
1123+
class TrainerAgreementAlwaysAgree(object):
1124+
"""Trainer for an agreement model that always predicts that samples agree.
1125+
1126+
The goal of this class is to simulate the behavior of the Neural Graph
1127+
Machines model, which assumes that two nodes connected by a graph
1128+
always have the same label.
1129+
"""
1130+
1131+
def __init__(self, data, **unused_kwargs):
1132+
self.data = data
1133+
self.vars_to_save = []
1134+
1135+
def train(self, *unused_args, **unused_kwargs):
1136+
logging.info('Using NGM, agreement always returns 1. no need to train...')
1137+
1138+
def predict(self, unused_session, unused_src_features, unused_tgt_features,
1139+
src_indices, tgt_indices):
1140+
"""Predict agreement for the provided pairs of samples.
1141+
1142+
The function contains many unused arguments, in order to conform with the
1143+
interface of the TrainerAgreement class.
1144+
1145+
Arguments:
1146+
unused_session: A TensorFlow session where to run the model.
1147+
unused_src_features: An array of shape (num_samples, num_features)
1148+
containing the features of the first element of the pair.
1149+
unused_tgt_features: An array of shape (num_samples, num_features)
1150+
containing the features of the second element of the pair.
1151+
src_indices: An array of integers containing the index of each sample in
1152+
self.data of the samples in src_features.
1153+
tgt_indices: An array of integers containing the index of each sample in
1154+
self.data of the samples in tgt_features.
1155+
1156+
Returns:
1157+
An array containing the predicted agreement value for each pair of
1158+
provided samples.
1159+
"""
1160+
num_samples = src_indices.shape[0]
1161+
return np.ones((num_samples,), dtype=np.float32)
1162+
1163+
def create_agreement_prediction(self, src_indices, *unused_args,
1164+
**unused_kwargs):
1165+
"""Creates the agreement prediction TensorFlow subgraph.
1166+
1167+
This function is the equivalent of `create_agreement_prediction` in
1168+
TrainerAgreement, but here we always predict 1.0.
1169+
1170+
Arguments:
1171+
src_indices: A Tensor or Placeholder of shape (batch_size,)
1172+
containing the indices of the samples that are the sources of the edges.
1173+
unused_args: Other unused arguments, which we allow in order to
1174+
create a common interface with TrainerAgreement.
1175+
unused_kwargs: Other unused keyword arguments, which we allow in order to
1176+
create a common interface with TrainerAgreement.
1177+
Returns:
1178+
predictions: None, because this model doesn't do logits computations, but
1179+
we still return something in order to keep the same function outputs as
1180+
TrainerAgreement.
1181+
normalized_predictions: A Tensor of shape (batch_size,) with values in
1182+
{0, 1}, containing the agreement prediction probabilities.
1183+
variables: An empty dictionary of trainable variables, because this model
1184+
does not have any trainable variables.
1185+
reg_params: An empty dictionary of variables that are used in the
1186+
regularization weight decay term, because this model doesn't have
1187+
regularization variables.
1188+
"""
1189+
return None, tf.ones((tf.shape(src_indices)[0],), tf.float32), {}, {}

neural_structured_learning/research/gam/trainer/trainer_cotrain.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from gam.data.dataset import CotrainDataset
3434
from gam.trainer.trainer_agreement import TrainerAgreement
35+
from gam.trainer.trainer_agreement import TrainerAgreementAlwaysAgree
3536
from gam.trainer.trainer_agreement import TrainerPerfectAgreement
3637
from gam.trainer.trainer_base import Trainer
3738
from gam.trainer.trainer_classification import TrainerClassification
@@ -443,38 +444,41 @@ def train(self, data, **kwargs):
443444
trainer_agr = TrainerPerfectAgreement(data=data)
444445
else:
445446
with tf.variable_scope('AgreementModel'):
446-
trainer_agr = TrainerAgreement(
447-
model=self.model_agr,
448-
data=data,
449-
optimizer=self.optimizer,
450-
gradient_clip=self.gradient_clip,
451-
min_num_iter=self.min_num_iter_agr,
452-
max_num_iter=self.max_num_iter_agr,
453-
num_iter_after_best_val=self.num_iter_after_best_val_agr,
454-
max_num_iter_cotrain=self.max_num_iter_cotrain,
455-
num_warm_up_iter=self.num_warm_up_iter_agr,
456-
warm_start=self.warm_start_agr,
457-
batch_size=self.batch_size_agr,
458-
enable_summaries=self.enable_summaries_per_model,
459-
summary_step=self.summary_step_agr,
460-
summary_dir=self.summary_dir,
461-
logging_step=self.logging_step_agr,
462-
eval_step=self.eval_step_agr,
463-
abs_loss_chg_tol=self.abs_loss_chg_tol,
464-
rel_loss_chg_tol=self.rel_loss_chg_tol,
465-
loss_chg_iter_below_tol=self.loss_chg_iter_below_tol,
466-
checkpoints_dir=self.checkpoints_dir,
467-
weight_decay=self.weight_decay_agr,
468-
weight_decay_schedule=self.weight_decay_schedule_agr,
469-
agree_by_default=False,
470-
percent_val=self.ratio_valid_agr,
471-
max_num_samples_val=self.max_samples_valid_agr,
472-
seed=self.seed,
473-
lr_decay_rate=self.lr_decay_rate_agr,
474-
lr_decay_steps=self.lr_decay_steps_agr,
475-
lr_initial=self.learning_rate_agr,
476-
use_graph=self.use_graph,
477-
add_negative_edges=self.add_negative_edges_agr)
447+
if self.always_agree:
448+
trainer_agr = TrainerAgreementAlwaysAgree(data=data)
449+
else:
450+
trainer_agr = TrainerAgreement(
451+
model=self.model_agr,
452+
data=data,
453+
optimizer=self.optimizer,
454+
gradient_clip=self.gradient_clip,
455+
min_num_iter=self.min_num_iter_agr,
456+
max_num_iter=self.max_num_iter_agr,
457+
num_iter_after_best_val=self.num_iter_after_best_val_agr,
458+
max_num_iter_cotrain=self.max_num_iter_cotrain,
459+
num_warm_up_iter=self.num_warm_up_iter_agr,
460+
warm_start=self.warm_start_agr,
461+
batch_size=self.batch_size_agr,
462+
enable_summaries=self.enable_summaries_per_model,
463+
summary_step=self.summary_step_agr,
464+
summary_dir=self.summary_dir,
465+
logging_step=self.logging_step_agr,
466+
eval_step=self.eval_step_agr,
467+
abs_loss_chg_tol=self.abs_loss_chg_tol,
468+
rel_loss_chg_tol=self.rel_loss_chg_tol,
469+
loss_chg_iter_below_tol=self.loss_chg_iter_below_tol,
470+
checkpoints_dir=self.checkpoints_dir,
471+
weight_decay=self.weight_decay_agr,
472+
weight_decay_schedule=self.weight_decay_schedule_agr,
473+
agree_by_default=False,
474+
percent_val=self.ratio_valid_agr,
475+
max_num_samples_val=self.max_samples_valid_agr,
476+
seed=self.seed,
477+
lr_decay_rate=self.lr_decay_rate_agr,
478+
lr_decay_steps=self.lr_decay_steps_agr,
479+
lr_initial=self.learning_rate_agr,
480+
use_graph=self.use_graph,
481+
add_negative_edges=self.add_negative_edges_agr)
478482

479483
if self.use_perfect_cls:
480484
# A perfect classification model used for debugging purposes.

0 commit comments

Comments
 (0)