mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
initial checkin
This commit is contained in:
parent
e525673f72
commit
769dcd1a04
@ -130,19 +130,68 @@ class BaseModel(torch.nn.Module):
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
self.device = device
|
||||
self.current_patcher: 'ModelPatcher' = None
|
||||
|
||||
self.enable_trt = True
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
breakpoint()
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
dtype = self.diffusion_model.dtype
|
||||
|
||||
# trt_rtx cannot handle bfloat16, so convert to float16, float32 always convert to float16 for trt_rtx to save memory
|
||||
if self.enable_trt and dtype in (torch.float32, torch.bfloat16, torch.float16):
|
||||
self.diffusion_model = self.diffusion_model.half()
|
||||
unet_config["dtype"] = torch.float16
|
||||
self.diffusion_model.dtype = torch.float16
|
||||
logging.debug(f"converted diffusion modelfrom {dtype} to float16 for trt_rtx")
|
||||
else:
|
||||
self.enable_trt = False
|
||||
logging.warning("trt_rtx cannot handle ${dtype}, so disabling trt_rtx")
|
||||
|
||||
self.diffusion_model.eval()
|
||||
if comfy.model_management.force_channels_last():
|
||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||
logging.debug("using channels last mode for diffusion model")
|
||||
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||
|
||||
if self.enable_trt:
|
||||
import torch_tensorrt
|
||||
settings = {
|
||||
"use_python_runtime": False,
|
||||
"immutable_weights": False,
|
||||
"offload_to_cpu": True,
|
||||
}
|
||||
self.trt_compiled_diffusion_model = torch_tensorrt.MutableTorchTensorRTModule(self.diffusion_model, **settings)
|
||||
# TODO: INVESTIGATE WHY DYNAMIC SHAPE IS NOT WORKING
|
||||
enable_trt_dynamic_shape = False
|
||||
if enable_trt_dynamic_shape:
|
||||
# if batch size is 2, then sigmas-batch is 2, dim_batch is 4
|
||||
sigmas_batch = torch.export.Dim("sigmas_batch", min=1, max=20)
|
||||
dim_batch = torch.export.Dim("batch", min=1, max=40)
|
||||
|
||||
#dim_width = torch.export.Dim("width", min=3, max=64)
|
||||
#dim_height = torch.export.Dim("height", min=5, max=64)
|
||||
# args: xc, t
|
||||
args_dynamic_shapes=({0: dim_batch}, {0: dim_batch},)
|
||||
#args_dynamic_shapes=({0: dim_batch, 2: dim_width*4, 3: dim_height*4}, {0: dim_batch},)
|
||||
# kwargs: context, transformer_options, y
|
||||
kwargs_dynamic_shape = {
|
||||
'context': {0: dim_batch},
|
||||
'transformer_options': {
|
||||
'wrappers': {},
|
||||
'callbacks': {},
|
||||
'sample_sigmas': {},
|
||||
#'cond_or_uncond': {},
|
||||
'sigmas': {0:sigmas_batch},
|
||||
},
|
||||
'y': {0: dim_batch,},
|
||||
}
|
||||
self.trt_compiled_diffusion_model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwargs_dynamic_shape)
|
||||
logging.debug("lan added ********** trt_model: trt_compiled_diffusion_model is created")
|
||||
|
||||
self.model_type = model_type
|
||||
self.model_sampling = model_sampling(model_config, model_type)
|
||||
|
||||
@ -199,8 +248,23 @@ class BaseModel(torch.nn.Module):
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
|
||||
logging.debug(f"lan added *** {xc.shape=} {t.shape=} {context.shape=} {extra_conds['y'].shape=} {transformer_options['sample_sigmas'].shape=} {transformer_options['sigmas'].shape=}")
|
||||
logging.debug(f"lan added *** {transformer_options=}")
|
||||
if control is not None:
|
||||
logging.debug(f"lan added ***{control.shape=}")
|
||||
else:
|
||||
logging.debug("lan added ***control is None")
|
||||
logging.debug(f"lan added ***{extra_conds=}")
|
||||
if self.enable_trt:
|
||||
transformer_options.pop("uuids", None)
|
||||
transformer_options.pop("cond_or_uncond", None)
|
||||
with torch.no_grad():
|
||||
model_output = self.trt_compiled_diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
else:
|
||||
transformer_options.pop("uuids", None)
|
||||
transformer_options.pop("cond_or_uncond", None)
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user