mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
Fix for HunyuanVideo1.5 meanflow distil (#11212)
This commit is contained in:
parent
9d252f3b70
commit
e2a800e7ef
@ -43,6 +43,7 @@ class HunyuanVideoParams:
|
|||||||
meanflow: bool
|
meanflow: bool
|
||||||
use_cond_type_embedding: bool
|
use_cond_type_embedding: bool
|
||||||
vision_in_dim: int
|
vision_in_dim: int
|
||||||
|
meanflow_sum: bool
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionRef(nn.Module):
|
class SelfAttentionRef(nn.Module):
|
||||||
@ -317,7 +318,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||||
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
||||||
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
||||||
vec = (vec + vec_r) / 2
|
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
|
||||||
|
|
||||||
if ref_latent is not None:
|
if ref_latent is not None:
|
||||||
ref_latent_ids = self.img_ids(ref_latent)
|
ref_latent_ids = self.img_ids(ref_latent)
|
||||||
|
|||||||
@ -180,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["use_cond_type_embedding"] = False
|
dit_config["use_cond_type_embedding"] = False
|
||||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["meanflow_sum"] = True
|
||||||
else:
|
else:
|
||||||
dit_config["vision_in_dim"] = None
|
dit_config["vision_in_dim"] = None
|
||||||
|
dit_config["meanflow_sum"] = False
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user