Improve compilation of models, adding support for triton

This commit is contained in:
doctorpangloss 2024-11-01 10:40:58 -07:00
parent 02d186b0c6
commit 31eacb6ac9
15 changed files with 512 additions and 52 deletions

View File

@ -123,16 +123,16 @@ When using Windows, open the **Windows Powershell** app. Then observe you are at
```shell
pip install "comfyui[withtorch]@git+https://github.com/hiddenswitch/ComfyUI.git"
```
**Recommended**: Currently, `torch 2.4.1` is the last version that `xformers` is compatible with. On Windows, install it first, along with `xformers`, for maximum compatibility and the best performance without advanced techniques in ComfyUI:
**Recommended**: Currently, `torch 2.5.0` is the latest version that `xformers` is compatible with. On Windows, install it first, along with `xformers`, for maximum compatibility and the best performance without advanced techniques in ComfyUI:
```shell
pip install torch==2.4.1+cu121 torchvision --index-url https://download.pytorch.org/whl/cu121
pip install --no-build-isolation --no-deps xformers==0.0.28.post1 --index-url https://download.pytorch.org/whl/
pip install torch==2.5.1+cu121 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
pip install --no-build-isolation --no-deps xformers==0.0.28.post3 --index-url https://download.pytorch.org/whl/
pip install comfyui@git+https://github.com/hiddenswitch/ComfyUI.git
```
To enable `torchaudio` support on Windows, install it directly:
```shell
pip install torchaudio==2.4.1+cu121 --index-url https://download.pytorch.org/whl/cu121
pip install torchaudio==2.5.0+cu121 --index-url https://download.pytorch.org/whl/cu121
```
**Advanced**: If you are running in Google Collab or another environment which has already installed `torch` for you; or, if you are an application developer:
```shell

View File

@ -24,6 +24,8 @@ from ..distributed.distributed_prompt_queue import DistributedPromptQueue
from ..distributed.server_stub import ServerStub
from ..nodes.package import import_all_nodes_in_workspace
logger = logging.getLogger(__name__)
def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
from ..cmd.execution import PromptExecutor
@ -58,7 +60,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
current_time = time.perf_counter()
execution_time = current_time - execution_start_time
logging.debug("Prompt executed in {:.2f} seconds".format(execution_time))
logger.debug("Prompt executed in {:.2f} seconds".format(execution_time))
flags = q.get_flags()
free_memory = flags.get("free_memory", False)
@ -109,7 +111,7 @@ def cuda_malloc_warning():
if b in device_name:
cuda_malloc_warning = True
if cuda_malloc_warning:
logging.warning(
logger.warning(
"\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
@ -125,13 +127,13 @@ async def main(from_script_dir: Optional[Path] = None):
if args.temp_directory:
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
logging.debug(f"Setting temp directory to: {temp_dir}")
logger.debug(f"Setting temp directory to: {temp_dir}")
folder_paths.set_temp_directory(temp_dir)
cleanup_temp()
if args.user_directory:
user_dir = os.path.abspath(args.user_directory)
logging.info(f"Setting user directory to: {user_dir}")
logger.info(f"Setting user directory to: {user_dir}")
folder_paths.set_user_directory(user_dir)
# configure extra model paths earlier
@ -193,7 +195,7 @@ async def main(from_script_dir: Optional[Path] = None):
worker_thread_server = server if not distributed else ServerStub()
if not distributed or args.distributed_queue_worker:
if distributed:
logging.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.")
logger.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.")
# todo: this should really be using an executor instead of doing things this jankilicious way
ctx = contextvars.copy_context()
threading.Thread(target=lambda _q, _worker_thread_server: ctx.run(prompt_worker, _q, _worker_thread_server), daemon=True, args=(q, worker_thread_server,)).start()
@ -203,7 +205,7 @@ async def main(from_script_dir: Optional[Path] = None):
if args.output_directory:
output_dir = os.path.abspath(args.output_directory)
logging.debug(f"Setting output directory to: {output_dir}")
logger.debug(f"Setting output directory to: {output_dir}")
folder_paths.set_output_directory(output_dir)
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
@ -215,7 +217,7 @@ async def main(from_script_dir: Optional[Path] = None):
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)
logging.debug(f"Setting input directory to: {input_dir}")
logger.debug(f"Setting input directory to: {input_dir}")
folder_paths.set_input_directory(input_dir)
if args.quick_test_for_ci:
@ -243,7 +245,7 @@ async def main(from_script_dir: Optional[Path] = None):
await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server,
call_on_start=call_on_start)
except (asyncio.CancelledError, KeyboardInterrupt):
logging.debug("\nStopped server")
logger.debug("\nStopped server")
finally:
if distributed:
await q.close()
@ -254,7 +256,7 @@ def entrypoint():
try:
asyncio.run(main())
except KeyboardInterrupt:
logging.info(f"Gracefully shutting down due to KeyboardInterrupt")
logger.info(f"Gracefully shutting down due to KeyboardInterrupt")
if __name__ == "__main__":

View File

@ -112,7 +112,7 @@ def _create_tracer():
def _configure_logging():
logging_level = args.logging_level
logging.basicConfig(format="%(message)s", level=logging_level)
logging.basicConfig(level=logging_level)
_configure_logging()

View File

@ -51,6 +51,7 @@ from ..model_filemanager import download_model, DownloadModelStatus
from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version
from ..nodes.package_typing import ExportedNodes
logger = logging.getLogger(__name__)
class HeuristicPath(NamedTuple):
filename_heuristic: str
@ -61,7 +62,7 @@ async def send_socket_catch_exception(function, message):
try:
await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
logging.warning("send error: {}".format(err))
logger.warning("send error: {}".format(err))
def get_comfyui_version():
@ -144,7 +145,7 @@ def create_origin_only_middleware():
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
if host_domain != origin_domain:
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
logger.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
return web.Response(status=403)
if request.method == "OPTIONS":
@ -227,7 +228,7 @@ class PromptServer(ExecutorToClientProgress):
await self.send("executing", {"node": self.last_node_id}, sid)
async for msg in ws:
if msg.type == aiohttp.WSMsgType.ERROR:
logging.warning('ws connection closed with exception %s' % ws.exception())
logger.warning('ws connection closed with exception %s' % ws.exception())
finally:
self.sockets.pop(sid, None)
return ws
@ -573,8 +574,8 @@ class PromptServer(ExecutorToClientProgress):
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
logger.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logger.error(traceback.format_exc())
return web.json_response(out)
@routes.get("/object_info/{node_class}")
@ -638,7 +639,7 @@ class PromptServer(ExecutorToClientProgress):
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
return web.json_response(response)
else:
logging.warning("invalid prompt: {}".format(valid[1]))
logger.warning("invalid prompt: {}".format(valid[1]))
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
else:
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
@ -708,7 +709,7 @@ class PromptServer(ExecutorToClientProgress):
session = self.client_session
if session is None:
logging.error("Client session is not initialized")
logger.error("Client session is not initialized")
return web.Response(status=500)
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
@ -1026,6 +1027,9 @@ class PromptServer(ExecutorToClientProgress):
await self.start_multi_address([(address, port)], call_on_start=call_on_start, verbose=verbose)
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
address_print = "localhost"
port = 8188
address: str = None
runner = web.AppRunner(self.app, access_log=None, keepalive_timeout=900)
await runner.setup()
for addr in addresses:
@ -1038,14 +1042,15 @@ class PromptServer(ExecutorToClientProgress):
self.address = address # TODO: remove this
self.port = port
if ':' in address:
if address == '::':
address_print = "localhost"
elif ':' in address:
address_print = "[{}]".format(address)
else:
address_print = address
if verbose:
logging.info("Starting server")
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address_print == "0.0.0.0" else address, port))
logger.info(f"Server ready. To see the GUI go to: http://{address_print}:{port}")
if call_on_start is not None:
call_on_start(address, port)
@ -1057,8 +1062,8 @@ class PromptServer(ExecutorToClientProgress):
try:
json_data = handler(json_data)
except Exception as e:
logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
logging.warning(traceback.format_exc())
logger.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
logger.warning(traceback.format_exc())
return json_data

View File

@ -1,6 +1,6 @@
import torch
from .. import ops
from comfy import ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):

View File

@ -40,6 +40,7 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
from .ops import Operations
class ModelType(Enum):
@ -106,18 +107,21 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device: torch.device = device
self.operations: Optional[Operations]
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
self.operations = operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
else:
self.operations = None
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)

View File

@ -933,6 +933,8 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
def device_supports_non_blocking(device):
if torch.jit.is_tracing() or torch.jit.is_scripting():
return True
if is_device_mps(device):
return False # pytorch bug? mps doesn't support non blocking
if is_intel_xpu():

View File

@ -101,7 +101,7 @@ class ModelManageable(Protocol):
return utils.get_attr(self.model, name)
@property
def model_options(self) -> dict:
def model_options(self) -> ModelOptions:
if not hasattr(self, "_model_options"):
setattr(self, "_model_options", {"transformer_options": {}})
return getattr(self, "_model_options")

View File

@ -131,14 +131,14 @@ def get_key_weight(model, key):
class ModelPatcher(ModelManageable):
def __init__(self, model: torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
self.size = size
self.model: torch.nn.Module | BaseModel = model
self.model: BaseModel | torch.nn.Module = model
self.patches = {}
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
self._model_options = {"transformer_options": {}}
self._model_options: ModelOptions = {"transformer_options": {}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
@ -601,7 +601,11 @@ class ModelPatcher(ModelManageable):
return self.current_loaded_device()
def __str__(self):
info_str = f"{self.model_dtype()} {self.model_device} {naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)}"
if hasattr(self.model, "operations"):
operations_str = self.model.operations.__name__
else:
operations_str = None
info_str = f"model_dtype={self.model_dtype()} device={self.model_device} size={naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)} operations={operations_str}"
if self.ckpt_name is not None:
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__} {info_str})>"
else:

View File

@ -15,7 +15,7 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from typing import Optional
from typing import Optional, Type, Union
import torch
from torch import Tensor
@ -42,7 +42,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
device = input.device
bias = None
non_blocking = model_management.device_supports_non_blocking(device)
non_blocking = True if torch.jit.is_tracing() or torch.jit.is_scripting() else model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = s.bias_function is not None
bias = model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
@ -358,8 +358,12 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias)
class scaled_fp8_op_base(manual_cast):
pass
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
class scaled_fp8_op(manual_cast):
class scaled_fp8_op(scaled_fp8_op_base):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
if override_dtype is not None:
@ -407,7 +411,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
return scaled_fp8_op
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, inference_mode: Optional[bool] = None):
Operations = Type[Union[manual_cast, fp8_ops, disable_weight_init, skip_init, scaled_fp8_op_base]]
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8: Optional[torch.dtype] = None, inference_mode: Optional[bool] = None) -> Operations:
if inference_mode is None:
# todo: check a context here, since this isn't being used by any callers yet
inference_mode = current_execution_context().inference_mode

View File

@ -640,7 +640,7 @@ class Flux(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Flux
memory_usage_factor = 2.8
memory_usage_factor = 2.4
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]

View File

@ -15,11 +15,14 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from typing import Optional
import torch
from . import model_base
from . import utils
from . import latent_formats
from .ops import Operations
class ClipTarget:
def __init__(self, tokenizer, clip):
@ -36,8 +39,8 @@ class BASE:
required_keys = {}
clip_prefix = []
clip_vision_prefix = None
clip_prefix: list[str] = []
clip_vision_prefix: Optional[str] = None
noise_aug_config = None
sampling_settings = {}
latent_format = latent_formats.LatentFormat
@ -47,9 +50,9 @@ class BASE:
memory_usage_factor = 2.0
manual_cast_dtype = None
custom_operations = None
scaled_fp8 = None
manual_cast_dtype: Optional[torch.dtype] = None
custom_operations: Optional[Operations] = None
scaled_fp8: Optional[torch.dtype] = None
optimizations = {"fp8": False}
@classmethod

View File

@ -11,6 +11,8 @@ from comfy import model_management
from comfy.model_patcher import ModelPatcher
from comfy.nodes.package_typing import CustomNode, InputTypes
logger = logging.getLogger(__name__)
DIFFUSION_MODEL = "diffusion_model"
TORCH_COMPILE_BACKENDS = [
"inductor",
@ -55,18 +57,18 @@ class TorchCompileModel(CustomNode):
"fullgraph": ("BOOLEAN", {"default": False}),
"dynamic": ("BOOLEAN", {"default": False}),
"backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}),
"mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"})
"mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"}),
"torch_tensorrt_optimization_level": ("INT", {"default": 3, "min": 1, "max": 5})
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
# INFERENCE_MODE = False
CATEGORY = "_for_testing"
EXPERIMENTAL = True
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune") -> tuple[ModelPatcher]:
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[ModelPatcher]:
if object_patch is None:
object_patch = DIFFUSION_MODEL
compile_kwargs = {
@ -75,19 +77,41 @@ class TorchCompileModel(CustomNode):
"backend": backend,
"mode": mode,
}
move_to_gpu = False
try:
if backend == "torch_tensorrt":
try:
import torch_tensorrt
except (ImportError, ModuleNotFoundError):
logger.error(f"Install torch-tensorrt and modelopt")
raise
compile_kwargs["options"] = {
# https://pytorch.org/TensorRT/dynamo/torch_compile.html
# Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
"enabled_precisions": {torch.float, torch.half}
"enabled_precisions": {torch.float, torch.half},
"optimization_level": torch_tensorrt_optimization_level,
"cache_built_engines": True,
"reuse_cached_engines": True,
"enable_weight_streaming": True,
"make_refittable": True,
}
move_to_gpu = True
del compile_kwargs["mode"]
if isinstance(model, ModelPatcher):
m = model.clone()
if move_to_gpu:
model_management.load_models_gpu([m])
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
if move_to_gpu:
model_management.unload_model_clones(m)
return (m,)
elif isinstance(model, torch.nn.Module):
return torch.compile(model=model, **compile_kwargs),
if move_to_gpu:
model.to(device=model_management.get_torch_device())
res = torch.compile(model=model, **compile_kwargs),
if move_to_gpu:
model.to(device=model_management.unet_offload_device())
return res
else:
logging.warning("Encountered a model that cannot be compiled")
return model,
@ -119,7 +143,6 @@ class QuantizeModel(CustomNode):
FUNCTION = "execute"
CATEGORY = "_for_testing"
EXPERIMENTAL = True
# INFERENCE_MODE = False
RETURN_TYPES = ("MODEL",)
@ -152,10 +175,10 @@ class QuantizeModel(CustomNode):
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
_in_place_fixme = unet
elif "torchao" in strategy:
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass, autoquant # pylint: disable=import-error
model = model.clone()
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, autoquant # pylint: disable=import-error
from torchao.utils import unwrap_tensor_subclass
self.warn_in_place(model)
unet = model.get_model_object("diffusion_model")
model_management.load_models_gpu([model])
def filter(module: torch.nn.Module, fqn: str) -> bool:
return isinstance(module, torch.nn.Linear) and not any(prefix in fqn for prefix in always_exclude_these)

View File

@ -0,0 +1,205 @@
{
"8": {
"inputs": {
"samples": [
"16",
0
],
"vae": [
"14",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "nike/nike_images_",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"13": {
"inputs": {
"clip_name1": "clip_l.safetensors",
"clip_name2": "t5xxl_fp16.safetensors",
"type": "flux"
},
"class_type": "DualCLIPLoader",
"_meta": {
"title": "DualCLIPLoader"
}
},
"14": {
"inputs": {
"vae_name": "ae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"15": {
"inputs": {
"text": "A photoreal image of a Nike Air Force 1 shoe in black, with a red Nike swoosh and red sole. The interior of the shoe is blue, and the laces are bright green.",
"clip": [
"13",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"16": {
"inputs": {
"noise": [
"17",
0
],
"guider": [
"18",
0
],
"sampler": [
"21",
0
],
"sigmas": [
"22",
0
],
"latent_image": [
"49",
0
]
},
"class_type": "SamplerCustomAdvanced",
"_meta": {
"title": "SamplerCustomAdvanced"
}
},
"17": {
"inputs": {
"noise_seed": 1
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
},
"18": {
"inputs": {
"model": [
"51",
0
],
"conditioning": [
"19",
0
]
},
"class_type": "BasicGuider",
"_meta": {
"title": "BasicGuider"
}
},
"19": {
"inputs": {
"guidance": 5,
"conditioning": [
"15",
0
]
},
"class_type": "FluxGuidance",
"_meta": {
"title": "FluxGuidance"
}
},
"21": {
"inputs": {
"sampler_name": "euler"
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"22": {
"inputs": {
"scheduler": "normal",
"steps": 20,
"denoise": 1,
"model": [
"51",
0
]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
"23": {
"inputs": {
"vae": [
"14",
0
]
},
"class_type": "VAEEncode",
"_meta": {
"title": "VAE Encode"
}
},
"49": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptySD3LatentImage",
"_meta": {
"title": "EmptySD3LatentImage"
}
},
"51": {
"inputs": {
"object_patch": "diffusion_model",
"fullgraph": false,
"dynamic": false,
"backend": "inductor",
"mode": "reduce-overhead",
"torch_tensorrt_optimization_level": 3,
"model": [
"53",
0
]
},
"class_type": "TorchCompileModel",
"_meta": {
"title": "TorchCompileModel"
}
},
"53": {
"inputs": {
"unet_name": "flux1-dev.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
}
}

View File

@ -0,0 +1,205 @@
{
"8": {
"inputs": {
"samples": [
"16",
0
],
"vae": [
"14",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "nike/nike_images_",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
},
"13": {
"inputs": {
"clip_name1": "clip_l.safetensors",
"clip_name2": "t5xxl_fp16.safetensors",
"type": "flux"
},
"class_type": "DualCLIPLoader",
"_meta": {
"title": "DualCLIPLoader"
}
},
"14": {
"inputs": {
"vae_name": "ae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"15": {
"inputs": {
"text": "A photoreal image of a Nike Air Force 1 shoe in black, with a red Nike swoosh and red sole. The interior of the shoe is blue, and the laces are bright green.",
"clip": [
"13",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"16": {
"inputs": {
"noise": [
"17",
0
],
"guider": [
"18",
0
],
"sampler": [
"21",
0
],
"sigmas": [
"22",
0
],
"latent_image": [
"49",
0
]
},
"class_type": "SamplerCustomAdvanced",
"_meta": {
"title": "SamplerCustomAdvanced"
}
},
"17": {
"inputs": {
"noise_seed": 2
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
},
"18": {
"inputs": {
"model": [
"51",
0
],
"conditioning": [
"19",
0
]
},
"class_type": "BasicGuider",
"_meta": {
"title": "BasicGuider"
}
},
"19": {
"inputs": {
"guidance": 5,
"conditioning": [
"15",
0
]
},
"class_type": "FluxGuidance",
"_meta": {
"title": "FluxGuidance"
}
},
"21": {
"inputs": {
"sampler_name": "euler"
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"22": {
"inputs": {
"scheduler": "normal",
"steps": 20,
"denoise": 1,
"model": [
"51",
0
]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
"23": {
"inputs": {
"vae": [
"14",
0
]
},
"class_type": "VAEEncode",
"_meta": {
"title": "VAE Encode"
}
},
"49": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptySD3LatentImage",
"_meta": {
"title": "EmptySD3LatentImage"
}
},
"51": {
"inputs": {
"object_patch": "diffusion_model",
"fullgraph": false,
"dynamic": false,
"backend": "inductor",
"mode": "reduce-overhead",
"torch_tensorrt_optimization_level": 3,
"model": [
"53",
0
]
},
"class_type": "TorchCompileModel",
"_meta": {
"title": "TorchCompileModel"
}
},
"53": {
"inputs": {
"unet_name": "flux1-dev-fp8.safetensors",
"weight_dtype": "fp8_e4m3fn"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
}
}