@@ -127,7 +127,7 @@ def __init__(self,
127127 weight_decay_schedule = None ,
128128 num_pairs_eval_random = 1000 ,
129129 agree_by_default = False ,
130- percent_val = 0.2 ,
130+ percent_val = 0.1 ,
131131 max_num_samples_val = 10000 ,
132132 seed = None ,
133133 lr_decay_steps = None ,
@@ -485,103 +485,6 @@ def _eval_validation(self, data, labeled_nodes_val, ratio_pos_to_neg,
485485 cummulative_val_acc /= samples_seen
486486 return cummulative_val_acc
487487
488- def _train_iterator (self , labeled_samples , neighbors_val , data ,
489- ratio_pos_to_neg = None ):
490- """An iterator over pairs of samples for training the agreement model.
491-
492- Provides batches of node pairs, including their features and the agreement
493- label (i.e. whether their labels agree). A set of validation pairs
494- is also provided to make sure those samples are not included in train.
495-
496- Arguments:
497- labeled_samples: An array of integers representing the indices of the
498- labeled nodes.
499- neighbors_val: An array of shape (num_samples, 2), where each row
500- represents a pair of sample indices used for validation.
501- data: A Dataset object used to provided the labels of the labeled samples.
502- ratio_pos_to_neg: A float representing the ratio of positive to negative
503- samples in the training set. If this is provided, the train iterator
504- will do rejection sampling based on this ratio to keep the training
505- data balanced. If None, we sample uniformly.
506- Yields:
507- neighbors_batch: An array of shape (batch_size, 2), where each row
508- represents a pair of sample indices used for training. It will not
509- include pairs of samples that are in the provided neighbors_val.
510- agreement_batch: An array of shape (batch_size,) with binary values,
511- where each row represents whether the labels of the corresponding
512- neighbor pair agree (1.0) or not (0.0).
513- """
514- neighbors_val = set ([(pair [0 ], pair [1 ]) if pair [0 ] < pair [1 ] else
515- (pair [1 ], pair [0 ]) for pair in neighbors_val ])
516- neighbors_batch = np .empty (shape = (self .batch_size , 2 ), dtype = np .int32 )
517- agreement_batch = np .empty (shape = (self .batch_size ,), dtype = np .float32 )
518- # TODO(otilastr): remove this. Temporary while fixing something.
519- # For sampling random pairs of samples very fast, we create two buffers,
520- # one containing elements for the left side of the pair, the other for the
521- # right side, and we go through them in parallel.
522- # buffer_left = np.copy(labeled_samples)
523- # buffer_right = np.copy(labeled_samples)
524- # idx_buffer = np.inf
525- # num_labeled = len(labeled_samples)
526- # while True:
527- # num_added = 0
528- # while num_added < self.batch_size:
529- # if idx_buffer >= num_labeled:
530- # idx_buffer = 0
531- # self.rng.shuffle(buffer_left)
532- # self.rng.shuffle(buffer_right)
533- # pair = (buffer_left[idx_buffer], buffer_right[idx_buffer])
534- # idx_buffer += 1
535- # if pair[0] == pair[1]:
536- # continue
537- # ordered_pair = ((pair[0], pair[1]) if pair[0] < pair[1] else
538- # (pair[1], pair[0]))
539- # if ordered_pair in neighbors_val:
540- # continue
541- # agreement = data.get_labels(pair[0]) == data.get_labels(pair[1])
542- # if ratio_pos_to_neg is not None:
543- # # To keep the positive and negatives balanced, do rejection sampling
544- # # according to their ratio.
545- # if ratio_pos_to_neg < 1 and not agreement:
546- # # Reject a negative sample with some probability.
547- # random_number = self.rng.rand(1)[0]
548- # if random_number > ratio_pos_to_neg:
549- # continue
550- # elif ratio_pos_to_neg > 1 and agreement:
551- # # Reject a positive sample with some probability.
552- # random_number = self.rng.random()
553- # if random_number > 1.0 / ratio_pos_to_neg:
554- # continue
555- # neighbors_batch[num_added][0] = pair[0]
556- # neighbors_batch[num_added][1] = pair[1]
557- # agreement_batch[num_added] = agreement
558- # num_added += 1
559- # yield neighbors_batch, agreement_batch
560- while True :
561- num_added = 0
562- while num_added < self .batch_size :
563- pair = self .rng .choice (labeled_samples , 2 )
564- ordered_pair = (pair [0 ], pair [1 ]) if pair [0 ] < pair [1 ] else \
565- (pair [1 ], pair [0 ])
566- if ordered_pair in neighbors_val :
567- continue
568- agreement = data .get_labels (pair [0 ]) == data .get_labels (pair [1 ])
569- if ratio_pos_to_neg is not None :
570- # Keep positives and negatives balanced.
571- if ratio_pos_to_neg < 1 and not agreement :
572- random_number = self .rng .rand (1 )[0 ]
573- if random_number > ratio_pos_to_neg :
574- continue
575- elif ratio_pos_to_neg > 1 and agreement :
576- random_number = self .rng .rand (1 )[0 ]
577- if random_number > 1.0 / ratio_pos_to_neg :
578- continue
579- neighbors_batch [num_added ][0 ] = pair [0 ]
580- neighbors_batch [num_added ][1 ] = pair [1 ]
581- agreement_batch [num_added ] = agreement
582- num_added += 1
583- yield neighbors_batch , agreement_batch
584-
585488 def _select_val_set (self , labeled_samples , num_samples , data ,
586489 ratio_pos_to_neg = None ):
587490 """Select a validation set for the agreement model.
0 commit comments