initial checkin

This commit is contained in:
lanluo-nvidia 2025-10-30 12:18:50 -07:00
parent e525673f72
commit 769dcd1a04

View File

@ -130,19 +130,68 @@ class BaseModel(torch.nn.Module):
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device
self.current_patcher: 'ModelPatcher' = None
self.enable_trt = True
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else:
operations = model_config.custom_operations
breakpoint()
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
dtype = self.diffusion_model.dtype
# trt_rtx cannot handle bfloat16, so convert to float16, float32 always convert to float16 for trt_rtx to save memory
if self.enable_trt and dtype in (torch.float32, torch.bfloat16, torch.float16):
self.diffusion_model = self.diffusion_model.half()
unet_config["dtype"] = torch.float16
self.diffusion_model.dtype = torch.float16
logging.debug(f"converted diffusion modelfrom {dtype} to float16 for trt_rtx")
else:
self.enable_trt = False
logging.warning("trt_rtx cannot handle ${dtype}, so disabling trt_rtx")
self.diffusion_model.eval()
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
if self.enable_trt:
import torch_tensorrt
settings = {
"use_python_runtime": False,
"immutable_weights": False,
"offload_to_cpu": True,
}
self.trt_compiled_diffusion_model = torch_tensorrt.MutableTorchTensorRTModule(self.diffusion_model, **settings)
# TODO: INVESTIGATE WHY DYNAMIC SHAPE IS NOT WORKING
enable_trt_dynamic_shape = False
if enable_trt_dynamic_shape:
# if batch size is 2, then sigmas-batch is 2, dim_batch is 4
sigmas_batch = torch.export.Dim("sigmas_batch", min=1, max=20)
dim_batch = torch.export.Dim("batch", min=1, max=40)
#dim_width = torch.export.Dim("width", min=3, max=64)
#dim_height = torch.export.Dim("height", min=5, max=64)
# args: xc, t
args_dynamic_shapes=({0: dim_batch}, {0: dim_batch},)
#args_dynamic_shapes=({0: dim_batch, 2: dim_width*4, 3: dim_height*4}, {0: dim_batch},)
# kwargs: context, transformer_options, y
kwargs_dynamic_shape = {
'context': {0: dim_batch},
'transformer_options': {
'wrappers': {},
'callbacks': {},
'sample_sigmas': {},
#'cond_or_uncond': {},
'sigmas': {0:sigmas_batch},
},
'y': {0: dim_batch,},
}
self.trt_compiled_diffusion_model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwargs_dynamic_shape)
logging.debug("lan added ********** trt_model: trt_compiled_diffusion_model is created")
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@ -199,8 +248,23 @@ class BaseModel(torch.nn.Module):
t = self.process_timestep(t, x=x, **extra_conds)
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
logging.debug(f"lan added *** {xc.shape=} {t.shape=} {context.shape=} {extra_conds['y'].shape=} {transformer_options['sample_sigmas'].shape=} {transformer_options['sigmas'].shape=}")
logging.debug(f"lan added *** {transformer_options=}")
if control is not None:
logging.debug(f"lan added ***{control.shape=}")
else:
logging.debug("lan added ***control is None")
logging.debug(f"lan added ***{extra_conds=}")
if self.enable_trt:
transformer_options.pop("uuids", None)
transformer_options.pop("cond_or_uncond", None)
with torch.no_grad():
model_output = self.trt_compiled_diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
else:
transformer_options.pop("uuids", None)
transformer_options.pop("cond_or_uncond", None)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)