diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 09c367f29..68ba27f5e 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -273,7 +273,7 @@ class HunyuanVideo(nn.Module): # HunyuanVideo 1.5 specific modules if self.vision_in_dim is not None: - from comfy.ldm.wan.model import MLPProj # todo move + from comfy.ldm.wan.model import MLPProj self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings) else: self.vision_in = None @@ -293,7 +293,7 @@ class HunyuanVideo(nn.Module): timesteps: Tensor, y: Tensor = None, txt_byt5=None, - vision_states=None, #todo hunyuan video 1.5 vision encoder states input + clip_fea=None, guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, @@ -365,13 +365,11 @@ class HunyuanVideo(nn.Module): txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) - #todo vision_in - if vision_states is not None: - txt_vision_states = self.vision_in(vision_states) + if clip_fea is not None: + txt_vision_states = self.vision_in(clip_fea) if self.cond_type_embedding is not None: cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device)) txt_vision_states = txt_vision_states + cond_emb - #print("txt_vision_states shape:", txt_vision_states.shape) txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1) extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) @@ -469,14 +467,14 @@ class HunyuanVideo(nn.Module): img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) return repeat(img_ids, "h w c -> b (h w) c", b=bs) - def forward(self, x, timestep, context, y=None, txt_byt5=None, vision_states=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, txt_byt5, vision_states, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y=None, txt_byt5=None, vision_states=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): bs = x.shape[0] if len(self.patch_size) == 3: img_ids = self.img_ids(x) @@ -484,5 +482,5 @@ class HunyuanVideo(nn.Module): else: img_ids = self.img_ids_2d(x) txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, vision_states, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) return out diff --git a/comfy/model_base.py b/comfy/model_base.py index 9549693b3..c080dd0ea 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1537,7 +1537,7 @@ class HunyuanImage21Refiner(HunyuanImage21): out['disable_time_r'] = comfy.conds.CONDConstant(True) return out -class HunyuanVideo15(HunyuanImage21): +class HunyuanVideo15(HunyuanVideo): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) @@ -1565,7 +1565,36 @@ class HunyuanVideo15(HunyuanImage21): if mask is None: mask = torch.zeros_like(noise)[:, :1] else: - mask = torch.zeros_like(noise)[:, :1] - mask[:, :, 1:] = 1.0 + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + print("image.shape:", image.shape) + print("mask.shape:", mask.shape) return torch.cat((image, mask), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_byt5small = kwargs.get("conditioning_byt5small", None) + if conditioning_byt5small is not None: + out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small) + + guidance = kwargs.get("guidance", 6.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + clip_vision_output = kwargs.get("clip_vision_output", None) + if clip_vision_output is not None: + out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states) + + return out diff --git a/comfy/sd.py b/comfy/sd.py index 4755f2111..1bf5bbc0a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -441,13 +441,13 @@ class VAE: elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] - self.latent_channels = 64 + self.latent_channels = 32 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) self.latent_dim = 3 - self.not_video = True + self.not_video = False self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"}, encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},