@@ -19,8 +19,8 @@ class BERTTrainer:
1919
2020 """
2121
22- def __init__ (self , bert : BERT , vocab_size ,
23- train_dataloader : DataLoader , test_dataloader : DataLoader = None ,
22+ def __init__ (self , bert , vocab_size ,
23+ train_dataloader , test_dataloader = None ,
2424 lr : float = 1e-4 , betas = (0.9 , 0.999 ), weight_decay : float = 0.01 ,
2525 with_cuda : bool = True , log_freq : int = 10 ):
2626 """
@@ -40,18 +40,18 @@ def __init__(self, bert: BERT, vocab_size,
4040 self .device = torch .device ("cuda:0" if cuda_condition else "cpu" )
4141
4242 # This BERT model will be saved every epoch
43- self .bert : BERT = bert
43+ self .bert = bert
4444 # Initialize the BERT Language Model, with BERT model
45- self .model : BERTLM = BERTLM (bert , vocab_size ).to (self .device )
45+ self .model = BERTLM (bert , vocab_size ).to (self .device )
4646
4747 # Distributed GPU training if CUDA can detect more than 1 GPU
4848 if torch .cuda .device_count () > 1 :
4949 print ("Using %d GPUS for BERT" % torch .cuda .device_count ())
5050 self .model = nn .DataParallel (self .model )
5151
5252 # Setting the train and test data loader
53- self .train_data : DataLoader = train_dataloader
54- self .test_data : DataLoader = test_dataloader
53+ self .train_data = train_dataloader
54+ self .test_data = test_dataloader
5555
5656 # Setting the Adam optimizer with hyper-param
5757 self .optim = Adam (self .model .parameters (), lr = lr , betas = betas , weight_decay = weight_decay )
0 commit comments