Skip to content

Commit a9c218f

Browse files
arjungtensorflow-copybara
authored andcommitted
Fix documentation in distances.py and multimodal_lib.py to conform to the TensorFlow documentation style.
PiperOrigin-RevId: 268045930
1 parent 6cb6b22 commit a9c218f

File tree

2 files changed

+110
-105
lines changed

2 files changed

+110
-105
lines changed

neural_structured_learning/lib/distances.py

Lines changed: 84 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
def _assert_multinomial_distribution(input_tensor, axis):
26-
"""Assert input has valid multinomial distribution along specified axis."""
26+
"""Assert input has valid multinomial distribution along `axis`."""
2727
sum_of_multinomial_distribution = tf.reduce_sum(
2828
input_tensor=input_tensor, axis=axis)
2929
return [
@@ -36,7 +36,7 @@ def _assert_multinomial_distribution(input_tensor, axis):
3636

3737

3838
def _assert_valid_axis(ndims, axis):
39-
"""Assert the condition `-ndims < axis <= ndims` if axis is not None."""
39+
"""Assert the condition `-ndims < axis <= ndims` if `axis` is not `None`."""
4040
if axis and (axis < -ndims or axis >= ndims):
4141
raise ValueError('axis = %d not in [%d, %d)' % (axis, -ndims, ndims))
4242

@@ -58,39 +58,41 @@ def kl_divergence(
5858
"""Adds a KL-divergence to the training procedure.
5959
6060
For brevity, let `P = labels` and `Q = predictions`. The
61-
Kullback-Leibler divergence`KL(P||Q)` is
61+
Kullback-Leibler divergence `KL(P||Q)` is:
6262
63-
losses = P * log(P) - P * log(Q)
63+
```
64+
KL(P||Q) = P * log(P) - P * log(Q)
65+
```
6466
65-
Note, the function assumes that `predictions` and `labels` are the values of
66-
multinomial distribution, i.e., each value is the probability of the
67+
Note: the function assumes that `predictions` and `labels` are the values of
68+
a multinomial distribution, i.e., each value is the probability of the
6769
corresponding class.
6870
69-
For the usage of `weights` and `reduction`, please refer to tf.losses.
71+
For the usage of `weights` and `reduction`, please refer to `tf.losses`.
7072
7173
Args:
72-
labels: `Tensor` of type float32 or float64, with shape `[d1, ..., dN,
73-
num_classes]`, represents target distribution.
74+
labels: `Tensor` of type `float32` or `float64`, with shape `[d1, ..., dN,
75+
num_classes]`, represents the target distribution.
7476
predictions: `Tensor` of the same type and shape as `labels`, represents
75-
predicted distribution.
76-
axis: The dimension along which the KL divergence is computed. Note, the
77-
values of `labels` and `predictions` along the `axis` should meet the
78-
condition of multinomial distribution.
79-
weights: (optional) `Tensor` whose rank is either 0, or the same rank as
77+
the predicted distribution.
78+
axis: The dimension along which the KL divergence is computed. The values
79+
of `labels` and `predictions` along `axis` should meet the requirements
80+
of a multinomial distribution.
81+
weights: (optional) `Tensor` whose rank is either 0, or the same as that of
8082
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
8183
be either `1`, or the same as the corresponding `losses` dimension).
8284
scope: The scope for the operations performed in computing the loss.
83-
loss_collection: collection to which the loss will be added.
84-
reduction: Type of reduction to apply to loss.
85+
loss_collection: Collection to which the loss will be added.
86+
reduction: Type of reduction to apply to the loss.
8587
8688
Returns:
87-
Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
88-
shape as `labels`; otherwise, it is scalar.
89+
Weighted loss `float` `Tensor`. If `reduction` is `NONE`, this has the same
90+
shape as `labels`, otherwise, it is a scalar.
8991
Raises:
90-
InvalidArgumentError: If `labels` or `predictions` doesn't meet the
91-
condition of multinomial distribution.
92-
ValueError: If `axis` is None, or the shape of `predictions` doesn't match
93-
that of `labels` or if the shape of `weights` is invalid.
92+
InvalidArgumentError: If `labels` or `predictions` don't meet the
93+
requirements of a multinomial distribution.
94+
ValueError: If `axis` is `None`, if the shape of `predictions` doesn't
95+
match that of `labels`, or if the shape of `weights` is invalid.
9496
"""
9597
with tf.compat.v1.name_scope(scope, 'kl_divergence',
9698
(predictions, labels, weights)) as scope:
@@ -121,40 +123,44 @@ def jensen_shannon_divergence(
121123
"""Adds a Jensen-Shannon divergence to the training procedure.
122124
123125
For brevity, let `P = labels`, `Q = predictions`, `KL(P||Q)` be the
124-
Kullback-Leibler divergence. The Jensen-Shannon divergence (JSD) is
126+
Kullback-Leibler divergence as defined in the description of the
127+
`nsl.lib.kl_divergence` function.". The Jensen-Shannon divergence (JSD) is
125128
126-
M = (P + Q) / 2
127-
JSD(P||Q) = KL(P||M) / 2 + KL(Q||M) / 2
129+
```
130+
M = (P + Q) / 2
131+
JSD(P||Q) = KL(P||M) / 2 + KL(Q||M) / 2
132+
```
128133
129-
Note, the function assumes that `predictions` and `labels` are the values of
134+
This function assumes that `predictions` and `labels` are the values of a
130135
multinomial distribution, i.e., each value is the probability of the
131136
corresponding class.
132137
133-
For the usage of `weights` and `reduction`, please refer to tf.losses.
138+
For the usage of `weights` and `reduction`, please refer to `tf.losses`.
134139
135140
Args:
136-
labels: `Tensor` of type float32 or float64, with shape `[d1, ..., dN,
137-
num_classes]`, represents target distribution.
141+
labels: `Tensor` of type `float32` or `float64`, with shape `[d1, ..., dN,
142+
num_classes]`, represents the target distribution.
138143
predictions: `Tensor` of the same type and shape as `labels`, represents
139-
predicted distribution.
144+
the predicted distribution.
140145
axis: The dimension along which the Jensen-Shannon divergence is computed.
141-
Note, the values of `labels` and `predictions` along the `axis` should
142-
meet the condition of multinomial distribution.
143-
weights: (optional) `Tensor` whose rank is either 0, or the same rank as
146+
The values of `labels` and `predictions` along `axis` should meet the
147+
requirements of a multinomial distribution.
148+
weights: (optional) `Tensor` whose rank is either 0, or the same as that of
144149
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
145150
be either `1`, or the same as the corresponding `losses` dimension).
146151
scope: The scope for the operations performed in computing the loss.
147-
loss_collection: collection to which the loss will be added.
148-
reduction: Type of reduction to apply to loss.
152+
loss_collection: Collection to which the loss will be added.
153+
reduction: Type of reduction to apply to the loss.
149154
150155
Returns:
151-
Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
152-
shape as `labels`; otherwise, it is scalar.
156+
Weighted loss `float` `Tensor`. If `reduction` is
157+
`tf.compat.v1.losses.Reduction.MEAN`, this has the same shape as `labels`,
158+
otherwise, it is a scalar.
153159
Raises:
154-
InvalidArgumentError: If `labels` or `predictions` doesn't meet the
155-
condition of multinomial distribution.
156-
ValueError: If `axis` is None, or the shape of `predictions` doesn't match
157-
that of `labels` or if the shape of `weights` is invalid.
160+
InvalidArgumentError: If `labels` or `predictions` don't meet the
161+
requirements of a multinomial distribution.
162+
ValueError: If `axis` is `None`, the shape of `predictions` doesn't match
163+
that of `labels`, or if the shape of `weights` is invalid.
158164
"""
159165
with tf.compat.v1.name_scope(scope, 'jensen_shannon_divergence',
160166
(predictions, labels, weights)) as scope:
@@ -177,7 +183,7 @@ def jensen_shannon_divergence(
177183

178184

179185
def _apply_transform(batched_tensor, transform_type, axis=None):
180-
"""Applys the given transform function to the batched_tensor along axis."""
186+
"""Applies the given transform function to `batched_tensor` along `axis`."""
181187
if transform_type == configs.TransformType.SOFTMAX:
182188
return tf.nn.softmax(batched_tensor, axis=axis)
183189
else:
@@ -217,60 +223,60 @@ def pairwise_distance_wrapper(sources,
217223
targets,
218224
weights=1.0,
219225
distance_config=None):
220-
"""A wrapper to compute pairwise distance between sources and targets.
226+
"""A wrapper to compute the pairwise distance between `sources` and `targets`.
221227
222-
distances = weights * distance_type(sources, targets)
228+
`distances = weights * distance_config.distance_type(sources, targets)`
223229
224230
This wrapper calculates the weighted distance between `(sources, targets)`
225231
pairs, and provides an option to return the distance as the sum over the
226232
difference along the given axis, when vector based distance is needed.
227233
228-
For the usage of `weights` and `reduction`, please refer to tf.losses. For the
229-
usage of `sum_over_axis`, see the following examples:
234+
For the usage of `weights` and `reduction`, please refer to `tf.losses`. For
235+
the usage of `sum_over_axis`, see the following examples:
230236
231-
Given target tensors with shape `[batch_size, features]`, reduction set to
232-
be MEAN, and `sum_over_axis` set to be last dimension, the weighted average
233-
distance of `sample pairs` will be returned. For example:
234-
With a distance_config('L2', sum_over_axis=-1), the distance between
235-
[[1, 1], [2, 2], [0, 2], [5, 5]] and [[1, 1], [0, 2], [4, 4], [1, 4]] will be
236-
{(0+0) + (4+0) + (16+4) + (16+1)}/4 = 10.25
237+
Given target tensors with shape `[batch_size, features]`, the reduction set to
238+
`tf.compat.v1.losses.Reduction.MEAN`, and `sum_over_axis` set to the last
239+
dimension, the weighted average distance of sample pairs will be returned.
240+
For example: With a distance_config('L2', sum_over_axis=-1), the distance
241+
between [[1, 1], [2, 2], [0, 2], [5, 5]] and [[1, 1], [0, 2], [4, 4], [1, 4]]
242+
will be {(0+0) + (4+0) + (16+4) + (16+1)}/4 = 10.25
237243
238-
If `sum_over_axis` is None, the weighted average distance of `feature pairs`
239-
(instead of sample pairs) will be returned. For example:
240-
With a distance_config('L2'), the distance between
244+
If `sum_over_axis` is `None`, the weighted average distance of feature pairs
245+
(instead of sample pairs) will be returned. For example: With a
246+
distance_config('L2'), the distance between
241247
[[1, 1], [2, 2], [0, 2], [5, 5]] and [[1, 1], [0, 2], [4, 4], [1, 4]] will be
242248
{(0+0) + (4+0) + (16+4) + (16+1)}/8 = 5.125
243249
244-
If `transform_fn` is not None, the transform function is applied to both
245-
sources and targets before computing the distance. For example:
246-
distance_config('KL_DIVERGENCE', sum_over_axis=-1, transform_fn='SOFTMAX')
250+
If `transform_fn` is not `None`, the transform function is applied to both
251+
`sources` and `targets` before computing the distance. For example:
252+
`distance_config('KL_DIVERGENCE', sum_over_axis=-1, transform_fn='SOFTMAX')`
247253
treats `sources` and `targets` as logits, and computes the KL-divergence
248-
between the probability distributions.
254+
between the two probability distributions.
249255
250256
Args:
251-
sources: `Tensor` of type float32 or float64.
252-
targets: `Tensor` of the same type and shape as sources.
253-
weights: (optional) `Tensor` whose rank is either 0, or the same rank as
257+
sources: `Tensor` of type `float32` or `float64`.
258+
targets: `Tensor` of the same type and shape as `sources`.
259+
weights: (optional) `Tensor` whose rank is either 0, or the same as that of
254260
`targets`, and must be broadcastable to `targets` (i.e., all dimensions
255-
must be either `1`, or the same as the corresponding `distance`
256-
dimension).
257-
distance_config: DistanceConfig contains the following configs (or
258-
hyper-parameters) for computing distances:
259-
(a) 'distance_type': Type of distance function to apply.
260-
(b) 'reduction': Type of distance reduction. Refer to tf.losses.Reduction.
261-
(c) 'sum_over_axis': (optional) The distance is sum over the difference
262-
along the axis. Note, if `sum_over_axis` is not None and the rank of
263-
`weights` is nonzero, the size of `weights` along the `sum_over_axis`
264-
must be 1.
265-
(d) 'transform_fn': (optional) If set, both sources and targets will be
266-
transformed before calculating the distance. If set to 'SOFTMAX', it
267-
will be performed on the axis specified by 'sum_over_axis', or -1 if
268-
that is not specified. If None, the default distance config will be
261+
must be either `1`, or the same as the corresponding distance dimension).
262+
distance_config: An instance of `nsl.configs.DistanceConfig` that contains
263+
the following configuration (or hyperparameters) for computing distances:
264+
(a) `distance_type`: Type of distance function to apply.
265+
(b) `reduction`: Type of distance reduction. See `tf.losses.Reduction`.
266+
(c) `sum_over_axis`: (optional) The distance is the sum over the
267+
difference along the specified axis. Note that if `sum_over_axis` is not
268+
`None` and the rank of `weights` is non-zero, then the size of `weights`
269+
along `sum_over_axis` must be 1.
270+
(d) `transform_fn`: (optional) If set, both `sources` and `targets` will
271+
be transformed before calculating the distance. If set to 'SOFTMAX', it
272+
will be performed on the axis specified by 'sum_over_axis', or -1 if the
273+
axis is not specified. If `None`, the default distance config will be
269274
used.
270275
271276
Returns:
272-
Weighted distance scalar `Tensor`. If `reduction` is `NONE`, this has the
273-
same shape as `targets`.
277+
Weighted distance scalar `Tensor`. If `reduction` is
278+
`tf.compat.v1.losses.Reduction.MEAN`, this has the same shape as
279+
`targets`.
274280
Raises:
275281
ValueError: If the shape of targets doesn't match that of sources, or if the
276282
shape of weights is invalid.

neural_structured_learning/lib/multimodal_lib.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
"""Libs/utils for multimodal integration for Neural Structured Learning."""
14+
"""Utilities for multimodal integration for Neural Structured Learning."""
1515

1616
from __future__ import absolute_import
1717
from __future__ import division
@@ -23,7 +23,7 @@
2323

2424

2525
def _bimodal_op(x, y, op_config):
26-
"""Apply bimodal integration operation to inputs."""
26+
"""Applies a bimodal integration operation to the inputs `x` and `y`."""
2727
if op_config.integration_type == configs.IntegrationType.ADD:
2828
return x + y
2929
elif op_config.integration_type == configs.IntegrationType.MUL:
@@ -54,37 +54,36 @@ def bimodal_integration(x,
5454
integration_config,
5555
reuse=None,
5656
scope=None):
57-
"""Compute the bimodal integration between x and y.
57+
"""Computes the bimodal integration between `x` and `y`.
5858
59-
The inputs `x` and `y` are usually from two different types of sources,
60-
e.g., `x` represents image embeddings and `y` represent text embeddings.
61-
This function will integrate bimodal inputs `x` and `y` by the following:
59+
The inputs `x` and `y` are usually from two different types of input
60+
sources, e.g., `x` may represent image embeddings and `y` may represent text
61+
embeddings. This function will integrate bimodal inputs `x` and `y` as
62+
follows:
6263
63-
`outputs = fc_layer(
64-
activation_fn(integration_type(fc_layer(x), fc_layer(y))))`
64+
```
65+
outputs = fc_layer(activation_fn(integrate(fc_layer(x), fc_layer(y))))
66+
```,
67+
where `fc_layer` represents a fully connected layer.
6568
66-
When the integration_type is (elementwise) 'additive', this function will is
67-
equivalent to concat `x` and `y` and pass them into a two-layer perception.
68-
When the integration_type is (elementwise) 'multiplicative', this function
69-
is equivalent to multimodal low-rank bilinear Pooling (MLB) in
70-
arXiv:1610.04325.
71-
When the integration_type is 'tucker_decomp', this function is equivalent to
72-
multimodal tensor-based Tucker decomposition (MUTAN) in arXiv:1705.06676.
69+
When the integration type is (element-wise) 'additive', this function is
70+
equivalent to concatenating `x` and `y` and passing the result into a
71+
two-layer perceptron. When the integration type is (element-wise)
72+
'multiplicative', this function is equivalent to [multimodal low-rank
73+
bilinear Pooling (MLB)](https://arxiv.org/abs/1610.04325). When the
74+
integration type is 'tucker_decomp', this function is equivalent to
75+
[multimodal tensor-based Tucker decomposition
76+
(MUTAN)](https://arxiv.org/abs/1705.06676).
7377
7478
Args:
75-
x: A tensor of at least rank 2 and static value for the last dimension; i.e.
76-
[batch_size, depth], [None, None, None, channels].
77-
y: A tensor of the same type and shape as `x`, except the size of the last
78-
dimension can be different.
79+
x: A tensor of rank at least 2 and a static value for the last dimension.
80+
For example, `[batch_size, depth]`, `[None, None, None, channels]`, etc.
81+
y: A tensor of the same type and shape as `x`, except that the size of the
82+
last dimension can be different.
7983
output_dims: Integer or long, the number of output units.
80-
integration_config: IntegrationConfig contains the following configs (or
81-
hyper-parameters) for computing the hidden integration of `x` and `y`:
82-
(a) integration_type: Type of integration function to apply.
83-
(b) hidden_dims: Integer or a list of Integer, the number of hidden units
84-
in the fully-connected layer(s) before the output layer.
85-
(c) activation_fn: Activation function to be applied to.
86-
reuse: Whether or not the layer and its variables should be reused. To be
87-
able to reuse the layer scope must be given.
84+
integration_config: An instance of `nsl.configs.IntegrationConfig`.
85+
reuse: Whether or not the fully-connected layers and their variables should
86+
be reused. To be able to reuse them, `scope` must be specified.
8887
scope: Optional scope for `variable_scope`.
8988
9089
Returns:

0 commit comments

Comments
 (0)