You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/tutorials/examples.md
+36Lines changed: 36 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -35,6 +35,7 @@ output = model(data)
35
35
36
36
#### Complete - Float32
37
37
38
+
[//]: #(train_single_fp32_complete)
38
39
```
39
40
import torch
40
41
import torchvision
@@ -78,9 +79,11 @@ torch.save({
78
79
'optimizer_state_dict': optimizer.state_dict(),
79
80
}, 'checkpoint.pth')
80
81
```
82
+
[//]: #(train_single_fp32_complete)
81
83
82
84
#### Complete - BFloat16
83
85
86
+
[//]: #(train_single_bf16_complete)
84
87
```
85
88
import torch
86
89
import torchvision
@@ -125,13 +128,15 @@ torch.save({
125
128
'optimizer_state_dict': optimizer.state_dict(),
126
129
}, 'checkpoint.pth')
127
130
```
131
+
[//]: #(train_single_bf16_complete)
128
132
129
133
### Distributed Training
130
134
131
135
Distributed training with PyTorch DDP is accelerated by oneAPI Collective Communications Library Bindings for Pytorch\* (oneCCL Bindings for Pytorch\*). The extension supports FP32 and BF16 data types. More detailed information and examples are available at its [Github repo](https://github.com/intel/torch-ccl).
132
136
133
137
**Note:** When performing distributed training with BF16 data type, use oneCCL Bindings for Pytorch\*. Due to a PyTorch limitation, distributed training with BF16 data type with Intel® Extension for PyTorch\* is not supported.
134
138
139
+
[//]: #(train_ddp_complete)
135
140
```
136
141
import os
137
142
import torch
@@ -189,6 +194,7 @@ torch.save({
189
194
'optimizer_state_dict': optimizer.state_dict(),
190
195
}, 'checkpoint.pth')
191
196
```
197
+
[//]: #(train_ddp_complete)
192
198
193
199
## Inference
194
200
@@ -200,6 +206,7 @@ The `optimize` function of Intel® Extension for PyTorch\* applies optimizations
200
206
201
207
##### Resnet50
202
208
209
+
[//]: #(inf_rn50_imp_fp32)
203
210
```
204
211
import torch
205
212
import torchvision.models as models
@@ -216,9 +223,11 @@ model = ipex.optimize(model)
216
223
with torch.no_grad():
217
224
model(data)
218
225
```
226
+
[//]: #(inf_rn50_imp_fp32)
219
227
220
228
##### BERT
221
229
230
+
[//]: #(inf_bert_imp_fp32)
222
231
```
223
232
import torch
224
233
from transformers import BertModel
@@ -239,13 +248,15 @@ model = ipex.optimize(model)
239
248
with torch.no_grad():
240
249
model(data)
241
250
```
251
+
[//]: #(inf_bert_imp_fp32)
242
252
243
253
#### TorchScript Mode
244
254
245
255
We recommend you take advantage of Intel® Extension for PyTorch\* with [TorchScript](https://pytorch.org/docs/stable/jit.html) for further optimizations.
246
256
247
257
##### Resnet50
248
258
259
+
[//]: #(inf_rn50_ts_fp32)
249
260
```
250
261
import torch
251
262
import torchvision.models as models
@@ -266,9 +277,11 @@ with torch.no_grad():
266
277
267
278
model(data)
268
279
```
280
+
[//]: #(inf_rn50_ts_fp32)
269
281
270
282
##### BERT
271
283
284
+
[//]: #(inf_bert_ts_fp32)
272
285
```
273
286
import torch
274
287
from transformers import BertModel
@@ -293,11 +306,13 @@ with torch.no_grad():
293
306
294
307
model(data)
295
308
```
309
+
[//]: #(inf_bert_ts_fp32)
296
310
297
311
#### TorchDynamo Mode (Experimental, _NEW feature from 2.0.0_)
298
312
299
313
##### Resnet50
300
314
315
+
[//]: #(inf_rn50_dynamo_fp32)
301
316
```
302
317
import torch
303
318
import torchvision.models as models
@@ -315,9 +330,11 @@ model = torch.compile(model, backend="ipex")
315
330
with torch.no_grad():
316
331
model(data)
317
332
```
333
+
[//]: #(inf_rn50_dynamo_fp32)
318
334
319
335
##### BERT
320
336
337
+
[//]: #(inf_bert_dynamo_fp32)
321
338
```
322
339
import torch
323
340
from transformers import BertModel
@@ -339,6 +356,7 @@ model = torch.compile(model, backend="ipex")
339
356
with torch.no_grad():
340
357
model(data)
341
358
```
359
+
[//]: #(inf_bert_dynamo_fp32)
342
360
343
361
### BFloat16
344
362
@@ -349,6 +367,7 @@ We recommend using Auto Mixed Precision (AMP) with BFloat16 data type.
349
367
350
368
##### Resnet50
351
369
370
+
[//]: #(inf_rn50_imp_bf16)
352
371
```
353
372
import torch
354
373
import torchvision.models as models
@@ -366,9 +385,11 @@ with torch.no_grad():
366
385
with torch.cpu.amp.autocast():
367
386
model(data)
368
387
```
388
+
[//]: #(inf_rn50_imp_bf16)
369
389
370
390
##### BERT
371
391
392
+
[//]: #(inf_bert_imp_bf16)
372
393
```
373
394
import torch
374
395
from transformers import BertModel
@@ -390,13 +411,15 @@ with torch.no_grad():
390
411
with torch.cpu.amp.autocast():
391
412
model(data)
392
413
```
414
+
[//]: #(inf_bert_imp_bf16)
393
415
394
416
#### TorchScript Mode
395
417
396
418
We recommend you take advantage of Intel® Extension for PyTorch\* with [TorchScript](https://pytorch.org/docs/stable/jit.html) for further optimizations.
397
419
398
420
##### Resnet50
399
421
422
+
[//]: #(inf_rn50_ts_bf16)
400
423
```
401
424
import torch
402
425
import torchvision.models as models
@@ -417,9 +440,11 @@ with torch.no_grad():
417
440
418
441
model(data)
419
442
```
443
+
[//]: #(inf_rn50_ts_bf16)
420
444
421
445
##### BERT
422
446
447
+
[//]: #(inf_bert_ts_f16)
423
448
```
424
449
import torch
425
450
from transformers import BertModel
@@ -445,6 +470,7 @@ with torch.no_grad():
445
470
446
471
model(data)
447
472
```
473
+
[//]: #(inf_bert_ts_f16)
448
474
449
475
### INT8
450
476
@@ -465,6 +491,7 @@ Please follow the steps below to perform static calibration:
465
491
7. Save the INT8 model into a `pt` file.
466
492
467
493
494
+
[//]: #(int8_static)
468
495
```
469
496
import os
470
497
import torch
@@ -494,6 +521,7 @@ with torch.no_grad():
494
521
495
522
traced_model.save("quantized_model.pt")
496
523
```
524
+
[//]: #(int8_static)
497
525
498
526
##### Dynamic Quantization
499
527
@@ -507,6 +535,7 @@ Please follow the steps below to perform static calibration:
oneDNN provides [oneDNN Graph Compiler](https://github.com/oneapi-src/oneDNN/tree/dev-graph-preview4/doc#onednn-graph-compiler) as a prototype feature that could boost performance for selective topologies. No code change is required. Install <aclass="reference external"href="installation.md#installation_onednn_graph_compiler">a binary</a> with this feature enabled. We verified this feature with `Bert-large`, `bert-base-cased`, `roberta-base`, `xlm-roberta-base`, `google-electra-base-generator` and `google-electra-base-discriminator`.
564
596
@@ -572,6 +604,7 @@ The example code below works for all data types.
0 commit comments