From 2bf3a9e245ad6e0dc3208e56a765aeff1caa6319 Mon Sep 17 00:00:00 2001 From: imxuebi <77877325+imxuebi@users.noreply.github.com> Date: Mon, 29 Dec 2025 20:03:41 +0800 Subject: [PATCH] Update client.py fix: unify HomoLR predict logic and handle batch_size Signed-off-by: imxuebi <77877325+imxuebi@users.noreply.github.com> --- python/fate/ml/glm/homo/lr/client.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/python/fate/ml/glm/homo/lr/client.py b/python/fate/ml/glm/homo/lr/client.py index e5578d9a06..d5eb16aafe 100644 --- a/python/fate/ml/glm/homo/lr/client.py +++ b/python/fate/ml/glm/homo/lr/client.py @@ -399,20 +399,17 @@ def predict(self, ctx: Context, predict_data: DataFrame) -> DataFrame: if self.model is None: raise ValueError("model is not initialized") self.predict_set = self._make_dataset(predict_data) - if self.trainer is None: - batch_size = len(self.predict_set) if self.batch_size is None else self.batch_size - train_arg = TrainingArguments(num_train_epochs=self.max_iter, per_device_eval_batch_size=batch_size) - trainer = FedAVGClient( - ctx, - train_set=self.predict_set, - model=self.model, - training_args=train_arg, - fed_args=FedAVGArguments(), - data_collator=default_data_collator, - ) - trainer.set_local_mode() - else: - trainer = self.trainer + batch_size = len(self.predict_set) if self.batch_size is None else self.batch_size + train_arg = TrainingArguments(num_train_epochs=self.max_iter, per_device_eval_batch_size=batch_size) + trainer = FedAVGClient( + ctx, + train_set=self.predict_set, + model=self.model, + training_args=train_arg, + fed_args=FedAVGArguments(), + data_collator=default_data_collator, + ) + trainer.set_local_mode() predict_rs = trainer.predict(self.predict_set) predict_out_df = self._make_output_df(ctx, predict_rs, self.predict_set, self.threshold) return predict_out_df