diff --git a/train.py b/train.py index f5c6ecf..2f97fbf 100644 --- a/train.py +++ b/train.py @@ -77,18 +77,22 @@ # Create trainer logger = TensorBoardLogger("logs", name="helmnet") + # parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + gpu_list = [int(i) for i in args.gpus.split(',')] + check_val_every_n_epoch = int(args.check_val_every_n_epoch) + checkpoint_callback = ModelCheckpoint( dirpath = os.getcwd() + "/checkpoints/", save_top_k = 3, + every_n_epochs=check_val_every_n_epoch, verbose = True, monitor = "val_loss", mode = "min", save_last = True, ) - # parser = pl.Trainer.add_argparse_args(parser) - args = parser.parse_args() - gpu_list = [int(i) for i in args.gpus.split(',')] + # Make trainer trainer = pl.Trainer.from_argparse_args(