1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- r"""Program to build a graph based on dense input features (embeddings).
15+ r"""Library to build a graph based on dense input features (embeddings).
1616
17- USAGE:
18-
19- `python build_graph.py` [*flags*] *input_features.tfr ... output_graph.tsv*
20-
21- This program reads input instances from one or more TFRecord files, each
22- containing `tf.train.Example` protos. Each input example is expected to
23- contain at least these 2 features:
24-
25- * `id`: A singleton `bytes_list` feature that identifies each Example.
26- * `embedding`: A `float_list` feature that contains the (dense) embedding of
27- each example.
28-
29- `id` and `embedding` are not necessarily the literal feature names; if your
30- features have different names, you can use the `--id_feature_name` and
31- `--embedding_feature_name` flags to specify them, respectively.
32-
33- The program then computes the cosine similarity between all pairs of input
34- examples based on their associated embeddings. An edge is written to the
35- *output_graph.tsv* file for each pair whose similarity is at least as large as
36- the value of the `--similarity_threshold` flag's value. Each output edge is
37- represented by a line in the *output_graph.tsv* file with the following form:
38-
39- ```
40- source_id<TAB>target_id<TAB>edge_weight
41- ```
42-
43- All edges in the output will be symmetric (i.e., if edge `A--w-->B` exists in
44- the output, then so will edge `B--w-->A`).
45-
46- For details about this program's flags, run `python build_graph.py --help`.
17+ A python-based program for graph building also exists on
18+ [GitHub](https://github.com/tensorflow/neural-structured-learning/tree/master/neural_structured_learning/tools/graph_builder_main.py).
4719"""
4820
4921from __future__ import absolute_import
5426import itertools
5527import time
5628
57- from absl import app
58- from absl import flags
5929from absl import logging
6030from neural_structured_learning .tools import graph_utils
6131import numpy as np
@@ -71,7 +41,8 @@ def _read_tfrecord_examples(filenames, id_feature_name, embedding_feature_name):
7141 """Reads and returns the embeddings stored in the Examples in `filename`.
7242
7343 Args:
74- filenames: A list of names of TFRecord files containing tensorflow.Examples.
44+ filenames: A list of names of TFRecord files containing `tf.train.Example`
45+ objects.
7546 id_feature_name: Name of the feature that identifies the Example's ID.
7647 embedding_feature_name: Name of the feature that identifies the Example's
7748 embedding.
@@ -162,39 +133,52 @@ def _add_edges(embeddings, threshold, g):
162133 edge_cnt , (time .time () - start_time ))
163134
164135
165- def _main (argv ):
166- """Main function for running the build_graph program."""
167- flag = flags .FLAGS
168- flag .showprefixforinfo = False
169- if len (argv ) < 3 :
170- raise app .UsageError (
171- 'Invalid number of arguments; expected 2 or more, got %d' %
172- (len (argv ) - 1 ))
136+ def build_graph (embedding_files ,
137+ output_graph_path ,
138+ similarity_threshold = 0.8 ,
139+ id_feature_name = 'id' ,
140+ embedding_feature_name = 'embedding' ):
141+ """Builds a graph based on dense embeddings and persists it in TSV format.
173142
174- embeddings = _read_tfrecord_examples (argv [1 :- 1 ], flag .id_feature_name ,
175- flag .embedding_feature_name )
143+ This function reads input instances from one or more TFRecord files, each
144+ containing `tf.train.Example` protos. Each input example is expected to
145+ contain at least the following 2 features:
146+
147+ * `id`: A singleton `bytes_list` feature that identifies each example.
148+ * `embedding`: A `float_list` feature that contains the (dense) embedding of
149+ each example.
150+
151+ `id` and `embedding` are not necessarily the literal feature names; if your
152+ features have different names, you can specify them using the
153+ `id_feature_name` and `embedding_feature_name` arguments, respectively.
154+
155+ This function then computes the cosine similarity between all pairs of input
156+ examples based on their associated embeddings. An edge is written to the TSV
157+ file named by `output_graph_path` for each pair whose similarity is at least
158+ as large as `similarity_threshold`. Each output edge is represented by a TSV
159+ line in the `output_graph_path` file with the following form:
160+
161+ ```
162+ source_id<TAB>target_id<TAB>edge_weight
163+ ```
164+
165+ All edges in the output will be symmetric (i.e., if edge `A--w-->B` exists in
166+ the output, then so will edge `B--w-->A`).
167+
168+ Args:
169+ embedding_files: A list of names of TFRecord files containing
170+ `tf.train.Example` objects, which in turn contain dense embeddings.
171+ output_graph_path: Name of the file to which the output graph in TSV format
172+ should be written.
173+ similarity_threshold: Threshold used to determine which edges to retain in
174+ the resulting graph.
175+ id_feature_name: The name of the feature in the input `tf.train.Example`
176+ objects representing the ID of examples.
177+ embedding_feature_name: The name of the feature in the input
178+ `tf.train.Example` objects representing the embedding of examples.
179+ """
180+ embeddings = _read_tfrecord_examples (embedding_files , id_feature_name ,
181+ embedding_feature_name )
176182 graph = collections .defaultdict (dict )
177- _add_edges (embeddings , flag .similarity_threshold , graph )
178- graph_utils .write_tsv_graph (argv [- 1 ], graph )
179-
180-
181- if __name__ == '__main__' :
182- flags .DEFINE_string (
183- 'id_feature_name' , 'id' ,
184- """Name of the singleton bytes_list feature in each input Example
185- whose value is the Example's ID."""
186- )
187- flags .DEFINE_string (
188- 'embedding_feature_name' , 'embedding' ,
189- """Name of the float_list feature in each input Example
190- whose value is the Example's (dense) embedding."""
191- )
192- flags .DEFINE_float (
193- 'similarity_threshold' , 0.8 ,
194- """Lower bound on the cosine similarity required for an edge
195- to be created between two nodes."""
196- )
197-
198- # Ensure TF 2.0 behavior even if TF 1.X is installed.
199- tf .compat .v1 .enable_v2_behavior ()
200- app .run (_main )
183+ _add_edges (embeddings , similarity_threshold , graph )
184+ graph_utils .write_tsv_graph (output_graph_path , graph )
0 commit comments