|
7 | 7 | import torch._C |
8 | 8 | import torch.nn as nn |
9 | 9 | from torch._C import Graph, Node |
10 | | -from torch.jit import ScriptModule |
11 | 10 |
|
12 | 11 | from genes import NodeGene |
13 | 12 | from genome import OptimizerGenome |
@@ -122,37 +121,54 @@ def build_forward_graph( |
122 | 121 | return graph |
123 | 122 |
|
124 | 123 |
|
125 | | -class DynamicOptimizerModule(ScriptModule): |
126 | | - def __init__(self, genome, input_keys, output_keys): |
127 | | - super(DynamicOptimizerModule, self).__init__() |
128 | | - nn.Module.__init__(self) |
129 | | - ScriptModule.__init__(self) |
| 124 | +class DynamicOptimizerModule(nn.Module): |
| 125 | + """Simple PyTorch module implementing one round of message passing.""" |
130 | 126 |
|
131 | | - # register one Parameter per enabled connection |
| 127 | + def __init__(self, genome, input_keys, output_keys, graph_dict=None): |
| 128 | + super().__init__() |
| 129 | + self.num_nodes = len(genome.nodes) |
| 130 | + self.edges: List[Tuple[int, int]] = [] |
| 131 | + self.weights = nn.ParameterList() |
| 132 | + self.node_types: List[str] = [] |
132 | 133 | for (src, dst), conn in genome.connections.items(): |
133 | 134 | if conn.enabled: |
| 135 | + self.edges.append((src, dst)) |
134 | 136 | w = getattr(conn, "weight", 1.0) |
135 | | - self.register_parameter(f"w_{src}_{dst}", torch.nn.Parameter(torch.tensor(w))) |
136 | | - |
137 | | - # now build and attach the forward graph |
138 | | - graph = build_forward_graph( |
139 | | - num_nodes=len(genome.nodes), |
140 | | - edges=list(genome.connections.keys()), |
141 | | - input_keys=input_keys, |
142 | | - output_keys=output_keys, |
143 | | - ) |
144 | | - |
145 | | - # hook it into this ScriptModule |
146 | | - create_fn = getattr( |
147 | | - torch._C, |
148 | | - "_jit_create_method_from_graph", |
149 | | - getattr( |
150 | | - torch._C, |
151 | | - "_create_method_from_graph", |
152 | | - getattr(torch._C, "_jit_create_function_from_graph", getattr(torch._C, "_create_function_from_graph")), |
153 | | - ), |
154 | | - ) |
155 | | - create_fn("forward", graph) |
| 137 | + self.weights.append(nn.Parameter(torch.tensor(w))) |
| 138 | + for nid in range(self.num_nodes): |
| 139 | + ng = genome.nodes[nid] |
| 140 | + self.node_types.append(ng.node_type) |
| 141 | + |
| 142 | + self.input_keys = input_keys |
| 143 | + self.output_keys = output_keys |
| 144 | + self.graph_dict = graph_dict |
| 145 | + |
| 146 | + def forward( |
| 147 | + self, |
| 148 | + loss: torch.Tensor, |
| 149 | + prev_loss: torch.Tensor, |
| 150 | + named_params: List[Tuple[str, torch.Tensor]], |
| 151 | + ) -> Dict[str, torch.Tensor]: |
| 152 | + params = [p for _, p in named_params] |
| 153 | + all_inputs = [loss, prev_loss] + params |
| 154 | + features = torch.stack(all_inputs, 0) |
| 155 | + |
| 156 | + out_feats = [torch.zeros_like(loss) for _ in range(self.num_nodes)] |
| 157 | + for idx, w in enumerate(self.weights): |
| 158 | + src, dst = self.edges[idx] |
| 159 | + out_feats[dst] = out_feats[dst] + features[src] * w |
| 160 | + |
| 161 | + all_outputs = list(out_feats) |
| 162 | + all_outputs[0] = loss |
| 163 | + all_outputs[1] = prev_loss |
| 164 | + for i, p in enumerate(params): |
| 165 | + if 2 + i < len(all_outputs): |
| 166 | + all_outputs[2 + i] = p |
| 167 | + |
| 168 | + outputs = {} |
| 169 | + for ok in self.output_keys: |
| 170 | + outputs[str(ok)] = all_outputs[ok] |
| 171 | + return outputs |
156 | 172 |
|
157 | 173 |
|
158 | 174 | def rebuild_and_script(graph_dict, config, key) -> DynamicOptimizerModule: |
@@ -180,7 +196,10 @@ def rebuild_and_script(graph_dict, config, key) -> DynamicOptimizerModule: |
180 | 196 | cg.enabled = True |
181 | 197 | genome.connections[(src, dst)] = cg |
182 | 198 |
|
183 | | - # --- make a fresh ScriptModule and give it weight params --- |
| 199 | + # --- build a Python module and script it --- |
184 | 200 | if genome.connections: |
185 | | - return DynamicOptimizerModule(genome, config.input_keys, config.output_keys) |
| 201 | + module = DynamicOptimizerModule( |
| 202 | + genome, config.input_keys, config.output_keys, graph_dict |
| 203 | + ) |
| 204 | + return torch.jit.script(module) |
186 | 205 | return None |
0 commit comments