2323
2424
2525def _assert_multinomial_distribution (input_tensor , axis ):
26- """Assert input has valid multinomial distribution along specified axis."""
26+ """Assert input has valid multinomial distribution along ` axis` ."""
2727 sum_of_multinomial_distribution = tf .reduce_sum (
2828 input_tensor = input_tensor , axis = axis )
2929 return [
@@ -36,7 +36,7 @@ def _assert_multinomial_distribution(input_tensor, axis):
3636
3737
3838def _assert_valid_axis (ndims , axis ):
39- """Assert the condition `-ndims < axis <= ndims` if axis is not None."""
39+ """Assert the condition `-ndims < axis <= ndims` if ` axis` is not ` None` ."""
4040 if axis and (axis < - ndims or axis >= ndims ):
4141 raise ValueError ('axis = %d not in [%d, %d)' % (axis , - ndims , ndims ))
4242
@@ -58,39 +58,41 @@ def kl_divergence(
5858 """Adds a KL-divergence to the training procedure.
5959
6060 For brevity, let `P = labels` and `Q = predictions`. The
61- Kullback-Leibler divergence`KL(P||Q)` is
61+ Kullback-Leibler divergence `KL(P||Q)` is:
6262
63- losses = P * log(P) - P * log(Q)
63+ ```
64+ KL(P||Q) = P * log(P) - P * log(Q)
65+ ```
6466
65- Note, the function assumes that `predictions` and `labels` are the values of
66- multinomial distribution, i.e., each value is the probability of the
67+ Note: the function assumes that `predictions` and `labels` are the values of
68+ a multinomial distribution, i.e., each value is the probability of the
6769 corresponding class.
6870
69- For the usage of `weights` and `reduction`, please refer to tf.losses.
71+ For the usage of `weights` and `reduction`, please refer to ` tf.losses` .
7072
7173 Args:
72- labels: `Tensor` of type float32 or float64, with shape `[d1, ..., dN,
73- num_classes]`, represents target distribution.
74+ labels: `Tensor` of type ` float32` or ` float64` , with shape `[d1, ..., dN,
75+ num_classes]`, represents the target distribution.
7476 predictions: `Tensor` of the same type and shape as `labels`, represents
75- predicted distribution.
76- axis: The dimension along which the KL divergence is computed. Note, the
77- values of `labels` and `predictions` along the `axis` should meet the
78- condition of multinomial distribution.
79- weights: (optional) `Tensor` whose rank is either 0, or the same rank as
77+ the predicted distribution.
78+ axis: The dimension along which the KL divergence is computed. The values
79+ of `labels` and `predictions` along `axis` should meet the requirements
80+ of a multinomial distribution.
81+ weights: (optional) `Tensor` whose rank is either 0, or the same as that of
8082 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
8183 be either `1`, or the same as the corresponding `losses` dimension).
8284 scope: The scope for the operations performed in computing the loss.
83- loss_collection: collection to which the loss will be added.
84- reduction: Type of reduction to apply to loss.
85+ loss_collection: Collection to which the loss will be added.
86+ reduction: Type of reduction to apply to the loss.
8587
8688 Returns:
87- Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
88- shape as `labels`; otherwise, it is scalar.
89+ Weighted loss ` float` `Tensor`. If `reduction` is `NONE`, this has the same
90+ shape as `labels`, otherwise, it is a scalar.
8991 Raises:
90- InvalidArgumentError: If `labels` or `predictions` doesn 't meet the
91- condition of multinomial distribution.
92- ValueError: If `axis` is None, or the shape of `predictions` doesn't match
93- that of `labels` or if the shape of `weights` is invalid.
92+ InvalidArgumentError: If `labels` or `predictions` don 't meet the
93+ requirements of a multinomial distribution.
94+ ValueError: If `axis` is ` None`, if the shape of `predictions` doesn't
95+ match that of `labels`, or if the shape of `weights` is invalid.
9496 """
9597 with tf .compat .v1 .name_scope (scope , 'kl_divergence' ,
9698 (predictions , labels , weights )) as scope :
@@ -121,40 +123,44 @@ def jensen_shannon_divergence(
121123 """Adds a Jensen-Shannon divergence to the training procedure.
122124
123125 For brevity, let `P = labels`, `Q = predictions`, `KL(P||Q)` be the
124- Kullback-Leibler divergence. The Jensen-Shannon divergence (JSD) is
126+ Kullback-Leibler divergence as defined in the description of the
127+ `nsl.lib.kl_divergence` function.". The Jensen-Shannon divergence (JSD) is
125128
126- M = (P + Q) / 2
127- JSD(P||Q) = KL(P||M) / 2 + KL(Q||M) / 2
129+ ```
130+ M = (P + Q) / 2
131+ JSD(P||Q) = KL(P||M) / 2 + KL(Q||M) / 2
132+ ```
128133
129- Note, the function assumes that `predictions` and `labels` are the values of
134+ This function assumes that `predictions` and `labels` are the values of a
130135 multinomial distribution, i.e., each value is the probability of the
131136 corresponding class.
132137
133- For the usage of `weights` and `reduction`, please refer to tf.losses.
138+ For the usage of `weights` and `reduction`, please refer to ` tf.losses` .
134139
135140 Args:
136- labels: `Tensor` of type float32 or float64, with shape `[d1, ..., dN,
137- num_classes]`, represents target distribution.
141+ labels: `Tensor` of type ` float32` or ` float64` , with shape `[d1, ..., dN,
142+ num_classes]`, represents the target distribution.
138143 predictions: `Tensor` of the same type and shape as `labels`, represents
139- predicted distribution.
144+ the predicted distribution.
140145 axis: The dimension along which the Jensen-Shannon divergence is computed.
141- Note, the values of `labels` and `predictions` along the `axis` should
142- meet the condition of multinomial distribution.
143- weights: (optional) `Tensor` whose rank is either 0, or the same rank as
146+ The values of `labels` and `predictions` along `axis` should meet the
147+ requirements of a multinomial distribution.
148+ weights: (optional) `Tensor` whose rank is either 0, or the same as that of
144149 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
145150 be either `1`, or the same as the corresponding `losses` dimension).
146151 scope: The scope for the operations performed in computing the loss.
147- loss_collection: collection to which the loss will be added.
148- reduction: Type of reduction to apply to loss.
152+ loss_collection: Collection to which the loss will be added.
153+ reduction: Type of reduction to apply to the loss.
149154
150155 Returns:
151- Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
152- shape as `labels`; otherwise, it is scalar.
156+ Weighted loss `float` `Tensor`. If `reduction` is
157+ `tf.compat.v1.losses.Reduction.MEAN`, this has the same shape as `labels`,
158+ otherwise, it is a scalar.
153159 Raises:
154- InvalidArgumentError: If `labels` or `predictions` doesn 't meet the
155- condition of multinomial distribution.
156- ValueError: If `axis` is None, or the shape of `predictions` doesn't match
157- that of `labels` or if the shape of `weights` is invalid.
160+ InvalidArgumentError: If `labels` or `predictions` don 't meet the
161+ requirements of a multinomial distribution.
162+ ValueError: If `axis` is ` None`, the shape of `predictions` doesn't match
163+ that of `labels`, or if the shape of `weights` is invalid.
158164 """
159165 with tf .compat .v1 .name_scope (scope , 'jensen_shannon_divergence' ,
160166 (predictions , labels , weights )) as scope :
@@ -177,7 +183,7 @@ def jensen_shannon_divergence(
177183
178184
179185def _apply_transform (batched_tensor , transform_type , axis = None ):
180- """Applys the given transform function to the batched_tensor along axis."""
186+ """Applies the given transform function to ` batched_tensor` along ` axis` ."""
181187 if transform_type == configs .TransformType .SOFTMAX :
182188 return tf .nn .softmax (batched_tensor , axis = axis )
183189 else :
@@ -217,60 +223,60 @@ def pairwise_distance_wrapper(sources,
217223 targets ,
218224 weights = 1.0 ,
219225 distance_config = None ):
220- """A wrapper to compute pairwise distance between sources and targets.
226+ """A wrapper to compute the pairwise distance between ` sources` and ` targets` .
221227
222- distances = weights * distance_type(sources, targets)
228+ ` distances = weights * distance_config. distance_type(sources, targets)`
223229
224230 This wrapper calculates the weighted distance between `(sources, targets)`
225231 pairs, and provides an option to return the distance as the sum over the
226232 difference along the given axis, when vector based distance is needed.
227233
228- For the usage of `weights` and `reduction`, please refer to tf.losses. For the
229- usage of `sum_over_axis`, see the following examples:
234+ For the usage of `weights` and `reduction`, please refer to ` tf.losses` . For
235+ the usage of `sum_over_axis`, see the following examples:
230236
231- Given target tensors with shape `[batch_size, features]`, reduction set to
232- be MEAN, and `sum_over_axis` set to be last dimension, the weighted average
233- distance of ` sample pairs` will be returned. For example:
234- With a distance_config('L2', sum_over_axis=-1), the distance between
235- [[1, 1], [2, 2], [0, 2], [5, 5]] and [[1, 1], [0, 2], [4, 4], [1, 4]] will be
236- {(0+0) + (4+0) + (16+4) + (16+1)}/4 = 10.25
237+ Given target tensors with shape `[batch_size, features]`, the reduction set to
238+ `tf.compat.v1.losses.Reduction. MEAN` , and `sum_over_axis` set to the last
239+ dimension, the weighted average distance of sample pairs will be returned.
240+ For example: With a distance_config('L2', sum_over_axis=-1), the distance
241+ between [[1, 1], [2, 2], [0, 2], [5, 5]] and [[1, 1], [0, 2], [4, 4], [1, 4]]
242+ will be {(0+0) + (4+0) + (16+4) + (16+1)}/4 = 10.25
237243
238- If `sum_over_axis` is None, the weighted average distance of ` feature pairs`
239- (instead of sample pairs) will be returned. For example:
240- With a distance_config('L2'), the distance between
244+ If `sum_over_axis` is ` None` , the weighted average distance of feature pairs
245+ (instead of sample pairs) will be returned. For example: With a
246+ distance_config('L2'), the distance between
241247 [[1, 1], [2, 2], [0, 2], [5, 5]] and [[1, 1], [0, 2], [4, 4], [1, 4]] will be
242248 {(0+0) + (4+0) + (16+4) + (16+1)}/8 = 5.125
243249
244- If `transform_fn` is not None, the transform function is applied to both
245- sources and targets before computing the distance. For example:
246- distance_config('KL_DIVERGENCE', sum_over_axis=-1, transform_fn='SOFTMAX')
250+ If `transform_fn` is not ` None` , the transform function is applied to both
251+ ` sources` and ` targets` before computing the distance. For example:
252+ ` distance_config('KL_DIVERGENCE', sum_over_axis=-1, transform_fn='SOFTMAX')`
247253 treats `sources` and `targets` as logits, and computes the KL-divergence
248- between the probability distributions.
254+ between the two probability distributions.
249255
250256 Args:
251- sources: `Tensor` of type float32 or float64.
252- targets: `Tensor` of the same type and shape as sources.
253- weights: (optional) `Tensor` whose rank is either 0, or the same rank as
257+ sources: `Tensor` of type ` float32` or ` float64` .
258+ targets: `Tensor` of the same type and shape as ` sources` .
259+ weights: (optional) `Tensor` whose rank is either 0, or the same as that of
254260 `targets`, and must be broadcastable to `targets` (i.e., all dimensions
255- must be either `1`, or the same as the corresponding `distance`
256- dimension).
257- distance_config: DistanceConfig contains the following configs (or
258- hyper-parameters) for computing distances:
259- (a) 'distance_type': Type of distance function to apply.
260- (b) 'reduction': Type of distance reduction. Refer to tf.losses.Reduction.
261- (c) 'sum_over_axis': (optional) The distance is sum over the difference
262- along the axis. Note, if `sum_over_axis` is not None and the rank of
263- `weights` is nonzero, the size of `weights` along the `sum_over_axis`
264- must be 1.
265- (d) 'transform_fn': (optional) If set, both sources and targets will be
266- transformed before calculating the distance. If set to 'SOFTMAX', it
267- will be performed on the axis specified by 'sum_over_axis', or -1 if
268- that is not specified. If None, the default distance config will be
261+ must be either `1`, or the same as the corresponding distance dimension).
262+ distance_config: An instance of `nsl.configs.DistanceConfig` that contains
263+ the following configuration (or hyperparameters) for computing distances:
264+ (a) `distance_type`: Type of distance function to apply.
265+ (b) `reduction`: Type of distance reduction. See `tf.losses.Reduction`.
266+ (c) `sum_over_axis`: (optional) The distance is the sum over the
267+ difference along the specified axis. Note that if `sum_over_axis` is not
268+ `None` and the rank of `weights` is non-zero, then the size of `weights`
269+ along `sum_over_axis` must be 1.
270+ (d) `transform_fn`: (optional) If set, both `sources` and `targets` will
271+ be transformed before calculating the distance. If set to 'SOFTMAX', it
272+ will be performed on the axis specified by 'sum_over_axis', or -1 if the
273+ axis is not specified. If `None`, the default distance config will be
269274 used.
270275
271276 Returns:
272- Weighted distance scalar `Tensor`. If `reduction` is `NONE`, this has the
273- same shape as `targets`.
277+ Weighted distance scalar `Tensor`. If `reduction` is
278+ `tf.compat.v1.losses.Reduction.MEAN`, this has the same shape as
279+ `targets`.
274280 Raises:
275281 ValueError: If the shape of targets doesn't match that of sources, or if the
276282 shape of weights is invalid.
0 commit comments