|
1 | 1 | import torch |
| 2 | +import torch.nn as nn |
| 3 | +from copy import deepcopy |
2 | 4 | import torch.utils.model_zoo as model_zoo |
3 | 5 | import os |
4 | 6 | import logging |
5 | 7 | from collections import OrderedDict |
| 8 | +from timm.models.layers.conv2d_same import Conv2dSame |
6 | 9 |
|
7 | 10 |
|
8 | 11 | def load_state_dict(checkpoint_path, use_ema=False): |
@@ -101,4 +104,91 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non |
101 | 104 |
|
102 | 105 |
|
103 | 106 |
|
104 | | - |
| 107 | +def extract_layer(model, layer): |
| 108 | + layer = layer.split('.') |
| 109 | + module = model |
| 110 | + if hasattr(model, 'module') and layer[0] != 'module': |
| 111 | + module = model.module |
| 112 | + if not hasattr(model, 'module') and layer[0] == 'module': |
| 113 | + layer = layer[1:] |
| 114 | + for l in layer: |
| 115 | + if hasattr(module, l): |
| 116 | + if not l.isdigit(): |
| 117 | + module = getattr(module, l) |
| 118 | + else: |
| 119 | + module = module[int(l)] |
| 120 | + else: |
| 121 | + return module |
| 122 | + return module |
| 123 | + |
| 124 | + |
| 125 | +def set_layer(model, layer, val): |
| 126 | + layer = layer.split('.') |
| 127 | + module = model |
| 128 | + if hasattr(model, 'module') and layer[0] != 'module': |
| 129 | + module = model.module |
| 130 | + lst_index = 0 |
| 131 | + module2 = module |
| 132 | + for l in layer: |
| 133 | + if hasattr(module2, l): |
| 134 | + if not l.isdigit(): |
| 135 | + module2 = getattr(module2, l) |
| 136 | + else: |
| 137 | + module2 = module2[int(l)] |
| 138 | + lst_index += 1 |
| 139 | + lst_index -= 1 |
| 140 | + for l in layer[:lst_index]: |
| 141 | + if not l.isdigit(): |
| 142 | + module = getattr(module, l) |
| 143 | + else: |
| 144 | + module = module[int(l)] |
| 145 | + l = layer[lst_index] |
| 146 | + setattr(module, l, val) |
| 147 | + |
| 148 | + |
| 149 | +def adapt_model_from_string(parent_module, model_string): |
| 150 | + separator = '***' |
| 151 | + state_dict = {} |
| 152 | + lst_shape = model_string.split(separator) |
| 153 | + for k in lst_shape: |
| 154 | + k = k.split(':') |
| 155 | + key = k[0] |
| 156 | + shape = k[1][1:-1].split(',') |
| 157 | + if shape[0] != '': |
| 158 | + state_dict[key] = [int(i) for i in shape] |
| 159 | + |
| 160 | + new_module = deepcopy(parent_module) |
| 161 | + for n, m in parent_module.named_modules(): |
| 162 | + old_module = extract_layer(parent_module, n) |
| 163 | + if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame): |
| 164 | + if isinstance(old_module, Conv2dSame): |
| 165 | + conv = Conv2dSame |
| 166 | + else: |
| 167 | + conv = nn.Conv2d |
| 168 | + s = state_dict[n + '.weight'] |
| 169 | + in_channels = s[1] |
| 170 | + out_channels = s[0] |
| 171 | + if old_module.groups > 1: |
| 172 | + in_channels = out_channels |
| 173 | + g = in_channels |
| 174 | + else: |
| 175 | + g = 1 |
| 176 | + new_conv = conv(in_channels=in_channels, out_channels=out_channels, |
| 177 | + kernel_size=old_module.kernel_size, bias=old_module.bias is not None, |
| 178 | + padding=old_module.padding, dilation=old_module.dilation, |
| 179 | + groups=g, stride=old_module.stride) |
| 180 | + set_layer(new_module, n, new_conv) |
| 181 | + if isinstance(old_module, nn.BatchNorm2d): |
| 182 | + new_bn = nn.BatchNorm2d(num_features=state_dict[n + '.weight'][0], eps=old_module.eps, |
| 183 | + momentum=old_module.momentum, |
| 184 | + affine=old_module.affine, |
| 185 | + track_running_stats=True) |
| 186 | + set_layer(new_module, n, new_bn) |
| 187 | + if isinstance(old_module, nn.Linear): |
| 188 | + new_fc = nn.Linear(in_features=state_dict[n + '.weight'][1], out_features=old_module.out_features, |
| 189 | + bias=old_module.bias is not None) |
| 190 | + set_layer(new_module, n, new_fc) |
| 191 | + new_module.eval() |
| 192 | + parent_module.eval() |
| 193 | + |
| 194 | + return new_module |
0 commit comments