|
26 | 26 | import neural_structured_learning.lib as nsl_lib |
27 | 27 | import six |
28 | 28 | import tensorflow as tf |
29 | | -import tensorflow.keras as keras |
30 | 29 |
|
31 | 30 |
|
32 | 31 | def adversarial_loss(features, |
@@ -193,16 +192,15 @@ def __call__(self, *args, **kwargs): |
193 | 192 | return loss_value |
194 | 193 |
|
195 | 194 | def _is_sparse_categorical_loss(self): |
196 | | - return self.loss_fn == keras.losses.sparse_categorical_crossentropy or ( |
197 | | - isinstance(self.loss_fn, keras.losses.SparseCategoricalCrossentropy)) |
| 195 | + return self.loss_fn == tf.keras.losses.sparse_categorical_crossentropy or ( |
| 196 | + isinstance(self.loss_fn, tf.keras.losses.SparseCategoricalCrossentropy)) |
198 | 197 |
|
199 | 198 | def _is_binary_classification_loss(self): |
200 | | - return self.loss_fn in (keras.losses.binary_crossentropy, |
201 | | - keras.losses.hinge, |
202 | | - keras.losses.squared_hinge) or isinstance( |
203 | | - self.loss_fn, |
204 | | - (keras.losses.BinaryCrossentropy, |
205 | | - keras.losses.Hinge, keras.losses.SquaredHinge)) |
| 199 | + return self.loss_fn in ( |
| 200 | + tf.keras.losses.binary_crossentropy, |
| 201 | + tf.keras.losses.hinge, tf.keras.losses.squared_hinge) or isinstance( |
| 202 | + self.loss_fn, (tf.keras.losses.BinaryCrossentropy, |
| 203 | + tf.keras.losses.Hinge, tf.keras.losses.SquaredHinge)) |
206 | 204 |
|
207 | 205 | def resolve_metric(self, metric): |
208 | 206 | """Resolves potentially ambiguous metric name based on the loss function.""" |
@@ -237,21 +235,21 @@ def _prepare_loss_fns(loss, output_names): |
237 | 235 | if name not in loss: |
238 | 236 | raise ValueError( |
239 | 237 | 'Loss for {} not found in `loss` dictionary.'.format(name)) |
240 | | - return [keras.losses.get(loss[name]) for name in output_names] |
| 238 | + return [tf.keras.losses.get(loss[name]) for name in output_names] |
241 | 239 |
|
242 | 240 | # loss for single output, or shared loss fn for multiple outputs |
243 | 241 | if isinstance(loss, six.string_types): |
244 | | - return [keras.losses.get(loss) for _ in output_names] |
| 242 | + return [tf.keras.losses.get(loss) for _ in output_names] |
245 | 243 |
|
246 | 244 | # losses for multiple outputs indexed by position |
247 | 245 | if isinstance(loss, collections.Sequence): |
248 | 246 | if len(loss) != len(output_names): |
249 | 247 | raise ValueError('`loss` should have the same number of elements as ' |
250 | 248 | 'model output') |
251 | | - return six.moves.map(keras.losses.get, loss) |
| 249 | + return six.moves.map(tf.keras.losses.get, loss) |
252 | 250 |
|
253 | 251 | # loss for single output, or shared loss fn for multiple outputs |
254 | | - return [keras.losses.get(loss) for _ in output_names] |
| 252 | + return [tf.keras.losses.get(loss) for _ in output_names] |
255 | 253 |
|
256 | 254 |
|
257 | 255 | def _prepare_loss_weights(loss_weights, output_names): |
@@ -294,7 +292,7 @@ def clone(metric): |
294 | 292 | # adversarial-regularized models, and also on multiple outputs in one model. |
295 | 293 | # The cloning logic is the same as the `clone_metric` function in |
296 | 294 | # https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/metrics.py |
297 | | - if not isinstance(metric, keras.metrics.Metric): |
| 295 | + if not isinstance(metric, tf.keras.metrics.Metric): |
298 | 296 | return metric |
299 | 297 | with tf.init_scope(): |
300 | 298 | return metric.__class__.from_config(metric.get_config()) |
@@ -348,7 +346,7 @@ def _prepare_metric_fns(metrics, output_names, loss_wrappers): |
348 | 346 | metric_fns = [] |
349 | 347 | for per_output_metrics, loss_wrapper in zip(metrics, loss_wrappers): |
350 | 348 | metric_fns.append([ |
351 | | - keras.metrics.get(loss_wrapper.resolve_metric(metric)) |
| 349 | + tf.keras.metrics.get(loss_wrapper.resolve_metric(metric)) |
352 | 350 | for metric in to_list(per_output_metrics) |
353 | 351 | ]) |
354 | 352 | return metric_fns |
@@ -390,15 +388,15 @@ def _compute_loss_and_metrics(losses, |
390 | 388 | value = metric_fn(label, output) |
391 | 389 | # Metric objects always return an aggregated result, and shouldn't be |
392 | 390 | # aggregated again. |
393 | | - if isinstance(metric_fn, keras.metrics.Metric): |
| 391 | + if isinstance(metric_fn, tf.keras.metrics.Metric): |
394 | 392 | aggregation = None |
395 | 393 | else: |
396 | 394 | aggregation = 'mean' |
397 | 395 | output_metrics.append((value, aggregation, metric_name)) |
398 | 396 | return tf.add_n(total_loss), output_metrics |
399 | 397 |
|
400 | 398 |
|
401 | | -class AdversarialRegularization(keras.Model): |
| 399 | +class AdversarialRegularization(tf.keras.Model): |
402 | 400 | """Wrapper thats adds adversarial regularization to a given `tf.keras.Model`. |
403 | 401 |
|
404 | 402 | This model will reuse the layers and variables as the given `base_model`, so |
@@ -567,7 +565,7 @@ def _build_labeled_metrics(self, output_names, labeled_losses): |
567 | 565 | per_output_metrics = [] |
568 | 566 | for metric_fn in metric_fns: |
569 | 567 | metric_name = self._make_metric_name(metric_fn, label_key) |
570 | | - if isinstance(metric_fn, keras.metrics.Metric): |
| 568 | + if isinstance(metric_fn, tf.keras.metrics.Metric): |
571 | 569 | # Updates the name of the Metric object to make sure it is unique. |
572 | 570 | metric_fn._name = metric_name # pylint: disable=protected-access |
573 | 571 | per_output_metrics.append((metric_fn, metric_name)) |
@@ -682,7 +680,7 @@ def perturb_on_batch(self, x, **config_kwargs): |
682 | 680 | adv_inputs, |
683 | 681 | expand_composites=False) |
684 | 682 | else: |
685 | | - adv_inputs = keras.backend.function([], adv_inputs)([]) |
| 683 | + adv_inputs = tf.keras.backend.function([], adv_inputs)([]) |
686 | 684 |
|
687 | 685 | # Inserts the labels and sample_weights back to the input dictionary, so |
688 | 686 | # the returned input has the same structure as the original input. |
|
0 commit comments