1616
1717from datasets import get_datasets
1818from trainers import create_trainers
19- from handlers import get_handlers , get_logger
20- from utils import setup_logging , log_metrics , log_basic_info , initialize , resume_from
19+ from utils import setup_logging , log_metrics , log_basic_info , initialize , resume_from , get_handlers , get_logger
2120from config import get_default_parser
2221
2322
@@ -36,6 +35,19 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
3635 rank = idist .get_rank ()
3736 manual_seed (config .seed + rank )
3837
38+ # -----------------------
39+ # create output folder
40+ # -----------------------
41+
42+ if rank == 0 :
43+ now = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
44+ name = f"{ config .dataset } -backend-{ idist .backend ()} -{ now } "
45+ path = Path (config .output_dir , name )
46+ path .mkdir (parents = True , exist_ok = True )
47+ config .output_dir = path .as_posix ()
48+
49+ config .output_dir = Path (idist .broadcast (config .output_dir , src = 0 ))
50+
3951 # -----------------------------
4052 # datasets and dataloaders
4153 # -----------------------------
@@ -45,7 +57,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
4557 train_dataloader = idist .auto_dataloader (
4658 train_dataset ,
4759 batch_size = config .batch_size ,
48- num_workers = config .num_workers
60+ num_workers = config .num_workers ,
61+ {% if use_distributed_training and not use_distributed_launcher % }
62+ persistent_workers = True ,
63+ {% endif % }
4964 )
5065
5166 # ------------------------------------------
@@ -58,9 +73,10 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
5873 # -----------------------------
5974 # train_engine and eval_engine
6075 # -----------------------------
61- real_labels = torch .ones (config .batch_size , device = device )
62- fake_labels = torch .zeros (config .batch_size , device = device )
63- fixed_noise = torch .randn (config .batch_size , config .z_dim , 1 , 1 , device = device )
76+ ws = idist .get_world_size ()
77+ real_labels = torch .ones (config .batch_size // ws , device = device )
78+ fake_labels = torch .zeros (config .batch_size // ws , device = device )
79+ fixed_noise = torch .randn (config .batch_size // ws , config .z_dim , 1 , 1 , device = device )
6480
6581 train_engine = create_trainers (
6682 config = config ,
@@ -75,7 +91,6 @@ def run(local_rank: int, config: Any, *args: Any, **kwargs: Any):
7591 )
7692
7793 # -------------------------------------------
78- # update config with optimizer parameters
7994 # setup engines logger with python logging
8095 # print training configurations
8196 # -------------------------------------------
@@ -203,20 +218,17 @@ def main():
203218 parser = ArgumentParser (parents = [get_default_parser ()])
204219 config = parser .parse_args ()
205220
206- if config .output_dir :
207- now = datetime .now ().strftime ("%Y%m%d-%H%M%S" )
208- name = f'{ config .dataset } -backend-{ idist .backend ()} -{ now } '
209- path = Path (config .output_dir , name )
210- path .mkdir (parents = True , exist_ok = True )
211- config .output_dir = path
212-
213221 with idist .Parallel (
214222 backend = config .backend ,
223+ {% if use_distributed_training and not use_distributed_launcher % }
215224 nproc_per_node = config .nproc_per_node ,
216- nnodes = config . nnodes ,
225+ { % if nnodes > 1 and not use_distributed_launcher % }
217226 node_rank = config .node_rank ,
227+ nnodes = config .nnodes ,
218228 master_addr = config .master_addr ,
219229 master_port = config .master_port ,
230+ {% endif % }
231+ {% endif % }
220232 ) as parallel :
221233 parallel .run (run , config = config )
222234
0 commit comments