Skip to content

Commit feaffd5

Browse files
arjungtensorflow-copybara
authored andcommitted
Update graph and adversarial regularization metrics to include scaled values for consistency with their respective loss term contributions.
PiperOrigin-RevId: 322626309
1 parent 6396bc5 commit feaffd5

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,8 +681,10 @@ def call(self, inputs, **kwargs):
681681
labeled_loss=labeled_loss,
682682
gradient_tape=tape,
683683
model_kwargs=kwargs)
684-
self.add_loss(self.adv_config.multiplier * adv_loss)
685-
self.add_metric(adv_loss, name='adversarial_loss', aggregation='mean')
684+
scaled_adv_loss = self.adv_config.multiplier * adv_loss
685+
self.add_loss(scaled_adv_loss)
686+
self.add_metric(
687+
scaled_adv_loss, name='scaled_adversarial_loss', aggregation='mean')
686688
return outputs
687689

688690
def save(self, *args, **kwargs):

neural_structured_learning/keras/graph_regularization.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,10 @@ def call(self, inputs, training=False, **kwargs):
132132
else:
133133
graph_loss = tf.constant(0, dtype=tf.float32)
134134

135+
scaled_graph_loss = self.graph_reg_config.multiplier * graph_loss
135136
# Note that add_metric() cannot be invoked in a control flow branch.
136-
self.add_metric(graph_loss, name='graph_loss', aggregation='mean')
137-
self.add_loss(self.graph_reg_config.multiplier * graph_loss)
137+
self.add_metric(
138+
scaled_graph_loss, name='scaled_graph_loss', aggregation='mean')
139+
self.add_loss(scaled_graph_loss)
138140

139141
return base_output

0 commit comments

Comments
 (0)