Skip to content

Commit 12fb6d3

Browse files
Neural-Link Teamtensorflow-copybara
authored andcommitted
Adds a make_graph_reg_config() factory helper function.
PiperOrigin-RevId: 267607230
1 parent 08d2244 commit 12fb6d3

File tree

9 files changed

+110
-55
lines changed

9 files changed

+110
-55
lines changed

g3doc/tutorials/graph_keras_lstm_imdb.ipynb

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,12 +1263,11 @@
12631263
"outputs": [],
12641264
"source": [
12651265
"# Wrap the base model with graph regularization.\n",
1266-
"graph_reg_config = nsl.configs.GraphRegConfig(\n",
1267-
" neighbor_config=nsl.configs.GraphNeighborConfig(\n",
1268-
" max_neighbors=HPARAMS.num_neighbors),\n",
1266+
"graph_reg_config = nsl.configs.make_graph_reg_config(\n",
1267+
" max_neighbors=HPARAMS.num_neighbors,\n",
12691268
" multiplier=HPARAMS.graph_regularization_multiplier,\n",
1270-
" distance_config=nsl.configs.DistanceConfig(\n",
1271-
" distance_type=HPARAMS.distance_type, sum_over_axis=-1))\n",
1269+
" distance_type=HPARAMS.distance_type,\n",
1270+
" sum_over_axis=-1)\n",
12721271
"graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,\n",
12731272
" graph_reg_config)\n",
12741273
"graph_reg_model.compile(\n",

g3doc/tutorials/graph_keras_mlp_cora.ipynb

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -852,12 +852,11 @@
852852
"outputs": [],
853853
"source": [
854854
"# Wrap the base MLP model with graph regularization.\n",
855-
"graph_reg_config = nsl.configs.GraphRegConfig(\n",
856-
" neighbor_config=nsl.configs.GraphNeighborConfig(\n",
857-
" max_neighbors=HPARAMS.num_neighbors),\n",
855+
"graph_reg_config = nsl.configs.make_graph_reg_config(\n",
856+
" max_neighbors=HPARAMS.num_neighbors,\n",
858857
" multiplier=HPARAMS.graph_regularization_multiplier,\n",
859-
" distance_config=nsl.configs.DistanceConfig(\n",
860-
" distance_type=HPARAMS.distance_type, sum_over_axis=-1))\n",
858+
" distance_type=HPARAMS.distance_type,\n",
859+
" sum_over_axis=-1)\n",
861860
"graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,\n",
862861
" graph_reg_config)\n",
863862
"graph_reg_model.compile(\n",

neural_structured_learning/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from neural_structured_learning.configs.configs import IntegrationConfig
1616
from neural_structured_learning.configs.configs import IntegrationType
1717
from neural_structured_learning.configs.configs import make_adv_reg_config
18+
from neural_structured_learning.configs.configs import make_graph_reg_config
1819
from neural_structured_learning.configs.configs import NormType
1920
from neural_structured_learning.configs.configs import TransformType
2021
from neural_structured_learning.configs.configs import VirtualAdvConfig
@@ -35,6 +36,7 @@
3536
'IntegrationConfig',
3637
'IntegrationType',
3738
'make_adv_reg_config',
39+
'make_graph_reg_config',
3840
'NormType',
3941
'TransformType',
4042
'VirtualAdvConfig',

neural_structured_learning/configs/configs.py

Lines changed: 85 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def all(cls):
3535

3636
@attr.s
3737
class AdvNeighborConfig(object):
38-
"""AdvNeighborConfig contains configs for generating adversarial neighbors.
38+
"""Contains configuration for generating adversarial neighbors.
3939
4040
Attributes:
4141
feature_mask: mask (w/ 0-1 values) applied on gradient. The shape should be
@@ -54,7 +54,7 @@ class AdvNeighborConfig(object):
5454

5555
@attr.s
5656
class AdvRegConfig(object):
57-
"""AdvRegConfig contains configs for adversarial regularization.
57+
"""Contains configuration for adversarial regularization.
5858
5959
Attributes:
6060
multiplier: multiplier to adversarial regularization loss. Default set to
@@ -71,22 +71,20 @@ def make_adv_reg_config(
7171
feature_mask=attr.fields(AdvNeighborConfig).feature_mask.default,
7272
adv_step_size=attr.fields(AdvNeighborConfig).adv_step_size.default,
7373
adv_grad_norm=attr.fields(AdvNeighborConfig).adv_grad_norm.default):
74-
"""Creates AdvRegConfig object.
74+
"""Creates an `nsl.configs.AdvRegConfig` object.
7575
7676
Args:
77-
multiplier: multiplier to adversarial regularization loss. Default set to
78-
0.2.
79-
feature_mask: mask (w/ 0-1 values) applied on gradient. The shape should be
80-
the same as (or broadcastable to) input features. If set to None, no
77+
multiplier: multiplier to adversarial regularization loss. Defaults to 0.2.
78+
feature_mask: mask (w/ 0-1 values) applied on the gradient. The shape should
79+
be the same as (or broadcastable to) input features. If set to `None`, no
8180
feature mask will be applied.
82-
adv_step_size: step size to find the adversarial sample. Default set to
83-
0.001.
81+
adv_step_size: step size to find the adversarial sample. Defaults to 0.001.
8482
adv_grad_norm: type of tensor norm to normalize the gradient. Input will be
85-
converted to `NormType` when applicable (e.g., 'l2' -> NormType.L2).
86-
Default set to L2 norm.
83+
converted to `NormType` when applicable (e.g., a value of 'l2' will be
84+
converted to `nsl.configs.NormType.L2`). Defaults to L2 norm.
8785
8886
Returns:
89-
An AdvRegConfig object.
87+
An `nsl.configs.AdvRegConfig` object.
9088
"""
9189
return AdvRegConfig(
9290
multiplier=multiplier,
@@ -110,7 +108,7 @@ def all(cls):
110108

111109
@attr.s
112110
class AdvTargetConfig(object):
113-
"""AdvTargetConfig contains configs for selecting targets to be attacked.
111+
"""Contains configuration for selecting targets to be attacked.
114112
115113
Attributes:
116114
target_method: type of adversarial targeting method. The value needs to be
@@ -142,20 +140,21 @@ def all(cls):
142140

143141
@attr.s
144142
class DistanceConfig(object):
145-
"""DistanceConfig contains configs for computing distances.
143+
"""Contains configuration for computing distances between tensors.
146144
147145
Attributes:
148146
distance_type: type of distance function. Input type will be converted to
149-
'DistanceType' when applicable (e.g., 'l2' -> DistanceType.L2). Default
150-
set to L2 norm.
151-
reduction: type of distance reduction. See tf.compat.v1.losses.Reduction for
152-
details. Default set to tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS.
147+
the appropriate `nsl.configs.DistanceType` value (e.g., the value 'l2' is
148+
converted to `nsl.configs.DistanceType.L2`). Defaults to the L2 norm.
149+
reduction: type of distance reduction. See `tf.compat.v1.losses.Reduction`
150+
for details. Defaults to `tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS`.
153151
sum_over_axis: the distance is the sum over the difference along the axis.
154-
Default set to None.
152+
See `nsl.lib.pairwise_distance_wrapper` for how this field is used.
153+
Defaults to `None`.
155154
transform_fn: type of transform function to be applied on each side before
156155
computing the pairwise distance. Input type will be converted to
157-
'TransformType' when applicable (e.g., 'softmax' ->
158-
TransformType.SOFTMAX). Default set to 'none'.
156+
`nsl.configs.TransformType` when applicable (e.g., the value 'softmax'
157+
maps to `nsl.configs.TransformType.SOFTMAX`). Defaults to 'none'.
159158
"""
160159
distance_type = attr.ib(converter=DistanceType, default=DistanceType.L2)
161160
reduction = attr.ib(
@@ -177,7 +176,7 @@ def all(cls):
177176

178177
@attr.s
179178
class DecayConfig(object):
180-
"""DecayConfig contains configs for computing decayed value.
179+
"""Contains configuration for computing decayed value.
181180
182181
Attributes:
183182
decay_steps: A scalar int32 or int64 Tensor or a Python number. How often to
@@ -207,7 +206,7 @@ def all(cls):
207206

208207
@attr.s
209208
class IntegrationConfig(object):
210-
"""IntegrationConfig contains configs for computing multimodal integration.
209+
"""Contains configuration for computing multimodal integration.
211210
212211
Attributes:
213212
integration_type: Type of integration function to apply.
@@ -222,7 +221,7 @@ class IntegrationConfig(object):
222221

223222
@attr.s
224223
class VirtualAdvConfig(object):
225-
"""VirtualAdvConfig contains configs for virtual adversarial training.
224+
"""Contains configuration for virtual adversarial training.
226225
227226
Attributes:
228227
adv_neighbor_config: an AdvNeighborConfig object for generating virtual
@@ -245,7 +244,7 @@ class VirtualAdvConfig(object):
245244

246245
@attr.s
247246
class GraphNeighborConfig(object):
248-
"""GraphNeighborConfig specifies neighbor attributes for graph regularization.
247+
"""Specifies neighbor attributes for graph regularization.
249248
250249
Attributes:
251250
prefix: The prefix in feature names that identifies neighbor-specific
@@ -268,21 +267,78 @@ class GraphNeighborConfig(object):
268267

269268
@attr.s
270269
class GraphRegConfig(object):
271-
"""GraphRegConfig contains the configuration for graph regularization.
270+
"""Contains the configuration for graph regularization.
272271
273272
Attributes:
274273
neighbor_config: An instance of `GraphNeighborConfig` that describes
275274
neighbor attributes for graph regularization.
276275
multiplier: The multiplier or weight factor applied on the graph
277-
regularization loss term. Defaults to 0.01. This value has to be greater
278-
than or equal to 0.
276+
regularization loss term. This value has to be non-negative. Defaults to
277+
0.01.
279278
distance_config: An instance of `DistanceConfig` to calculate the graph
280-
regularization loss term. Defaults to `DistanceConfig()`.
279+
regularization loss term. Defaults to `nsl.configs.DistanceConfig()`.
281280
"""
282281
neighbor_config = attr.ib(default=GraphNeighborConfig())
283282
multiplier = attr.ib(default=0.01)
284283
distance_config = attr.ib(default=DistanceConfig())
285284

286285

286+
def make_graph_reg_config(
287+
neighbor_prefix=attr.fields(GraphNeighborConfig).prefix.default,
288+
neighbor_weight_suffix=attr.fields(
289+
GraphNeighborConfig).weight_suffix.default,
290+
max_neighbors=attr.fields(GraphNeighborConfig).max_neighbors.default,
291+
multiplier=attr.fields(GraphRegConfig).multiplier.default,
292+
distance_type=attr.fields(DistanceConfig).distance_type.default,
293+
reduction=attr.fields(DistanceConfig).reduction.default,
294+
sum_over_axis=attr.fields(DistanceConfig).sum_over_axis.default,
295+
transform_fn=attr.fields(DistanceConfig).transform_fn.default):
296+
"""Creates an `nsl.configs.GraphRegConfig` object.
297+
298+
Args:
299+
neighbor_prefix: The prefix in feature names that identifies
300+
neighbor-specific features. Defaults to 'NL_nbr_'.
301+
neighbor_weight_suffix: The suffix in feature names that identifies the
302+
neighbor weight value. Defaults to '_weight'. Note that neighbor weight
303+
features will have `prefix` as a prefix and `weight_suffix` as a suffix.
304+
For example, based on the default values of `prefix` and `weight_suffix`,
305+
a valid neighbor weight feature is 'NL_nbr_0_weight', where 0 corresponds
306+
to the first neighbor of the sample.
307+
max_neighbors: The maximum number of neighbors to be used for graph
308+
regularization. Defaults to 0, which disables graph regularization. Note
309+
that this value has to be less than or equal to the actual number of
310+
neighbors in each sample.
311+
multiplier: The multiplier or weight factor applied on the graph
312+
regularization loss term. This value has to be non-negative. Defaults to
313+
0.01.
314+
distance_type: type of distance function. Input type will be converted to
315+
the appropriate `nsl.configs.DistanceType` value (e.g., the value 'l2' is
316+
converted to `nsl.configs.DistanceType.L2`). Defaults to the L2 norm.
317+
reduction: type of distance reduction. See `tf.compat.v1.losses.Reduction`
318+
for details. Defaults to `tf.losses.Reduction.SUM_BY_NONZERO_WEIGHTS`.
319+
sum_over_axis: the distance is the sum over the difference along the axis.
320+
See `nsl.lib.pairwise_distance_wrapper` for how this field is used.
321+
Defaults to `None`.
322+
transform_fn: type of transform function to be applied on each side before
323+
computing the pairwise distance. Input type will be converted to
324+
`nsl.configs.TransformType` when applicable (e.g., the value 'softmax'
325+
maps to `nsl.configs.TransformType.SOFTMAX`). Defaults to 'none'.
326+
327+
Returns:
328+
An `nsl.configs.GraphRegConfig` object.
329+
"""
330+
return GraphRegConfig(
331+
neighbor_config=GraphNeighborConfig(
332+
prefix=neighbor_prefix,
333+
weight_suffix=neighbor_weight_suffix,
334+
max_neighbors=max_neighbors),
335+
multiplier=multiplier,
336+
distance_config=DistanceConfig(
337+
distance_type=distance_type,
338+
reduction=reduction,
339+
sum_over_axis=sum_over_axis,
340+
transform_fn=transform_fn))
341+
342+
287343
DEFAULT_DISTANCE_PARAMS = attr.asdict(DistanceConfig())
288344
DEFAULT_ADVERSARIAL_PARAMS = attr.asdict(AdvNeighborConfig())

neural_structured_learning/examples/graph_keras_mlp_cora.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -265,14 +265,13 @@ def main(argv):
265265
train_dataset, test_dataset = make_datasets(argv[1], argv[2], hparams)
266266

267267
# Graph regularization configuration.
268-
graph_reg_config = nsl.configs.GraphRegConfig(
269-
neighbor_config=nsl.configs.GraphNeighborConfig(
270-
prefix=NBR_FEATURE_PREFIX,
271-
weight_suffix=NBR_WEIGHT_SUFFIX,
272-
max_neighbors=hparams.num_neighbors),
268+
graph_reg_config = nsl.configs.make_graph_reg_config(
269+
prefix=NBR_FEATURE_PREFIX,
270+
weight_suffix=NBR_WEIGHT_SUFFIX,
271+
max_neighbors=hparams.num_neighbors,
273272
multiplier=hparams.graph_regularization_multiplier,
274-
distance_config=nsl.configs.DistanceConfig(
275-
distance_type=hparams.distance_type, sum_over_axis=-1))
273+
distance_type=hparams.distance_type,
274+
sum_over_axis=-1)
276275

277276
# Create the base MLP models.
278277
base_models = {

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def adversarial_loss(features,
9999
sample_weights: (optional) A 1-D `Tensor` of weights for the examples, with
100100
the same length as the first dimension of `features`.
101101
adv_config: (optional) An `nsl.configs.AdvRegConfig` object for adversarial
102-
regularization hyperparameters.
102+
regularization hyperparameters. Use `nsl.configs.make_adv_reg_config` to
103+
construct one.
103104
predictions: (optional) Precomputed value of `model(features)`. If set, the
104105
value will be reused when calculating adversarial regularization. In eager
105106
mode, the `gradient_tape` has to be set as well.

neural_structured_learning/keras/graph_regularization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ class GraphRegularization(tf.keras.Model):
4242
4343
# Wrap the base model to include graph regularization using up to 1 neighbor
4444
# per sample.
45-
graph_config = nsl.configs.GraphRegConfig(
46-
neighbor_config=nsl.configs.GraphNeighborConfig(max_neighbors=1))
45+
graph_config = nsl.configs.make_graph_reg_config(max_neighbors=1)
4746
graph_model = nsl.keras.GraphRegularization(base_model, graph_config)
4847
4948
# Compile, train, and evaluate the graph-regularized model as usual.
@@ -62,8 +61,9 @@ def __init__(self, base_model, graph_reg_config=None):
6261
Args:
6362
base_model: Unregularized model to which the loss term resulting from
6463
graph regularization will be added.
65-
graph_reg_config: Instance of `GraphRegConfig` that contains configuration
66-
for graph regularization.
64+
graph_reg_config: Instance of `nsl.configs.GraphRegConfig` that contains
65+
configuration for graph regularization. Use
66+
`nsl.configs.make_graph_reg_config` to construct one.
6767
"""
6868

6969
super(GraphRegularization, self).__init__(name='GraphRegularization')

neural_structured_learning/keras/graph_regularization_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,8 @@ def _create_and_compile_graph_reg_model(model_fn, weight, max_neighbors):
195195
model as `tf.keras.Model` instances.
196196
"""
197197
model = model_fn((2,), weight)
198-
graph_reg_config = configs.GraphRegConfig(
199-
configs.GraphNeighborConfig(max_neighbors=max_neighbors),
200-
multiplier=1)
198+
graph_reg_config = configs.make_graph_reg_config(
199+
max_neighbors=max_neighbors, multiplier=1)
201200
graph_reg_model = graph_regularization.GraphRegularization(
202201
model, graph_reg_config)
203202
graph_reg_model.compile(

neural_structured_learning/lib/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ def unpack_neighbor_features(features, neighbor_config, keep_rank=False):
416416
tensors. The shape of each neighbor weight tensor is expected to be `[B,
417417
1]`, where `B` is the batch size. Neighbor weight tensors cannot be sparse
418418
tensors.
419-
neighbor_config: An instance of `GraphNeighborConfig`.
419+
neighbor_config: An instance of `nsl.configs.GraphNeighborConfig`.
420420
keep_rank: Whether to preserve the neighborhood size dimension. Defaults to
421421
False.
422422

0 commit comments

Comments
 (0)