Skip to content

Commit 14938b5

Browse files
[convnet] Returns a dictionnary instead of list;
1 parent f48af1c commit 14938b5

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

inclearn/convnet/my_resnet.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,11 @@ def _make_layer(self, Block, planes, increase_dim=False, n=None):
261261

262262
return Stage(layers, block_relu=self.last_relu)
263263

264-
def forward(self, x, attention_hook=False):
264+
@property
265+
def last_conv(self):
266+
return self.stage_4.conv_b
267+
268+
def forward(self, x):
265269
x = self.conv_1_3x3(x)
266270
x = F.relu(self.bn_1(x), inplace=True)
267271

@@ -278,9 +282,11 @@ def forward(self, x, attention_hook=False):
278282
else:
279283
attentions = [feats_s1[-1], feats_s2[-1], feats_s3[-1], x]
280284

281-
if attention_hook:
282-
return raw_features, features, attentions
283-
return raw_features, features
285+
return {
286+
"raw_features": raw_features,
287+
"features": features,
288+
"attention": attentions
289+
}
284290

285291
def end_features(self, x):
286292
x = self.pool(x)

inclearn/convnet/resnet.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,11 @@ def _make_layer(self, block, planes, blocks, stride=1, last=False):
159159

160160
return nn.Sequential(*layers)
161161

162-
def forward(self, x, attention_hook=False):
162+
@property
163+
def last_conv(self):
164+
return self.layer4[-1].conv2
165+
166+
def forward(self, x):
163167
x = self.conv1(x)
164168
x = self.bn1(x)
165169
x = self.relu(x)
@@ -173,9 +177,11 @@ def forward(self, x, attention_hook=False):
173177
raw_features = self.end_features(x_4)
174178
features = self.end_features(F.relu(x_4, inplace=False))
175179

176-
if attention_hook:
177-
return raw_features, features, [x_1, x_2, x_3, x_4]
178-
return raw_features, features
180+
return {
181+
"raw_features": raw_features,
182+
"features": features,
183+
"attention": [x_1, x_2, x_3, x_4]
184+
}
179185

180186
def end_features(self, x):
181187
x = self.avgpool(x)

inclearn/models/still.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def _before_task(self, train_loader, val_loader):
352352
self._class_weights = None
353353

354354
def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
355-
features, logits, atts = outputs["features"], outputs["logits"], outputs["attention_maps"]
355+
features, logits, atts = outputs["raw_features"], outputs["logits"], outputs["attention"]
356356

357357
if self._post_processing_type is None:
358358
scaled_logits = self._network.post_process(logits)
@@ -362,8 +362,8 @@ def _compute_loss(self, inputs, outputs, targets, onehot_targets, memory_flags):
362362
if self._old_model is not None:
363363
with torch.no_grad():
364364
old_outputs = self._old_model(inputs)
365-
old_features = old_outputs["features"]
366-
old_atts = old_outputs["attention_maps"]
365+
old_features = old_outputs["raw_features"]
366+
old_atts = old_outputs["attention"]
367367

368368
if self._ams_config:
369369
ams_config = copy.deepcopy(self._ams_config)

0 commit comments

Comments
 (0)