|
| 1 | +from torchbench.image_classification import ImageNet |
| 2 | +from timm import create_model, list_models |
| 3 | +from timm.data import resolve_data_config, create_transform |
| 4 | + |
| 5 | +NUM_GPU = 1 |
| 6 | +BATCH_SIZE = 256 * NUM_GPU |
| 7 | + |
| 8 | + |
| 9 | +def _attrib(paper_model_name='', paper_arxiv_id='', batch_size=BATCH_SIZE): |
| 10 | + return dict( |
| 11 | + paper_model_name=paper_model_name, |
| 12 | + paper_arxiv_id=paper_arxiv_id, |
| 13 | + batch_size=batch_size) |
| 14 | + |
| 15 | +model_map = dict( |
| 16 | + #adv_inception_v3=_attrib(paper_model_name='Adversarial Inception V3', paper_arxiv_id=), |
| 17 | + #densenet121=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 18 | + #densenet161=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 19 | + #densenet169=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 20 | + #densenet201=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 21 | + dpn68=_attrib( |
| 22 | + paper_model_name='DPN-68', paper_arxiv_id='1707.01629'), |
| 23 | + dpn68b=_attrib( |
| 24 | + paper_model_name='DPN-68b', paper_arxiv_id='1707.01629'), |
| 25 | + dpn92=_attrib( |
| 26 | + paper_model_name='DPN-92', paper_arxiv_id='1707.01629'), |
| 27 | + dpn98=_attrib( |
| 28 | + paper_model_name='DPN-98', paper_arxiv_id='1707.01629'), |
| 29 | + dpn107=_attrib( |
| 30 | + paper_model_name='DPN-107', paper_arxiv_id='1707.01629'), |
| 31 | + dpn131=_attrib( |
| 32 | + paper_model_name='DPN-131', paper_arxiv_id='1707.01629'), |
| 33 | + efficientnet_b0=_attrib( |
| 34 | + paper_model_name='EfficientNet-B0', paper_arxiv_id='1905.11946'), |
| 35 | + efficientnet_b1=_attrib( |
| 36 | + paper_model_name='EfficientNet-B1', paper_arxiv_id='1905.11946'), |
| 37 | + efficientnet_b2=_attrib( |
| 38 | + paper_model_name='EfficientNet-B2', paper_arxiv_id='1905.11946'), |
| 39 | + #ens_adv_inception_resnet_v2=_attrib(paper_model_name=, paper_arxiv_id=), |
| 40 | + fbnetc_100=_attrib( |
| 41 | + paper_model_name='FBNet-C', paper_arxiv_id='1812.03443'), |
| 42 | + gluon_inception_v3=_attrib( |
| 43 | + paper_model_name='Inception V3', paper_arxiv_id='1512.00567'), |
| 44 | + gluon_resnet18_v1b=_attrib( |
| 45 | + paper_model_name='ResNet-18', paper_arxiv_id='1812.01187'), |
| 46 | + gluon_resnet34_v1b=_attrib( |
| 47 | + paper_model_name='ResNet-34', paper_arxiv_id='1812.01187'), |
| 48 | + gluon_resnet50_v1b=_attrib( |
| 49 | + paper_model_name='ResNet-50', paper_arxiv_id='1812.01187'), |
| 50 | + gluon_resnet50_v1c=_attrib( |
| 51 | + paper_model_name='ResNet-50-C', paper_arxiv_id='1812.01187'), |
| 52 | + gluon_resnet50_v1d=_attrib( |
| 53 | + paper_model_name='ResNet-50-D', paper_arxiv_id='1812.01187'), |
| 54 | + gluon_resnet50_v1s=_attrib( |
| 55 | + paper_model_name='ResNet-50-S', paper_arxiv_id='1812.01187'), |
| 56 | + gluon_resnet101_v1b=_attrib( |
| 57 | + paper_model_name='ResNet-101', paper_arxiv_id='1812.01187'), |
| 58 | + gluon_resnet101_v1c=_attrib( |
| 59 | + paper_model_name='ResNet-101-C', paper_arxiv_id='1812.01187'), |
| 60 | + gluon_resnet101_v1d=_attrib( |
| 61 | + paper_model_name='ResNet-101-D', paper_arxiv_id='1812.01187'), |
| 62 | + gluon_resnet101_v1s=_attrib( |
| 63 | + paper_model_name='ResNet-101-S', paper_arxiv_id='1812.01187'), |
| 64 | + gluon_resnet152_v1b=_attrib( |
| 65 | + paper_model_name='ResNet-152', paper_arxiv_id='1812.01187'), |
| 66 | + gluon_resnet152_v1c=_attrib( |
| 67 | + paper_model_name='ResNet-152-C', paper_arxiv_id='1812.01187'), |
| 68 | + gluon_resnet152_v1d=_attrib( |
| 69 | + paper_model_name='ResNet-152-D', paper_arxiv_id='1812.01187'), |
| 70 | + gluon_resnet152_v1s=_attrib( |
| 71 | + paper_model_name='ResNet-152-S', paper_arxiv_id='1812.01187'), |
| 72 | + gluon_resnext50_32x4d=_attrib( |
| 73 | + paper_model_name='ResNeXt-50 32x4d', paper_arxiv_id='1812.01187'), |
| 74 | + gluon_resnext101_32x4d=_attrib( |
| 75 | + paper_model_name='ResNeXt-101 32x4d', paper_arxiv_id='1812.01187'), |
| 76 | + gluon_resnext101_64x4d=_attrib( |
| 77 | + paper_model_name='ResNeXt-101 64x4d', paper_arxiv_id='1812.01187'), |
| 78 | + gluon_senet154=_attrib( |
| 79 | + paper_model_name='SENet-154', paper_arxiv_id='1812.01187'), |
| 80 | + gluon_seresnext50_32x4d=_attrib( |
| 81 | + paper_model_name='SE-ResNeXt-50 32x4d', paper_arxiv_id='1812.01187'), |
| 82 | + gluon_seresnext101_32x4d=_attrib( |
| 83 | + paper_model_name='SE-ResNeXt-101 32x4d', paper_arxiv_id='1812.01187'), |
| 84 | + gluon_seresnext101_64x4d=_attrib( |
| 85 | + paper_model_name='SE-ResNeXt-101 64x4d', paper_arxiv_id='1812.01187'), |
| 86 | + gluon_xception65=_attrib( |
| 87 | + paper_model_name='Modified Aligned Xception', paper_arxiv_id='1802.02611', batch_size=BATCH_SIZE//2), |
| 88 | + ig_resnext101_32x8d=_attrib( |
| 89 | + paper_model_name='ResNeXt-101 32×8d', paper_arxiv_id='1805.00932'), |
| 90 | + ig_resnext101_32x16d=_attrib( |
| 91 | + paper_model_name='ResNeXt-101 32×16d', paper_arxiv_id='1805.00932'), |
| 92 | + ig_resnext101_32x32d=_attrib( |
| 93 | + paper_model_name='ResNeXt-101 32×32d', paper_arxiv_id='1805.00932', batch_size=BATCH_SIZE//2), |
| 94 | + ig_resnext101_32x48d=_attrib( |
| 95 | + paper_model_name='ResNeXt-101 32×48d', paper_arxiv_id='1805.00932', batch_size=BATCH_SIZE//4), |
| 96 | + inception_resnet_v2=_attrib( |
| 97 | + paper_model_name='Inception ResNet V2', paper_arxiv_id='1602.07261'), |
| 98 | + #inception_v3=dict(paper_model_name='Inception V3', paper_arxiv_id=), # same weights as torchvision |
| 99 | + inception_v4=_attrib( |
| 100 | + paper_model_name='Inception V4', paper_arxiv_id='1602.07261'), |
| 101 | + mixnet_l=_attrib( |
| 102 | + paper_model_name='MixNet-L', paper_arxiv_id='1907.09595'), |
| 103 | + mixnet_m=_attrib( |
| 104 | + paper_model_name='MixNet-M', paper_arxiv_id='1907.09595'), |
| 105 | + mixnet_s=_attrib( |
| 106 | + paper_model_name='MixNet-S', paper_arxiv_id='1907.09595'), |
| 107 | + mnasnet_100=_attrib( |
| 108 | + paper_model_name='MnasNet-B1', paper_arxiv_id='1807.11626'), |
| 109 | + mobilenetv3_100=_attrib( |
| 110 | + paper_model_name='MobileNet V3(1.0)', paper_arxiv_id='1905.02244'), |
| 111 | + nasnetalarge=_attrib( |
| 112 | + paper_model_name='NASNet-A Large', paper_arxiv_id='1707.07012', batch_size=BATCH_SIZE//4), |
| 113 | + pnasnet5large=_attrib( |
| 114 | + paper_model_name='PNASNet-5', paper_arxiv_id='1712.00559', batch_size=BATCH_SIZE//4), |
| 115 | + resnet18=_attrib( |
| 116 | + paper_model_name='ResNet-18', paper_arxiv_id='1812.01187'), |
| 117 | + resnet26=_attrib( |
| 118 | + paper_model_name='ResNet-26', paper_arxiv_id='1812.01187'), |
| 119 | + resnet26d=_attrib( |
| 120 | + paper_model_name='ResNet-26-D', paper_arxiv_id='1812.01187'), |
| 121 | + resnet34=_attrib( |
| 122 | + paper_model_name='ResNet-34', paper_arxiv_id='1812.01187'), |
| 123 | + resnet50=_attrib( |
| 124 | + paper_model_name='ResNet-50', paper_arxiv_id='1812.01187'), |
| 125 | + #resnet101=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 126 | + #resnet152=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 127 | + resnext50_32x4d=_attrib( |
| 128 | + paper_model_name='ResNeXt-50 32x4d', paper_arxiv_id='1812.01187'), |
| 129 | + resnext50d_32x4d=_attrib( |
| 130 | + paper_model_name='ResNeXt-50-D 32x4d', paper_arxiv_id='1812.01187'), |
| 131 | + #resnext101_32x8d=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 132 | + semnasnet_100=_attrib( |
| 133 | + paper_model_name='MnasNet-A1', paper_arxiv_id='1807.11626'), |
| 134 | + senet154=_attrib( |
| 135 | + paper_model_name='SENet-154', paper_arxiv_id='1709.01507'), |
| 136 | + seresnet18=_attrib( |
| 137 | + paper_model_name='SE-ResNet-18', paper_arxiv_id='1709.01507'), |
| 138 | + seresnet34=_attrib( |
| 139 | + paper_model_name='SE-ResNet-34', paper_arxiv_id='1709.01507'), |
| 140 | + seresnet50=_attrib( |
| 141 | + paper_model_name='SE-ResNet-50', paper_arxiv_id='1709.01507'), |
| 142 | + seresnet101=_attrib( |
| 143 | + paper_model_name='SE-ResNet-101', paper_arxiv_id='1709.01507'), |
| 144 | + seresnet152=_attrib( |
| 145 | + paper_model_name='SE-ResNet-152', paper_arxiv_id='1709.01507'), |
| 146 | + seresnext26_32x4d=_attrib( |
| 147 | + paper_model_name='SE-ResNeXt-26 32x4d', paper_arxiv_id='1709.01507'), |
| 148 | + seresnext50_32x4d=_attrib( |
| 149 | + paper_model_name='SE-ResNeXt-50 32x4d', paper_arxiv_id='1709.01507'), |
| 150 | + seresnext101_32x4d=_attrib( |
| 151 | + paper_model_name='SE-ResNeXt-101 32x4d', paper_arxiv_id='1709.01507'), |
| 152 | + spnasnet_100=_attrib( |
| 153 | + paper_model_name='Single-Path NAS', paper_arxiv_id='1904.02877'), |
| 154 | + tf_efficientnet_b0=_attrib( |
| 155 | + paper_model_name='EfficientNet-B0', paper_arxiv_id='1905.11946'), |
| 156 | + tf_efficientnet_b1=_attrib( |
| 157 | + paper_model_name='EfficientNet-B1', paper_arxiv_id='1905.11946'), |
| 158 | + tf_efficientnet_b2=_attrib( |
| 159 | + paper_model_name='EfficientNet-B2', paper_arxiv_id='1905.11946'), |
| 160 | + tf_efficientnet_b3=_attrib( |
| 161 | + paper_model_name='EfficientNet-B3', paper_arxiv_id='1905.11946', batch_size=BATCH_SIZE//2), |
| 162 | + tf_efficientnet_b4=_attrib( |
| 163 | + paper_model_name='EfficientNet-B4', paper_arxiv_id='1905.11946', batch_size=BATCH_SIZE//2), |
| 164 | + tf_efficientnet_b5=_attrib( |
| 165 | + paper_model_name='EfficientNet-B5', paper_arxiv_id='1905.11946', batch_size=BATCH_SIZE//4), |
| 166 | + tf_efficientnet_b6=_attrib( |
| 167 | + paper_model_name='EfficientNet-B6', paper_arxiv_id='1905.11946', batch_size=BATCH_SIZE//8), |
| 168 | + tf_efficientnet_b7=_attrib( |
| 169 | + paper_model_name='EfficientNet-B7', paper_arxiv_id='1905.11946', batch_size=BATCH_SIZE//8), |
| 170 | + tf_inception_v3=_attrib( |
| 171 | + paper_model_name='Inception V3', paper_arxiv_id='1512.00567'), |
| 172 | + tf_mixnet_l=_attrib( |
| 173 | + paper_model_name='MixNet-L', paper_arxiv_id='1907.09595'), |
| 174 | + tf_mixnet_m=_attrib( |
| 175 | + paper_model_name='MixNet-M', paper_arxiv_id='1907.09595'), |
| 176 | + tf_mixnet_s=_attrib( |
| 177 | + paper_model_name='MixNet-S', paper_arxiv_id='1907.09595'), |
| 178 | + #tv_resnet34=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 179 | + #tv_resnet50=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 180 | + #tv_resnext50_32x4d=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 181 | + #wide_resnet50_2=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 182 | + #wide_resnet101_2=_attrib(paper_model_name=, paper_arxiv_id=), # same weights as torchvision |
| 183 | + xception=_attrib( |
| 184 | + paper_model_name='Xception', paper_arxiv_id='1610.02357'), |
| 185 | +) |
| 186 | + |
| 187 | +model_names = list_models(pretrained=True) |
| 188 | + |
| 189 | +for model_name in model_names: |
| 190 | + if model_name not in model_map: |
| 191 | + print('Skipping %s' % model_name) |
| 192 | + continue |
| 193 | + |
| 194 | + # create model from name |
| 195 | + model = create_model(model_name, pretrained=True) |
| 196 | + param_count = sum([m.numel() for m in model.parameters()]) |
| 197 | + print('Model %s created, param count: %d' % (model_name, param_count)) |
| 198 | + |
| 199 | + # get appropriate transform for model's default pretrained config |
| 200 | + data_config = resolve_data_config(dict(), model=model, verbose=True) |
| 201 | + input_transform = create_transform(**data_config) |
| 202 | + |
| 203 | + # Run the benchmark |
| 204 | + ImageNet.benchmark( |
| 205 | + model=model, |
| 206 | + paper_model_name=model_map[model_name]['paper_model_name'], |
| 207 | + paper_arxiv_id=model_map[model_name]['paper_arxiv_id'], |
| 208 | + input_transform=input_transform, |
| 209 | + batch_size=model_map[model_name]['batch_size'], |
| 210 | + num_gpu=NUM_GPU, |
| 211 | + #data_root=DATA_ROOT |
| 212 | + ) |
| 213 | + |
| 214 | + |
0 commit comments