@@ -142,6 +142,7 @@ public History fit(IDatasetV2 dataset,
142142 int verbose = 1 ,
143143 List < ICallback > callbacks = null ,
144144 IDatasetV2 validation_data = null ,
145+ int validation_step = 10 , // 间隔多少次会进行一次验证
145146 bool shuffle = true ,
146147 int initial_epoch = 0 ,
147148 int max_queue_size = 10 ,
@@ -164,11 +165,11 @@ public History fit(IDatasetV2 dataset,
164165 } ) ;
165166
166167
167- return FitInternal ( data_handler , epochs , verbose , callbacks , validation_data : validation_data ,
168+ return FitInternal ( data_handler , epochs , validation_step , verbose , callbacks , validation_data : validation_data ,
168169 train_step_func : train_step_function ) ;
169170 }
170171
171- History FitInternal ( DataHandler data_handler , int epochs , int verbose , List < ICallback > callbackList , IDatasetV2 validation_data ,
172+ History FitInternal ( DataHandler data_handler , int epochs , int validation_step , int verbose , List < ICallback > callbackList , IDatasetV2 validation_data ,
172173 Func < DataHandler , OwnedIterator , Dictionary < string , float > > train_step_func )
173174 {
174175 stop_training = false ;
@@ -207,6 +208,9 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
207208
208209 if ( validation_data != null )
209210 {
211+ if ( validation_step > 0 && epoch == 0 || ( epoch ) % validation_step != 0 )
212+ continue ;
213+
210214 var val_logs = evaluate ( validation_data ) ;
211215 foreach ( var log in val_logs )
212216 {
0 commit comments