mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-13 23:12:35 +08:00
initial checkin
This commit is contained in:
parent
e525673f72
commit
769dcd1a04
@ -130,19 +130,68 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
self.current_patcher: 'ModelPatcher' = None
|
self.current_patcher: 'ModelPatcher' = None
|
||||||
|
self.enable_trt = True
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
fp8 = model_config.optimizations.get("fp8", False)
|
fp8 = model_config.optimizations.get("fp8", False)
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
|
breakpoint()
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
|
dtype = self.diffusion_model.dtype
|
||||||
|
|
||||||
|
# trt_rtx cannot handle bfloat16, so convert to float16, float32 always convert to float16 for trt_rtx to save memory
|
||||||
|
if self.enable_trt and dtype in (torch.float32, torch.bfloat16, torch.float16):
|
||||||
|
self.diffusion_model = self.diffusion_model.half()
|
||||||
|
unet_config["dtype"] = torch.float16
|
||||||
|
self.diffusion_model.dtype = torch.float16
|
||||||
|
logging.debug(f"converted diffusion modelfrom {dtype} to float16 for trt_rtx")
|
||||||
|
else:
|
||||||
|
self.enable_trt = False
|
||||||
|
logging.warning("trt_rtx cannot handle ${dtype}, so disabling trt_rtx")
|
||||||
|
|
||||||
self.diffusion_model.eval()
|
self.diffusion_model.eval()
|
||||||
if comfy.model_management.force_channels_last():
|
if comfy.model_management.force_channels_last():
|
||||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
logging.debug("using channels last mode for diffusion model")
|
logging.debug("using channels last mode for diffusion model")
|
||||||
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
|
||||||
|
|
||||||
|
if self.enable_trt:
|
||||||
|
import torch_tensorrt
|
||||||
|
settings = {
|
||||||
|
"use_python_runtime": False,
|
||||||
|
"immutable_weights": False,
|
||||||
|
"offload_to_cpu": True,
|
||||||
|
}
|
||||||
|
self.trt_compiled_diffusion_model = torch_tensorrt.MutableTorchTensorRTModule(self.diffusion_model, **settings)
|
||||||
|
# TODO: INVESTIGATE WHY DYNAMIC SHAPE IS NOT WORKING
|
||||||
|
enable_trt_dynamic_shape = False
|
||||||
|
if enable_trt_dynamic_shape:
|
||||||
|
# if batch size is 2, then sigmas-batch is 2, dim_batch is 4
|
||||||
|
sigmas_batch = torch.export.Dim("sigmas_batch", min=1, max=20)
|
||||||
|
dim_batch = torch.export.Dim("batch", min=1, max=40)
|
||||||
|
|
||||||
|
#dim_width = torch.export.Dim("width", min=3, max=64)
|
||||||
|
#dim_height = torch.export.Dim("height", min=5, max=64)
|
||||||
|
# args: xc, t
|
||||||
|
args_dynamic_shapes=({0: dim_batch}, {0: dim_batch},)
|
||||||
|
#args_dynamic_shapes=({0: dim_batch, 2: dim_width*4, 3: dim_height*4}, {0: dim_batch},)
|
||||||
|
# kwargs: context, transformer_options, y
|
||||||
|
kwargs_dynamic_shape = {
|
||||||
|
'context': {0: dim_batch},
|
||||||
|
'transformer_options': {
|
||||||
|
'wrappers': {},
|
||||||
|
'callbacks': {},
|
||||||
|
'sample_sigmas': {},
|
||||||
|
#'cond_or_uncond': {},
|
||||||
|
'sigmas': {0:sigmas_batch},
|
||||||
|
},
|
||||||
|
'y': {0: dim_batch,},
|
||||||
|
}
|
||||||
|
self.trt_compiled_diffusion_model.set_expected_dynamic_shape_range(args_dynamic_shapes, kwargs_dynamic_shape)
|
||||||
|
logging.debug("lan added ********** trt_model: trt_compiled_diffusion_model is created")
|
||||||
|
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model_sampling = model_sampling(model_config, model_type)
|
self.model_sampling = model_sampling(model_config, model_type)
|
||||||
|
|
||||||
@ -199,8 +248,23 @@ class BaseModel(torch.nn.Module):
|
|||||||
t = self.process_timestep(t, x=x, **extra_conds)
|
t = self.process_timestep(t, x=x, **extra_conds)
|
||||||
if "latent_shapes" in extra_conds:
|
if "latent_shapes" in extra_conds:
|
||||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||||
|
|
||||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
logging.debug(f"lan added *** {xc.shape=} {t.shape=} {context.shape=} {extra_conds['y'].shape=} {transformer_options['sample_sigmas'].shape=} {transformer_options['sigmas'].shape=}")
|
||||||
|
logging.debug(f"lan added *** {transformer_options=}")
|
||||||
|
if control is not None:
|
||||||
|
logging.debug(f"lan added ***{control.shape=}")
|
||||||
|
else:
|
||||||
|
logging.debug("lan added ***control is None")
|
||||||
|
logging.debug(f"lan added ***{extra_conds=}")
|
||||||
|
if self.enable_trt:
|
||||||
|
transformer_options.pop("uuids", None)
|
||||||
|
transformer_options.pop("cond_or_uncond", None)
|
||||||
|
with torch.no_grad():
|
||||||
|
model_output = self.trt_compiled_diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||||
|
else:
|
||||||
|
transformer_options.pop("uuids", None)
|
||||||
|
transformer_options.pop("cond_or_uncond", None)
|
||||||
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||||
model_output, _ = utils.pack_latents(model_output)
|
model_output, _ = utils.pack_latents(model_output)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user