From 0f82fa2ba95e69d21eef356bb28ea879e734ad1d Mon Sep 17 00:00:00 2001 From: Velythyl Date: Fri, 16 May 2025 14:16:35 -0400 Subject: [PATCH 1/8] flask api --- flaskclient.py | 166 +++++++++++++++++++++++++++++ flaskserver.py | 279 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 445 insertions(+) create mode 100644 flaskclient.py create mode 100644 flaskserver.py diff --git a/flaskclient.py b/flaskclient.py new file mode 100644 index 00000000..fa279fb3 --- /dev/null +++ b/flaskclient.py @@ -0,0 +1,166 @@ +import requests +import os +import uuid + +class Trellis3DClient: + def __init__(self, base_url='http://localhost:5000'): + self.base_url = base_url + + def generate_from_single_image(self, image_path, params=None): + """ + Generate 3D model from a single image + + Args: + image_path (str): Path to the input image + params (dict): Optional parameters including: + - seed (int): Random seed + - ss_guidance_strength (float): Guidance strength for sparse structure generation + - ss_sampling_steps (int): Sampling steps for sparse structure generation + - slat_guidance_strength (float): Guidance strength for structured latent generation + - slat_sampling_steps (int): Sampling steps for structured latent generation + + Returns: + dict: Response containing session_id and download URLs + """ + if params is None: + params = {} + + url = f"{self.base_url}/generate_from_single_image" + files = {'image': open(image_path, 'rb')} + + response = requests.post(url, files=files, data=params) + return response.json() + + def generate_from_multiple_images(self, image_paths, params=None): + """ + Generate 3D model from multiple images + + Args: + image_paths (list): List of paths to input images + params (dict): Optional parameters including: + - seed (int): Random seed + - ss_guidance_strength (float): Guidance strength for sparse structure generation + - ss_sampling_steps (int): Sampling steps for sparse structure generation + - slat_guidance_strength (float): Guidance strength for structured latent generation + - slat_sampling_steps (int): Sampling steps for structured latent generation + - multiimage_algo (str): Algorithm for multi-image generation ('stochastic' or 'multidiffusion') + + Returns: + dict: Response containing session_id and download URLs + """ + if params is None: + params = {} + + url = f"{self.base_url}/generate_from_multiple_images" + files = [('images', (os.path.basename(path), open(path, 'rb')) for path in image_paths] + + response = requests.post(url, files=files, data=params) + return response.json() + + def extract_glb(self, session_id, params=None): + """ + Extract GLB file from generated 3D model + + Args: + session_id (str): Session ID from generation step + params (dict): Optional parameters including: + - mesh_simplify (float): Mesh simplification factor (0.9-0.98) + - texture_size (int): Texture resolution (512, 1024, 1536, or 2048) + + Returns: + dict: Response containing GLB download URL + """ + if params is None: + params = {} + + url = f"{self.base_url}/extract_glb" + data = {'session_id': session_id, **params} + + response = requests.post(url, data=data) + return response.json() + + def download_file(self, url, save_path=None): + """ + Download a file from the server + + Args: + url (str): Full URL to download (from previous responses) + save_path (str): Optional path to save the file + + Returns: + str: Path where file was saved + """ + if save_path is None: + save_path = os.path.basename(url) + + response = requests.get(url, stream=True) + with open(save_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + return save_path + +# Example usage +if __name__ == '__main__': + client = Trellis3DClient() + + # Example 1: Single image generation + print("Generating from single image...") + single_result = client.generate_from_single_image( + 'example_image.png', + params={ + 'seed': 42, + 'ss_guidance_strength': 7.5, + 'slat_guidance_strength': 3.0 + } + ) + print(single_result) + + # Download preview + client.download_file( + f"http://localhost:5000{single_result['preview_url']}", + 'single_preview.mp4' + ) + + # Extract GLB + glb_result = client.extract_glb( + single_result['session_id'], + params={'mesh_simplify': 0.95} + ) + print(glb_result) + + # Download GLB + client.download_file( + f"http://localhost:5000{glb_result['glb_url']}", + 'single_model.glb' + ) + + # Example 2: Multiple image generation + print("\nGenerating from multiple images...") + multi_result = client.generate_from_multiple_images( + ['view1.png', 'view2.png', 'view3.png'], + params={ + 'multiimage_algo': 'stochastic', + 'seed': 123 + } + ) + print(multi_result) + + # Download preview + client.download_file( + f"http://localhost:5000{multi_result['preview_url']}", + 'multi_preview.mp4' + ) + + # Extract GLB + glb_result = client.extract_glb( + multi_result['session_id'], + params={'texture_size': 2048} + ) + print(glb_result) + + # Download GLB + client.download_file( + f"http://localhost:5000{glb_result['glb_url']}", + 'multi_model.glb' + ) \ No newline at end of file diff --git a/flaskserver.py b/flaskserver.py new file mode 100644 index 00000000..bb5bd0b6 --- /dev/null +++ b/flaskserver.py @@ -0,0 +1,279 @@ +import os +import uuid +import shutil +from flask import Flask, request, jsonify, send_file +from werkzeug.utils import secure_filename +from PIL import Image +import numpy as np +import torch +from easydict import EasyDict as edict +import imageio +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.representations import Gaussian, MeshExtractResult +from trellis.utils import render_utils, postprocessing_utils + +app = Flask(__name__) +app.config['UPLOAD_FOLDER'] = 'uploads' +app.config['OUTPUT_FOLDER'] = 'outputs' +os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) +os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True) + +# Initialize pipeline +pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") +pipeline.cuda() + +MAX_SEED = np.iinfo(np.int32).max + +def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: + return { + 'gaussian': { + **gs.init_params, + '_xyz': gs._xyz.cpu().numpy(), + '_features_dc': gs._features_dc.cpu().numpy(), + '_scaling': gs._scaling.cpu().numpy(), + '_rotation': gs._rotation.cpu().numpy(), + '_opacity': gs._opacity.cpu().numpy(), + }, + 'mesh': { + 'vertices': mesh.vertices.cpu().numpy(), + 'faces': mesh.faces.cpu().numpy(), + }, + } + +def unpack_state(state: dict) -> Tuple[Gaussian, edict]: + gs = Gaussian( + aabb=state['gaussian']['aabb'], + sh_degree=state['gaussian']['sh_degree'], + mininum_kernel_size=state['gaussian']['mininum_kernel_size'], + scaling_bias=state['gaussian']['scaling_bias'], + opacity_bias=state['gaussian']['opacity_bias'], + scaling_activation=state['gaussian']['scaling_activation'], + ) + gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') + gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') + gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') + gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') + gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') + + mesh = edict( + vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), + faces=torch.tensor(state['mesh']['faces'], device='cuda'), + ) + + return gs, mesh + +def preprocess_image(image_path: str) -> Image.Image: + image = Image.open(image_path) + processed_image = pipeline.preprocess_image(image) + return processed_image + +@app.route('/generate_from_single_image', methods=['POST']) +def generate_from_single_image(): + try: + # Get parameters + seed = int(request.form.get('seed', 0)) + ss_guidance_strength = float(request.form.get('ss_guidance_strength', 7.5)) + ss_sampling_steps = int(request.form.get('ss_sampling_steps', 12)) + slat_guidance_strength = float(request.form.get('slat_guidance_strength', 3.0)) + slat_sampling_steps = int(request.form.get('slat_sampling_steps', 12)) + + # Handle file upload + if 'image' not in request.files: + return jsonify({'error': 'No image provided'}), 400 + + file = request.files['image'] + if file.filename == '': + return jsonify({'error': 'No selected file'}), 400 + + # Create session directory + session_id = str(uuid.uuid4()) + session_dir = os.path.join(app.config['UPLOAD_FOLDER'], session_id) + os.makedirs(session_dir, exist_ok=True) + + # Save uploaded file + filename = secure_filename(file.filename) + image_path = os.path.join(session_dir, filename) + file.save(image_path) + + # Preprocess image + image = preprocess_image(image_path) + + # Generate 3D model + outputs = pipeline.run( + image, + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + ) + + # Create output directory + output_dir = os.path.join(app.config['OUTPUT_FOLDER'], session_id) + os.makedirs(output_dir, exist_ok=True) + + # Save video preview + video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] + video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] + video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] + video_path = os.path.join(output_dir, 'preview.mp4') + imageio.mimsave(video_path, video, fps=15) + + # Save state + state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) + state_path = os.path.join(output_dir, 'state.pkl') + torch.save(state, state_path) + + torch.cuda.empty_cache() + + return jsonify({ + 'session_id': session_id, + 'preview_url': f'/download/preview/{session_id}', + 'state_url': f'/download/state/{session_id}' + }), 200 + + except Exception as e: + return jsonify({'error': str(e)}), 500 + +@app.route('/generate_from_multiple_images', methods=['POST']) +def generate_from_multiple_images(): + try: + # Get parameters + seed = int(request.form.get('seed', 0)) + ss_guidance_strength = float(request.form.get('ss_guidance_strength', 7.5)) + ss_sampling_steps = int(request.form.get('ss_sampling_steps', 12)) + slat_guidance_strength = float(request.form.get('slat_guidance_strength', 3.0)) + slat_sampling_steps = int(request.form.get('slat_sampling_steps', 12)) + multiimage_algo = request.form.get('multiimage_algo', 'stochastic') + + # Handle file uploads + if 'images' not in request.files: + return jsonify({'error': 'No images provided'}), 400 + + files = request.files.getlist('images') + if len(files) == 0: + return jsonify({'error': 'No selected files'}), 400 + + # Create session directory + session_id = str(uuid.uuid4()) + session_dir = os.path.join(app.config['UPLOAD_FOLDER'], session_id) + os.makedirs(session_dir, exist_ok=True) + + # Save uploaded files + images = [] + for file in files: + if file.filename == '': + continue + filename = secure_filename(file.filename) + image_path = os.path.join(session_dir, filename) + file.save(image_path) + images.append(Image.open(image_path)) + + # Generate 3D model + outputs = pipeline.run_multi_image( + images, + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=False, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + mode=multiimage_algo, + ) + + # Create output directory + output_dir = os.path.join(app.config['OUTPUT_FOLDER'], session_id) + os.makedirs(output_dir, exist_ok=True) + + # Save video preview + video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] + video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] + video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] + video_path = os.path.join(output_dir, 'preview.mp4') + imageio.mimsave(video_path, video, fps=15) + + # Save state + state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) + state_path = os.path.join(output_dir, 'state.pkl') + torch.save(state, state_path) + + torch.cuda.empty_cache() + + return jsonify({ + 'session_id': session_id, + 'preview_url': f'/download/preview/{session_id}', + 'state_url': f'/download/state/{session_id}' + }), 200 + + except Exception as e: + return jsonify({'error': str(e)}), 500 + +@app.route('/extract_glb', methods=['POST']) +def extract_glb(): + try: + # Get parameters + session_id = request.form.get('session_id') + mesh_simplify = float(request.form.get('mesh_simplify', 0.95)) + texture_size = int(request.form.get('texture_size', 1024)) + + if not session_id: + return jsonify({'error': 'session_id is required'}), 400 + + # Load state + state_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'state.pkl') + if not os.path.exists(state_path): + return jsonify({'error': 'Invalid session_id'}), 404 + + state = torch.load(state_path) + gs, mesh = unpack_state(state) + + # Extract GLB + glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) + + # Save GLB + glb_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'model.glb') + glb.export(glb_path) + + torch.cuda.empty_cache() + + return jsonify({ + 'glb_url': f'/download/glb/{session_id}' + }), 200 + + except Exception as e: + return jsonify({'error': str(e)}), 500 + +@app.route('/download/preview/', methods=['GET']) +def download_preview(session_id): + preview_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'preview.mp4') + if not os.path.exists(preview_path): + return jsonify({'error': 'Preview not found'}), 404 + return send_file(preview_path, as_attachment=True) + +@app.route('/download/state/', methods=['GET']) +def download_state(session_id): + state_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'state.pkl') + if not os.path.exists(state_path): + return jsonify({'error': 'State not found'}), 404 + return send_file(state_path, as_attachment=True) + +@app.route('/download/glb/', methods=['GET']) +def download_glb(session_id): + glb_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'model.glb') + if not os.path.exists(glb_path): + return jsonify({'error': 'GLB not found'}), 404 + return send_file(glb_path, as_attachment=True) + +if __name__ == '__main__': + app.run(host='0.0.0.0', port=5000, threaded=True) \ No newline at end of file From 6916dc1c1c84c343e3fecf2a33733bf3a31eea93 Mon Sep 17 00:00:00 2001 From: Charlie Gauthier Date: Fri, 16 May 2025 16:17:41 -0400 Subject: [PATCH 2/8] progress --- app_cli.py | 47 ++++++++++ cleanup_mesh.py | 47 ++++++++++ cleanup_mesh2.py | 190 +++++++++++++++++++++++++++++++++++++ cleanup_mesh3.py | 167 +++++++++++++++++++++++++++++++++ cleanup_mesh4.py | 97 +++++++++++++++++++ cleanup_mesh5.py | 237 +++++++++++++++++++++++++++++++++++++++++++++++ flaskclient.py | 146 +++++++++++++++++++++++++++-- 7 files changed, 921 insertions(+), 10 deletions(-) create mode 100644 app_cli.py create mode 100644 cleanup_mesh.py create mode 100644 cleanup_mesh2.py create mode 100644 cleanup_mesh3.py create mode 100644 cleanup_mesh4.py create mode 100644 cleanup_mesh5.py diff --git a/app_cli.py b/app_cli.py new file mode 100644 index 00000000..7f89a4eb --- /dev/null +++ b/app_cli.py @@ -0,0 +1,47 @@ +import requests +from PIL import Image +import io +import os + +class Gradio3DClient: + def __init__(self, base_url: str): + self.base_url = base_url.rstrip('/') + + def send_image(self, image_path: str) -> dict: + """Sends an image to the Gradio app and generates a 3D model.""" + with open(image_path, "rb") as f: + response = requests.post(f"{self.base_url}/run/image_to_3d", files={"image": f}) + response.raise_for_status() + return response.json()["data"][0] + + def extract_glb(self, state: dict, mesh_simplify: float = 0.95, texture_size: int = 1024) -> str: + """Extracts a GLB file from the generated 3D model.""" + response = requests.post(f"{self.base_url}/run/extract_glb", json={"state": state, "mesh_simplify": mesh_simplify, "texture_size": texture_size}) + response.raise_for_status() + return response.json()["data"][0] + + def generate_views(self, state: dict, num_views: int = 4) -> list: + """Generates views of the 3D model from different angles.""" + images = [] + for angle in range(0, 360, 360 // num_views): + response = requests.post(f"{self.base_url}/run/render_view", json={"state": state, "angle": angle}) + response.raise_for_status() + images.append(response.content) + return images + + def process_image(self, image_path: str) -> dict: + """Full pipeline: Send image, generate model, extract GLB, and generate views.""" + model_info = self.send_image(image_path) + glb_path = self.extract_glb(model_info) + views = self.generate_views(model_info) + + return { + "original_image": image_path, + "glb_path": glb_path, + "views": views + } + +if __name__ == "__main__": + client = Gradio3DClient("http://localhost:7860") + result = client.process_image("/home/charlie/Desktop/Holodeck/hippo/datasets/sacha_kitchen/segments/2/rgb/000.png") + print("Process complete. Files:", result) diff --git a/cleanup_mesh.py b/cleanup_mesh.py new file mode 100644 index 00000000..63943900 --- /dev/null +++ b/cleanup_mesh.py @@ -0,0 +1,47 @@ +import trimesh +import numpy as np + +# Function to clean the thin base plane of a mesh +def clean_base_plane(mesh_path: str, output_path: str, threshold_ratio: float = 0.01): + # Load the mesh + mesh = trimesh.load(mesh_path) + + # If the mesh is a Scene, merge it into a single Trimesh + if isinstance(mesh, trimesh.Scene): + mesh = mesh.dump(concatenate=True) + + # Ensure it is a valid Trimesh + if not isinstance(mesh, trimesh.Trimesh): + raise ValueError("Loaded file is not a valid Trimesh object.") + + # Calculate the average Z position of each face + face_z_heights = mesh.vertices[mesh.faces].mean(axis=1)[:, 2] + + # Determine the Z threshold (bottom 1% by default) + z_threshold = np.quantile(face_z_heights, threshold_ratio) + + # Identify faces below the threshold + face_mask = face_z_heights > z_threshold + + # Keep only the faces above the threshold + cleaned_faces = mesh.faces[face_mask] + + # Get the unique vertices used by these faces + unique_vertices = np.unique(cleaned_faces) + cleaned_vertices = mesh.vertices[unique_vertices] + + # Remap face indices to the cleaned vertices + vertex_map = {old_idx: new_idx for new_idx, old_idx in enumerate(unique_vertices)} + remapped_faces = np.vectorize(vertex_map.get)(cleaned_faces) + + # Create a new mesh using the cleaned faces and vertices + cleaned_mesh = trimesh.Trimesh(vertices=cleaned_vertices, faces=remapped_faces) + + # Export the cleaned mesh + cleaned_mesh.export(output_path) + + print(f"Cleaned mesh saved to {output_path}") + + +if __name__ == "__main__": + clean_base_plane("./blo/model.glb", "./blo/model_clean.glb") \ No newline at end of file diff --git a/cleanup_mesh2.py b/cleanup_mesh2.py new file mode 100644 index 00000000..6fdeb290 --- /dev/null +++ b/cleanup_mesh2.py @@ -0,0 +1,190 @@ +import trimesh +import numpy as np +from typing import Optional, Dict, List +import os + + +def find_and_remove_bottom_planes( + scene: trimesh.Scene, + z_threshold: Optional[float] = None, + angle_threshold: float = 15.0, + min_area_ratio: float = 0.05, + naming_pattern: str = "world/geometry_" +) -> Dict[str, bool]: + """ + Find and remove bottom plane artefacts from all meshes in a scene that match the naming pattern. + + Args: + scene: The loaded trimesh Scene object + z_threshold: Absolute Z threshold for bottom detection (auto-detected if None) + angle_threshold: Maximum angle (degrees) from horizontal to consider as bottom plane + min_area_ratio: Minimum area ratio (compared to bounding box) to consider as artefact plane + naming_pattern: Pattern to identify geometry nodes to process + + Returns: + Dictionary mapping geometry names to whether they were modified + """ + results = {} + + # Iterate through all geometry nodes in the scene + for node_name in scene.graph.nodes_geometry: + if naming_pattern not in node_name: + continue + + # Get the mesh instance + geometry = scene.geometry[node_name] + if not isinstance(geometry, trimesh.Trimesh): + continue + + print(f"\nProcessing {node_name}...") + mesh = geometry.copy() + original_face_count = len(mesh.faces) + + # Calculate mesh properties + bounds = mesh.bounds + height = bounds[1][2] - bounds[0][2] + + # Auto-detect z-threshold if not provided + if z_threshold is None: + z_threshold = bounds[0][2] + height * 0.1 # Bottom 10% of mesh + + # Find faces in the bottom region + face_z_coords = mesh.vertices[mesh.faces][:, :, 2] + bottom_faces = np.all(face_z_coords < z_threshold, axis=1) + + if not np.any(bottom_faces): + print(f"No bottom faces found in {node_name}") + results[node_name] = False + continue + + # Get normals of bottom faces + bottom_normals = mesh.face_normals[bottom_faces] + + # Calculate angle from vertical (we want faces that are roughly horizontal) + angles = np.degrees(np.arccos(np.abs(bottom_normals[:, 2]))) + horizontal_faces = angles < angle_threshold + + if not np.any(horizontal_faces): + print(f"No horizontal faces found in bottom region of {node_name}") + results[node_name] = False + continue + + # Get all face indices that meet our criteria + candidate_faces = np.where(bottom_faces)[0][horizontal_faces] + + # Calculate area of candidate faces + candidate_area = mesh.area_faces[candidate_faces].sum() + total_area = mesh.area + + if candidate_area / total_area < min_area_ratio: + print(f"Bottom plane area ({candidate_area:.4f}) too small in {node_name}") + results[node_name] = False + continue + + # Create a submesh of just the candidate faces + plane_mesh = mesh.submesh([candidate_faces], append=True)[0] + + # Split the plane mesh into connected components + components = plane_mesh.split(only_watertight=False) + + if not components: + print(f"No components found in candidate faces of {node_name}") + results[node_name] = False + continue + + # Find the largest component (likely our plane) + largest_component = max(components, key=lambda x: x.area) + + # Verify this component is large enough + if largest_component.area / total_area < min_area_ratio: + print(f"Largest component in bottom region is too small in {node_name}") + results[node_name] = False + continue + + # Get the vertex indices of the largest component + component_verts = set(largest_component.vertices.view(np.ndarray).flatten()) + + # Find all faces in the original mesh that exclusively use these vertices + original_faces_to_remove = [] + for face_idx in candidate_faces: + face_verts = mesh.faces[face_idx] + if all(v in component_verts for v in face_verts): + original_faces_to_remove.append(face_idx) + + if not original_faces_to_remove: + print(f"Could not map component back to original faces in {node_name}") + results[node_name] = False + continue + + # Remove the faces from the original mesh + mask = np.ones(len(mesh.faces), dtype=bool) + mask[original_faces_to_remove] = False + mesh.update_faces(mask) + mesh.remove_unreferenced_vertices() + + # Update the scene with the modified mesh + scene.geometry[node_name] = mesh + + removed_faces = original_face_count - len(mesh.faces) + print(f"Removed {removed_faces} faces from {node_name}") + results[node_name] = True + + return results + + +def process_glb_files( + input_path: str, + output_path: str, + z_threshold: Optional[float] = None, + angle_threshold: float = 15.0, + min_area_ratio: float = 0.05, + naming_pattern: str = "world/geometry_" +): + """ + Process a GLB file to remove bottom plane artefacts from specific geometry nodes. + + Args: + input_path: Path to input GLB file + output_path: Path to save cleaned GLB file + z_threshold: Absolute Z threshold for bottom detection (auto-detected if None) + angle_threshold: Maximum angle (degrees) from horizontal to consider as bottom plane + min_area_ratio: Minimum area ratio (compared to bounding box) to consider as artefact plane + naming_pattern: Pattern to identify geometry nodes to process + """ + # Load the scene + scene = trimesh.load(input_path) + + if not isinstance(scene, trimesh.Scene): + print("Input file is not a scene, using single mesh handling") + # Handle as single mesh (using previous approach) + remove_bottom_plane(input_path, output_path, z_threshold, angle_threshold, min_area_ratio) + return + + print(f"Loaded scene with {len(scene.geometry)} geometries") + + # Process all matching geometry nodes + results = find_and_remove_bottom_planes( + scene, + z_threshold, + angle_threshold, + min_area_ratio, + naming_pattern + ) + + # Count how many were modified + modified_count = sum(results.values()) + print(f"\nModified {modified_count} out of {len(results)} matching geometries") + + if modified_count > 0: + # Save the modified scene + scene.export(output_path) + print(f"Saved cleaned scene to {output_path}") + else: + print("No modifications made - output file not created") + +# Example usage +if __name__ == "__main__": + input_glb = "./blo/model.glb" + output_glb = "./blo/output_cleaned.glb" + + process_glb_files(input_glb, output_glb) \ No newline at end of file diff --git a/cleanup_mesh3.py b/cleanup_mesh3.py new file mode 100644 index 00000000..f4ee135f --- /dev/null +++ b/cleanup_mesh3.py @@ -0,0 +1,167 @@ +import os +import numpy as np +import trimesh +from pygltflib import GLTF2 +from tqdm import tqdm + + +def is_bottom_plane(vertices, faces, z_threshold=0.1, plane_area_threshold=0.5): + """ + Identify if there's a large planar surface at the bottom of the mesh. + + Args: + vertices: Mesh vertices + faces: Mesh faces + z_threshold: Height threshold to consider as "bottom" + plane_area_threshold: Minimum area to consider as significant plane + + Returns: + Tuple of (bool indicating if plane found, face indices of the plane) + """ + # Find vertices near the bottom + min_z = np.min(vertices[:, 2]) + bottom_vertices = vertices[:, 2] < min_z + z_threshold + + # Get faces that use these bottom vertices + bottom_faces_mask = np.any(bottom_vertices[faces], axis=1) + bottom_faces = faces[bottom_faces_mask] + + if len(bottom_faces) == 0: + return False, np.array([]) + + # Calculate normals of bottom faces + mesh = trimesh.Trimesh(vertices=vertices, faces=faces) + face_normals = mesh.face_normals[bottom_faces_mask] + + # Find faces that are roughly horizontal (normal points mostly up/down) + vertical_normals = np.abs(face_normals[:, 2]) > 0.9 # 0.9 means ~25 degree tolerance + horizontal_faces = bottom_faces[vertical_normals] + + if len(horizontal_faces) == 0: + return False, np.array([]) + + # Calculate area of horizontal faces + horizontal_face_indices = np.where(bottom_faces_mask)[0][vertical_normals] + areas = mesh.area_faces[horizontal_face_indices] + total_area = np.sum(areas) + + if total_area < plane_area_threshold: + return False, np.array([]) + + return True, horizontal_face_indices + + +def remove_plane_from_glb(glb_path, output_path=None): + """ + Load a GLB file, detect and remove bottom plane, then save it back. + + Args: + glb_path: Path to input GLB file + output_path: Path to save cleaned GLB (None to overwrite) + """ + if output_path is None: + output_path = glb_path + + # Load GLB file + gltf = GLTF2().load(glb_path) + + # Process each mesh in the GLB + for mesh in gltf.meshes: + for primitive in mesh.primitives: + # Get accessor indices + pos_accessor = gltf.accessors[primitive.attributes.POSITION] + normal_accessor = gltf.accessors[primitive.attributes.NORMAL] + indices_accessor = gltf.accessors[primitive.indices] + + # Get buffer views + pos_view = gltf.bufferViews[pos_accessor.bufferView] + normal_view = gltf.bufferViews[normal_accessor.bufferView] + indices_view = gltf.bufferViews[indices_accessor.bufferView] + + # Get actual data + buffer = gltf.buffers[pos_view.buffer] + pos_data = gltf.get_data_from_buffer_uri(buffer.uri) + pos_array = np.frombuffer(pos_data, dtype=np.float32, + count=pos_accessor.count * 3, + offset=pos_view.byteOffset) + pos_array = pos_array.reshape(-1, 3) + + indices_data = gltf.get_data_from_buffer_uri(buffer.uri) + indices_array = np.frombuffer(indices_data, + dtype=np.uint16 if indices_accessor.componentType == 5123 else np.uint32, + count=indices_accessor.count, + offset=indices_view.byteOffset) + indices_array = indices_array.reshape(-1, 3) + + # Check for bottom plane + has_plane, plane_face_indices = is_bottom_plane(pos_array, indices_array) + + if has_plane: + print(f"Found bottom plane with {len(plane_face_indices)} faces in {glb_path}") + + # Remove the plane faces + mask = np.ones(len(indices_array), dtype=bool) + mask[plane_face_indices] = False + new_indices = indices_array[mask] + + # Update the indices in the GLB structure + indices_accessor.count = len(new_indices) * 3 + indices_data = new_indices.tobytes() + + # Update buffer view byte length + indices_view.byteLength = len(indices_data) + + # Update the buffer + buffer_data = gltf.get_data_from_buffer_uri(buffer.uri) + buffer_data = buffer_data[:indices_view.byteOffset] + indices_data + buffer_data[ + indices_view.byteOffset + len( + indices_data):] + gltf.set_data_from_buffer_uri(buffer.uri, buffer_data) + + # Save the cleaned GLB + gltf.save(output_path) + + +def process_glb_directory(root_dir, output_dir=None): + """ + Process all GLB files in a directory and its subdirectories. + + Args: + root_dir: Root directory containing GLB files + output_dir: Output directory (None to overwrite original files) + """ + if output_dir is not None and not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Find all GLB files + glb_files = [] + for root, _, files in os.walk(root_dir): + for file in files: + if file.endswith('.glb'): + glb_files.append(os.path.join(root, file)) + + # Process each file + for glb_path in tqdm(glb_files, desc="Processing GLB files"): + if output_dir is not None: + rel_path = os.path.relpath(glb_path, root_dir) + output_path = os.path.join(output_dir, rel_path) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + else: + output_path = None + + try: + remove_plane_from_glb(glb_path, output_path) + except Exception as e: + print(f"Error processing {glb_path}: {str(e)}") + + +if __name__ == "__main__": + import argparse + + #parser = argparse.ArgumentParser(description="Remove bottom planes from GLB meshes") + #parser.add_argument("input_dir", help="Input directory containing GLB files") + #parser.add_argument("--output_dir", help="Output directory (optional, overwrites by default)") + + #args = parser.parse_args() + + process_glb_directory("./blo/model.glb", "./blo/modelCLEANED.glb") \ No newline at end of file diff --git a/cleanup_mesh4.py b/cleanup_mesh4.py new file mode 100644 index 00000000..d134cdab --- /dev/null +++ b/cleanup_mesh4.py @@ -0,0 +1,97 @@ +import numpy as np +import trimesh +from pygltflib import GLTF2 +from pygltflib.utils import gltf2glb + + +def remove_bottom_plane(glb_path, output_path, height_threshold=0.1, plane_normal_tolerance=0.9): + """ + Remove unwanted planar artifacts at the bottom of a GLB mesh. + + Args: + glb_path (str): Path to input GLB file + output_path (str): Path to save cleaned GLB file + height_threshold (float): Height threshold to consider vertices as "bottom" + plane_normal_tolerance (float): Tolerance for detecting planar surfaces (0.9 = mostly aligned with XY plane) + """ + # Load the GLB file + gltf = GLTF2().load(glb_path) + + # Process each mesh in the GLB + for mesh_index, mesh in enumerate(gltf.meshes): + for primitive in mesh.primitives: + # Get the position accessor and buffer view + pos_accessor = gltf.accessors[primitive.attributes.POSITION] + pos_buffer_view = gltf.bufferViews[pos_accessor.bufferView] + pos_buffer = gltf.buffers[pos_buffer_view.buffer] + + # Get the position data as numpy array + pos_data = gltf.get_data_from_buffer_uri(pos_buffer.uri) + pos_array = np.frombuffer(pos_data, dtype=np.float32, + count=pos_accessor.count * 3, + offset=pos_buffer_view.byteOffset) + pos_array = pos_array.reshape(-1, 3) + + # Get indices if they exist + if hasattr(primitive, 'indices'): + indices_accessor = gltf.accessors[primitive.indices] + indices_buffer_view = gltf.bufferViews[indices_accessor.bufferView] + indices_buffer = gltf.buffers[indices_buffer_view.buffer] + indices_data = gltf.get_data_from_buffer_uri(indices_buffer.uri) + indices_array = np.frombuffer(indices_data, + dtype=np.uint16 if indices_accessor.componentType == 5123 else np.uint32, + count=indices_accessor.count, + offset=indices_buffer_view.byteOffset) + else: + indices_array = np.arange(len(pos_array)) + + # Find bottom vertices + min_z = np.min(pos_array[:, 2]) + bottom_vertices = pos_array[:, 2] < (min_z + height_threshold) + + if np.sum(bottom_vertices) < 3: # Not enough vertices to form a plane + continue + + # Check if these vertices form a plane (normal mostly in Z direction) + bottom_pos = pos_array[bottom_vertices] + centroid = np.mean(bottom_pos, axis=0) + cov = np.cov((bottom_pos - centroid).T) + _, eig_vecs = np.linalg.eig(cov) + normal = eig_vecs[:, np.argmin(np.abs(eig_vecs))] + normal = normal / np.linalg.norm(normal) + + # If the normal is mostly vertical (aligned with Z axis) + if abs(normal[2]) > plane_normal_tolerance: + print(f"Found bottom plane in mesh {mesh_index} with normal {normal}") + + # Create a mask of faces that contain bottom vertices + if len(indices_array) % 3 == 0: # Assuming triangles + faces = indices_array.reshape(-1, 3) + bottom_faces_mask = np.any(bottom_vertices[faces], axis=1) + + # Keep only non-bottom faces + if hasattr(primitive, 'indices'): + # For indexed geometry + valid_faces = faces[~bottom_faces_mask] + new_indices = valid_faces.flatten() + + # Update the indices buffer + indices_buffer.data = new_indices.tobytes() + indices_accessor.count = len(new_indices) + else: + # For non-indexed geometry + valid_vertices_mask = ~bottom_vertices + new_pos_array = pos_array[valid_vertices_mask] + + # Update the position buffer + pos_buffer.data = new_pos_array.tobytes() + pos_accessor.count = len(new_pos_array) + + # Save the cleaned GLB + gltf.save(output_path) + + +# Example usage +input_glb = "./blo/model.glb" +output_glb = "./blo/OUT.glb" +remove_bottom_plane(input_glb, output_glb) \ No newline at end of file diff --git a/cleanup_mesh5.py b/cleanup_mesh5.py new file mode 100644 index 00000000..caa1e2ab --- /dev/null +++ b/cleanup_mesh5.py @@ -0,0 +1,237 @@ +import trimesh +import numpy as np + +# Load the GLB file with error handling +def load_mesh(filename: str) -> trimesh.Trimesh: + try: + mesh = trimesh.load(filename) + return mesh.geometry["geometry_0"] + if not isinstance(mesh, trimesh.Trimesh): + raise ValueError("Loaded object is not a single mesh.") + return mesh + except Exception as e: + raise ValueError(f"Failed to load mesh: {e}") + + +import numpy as np + + +def remove_ground_plane(points, height_step=0.01, centroid_tolerance=0.1, max_iterations=100): + """ + Remove ground plane from a point cloud using statistical analysis. + + Args: + points: numpy array of shape (N, 3) containing the point cloud + height_step: how much to increase the height of the removal rectangle each iteration + centroid_tolerance: how close the centroid needs to be to the middle height range + max_iterations: maximum number of iterations to perform + + Returns: + numpy array of points with ground plane removed + """ + original_points = points.copy() + points = points.copy() + + # Calculate initial metrics + min_coords = np.min(points, axis=0) + max_coords = np.max(points, axis=0) + height_range = max_coords[2] - min_coords[2] + + for iteration in range(max_iterations): + # 1. Calculate current centroid + centroid = np.mean(points, axis=0) + + # 2. Find the 4 lowest points that are furthest from centroid in x-y plane + # Get points in the bottom 10% height + height_threshold = min_coords[2] + 0.1 * height_range + bottom_points = points[points[:, 2] < height_threshold] + + if len(bottom_points) == 0: + break + + # Calculate x-y distances from centroid + xy_distances = np.linalg.norm(bottom_points[:, :2] - centroid[:2], axis=1) + + # Get indices of 4 furthest points in x-y plane + furthest_indices = np.argpartition(xy_distances, -4)[-4:] + ground_corners = bottom_points[furthest_indices] + + # 3. Create a bounding box from these corners + ground_min = np.min(ground_corners, axis=0) + ground_max = np.max(ground_corners, axis=0) + + # Expand the height of the bounding box + ground_max[2] += height_step + + # 4. Remove points within this bounding box + in_ground = np.all((points >= ground_min) & (points <= ground_max), axis=1) + points = points[~in_ground] + + # Check termination conditions + new_centroid = np.mean(points, axis=0) + new_height_range = np.max(points[:, 2]) - np.min(points[:, 2]) + + # Condition 1: Centroid is adequately centered vertically + centroid_height_ratio = (new_centroid[2] - np.min(points[:, 2])) / new_height_range + centroid_centered = abs(centroid_height_ratio - 0.5) < centroid_tolerance + + # Condition 2: The furthest corners are now closer to centroid (ground removed) + new_bottom_points = points[points[:, 2] < (np.min(points[:, 2]) + 0.1 * new_height_range)] + if len(new_bottom_points) > 0: + new_xy_distances = np.linalg.norm(new_bottom_points[:, :2] - new_centroid[:2], axis=1) + avg_distance_reduced = np.mean(new_xy_distances) < 0.5 * np.mean(xy_distances) + else: + avg_distance_reduced = True + + if centroid_centered and avg_distance_reduced: + break + + # Update for next iteration + min_coords = np.min(points, axis=0) + max_coords = np.max(points, axis=0) + height_range = max_coords[2] - min_coords[2] + + return points + +# Identify and remove the bottom plane artifact +def remove_bottom_plane(mesh: trimesh.Trimesh, height_threshold: float = 0.5) -> trimesh.Trimesh: + # Make a copy of the original mesh + mesh = mesh.copy() + + # Calculate mesh characteristics + vertices = mesh.vertices + min_z = np.min(vertices[:, 2]) + z_range = np.max(vertices[:, 2]) - min_z + + # Adaptive threshold based on mesh size + adaptive_threshold = min_z + height_threshold * z_range + + # Find all faces that are entirely within the bottom plane region + face_z_values = vertices[mesh.faces][:, :, 2] # Z-coordinates of all face vertices + max_face_z = np.max(face_z_values, axis=1) # Max z for each face + + # Faces where all vertices are below the threshold + faces_to_remove = max_face_z <= adaptive_threshold + + # Remove these faces + mesh.update_faces(~faces_to_remove) + + # Remove unreferenced vertices to clean up + mesh.remove_unreferenced_vertices() + + return mesh + + +def remove_ground_plane_from_mesh(mesh: trimesh.Trimesh, + height_step=0.01, + centroid_tolerance=0.1, + max_iterations=100) -> trimesh.Trimesh: + """ + Remove ground plane from a mesh using statistical analysis of its vertices. + + Args: + mesh: Input mesh to process + height_step: How much to increase the height of the removal volume each iteration + centroid_tolerance: How close the centroid needs to be to the middle height range + max_iterations: Maximum number of iterations to perform + + Returns: + Processed mesh with ground plane removed + """ + # Make a copy of the original mesh + mesh = mesh.copy() + vertices = mesh.vertices + + # Calculate initial metrics + min_coords = np.min(vertices, axis=0) + max_coords = np.max(vertices, axis=0) + height_range = max_coords[2] - min_coords[2] + + # Initialize removal volume + removal_min = None + removal_max = None + + for iteration in range(max_iterations): + # 1. Calculate current centroid + centroid = np.mean(vertices, axis=0) + + # 2. Find the 4 lowest points that are furthest from centroid in x-y plane + # Get points in the bottom 10% height + height_threshold = min_coords[2] + 0.1 * height_range + bottom_points = vertices[vertices[:, 2] < height_threshold] + + if len(bottom_points) == 0: + break + + # Calculate x-y distances from centroid + xy_distances = np.linalg.norm(bottom_points[:, :2] - centroid[:2], axis=1) + + # Get indices of 4 furthest points in x-y plane + furthest_indices = np.argpartition(xy_distances, -4)[-4:] + ground_corners = bottom_points[furthest_indices] + + # 3. Create a bounding box from these corners + ground_min = np.min(ground_corners, axis=0) + ground_max = np.max(ground_corners, axis=0) + + # Expand the height of the bounding box + ground_max[2] += height_step + + # 4. Find faces that are entirely within this bounding box + # Get all vertices of each face + face_vertices = vertices[mesh.faces] + + # Check if all 3 vertices of each face are within the bounding box + in_ground = np.all( + np.all((face_vertices >= ground_min) & (face_vertices <= ground_max), axis=2), + axis=1) + + # Remove these faces + mesh.update_faces(~in_ground) + + # Check termination conditions + vertices = mesh.vertices # Get updated vertices after face removal + if len(vertices) == 0: + break + + new_centroid = np.mean(vertices, axis=0) + new_height_range = np.max(vertices[:, 2]) - np.min(vertices[:, 2]) + + # Condition 1: Centroid is adequately centered vertically + centroid_height_ratio = (new_centroid[2] - np.min(vertices[:, 2])) / new_height_range + centroid_centered = abs(centroid_height_ratio - 0.5) < centroid_tolerance + + # Condition 2: The furthest corners are now closer to centroid (ground removed) + new_bottom_points = vertices[vertices[:, 2] < (np.min(vertices[:, 2]) + 0.1 * new_height_range)] + if len(new_bottom_points) > 0: + new_xy_distances = np.linalg.norm(new_bottom_points[:, :2] - new_centroid[:2], axis=1) + avg_distance_reduced = np.mean(new_xy_distances) < 0.5 * np.mean(xy_distances) + else: + avg_distance_reduced = True + + if centroid_centered and avg_distance_reduced: + break + + # Update for next iteration + min_coords = np.min(vertices, axis=0) + max_coords = np.max(vertices, axis=0) + height_range = max_coords[2] - min_coords[2] + + # Clean up any unreferenced vertices + mesh.remove_unreferenced_vertices() + + return mesh + +# Save the cleaned mesh +def save_mesh(mesh: trimesh.Trimesh, filename: str, file_format: str = "glb"): + mesh.export(filename, file_type=file_format) + +if __name__ == "__main__": + input_file = "./bla/model.glb" + output_file = "./bla/model2.glb" + + mesh = load_mesh(input_file) + cleaned_mesh = remove_ground_plane_from_mesh(mesh) + save_mesh(cleaned_mesh, output_file) + + #print(f"Cleaned mesh saved to {args.output_file}.") \ No newline at end of file diff --git a/flaskclient.py b/flaskclient.py index fa279fb3..dcb63549 100644 --- a/flaskclient.py +++ b/flaskclient.py @@ -30,13 +30,15 @@ def generate_from_single_image(self, image_path, params=None): response = requests.post(url, files=files, data=params) return response.json() - + def generate_from_multiple_images(self, image_paths, params=None): """ Generate 3D model from multiple images - + Args: - image_paths (list): List of paths to input images + image_paths (str or list): Either: + - A list of paths to input images, or + - A directory path containing images (will load all image files from the directory) params (dict): Optional parameters including: - seed (int): Random seed - ss_guidance_strength (float): Guidance strength for sparse structure generation @@ -44,18 +46,38 @@ def generate_from_multiple_images(self, image_paths, params=None): - slat_guidance_strength (float): Guidance strength for structured latent generation - slat_sampling_steps (int): Sampling steps for structured latent generation - multiimage_algo (str): Algorithm for multi-image generation ('stochastic' or 'multidiffusion') - + Returns: dict: Response containing session_id and download URLs """ if params is None: params = {} - + + # If image_paths is a directory, find all image files in it + if isinstance(image_paths, str) and os.path.isdir(image_paths): + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif') + image_paths = [ + os.path.join(image_paths, f) + for f in os.listdir(image_paths) + if f.lower().endswith(image_extensions) + ] + if not image_paths: + raise ValueError(f"No image files found in directory: {image_paths}") + + # Ensure image_paths is a list at this point + if not isinstance(image_paths, list): + raise ValueError("image_paths must be either a list of image paths or a directory path") + url = f"{self.base_url}/generate_from_multiple_images" - files = [('images', (os.path.basename(path), open(path, 'rb')) for path in image_paths] - - response = requests.post(url, files=files, data=params) - return response.json() + files = [('images', (os.path.basename(path), open(path, 'rb'))) for path in image_paths] + + try: + response = requests.post(url, files=files, data=params) + return response.json() + finally: + # Ensure all files are closed after the request + for _, (_, file_obj) in files: + file_obj.close() def extract_glb(self, session_id, params=None): """ @@ -100,14 +122,118 @@ def download_file(self, url, save_path=None): return save_path + def generate_and_download_from_single_image(self, image_path, target_dir=None, params=None): + """ + Generate 3D model from a single image and download all artifacts + + Args: + image_path (str): Path to the input image + target_dir (str): Optional target directory to save files (defaults to /tmp/random_uuid) + params (dict): Optional generation parameters + + Returns: + dict: Paths to downloaded files + """ + if target_dir is None: + target_dir = f"/tmp/{uuid.uuid4()}" + os.makedirs(target_dir, exist_ok=True) + + # Generate the 3D model + gen_result = self.generate_from_single_image(image_path, params) + + # Download preview + preview_path = os.path.join(target_dir, 'preview.mp4') + self.download_file( + f"{self.base_url}{gen_result['preview_url']}", + preview_path + ) + + # Extract and download GLB + glb_result = self.extract_glb(gen_result['session_id'], params) + glb_path = os.path.join(target_dir, 'model.glb') + self.download_file( + f"{self.base_url}{glb_result['glb_url']}", + glb_path + ) + + return { + 'preview_path': preview_path, + 'glb_path': glb_path, + 'session_id': gen_result['session_id'], + 'target_dir': target_dir + } + + def generate_and_download_from_multiple_images(self, image_paths, target_dir=None, params=None): + """ + Generate 3D model from multiple images and download all artifacts + + Args: + image_paths (list): List of paths to input images + target_dir (str): Optional target directory to save files (defaults to /tmp/random_uuid) + params (dict): Optional generation parameters + + Returns: + dict: Paths to downloaded files + """ + if target_dir is None: + target_dir = f"/tmp/{uuid.uuid4()}" + os.makedirs(target_dir, exist_ok=True) + + # Generate the 3D model + gen_result = self.generate_from_multiple_images(image_paths, params) + + # Download preview + preview_path = os.path.join(target_dir, 'preview.mp4') + self.download_file( + f"{self.base_url}{gen_result['preview_url']}", + preview_path + ) + + # Extract and download GLB + glb_result = self.extract_glb(gen_result['session_id'], params) + glb_path = os.path.join(target_dir, 'model.glb') + self.download_file( + f"{self.base_url}{glb_result['glb_url']}", + glb_path + ) + + return { + 'preview_path': preview_path, + 'glb_path': glb_path, + 'session_id': gen_result['session_id'], + 'target_dir': target_dir + } + # Example usage if __name__ == '__main__': client = Trellis3DClient() + + multi_result = client.generate_and_download_from_multiple_images( + '/home/charlie/Desktop/Holodeck/hippo/datasets/sacha_kitchen/segments/6/rgb', + target_dir="./blo", + params={ + 'multiimage_algo': 'stochastic', + 'seed': 123 + } + ) + exit() + + single_result = client.generate_and_download_from_single_image( + 'test/000.png', + target_dir="./bla", + params={ + 'seed': 42, + 'ss_guidance_strength': 7.5, + 'slat_guidance_strength': 3.0 + } + ) + + exit() # Example 1: Single image generation print("Generating from single image...") single_result = client.generate_from_single_image( - 'example_image.png', + 'test/000.png', params={ 'seed': 42, 'ss_guidance_strength': 7.5, From 7d83026a9d8de91dd31c9e6af3bba715ca67a9e6 Mon Sep 17 00:00:00 2001 From: Velythyl Date: Tue, 20 May 2025 14:14:58 -0400 Subject: [PATCH 3/8] flask server and client --- flaskserver.py | 6 ++++-- trellis/pipelines/trellis_image_to_3d.py | 3 +++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/flaskserver.py b/flaskserver.py index bb5bd0b6..fc23f690 100644 --- a/flaskserver.py +++ b/flaskserver.py @@ -40,6 +40,8 @@ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: }, } +from typing import Tuple + def unpack_state(state: dict) -> Tuple[Gaussian, edict]: gs = Gaussian( aabb=state['gaussian']['aabb'], @@ -103,7 +105,7 @@ def generate_from_single_image(): image, seed=seed, formats=["gaussian", "mesh"], - preprocess_image=False, + preprocess_image=True, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, @@ -180,7 +182,7 @@ def generate_from_multiple_images(): images, seed=seed, formats=["gaussian", "mesh"], - preprocess_image=False, + preprocess_image=True, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, diff --git a/trellis/pipelines/trellis_image_to_3d.py b/trellis/pipelines/trellis_image_to_3d.py index 3725da20..b6b6dcac 100644 --- a/trellis/pipelines/trellis_image_to_3d.py +++ b/trellis/pipelines/trellis_image_to_3d.py @@ -84,6 +84,8 @@ def preprocess_image(self, input: Image.Image) -> Image.Image: Preprocess the input image. """ # if has alpha channel, use it directly; otherwise, remove background + + print("Processing images...") has_alpha = False if input.mode == 'RGBA': alpha = np.array(input)[:, :, 3] @@ -92,6 +94,7 @@ def preprocess_image(self, input: Image.Image) -> Image.Image: if has_alpha: output = input else: + print("Removing background...") input = input.convert('RGB') max_size = max(input.size) scale = min(1, 1024 / max_size) From 6d6bf15f700cb295acd31ff10f18841e81dfd9bd Mon Sep 17 00:00:00 2001 From: Charlie Gauthier Date: Tue, 20 May 2025 14:15:51 -0400 Subject: [PATCH 4/8] flask client fix --- app_cli.py | 47 ---------- cleanup_mesh.py | 47 ---------- cleanup_mesh2.py | 190 ------------------------------------- cleanup_mesh3.py | 167 --------------------------------- cleanup_mesh4.py | 97 ------------------- cleanup_mesh5.py | 237 ----------------------------------------------- flaskclient.py | 11 +++ 7 files changed, 11 insertions(+), 785 deletions(-) delete mode 100644 app_cli.py delete mode 100644 cleanup_mesh.py delete mode 100644 cleanup_mesh2.py delete mode 100644 cleanup_mesh3.py delete mode 100644 cleanup_mesh4.py delete mode 100644 cleanup_mesh5.py diff --git a/app_cli.py b/app_cli.py deleted file mode 100644 index 7f89a4eb..00000000 --- a/app_cli.py +++ /dev/null @@ -1,47 +0,0 @@ -import requests -from PIL import Image -import io -import os - -class Gradio3DClient: - def __init__(self, base_url: str): - self.base_url = base_url.rstrip('/') - - def send_image(self, image_path: str) -> dict: - """Sends an image to the Gradio app and generates a 3D model.""" - with open(image_path, "rb") as f: - response = requests.post(f"{self.base_url}/run/image_to_3d", files={"image": f}) - response.raise_for_status() - return response.json()["data"][0] - - def extract_glb(self, state: dict, mesh_simplify: float = 0.95, texture_size: int = 1024) -> str: - """Extracts a GLB file from the generated 3D model.""" - response = requests.post(f"{self.base_url}/run/extract_glb", json={"state": state, "mesh_simplify": mesh_simplify, "texture_size": texture_size}) - response.raise_for_status() - return response.json()["data"][0] - - def generate_views(self, state: dict, num_views: int = 4) -> list: - """Generates views of the 3D model from different angles.""" - images = [] - for angle in range(0, 360, 360 // num_views): - response = requests.post(f"{self.base_url}/run/render_view", json={"state": state, "angle": angle}) - response.raise_for_status() - images.append(response.content) - return images - - def process_image(self, image_path: str) -> dict: - """Full pipeline: Send image, generate model, extract GLB, and generate views.""" - model_info = self.send_image(image_path) - glb_path = self.extract_glb(model_info) - views = self.generate_views(model_info) - - return { - "original_image": image_path, - "glb_path": glb_path, - "views": views - } - -if __name__ == "__main__": - client = Gradio3DClient("http://localhost:7860") - result = client.process_image("/home/charlie/Desktop/Holodeck/hippo/datasets/sacha_kitchen/segments/2/rgb/000.png") - print("Process complete. Files:", result) diff --git a/cleanup_mesh.py b/cleanup_mesh.py deleted file mode 100644 index 63943900..00000000 --- a/cleanup_mesh.py +++ /dev/null @@ -1,47 +0,0 @@ -import trimesh -import numpy as np - -# Function to clean the thin base plane of a mesh -def clean_base_plane(mesh_path: str, output_path: str, threshold_ratio: float = 0.01): - # Load the mesh - mesh = trimesh.load(mesh_path) - - # If the mesh is a Scene, merge it into a single Trimesh - if isinstance(mesh, trimesh.Scene): - mesh = mesh.dump(concatenate=True) - - # Ensure it is a valid Trimesh - if not isinstance(mesh, trimesh.Trimesh): - raise ValueError("Loaded file is not a valid Trimesh object.") - - # Calculate the average Z position of each face - face_z_heights = mesh.vertices[mesh.faces].mean(axis=1)[:, 2] - - # Determine the Z threshold (bottom 1% by default) - z_threshold = np.quantile(face_z_heights, threshold_ratio) - - # Identify faces below the threshold - face_mask = face_z_heights > z_threshold - - # Keep only the faces above the threshold - cleaned_faces = mesh.faces[face_mask] - - # Get the unique vertices used by these faces - unique_vertices = np.unique(cleaned_faces) - cleaned_vertices = mesh.vertices[unique_vertices] - - # Remap face indices to the cleaned vertices - vertex_map = {old_idx: new_idx for new_idx, old_idx in enumerate(unique_vertices)} - remapped_faces = np.vectorize(vertex_map.get)(cleaned_faces) - - # Create a new mesh using the cleaned faces and vertices - cleaned_mesh = trimesh.Trimesh(vertices=cleaned_vertices, faces=remapped_faces) - - # Export the cleaned mesh - cleaned_mesh.export(output_path) - - print(f"Cleaned mesh saved to {output_path}") - - -if __name__ == "__main__": - clean_base_plane("./blo/model.glb", "./blo/model_clean.glb") \ No newline at end of file diff --git a/cleanup_mesh2.py b/cleanup_mesh2.py deleted file mode 100644 index 6fdeb290..00000000 --- a/cleanup_mesh2.py +++ /dev/null @@ -1,190 +0,0 @@ -import trimesh -import numpy as np -from typing import Optional, Dict, List -import os - - -def find_and_remove_bottom_planes( - scene: trimesh.Scene, - z_threshold: Optional[float] = None, - angle_threshold: float = 15.0, - min_area_ratio: float = 0.05, - naming_pattern: str = "world/geometry_" -) -> Dict[str, bool]: - """ - Find and remove bottom plane artefacts from all meshes in a scene that match the naming pattern. - - Args: - scene: The loaded trimesh Scene object - z_threshold: Absolute Z threshold for bottom detection (auto-detected if None) - angle_threshold: Maximum angle (degrees) from horizontal to consider as bottom plane - min_area_ratio: Minimum area ratio (compared to bounding box) to consider as artefact plane - naming_pattern: Pattern to identify geometry nodes to process - - Returns: - Dictionary mapping geometry names to whether they were modified - """ - results = {} - - # Iterate through all geometry nodes in the scene - for node_name in scene.graph.nodes_geometry: - if naming_pattern not in node_name: - continue - - # Get the mesh instance - geometry = scene.geometry[node_name] - if not isinstance(geometry, trimesh.Trimesh): - continue - - print(f"\nProcessing {node_name}...") - mesh = geometry.copy() - original_face_count = len(mesh.faces) - - # Calculate mesh properties - bounds = mesh.bounds - height = bounds[1][2] - bounds[0][2] - - # Auto-detect z-threshold if not provided - if z_threshold is None: - z_threshold = bounds[0][2] + height * 0.1 # Bottom 10% of mesh - - # Find faces in the bottom region - face_z_coords = mesh.vertices[mesh.faces][:, :, 2] - bottom_faces = np.all(face_z_coords < z_threshold, axis=1) - - if not np.any(bottom_faces): - print(f"No bottom faces found in {node_name}") - results[node_name] = False - continue - - # Get normals of bottom faces - bottom_normals = mesh.face_normals[bottom_faces] - - # Calculate angle from vertical (we want faces that are roughly horizontal) - angles = np.degrees(np.arccos(np.abs(bottom_normals[:, 2]))) - horizontal_faces = angles < angle_threshold - - if not np.any(horizontal_faces): - print(f"No horizontal faces found in bottom region of {node_name}") - results[node_name] = False - continue - - # Get all face indices that meet our criteria - candidate_faces = np.where(bottom_faces)[0][horizontal_faces] - - # Calculate area of candidate faces - candidate_area = mesh.area_faces[candidate_faces].sum() - total_area = mesh.area - - if candidate_area / total_area < min_area_ratio: - print(f"Bottom plane area ({candidate_area:.4f}) too small in {node_name}") - results[node_name] = False - continue - - # Create a submesh of just the candidate faces - plane_mesh = mesh.submesh([candidate_faces], append=True)[0] - - # Split the plane mesh into connected components - components = plane_mesh.split(only_watertight=False) - - if not components: - print(f"No components found in candidate faces of {node_name}") - results[node_name] = False - continue - - # Find the largest component (likely our plane) - largest_component = max(components, key=lambda x: x.area) - - # Verify this component is large enough - if largest_component.area / total_area < min_area_ratio: - print(f"Largest component in bottom region is too small in {node_name}") - results[node_name] = False - continue - - # Get the vertex indices of the largest component - component_verts = set(largest_component.vertices.view(np.ndarray).flatten()) - - # Find all faces in the original mesh that exclusively use these vertices - original_faces_to_remove = [] - for face_idx in candidate_faces: - face_verts = mesh.faces[face_idx] - if all(v in component_verts for v in face_verts): - original_faces_to_remove.append(face_idx) - - if not original_faces_to_remove: - print(f"Could not map component back to original faces in {node_name}") - results[node_name] = False - continue - - # Remove the faces from the original mesh - mask = np.ones(len(mesh.faces), dtype=bool) - mask[original_faces_to_remove] = False - mesh.update_faces(mask) - mesh.remove_unreferenced_vertices() - - # Update the scene with the modified mesh - scene.geometry[node_name] = mesh - - removed_faces = original_face_count - len(mesh.faces) - print(f"Removed {removed_faces} faces from {node_name}") - results[node_name] = True - - return results - - -def process_glb_files( - input_path: str, - output_path: str, - z_threshold: Optional[float] = None, - angle_threshold: float = 15.0, - min_area_ratio: float = 0.05, - naming_pattern: str = "world/geometry_" -): - """ - Process a GLB file to remove bottom plane artefacts from specific geometry nodes. - - Args: - input_path: Path to input GLB file - output_path: Path to save cleaned GLB file - z_threshold: Absolute Z threshold for bottom detection (auto-detected if None) - angle_threshold: Maximum angle (degrees) from horizontal to consider as bottom plane - min_area_ratio: Minimum area ratio (compared to bounding box) to consider as artefact plane - naming_pattern: Pattern to identify geometry nodes to process - """ - # Load the scene - scene = trimesh.load(input_path) - - if not isinstance(scene, trimesh.Scene): - print("Input file is not a scene, using single mesh handling") - # Handle as single mesh (using previous approach) - remove_bottom_plane(input_path, output_path, z_threshold, angle_threshold, min_area_ratio) - return - - print(f"Loaded scene with {len(scene.geometry)} geometries") - - # Process all matching geometry nodes - results = find_and_remove_bottom_planes( - scene, - z_threshold, - angle_threshold, - min_area_ratio, - naming_pattern - ) - - # Count how many were modified - modified_count = sum(results.values()) - print(f"\nModified {modified_count} out of {len(results)} matching geometries") - - if modified_count > 0: - # Save the modified scene - scene.export(output_path) - print(f"Saved cleaned scene to {output_path}") - else: - print("No modifications made - output file not created") - -# Example usage -if __name__ == "__main__": - input_glb = "./blo/model.glb" - output_glb = "./blo/output_cleaned.glb" - - process_glb_files(input_glb, output_glb) \ No newline at end of file diff --git a/cleanup_mesh3.py b/cleanup_mesh3.py deleted file mode 100644 index f4ee135f..00000000 --- a/cleanup_mesh3.py +++ /dev/null @@ -1,167 +0,0 @@ -import os -import numpy as np -import trimesh -from pygltflib import GLTF2 -from tqdm import tqdm - - -def is_bottom_plane(vertices, faces, z_threshold=0.1, plane_area_threshold=0.5): - """ - Identify if there's a large planar surface at the bottom of the mesh. - - Args: - vertices: Mesh vertices - faces: Mesh faces - z_threshold: Height threshold to consider as "bottom" - plane_area_threshold: Minimum area to consider as significant plane - - Returns: - Tuple of (bool indicating if plane found, face indices of the plane) - """ - # Find vertices near the bottom - min_z = np.min(vertices[:, 2]) - bottom_vertices = vertices[:, 2] < min_z + z_threshold - - # Get faces that use these bottom vertices - bottom_faces_mask = np.any(bottom_vertices[faces], axis=1) - bottom_faces = faces[bottom_faces_mask] - - if len(bottom_faces) == 0: - return False, np.array([]) - - # Calculate normals of bottom faces - mesh = trimesh.Trimesh(vertices=vertices, faces=faces) - face_normals = mesh.face_normals[bottom_faces_mask] - - # Find faces that are roughly horizontal (normal points mostly up/down) - vertical_normals = np.abs(face_normals[:, 2]) > 0.9 # 0.9 means ~25 degree tolerance - horizontal_faces = bottom_faces[vertical_normals] - - if len(horizontal_faces) == 0: - return False, np.array([]) - - # Calculate area of horizontal faces - horizontal_face_indices = np.where(bottom_faces_mask)[0][vertical_normals] - areas = mesh.area_faces[horizontal_face_indices] - total_area = np.sum(areas) - - if total_area < plane_area_threshold: - return False, np.array([]) - - return True, horizontal_face_indices - - -def remove_plane_from_glb(glb_path, output_path=None): - """ - Load a GLB file, detect and remove bottom plane, then save it back. - - Args: - glb_path: Path to input GLB file - output_path: Path to save cleaned GLB (None to overwrite) - """ - if output_path is None: - output_path = glb_path - - # Load GLB file - gltf = GLTF2().load(glb_path) - - # Process each mesh in the GLB - for mesh in gltf.meshes: - for primitive in mesh.primitives: - # Get accessor indices - pos_accessor = gltf.accessors[primitive.attributes.POSITION] - normal_accessor = gltf.accessors[primitive.attributes.NORMAL] - indices_accessor = gltf.accessors[primitive.indices] - - # Get buffer views - pos_view = gltf.bufferViews[pos_accessor.bufferView] - normal_view = gltf.bufferViews[normal_accessor.bufferView] - indices_view = gltf.bufferViews[indices_accessor.bufferView] - - # Get actual data - buffer = gltf.buffers[pos_view.buffer] - pos_data = gltf.get_data_from_buffer_uri(buffer.uri) - pos_array = np.frombuffer(pos_data, dtype=np.float32, - count=pos_accessor.count * 3, - offset=pos_view.byteOffset) - pos_array = pos_array.reshape(-1, 3) - - indices_data = gltf.get_data_from_buffer_uri(buffer.uri) - indices_array = np.frombuffer(indices_data, - dtype=np.uint16 if indices_accessor.componentType == 5123 else np.uint32, - count=indices_accessor.count, - offset=indices_view.byteOffset) - indices_array = indices_array.reshape(-1, 3) - - # Check for bottom plane - has_plane, plane_face_indices = is_bottom_plane(pos_array, indices_array) - - if has_plane: - print(f"Found bottom plane with {len(plane_face_indices)} faces in {glb_path}") - - # Remove the plane faces - mask = np.ones(len(indices_array), dtype=bool) - mask[plane_face_indices] = False - new_indices = indices_array[mask] - - # Update the indices in the GLB structure - indices_accessor.count = len(new_indices) * 3 - indices_data = new_indices.tobytes() - - # Update buffer view byte length - indices_view.byteLength = len(indices_data) - - # Update the buffer - buffer_data = gltf.get_data_from_buffer_uri(buffer.uri) - buffer_data = buffer_data[:indices_view.byteOffset] + indices_data + buffer_data[ - indices_view.byteOffset + len( - indices_data):] - gltf.set_data_from_buffer_uri(buffer.uri, buffer_data) - - # Save the cleaned GLB - gltf.save(output_path) - - -def process_glb_directory(root_dir, output_dir=None): - """ - Process all GLB files in a directory and its subdirectories. - - Args: - root_dir: Root directory containing GLB files - output_dir: Output directory (None to overwrite original files) - """ - if output_dir is not None and not os.path.exists(output_dir): - os.makedirs(output_dir) - - # Find all GLB files - glb_files = [] - for root, _, files in os.walk(root_dir): - for file in files: - if file.endswith('.glb'): - glb_files.append(os.path.join(root, file)) - - # Process each file - for glb_path in tqdm(glb_files, desc="Processing GLB files"): - if output_dir is not None: - rel_path = os.path.relpath(glb_path, root_dir) - output_path = os.path.join(output_dir, rel_path) - os.makedirs(os.path.dirname(output_path), exist_ok=True) - else: - output_path = None - - try: - remove_plane_from_glb(glb_path, output_path) - except Exception as e: - print(f"Error processing {glb_path}: {str(e)}") - - -if __name__ == "__main__": - import argparse - - #parser = argparse.ArgumentParser(description="Remove bottom planes from GLB meshes") - #parser.add_argument("input_dir", help="Input directory containing GLB files") - #parser.add_argument("--output_dir", help="Output directory (optional, overwrites by default)") - - #args = parser.parse_args() - - process_glb_directory("./blo/model.glb", "./blo/modelCLEANED.glb") \ No newline at end of file diff --git a/cleanup_mesh4.py b/cleanup_mesh4.py deleted file mode 100644 index d134cdab..00000000 --- a/cleanup_mesh4.py +++ /dev/null @@ -1,97 +0,0 @@ -import numpy as np -import trimesh -from pygltflib import GLTF2 -from pygltflib.utils import gltf2glb - - -def remove_bottom_plane(glb_path, output_path, height_threshold=0.1, plane_normal_tolerance=0.9): - """ - Remove unwanted planar artifacts at the bottom of a GLB mesh. - - Args: - glb_path (str): Path to input GLB file - output_path (str): Path to save cleaned GLB file - height_threshold (float): Height threshold to consider vertices as "bottom" - plane_normal_tolerance (float): Tolerance for detecting planar surfaces (0.9 = mostly aligned with XY plane) - """ - # Load the GLB file - gltf = GLTF2().load(glb_path) - - # Process each mesh in the GLB - for mesh_index, mesh in enumerate(gltf.meshes): - for primitive in mesh.primitives: - # Get the position accessor and buffer view - pos_accessor = gltf.accessors[primitive.attributes.POSITION] - pos_buffer_view = gltf.bufferViews[pos_accessor.bufferView] - pos_buffer = gltf.buffers[pos_buffer_view.buffer] - - # Get the position data as numpy array - pos_data = gltf.get_data_from_buffer_uri(pos_buffer.uri) - pos_array = np.frombuffer(pos_data, dtype=np.float32, - count=pos_accessor.count * 3, - offset=pos_buffer_view.byteOffset) - pos_array = pos_array.reshape(-1, 3) - - # Get indices if they exist - if hasattr(primitive, 'indices'): - indices_accessor = gltf.accessors[primitive.indices] - indices_buffer_view = gltf.bufferViews[indices_accessor.bufferView] - indices_buffer = gltf.buffers[indices_buffer_view.buffer] - indices_data = gltf.get_data_from_buffer_uri(indices_buffer.uri) - indices_array = np.frombuffer(indices_data, - dtype=np.uint16 if indices_accessor.componentType == 5123 else np.uint32, - count=indices_accessor.count, - offset=indices_buffer_view.byteOffset) - else: - indices_array = np.arange(len(pos_array)) - - # Find bottom vertices - min_z = np.min(pos_array[:, 2]) - bottom_vertices = pos_array[:, 2] < (min_z + height_threshold) - - if np.sum(bottom_vertices) < 3: # Not enough vertices to form a plane - continue - - # Check if these vertices form a plane (normal mostly in Z direction) - bottom_pos = pos_array[bottom_vertices] - centroid = np.mean(bottom_pos, axis=0) - cov = np.cov((bottom_pos - centroid).T) - _, eig_vecs = np.linalg.eig(cov) - normal = eig_vecs[:, np.argmin(np.abs(eig_vecs))] - normal = normal / np.linalg.norm(normal) - - # If the normal is mostly vertical (aligned with Z axis) - if abs(normal[2]) > plane_normal_tolerance: - print(f"Found bottom plane in mesh {mesh_index} with normal {normal}") - - # Create a mask of faces that contain bottom vertices - if len(indices_array) % 3 == 0: # Assuming triangles - faces = indices_array.reshape(-1, 3) - bottom_faces_mask = np.any(bottom_vertices[faces], axis=1) - - # Keep only non-bottom faces - if hasattr(primitive, 'indices'): - # For indexed geometry - valid_faces = faces[~bottom_faces_mask] - new_indices = valid_faces.flatten() - - # Update the indices buffer - indices_buffer.data = new_indices.tobytes() - indices_accessor.count = len(new_indices) - else: - # For non-indexed geometry - valid_vertices_mask = ~bottom_vertices - new_pos_array = pos_array[valid_vertices_mask] - - # Update the position buffer - pos_buffer.data = new_pos_array.tobytes() - pos_accessor.count = len(new_pos_array) - - # Save the cleaned GLB - gltf.save(output_path) - - -# Example usage -input_glb = "./blo/model.glb" -output_glb = "./blo/OUT.glb" -remove_bottom_plane(input_glb, output_glb) \ No newline at end of file diff --git a/cleanup_mesh5.py b/cleanup_mesh5.py deleted file mode 100644 index caa1e2ab..00000000 --- a/cleanup_mesh5.py +++ /dev/null @@ -1,237 +0,0 @@ -import trimesh -import numpy as np - -# Load the GLB file with error handling -def load_mesh(filename: str) -> trimesh.Trimesh: - try: - mesh = trimesh.load(filename) - return mesh.geometry["geometry_0"] - if not isinstance(mesh, trimesh.Trimesh): - raise ValueError("Loaded object is not a single mesh.") - return mesh - except Exception as e: - raise ValueError(f"Failed to load mesh: {e}") - - -import numpy as np - - -def remove_ground_plane(points, height_step=0.01, centroid_tolerance=0.1, max_iterations=100): - """ - Remove ground plane from a point cloud using statistical analysis. - - Args: - points: numpy array of shape (N, 3) containing the point cloud - height_step: how much to increase the height of the removal rectangle each iteration - centroid_tolerance: how close the centroid needs to be to the middle height range - max_iterations: maximum number of iterations to perform - - Returns: - numpy array of points with ground plane removed - """ - original_points = points.copy() - points = points.copy() - - # Calculate initial metrics - min_coords = np.min(points, axis=0) - max_coords = np.max(points, axis=0) - height_range = max_coords[2] - min_coords[2] - - for iteration in range(max_iterations): - # 1. Calculate current centroid - centroid = np.mean(points, axis=0) - - # 2. Find the 4 lowest points that are furthest from centroid in x-y plane - # Get points in the bottom 10% height - height_threshold = min_coords[2] + 0.1 * height_range - bottom_points = points[points[:, 2] < height_threshold] - - if len(bottom_points) == 0: - break - - # Calculate x-y distances from centroid - xy_distances = np.linalg.norm(bottom_points[:, :2] - centroid[:2], axis=1) - - # Get indices of 4 furthest points in x-y plane - furthest_indices = np.argpartition(xy_distances, -4)[-4:] - ground_corners = bottom_points[furthest_indices] - - # 3. Create a bounding box from these corners - ground_min = np.min(ground_corners, axis=0) - ground_max = np.max(ground_corners, axis=0) - - # Expand the height of the bounding box - ground_max[2] += height_step - - # 4. Remove points within this bounding box - in_ground = np.all((points >= ground_min) & (points <= ground_max), axis=1) - points = points[~in_ground] - - # Check termination conditions - new_centroid = np.mean(points, axis=0) - new_height_range = np.max(points[:, 2]) - np.min(points[:, 2]) - - # Condition 1: Centroid is adequately centered vertically - centroid_height_ratio = (new_centroid[2] - np.min(points[:, 2])) / new_height_range - centroid_centered = abs(centroid_height_ratio - 0.5) < centroid_tolerance - - # Condition 2: The furthest corners are now closer to centroid (ground removed) - new_bottom_points = points[points[:, 2] < (np.min(points[:, 2]) + 0.1 * new_height_range)] - if len(new_bottom_points) > 0: - new_xy_distances = np.linalg.norm(new_bottom_points[:, :2] - new_centroid[:2], axis=1) - avg_distance_reduced = np.mean(new_xy_distances) < 0.5 * np.mean(xy_distances) - else: - avg_distance_reduced = True - - if centroid_centered and avg_distance_reduced: - break - - # Update for next iteration - min_coords = np.min(points, axis=0) - max_coords = np.max(points, axis=0) - height_range = max_coords[2] - min_coords[2] - - return points - -# Identify and remove the bottom plane artifact -def remove_bottom_plane(mesh: trimesh.Trimesh, height_threshold: float = 0.5) -> trimesh.Trimesh: - # Make a copy of the original mesh - mesh = mesh.copy() - - # Calculate mesh characteristics - vertices = mesh.vertices - min_z = np.min(vertices[:, 2]) - z_range = np.max(vertices[:, 2]) - min_z - - # Adaptive threshold based on mesh size - adaptive_threshold = min_z + height_threshold * z_range - - # Find all faces that are entirely within the bottom plane region - face_z_values = vertices[mesh.faces][:, :, 2] # Z-coordinates of all face vertices - max_face_z = np.max(face_z_values, axis=1) # Max z for each face - - # Faces where all vertices are below the threshold - faces_to_remove = max_face_z <= adaptive_threshold - - # Remove these faces - mesh.update_faces(~faces_to_remove) - - # Remove unreferenced vertices to clean up - mesh.remove_unreferenced_vertices() - - return mesh - - -def remove_ground_plane_from_mesh(mesh: trimesh.Trimesh, - height_step=0.01, - centroid_tolerance=0.1, - max_iterations=100) -> trimesh.Trimesh: - """ - Remove ground plane from a mesh using statistical analysis of its vertices. - - Args: - mesh: Input mesh to process - height_step: How much to increase the height of the removal volume each iteration - centroid_tolerance: How close the centroid needs to be to the middle height range - max_iterations: Maximum number of iterations to perform - - Returns: - Processed mesh with ground plane removed - """ - # Make a copy of the original mesh - mesh = mesh.copy() - vertices = mesh.vertices - - # Calculate initial metrics - min_coords = np.min(vertices, axis=0) - max_coords = np.max(vertices, axis=0) - height_range = max_coords[2] - min_coords[2] - - # Initialize removal volume - removal_min = None - removal_max = None - - for iteration in range(max_iterations): - # 1. Calculate current centroid - centroid = np.mean(vertices, axis=0) - - # 2. Find the 4 lowest points that are furthest from centroid in x-y plane - # Get points in the bottom 10% height - height_threshold = min_coords[2] + 0.1 * height_range - bottom_points = vertices[vertices[:, 2] < height_threshold] - - if len(bottom_points) == 0: - break - - # Calculate x-y distances from centroid - xy_distances = np.linalg.norm(bottom_points[:, :2] - centroid[:2], axis=1) - - # Get indices of 4 furthest points in x-y plane - furthest_indices = np.argpartition(xy_distances, -4)[-4:] - ground_corners = bottom_points[furthest_indices] - - # 3. Create a bounding box from these corners - ground_min = np.min(ground_corners, axis=0) - ground_max = np.max(ground_corners, axis=0) - - # Expand the height of the bounding box - ground_max[2] += height_step - - # 4. Find faces that are entirely within this bounding box - # Get all vertices of each face - face_vertices = vertices[mesh.faces] - - # Check if all 3 vertices of each face are within the bounding box - in_ground = np.all( - np.all((face_vertices >= ground_min) & (face_vertices <= ground_max), axis=2), - axis=1) - - # Remove these faces - mesh.update_faces(~in_ground) - - # Check termination conditions - vertices = mesh.vertices # Get updated vertices after face removal - if len(vertices) == 0: - break - - new_centroid = np.mean(vertices, axis=0) - new_height_range = np.max(vertices[:, 2]) - np.min(vertices[:, 2]) - - # Condition 1: Centroid is adequately centered vertically - centroid_height_ratio = (new_centroid[2] - np.min(vertices[:, 2])) / new_height_range - centroid_centered = abs(centroid_height_ratio - 0.5) < centroid_tolerance - - # Condition 2: The furthest corners are now closer to centroid (ground removed) - new_bottom_points = vertices[vertices[:, 2] < (np.min(vertices[:, 2]) + 0.1 * new_height_range)] - if len(new_bottom_points) > 0: - new_xy_distances = np.linalg.norm(new_bottom_points[:, :2] - new_centroid[:2], axis=1) - avg_distance_reduced = np.mean(new_xy_distances) < 0.5 * np.mean(xy_distances) - else: - avg_distance_reduced = True - - if centroid_centered and avg_distance_reduced: - break - - # Update for next iteration - min_coords = np.min(vertices, axis=0) - max_coords = np.max(vertices, axis=0) - height_range = max_coords[2] - min_coords[2] - - # Clean up any unreferenced vertices - mesh.remove_unreferenced_vertices() - - return mesh - -# Save the cleaned mesh -def save_mesh(mesh: trimesh.Trimesh, filename: str, file_format: str = "glb"): - mesh.export(filename, file_type=file_format) - -if __name__ == "__main__": - input_file = "./bla/model.glb" - output_file = "./bla/model2.glb" - - mesh = load_mesh(input_file) - cleaned_mesh = remove_ground_plane_from_mesh(mesh) - save_mesh(cleaned_mesh, output_file) - - #print(f"Cleaned mesh saved to {args.output_file}.") \ No newline at end of file diff --git a/flaskclient.py b/flaskclient.py index dcb63549..332121a6 100644 --- a/flaskclient.py +++ b/flaskclient.py @@ -208,6 +208,17 @@ def generate_and_download_from_multiple_images(self, image_paths, target_dir=Non if __name__ == '__main__': client = Trellis3DClient() + multi_result = client.generate_and_download_from_multiple_images( + '/home/charlie/Desktop/Holodeck/hippo/datasets/sacha_kitchen/segments/6/rgb', + target_dir="./blo", + params={ + 'multiimage_algo': 'stochastic', + 'seed': 123 + } + ) + exit() + + multi_result = client.generate_and_download_from_multiple_images( '/home/charlie/Desktop/Holodeck/hippo/datasets/sacha_kitchen/segments/6/rgb', target_dir="./blo", From 79fe26e9af2778354da7f8286a0ab8a1e878520b Mon Sep 17 00:00:00 2001 From: Velythyl Date: Tue, 20 May 2025 14:17:24 -0400 Subject: [PATCH 5/8] remove print statements --- trellis/pipelines/trellis_image_to_3d.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/trellis/pipelines/trellis_image_to_3d.py b/trellis/pipelines/trellis_image_to_3d.py index b6b6dcac..3725da20 100644 --- a/trellis/pipelines/trellis_image_to_3d.py +++ b/trellis/pipelines/trellis_image_to_3d.py @@ -84,8 +84,6 @@ def preprocess_image(self, input: Image.Image) -> Image.Image: Preprocess the input image. """ # if has alpha channel, use it directly; otherwise, remove background - - print("Processing images...") has_alpha = False if input.mode == 'RGBA': alpha = np.array(input)[:, :, 3] @@ -94,7 +92,6 @@ def preprocess_image(self, input: Image.Image) -> Image.Image: if has_alpha: output = input else: - print("Removing background...") input = input.convert('RGB') max_size = max(input.size) scale = min(1, 1024 / max_size) From cce9290407978abd710ccf30b765c7c913dfb3c2 Mon Sep 17 00:00:00 2001 From: Charlie Gauthier Date: Wed, 21 May 2025 16:50:46 -0400 Subject: [PATCH 6/8] spinner --- flaskclient.py | 110 +++++++++++++++++-------------------------------- 1 file changed, 37 insertions(+), 73 deletions(-) diff --git a/flaskclient.py b/flaskclient.py index 332121a6..0070a1d7 100644 --- a/flaskclient.py +++ b/flaskclient.py @@ -1,6 +1,10 @@ +import threading +import time + import requests import os import uuid +from tqdm import tqdm class Trellis3DClient: def __init__(self, base_url='http://localhost:5000'): @@ -139,6 +143,7 @@ def generate_and_download_from_single_image(self, image_path, target_dir=None, p os.makedirs(target_dir, exist_ok=True) # Generate the 3D model + et = start_spinner_thread("Generating 3D model...") gen_result = self.generate_from_single_image(image_path, params) # Download preview @@ -147,14 +152,17 @@ def generate_and_download_from_single_image(self, image_path, target_dir=None, p f"{self.base_url}{gen_result['preview_url']}", preview_path ) + stop_spinner_thread(*et) # Extract and download GLB + et = start_spinner_thread("Extracting 3D model...") glb_result = self.extract_glb(gen_result['session_id'], params) glb_path = os.path.join(target_dir, 'model.glb') self.download_file( f"{self.base_url}{glb_result['glb_url']}", glb_path ) + stop_spinner_thread(*et) return { 'preview_path': preview_path, @@ -180,6 +188,7 @@ def generate_and_download_from_multiple_images(self, image_paths, target_dir=Non os.makedirs(target_dir, exist_ok=True) # Generate the 3D model + et = start_spinner_thread("Generating 3D model...") gen_result = self.generate_from_multiple_images(image_paths, params) # Download preview @@ -188,14 +197,17 @@ def generate_and_download_from_multiple_images(self, image_paths, target_dir=Non f"{self.base_url}{gen_result['preview_url']}", preview_path ) + stop_spinner_thread(*et) # Extract and download GLB + et = start_spinner_thread("Extracting 3D model...") glb_result = self.extract_glb(gen_result['session_id'], params) glb_path = os.path.join(target_dir, 'model.glb') self.download_file( f"{self.base_url}{glb_result['glb_url']}", glb_path ) + stop_spinner_thread(*et) return { 'preview_path': preview_path, @@ -204,6 +216,31 @@ def generate_and_download_from_multiple_images(self, image_paths, target_dir=Non 'target_dir': target_dir } +def spinner(desc, stop_event): + from itertools import cycle + + spinner = cycle(['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']) + with tqdm(total=None, desc=desc, bar_format='{desc}') as pbar: + while not stop_event.is_set(): + pbar.set_description(f"{desc} - {next(spinner)}") + time.sleep(0.1) + #pbar.update() + +def start_spinner_thread(desc) -> (threading.Event, threading.Thread): + stop_event = threading.Event() + spinner_thread = threading.Thread( + target=spinner, + args=(desc, stop_event) + ) + spinner_thread.start() + + return stop_event, spinner_thread + +def stop_spinner_thread(stop_event, spinner_thread): + stop_event.set() + spinner_thread.join(timeout=5) + + # Example usage if __name__ == '__main__': client = Trellis3DClient() @@ -228,76 +265,3 @@ def generate_and_download_from_multiple_images(self, image_paths, target_dir=Non } ) exit() - - single_result = client.generate_and_download_from_single_image( - 'test/000.png', - target_dir="./bla", - params={ - 'seed': 42, - 'ss_guidance_strength': 7.5, - 'slat_guidance_strength': 3.0 - } - ) - - exit() - - # Example 1: Single image generation - print("Generating from single image...") - single_result = client.generate_from_single_image( - 'test/000.png', - params={ - 'seed': 42, - 'ss_guidance_strength': 7.5, - 'slat_guidance_strength': 3.0 - } - ) - print(single_result) - - # Download preview - client.download_file( - f"http://localhost:5000{single_result['preview_url']}", - 'single_preview.mp4' - ) - - # Extract GLB - glb_result = client.extract_glb( - single_result['session_id'], - params={'mesh_simplify': 0.95} - ) - print(glb_result) - - # Download GLB - client.download_file( - f"http://localhost:5000{glb_result['glb_url']}", - 'single_model.glb' - ) - - # Example 2: Multiple image generation - print("\nGenerating from multiple images...") - multi_result = client.generate_from_multiple_images( - ['view1.png', 'view2.png', 'view3.png'], - params={ - 'multiimage_algo': 'stochastic', - 'seed': 123 - } - ) - print(multi_result) - - # Download preview - client.download_file( - f"http://localhost:5000{multi_result['preview_url']}", - 'multi_preview.mp4' - ) - - # Extract GLB - glb_result = client.extract_glb( - multi_result['session_id'], - params={'texture_size': 2048} - ) - print(glb_result) - - # Download GLB - client.download_file( - f"http://localhost:5000{glb_result['glb_url']}", - 'multi_model.glb' - ) \ No newline at end of file From f8f1cc4e572e98a0d6a55c15d8781ab1b9d8374a Mon Sep 17 00:00:00 2001 From: Velythyl Date: Fri, 11 Jul 2025 10:57:33 -0400 Subject: [PATCH 7/8] port --- flaskserver.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/flaskserver.py b/flaskserver.py index fc23f690..eea060b4 100644 --- a/flaskserver.py +++ b/flaskserver.py @@ -19,7 +19,7 @@ os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True) # Initialize pipeline -pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large") +pipeline = TrellisImageTo3DPipeline.from_pretrained("gqk/TRELLIS-image-large-fork") pipeline.cuda() MAX_SEED = np.iinfo(np.int32).max @@ -78,6 +78,7 @@ def generate_from_single_image(): ss_sampling_steps = int(request.form.get('ss_sampling_steps', 12)) slat_guidance_strength = float(request.form.get('slat_guidance_strength', 3.0)) slat_sampling_steps = int(request.form.get('slat_sampling_steps', 12)) + preprocess_image = request.form.get('preprocess_image', 'True').lower() == 'true' # Handle file upload if 'image' not in request.files: @@ -98,14 +99,13 @@ def generate_from_single_image(): file.save(image_path) # Preprocess image - image = preprocess_image(image_path) # Generate 3D model outputs = pipeline.run( - image, + Image.open(image_path), seed=seed, formats=["gaussian", "mesh"], - preprocess_image=True, + preprocess_image=preprocess_image, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, @@ -153,6 +153,7 @@ def generate_from_multiple_images(): slat_guidance_strength = float(request.form.get('slat_guidance_strength', 3.0)) slat_sampling_steps = int(request.form.get('slat_sampling_steps', 12)) multiimage_algo = request.form.get('multiimage_algo', 'stochastic') + preprocess_image = request.form.get('preprocess_image', 'True').lower() == 'true' # Handle file uploads if 'images' not in request.files: @@ -182,7 +183,7 @@ def generate_from_multiple_images(): images, seed=seed, formats=["gaussian", "mesh"], - preprocess_image=True, + preprocess_image=preprocess_image, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, @@ -278,4 +279,16 @@ def download_glb(session_id): return send_file(glb_path, as_attachment=True) if __name__ == '__main__': - app.run(host='0.0.0.0', port=5000, threaded=True) \ No newline at end of file + + #def find_free_port(): + # # todo + # with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # s.bind(('0.0.0.0', 0)) # Bind to a free port provided by the host + # return s.getsockname()[1] # Return the port number assigned + import sys + for i, arg in sys.argv: + if arg.startswith('--port='): + port = int(arg.split('=')[1]) + break + + app.run(host='0.0.0.0', port=port, threaded=True) \ No newline at end of file From 77fa0598f40e4cd5a8dc24d7bbc29979c41bf751 Mon Sep 17 00:00:00 2001 From: Velythyl Date: Fri, 11 Jul 2025 11:09:48 -0400 Subject: [PATCH 8/8] port --- flaskserver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flaskserver.py b/flaskserver.py index eea060b4..5d67d450 100644 --- a/flaskserver.py +++ b/flaskserver.py @@ -286,7 +286,7 @@ def download_glb(session_id): # s.bind(('0.0.0.0', 0)) # Bind to a free port provided by the host # return s.getsockname()[1] # Return the port number assigned import sys - for i, arg in sys.argv: + for arg in sys.argv: if arg.startswith('--port='): port = int(arg.split('=')[1]) break