From 769dcd1a046f6ee7964c2598ef5d9e57c4844df5 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 30 Oct 2025 12:18:50 -0700 Subject: [PATCH] initial checkin --- comfy/model_base.py | 70 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 67 insertions(+), 3 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7c788d085..d6a7727ed 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -130,19 +130,68 @@ class BaseModel(torch.nn.Module): self.manual_cast_dtype = model_config.manual_cast_dtype self.device = device self.current_patcher: 'ModelPatcher' = None - + self.enable_trt = True if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) else: operations = model_config.custom_operations + breakpoint() self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) + dtype = self.diffusion_model.dtype + + # trt_rtx cannot handle bfloat16, so convert to float16, float32 always convert to float16 for trt_rtx to save memory + if self.enable_trt and dtype in (torch.float32, torch.bfloat16, torch.float16): + self.diffusion_model = self.diffusion_model.half() + unet_config["dtype"] = torch.float16 + self.diffusion_model.dtype = torch.float16 + logging.debug(f"converted diffusion modelfrom {dtype} to float16 for trt_rtx") + else: + self.enable_trt = False + logging.warning("trt_rtx cannot handle ${dtype}, so disabling trt_rtx") + self.diffusion_model.eval() if comfy.model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) + + if self.enable_trt: + import torch_tensorrt + settings = { + "use_python_runtime": False, + "immutable_weights": False, + "offload_to_cpu": True, + } + self.trt_compiled_diffusion_model = torch_tensorrt.MutableTorchTensorRTModule(self.diffusion_model, **settings) + # TODO: INVESTIGATE WHY DYNAMIC SHAPE IS NOT WORKING + enable_trt_dynamic_shape = False + if enable_trt_dynamic_shape: + # if batch size is 2, then sigmas-batch is 2, dim_batch is 4 + sigmas_batch = torch.export.Dim("sigmas_batch", min=1, max=20) + dim_batch = torch.export.Dim("batch", min=1, max=40) + + #dim_width = torch.export.Dim("width", min=3, max=64) + #dim_height = torch.export.Dim("height", min=5, max=64) + # args: xc, t + args_dynamic_shapes=({0: dim_batch}, {0: dim_batch},) + #args_dynamic_shapes=({0: dim_batch, 2: dim_width*4, 3: dim_height*4}, {0: dim_batch},) + # kwargs: context, transformer_options, y + kwargs_dynamic_shape = { + 'context': {0: dim_batch}, + 'transformer_options': { + 'wrappers': {}, + 'callbacks': {}, + 'sample_sigmas': {}, + #'cond_or_uncond': {}, + 'sigmas': {0:sigmas_batch}, + }, + 'y': {0: dim_batch,}, + } + self.trt_compiled_diffusion_model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwargs_dynamic_shape) + logging.debug("lan added ********** trt_model: trt_compiled_diffusion_model is created") + self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -199,8 +248,23 @@ class BaseModel(torch.nn.Module): t = self.process_timestep(t, x=x, **extra_conds) if "latent_shapes" in extra_conds: xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) - - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) + + logging.debug(f"lan added *** {xc.shape=} {t.shape=} {context.shape=} {extra_conds['y'].shape=} {transformer_options['sample_sigmas'].shape=} {transformer_options['sigmas'].shape=}") + logging.debug(f"lan added *** {transformer_options=}") + if control is not None: + logging.debug(f"lan added ***{control.shape=}") + else: + logging.debug("lan added ***control is None") + logging.debug(f"lan added ***{extra_conds=}") + if self.enable_trt: + transformer_options.pop("uuids", None) + transformer_options.pop("cond_or_uncond", None) + with torch.no_grad(): + model_output = self.trt_compiled_diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) + else: + transformer_options.pop("uuids", None) + transformer_options.pop("cond_or_uncond", None) + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) if len(model_output) > 1 and not torch.is_tensor(model_output): model_output, _ = utils.pack_latents(model_output)