diff --git a/requirements.txt b/requirements.txt index 4a7e68ad..92fdbb0b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ aiohttp toml twilio prometheus_client +GitPython \ No newline at end of file diff --git a/server/api/__init__.py b/server/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/api/api.py b/server/api/api.py new file mode 100644 index 00000000..4b22da60 --- /dev/null +++ b/server/api/api.py @@ -0,0 +1,306 @@ +from aiohttp import web +import asyncio + +from api.nodes.nodes import list_nodes, install_node, delete_node +from api.models.models import list_models, add_model, delete_model +from api.settings.settings import set_twilio_account_info +from pipeline import Pipeline + +from comfy.nodes.package import _comfy_nodes, import_all_nodes_in_workspace + +from api.nodes.nodes import force_import_all_nodes_in_workspace +#use a different node import +import_all_nodes_in_workspace = force_import_all_nodes_in_workspace + +def add_routes(app): + app.router.add_get("/env/list_nodes", nodes) + app.router.add_post("/env/install_nodes", install_nodes) + app.router.add_post("/env/delete_nodes", delete_nodes) + + app.router.add_get("/env/list_models", models) + app.router.add_post("/env/add_models", add_models) + app.router.add_post("/env/delete_models", delete_models) + + app.router.add_post("/env/reload", reload) + app.router.add_post("/env/set_account_info", set_account_info) + +async def reload(request): + ''' + Reload ComfyUI environment + + ''' + + #reset embedded client + await request.app["pipeline"].cleanup() + + #reset imports to clear imported nodes + global _comfy_nodes + import_all_nodes_in_workspace = force_import_all_nodes_in_workspace + _comfy_nodes = import_all_nodes_in_workspace() + + #reload pipeline + request.app["pipeline"] = Pipeline(cwd=request.app["workspace"], disable_cuda_malloc=True, gpu_only=True) + + #reset webrtc connections + pcs = request.app["pcs"] + coros = [pc.close() for pc in pcs] + await asyncio.gather(*coros) + pcs.clear() + + return web.json_response({"success": True, "error": None}) + +async def nodes(request): + ''' + List all custom nodes in the workspace + + # Example response: + { + "error": null, + "nodes": + [ + { + "name": ComfyUI-Custom-Node, + "version": "0.0.1", + "url": "https://github.com/custom-node-maker/ComfyUI-Custom-Node", + "branch": "main", + "commit": "uasfg98", + "update_available": false, + }, + { + ... + }, + { + ... + } + ] + } + + ''' + workspace_dir = request.app["workspace"] + try: + nodes = await list_nodes(workspace_dir) + return web.json_response({"error": None, "nodes": nodes}) + except Exception as e: + return web.json_response({"error": str(e), "nodes": nodes}, status=500) + +async def install_nodes(request): + ''' + Install ComfyUI custom node from git repository. + + Installs requirements.txt from repository if present + + # Parameters: + url: url of the git repository + branch: branch of the git repository + depdenencies: comma separated list of dependencies to install with pip (optional) + + # Example request: + [ + { + "url": "https://github.com/custom-node-maker/ComfyUI-Custom-Node", + "branch": "main" + }, + { + "url": "https://github.com/custom-node-maker/ComfyUI-Custom-Node", + "branch": "main", + "dependencies": "requests, numpy" + } + ] + ''' + workspace_dir = request.app["workspace"] + try: + nodes = await request.json() + installed_nodes = [] + for node in nodes: + await install_node(node, workspace_dir) + installed_nodes.append(node['url']) + + return web.json_response({"success": True, "error": None, "installed_nodes": installed_nodes}) + except Exception as e: + return web.json_response({"success": False, "error": str(e), "installed_nodes": installed_nodes}, status=500) + +async def delete_nodes(request): + ''' + Delete ComfyUI custom node + + # Parameters: + name: name of the repository (e.g. ComfyUI-Custom-Node for url "https://github.com/custom-node-maker/ComfyUI-Custom-Node") + + # Example request: + [ + { + "name": "ComfyUI-Custom-Node" + }, + { + ... + } + ] + ''' + workspace_dir = request.app["workspace"] + try: + nodes = await request.json() + deleted_nodes = [] + for node in nodes: + await delete_node(node, workspace_dir) + deleted_nodes.append(node['name']) + return web.json_response({"success": True, "error": None, "deleted_nodes": deleted_nodes}) + except Exception as e: + return web.json_response({"success": False, "error": str(e), "deleted_nodes": deleted_nodes}, status=500) + +async def models(request): + ''' + List all custom models in the workspace + + # Example response: + { + "error": null, + "models": + { + "checkpoints": [ + { + "name": "dreamshaper-8.safetensors", + "path": "SD1.5/dreamshaper-8.safetensors", + "type": "checkpoint", + "downloading": false" + } + ], + "controlnet": [ + { + "name": "controlnet.sd15.safetensors", + "path": "SD1.5/controlnet.sd15.safetensors", + "type": "controlnet", + "downloading": false" + } + ], + "unet": [ + { + "name": "unet.sd15.safetensors", + "path": "SD1.5/unet.sd15.safetensors", + "type": "unet", + "downloading": false" + } + ], + "vae": [ + { + "name": "vae.safetensors", + "path": "vae.safetensors", + "type": "vae", + "downloading": false" + } + ], + "tensorrt": [ + { + "name": "model.trt", + "path": "model.trt", + "type": "tensorrt", + "downloading": false" + } + ] + } + } + + ''' + workspace_dir = request.app["workspace"] + try: + models = await list_models(workspace_dir) + return web.json_response({"error": None, "models": models}) + except Exception as e: + return web.json_response({"error": str(e), "models": models}, status=500) + +async def add_models(request): + ''' + Download models from url + + # Parameters: + url: url of the git repository + type: type of model (e.g. checkpoints, controlnet, unet, vae, onnx, tensorrt) + path: path of the model. supports up to 1 subfolder (e.g. SD1.5/newmodel.safetensors) + + # Example request: + [ + { + "url": "http://url.to/model.safetensors", + "type": "checkpoints" + }, + { + "url": "http://url.to/controlnet.super.safetensors", + "type": "controlnet", + "path": "SD1.5/controlnet.super.safetensors" + } + ] + ''' + workspace_dir = request.app["workspace"] + try: + models = await request.json() + added_models = [] + for model in models: + await add_model(model, workspace_dir) + added_models.append(model['url']) + return web.json_response({"success": True, "error": None, "added_models": added_models}) + except Exception as e: + return web.json_response({"success": False, "error": str(e), "added_nodes": added_models}, status=500) + +async def delete_models(request): + ''' + Delete model + + # Parameters: + type: type of model (e.g. checkpoints, controlnet, unet, vae, onnx, tensorrt) + path: path of the model. supports up to 1 subfolder (e.g. SD1.5/newmodel.safetensors) + + # Example request: + [ + { + "type": "checkpoints", + "path": "model.safetensors" + }, + { + "type": "controlnet", + "path": "SD1.5/controlnet.super.safetensors" + } + ] + ''' + workspace_dir = request.app["workspace"] + try: + models = await request.json() + deleted_models = [] + for model in models: + await delete_model(model, workspace_dir) + deleted_models.append(model['path']) + return web.json_response({"success": True, "error": None, "deleted_models": deleted_models}) + except Exception as e: + return web.json_response({"success": False, "error": str(e), "deleted_models": deleted_models}, status=500) + +async def set_account_info(request): + ''' + Set account info for ice server providers + + # Parameters: + type: account type (e.g. twilio) + account_id: account id from provider + auth_token: auth token from provider + + # Example request: + [ + { + "type": "twilio", + "account_id": "ACXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", + "auth_token": "your_auth_token" + }, + { + ... + } + ] + + ''' + try: + accounts = await request.json() + accounts_updated = [] + for account in accounts: + if 'type' in account: + if account['type'] == 'twilio': + await set_twilio_account_info(account) + accounts_updated.append(account['type']) + return web.json_response({"success": True, "error": None, "accounts_updated": accounts_updated}) + except Exception as e: + return web.json_response({"success": False, "error": str(e), "accounts_updated": accounts_updated}, status=500) diff --git a/server/api/models/__init__.py b/server/api/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/api/models/models.py b/server/api/models/models.py new file mode 100644 index 00000000..447b9486 --- /dev/null +++ b/server/api/models/models.py @@ -0,0 +1,114 @@ +import asyncio +from pathlib import Path +import os +import logging +from aiohttp import ClientSession + +logger = logging.getLogger(__name__) + +async def list_models(workspace_dir): + models_path = Path(os.path.join(workspace_dir, "models")) + models_path.mkdir(parents=True, exist_ok=True) + os.chdir(models_path) + + model_types = ["checkpoints", "controlnet", "unet", "vae", "onnx", "tensorrt"] + + models = {} + for model_type in model_types: + models[model_type] = [] + model_path = models_path / model_type + if not model_path.exists(): + continue + for model in model_path.iterdir(): + if model.is_dir(): + for submodel in model.iterdir(): + if submodel.is_file(): + model_info = { + "name": submodel.name, + "path": f"{model.name}/{submodel.name}", + "type": model_type, + "downloading": os.path.exists(f"{model_type}/{model.name}/{submodel.name}.downloading") + } + models[model_type].append(model_info) + + if model.is_file(): + model_info = { + "name": model.name, + "path": f"{model.name}", + "type": model_type, + "downloading": os.path.exists(f"{model_type}/{model.name}.downloading") + } + models[model_type].append(model_info) + + return models + +async def add_model(model, workspace_dir): + if not 'url' in model: + raise Exception("model url is required") + if not 'type' in model: + raise Exception("model type is required (e.g. checkpoints, controlnet, unet, vae, onnx, tensorrt)") + + try: + model_name = model['url'].split("/")[-1] + model_path = Path(os.path.join(workspace_dir, "models", model['type'], model_name)) + if 'path' in model: + model_path = Path(os.path.join(workspace_dir, "models", model['type'], model['path'])) + logger.info(f"model path: {model_path}") + + # check path is in workspace_dir, raises value error if not + model_path.resolve().relative_to(Path(os.path.join(workspace_dir, "models"))) + os.makedirs(model_path.parent, exist_ok=True) + # start downloading the model in background without blocking + asyncio.create_task(download_model(model['url'], model_path)) + except Exception as e: + os.remove(model_path)+".downloading" + raise Exception(f"error downloading model: {e}") + +async def delete_model(model, workspace_dir): + if not 'type' in model: + raise Exception("model type is required (e.g. checkpoints, controlnet, unet, vae, onnx, tensorrt)") + if not 'path' in model: + raise Exception("model path is required") + try: + model_path = Path(os.path.join(workspace_dir, "models", model['type'], model['path'])) + #check path is in workspace_dir, raises value error if not + model_path.resolve().relative_to(Path(os.path.join(workspace_dir, "models"))) + + os.remove(model_path) + except Exception as e: + raise Exception(f"error deleting model: {e}") + +async def download_model(url: str, save_path: Path): + try: + temp_file = save_path.with_suffix(save_path.suffix + ".downloading") + print("downloading") + async with ClientSession() as session: + logger.info(f"downloading model from {url} to {save_path}") + # Create empty file to track download in process + model_name = os.path.basename(save_path) + + open(temp_file, "w").close() + + async with session.get(url) as response: + if response.status == 200: + total_size = int(response.headers.get('Content-Length', 0)) + total_downloaded = 0 + last_logged_percentage = -1 # Ensures first log at 1% + with open(save_path, "wb") as f: + while chunk := await response.content.read(4096): # Read in chunks of 1KB + f.write(chunk) + total_downloaded += len(chunk) + # Calculate percentage and log only if it has increased by 1% + percentage = (total_downloaded / total_size) * 100 + if int(percentage) > last_logged_percentage: + last_logged_percentage = int(percentage) + logger.info(f"Downloaded {total_downloaded} of {total_size} bytes ({percentage:.2f}%) of {model_name}") + + #remove download in process file + os.remove(temp_file) + print(f"Model downloaded and saved to {save_path}") + else: + raise print(f"Failed to download model. HTTP Status: {response.status}") + except Exception as e: + #remove download in process file + os.remove(temp_file) \ No newline at end of file diff --git a/server/api/nodes/__init__.py b/server/api/nodes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/api/nodes/nodes.py b/server/api/nodes/nodes.py new file mode 100644 index 00000000..8aab1b36 --- /dev/null +++ b/server/api/nodes/nodes.py @@ -0,0 +1,182 @@ +from pathlib import Path +import os +import json +from git import Repo +import logging +import subprocess +import shutil + +logger = logging.getLogger(__name__) + +async def list_nodes(workspace_dir): + custom_nodes_path = Path(os.path.join(workspace_dir, "custom_nodes")) + custom_nodes_path.mkdir(parents=True, exist_ok=True) + os.chdir(custom_nodes_path) + + nodes = [] + for node in custom_nodes_path.iterdir(): + if node.name == "__pycache__": + continue + + if node.is_dir(): + logger.info(f"getting info for node: { node.name}") + node_info = { + "name": node.name, + "version": "unknown", + "url": "unknown", + "branch": "unknown", + "commit": "unknown", + "update_available": "unknown", + } + + #include VERSION if set in file + version_file = os.path.join(custom_nodes_path, node.name, "VERSION") + if os.path.exists(version_file): + node_info["version"] = json.dumps(open(version_file).readline().strip()) + + #include git info if available + try: + repo = Repo(node) + node_info["url"] = repo.remotes.origin.url.replace(".git","") + node_info["commit"] = repo.head.commit.hexsha[:7] + if not repo.head.is_detached: + node_info["branch"] = repo.active_branch.name + fetch_info = repo.remotes.origin.fetch(repo.active_branch.name) + node_info["update_available"] = repo.head.commit.hexsha[:7] != fetch_info[0].commit.hexsha[:7] + else: + node_info["branch"] = "detached" + + except Exception as e: + logger.info(f"error getting repo info for {node.name} {e}") + + nodes.append(node_info) + + return nodes + +async def install_node(node, workspace_dir): + ''' + install ComfyUI custom node in git repository. + + installs requirements.txt from repository if present + + # Paramaters + url: url of the git repository + branch: branch to install + dependencies: comma separated list of pip dependencies to install + ''' + + custom_nodes_path = Path(os.path.join(workspace_dir, "custom_nodes")) + custom_nodes_path.mkdir(parents=True, exist_ok=True) + os.chdir(custom_nodes_path) + node_url = node.get("url", "") + if node_url == "": + raise ValueError("url is required") + + if not ".git" in node_url: + node_url = f"{node_url}.git" + + try: + dir_name = node_url.split("/")[-1].replace(".git", "") + node_path = custom_nodes_path / dir_name + if not node_path.exists(): + # Clone and install the repository if it doesn't already exist + logger.info(f"installing {dir_name}...") + repo = Repo.clone_from(node["url"], node_path, depth=1) + if "branch" in node: + repo.git.checkout(node['branch']) + else: + # Update the repository if it already exists + logger.info(f"updating node {dir_name}") + repo = Repo(node_path) + repo.remotes.origin.fetch() + branch = node.get("branch", repo.remotes.origin.refs[0].remote_head) + + repo.remotes.origin.pull(branch) + + # Install requirements if present + requirements_file = node_path / "requirements.txt" + if requirements_file.exists(): + subprocess.run(["conda", "run", "-n", "comfystream", "pip", "install", "-r", str(requirements_file)], check=True) + + # Install additional dependencies if specified + if "dependencies" in node: + for dep in node["dependencies"].split(','): + subprocess.run(["conda", "run", "-n", "comfystream", "pip", "install", dep.strip()], check=True) + + except Exception as e: + logger.error(f"Error installing {dir_name} {e}") + raise e + +async def delete_node(node, workspace_dir): + custom_nodes_path = Path(os.path.join(workspace_dir, "custom_nodes")) + custom_nodes_path.mkdir(parents=True, exist_ok=True) + os.chdir(custom_nodes_path) + if "name" not in node: + raise ValueError("name is required") + + node_path = custom_nodes_path / node["name"] + if not node_path.exists(): + raise ValueError(f"node {node['name']} does not exist") + try: + #delete the folder and all its contents. ignore_errors allows readonly files to be deleted + logger.info(f"deleting node {node['name']}") + shutil.rmtree(node_path, ignore_errors=True) + except Exception as e: + logger.error(f"error deleting node {node['name']}") + raise Exception(f"error deleting node: {e}") + + +from comfy.nodes.package import ExportedNodes +from comfy.nodes.package import _comfy_nodes, _import_and_enumerate_nodes_in_module +from functools import reduce +from importlib.metadata import entry_points +import types + +def force_import_all_nodes_in_workspace(vanilla_custom_nodes=True, raise_on_failure=False) -> ExportedNodes: + # now actually import the nodes, to improve control of node loading order + from comfy_extras import nodes as comfy_extras_nodes # pylint: disable=absolute-import-used + from comfy.cli_args import args + from comfy.nodes import base_nodes + from comfy.nodes.vanilla_node_importing import mitigated_import_of_vanilla_custom_nodes + + # only load these nodes once + + base_and_extra = reduce(lambda x, y: x.update(y), + map(lambda module_inner: _import_and_enumerate_nodes_in_module(module_inner, raise_on_failure=raise_on_failure), [ + # this is the list of default nodes to import + base_nodes, + comfy_extras_nodes + ]), + ExportedNodes()) + custom_nodes_mappings = ExportedNodes() + + if args.disable_all_custom_nodes: + logging.info("Loading custom nodes was disabled, only base and extra nodes were loaded") + _comfy_nodes.update(base_and_extra) + return _comfy_nodes + + # load from entrypoints + for entry_point in entry_points().select(group='comfyui.custom_nodes'): + # Load the module associated with the current entry point + try: + module = entry_point.load() + except ModuleNotFoundError as module_not_found_error: + logging.error(f"A module was not found while importing nodes via an entry point: {entry_point}. Please ensure the entry point in setup.py is named correctly", exc_info=module_not_found_error) + continue + + # Ensure that what we've loaded is indeed a module + if isinstance(module, types.ModuleType): + custom_nodes_mappings.update( + _import_and_enumerate_nodes_in_module(module, print_import_times=True)) + + # load the vanilla custom nodes last + if vanilla_custom_nodes: + custom_nodes_mappings += mitigated_import_of_vanilla_custom_nodes() + + # don't allow custom nodes to overwrite base nodes + custom_nodes_mappings -= base_and_extra + + _comfy_nodes.update(base_and_extra + custom_nodes_mappings) + + return _comfy_nodes + diff --git a/server/api/settings/__init__.py b/server/api/settings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/api/settings/settings.py b/server/api/settings/settings.py new file mode 100644 index 00000000..253bd012 --- /dev/null +++ b/server/api/settings/settings.py @@ -0,0 +1,7 @@ +import os + +async def set_twilio_account_info(account_sid, auth_token): + if not account_sid is None: + os.environ["TWILIO_ACCOUNT_SID"] = account_sid + if not auth_token is None: + os.environ["TWILIO_AUTH_TOKEN"] = auth_token diff --git a/server/app.py b/server/app.py index 0d8ffa75..b4637822 100644 --- a/server/app.py +++ b/server/app.py @@ -427,6 +427,10 @@ async def on_shutdown(app: web.Application): ) app.router.add_get("/metrics", app["metrics_manager"].metrics_handler) + #add management api routes + from api.api import add_routes + add_routes(app) + # Add hosted platform route prefix. # NOTE: This ensures that the local and hosted experiences have consistent routes. add_prefix_to_app_routes(app, "/live") @@ -435,4 +439,5 @@ def force_print(*args, **kwargs): print(*args, **kwargs, flush=True) sys.stdout.flush() + web.run_app(app, host=args.host, port=int(args.port), print=force_print) diff --git a/server/utils/utils.py b/server/utils/utils.py index 63c61a58..996a5b8b 100644 --- a/server/utils/utils.py +++ b/server/utils/utils.py @@ -1,12 +1,21 @@ """General utility functions.""" import asyncio +import json import random import types import logging from aiohttp import web +import os +from pathlib import Path +import subprocess +import sys +import requests + from typing import List, Tuple +from git import Repo + logger = logging.getLogger(__name__) @@ -63,3 +72,74 @@ def add_prefix_to_app_routes(app: web.Application, prefix: str): for route in list(app.router.routes()): new_path = prefix + route.resource.canonical app.router.add_route(route.method, new_path, route.handler) + +def list_nodes(workspace_dir): + custom_nodes_path = Path(os.path.join(workspace_dir, "custom_nodes")) + custom_nodes_path.mkdir(parents=True, exist_ok=True) + os.chdir(custom_nodes_path) + + nodes = [] + for node in custom_nodes_path.iterdir(): + if node.is_dir(): + print(f"checking custom_node:{node.name}") + repo = Repo(node) + fetch_info = repo.remotes.origin.fetch(repo.active_branch.name) + + node_info = { + "name": node.name, + "url": repo.remotes.origin.url, + "branch": repo.active_branch.name, + "commit": repo.head.commit.hexsha[:7], + "update_available": repo.head.commit.hexsha != fetch_info[0].commit.hexsha, + } + + try: + with open(node / "node_info.json") as f: + node_info.update(json.load(f)) + except FileNotFoundError: + pass + + nodes.append(node_info) + + return nodes + + +def install_node(node, workspace_dir): + custom_nodes_path = workspace_dir / "custom_nodes" + custom_nodes_path.mkdir(parents=True, exist_ok=True) + os.chdir(custom_nodes_path) + + try: + dir_name = node_info['url'].split("/")[-1].replace(".git", "") + node_path = custom_nodes_path / dir_name + + print(f"Installing {node_info['name']}...") + + # Clone the repository if it doesn't already exist + if not node_path.exists(): + cmd = ["git", "clone", node_info['url']] + if 'branch' in node_info: + cmd.extend(["-b", node_info['branch']]) + subprocess.run(cmd, check=True) + else: + print(f"{node_info['name']} already exists, skipping clone.") + + # Checkout specific commit if branch is a commit hash + if 'branch' in node_info and len(node_info['branch']) == 40: # SHA-1 hash length + subprocess.run(["git", "-C", dir_name, "checkout", node_info['branch']], check=True) + + # Install requirements if present + requirements_file = node_path / "requirements.txt" + if requirements_file.exists(): + subprocess.run([sys.executable, "-m", "pip", "install", "-r", str(requirements_file)], check=True) + + # Install additional dependencies if specified + if 'dependencies' in node_info: + for dep in node_info['dependencies']: + subprocess.run([sys.executable, "-m", "pip", "install", dep], check=True) + + print(f"Installed {node_info['name']}") + except Exception as e: + print(f"Error installing {node_info['name']} {e}") + raise e + return