diff --git a/python/fate/ml/glm/homo/lr/client.py b/python/fate/ml/glm/homo/lr/client.py index e5578d9a06..d5eb16aafe 100644 --- a/python/fate/ml/glm/homo/lr/client.py +++ b/python/fate/ml/glm/homo/lr/client.py @@ -399,20 +399,17 @@ def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: if self.model is None: raise ValueError("model is not initialized") self.predict_set = self._make_dataset(predict_data) - if self.trainer is None: - batch_size = len(self.predict_set) if self.batch_size is None else self.batch_size - train_arg = TrainingArguments(num_train_epochs=self.max_iter, per_device_eval_batch_size=batch_size) - trainer = FedAVGClient( - ctx, - train_set=self.predict_set, - model=self.model, - training_args=train_arg, - fed_args=FedAVGArguments(), - data_collator=default_data_collator, - ) - trainer.set_local_mode() - else: - trainer = self.trainer + batch_size = len(self.predict_set) if self.batch_size is None else self.batch_size + train_arg = TrainingArguments(num_train_epochs=self.max_iter, per_device_eval_batch_size=batch_size) + trainer = FedAVGClient( + ctx, + train_set=self.predict_set, + model=self.model, + training_args=train_arg, + fed_args=FedAVGArguments(), + data_collator=default_data_collator, + ) + trainer.set_local_mode() predict_rs = trainer.predict(self.predict_set) predict_out_df = self._make_output_df(ctx, predict_rs, self.predict_set, self.threshold) return predict_out_df