mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 14:50:49 +08:00
Improve compilation of models, adding support for triton
This commit is contained in:
parent
02d186b0c6
commit
31eacb6ac9
@ -123,16 +123,16 @@ When using Windows, open the **Windows Powershell** app. Then observe you are at
|
||||
```shell
|
||||
pip install "comfyui[withtorch]@git+https://github.com/hiddenswitch/ComfyUI.git"
|
||||
```
|
||||
**Recommended**: Currently, `torch 2.4.1` is the last version that `xformers` is compatible with. On Windows, install it first, along with `xformers`, for maximum compatibility and the best performance without advanced techniques in ComfyUI:
|
||||
**Recommended**: Currently, `torch 2.5.0` is the latest version that `xformers` is compatible with. On Windows, install it first, along with `xformers`, for maximum compatibility and the best performance without advanced techniques in ComfyUI:
|
||||
```shell
|
||||
pip install torch==2.4.1+cu121 torchvision --index-url https://download.pytorch.org/whl/cu121
|
||||
pip install --no-build-isolation --no-deps xformers==0.0.28.post1 --index-url https://download.pytorch.org/whl/
|
||||
pip install torch==2.5.1+cu121 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
pip install --no-build-isolation --no-deps xformers==0.0.28.post3 --index-url https://download.pytorch.org/whl/
|
||||
pip install comfyui@git+https://github.com/hiddenswitch/ComfyUI.git
|
||||
```
|
||||
|
||||
To enable `torchaudio` support on Windows, install it directly:
|
||||
```shell
|
||||
pip install torchaudio==2.4.1+cu121 --index-url https://download.pytorch.org/whl/cu121
|
||||
pip install torchaudio==2.5.0+cu121 --index-url https://download.pytorch.org/whl/cu121
|
||||
```
|
||||
**Advanced**: If you are running in Google Collab or another environment which has already installed `torch` for you; or, if you are an application developer:
|
||||
```shell
|
||||
|
||||
@ -24,6 +24,8 @@ from ..distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
from ..distributed.server_stub import ServerStub
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
|
||||
from ..cmd.execution import PromptExecutor
|
||||
@ -58,7 +60,7 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
logging.debug("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
logger.debug("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
@ -109,7 +111,7 @@ def cuda_malloc_warning():
|
||||
if b in device_name:
|
||||
cuda_malloc_warning = True
|
||||
if cuda_malloc_warning:
|
||||
logging.warning(
|
||||
logger.warning(
|
||||
"\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
@ -125,13 +127,13 @@ async def main(from_script_dir: Optional[Path] = None):
|
||||
|
||||
if args.temp_directory:
|
||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||
logging.debug(f"Setting temp directory to: {temp_dir}")
|
||||
logger.debug(f"Setting temp directory to: {temp_dir}")
|
||||
folder_paths.set_temp_directory(temp_dir)
|
||||
cleanup_temp()
|
||||
|
||||
if args.user_directory:
|
||||
user_dir = os.path.abspath(args.user_directory)
|
||||
logging.info(f"Setting user directory to: {user_dir}")
|
||||
logger.info(f"Setting user directory to: {user_dir}")
|
||||
folder_paths.set_user_directory(user_dir)
|
||||
|
||||
# configure extra model paths earlier
|
||||
@ -193,7 +195,7 @@ async def main(from_script_dir: Optional[Path] = None):
|
||||
worker_thread_server = server if not distributed else ServerStub()
|
||||
if not distributed or args.distributed_queue_worker:
|
||||
if distributed:
|
||||
logging.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.")
|
||||
logger.warning(f"Distributed workers started in the default thread loop cannot notify clients of progress updates. Instead of comfyui or main.py, use comfyui-worker.")
|
||||
# todo: this should really be using an executor instead of doing things this jankilicious way
|
||||
ctx = contextvars.copy_context()
|
||||
threading.Thread(target=lambda _q, _worker_thread_server: ctx.run(prompt_worker, _q, _worker_thread_server), daemon=True, args=(q, worker_thread_server,)).start()
|
||||
@ -203,7 +205,7 @@ async def main(from_script_dir: Optional[Path] = None):
|
||||
|
||||
if args.output_directory:
|
||||
output_dir = os.path.abspath(args.output_directory)
|
||||
logging.debug(f"Setting output directory to: {output_dir}")
|
||||
logger.debug(f"Setting output directory to: {output_dir}")
|
||||
folder_paths.set_output_directory(output_dir)
|
||||
|
||||
# These are the default folders that checkpoints, clip and vae models will be saved to when using CheckpointSave, etc.. nodes
|
||||
@ -215,7 +217,7 @@ async def main(from_script_dir: Optional[Path] = None):
|
||||
|
||||
if args.input_directory:
|
||||
input_dir = os.path.abspath(args.input_directory)
|
||||
logging.debug(f"Setting input directory to: {input_dir}")
|
||||
logger.debug(f"Setting input directory to: {input_dir}")
|
||||
folder_paths.set_input_directory(input_dir)
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
@ -243,7 +245,7 @@ async def main(from_script_dir: Optional[Path] = None):
|
||||
await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server,
|
||||
call_on_start=call_on_start)
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
logging.debug("\nStopped server")
|
||||
logger.debug("\nStopped server")
|
||||
finally:
|
||||
if distributed:
|
||||
await q.close()
|
||||
@ -254,7 +256,7 @@ def entrypoint():
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
logging.info(f"Gracefully shutting down due to KeyboardInterrupt")
|
||||
logger.info(f"Gracefully shutting down due to KeyboardInterrupt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -112,7 +112,7 @@ def _create_tracer():
|
||||
|
||||
def _configure_logging():
|
||||
logging_level = args.logging_level
|
||||
logging.basicConfig(format="%(message)s", level=logging_level)
|
||||
logging.basicConfig(level=logging_level)
|
||||
|
||||
|
||||
_configure_logging()
|
||||
|
||||
@ -51,6 +51,7 @@ from ..model_filemanager import download_model, DownloadModelStatus
|
||||
from ..model_management import get_torch_device, get_torch_device_name, get_total_memory, get_free_memory, torch_version
|
||||
from ..nodes.package_typing import ExportedNodes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class HeuristicPath(NamedTuple):
|
||||
filename_heuristic: str
|
||||
@ -61,7 +62,7 @@ async def send_socket_catch_exception(function, message):
|
||||
try:
|
||||
await function(message)
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
|
||||
logging.warning("send error: {}".format(err))
|
||||
logger.warning("send error: {}".format(err))
|
||||
|
||||
|
||||
def get_comfyui_version():
|
||||
@ -144,7 +145,7 @@ def create_origin_only_middleware():
|
||||
|
||||
if loopback and host_domain is not None and origin_domain is not None and len(host_domain) > 0 and len(origin_domain) > 0:
|
||||
if host_domain != origin_domain:
|
||||
logging.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||
logger.warning("WARNING: request with non matching host and origin {} != {}, returning 403".format(host_domain, origin_domain))
|
||||
return web.Response(status=403)
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
@ -227,7 +228,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
await self.send("executing", {"node": self.last_node_id}, sid)
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.ERROR:
|
||||
logging.warning('ws connection closed with exception %s' % ws.exception())
|
||||
logger.warning('ws connection closed with exception %s' % ws.exception())
|
||||
finally:
|
||||
self.sockets.pop(sid, None)
|
||||
return ws
|
||||
@ -573,8 +574,8 @@ class PromptServer(ExecutorToClientProgress):
|
||||
try:
|
||||
out[x] = node_info(x)
|
||||
except Exception as e:
|
||||
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||
logging.error(traceback.format_exc())
|
||||
logger.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||
logger.error(traceback.format_exc())
|
||||
return web.json_response(out)
|
||||
|
||||
@routes.get("/object_info/{node_class}")
|
||||
@ -638,7 +639,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
return web.json_response(response)
|
||||
else:
|
||||
logging.warning("invalid prompt: {}".format(valid[1]))
|
||||
logger.warning("invalid prompt: {}".format(valid[1]))
|
||||
return web.json_response({"error": valid[1], "node_errors": valid[3]}, status=400)
|
||||
else:
|
||||
return web.json_response({"error": "no prompt", "node_errors": []}, status=400)
|
||||
@ -708,7 +709,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
session = self.client_session
|
||||
if session is None:
|
||||
logging.error("Client session is not initialized")
|
||||
logger.error("Client session is not initialized")
|
||||
return web.Response(status=500)
|
||||
|
||||
task = asyncio.create_task(download_model(lambda url: session.get(url), model_filename, url, model_directory, folder_path, report_progress, progress_interval))
|
||||
@ -1026,6 +1027,9 @@ class PromptServer(ExecutorToClientProgress):
|
||||
await self.start_multi_address([(address, port)], call_on_start=call_on_start, verbose=verbose)
|
||||
|
||||
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
|
||||
address_print = "localhost"
|
||||
port = 8188
|
||||
address: str = None
|
||||
runner = web.AppRunner(self.app, access_log=None, keepalive_timeout=900)
|
||||
await runner.setup()
|
||||
for addr in addresses:
|
||||
@ -1038,14 +1042,15 @@ class PromptServer(ExecutorToClientProgress):
|
||||
self.address = address # TODO: remove this
|
||||
self.port = port
|
||||
|
||||
if ':' in address:
|
||||
if address == '::':
|
||||
address_print = "localhost"
|
||||
elif ':' in address:
|
||||
address_print = "[{}]".format(address)
|
||||
else:
|
||||
address_print = address
|
||||
|
||||
if verbose:
|
||||
logging.info("Starting server")
|
||||
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address_print == "0.0.0.0" else address, port))
|
||||
logger.info(f"Server ready. To see the GUI go to: http://{address_print}:{port}")
|
||||
if call_on_start is not None:
|
||||
call_on_start(address, port)
|
||||
|
||||
@ -1057,8 +1062,8 @@ class PromptServer(ExecutorToClientProgress):
|
||||
try:
|
||||
json_data = handler(json_data)
|
||||
except Exception as e:
|
||||
logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
|
||||
logging.warning(traceback.format_exc())
|
||||
logger.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
return json_data
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from .. import ops
|
||||
from comfy import ops
|
||||
|
||||
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
|
||||
@ -40,6 +40,7 @@ from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from .ops import Operations
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
@ -106,18 +107,21 @@ class BaseModel(torch.nn.Module):
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
self.device: torch.device = device
|
||||
|
||||
self.operations: Optional[Operations]
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
|
||||
operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.operations = operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
if model_management.force_channels_last():
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
logging.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||
else:
|
||||
self.operations = None
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
|
||||
@ -933,6 +933,8 @@ def pick_weight_dtype(dtype, fallback_dtype, device=None):
|
||||
|
||||
|
||||
def device_supports_non_blocking(device):
|
||||
if torch.jit.is_tracing() or torch.jit.is_scripting():
|
||||
return True
|
||||
if is_device_mps(device):
|
||||
return False # pytorch bug? mps doesn't support non blocking
|
||||
if is_intel_xpu():
|
||||
|
||||
@ -101,7 +101,7 @@ class ModelManageable(Protocol):
|
||||
return utils.get_attr(self.model, name)
|
||||
|
||||
@property
|
||||
def model_options(self) -> dict:
|
||||
def model_options(self) -> ModelOptions:
|
||||
if not hasattr(self, "_model_options"):
|
||||
setattr(self, "_model_options", {"transformer_options": {}})
|
||||
return getattr(self, "_model_options")
|
||||
|
||||
@ -131,14 +131,14 @@ def get_key_weight(model, key):
|
||||
|
||||
|
||||
class ModelPatcher(ModelManageable):
|
||||
def __init__(self, model: torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||
def __init__(self, model: BaseModel | torch.nn.Module, load_device: torch.device, offload_device: torch.device, size=0, weight_inplace_update=False, ckpt_name: Optional[str] = None):
|
||||
self.size = size
|
||||
self.model: torch.nn.Module | BaseModel = model
|
||||
self.model: BaseModel | torch.nn.Module = model
|
||||
self.patches = {}
|
||||
self.backup = {}
|
||||
self.object_patches = {}
|
||||
self.object_patches_backup = {}
|
||||
self._model_options = {"transformer_options": {}}
|
||||
self._model_options: ModelOptions = {"transformer_options": {}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
@ -601,7 +601,11 @@ class ModelPatcher(ModelManageable):
|
||||
return self.current_loaded_device()
|
||||
|
||||
def __str__(self):
|
||||
info_str = f"{self.model_dtype()} {self.model_device} {naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)}"
|
||||
if hasattr(self.model, "operations"):
|
||||
operations_str = self.model.operations.__name__
|
||||
else:
|
||||
operations_str = None
|
||||
info_str = f"model_dtype={self.model_dtype()} device={self.model_device} size={naturalsize(self._memory_measurements.model_loaded_weight_memory, binary=True)} operations={operations_str}"
|
||||
if self.ckpt_name is not None:
|
||||
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__} {info_str})>"
|
||||
else:
|
||||
|
||||
15
comfy/ops.py
15
comfy/ops.py
@ -15,7 +15,7 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from typing import Optional
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -42,7 +42,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
device = input.device
|
||||
|
||||
bias = None
|
||||
non_blocking = model_management.device_supports_non_blocking(device)
|
||||
non_blocking = True if torch.jit.is_tracing() or torch.jit.is_scripting() else model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = s.bias_function is not None
|
||||
bias = model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
@ -358,8 +358,12 @@ class fp8_ops(manual_cast):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
class scaled_fp8_op_base(manual_cast):
|
||||
pass
|
||||
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class scaled_fp8_op(scaled_fp8_op_base):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if override_dtype is not None:
|
||||
@ -407,7 +411,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
||||
return scaled_fp8_op
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, inference_mode: Optional[bool] = None):
|
||||
Operations = Type[Union[manual_cast, fp8_ops, disable_weight_init, skip_init, scaled_fp8_op_base]]
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8: Optional[torch.dtype] = None, inference_mode: Optional[bool] = None) -> Operations:
|
||||
if inference_mode is None:
|
||||
# todo: check a context here, since this isn't being used by any callers yet
|
||||
inference_mode = current_execution_context().inference_mode
|
||||
|
||||
@ -640,7 +640,7 @@ class Flux(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
|
||||
memory_usage_factor = 2.8
|
||||
memory_usage_factor = 2.4
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
|
||||
@ -15,11 +15,14 @@
|
||||
You should have received a copy of the GNU General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from . import model_base
|
||||
from . import utils
|
||||
from . import latent_formats
|
||||
from .ops import Operations
|
||||
|
||||
|
||||
class ClipTarget:
|
||||
def __init__(self, tokenizer, clip):
|
||||
@ -36,8 +39,8 @@ class BASE:
|
||||
|
||||
required_keys = {}
|
||||
|
||||
clip_prefix = []
|
||||
clip_vision_prefix = None
|
||||
clip_prefix: list[str] = []
|
||||
clip_vision_prefix: Optional[str] = None
|
||||
noise_aug_config = None
|
||||
sampling_settings = {}
|
||||
latent_format = latent_formats.LatentFormat
|
||||
@ -47,9 +50,9 @@ class BASE:
|
||||
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
manual_cast_dtype: Optional[torch.dtype] = None
|
||||
custom_operations: Optional[Operations] = None
|
||||
scaled_fp8: Optional[torch.dtype] = None
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -11,6 +11,8 @@ from comfy import model_management
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DIFFUSION_MODEL = "diffusion_model"
|
||||
TORCH_COMPILE_BACKENDS = [
|
||||
"inductor",
|
||||
@ -55,18 +57,18 @@ class TorchCompileModel(CustomNode):
|
||||
"fullgraph": ("BOOLEAN", {"default": False}),
|
||||
"dynamic": ("BOOLEAN", {"default": False}),
|
||||
"backend": (TORCH_COMPILE_BACKENDS, {"default": "inductor"}),
|
||||
"mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"})
|
||||
"mode": (TORCH_COMPILE_MODES, {"default": "max-autotune"}),
|
||||
"torch_tensorrt_optimization_level": ("INT", {"default": 3, "min": 1, "max": 5})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
# INFERENCE_MODE = False
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune") -> tuple[ModelPatcher]:
|
||||
def patch(self, model: ModelPatcher, object_patch: str | None = DIFFUSION_MODEL, fullgraph: bool = False, dynamic: bool = False, backend: str = "inductor", mode: str = "max-autotune", torch_tensorrt_optimization_level: int = 3) -> tuple[ModelPatcher]:
|
||||
if object_patch is None:
|
||||
object_patch = DIFFUSION_MODEL
|
||||
compile_kwargs = {
|
||||
@ -75,19 +77,41 @@ class TorchCompileModel(CustomNode):
|
||||
"backend": backend,
|
||||
"mode": mode,
|
||||
}
|
||||
move_to_gpu = False
|
||||
try:
|
||||
if backend == "torch_tensorrt":
|
||||
try:
|
||||
import torch_tensorrt
|
||||
except (ImportError, ModuleNotFoundError):
|
||||
logger.error(f"Install torch-tensorrt and modelopt")
|
||||
raise
|
||||
compile_kwargs["options"] = {
|
||||
# https://pytorch.org/TensorRT/dynamo/torch_compile.html
|
||||
# Quantization/INT8 support is slated for a future release; currently, we support FP16 and FP32 precision layers.
|
||||
"enabled_precisions": {torch.float, torch.half}
|
||||
"enabled_precisions": {torch.float, torch.half},
|
||||
"optimization_level": torch_tensorrt_optimization_level,
|
||||
"cache_built_engines": True,
|
||||
"reuse_cached_engines": True,
|
||||
"enable_weight_streaming": True,
|
||||
"make_refittable": True,
|
||||
}
|
||||
move_to_gpu = True
|
||||
del compile_kwargs["mode"]
|
||||
if isinstance(model, ModelPatcher):
|
||||
m = model.clone()
|
||||
if move_to_gpu:
|
||||
model_management.load_models_gpu([m])
|
||||
m.add_object_patch(object_patch, torch.compile(model=m.get_model_object(object_patch), **compile_kwargs))
|
||||
if move_to_gpu:
|
||||
model_management.unload_model_clones(m)
|
||||
return (m,)
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
return torch.compile(model=model, **compile_kwargs),
|
||||
if move_to_gpu:
|
||||
model.to(device=model_management.get_torch_device())
|
||||
res = torch.compile(model=model, **compile_kwargs),
|
||||
if move_to_gpu:
|
||||
model.to(device=model_management.unet_offload_device())
|
||||
return res
|
||||
else:
|
||||
logging.warning("Encountered a model that cannot be compiled")
|
||||
return model,
|
||||
@ -119,7 +143,6 @@ class QuantizeModel(CustomNode):
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "_for_testing"
|
||||
EXPERIMENTAL = True
|
||||
# INFERENCE_MODE = False
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
|
||||
@ -152,10 +175,10 @@ class QuantizeModel(CustomNode):
|
||||
quantize(unet, weights=qint8, activations=qint8, exclude=exclusion_list)
|
||||
_in_place_fixme = unet
|
||||
elif "torchao" in strategy:
|
||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, unwrap_tensor_subclass, autoquant # pylint: disable=import-error
|
||||
model = model.clone()
|
||||
from torchao.quantization import quantize_, int8_dynamic_activation_int8_weight, autoquant # pylint: disable=import-error
|
||||
from torchao.utils import unwrap_tensor_subclass
|
||||
self.warn_in_place(model)
|
||||
unet = model.get_model_object("diffusion_model")
|
||||
model_management.load_models_gpu([model])
|
||||
|
||||
def filter(module: torch.nn.Module, fqn: str) -> bool:
|
||||
return isinstance(module, torch.nn.Linear) and not any(prefix in fqn for prefix in always_exclude_these)
|
||||
|
||||
205
tests/inference/workflows/flux-compile-0.json
Normal file
205
tests/inference/workflows/flux-compile-0.json
Normal file
@ -0,0 +1,205 @@
|
||||
{
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"16",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"14",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "nike/nike_images_",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"clip_name1": "clip_l.safetensors",
|
||||
"clip_name2": "t5xxl_fp16.safetensors",
|
||||
"type": "flux"
|
||||
},
|
||||
"class_type": "DualCLIPLoader",
|
||||
"_meta": {
|
||||
"title": "DualCLIPLoader"
|
||||
}
|
||||
},
|
||||
"14": {
|
||||
"inputs": {
|
||||
"vae_name": "ae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "Load VAE"
|
||||
}
|
||||
},
|
||||
"15": {
|
||||
"inputs": {
|
||||
"text": "A photoreal image of a Nike Air Force 1 shoe in black, with a red Nike swoosh and red sole. The interior of the shoe is blue, and the laces are bright green.",
|
||||
"clip": [
|
||||
"13",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"16": {
|
||||
"inputs": {
|
||||
"noise": [
|
||||
"17",
|
||||
0
|
||||
],
|
||||
"guider": [
|
||||
"18",
|
||||
0
|
||||
],
|
||||
"sampler": [
|
||||
"21",
|
||||
0
|
||||
],
|
||||
"sigmas": [
|
||||
"22",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"49",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SamplerCustomAdvanced",
|
||||
"_meta": {
|
||||
"title": "SamplerCustomAdvanced"
|
||||
}
|
||||
},
|
||||
"17": {
|
||||
"inputs": {
|
||||
"noise_seed": 1
|
||||
},
|
||||
"class_type": "RandomNoise",
|
||||
"_meta": {
|
||||
"title": "RandomNoise"
|
||||
}
|
||||
},
|
||||
"18": {
|
||||
"inputs": {
|
||||
"model": [
|
||||
"51",
|
||||
0
|
||||
],
|
||||
"conditioning": [
|
||||
"19",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "BasicGuider",
|
||||
"_meta": {
|
||||
"title": "BasicGuider"
|
||||
}
|
||||
},
|
||||
"19": {
|
||||
"inputs": {
|
||||
"guidance": 5,
|
||||
"conditioning": [
|
||||
"15",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "FluxGuidance",
|
||||
"_meta": {
|
||||
"title": "FluxGuidance"
|
||||
}
|
||||
},
|
||||
"21": {
|
||||
"inputs": {
|
||||
"sampler_name": "euler"
|
||||
},
|
||||
"class_type": "KSamplerSelect",
|
||||
"_meta": {
|
||||
"title": "KSamplerSelect"
|
||||
}
|
||||
},
|
||||
"22": {
|
||||
"inputs": {
|
||||
"scheduler": "normal",
|
||||
"steps": 20,
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"51",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "BasicScheduler",
|
||||
"_meta": {
|
||||
"title": "BasicScheduler"
|
||||
}
|
||||
},
|
||||
"23": {
|
||||
"inputs": {
|
||||
"vae": [
|
||||
"14",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEEncode",
|
||||
"_meta": {
|
||||
"title": "VAE Encode"
|
||||
}
|
||||
},
|
||||
"49": {
|
||||
"inputs": {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"_meta": {
|
||||
"title": "EmptySD3LatentImage"
|
||||
}
|
||||
},
|
||||
"51": {
|
||||
"inputs": {
|
||||
"object_patch": "diffusion_model",
|
||||
"fullgraph": false,
|
||||
"dynamic": false,
|
||||
"backend": "inductor",
|
||||
"mode": "reduce-overhead",
|
||||
"torch_tensorrt_optimization_level": 3,
|
||||
"model": [
|
||||
"53",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "TorchCompileModel",
|
||||
"_meta": {
|
||||
"title": "TorchCompileModel"
|
||||
}
|
||||
},
|
||||
"53": {
|
||||
"inputs": {
|
||||
"unet_name": "flux1-dev.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "Load Diffusion Model"
|
||||
}
|
||||
}
|
||||
}
|
||||
205
tests/inference/workflows/flux-compile-1.json
Normal file
205
tests/inference/workflows/flux-compile-1.json
Normal file
@ -0,0 +1,205 @@
|
||||
{
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"16",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"14",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE Decode"
|
||||
}
|
||||
},
|
||||
"9": {
|
||||
"inputs": {
|
||||
"filename_prefix": "nike/nike_images_",
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
},
|
||||
"13": {
|
||||
"inputs": {
|
||||
"clip_name1": "clip_l.safetensors",
|
||||
"clip_name2": "t5xxl_fp16.safetensors",
|
||||
"type": "flux"
|
||||
},
|
||||
"class_type": "DualCLIPLoader",
|
||||
"_meta": {
|
||||
"title": "DualCLIPLoader"
|
||||
}
|
||||
},
|
||||
"14": {
|
||||
"inputs": {
|
||||
"vae_name": "ae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "Load VAE"
|
||||
}
|
||||
},
|
||||
"15": {
|
||||
"inputs": {
|
||||
"text": "A photoreal image of a Nike Air Force 1 shoe in black, with a red Nike swoosh and red sole. The interior of the shoe is blue, and the laces are bright green.",
|
||||
"clip": [
|
||||
"13",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Prompt)"
|
||||
}
|
||||
},
|
||||
"16": {
|
||||
"inputs": {
|
||||
"noise": [
|
||||
"17",
|
||||
0
|
||||
],
|
||||
"guider": [
|
||||
"18",
|
||||
0
|
||||
],
|
||||
"sampler": [
|
||||
"21",
|
||||
0
|
||||
],
|
||||
"sigmas": [
|
||||
"22",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"49",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SamplerCustomAdvanced",
|
||||
"_meta": {
|
||||
"title": "SamplerCustomAdvanced"
|
||||
}
|
||||
},
|
||||
"17": {
|
||||
"inputs": {
|
||||
"noise_seed": 2
|
||||
},
|
||||
"class_type": "RandomNoise",
|
||||
"_meta": {
|
||||
"title": "RandomNoise"
|
||||
}
|
||||
},
|
||||
"18": {
|
||||
"inputs": {
|
||||
"model": [
|
||||
"51",
|
||||
0
|
||||
],
|
||||
"conditioning": [
|
||||
"19",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "BasicGuider",
|
||||
"_meta": {
|
||||
"title": "BasicGuider"
|
||||
}
|
||||
},
|
||||
"19": {
|
||||
"inputs": {
|
||||
"guidance": 5,
|
||||
"conditioning": [
|
||||
"15",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "FluxGuidance",
|
||||
"_meta": {
|
||||
"title": "FluxGuidance"
|
||||
}
|
||||
},
|
||||
"21": {
|
||||
"inputs": {
|
||||
"sampler_name": "euler"
|
||||
},
|
||||
"class_type": "KSamplerSelect",
|
||||
"_meta": {
|
||||
"title": "KSamplerSelect"
|
||||
}
|
||||
},
|
||||
"22": {
|
||||
"inputs": {
|
||||
"scheduler": "normal",
|
||||
"steps": 20,
|
||||
"denoise": 1,
|
||||
"model": [
|
||||
"51",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "BasicScheduler",
|
||||
"_meta": {
|
||||
"title": "BasicScheduler"
|
||||
}
|
||||
},
|
||||
"23": {
|
||||
"inputs": {
|
||||
"vae": [
|
||||
"14",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEEncode",
|
||||
"_meta": {
|
||||
"title": "VAE Encode"
|
||||
}
|
||||
},
|
||||
"49": {
|
||||
"inputs": {
|
||||
"width": 1024,
|
||||
"height": 1024,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"_meta": {
|
||||
"title": "EmptySD3LatentImage"
|
||||
}
|
||||
},
|
||||
"51": {
|
||||
"inputs": {
|
||||
"object_patch": "diffusion_model",
|
||||
"fullgraph": false,
|
||||
"dynamic": false,
|
||||
"backend": "inductor",
|
||||
"mode": "reduce-overhead",
|
||||
"torch_tensorrt_optimization_level": 3,
|
||||
"model": [
|
||||
"53",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "TorchCompileModel",
|
||||
"_meta": {
|
||||
"title": "TorchCompileModel"
|
||||
}
|
||||
},
|
||||
"53": {
|
||||
"inputs": {
|
||||
"unet_name": "flux1-dev-fp8.safetensors",
|
||||
"weight_dtype": "fp8_e4m3fn"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "Load Diffusion Model"
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user