diff --git a/trellis/pipelines/trellis_image_to_3d.py b/trellis/pipelines/trellis_image_to_3d.py index f781e348..64ea6f1f 100644 --- a/trellis/pipelines/trellis_image_to_3d.py +++ b/trellis/pipelines/trellis_image_to_3d.py @@ -105,7 +105,8 @@ def preprocess_image(self, input: Image.Image) -> Image.Image: output = rembg.remove(input, session=self.rembg_session) output_np = np.array(output) alpha = output_np[:, :, 3] - bbox = np.argwhere(alpha > 0.8 * 255) + alpha_cutoff = 0.8 + bbox = np.argwhere(alpha > alpha_cutoff * 255) bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) @@ -114,7 +115,7 @@ def preprocess_image(self, input: Image.Image) -> Image.Image: output = output.crop(bbox) # type: ignore output = output.resize((518, 518), Image.Resampling.LANCZOS) output = np.array(output).astype(np.float32) / 255 - output = output[:, :, :3] * output[:, :, 3:4] + output = output[:, :, :3] * (output[:, :, 3:4] > alpha_cutoff) output = Image.fromarray((output * 255).astype(np.uint8)) return output