""" This file is part of ComfyUI. Copyright (C) 2024 Comfy This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ import logging import math import torch from enum import Enum from typing import TypeVar, Type, Protocol, Any, Optional from . import conds from . import latent_formats from . import model_management from . import ops from . import utils from .conds import CONDRegular, CONDConstant from .ldm.ace.model import ACEStepTransformer2DModel from .ldm.audio.dit import AudioDiffusionTransformer from .ldm.audio.embedders import NumberConditioner from .ldm.aura.mmdit import MMDiT as AuraMMDiT from .ldm.cascade.stage_b import StageB from .ldm.cascade.stage_c import StageC from .ldm.chroma import model as chroma_model from .ldm.cosmos.model import GeneralDIT from .ldm.cosmos.predict2 import MiniTrainDIT from .ldm.flux import model as flux_model from .ldm.genmo.joint_model.asymm_models_joint import AsymmDiTJoint from .ldm.hidream.model import HiDreamImageTransformer2DModel from .ldm.hunyuan3d.model import Hunyuan3Dv2 as Hunyuan3Dv2Model from .ldm.hunyuan3dv2_1.hunyuandit import HunYuanDiTPlain from .ldm.hunyuan_video.model import HunyuanVideo as HunyuanVideoModel from .ldm.hydit.models import HunYuanDiT from .ldm.lightricks.model import LTXVModel from .ldm.lumina.model import NextDiT from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation from .ldm.chroma_radiance import model as chroma_radiance from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel from .ldm.pixart.pixartms import PixArtMS from .ldm.kandinsky5 import model as kadinsky5_model from .ldm.qwen_image.model import QwenImageTransformer2DModel from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel from .ldm.wan.model_animate import AnimateWanModel from .model_management_types import ModelManageable from .model_sampling import CONST, ModelSamplingDiscreteFlow, ModelSamplingFlux, IMG_TO_IMG from .model_sampling import StableCascadeSampling, COSMOS_RFLOW, ModelSamplingCosmosRFlow, V_PREDICTION, \ ModelSamplingContinuousEDM, ModelSamplingDiscrete, EPS, EDM, ModelSamplingContinuousV from .ops import Operations from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers logger = logging.getLogger(__name__) class ModelType(Enum): EPS = 1 V_PREDICTION = 2 V_PREDICTION_EDM = 3 STABLE_CASCADE = 4 EDM = 5 FLOW = 6 V_PREDICTION_CONTINUOUS = 7 FLUX = 8 IMG_TO_IMG = 9 FLOW_COSMOS = 10 def model_sampling(model_config, model_type): c = EPS s = ModelSamplingDiscrete if model_type == ModelType.EPS: c = EPS elif model_type == ModelType.V_PREDICTION: c = V_PREDICTION elif model_type == ModelType.V_PREDICTION_EDM: c = V_PREDICTION s = ModelSamplingContinuousEDM elif model_type == ModelType.STABLE_CASCADE: c = EPS s = StableCascadeSampling elif model_type == ModelType.EDM: c = EDM s = ModelSamplingContinuousEDM elif model_type == ModelType.FLOW: c = CONST s = ModelSamplingDiscreteFlow elif model_type == ModelType.V_PREDICTION_CONTINUOUS: c = V_PREDICTION s = ModelSamplingContinuousV elif model_type == ModelType.FLUX: c = CONST s = ModelSamplingFlux elif model_type == ModelType.IMG_TO_IMG: c = IMG_TO_IMG elif model_type == ModelType.FLOW_COSMOS: c = COSMOS_RFLOW s = ModelSamplingCosmosRFlow class ModelSampling(s, c): pass return ModelSampling(model_config) TModule = TypeVar('TModule', bound=torch.nn.Module) class ComfyUIModel(Protocol): def __call__(self, xc: torch.Tensor, t: torch.Tensor, context: Any = None, control: Any = None, transformer_options: Optional[dict] = None, **extra_conds: dict[str, Any]) -> Any: ... def convert_tensor(extra, dtype, device): if hasattr(extra, "dtype"): if extra.dtype != torch.int and extra.dtype != torch.long: extra = model_management.cast_to_device(extra, device, dtype) else: extra = model_management.cast_to_device(extra, device, None) return extra class BaseModel(torch.nn.Module): def __init__(self, model_config, model_type=ModelType.EPS, device: torch.device = None, unet_model: Type[TModule] = UNetModel): super().__init__() unet_config = model_config.unet_config self.latent_format = model_config.latent_format self.model_config = model_config self.manual_cast_dtype = model_config.manual_cast_dtype self.device: torch.device = device self.operations: Optional[Operations] self.current_patcher: Optional[ModelManageable] = None if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config) else: operations = model_config.custom_operations self.operations = operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model.eval() if model_management.force_channels_last(): self.diffusion_model.to(memory_format=torch.channels_last) logger.debug("using channels last mode for diffusion model") logger.debug("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) else: self.operations = None self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 self.concat_keys = () logger.debug("model_type {}".format(model_type.name)) logger.debug("adm {}".format(self.adm_channels)) self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor_conds = () self.memory_usage_shape_process = {} self.training = False # todo: does this break the training nodes? def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): return WrapperExecutor.new_class_executor( self._apply_model, self, get_all_wrappers(WrappersMP.APPLY_MODEL, transformer_options) ).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): sigma = t xc = self.model_sampling.calculate_input(sigma, x) if c_concat is not None: xc = torch.cat([xc] + [model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) context = c_crossattn dtype = self.get_dtype() if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype xc = xc.to(dtype) device = xc.device t = self.model_sampling.timestep(t).float() if context is not None: context = model_management.cast_to_device(context, device, dtype) extra_conds = {} for o in kwargs: extra = kwargs[o] if hasattr(extra, "dtype"): extra = convert_tensor(extra, dtype, device) elif isinstance(extra, list): ex = [] for ext in extra: ex.append(convert_tensor(ext, dtype, device)) extra = ex extra_conds[o] = extra t = self.process_timestep(t, x=x, **extra_conds) if "latent_shapes" in extra_conds: xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) if len(model_output) > 1 and not torch.is_tensor(model_output): model_output, _ = utils.pack_latents(model_output) return self.model_sampling.calculate_denoised(sigma, model_output.float(), x) def process_timestep(self, timestep, **kwargs): return timestep def get_dtype(self): return self.diffusion_model.dtype def encode_adm(self, **kwargs): return None def concat_cond(self, **kwargs): if len(self.concat_keys) > 0: cond_concat = [] denoise_mask: Optional[torch.Tensor] = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) concat_latent_image = kwargs.get("concat_latent_image", None) if concat_latent_image is None: concat_latent_image = kwargs.get("latent_image", None) else: concat_latent_image = self.process_latent_in(concat_latent_image) noise: Optional[torch.Tensor] = kwargs.get("noise", None) device = kwargs["device"] if concat_latent_image.shape[1:] != noise.shape[1:]: concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") if noise.ndim == 5: if concat_latent_image.shape[-3] < noise.shape[-3]: concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0) else: concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]] concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0]) if denoise_mask is not None: if len(denoise_mask.shape) == len(noise.shape): denoise_mask = denoise_mask[:, :1] num_dim = noise.ndim - 2 denoise_mask = denoise_mask.reshape((-1, 1) + tuple(denoise_mask.shape[-num_dim:])) if denoise_mask.shape[-2:] != noise.shape[-2:]: denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center") denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0]) for ck in self.concat_keys: if denoise_mask is not None: if ck == "mask": cond_concat.append(denoise_mask.to(device)) elif ck == "masked_image": cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space elif ck == "mask_inverted": cond_concat.append(1.0 - denoise_mask.to(device)) else: if ck == "mask": cond_concat.append(torch.ones_like(noise)[:, :1]) elif ck == "masked_image": cond_concat.append(self.blank_inpaint_image_like(noise)) elif ck == "mask_inverted": cond_concat.append(torch.zeros_like(noise)[:, :1]) if ck == "concat_image": if concat_latent_image is not None: cond_concat.append(concat_latent_image.to(device)) else: cond_concat.append(torch.zeros_like(noise)) data = torch.cat(cond_concat, dim=1) return data return None def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) if concat_cond is not None: out['c_concat'] = conds.CONDNoiseShape(concat_cond) # pylint: disable=assignment-from-none adm = self.encode_adm(**kwargs) if adm is not None: out['y'] = conds.CONDRegular(adm) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) cross_attn_cnet = kwargs.get("cross_attn_controlnet", None) if cross_attn_cnet is not None: out['crossattn_controlnet'] = conds.CONDCrossAttn(cross_attn_cnet) c_concat = kwargs.get("noise_concat", None) if c_concat is not None: out['c_concat'] = conds.CONDNoiseShape(c_concat) return out def load_model_weights(self, sd, unet_prefix=""): to_load = {} keys = list(sd.keys()) for k in keys: if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) m, u = self.diffusion_model.load_state_dict(to_load, strict=False) if len(m) > 0: logger.warning("unet missing: {}".format(m)) if len(u) > 0: logger.warning("unet unexpected: {}".format(u)) del to_load return self def process_latent_in(self, latent): return self.latent_format.process_in(latent) def process_latent_out(self, latent): return self.latent_format.process_out(latent) def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) if vae_state_dict is not None: extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) if clip_vision_state_dict is not None: extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) for sd in extra_sds: unet_state_dict.update(sd) return unet_state_dict def set_inpaint(self): self.concat_keys = ("mask", "masked_image") def blank_inpaint_image_like(latent_image): blank_image = torch.ones_like(latent_image) # these are the values for "zero" in pixel space translated to latent space blank_image[:, 0] *= 0.8223 blank_image[:, 1] *= -0.6876 blank_image[:, 2] *= 0.6364 blank_image[:, 3] *= 0.1380 return blank_image self.blank_inpaint_image_like = blank_inpaint_image_like def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image) def memory_required(self, input_shape, cond_shapes={}): input_shapes = [input_shape] for c in self.memory_usage_factor_conds: shape = cond_shapes.get(c, None) if shape is not None: if c in self.memory_usage_shape_process: out = [] for s in shape: out.append(self.memory_usage_shape_process[c](s)) shape = out if len(shape) > 0: input_shapes += shape if model_management.xformers_enabled() or model_management.pytorch_attention_flash_attention(): dtype = self.get_dtype() if self.manual_cast_dtype is not None: dtype = self.manual_cast_dtype # TODO: this needs to be tweaked area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) else: # TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) def extra_conds_shapes(self, **kwargs): return {} def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): adm_inputs = [] weights = [] noise_aug = [] for unclip_cond in unclip_conditioning: for adm_cond in unclip_cond["clip_vision_output"].image_embeds: weight = unclip_cond["strength"] noise_augment = unclip_cond["noise_augmentation"] noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device), seed=seed) adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight weights.append(weight) noise_aug.append(noise_augment) adm_inputs.append(adm_out) if len(noise_aug) > 1: adm_out = torch.stack(adm_inputs).sum(0) noise_augment = noise_augment_merge noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) adm_out = torch.cat((c_adm, noise_level_emb), 1) return adm_out class SD21UNCLIP(BaseModel): def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device) self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**noise_aug_config) def encode_adm(self, **kwargs): unclip_conditioning = kwargs.get("unclip_conditioning", None) device = kwargs["device"] if unclip_conditioning is None: return torch.zeros((1, self.adm_channels), device=device) else: return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05), kwargs.get("seed", 0) - 10) def sdxl_pooled(args, noise_augmentor): if "unclip_conditioning" in args: return unclip_adm(args.get("unclip_conditioning", None), args["device"], noise_augmentor, seed=args.get("seed", 0) - 10)[:, :1280] else: return args["pooled_output"] class SDXLRefiner(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) self.embedder = Timestep(256) self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280}) def encode_adm(self, **kwargs): clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor) width = kwargs.get("width", 768) height = kwargs.get("height", 768) crop_w = kwargs.get("crop_w", 0) crop_h = kwargs.get("crop_h", 0) if kwargs.get("prompt_type", "") == "negative": aesthetic_score = kwargs.get("aesthetic_score", 2.5) else: aesthetic_score = kwargs.get("aesthetic_score", 6) out = [] out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([aesthetic_score]))) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) class SDXL(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) self.embedder = Timestep(256) self.noise_augmentor = CLIPEmbeddingNoiseAugmentation(**{"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1280}) def encode_adm(self, **kwargs): clip_pooled = sdxl_pooled(kwargs, self.noise_augmentor) width = kwargs.get("width", 768) height = kwargs.get("height", 768) crop_w = kwargs.get("crop_w", 0) crop_h = kwargs.get("crop_h", 0) target_width = kwargs.get("target_width", width) target_height = kwargs.get("target_height", height) out = [] out.append(self.embedder(torch.Tensor([height]))) out.append(self.embedder(torch.Tensor([width]))) out.append(self.embedder(torch.Tensor([crop_h]))) out.append(self.embedder(torch.Tensor([crop_w]))) out.append(self.embedder(torch.Tensor([target_height]))) out.append(self.embedder(torch.Tensor([target_width]))) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1) class SVD_img2vid(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): super().__init__(model_config, model_type, device=device) self.embedder = Timestep(256) def encode_adm(self, **kwargs): fps_id = kwargs.get("fps", 6) - 1 motion_bucket_id = kwargs.get("motion_bucket_id", 127) augmentation = kwargs.get("augmentation_level", 0) out = [] out.append(self.embedder(torch.Tensor([fps_id]))) out.append(self.embedder(torch.Tensor([motion_bucket_id]))) out.append(self.embedder(torch.Tensor([augmentation]))) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) return flat def extra_conds(self, **kwargs): out = {} adm = self.encode_adm(**kwargs) if adm is not None: out['y'] = conds.CONDRegular(adm) latent_image = kwargs.get("concat_latent_image", None) noise = kwargs.get("noise", None) if latent_image is None: latent_image = torch.zeros_like(noise) if latent_image.shape[1:] != noise.shape[1:]: latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0]) out['c_concat'] = conds.CONDNoiseShape(latent_image) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) if "time_conditioning" in kwargs: out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"]) out['num_video_frames'] = conds.CONDConstant(noise.shape[0]) return out class SV3D_u(SVD_img2vid): def encode_adm(self, **kwargs): augmentation = kwargs.get("augmentation_level", 0) out = [] out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0) return flat class SV3D_p(SVD_img2vid): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): super().__init__(model_config, model_type, device=device) self.embedder_512 = Timestep(512) def encode_adm(self, **kwargs): augmentation = kwargs.get("augmentation_level", 0) elevation = kwargs.get("elevation", 0) # elevation and azimuth are in degrees here azimuth = kwargs.get("azimuth", 0) noise = kwargs.get("noise", None) out = [] out.append(self.embedder(torch.flatten(torch.Tensor([augmentation])))) out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(90 - torch.Tensor([elevation])), 360.0)))) out.append(self.embedder_512(torch.deg2rad(torch.fmod(torch.flatten(torch.Tensor([azimuth])), 360.0)))) out = list(map(lambda a: utils.resize_to_batch_size(a, noise.shape[0]), out)) return torch.cat(out, dim=1) class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) self.cc_projection = ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device) self.cc_projection.weight.copy_(cc_projection_weight) self.cc_projection.bias.copy_(cc_projection_bias) def extra_conds(self, **kwargs): out = {} latent_image = kwargs.get("concat_latent_image", None) noise = kwargs.get("noise", None) if latent_image is None: latent_image = torch.zeros_like(noise) if latent_image.shape[1:] != noise.shape[1:]: latent_image = utils.common_upscale(latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center") latent_image = utils.resize_to_batch_size(latent_image, noise.shape[0]) out['c_concat'] = conds.CONDNoiseShape(latent_image) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: if cross_attn.shape[-1] != 768: cross_attn = self.cc_projection(cross_attn) out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) return out class SD_X4Upscaler(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device) self.noise_augmentor = ImageConcatWithNoiseAugmentation(noise_schedule_config={"linear_start": 0.0001, "linear_end": 0.02}, max_noise_level=350) def extra_conds(self, **kwargs): out = {} image = kwargs.get("concat_image", None) noise = kwargs.get("noise", None) noise_augment = kwargs.get("noise_augmentation", 0.0) device = kwargs["device"] seed = kwargs["seed"] - 10 noise_level = round((self.noise_augmentor.max_noise_level) * noise_augment) if image is None: image = torch.zeros_like(noise)[:, :3] if image.shape[1:] != noise.shape[1:]: image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") noise_level = torch.tensor([noise_level], device=device) if noise_augment > 0: image, noise_level = self.noise_augmentor(image.to(device), noise_level=noise_level, seed=seed) image = utils.resize_to_batch_size(image, noise.shape[0]) out['c_concat'] = conds.CONDNoiseShape(image) out['y'] = conds.CONDRegular(noise_level) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) return out class IP2P(BaseModel): def process_ip2p_image_in(self, image): return None def concat_cond(self, **kwargs): image = kwargs.get("concat_latent_image", None) noise = kwargs.get("noise", None) device = kwargs["device"] if image is None: image = torch.zeros_like(noise) else: image = image.to(device=device) if image.shape[1:] != noise.shape[1:]: image = utils.common_upscale(image, noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.resize_to_batch_size(image, noise.shape[0]) return self.process_ip2p_image_in(image) class SD15_instructpix2pix(IP2P, BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) self.process_ip2p_image_in = lambda image: image class SDXL_instructpix2pix(IP2P, SDXL): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device) if model_type == ModelType.V_PREDICTION_EDM: self.process_ip2p_image_in = lambda image: latent_formats.SDXL().process_in(image) # cosxl ip2p else: self.process_ip2p_image_in = lambda image: image # diffusers ip2p class Lotus(BaseModel): def extra_conds(self, **kwargs): out = {} cross_attn = kwargs.get("cross_attn", None) out['c_crossattn'] = conds.CONDCrossAttn(cross_attn) device = kwargs["device"] task_emb = torch.tensor([1, 0]).float().to(device) task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0) out['y'] = conds.CONDRegular(task_emb) return out def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None): super().__init__(model_config, model_type, device=device) class StableCascade_C(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageC) def extra_conds(self, **kwargs): out = {} clip_text_pooled = kwargs["pooled_output"] if clip_text_pooled is not None: out['clip_text_pooled'] = conds.CONDRegular(clip_text_pooled) if "unclip_conditioning" in kwargs: embeds = [] for unclip_cond in kwargs["unclip_conditioning"]: weight = unclip_cond["strength"] embeds.append(unclip_cond["clip_vision_output"].image_embeds.unsqueeze(0) * weight) clip_img = torch.cat(embeds, dim=1) else: clip_img = torch.zeros((1, 1, 768)) out["clip_img"] = conds.CONDRegular(clip_img) out["sca"] = conds.CONDRegular(torch.zeros((1,))) out["crp"] = conds.CONDRegular(torch.zeros((1,))) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['clip_text'] = conds.CONDCrossAttn(cross_attn) return out class StableCascade_B(BaseModel): def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None): super().__init__(model_config, model_type, device=device, unet_model=StageB) def extra_conds(self, **kwargs): out = {} noise = kwargs.get("noise", None) clip_text_pooled = kwargs["pooled_output"] if clip_text_pooled is not None: out['clip'] = conds.CONDRegular(clip_text_pooled) # size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device)) out["effnet"] = conds.CONDRegular(prior.to(device=noise.device)) out["sca"] = conds.CONDRegular(torch.zeros((1,))) return out class SD3(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=OpenAISignatureMMDITWrapper) def encode_adm(self, **kwargs): return kwargs["pooled_output"] def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) return out class AuraFlow(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=AuraMMDiT) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) return out class StableAudio1(BaseModel): def __init__(self, model_config, seconds_start_embedder_weights, seconds_total_embedder_weights, model_type=ModelType.V_PREDICTION_CONTINUOUS, device=None): super().__init__(model_config, model_type, device=device, unet_model=AudioDiffusionTransformer) self.seconds_start_embedder = NumberConditioner(768, min_val=0, max_val=512) self.seconds_total_embedder = NumberConditioner(768, min_val=0, max_val=512) self.seconds_start_embedder.load_state_dict(seconds_start_embedder_weights) self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights) def extra_conds(self, **kwargs): out = {} noise = kwargs.get("noise", None) device = kwargs["device"] seconds_start = kwargs.get("seconds_start", 0) seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 21.53)) seconds_start_embed = self.seconds_start_embedder([seconds_start])[0].to(device) seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device) global_embed = torch.cat([seconds_start_embed, seconds_total_embed], dim=-1).reshape((1, -1)) out['global_embed'] = conds.CONDRegular(global_embed) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: cross_attn = torch.cat([cross_attn.to(device), seconds_start_embed.repeat((cross_attn.shape[0], 1, 1)), seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1) out['c_crossattn'] = conds.CONDRegular(cross_attn) return out def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} for k in d: s = d[k] for l in s: sd["{}{}".format(k, l)] = s[l] return sd class HunyuanDiT(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): super().__init__(model_config, model_type, device=device, unet_model=HunYuanDiT) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: out['text_embedding_mask'] = conds.CONDRegular(attention_mask) conditioning_mt5xl = kwargs.get("conditioning_mt5xl", None) if conditioning_mt5xl is not None: out['encoder_hidden_states_t5'] = conds.CONDRegular(conditioning_mt5xl) attention_mask_mt5xl = kwargs.get("attention_mask_mt5xl", None) if attention_mask_mt5xl is not None: out['text_embedding_mask_t5'] = conds.CONDRegular(attention_mask_mt5xl) width = kwargs.get("width", 768) height = kwargs.get("height", 768) target_width = kwargs.get("target_width", width) target_height = kwargs.get("target_height", height) out['image_meta_size'] = conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device, unet_model=PixArtMS) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) width = kwargs.get("width", None) height = kwargs.get("height", None) if width is not None and height is not None: out["c_size"] = conds.CONDRegular(torch.FloatTensor([[height, width]])) out["c_ar"] = conds.CONDRegular(torch.FloatTensor([[kwargs.get("aspect_ratio", height / width)]])) return out class Flux(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=flux_model.Flux): super().__init__(model_config, model_type, device=device, unet_model=unet_model) self.memory_usage_factor_conds = ("ref_latents",) def concat_cond(self, **kwargs): try: # Handle Flux control loras dynamically changing the img_in weight. num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size) except: # Some cases like tensorrt might not have the weights accessible num_channels = self.model_config.unet_config["in_channels"] out_channels = self.model_config.unet_config["out_channels"] if num_channels <= out_channels: return None image = kwargs.get("concat_latent_image", None) noise = kwargs.get("noise", None) device = kwargs["device"] if image is None: image = torch.zeros_like(noise) image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") image = utils.resize_to_batch_size(image, noise.shape[0]) image = self.process_latent_in(image) if num_channels <= out_channels * 2: return image # inpaint model mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: mask = torch.ones_like(noise)[:, :1] mask = torch.mean(mask, dim=1, keepdim=True) mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center") mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8) mask = utils.resize_to_batch_size(mask, noise.shape[0]) return torch.cat((image, mask), dim=1) def encode_adm(self, **kwargs): return kwargs["pooled_output"] def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) # upscale the attention mask, since now we attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: shape = kwargs["noise"].shape mask_ref_size = kwargs.get("attention_mask_img_shape", None) if mask_ref_size is not None: # the model will pad to the patch size, and then divide # essentially dividing and rounding up (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) out['attention_mask'] = conds.CONDRegular(attention_mask) guidance = kwargs.get("guidance", 3.5) if guidance is not None: out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: latents = [] for lat in ref_latents: latents.append(self.process_latent_in(lat)) out['ref_latents'] = conds.CONDList(latents) ref_latents_method = kwargs.get("reference_latents_method", None) if ref_latents_method is not None: out['ref_latents_method'] = conds.CONDConstant(ref_latents_method) return out def extra_conds_shapes(self, **kwargs): out = {} ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out class Flux2(Flux): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: target_text_len = 512 if cross_attn.shape[1] < target_text_len: cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0)) out['c_crossattn'] = conds.CONDRegular(cross_attn) return out class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=AsymmDiTJoint) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: out['attention_mask'] = conds.CONDRegular(attention_mask) out['num_tokens'] = conds.CONDConstant(max(1, torch.sum(attention_mask).item())) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) return out class LTXV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=LTXVModel) # TODO def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: out['attention_mask'] = conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) out['frame_rate'] = conds.CONDConstant(kwargs.get("frame_rate", 25)) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if denoise_mask is not None: out["denoise_mask"] = conds.CONDRegular(denoise_mask) keyframe_idxs = kwargs.get("keyframe_idxs", None) if keyframe_idxs is not None: out['keyframe_idxs'] = conds.CONDRegular(keyframe_idxs) return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): if denoise_mask is None: return timestep return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0] def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=HunyuanVideoModel) def encode_adm(self, **kwargs): return kwargs["pooled_output"] def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: out['attention_mask'] = conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) guiding_frame_index = kwargs.get("guiding_frame_index", None) if guiding_frame_index is not None: out['guiding_frame_index'] = conds.CONDRegular(torch.FloatTensor([guiding_frame_index])) ref_latent = kwargs.get("ref_latent", None) if ref_latent is not None: out['ref_latent'] = conds.CONDRegular(self.process_latent_in(ref_latent)) return out def scale_latent_inpaint(self, latent_image, **kwargs): return latent_image class HunyuanVideoI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) self.concat_keys = ("concat_image", "mask_inverted") def scale_latent_inpaint(self, latent_image, **kwargs): return super().scale_latent_inpaint(latent_image=latent_image, **kwargs) class HunyuanVideoSkyreelsI2V(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) self.concat_keys = ("concat_image",) def scale_latent_inpaint(self, latent_image, **kwargs): return super().scale_latent_inpaint(latent_image=latent_image, **kwargs) class CosmosVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=GeneralDIT) self.image_to_video = image_to_video if self.image_to_video: self.concat_keys = ("mask_inverted",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: out['attention_mask'] = conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) out['fps'] = conds.CONDConstant(kwargs.get("frame_rate", None)) return out def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)) sigma_noise_augmentation = 0 # TODO if sigma_noise_augmentation != 0: latent_image = latent_image + noise latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5) class CosmosPredict2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW_COSMOS, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=MiniTrainDIT) self.image_to_video = image_to_video if self.image_to_video: self.concat_keys = ("mask_inverted",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = CONDRegular(cross_attn) denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if denoise_mask is not None: out["denoise_mask"] = CONDRegular(denoise_mask) out['fps'] = CONDConstant(kwargs.get("frame_rate", None)) return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): if denoise_mask is None: return timestep if denoise_mask.ndim <= 4: return timestep condition_video_mask_B_1_T_1_1 = denoise_mask.mean(dim=[1, 3, 4], keepdim=True) c_noise_B_1_T_1_1 = 0.0 * (1.0 - condition_video_mask_B_1_T_1_1) + timestep.reshape(timestep.shape[0], 1, 1, 1, 1) * condition_video_mask_B_1_T_1_1 out = c_noise_B_1_T_1_1.squeeze(dim=[1, 3, 4]) return out def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)) sigma_noise_augmentation = 0 # TODO if sigma_noise_augmentation != 0: latent_image = latent_image + noise latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image) sigma = (sigma / (sigma + 1)) return latent_image / (1.0 - sigma) class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=NextDiT) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: if torch.numel(attention_mask) != attention_mask.sum(): out['attention_mask'] = conds.CONDRegular(attention_mask) out['num_tokens'] = conds.CONDConstant(max(1, torch.sum(attention_mask).item())) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) if 'num_tokens' not in out: out['num_tokens'] = conds.CONDConstant(cross_attn.shape[1]) clip_text_pooled = kwargs.get("pooled_output", None) # NewBie if clip_text_pooled is not None: out['clip_text_pooled'] = conds.CONDRegular(clip_text_pooled) return out class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=WanModel) self.image_to_video = image_to_video def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] if extra_channels == 0: return None image = kwargs.get("concat_latent_image", None) device = kwargs["device"] if image is None: shape_image = list(noise.shape) shape_image[1] = extra_channels image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) else: latent_dim = self.latent_format.latent_channels image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") for i in range(0, image.shape[1], latent_dim): image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim]) image = utils.resize_to_batch_size(image, noise.shape[0]) if extra_channels != image.shape[1] + 4: if not self.image_to_video or extra_channels == image.shape[1]: return image if image.shape[1] > (extra_channels - 4): image = image[:, :(extra_channels - 4)] mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: mask = torch.zeros_like(noise)[:, :4] else: if mask.shape[1] != 4: mask = torch.mean(mask, dim=1, keepdim=True) mask = 1.0 - mask mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") if mask.shape[-3] < noise.shape[-3]: mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) if mask.shape[1] == 1: mask = mask.repeat(1, 4, 1, 1, 1) mask = utils.resize_to_batch_size(mask, noise.shape[0]) concat_mask_index = kwargs.get("concat_mask_index", 0) if concat_mask_index != 0: return torch.cat((image[:, :concat_mask_index], mask, image[:, concat_mask_index:]), dim=1) else: return torch.cat((mask, image), dim=1) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) clip_vision_output = kwargs.get("clip_vision_output", None) if clip_vision_output is not None: out['clip_fea'] = conds.CONDRegular(clip_vision_output.penultimate_hidden_states) time_dim_concat = kwargs.get("time_dim_concat", None) if time_dim_concat is not None: out['time_dim_concat'] = conds.CONDRegular(self.process_latent_in(time_dim_concat)) reference_latents = kwargs.get("reference_latents", None) if reference_latents is not None: out['reference_latent'] = conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) return out class WAN21_Vace(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=VaceWanModel) self.image_to_video = image_to_video def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) noise = kwargs.get("noise", None) noise_shape = list(noise.shape) vace_frames = kwargs.get("vace_frames", None) if vace_frames is None: noise_shape[1] = 32 vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)] mask = kwargs.get("vace_mask", None) if mask is None: noise_shape[1] = 64 mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames) vace_frames_out = [] for j in range(len(vace_frames)): vf = vace_frames[j].to(device=noise.device, dtype=noise.dtype, copy=True) for i in range(0, vf.shape[1], 16): vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16]) vf = torch.cat([vf, mask[j].to(device=noise.device, dtype=noise.dtype)], dim=1) vace_frames_out.append(vf) vace_frames = torch.stack(vace_frames_out, dim=1) out['vace_context'] = conds.CONDRegular(vace_frames) vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out)) out['vace_strength'] = conds.CONDConstant(vace_strength) return out class WAN21_Camera(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=CameraWanModel) self.image_to_video = image_to_video def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) camera_conditions = kwargs.get("camera_conditions", None) if camera_conditions is not None: out['camera_conditions'] = conds.CONDRegular(camera_conditions) return out class WAN21_HuMo(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=HumoWanModel) self.image_to_video = image_to_video def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) noise = kwargs.get("noise", None) audio_embed = kwargs.get("audio_embed", None) if audio_embed is not None: out['audio_embed'] = conds.CONDRegular(audio_embed) if "c_concat" not in out: # 1.7B model reference_latents = kwargs.get("reference_latents", None) if reference_latents is not None: out['reference_latent'] = conds.CONDRegular(self.process_latent_in(reference_latents[-1])) else: noise_shape = list(noise.shape) noise_shape[1] += 4 concat_latent = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype) zero_vae_values_first = torch.tensor([0.8660, -0.4326, -0.0017, -0.4884, -0.5283, 0.9207, -0.9896, 0.4433, -0.5543, -0.0113, 0.5753, -0.6000, -0.8346, -0.3497, -0.1926, -0.6938]).view(1, 16, 1, 1, 1) zero_vae_values_second = torch.tensor([1.0869, -1.2370, 0.0206, -0.4357, -0.6411, 2.0307, -1.5972, 1.2659, -0.8595, -0.4654, 0.9638, -1.6330, -1.4310, -0.1098, -0.3856, -1.4583]).view(1, 16, 1, 1, 1) zero_vae_values = torch.tensor([0.8642, -1.8583, 0.1577, 0.1350, -0.3641, 2.5863, -1.9670, 1.6065, -1.0475, -0.8678, 1.1734, -1.8138, -1.5933, -0.7721, -0.3289, -1.3745]).view(1, 16, 1, 1, 1) concat_latent[:, 4:] = zero_vae_values concat_latent[:, 4:, :1] = zero_vae_values_first concat_latent[:, 4:, 1:2] = zero_vae_values_second out['c_concat'] = conds.CONDNoiseShape(concat_latent) reference_latents = kwargs.get("reference_latents", None) if reference_latents is not None: ref_latent = self.process_latent_in(reference_latents[-1]) ref_latent_shape = list(ref_latent.shape) ref_latent_shape[1] += 4 + ref_latent_shape[1] ref_latent_full = torch.zeros(ref_latent_shape, device=ref_latent.device, dtype=ref_latent.dtype) ref_latent_full[:, 20:] = ref_latent ref_latent_full[:, 16:20] = 1.0 out['reference_latent'] = conds.CONDRegular(ref_latent_full) return out class WAN22_Animate(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=AnimateWanModel) self.image_to_video = image_to_video def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) face_video_pixels = kwargs.get("face_video_pixels", None) if face_video_pixels is not None: out['face_pixel_values'] = conds.CONDRegular(face_video_pixels) pose_latents = kwargs.get("pose_video_latent", None) if pose_latents is not None: out['pose_latents'] = conds.CONDRegular(self.process_latent_in(pose_latents)) return out class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=WanModel_S2V) self.memory_usage_factor_conds = ("reference_latent", "reference_motion") self.memory_usage_shape_process = {"reference_motion": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) audio_embed = kwargs.get("audio_embed", None) if audio_embed is not None: out['audio_embed'] = conds.CONDRegular(audio_embed) reference_latents = kwargs.get("reference_latents", None) if reference_latents is not None: out['reference_latent'] = conds.CONDRegular(self.process_latent_in(reference_latents[-1])) reference_motion = kwargs.get("reference_motion", None) if reference_motion is not None: out['reference_motion'] = conds.CONDRegular(self.process_latent_in(reference_motion)) control_video = kwargs.get("control_video", None) if control_video is not None: out['control_video'] = conds.CONDRegular(self.process_latent_in(control_video)) return out def extra_conds_shapes(self, **kwargs): out = {} ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: out['reference_latent'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) reference_motion = kwargs.get("reference_motion", None) if reference_motion is not None: out['reference_motion'] = reference_motion.shape return out class WAN22(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=WanModel) self.image_to_video = image_to_video def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) denoise_mask = kwargs.get("denoise_mask", None) if denoise_mask is not None: out["denoise_mask"] = conds.CONDRegular(denoise_mask) return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): if denoise_mask is None: return timestep temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1) return temp_ts def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=Hunyuan3Dv2Model) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) guidance = kwargs.get("guidance", 5.0) if guidance is not None: out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) return out class Hunyuan3Dv2_1(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=HunYuanDiTPlain) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) guidance = kwargs.get("guidance", 5.0) if guidance is not None: out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) return out class HiDream(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=HiDreamImageTransformer2DModel) def encode_adm(self, **kwargs): return kwargs["pooled_output"] def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) conditioning_llama3 = kwargs.get("conditioning_llama3", None) if conditioning_llama3 is not None: out['encoder_hidden_states_llama3'] = conds.CONDRegular(conditioning_llama3) image_cond = kwargs.get("concat_latent_image", None) if image_cond is not None: out['image_cond'] = conds.CONDNoiseShape(self.process_latent_in(image_cond)) return out class Chroma(Flux): def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=chroma_model.Chroma): super().__init__(model_config, model_type, device=device, unet_model=unet_model) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) guidance = kwargs.get("guidance", 0) if guidance is not None: out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) return out class ChromaRadiance(Chroma): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=chroma_radiance.ChromaRadiance) class ACEStep(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=ACEStepTransformer2DModel) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) noise: Optional[torch.Tensor] = kwargs.get("noise", None) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) conditioning_lyrics = kwargs.get("conditioning_lyrics", None) if cross_attn is not None: out['lyric_token_idx'] = conds.CONDRegular(conditioning_lyrics) out['speaker_embeds'] = conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype)) out['lyrics_strength'] = conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out class Omnigen2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=OmniGen2Transformer2DModel) self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask: Optional[torch.Tensor] = kwargs.get("attention_mask", None) if attention_mask is not None: if torch.numel(attention_mask) != attention_mask.sum(): out['attention_mask'] = conds.CONDRegular(attention_mask) out['num_tokens'] = conds.CONDConstant(max(1, torch.sum(attention_mask).item())) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) ref_latents: Optional[torch.Tensor] = kwargs.get("reference_latents", None) if ref_latents is not None: latents = [] for lat in ref_latents: latents.append(self.process_latent_in(lat)) out['ref_latents'] = conds.CONDList(latents) return out def extra_conds_shapes(self, **kwargs): out = {} ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out class QwenImage(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=QwenImageTransformer2DModel) self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) ref_latents: Optional[torch.Tensor] = kwargs.get("reference_latents", None) if ref_latents is not None: latents = [] for lat in ref_latents: latents.append(self.process_latent_in(lat)) out['ref_latents'] = conds.CONDList(latents) ref_latents_method = kwargs.get("reference_latents_method", None) if ref_latents_method is not None: out['ref_latents_method'] = conds.CONDConstant(ref_latents_method) return out def extra_conds_shapes(self, **kwargs): out = {} ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out class HunyuanImage21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=HunyuanVideo) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask: Optional[torch.Tensor] = kwargs.get("attention_mask", None) if attention_mask is not None: if torch.numel(attention_mask) != attention_mask.sum(): out['attention_mask'] = conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) conditioning_byt5small = kwargs.get("conditioning_byt5small", None) if conditioning_byt5small is not None: out['txt_byt5'] = conds.CONDRegular(conditioning_byt5small) guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) return out class HunyuanImage21Refiner(HunyuanImage21): def concat_cond(self, **kwargs): noise: Optional[torch.Tensor] = kwargs.get("noise", None) image: Optional[torch.Tensor] = kwargs.get("concat_latent_image", None) noise_augmentation = kwargs.get("noise_augmentation", 0.0) device = kwargs["device"] if image is None: shape_image = list(noise.shape) image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) else: image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") image = self.process_latent_in(image) image = utils.resize_to_batch_size(image, noise.shape[0]) if noise_augmentation > 0: generator = torch.Generator(device="cpu") generator.manual_seed(kwargs.get("seed", 0) - 10) noise = torch.randn(image.shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device) image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image else: image = 0.75 * image return image def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) out['disable_time_r'] = conds.CONDConstant(True) return out class HunyuanVideo15(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 # noise 32 img cond 32 + mask 1 if extra_channels == 0: return None image = kwargs.get("concat_latent_image", None) device = kwargs["device"] if image is None: shape_image = list(noise.shape) shape_image[1] = extra_channels image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) else: latent_dim = self.latent_format.latent_channels image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") for i in range(0, image.shape[1], latent_dim): image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim]) image = utils.resize_to_batch_size(image, noise.shape[0]) mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: mask = torch.zeros_like(noise)[:, :1] else: mask = 1.0 - mask mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") if mask.shape[-3] < noise.shape[-3]: mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) mask = utils.resize_to_batch_size(mask, noise.shape[0]) return torch.cat((image, mask), dim=1) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: if torch.numel(attention_mask) != attention_mask.sum(): out['attention_mask'] = conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) conditioning_byt5small = kwargs.get("conditioning_byt5small", None) if conditioning_byt5small is not None: out['txt_byt5'] = conds.CONDRegular(conditioning_byt5small) guidance = kwargs.get("guidance", 6.0) if guidance is not None: out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) clip_vision_output = kwargs.get("clip_vision_output", None) if clip_vision_output is not None: out['clip_fea'] = conds.CONDRegular(clip_vision_output.last_hidden_state) return out class HunyuanVideo15_SR_Distilled(HunyuanVideo15): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) image = kwargs.get("concat_latent_image", None) noise_augmentation = kwargs.get("noise_augmentation", 0.0) device = kwargs["device"] if image is None: image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=model_management.intermediate_device()) else: image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") # image = self.process_latent_in(image) # scaling wasn't applied in reference code image = utils.resize_to_batch_size(image, noise.shape[0]) lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1) if noise_augmentation > 0: generator = torch.Generator(device="cpu") generator.manual_seed(kwargs.get("seed", 0) - 10) noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device) image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice] else: image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice] return image def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) out['disable_time_r'] = conds.CONDConstant(False) return out class Kandinsky5(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=kadinsky5_model.Kandinsky5) def encode_adm(self, **kwargs): return kwargs["pooled_output"] def concat_cond(self, **kwargs): noise = kwargs.get("noise", None) device = kwargs["device"] image = torch.zeros_like(noise) mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) if mask is None: mask = torch.zeros_like(noise)[:, :1] else: mask = 1.0 - mask mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") if mask.shape[-3] < noise.shape[-3]: mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) mask = utils.resize_to_batch_size(mask, noise.shape[0]) return torch.cat((image, mask), dim=1) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: out['attention_mask'] = conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) time_dim_replace = kwargs.get("time_dim_replace", None) if time_dim_replace is not None: out['time_dim_replace'] = conds.CONDRegular(self.process_latent_in(time_dim_replace)) return out class Kandinsky5Image(Kandinsky5): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) def concat_cond(self, **kwargs): return None