1+ from torch import nn , long , argmax , optim , save
2+ from torch import no_grad
3+ from transformers import BertModel
4+ from torch import cuda
5+ from Datasets .dataloader import create_dataloader
6+ from loss import calc_loss
7+
8+ import matplotlib .pyplot as plt
9+ from sklearn .metrics import f1_score , confusion_matrix , ConfusionMatrixDisplay
10+
11+ class BERTModule (nn .Module ):
12+ def __init__ (self , epochs = 10 , learning_rate = 1e-05 , dropout_p = 0.3 ):
13+ super (BERTModule , self ).__init__ ()
14+ self .bert = BertModel .from_pretrained ('bert-base-uncased' )
15+ # for param in self.bert.parameters():
16+ # param.requires_grad = False
17+ self .dropout = nn .Dropout (p = dropout_p )
18+ self .fc = nn .Linear (768 , 3 )
19+
20+ self .epochs = epochs
21+ self .learning_rate = learning_rate
22+
23+ def forward (self , ids , masks , ttis ):
24+ _ , pooled_output = self .bert (ids , attention_mask = masks , token_type_ids = ttis , return_dict = False )
25+ output_2 = self .dropout (pooled_output )
26+ output = self .fc (output_2 )
27+
28+ return output
29+
30+ def fit (self , train_loader , test_loader ):
31+ self .device = 'cuda' if cuda .is_available () else 'cpu'
32+ self .to (self .device )
33+
34+ self .train ()
35+
36+ criterion = nn .CrossEntropyLoss ()
37+ optimizer = optim .Adam (params = self .parameters (), lr = self .learning_rate )
38+
39+ print ('Begin training...' )
40+
41+ train_losses = []
42+ test_losses = []
43+
44+ for epoch in range (self .epochs ):
45+ train_loss = 0.
46+
47+ for inputs , labels in train_loader :
48+ optimizer .zero_grad ()
49+
50+ ids = inputs [:, 0 ].to (self .device , dtype = long )
51+ masks = inputs [:, 1 ].to (self .device , dtype = long )
52+ tti = inputs [:, 2 ].to (self .device , dtype = long )
53+ labels = labels .squeeze ().to (self .device , dtype = long )
54+
55+ assert ids .shape == masks .shape , 'Ids != Masks'
56+ assert masks .shape == tti .shape , 'Masks != Ttis'
57+ assert ids .shape == tti .shape , 'Ids != Ttis'
58+
59+ assert ids .shape [0 ] == labels .shape [0 ], 'inputs and labels are incompatible'
60+
61+ outputs = self (ids , masks , tti )
62+
63+ loss = criterion (outputs , labels )
64+
65+ loss .backward ()
66+ optimizer .step ()
67+
68+ train_loss += loss .item ()
69+
70+ avg_train_loss = train_loss / len (train_loader )
71+ avg_test_loss = calc_loss (self , test_loader , criterion , self .device )
72+
73+ train_losses .append (avg_train_loss )
74+ test_losses .append (avg_test_loss )
75+
76+ print (f'Epoch { epoch + 1 } /{ self .epochs } Train Loss: { avg_train_loss } Test Loss: { avg_test_loss } ' )
77+
78+ print ('Ending training...' )
79+
80+ model_name = 'model' + '_' + 'ep' + str (self .epochs ) + '_' + 'lr' + str (self .learning_rate ) + '.pth'
81+ save (self .state_dict (), model_name )
82+
83+ return train_losses , test_losses
84+
85+ def evaluate (self , dataloader ):
86+ self .eval ()
87+
88+ data_labels = []
89+ data_outputs = []
90+
91+ with no_grad ():
92+ for inputs , labels in dataloader :
93+
94+ ids = inputs [:, 0 ].to (self .device , dtype = long )
95+ masks = inputs [:, 1 ].to (self .device , dtype = long )
96+ tti = inputs [:, 2 ].to (self .device , dtype = long )
97+ labels = labels .squeeze ().to (self .device , dtype = long )
98+
99+ assert ids .shape == masks .shape , 'Ids != Masks'
100+ assert masks .shape == tti .shape , 'Masks != Ttis'
101+ assert ids .shape == tti .shape , 'Ids != Ttis'
102+
103+ assert ids .shape [0 ] == labels .shape [0 ], 'inputs and labels are incompatible'
104+
105+ outputs = self (ids , masks , tti )
106+ outputs = nn .functional .softmax (outputs , dim = 1 )
107+ outputs = argmax (outputs , dim = 1 )
108+
109+ data_labels .extend (labels .cpu ().detach ().numpy ().tolist ())
110+ data_outputs .extend (outputs .cpu ().detach ().numpy ().tolist ())
111+
112+
113+ target_names = ['Easy' , 'Medium' , 'Hard' ]
114+ macro_f1 = f1_score (data_labels , data_outputs , average = 'macro' )
115+ cm = confusion_matrix (data_labels , data_outputs )
116+ disp = ConfusionMatrixDisplay (confusion_matrix = cm , display_labels = ['Easy' , 'Medium' , 'Hard' ])
117+ print (f'Macro F1: { macro_f1 } ' )
118+ disp .plot ()
119+ plt .show ()
120+
121+
122+ def predict ():
123+ return 1
0 commit comments