diff --git a/flaskclient.py b/flaskclient.py new file mode 100644 index 00000000..0070a1d7 --- /dev/null +++ b/flaskclient.py @@ -0,0 +1,267 @@ +import threading +import time + +import requests +import os +import uuid +from tqdm import tqdm + +class Trellis3DClient: + def __init__(self, base_url='http://localhost:5000'): + self.base_url = base_url + + def generate_from_single_image(self, image_path, params=None): + """ + Generate 3D model from a single image + + Args: + image_path (str): Path to the input image + params (dict): Optional parameters including: + - seed (int): Random seed + - ss_guidance_strength (float): Guidance strength for sparse structure generation + - ss_sampling_steps (int): Sampling steps for sparse structure generation + - slat_guidance_strength (float): Guidance strength for structured latent generation + - slat_sampling_steps (int): Sampling steps for structured latent generation + + Returns: + dict: Response containing session_id and download URLs + """ + if params is None: + params = {} + + url = f"{self.base_url}/generate_from_single_image" + files = {'image': open(image_path, 'rb')} + + response = requests.post(url, files=files, data=params) + return response.json() + + def generate_from_multiple_images(self, image_paths, params=None): + """ + Generate 3D model from multiple images + + Args: + image_paths (str or list): Either: + - A list of paths to input images, or + - A directory path containing images (will load all image files from the directory) + params (dict): Optional parameters including: + - seed (int): Random seed + - ss_guidance_strength (float): Guidance strength for sparse structure generation + - ss_sampling_steps (int): Sampling steps for sparse structure generation + - slat_guidance_strength (float): Guidance strength for structured latent generation + - slat_sampling_steps (int): Sampling steps for structured latent generation + - multiimage_algo (str): Algorithm for multi-image generation ('stochastic' or 'multidiffusion') + + Returns: + dict: Response containing session_id and download URLs + """ + if params is None: + params = {} + + # If image_paths is a directory, find all image files in it + if isinstance(image_paths, str) and os.path.isdir(image_paths): + image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.gif') + image_paths = [ + os.path.join(image_paths, f) + for f in os.listdir(image_paths) + if f.lower().endswith(image_extensions) + ] + if not image_paths: + raise ValueError(f"No image files found in directory: {image_paths}") + + # Ensure image_paths is a list at this point + if not isinstance(image_paths, list): + raise ValueError("image_paths must be either a list of image paths or a directory path") + + url = f"{self.base_url}/generate_from_multiple_images" + files = [('images', (os.path.basename(path), open(path, 'rb'))) for path in image_paths] + + try: + response = requests.post(url, files=files, data=params) + return response.json() + finally: + # Ensure all files are closed after the request + for _, (_, file_obj) in files: + file_obj.close() + + def extract_glb(self, session_id, params=None): + """ + Extract GLB file from generated 3D model + + Args: + session_id (str): Session ID from generation step + params (dict): Optional parameters including: + - mesh_simplify (float): Mesh simplification factor (0.9-0.98) + - texture_size (int): Texture resolution (512, 1024, 1536, or 2048) + + Returns: + dict: Response containing GLB download URL + """ + if params is None: + params = {} + + url = f"{self.base_url}/extract_glb" + data = {'session_id': session_id, **params} + + response = requests.post(url, data=data) + return response.json() + + def download_file(self, url, save_path=None): + """ + Download a file from the server + + Args: + url (str): Full URL to download (from previous responses) + save_path (str): Optional path to save the file + + Returns: + str: Path where file was saved + """ + if save_path is None: + save_path = os.path.basename(url) + + response = requests.get(url, stream=True) + with open(save_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + return save_path + + def generate_and_download_from_single_image(self, image_path, target_dir=None, params=None): + """ + Generate 3D model from a single image and download all artifacts + + Args: + image_path (str): Path to the input image + target_dir (str): Optional target directory to save files (defaults to /tmp/random_uuid) + params (dict): Optional generation parameters + + Returns: + dict: Paths to downloaded files + """ + if target_dir is None: + target_dir = f"/tmp/{uuid.uuid4()}" + os.makedirs(target_dir, exist_ok=True) + + # Generate the 3D model + et = start_spinner_thread("Generating 3D model...") + gen_result = self.generate_from_single_image(image_path, params) + + # Download preview + preview_path = os.path.join(target_dir, 'preview.mp4') + self.download_file( + f"{self.base_url}{gen_result['preview_url']}", + preview_path + ) + stop_spinner_thread(*et) + + # Extract and download GLB + et = start_spinner_thread("Extracting 3D model...") + glb_result = self.extract_glb(gen_result['session_id'], params) + glb_path = os.path.join(target_dir, 'model.glb') + self.download_file( + f"{self.base_url}{glb_result['glb_url']}", + glb_path + ) + stop_spinner_thread(*et) + + return { + 'preview_path': preview_path, + 'glb_path': glb_path, + 'session_id': gen_result['session_id'], + 'target_dir': target_dir + } + + def generate_and_download_from_multiple_images(self, image_paths, target_dir=None, params=None): + """ + Generate 3D model from multiple images and download all artifacts + + Args: + image_paths (list): List of paths to input images + target_dir (str): Optional target directory to save files (defaults to /tmp/random_uuid) + params (dict): Optional generation parameters + + Returns: + dict: Paths to downloaded files + """ + if target_dir is None: + target_dir = f"/tmp/{uuid.uuid4()}" + os.makedirs(target_dir, exist_ok=True) + + # Generate the 3D model + et = start_spinner_thread("Generating 3D model...") + gen_result = self.generate_from_multiple_images(image_paths, params) + + # Download preview + preview_path = os.path.join(target_dir, 'preview.mp4') + self.download_file( + f"{self.base_url}{gen_result['preview_url']}", + preview_path + ) + stop_spinner_thread(*et) + + # Extract and download GLB + et = start_spinner_thread("Extracting 3D model...") + glb_result = self.extract_glb(gen_result['session_id'], params) + glb_path = os.path.join(target_dir, 'model.glb') + self.download_file( + f"{self.base_url}{glb_result['glb_url']}", + glb_path + ) + stop_spinner_thread(*et) + + return { + 'preview_path': preview_path, + 'glb_path': glb_path, + 'session_id': gen_result['session_id'], + 'target_dir': target_dir + } + +def spinner(desc, stop_event): + from itertools import cycle + + spinner = cycle(['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']) + with tqdm(total=None, desc=desc, bar_format='{desc}') as pbar: + while not stop_event.is_set(): + pbar.set_description(f"{desc} - {next(spinner)}") + time.sleep(0.1) + #pbar.update() + +def start_spinner_thread(desc) -> (threading.Event, threading.Thread): + stop_event = threading.Event() + spinner_thread = threading.Thread( + target=spinner, + args=(desc, stop_event) + ) + spinner_thread.start() + + return stop_event, spinner_thread + +def stop_spinner_thread(stop_event, spinner_thread): + stop_event.set() + spinner_thread.join(timeout=5) + + +# Example usage +if __name__ == '__main__': + client = Trellis3DClient() + + multi_result = client.generate_and_download_from_multiple_images( + '/home/charlie/Desktop/Holodeck/hippo/datasets/sacha_kitchen/segments/6/rgb', + target_dir="./blo", + params={ + 'multiimage_algo': 'stochastic', + 'seed': 123 + } + ) + exit() + + + multi_result = client.generate_and_download_from_multiple_images( + '/home/charlie/Desktop/Holodeck/hippo/datasets/sacha_kitchen/segments/6/rgb', + target_dir="./blo", + params={ + 'multiimage_algo': 'stochastic', + 'seed': 123 + } + ) + exit() diff --git a/flaskserver.py b/flaskserver.py new file mode 100644 index 00000000..5d67d450 --- /dev/null +++ b/flaskserver.py @@ -0,0 +1,294 @@ +import os +import uuid +import shutil +from flask import Flask, request, jsonify, send_file +from werkzeug.utils import secure_filename +from PIL import Image +import numpy as np +import torch +from easydict import EasyDict as edict +import imageio +from trellis.pipelines import TrellisImageTo3DPipeline +from trellis.representations import Gaussian, MeshExtractResult +from trellis.utils import render_utils, postprocessing_utils + +app = Flask(__name__) +app.config['UPLOAD_FOLDER'] = 'uploads' +app.config['OUTPUT_FOLDER'] = 'outputs' +os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) +os.makedirs(app.config['OUTPUT_FOLDER'], exist_ok=True) + +# Initialize pipeline +pipeline = TrellisImageTo3DPipeline.from_pretrained("gqk/TRELLIS-image-large-fork") +pipeline.cuda() + +MAX_SEED = np.iinfo(np.int32).max + +def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict: + return { + 'gaussian': { + **gs.init_params, + '_xyz': gs._xyz.cpu().numpy(), + '_features_dc': gs._features_dc.cpu().numpy(), + '_scaling': gs._scaling.cpu().numpy(), + '_rotation': gs._rotation.cpu().numpy(), + '_opacity': gs._opacity.cpu().numpy(), + }, + 'mesh': { + 'vertices': mesh.vertices.cpu().numpy(), + 'faces': mesh.faces.cpu().numpy(), + }, + } + +from typing import Tuple + +def unpack_state(state: dict) -> Tuple[Gaussian, edict]: + gs = Gaussian( + aabb=state['gaussian']['aabb'], + sh_degree=state['gaussian']['sh_degree'], + mininum_kernel_size=state['gaussian']['mininum_kernel_size'], + scaling_bias=state['gaussian']['scaling_bias'], + opacity_bias=state['gaussian']['opacity_bias'], + scaling_activation=state['gaussian']['scaling_activation'], + ) + gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') + gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') + gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') + gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') + gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') + + mesh = edict( + vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), + faces=torch.tensor(state['mesh']['faces'], device='cuda'), + ) + + return gs, mesh + +def preprocess_image(image_path: str) -> Image.Image: + image = Image.open(image_path) + processed_image = pipeline.preprocess_image(image) + return processed_image + +@app.route('/generate_from_single_image', methods=['POST']) +def generate_from_single_image(): + try: + # Get parameters + seed = int(request.form.get('seed', 0)) + ss_guidance_strength = float(request.form.get('ss_guidance_strength', 7.5)) + ss_sampling_steps = int(request.form.get('ss_sampling_steps', 12)) + slat_guidance_strength = float(request.form.get('slat_guidance_strength', 3.0)) + slat_sampling_steps = int(request.form.get('slat_sampling_steps', 12)) + preprocess_image = request.form.get('preprocess_image', 'True').lower() == 'true' + + # Handle file upload + if 'image' not in request.files: + return jsonify({'error': 'No image provided'}), 400 + + file = request.files['image'] + if file.filename == '': + return jsonify({'error': 'No selected file'}), 400 + + # Create session directory + session_id = str(uuid.uuid4()) + session_dir = os.path.join(app.config['UPLOAD_FOLDER'], session_id) + os.makedirs(session_dir, exist_ok=True) + + # Save uploaded file + filename = secure_filename(file.filename) + image_path = os.path.join(session_dir, filename) + file.save(image_path) + + # Preprocess image + + # Generate 3D model + outputs = pipeline.run( + Image.open(image_path), + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=preprocess_image, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + ) + + # Create output directory + output_dir = os.path.join(app.config['OUTPUT_FOLDER'], session_id) + os.makedirs(output_dir, exist_ok=True) + + # Save video preview + video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] + video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] + video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] + video_path = os.path.join(output_dir, 'preview.mp4') + imageio.mimsave(video_path, video, fps=15) + + # Save state + state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) + state_path = os.path.join(output_dir, 'state.pkl') + torch.save(state, state_path) + + torch.cuda.empty_cache() + + return jsonify({ + 'session_id': session_id, + 'preview_url': f'/download/preview/{session_id}', + 'state_url': f'/download/state/{session_id}' + }), 200 + + except Exception as e: + return jsonify({'error': str(e)}), 500 + +@app.route('/generate_from_multiple_images', methods=['POST']) +def generate_from_multiple_images(): + try: + # Get parameters + seed = int(request.form.get('seed', 0)) + ss_guidance_strength = float(request.form.get('ss_guidance_strength', 7.5)) + ss_sampling_steps = int(request.form.get('ss_sampling_steps', 12)) + slat_guidance_strength = float(request.form.get('slat_guidance_strength', 3.0)) + slat_sampling_steps = int(request.form.get('slat_sampling_steps', 12)) + multiimage_algo = request.form.get('multiimage_algo', 'stochastic') + preprocess_image = request.form.get('preprocess_image', 'True').lower() == 'true' + + # Handle file uploads + if 'images' not in request.files: + return jsonify({'error': 'No images provided'}), 400 + + files = request.files.getlist('images') + if len(files) == 0: + return jsonify({'error': 'No selected files'}), 400 + + # Create session directory + session_id = str(uuid.uuid4()) + session_dir = os.path.join(app.config['UPLOAD_FOLDER'], session_id) + os.makedirs(session_dir, exist_ok=True) + + # Save uploaded files + images = [] + for file in files: + if file.filename == '': + continue + filename = secure_filename(file.filename) + image_path = os.path.join(session_dir, filename) + file.save(image_path) + images.append(Image.open(image_path)) + + # Generate 3D model + outputs = pipeline.run_multi_image( + images, + seed=seed, + formats=["gaussian", "mesh"], + preprocess_image=preprocess_image, + sparse_structure_sampler_params={ + "steps": ss_sampling_steps, + "cfg_strength": ss_guidance_strength, + }, + slat_sampler_params={ + "steps": slat_sampling_steps, + "cfg_strength": slat_guidance_strength, + }, + mode=multiimage_algo, + ) + + # Create output directory + output_dir = os.path.join(app.config['OUTPUT_FOLDER'], session_id) + os.makedirs(output_dir, exist_ok=True) + + # Save video preview + video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color'] + video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal'] + video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))] + video_path = os.path.join(output_dir, 'preview.mp4') + imageio.mimsave(video_path, video, fps=15) + + # Save state + state = pack_state(outputs['gaussian'][0], outputs['mesh'][0]) + state_path = os.path.join(output_dir, 'state.pkl') + torch.save(state, state_path) + + torch.cuda.empty_cache() + + return jsonify({ + 'session_id': session_id, + 'preview_url': f'/download/preview/{session_id}', + 'state_url': f'/download/state/{session_id}' + }), 200 + + except Exception as e: + return jsonify({'error': str(e)}), 500 + +@app.route('/extract_glb', methods=['POST']) +def extract_glb(): + try: + # Get parameters + session_id = request.form.get('session_id') + mesh_simplify = float(request.form.get('mesh_simplify', 0.95)) + texture_size = int(request.form.get('texture_size', 1024)) + + if not session_id: + return jsonify({'error': 'session_id is required'}), 400 + + # Load state + state_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'state.pkl') + if not os.path.exists(state_path): + return jsonify({'error': 'Invalid session_id'}), 404 + + state = torch.load(state_path) + gs, mesh = unpack_state(state) + + # Extract GLB + glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) + + # Save GLB + glb_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'model.glb') + glb.export(glb_path) + + torch.cuda.empty_cache() + + return jsonify({ + 'glb_url': f'/download/glb/{session_id}' + }), 200 + + except Exception as e: + return jsonify({'error': str(e)}), 500 + +@app.route('/download/preview/', methods=['GET']) +def download_preview(session_id): + preview_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'preview.mp4') + if not os.path.exists(preview_path): + return jsonify({'error': 'Preview not found'}), 404 + return send_file(preview_path, as_attachment=True) + +@app.route('/download/state/', methods=['GET']) +def download_state(session_id): + state_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'state.pkl') + if not os.path.exists(state_path): + return jsonify({'error': 'State not found'}), 404 + return send_file(state_path, as_attachment=True) + +@app.route('/download/glb/', methods=['GET']) +def download_glb(session_id): + glb_path = os.path.join(app.config['OUTPUT_FOLDER'], session_id, 'model.glb') + if not os.path.exists(glb_path): + return jsonify({'error': 'GLB not found'}), 404 + return send_file(glb_path, as_attachment=True) + +if __name__ == '__main__': + + #def find_free_port(): + # # todo + # with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # s.bind(('0.0.0.0', 0)) # Bind to a free port provided by the host + # return s.getsockname()[1] # Return the port number assigned + import sys + for arg in sys.argv: + if arg.startswith('--port='): + port = int(arg.split('=')[1]) + break + + app.run(host='0.0.0.0', port=port, threaded=True) \ No newline at end of file