From 31eacb6ac9a9aef32e513de3257fdadb82c81652 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Fri, 1 Nov 2024 10:40:58 -0700 Subject: [PATCH] Improve compilation of models, adding support for triton --- README.md | 8 +- comfy/cmd/main.py | 20 +- comfy/cmd/main_pre.py | 2 +- comfy/cmd/server.py | 29 ++- comfy/ldm/common_dit.py | 2 +- comfy/model_base.py | 6 +- comfy/model_management.py | 2 + comfy/model_management_types.py | 2 +- comfy/model_patcher.py | 12 +- comfy/ops.py | 15 +- comfy/supported_models.py | 2 +- comfy/supported_models_base.py | 13 +- comfy_extras/nodes/nodes_torch_compile.py | 41 +++- tests/inference/workflows/flux-compile-0.json | 205 ++++++++++++++++++ tests/inference/workflows/flux-compile-1.json | 205 ++++++++++++++++++ 15 files changed, 512 insertions(+), 52 deletions(-) create mode 100644 tests/inference/workflows/flux-compile-0.json create mode 100644 tests/inference/workflows/flux-compile-1.json diff --git a/README.md b/README.md index 4966f879f..12518d6db 100644 --- a/README.md +++ b/README.md @@ -123,16 +123,16 @@ When using Windows, open the **Windows Powershell** app. Then observe you are at ```shell pip install "comfyui[withtorch]@git+https://github.com/hiddenswitch/ComfyUI.git" ``` - **Recommended**: Currently, `torch 2.4.1` is the last version that `xformers` is compatible with. On Windows, install it first, along with `xformers`, for maximum compatibility and the best performance without advanced techniques in ComfyUI: + **Recommended**: Currently, `torch 2.5.0` is the latest version that `xformers` is compatible with. On Windows, install it first, along with `xformers`, for maximum compatibility and the best performance without advanced techniques in ComfyUI: ```shell - pip install torch==2.4.1+cu121 torchvision --index-url https://download.pytorch.org/whl/cu121 - pip install --no-build-isolation --no-deps xformers==0.0.28.post1 --index-url https://download.pytorch.org/whl/ + pip install torch==2.5.1+cu121 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 + pip install --no-build-isolation --no-deps xformers==0.0.28.post3 --index-url https://download.pytorch.org/whl/ pip install comfyui@git+https://github.com/hiddenswitch/ComfyUI.git ``` To enable `torchaudio` support on Windows, install it directly: ```shell - pip install torchaudio==2.4.1+cu121 --index-url https://download.pytorch.org/whl/cu121 + pip install torchaudio==2.5.0+cu121 --index-url https://download.pytorch.org/whl/cu121 ``` **Advanced**: If you are running in Google Collab or another environment which has already installed `torch` for you; or, if you are an application developer: ```shell diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index 2e3fccf75..8bbcd6906 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -24,6 +24,8 @@ from ..distributed.distributed_prompt_queue import DistributedPromptQueue from ..distributed.server_stub import ServerStub from ..nodes.package import import_all_nodes_in_workspace +logger = logging.getLogger(__name__) + def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer): from ..cmd.execution import PromptExecutor @@ -58,7 +60,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer): current_time = time.perf_counter() execution_time = current_time - execution_start_time - logging.debug("Prompt executed in {:.2f} seconds".format(execution_time)) + logger.debug("Prompt executed in {:.2f} seconds".format(execution_time)) flags = q.get_flags() free_memory = flags.get("free_memory", False) @@ -109,7 +111,7 @@ def cuda_malloc_warning(): if b in device_name: cuda_malloc_warning = True if cuda_malloc_warning: - logging.warning( + logger.warning( "\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") @@ -125,13 +127,13 @@ async def main(from_script_dir: Optional[Path] = None): if args.temp_directory: temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp") - logging.debug(f"Setting temp directory to: {temp_dir}") + logger.debug(f"Setting temp directory to: {temp_dir}") folder_paths.set_temp_directory(temp_dir) cleanup_temp() if args.user_directory: user_dir = os.path.abspath(args.user_directory) - logging.info(f"Setting user directory to: {user_dir}") + logger.info(f"Setting user directory to: {user_dir}") folder_paths.set_user_directory(user_dir) # configure extra model paths earlier @@ -193,7 +195,7 @@ async def main(from_script_dir: Optional[Path] = None): worker_thread_server = server if not distributed else ServerStub() if not distributed or args.distributed_queue_worker: if distributed: - logging.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.") + logger.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.") # todo: this should really be using an executor instead of doing things this jankilicious way ctx = contextvars.copy_context() threading.Thread(target=lambda _q, _worker_thread_server: ctx.run(prompt_worker, _q, _worker_thread_server), daemon=True, args=(q, worker_thread_server,)).start() @@ -203,7 +205,7 @@ async def main(from_script_dir: Optional[Path] = None): if args.output_directory: output_dir = os.path.abspath(args.output_directory) - logging.debug(f"Setting output directory to: {output_dir}") + logger.debug(f"Setting output directory to: {output_dir}") folder_paths.set_output_directory(output_dir) # These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes @@ -215,7 +217,7 @@ async def main(from_script_dir: Optional[Path] = None): if args.input_directory: input_dir = os.path.abspath(args.input_directory) - logging.debug(f"Setting input directory to: {input_dir}") + logger.debug(f"Setting input directory to: {input_dir}") folder_paths.set_input_directory(input_dir) if args.quick_test_for_ci: @@ -243,7 +245,7 @@ async def main(from_script_dir: Optional[Path] = None): await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start) except (asyncio.CancelledError, KeyboardInterrupt): - logging.debug("\nStopped server") + logger.debug("\nStopped server") finally: if distributed: await q.close() @@ -254,7 +256,7 @@ def entrypoint(): try: asyncio.run(main()) except KeyboardInterrupt: - logging.info(f"Gracefully shutting down due to KeyboardInterrupt") + logger.info(f"Gracefully shutting down due to KeyboardInterrupt") if __name__ == "__main__": diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 3c7b2e552..9fba54260 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -112,7 +112,7 @@ def _create_tracer(): def _configure_logging(): logging_level = args.logging_level - logging.basicConfig(format="%(message)s", level=logging_level) + logging.basicConfig(level=logging_level) _configure_logging() diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 06f3f990b..c47228fce 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -51,6 +51,7 @@ from ..model_filemanager import download_model, DownloadModelStatus from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version from ..nodes.package_typing import ExportedNodes +logger = logging.getLogger(__name__) class HeuristicPath(NamedTuple): filename_heuristic: str @@ -61,7 +62,7 @@ async def send_socket_catch_exception(function, message): try: await function(message) except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err: - logging.warning("send error: {}".format(err)) + logger.warning("send error: {}".format(err)) def get_comfyui_version(): @@ -144,7 +145,7 @@ def create_origin_only_middleware(): if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0: if host_domain != origin_domain: - logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain)) + logger.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain)) return web.Response(status=403) if request.method == "OPTIONS": @@ -227,7 +228,7 @@ class PromptServer(ExecutorToClientProgress): await self.send("executing", {"node": self.last_node_id}, sid) async for msg in ws: if msg.type == aiohttp.WSMsgType.ERROR: - logging.warning('ws connection closed with exception %s' % ws.exception()) + logger.warning('ws connection closed with exception %s' % ws.exception()) finally: self.sockets.pop(sid, None) return ws @@ -573,8 +574,8 @@ class PromptServer(ExecutorToClientProgress): try: out[x] = node_info(x) except Exception as e: - logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") - logging.error(traceback.format_exc()) + logger.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") + logger.error(traceback.format_exc()) return web.json_response(out) @routes.get("/object_info/{node_class}") @@ -638,7 +639,7 @@ class PromptServer(ExecutorToClientProgress): response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) else: - logging.warning("invalid prompt: {}".format(valid[1])) + logger.warning("invalid prompt: {}".format(valid[1])) return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400) else: return web.json_response({"error": "no prompt", "node_errors": []}, status=400) @@ -708,7 +709,7 @@ class PromptServer(ExecutorToClientProgress): session = self.client_session if session is None: - logging.error("Client session is not initialized") + logger.error("Client session is not initialized") return web.Response(status=500) task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval)) @@ -1026,6 +1027,9 @@ class PromptServer(ExecutorToClientProgress): await self.start_multi_address([(address, port)], call_on_start=call_on_start, verbose=verbose) async def start_multi_address(self, addresses, call_on_start=None, verbose=True): + address_print = "localhost" + port = 8188 + address: str = None runner = web.AppRunner(self.app, access_log=None, keepalive_timeout=900) await runner.setup() for addr in addresses: @@ -1038,14 +1042,15 @@ class PromptServer(ExecutorToClientProgress): self.address = address # TODO: remove this self.port = port - if ':' in address: + if address == '::': + address_print = "localhost" + elif ':' in address: address_print = "[{}]".format(address) else: address_print = address if verbose: - logging.info("Starting server") - logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address_print == "0.0.0.0" else address, port)) + logger.info(f"Server ready. To see the GUI go to: http://{address_print}:{port}") if call_on_start is not None: call_on_start(address, port) @@ -1057,8 +1062,8 @@ class PromptServer(ExecutorToClientProgress): try: json_data = handler(json_data) except Exception as e: - logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing") - logging.warning(traceback.format_exc()) + logger.warning(f"[ERROR] An error occurred during the on_prompt_handler processing") + logger.warning(traceback.format_exc()) return json_data diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py index 3cd69fa73..5bc80d64c 100644 --- a/comfy/ldm/common_dit.py +++ b/comfy/ldm/common_dit.py @@ -1,6 +1,6 @@ import torch -from .. import ops +from comfy import ops def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"): diff --git a/comfy/model_base.py b/comfy/model_base.py index a1e5eebe7..6e72522dd 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -40,6 +40,7 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation +from .ops import Operations class ModelType(Enum): @@ -106,18 +107,21 @@ class BaseModel(torch.nn.Module): self.model_config = model_config self.manual_cast_dtype = model_config.manual_cast_dtype self.device: torch.device = device - + self.operations: Optional[Operations] if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None) operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) else: operations = model_config.custom_operations + self.operations = operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) if model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) + else: + self.operations = None self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) diff --git a/comfy/model_management.py b/comfy/model_management.py index b20756637..c4e441b09 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -933,6 +933,8 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None): def device_supports_non_blocking(device): + if torch.jit.is_tracing() or torch.jit.is_scripting(): + return True if is_device_mps(device): return False # pytorch bug? mps doesn't support non blocking if is_intel_xpu(): diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index b86bafa9e..af804631e 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -101,7 +101,7 @@ class ModelManageable(Protocol): return utils.get_attr(self.model, name) @property - def model_options(self) -> dict: + def model_options(self) -> ModelOptions: if not hasattr(self, "_model_options"): setattr(self, "_model_options", {"transformer_options": {}}) return getattr(self, "_model_options") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index bdf818582..52a5b9264 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -131,14 +131,14 @@ def get_key_weight(model, key): class ModelPatcher(ModelManageable): - def __init__(self, model: torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None): + def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None): self.size = size - self.model: torch.nn.Module | BaseModel = model + self.model: BaseModel | torch.nn.Module = model self.patches = {} self.backup = {} self.object_patches = {} self.object_patches_backup = {} - self._model_options = {"transformer_options": {}} + self._model_options: ModelOptions = {"transformer_options": {}} self.model_size() self.load_device = load_device self.offload_device = offload_device @@ -601,7 +601,11 @@ class ModelPatcher(ModelManageable): return self.current_loaded_device() def __str__(self): - info_str = f"{self.model_dtype()} {self.model_device} {naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)}" + if hasattr(self.model, "operations"): + operations_str = self.model.operations.__name__ + else: + operations_str = None + info_str = f"model_dtype={self.model_dtype()} device={self.model_device} size={naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)} operations={operations_str}" if self.ckpt_name is not None: return f"" else: diff --git a/comfy/ops.py b/comfy/ops.py index 0ff1ed748..941c3f338 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -15,7 +15,7 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ -from typing import Optional +from typing import Optional, Type, Union import torch from torch import Tensor @@ -42,7 +42,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): device = input.device bias = None - non_blocking = model_management.device_supports_non_blocking(device) + non_blocking = True if torch.jit.is_tracing() or torch.jit.is_scripting() else model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = s.bias_function is not None bias = model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) @@ -358,8 +358,12 @@ class fp8_ops(manual_cast): return torch.nn.functional.linear(input, weight, bias) +class scaled_fp8_op_base(manual_cast): + pass + + def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): - class scaled_fp8_op(manual_cast): + class scaled_fp8_op(scaled_fp8_op_base): class Linear(manual_cast.Linear): def __init__(self, *args, **kwargs): if override_dtype is not None: @@ -407,7 +411,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None return scaled_fp8_op -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, inference_mode: Optional[bool] = None): +Operations = Type[Union[manual_cast, fp8_ops, disable_weight_init, skip_init, scaled_fp8_op_base]] + + +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8: Optional[torch.dtype] = None, inference_mode: Optional[bool] = None) -> Operations: if inference_mode is None: # todo: check a context here, since this isn't being used by any callers yet inference_mode = current_execution_context().inference_mode diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c68f85b2a..9159fdcbf 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -640,7 +640,7 @@ class Flux(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.Flux - memory_usage_factor = 2.8 + memory_usage_factor = 2.4 supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 54573abb1..e8aa64ad2 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -15,11 +15,14 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ +from typing import Optional import torch from . import model_base from . import utils from . import latent_formats +from .ops import Operations + class ClipTarget: def __init__(self, tokenizer, clip): @@ -36,8 +39,8 @@ class BASE: required_keys = {} - clip_prefix = [] - clip_vision_prefix = None + clip_prefix: list[str] = [] + clip_vision_prefix: Optional[str] = None noise_aug_config = None sampling_settings = {} latent_format = latent_formats.LatentFormat @@ -47,9 +50,9 @@ class BASE: memory_usage_factor = 2.0 - manual_cast_dtype = None - custom_operations = None - scaled_fp8 = None + manual_cast_dtype: Optional[torch.dtype] = None + custom_operations: Optional[Operations] = None + scaled_fp8: Optional[torch.dtype] = None optimizations = {"fp8": False} @classmethod diff --git a/comfy_extras/nodes/nodes_torch_compile.py b/comfy_extras/nodes/nodes_torch_compile.py index fe0d683a2..d158c058c 100644 --- a/comfy_extras/nodes/nodes_torch_compile.py +++ b/comfy_extras/nodes/nodes_torch_compile.py @@ -11,6 +11,8 @@ from comfy import model_management from comfy.model_patcher import ModelPatcher from comfy.nodes.package_typing import CustomNode, InputTypes +logger = logging.getLogger(__name__) + DIFFUSION_MODEL = "diffusion_model" TORCH_COMPILE_BACKENDS = [ "inductor", @@ -55,18 +57,18 @@ class TorchCompileModel(CustomNode): "fullgraph": ("BOOLEAN", {"default": False}), "dynamic": ("BOOLEAN", {"default": False}), "backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}), - "mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"}) + "mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"}), + "torch_tensorrt_optimization_level": ("INT", {"default": 3, "min": 1, "max": 5}) } } RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - # INFERENCE_MODE = False CATEGORY = "_for_testing" EXPERIMENTAL = True - def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune") -> tuple[ModelPatcher]: + def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[ModelPatcher]: if object_patch is None: object_patch = DIFFUSION_MODEL compile_kwargs = { @@ -75,19 +77,41 @@ class TorchCompileModel(CustomNode): "backend": backend, "mode": mode, } + move_to_gpu = False try: if backend == "torch_tensorrt": + try: + import torch_tensorrt + except (ImportError, ModuleNotFoundError): + logger.error(f"Install torch-tensorrt and modelopt") + raise compile_kwargs["options"] = { # https://pytorch.org/TensorRT/dynamo/torch_compile.html # Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers. - "enabled_precisions": {torch.float, torch.half} + "enabled_precisions": {torch.float, torch.half}, + "optimization_level": torch_tensorrt_optimization_level, + "cache_built_engines": True, + "reuse_cached_engines": True, + "enable_weight_streaming": True, + "make_refittable": True, } + move_to_gpu = True + del compile_kwargs["mode"] if isinstance(model, ModelPatcher): m = model.clone() + if move_to_gpu: + model_management.load_models_gpu([m]) m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs)) + if move_to_gpu: + model_management.unload_model_clones(m) return (m,) elif isinstance(model, torch.nn.Module): - return torch.compile(model=model, **compile_kwargs), + if move_to_gpu: + model.to(device=model_management.get_torch_device()) + res = torch.compile(model=model, **compile_kwargs), + if move_to_gpu: + model.to(device=model_management.unet_offload_device()) + return res else: logging.warning("Encountered a model that cannot be compiled") return model, @@ -119,7 +143,6 @@ class QuantizeModel(CustomNode): FUNCTION = "execute" CATEGORY = "_for_testing" EXPERIMENTAL = True - # INFERENCE_MODE = False RETURN_TYPES = ("MODEL",) @@ -152,10 +175,10 @@ class QuantizeModel(CustomNode): quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list) _in_place_fixme = unet elif "torchao" in strategy: - from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass, autoquant # pylint: disable=import-error - model = model.clone() + from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, autoquant # pylint: disable=import-error + from torchao.utils import unwrap_tensor_subclass self.warn_in_place(model) - unet = model.get_model_object("diffusion_model") + model_management.load_models_gpu([model]) def filter(module: torch.nn.Module, fqn: str) -> bool: return isinstance(module, torch.nn.Linear) and not any(prefix in fqn for prefix in always_exclude_these) diff --git a/tests/inference/workflows/flux-compile-0.json b/tests/inference/workflows/flux-compile-0.json new file mode 100644 index 000000000..7295e338b --- /dev/null +++ b/tests/inference/workflows/flux-compile-0.json @@ -0,0 +1,205 @@ +{ + "8": { + "inputs": { + "samples": [ + "16", + 0 + ], + "vae": [ + "14", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "nike/nike_images_", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "13": { + "inputs": { + "clip_name1": "clip_l.safetensors", + "clip_name2": "t5xxl_fp16.safetensors", + "type": "flux" + }, + "class_type": "DualCLIPLoader", + "_meta": { + "title": "DualCLIPLoader" + } + }, + "14": { + "inputs": { + "vae_name": "ae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "15": { + "inputs": { + "text": "A photoreal image of a Nike Air Force 1 shoe in black, with a red Nike swoosh and red sole. The interior of the shoe is blue, and the laces are bright green.", + "clip": [ + "13", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "16": { + "inputs": { + "noise": [ + "17", + 0 + ], + "guider": [ + "18", + 0 + ], + "sampler": [ + "21", + 0 + ], + "sigmas": [ + "22", + 0 + ], + "latent_image": [ + "49", + 0 + ] + }, + "class_type": "SamplerCustomAdvanced", + "_meta": { + "title": "SamplerCustomAdvanced" + } + }, + "17": { + "inputs": { + "noise_seed": 1 + }, + "class_type": "RandomNoise", + "_meta": { + "title": "RandomNoise" + } + }, + "18": { + "inputs": { + "model": [ + "51", + 0 + ], + "conditioning": [ + "19", + 0 + ] + }, + "class_type": "BasicGuider", + "_meta": { + "title": "BasicGuider" + } + }, + "19": { + "inputs": { + "guidance": 5, + "conditioning": [ + "15", + 0 + ] + }, + "class_type": "FluxGuidance", + "_meta": { + "title": "FluxGuidance" + } + }, + "21": { + "inputs": { + "sampler_name": "euler" + }, + "class_type": "KSamplerSelect", + "_meta": { + "title": "KSamplerSelect" + } + }, + "22": { + "inputs": { + "scheduler": "normal", + "steps": 20, + "denoise": 1, + "model": [ + "51", + 0 + ] + }, + "class_type": "BasicScheduler", + "_meta": { + "title": "BasicScheduler" + } + }, + "23": { + "inputs": { + "vae": [ + "14", + 0 + ] + }, + "class_type": "VAEEncode", + "_meta": { + "title": "VAE Encode" + } + }, + "49": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptySD3LatentImage", + "_meta": { + "title": "EmptySD3LatentImage" + } + }, + "51": { + "inputs": { + "object_patch": "diffusion_model", + "fullgraph": false, + "dynamic": false, + "backend": "inductor", + "mode": "reduce-overhead", + "torch_tensorrt_optimization_level": 3, + "model": [ + "53", + 0 + ] + }, + "class_type": "TorchCompileModel", + "_meta": { + "title": "TorchCompileModel" + } + }, + "53": { + "inputs": { + "unet_name": "flux1-dev.safetensors", + "weight_dtype": "default" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + } +} \ No newline at end of file diff --git a/tests/inference/workflows/flux-compile-1.json b/tests/inference/workflows/flux-compile-1.json new file mode 100644 index 000000000..1a2cc97cf --- /dev/null +++ b/tests/inference/workflows/flux-compile-1.json @@ -0,0 +1,205 @@ +{ + "8": { + "inputs": { + "samples": [ + "16", + 0 + ], + "vae": [ + "14", + 0 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "9": { + "inputs": { + "filename_prefix": "nike/nike_images_", + "images": [ + "8", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + }, + "13": { + "inputs": { + "clip_name1": "clip_l.safetensors", + "clip_name2": "t5xxl_fp16.safetensors", + "type": "flux" + }, + "class_type": "DualCLIPLoader", + "_meta": { + "title": "DualCLIPLoader" + } + }, + "14": { + "inputs": { + "vae_name": "ae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "15": { + "inputs": { + "text": "A photoreal image of a Nike Air Force 1 shoe in black, with a red Nike swoosh and red sole. The interior of the shoe is blue, and the laces are bright green.", + "clip": [ + "13", + 0 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "16": { + "inputs": { + "noise": [ + "17", + 0 + ], + "guider": [ + "18", + 0 + ], + "sampler": [ + "21", + 0 + ], + "sigmas": [ + "22", + 0 + ], + "latent_image": [ + "49", + 0 + ] + }, + "class_type": "SamplerCustomAdvanced", + "_meta": { + "title": "SamplerCustomAdvanced" + } + }, + "17": { + "inputs": { + "noise_seed": 2 + }, + "class_type": "RandomNoise", + "_meta": { + "title": "RandomNoise" + } + }, + "18": { + "inputs": { + "model": [ + "51", + 0 + ], + "conditioning": [ + "19", + 0 + ] + }, + "class_type": "BasicGuider", + "_meta": { + "title": "BasicGuider" + } + }, + "19": { + "inputs": { + "guidance": 5, + "conditioning": [ + "15", + 0 + ] + }, + "class_type": "FluxGuidance", + "_meta": { + "title": "FluxGuidance" + } + }, + "21": { + "inputs": { + "sampler_name": "euler" + }, + "class_type": "KSamplerSelect", + "_meta": { + "title": "KSamplerSelect" + } + }, + "22": { + "inputs": { + "scheduler": "normal", + "steps": 20, + "denoise": 1, + "model": [ + "51", + 0 + ] + }, + "class_type": "BasicScheduler", + "_meta": { + "title": "BasicScheduler" + } + }, + "23": { + "inputs": { + "vae": [ + "14", + 0 + ] + }, + "class_type": "VAEEncode", + "_meta": { + "title": "VAE Encode" + } + }, + "49": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptySD3LatentImage", + "_meta": { + "title": "EmptySD3LatentImage" + } + }, + "51": { + "inputs": { + "object_patch": "diffusion_model", + "fullgraph": false, + "dynamic": false, + "backend": "inductor", + "mode": "reduce-overhead", + "torch_tensorrt_optimization_level": 3, + "model": [ + "53", + 0 + ] + }, + "class_type": "TorchCompileModel", + "_meta": { + "title": "TorchCompileModel" + } + }, + "53": { + "inputs": { + "unet_name": "flux1-dev-fp8.safetensors", + "weight_dtype": "fp8_e4m3fn" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + } +} \ No newline at end of file