|
| 1 | +import collections |
| 2 | +import logging |
| 3 | + |
| 4 | +from torch import nn |
| 5 | +from tqdm import tqdm |
| 6 | + |
| 7 | +logger = logging.getLogger(__name__) |
| 8 | + |
| 9 | + |
| 10 | +def single_loop( |
| 11 | + train_loader, |
| 12 | + val_loader, |
| 13 | + devices, |
| 14 | + network, |
| 15 | + n_epochs, |
| 16 | + optimizer, |
| 17 | + train_function, |
| 18 | + eval_function, |
| 19 | + task, |
| 20 | + n_tasks, |
| 21 | + scheduler=None, |
| 22 | + disable_progressbar=False, |
| 23 | + eval_every_x_epochs=None, |
| 24 | + early_stopping=None |
| 25 | +): |
| 26 | + best_epoch, best_acc = -1, -1. |
| 27 | + wait = 0 |
| 28 | + |
| 29 | + if len(devices) > 1: |
| 30 | + logger.info("Duplicating model on {} gpus.".format(len(devices))) |
| 31 | + training_network = nn.DataParallel(network, devices) |
| 32 | + else: |
| 33 | + training_network = network |
| 34 | + |
| 35 | + for epoch in range(n_epochs): |
| 36 | + metrics = collections.defaultdict(float) |
| 37 | + |
| 38 | + prog_bar = tqdm( |
| 39 | + train_loader, |
| 40 | + disable=disable_progressbar, |
| 41 | + ascii=True, |
| 42 | + bar_format="{desc}: {percentage:3.0f}% | {n_fmt}/{total_fmt} | {rate_fmt}{postfix}" |
| 43 | + ) |
| 44 | + for batch_index, input_dict in enumerate(prog_bar, start=1): |
| 45 | + inputs, targets = input_dict["inputs"], input_dict["targets"] |
| 46 | + memory_flags = input_dict["memory_flags"] |
| 47 | + |
| 48 | + optimizer.zero_grad() |
| 49 | + loss = train_function(training_network, inputs, targets, memory_flags, metrics) |
| 50 | + loss.backward() |
| 51 | + optimizer.step() |
| 52 | + |
| 53 | + _print_metrics(metrics, prog_bar, epoch, n_epochs, batch_index, task, n_tasks) |
| 54 | + |
| 55 | + if scheduler: |
| 56 | + scheduler.step(epoch) |
| 57 | + |
| 58 | + if eval_every_x_epochs and epoch != 0 and epoch % eval_every_x_epochs == 0: |
| 59 | + training_network.eval() |
| 60 | + accuracy = eval_function(training_network, val_loader) |
| 61 | + training_network.train() |
| 62 | + |
| 63 | + logger.info("Val accuracy: {}".format(accuracy)) |
| 64 | + |
| 65 | + if accuracy > best_acc: |
| 66 | + best_epoch = epoch |
| 67 | + best_acc = accuracy |
| 68 | + wait = 0 |
| 69 | + else: |
| 70 | + wait += 1 |
| 71 | + |
| 72 | + if early_stopping and early_stopping["patience"] > wait: |
| 73 | + logger.warning("Early stopping!") |
| 74 | + break |
| 75 | + |
| 76 | + if eval_every_x_epochs: |
| 77 | + logger.info("Best accuracy reached at epoch {} with {}%.".format(best_epoch, best_acc)) |
| 78 | + |
| 79 | + |
| 80 | +def _print_metrics(metrics, prog_bar, epoch, nb_epochs, nb_batches, task, n_tasks): |
| 81 | + pretty_metrics = ", ".join( |
| 82 | + "{}: {}".format(metric_name, round(metric_value / nb_batches, 3)) |
| 83 | + for metric_name, metric_value in metrics.items() |
| 84 | + ) |
| 85 | + |
| 86 | + prog_bar.set_description( |
| 87 | + "T{}/{}, E{}/{} => {}".format(task + 1, n_tasks, epoch + 1, nb_epochs, pretty_metrics) |
| 88 | + ) |
0 commit comments