Skip to content

Commit aadfecf

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Add missing build dependencies (which had been indirect dependencies previously).
Access Keras using tf.keras rather than importing keras itself. PiperOrigin-RevId: 274047775
1 parent 3434e37 commit aadfecf

File tree

8 files changed

+27
-19
lines changed

8 files changed

+27
-19
lines changed

neural_structured_learning/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# Description:
1616
# Build rules for TensorFlow Neural Structured Learning.
1717

18+
# Placeholder for internal Python strict compatibility macro.
1819
# Internal annotation for sync
1920

2021
package(default_visibility = ["//visibility:public"])
@@ -28,6 +29,7 @@ py_library(
2829
srcs = ["__init__.py"],
2930
deps = [
3031
":version",
32+
"//neural_structured_learning/configs",
3133
"//neural_structured_learning/estimator",
3234
"//neural_structured_learning/keras",
3335
"//neural_structured_learning/lib",

neural_structured_learning/configs/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# Description:
1616
# Build rules for config libraries in Neural Structured Learning.
1717

18+
# Placeholder for internal Python strict compatibility macro.
19+
1820
package(
1921
default_visibility = ["//visibility:public"],
2022
licenses = ["notice"], # Apache 2.0

neural_structured_learning/estimator/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# Description:
1616
# Build rules for Estimator APIs in Neural Structured Learning.
1717

18+
# Placeholder for internal Python strict compatibility macro.
1819
# Placeholder for internal Python version compatibility macro.
1920

2021
package(

neural_structured_learning/keras/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# Placeholder for internal Python strict compatibility macro.
1516
# Placeholder for internal Python version compatibility macro.
1617

1718
# Description:
@@ -30,6 +31,7 @@ py_library(
3031
deps = [
3132
":adversarial_regularization",
3233
":graph_regularization",
34+
"//neural_structured_learning/keras/layers",
3335
],
3436
)
3537

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import neural_structured_learning.lib as nsl_lib
2727
import six
2828
import tensorflow as tf
29-
import tensorflow.keras as keras
3029

3130

3231
def adversarial_loss(features,
@@ -193,16 +192,15 @@ def __call__(self, *args, **kwargs):
193192
return loss_value
194193

195194
def _is_sparse_categorical_loss(self):
196-
return self.loss_fn == keras.losses.sparse_categorical_crossentropy or (
197-
isinstance(self.loss_fn, keras.losses.SparseCategoricalCrossentropy))
195+
return self.loss_fn == tf.keras.losses.sparse_categorical_crossentropy or (
196+
isinstance(self.loss_fn, tf.keras.losses.SparseCategoricalCrossentropy))
198197

199198
def _is_binary_classification_loss(self):
200-
return self.loss_fn in (keras.losses.binary_crossentropy,
201-
keras.losses.hinge,
202-
keras.losses.squared_hinge) or isinstance(
203-
self.loss_fn,
204-
(keras.losses.BinaryCrossentropy,
205-
keras.losses.Hinge, keras.losses.SquaredHinge))
199+
return self.loss_fn in (
200+
tf.keras.losses.binary_crossentropy,
201+
tf.keras.losses.hinge, tf.keras.losses.squared_hinge) or isinstance(
202+
self.loss_fn, (tf.keras.losses.BinaryCrossentropy,
203+
tf.keras.losses.Hinge, tf.keras.losses.SquaredHinge))
206204

207205
def resolve_metric(self, metric):
208206
"""Resolves potentially ambiguous metric name based on the loss function."""
@@ -237,21 +235,21 @@ def _prepare_loss_fns(loss, output_names):
237235
if name not in loss:
238236
raise ValueError(
239237
'Loss for {} not found in `loss` dictionary.'.format(name))
240-
return [keras.losses.get(loss[name]) for name in output_names]
238+
return [tf.keras.losses.get(loss[name]) for name in output_names]
241239

242240
# loss for single output, or shared loss fn for multiple outputs
243241
if isinstance(loss, six.string_types):
244-
return [keras.losses.get(loss) for _ in output_names]
242+
return [tf.keras.losses.get(loss) for _ in output_names]
245243

246244
# losses for multiple outputs indexed by position
247245
if isinstance(loss, collections.Sequence):
248246
if len(loss) != len(output_names):
249247
raise ValueError('`loss` should have the same number of elements as '
250248
'model output')
251-
return six.moves.map(keras.losses.get, loss)
249+
return six.moves.map(tf.keras.losses.get, loss)
252250

253251
# loss for single output, or shared loss fn for multiple outputs
254-
return [keras.losses.get(loss) for _ in output_names]
252+
return [tf.keras.losses.get(loss) for _ in output_names]
255253

256254

257255
def _prepare_loss_weights(loss_weights, output_names):
@@ -294,7 +292,7 @@ def clone(metric):
294292
# adversarial-regularized models, and also on multiple outputs in one model.
295293
# The cloning logic is the same as the `clone_metric` function in
296294
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/metrics.py
297-
if not isinstance(metric, keras.metrics.Metric):
295+
if not isinstance(metric, tf.keras.metrics.Metric):
298296
return metric
299297
with tf.init_scope():
300298
return metric.__class__.from_config(metric.get_config())
@@ -348,7 +346,7 @@ def _prepare_metric_fns(metrics, output_names, loss_wrappers):
348346
metric_fns = []
349347
for per_output_metrics, loss_wrapper in zip(metrics, loss_wrappers):
350348
metric_fns.append([
351-
keras.metrics.get(loss_wrapper.resolve_metric(metric))
349+
tf.keras.metrics.get(loss_wrapper.resolve_metric(metric))
352350
for metric in to_list(per_output_metrics)
353351
])
354352
return metric_fns
@@ -390,15 +388,15 @@ def _compute_loss_and_metrics(losses,
390388
value = metric_fn(label, output)
391389
# Metric objects always return an aggregated result, and shouldn't be
392390
# aggregated again.
393-
if isinstance(metric_fn, keras.metrics.Metric):
391+
if isinstance(metric_fn, tf.keras.metrics.Metric):
394392
aggregation = None
395393
else:
396394
aggregation = 'mean'
397395
output_metrics.append((value, aggregation, metric_name))
398396
return tf.add_n(total_loss), output_metrics
399397

400398

401-
class AdversarialRegularization(keras.Model):
399+
class AdversarialRegularization(tf.keras.Model):
402400
"""Wrapper thats adds adversarial regularization to a given `tf.keras.Model`.
403401
404402
This model will reuse the layers and variables as the given `base_model`, so
@@ -567,7 +565,7 @@ def _build_labeled_metrics(self, output_names, labeled_losses):
567565
per_output_metrics = []
568566
for metric_fn in metric_fns:
569567
metric_name = self._make_metric_name(metric_fn, label_key)
570-
if isinstance(metric_fn, keras.metrics.Metric):
568+
if isinstance(metric_fn, tf.keras.metrics.Metric):
571569
# Updates the name of the Metric object to make sure it is unique.
572570
metric_fn._name = metric_name # pylint: disable=protected-access
573571
per_output_metrics.append((metric_fn, metric_name))
@@ -682,7 +680,7 @@ def perturb_on_batch(self, x, **config_kwargs):
682680
adv_inputs,
683681
expand_composites=False)
684682
else:
685-
adv_inputs = keras.backend.function([], adv_inputs)([])
683+
adv_inputs = tf.keras.backend.function([], adv_inputs)([])
686684

687685
# Inserts the labels and sample_weights back to the input dictionary, so
688686
# the returned input has the same structure as the original input.

neural_structured_learning/keras/layers/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# Placeholder for internal Python strict compatibility macro.
1516
# Placeholder for internal Python version compatibility macro.
1617

1718
# Description:

neural_structured_learning/lib/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# Placeholder for internal Python strict compatibility macro.
1516
# Placeholder for internal Python version compatibility macro.
1617

1718
# Description:

neural_structured_learning/tools/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# Placeholder for internal Python strict compatibility macro.
1516
# Placeholder for internal Python version compatibility macro.
1617

1718
# Description:

0 commit comments

Comments
 (0)