1+ import torch
12from torch import nn , long , argmax , optim , save
2- from torch import no_grad
33from transformers import BertModel
44from torch import cuda
5- from Datasets .dataloader import create_dataloader
65from loss import calc_loss
76
87import matplotlib .pyplot as plt
98from sklearn .metrics import f1_score , confusion_matrix , ConfusionMatrixDisplay
109
1110class BERTModule (nn .Module ):
12- def __init__ (self , epochs = 10 , learning_rate = 1e-05 , dropout_p = 0.3 ):
11+ def __init__ (self , n_classes , dropout_p = 0.3 ):
1312 super (BERTModule , self ).__init__ ()
1413 self .bert = BertModel .from_pretrained ('bert-base-uncased' )
1514 # for param in self.bert.parameters():
1615 # param.requires_grad = False
1716 self .dropout = nn .Dropout (p = dropout_p )
18- self .fc = nn .Linear (768 , 3 )
17+ self .fc = nn .Linear (768 , n_classes )
1918
20- self .epochs = epochs
21- self .learning_rate = learning_rate
19+ self .device = 'cuda' if cuda . is_available () else 'cpu'
20+ self .to ( self . device )
2221
2322 def forward (self , ids , masks , ttis ):
2423 _ , pooled_output = self .bert (ids , attention_mask = masks , token_type_ids = ttis , return_dict = False )
25- output_2 = self .dropout (pooled_output )
26- output = self .fc (output_2 )
27-
24+ output_drop = self .dropout (pooled_output )
25+ output = self .fc (output_drop )
26+
2827 return output
29-
30- def fit (self , train_loader , test_loader ):
31- self .device = 'cuda' if cuda .is_available () else 'cpu'
32- self .to (self .device )
3328
34- self .train ()
35-
29+ def fit (self , train_loader , test_loader , epochs = 10 , learning_rate = 1e-05 ):
30+ self .epochs = epochs
31+ self .learning_rate = learning_rate
32+
3633 criterion = nn .CrossEntropyLoss ()
3734 optimizer = optim .Adam (params = self .parameters (), lr = self .learning_rate )
38-
39- print ('Begin training...' )
35+
36+ self .to (self .device )
37+
38+ self .train ()
4039
4140 train_losses = []
4241 test_losses = []
4342
43+ print ('Begin training...' )
44+
4445 for epoch in range (self .epochs ):
4546 train_loss = 0.
4647
@@ -66,9 +67,9 @@ def fit(self, train_loader, test_loader):
6667 optimizer .step ()
6768
6869 train_loss += loss .item ()
69-
70+
7071 avg_train_loss = train_loss / len (train_loader )
71- avg_test_loss = calc_loss (self , test_loader , criterion , self . device )
72+ avg_test_loss = calc_loss (self , test_loader , criterion )
7273
7374 train_losses .append (avg_train_loss )
7475 test_losses .append (avg_test_loss )
@@ -88,9 +89,9 @@ def evaluate(self, dataloader):
8889 data_labels = []
8990 data_outputs = []
9091
91- with no_grad ():
92+ with torch . no_grad ():
9293 for inputs , labels in dataloader :
93-
94+
9495 ids = inputs [:, 0 ].to (self .device , dtype = long )
9596 masks = inputs [:, 1 ].to (self .device , dtype = long )
9697 tti = inputs [:, 2 ].to (self .device , dtype = long )
@@ -109,15 +110,41 @@ def evaluate(self, dataloader):
109110 data_labels .extend (labels .cpu ().detach ().numpy ().tolist ())
110111 data_outputs .extend (outputs .cpu ().detach ().numpy ().tolist ())
111112
112-
113113 target_names = ['Easy' , 'Medium' , 'Hard' ]
114114 macro_f1 = f1_score (data_labels , data_outputs , average = 'macro' )
115115 cm = confusion_matrix (data_labels , data_outputs )
116- disp = ConfusionMatrixDisplay (confusion_matrix = cm , display_labels = [ 'Easy' , 'Medium' , 'Hard' ] )
116+ disp = ConfusionMatrixDisplay (confusion_matrix = cm , display_labels = target_names )
117117 print (f'Macro F1: { macro_f1 } ' )
118118 disp .plot ()
119119 plt .show ()
120120
121+ def predict (self , text ):
122+ self .eval ()
123+
124+ from Datasets .encoders import define_encoders
125+ input_encoder , _ = define_encoders (max_len = 300 )
126+
127+ with torch .no_grad ():
128+ input = input_encoder (text )
129+
130+ ids = input [:, 0 ].to (self .device , dtype = long )
131+ masks = input [:, 1 ].to (self .device , dtype = long )
132+ tti = input [:, 2 ].to (self .device , dtype = long )
133+ labels = labels .squeeze ().to (self .device , dtype = long )
134+
135+ assert ids .shape == masks .shape , 'Ids != Masks'
136+ assert masks .shape == tti .shape , 'Masks != Ttis'
137+ assert ids .shape == tti .shape , 'Ids != Ttis'
138+
139+ assert ids .shape [0 ] == labels .shape [0 ], 'inputs and labels are incompatible'
140+
141+ outputs = self (ids , masks , tti )
142+ outputs = nn .functional .softmax (outputs , dim = 1 )
143+ outputs = argmax (outputs , dim = 1 )
144+
145+ outputs = outputs .cpu ().detach ().numpy ().tolist ()
146+
147+ print (len (outputs ))
121148
122- def predict ():
123- return 1
149+ print ( f'Text: { text } ' )
150+ print ( f'Difficulty: { text } ' )
0 commit comments