mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 16:02:32 +08:00
Merge branch 'master' into feat/core/expected_outputs
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
This commit is contained in:
commit
06408af600
@ -1110,7 +1110,7 @@ class AceStepConditionGenerationModel(nn.Module):
|
|||||||
|
|
||||||
return encoder_hidden, encoder_mask, context_latents
|
return encoder_hidden, encoder_mask, context_latents
|
||||||
|
|
||||||
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs):
|
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs):
|
||||||
text_attention_mask = None
|
text_attention_mask = None
|
||||||
lyric_attention_mask = None
|
lyric_attention_mask = None
|
||||||
refer_audio_order_mask = None
|
refer_audio_order_mask = None
|
||||||
@ -1140,6 +1140,9 @@ class AceStepConditionGenerationModel(nn.Module):
|
|||||||
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
|
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if replace_with_null_embeds:
|
||||||
|
enc_hidden[:] = self.null_condition_emb.to(enc_hidden)
|
||||||
|
|
||||||
out = self.decoder(hidden_states=x,
|
out = self.decoder(hidden_states=x,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
timestep_r=timestep,
|
timestep_r=timestep,
|
||||||
|
|||||||
@ -335,7 +335,7 @@ class FinalLayer(nn.Module):
|
|||||||
device=None, dtype=None, operations=None
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
self.linear = operations.Linear(
|
self.linear = operations.Linear(
|
||||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
@ -463,6 +463,8 @@ class Block(nn.Module):
|
|||||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
transformer_options: Optional[dict] = {},
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
residual_dtype = x_B_T_H_W_D.dtype
|
||||||
|
compute_dtype = emb_B_T_D.dtype
|
||||||
if extra_per_block_pos_emb is not None:
|
if extra_per_block_pos_emb is not None:
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||||
|
|
||||||
@ -512,7 +514,7 @@ class Block(nn.Module):
|
|||||||
result_B_T_H_W_D = rearrange(
|
result_B_T_H_W_D = rearrange(
|
||||||
self.self_attn(
|
self.self_attn(
|
||||||
# normalized_x_B_T_HW_D,
|
# normalized_x_B_T_HW_D,
|
||||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||||
None,
|
None,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
@ -522,7 +524,7 @@ class Block(nn.Module):
|
|||||||
h=H,
|
h=H,
|
||||||
w=W,
|
w=W,
|
||||||
)
|
)
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||||
|
|
||||||
def _x_fn(
|
def _x_fn(
|
||||||
_x_B_T_H_W_D: torch.Tensor,
|
_x_B_T_H_W_D: torch.Tensor,
|
||||||
@ -536,7 +538,7 @@ class Block(nn.Module):
|
|||||||
)
|
)
|
||||||
_result_B_T_H_W_D = rearrange(
|
_result_B_T_H_W_D = rearrange(
|
||||||
self.cross_attn(
|
self.cross_attn(
|
||||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||||
crossattn_emb,
|
crossattn_emb,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
@ -555,7 +557,7 @@ class Block(nn.Module):
|
|||||||
shift_cross_attn_B_T_1_1_D,
|
shift_cross_attn_B_T_1_1_D,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||||
|
|
||||||
normalized_x_B_T_H_W_D = _fn(
|
normalized_x_B_T_H_W_D = _fn(
|
||||||
x_B_T_H_W_D,
|
x_B_T_H_W_D,
|
||||||
@ -563,8 +565,8 @@ class Block(nn.Module):
|
|||||||
scale_mlp_B_T_1_1_D,
|
scale_mlp_B_T_1_1_D,
|
||||||
shift_mlp_B_T_1_1_D,
|
shift_mlp_B_T_1_1_D,
|
||||||
)
|
)
|
||||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||||
return x_B_T_H_W_D
|
return x_B_T_H_W_D
|
||||||
|
|
||||||
|
|
||||||
@ -876,6 +878,14 @@ class MiniTrainDIT(nn.Module):
|
|||||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
"transformer_options": kwargs.get("transformer_options", {}),
|
"transformer_options": kwargs.get("transformer_options", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||||
|
# in fp32, but run attention and MLP modules in fp16.
|
||||||
|
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||||
|
# quality degradation and visual artifacts.
|
||||||
|
if x_B_T_H_W_D.dtype == torch.float16:
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x_B_T_H_W_D = block(
|
x_B_T_H_W_D = block(
|
||||||
x_B_T_H_W_D,
|
x_B_T_H_W_D,
|
||||||
@ -884,6 +894,6 @@ class MiniTrainDIT(nn.Module):
|
|||||||
**block_kwargs,
|
**block_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
||||||
return x_B_C_Tt_Hp_Wp
|
return x_B_C_Tt_Hp_Wp
|
||||||
|
|||||||
@ -1552,6 +1552,8 @@ class ACEStep15(BaseModel):
|
|||||||
|
|
||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
|
if torch.count_nonzero(cross_attn) == 0:
|
||||||
|
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||||
@ -1575,6 +1577,10 @@ class ACEStep15(BaseModel):
|
|||||||
else:
|
else:
|
||||||
out['is_covers'] = comfy.conds.CONDConstant(False)
|
out['is_covers'] = comfy.conds.CONDConstant(False)
|
||||||
|
|
||||||
|
if refer_audio.shape[2] < noise.shape[2]:
|
||||||
|
pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||||
|
refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
|
||||||
|
|
||||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -993,7 +993,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
|||||||
|
|
||||||
memory_usage_factor = 1.0
|
memory_usage_factor = 1.0
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
def __init__(self, unet_config):
|
def __init__(self, unet_config):
|
||||||
super().__init__(unet_config)
|
super().__init__(unet_config)
|
||||||
@ -1023,11 +1023,7 @@ class Anima(supported_models_base.BASE):
|
|||||||
|
|
||||||
memory_usage_factor = 1.0
|
memory_usage_factor = 1.0
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
def __init__(self, unet_config):
|
|
||||||
super().__init__(unet_config)
|
|
||||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Anima(self, device=device)
|
out = model_base.Anima(self, device=device)
|
||||||
@ -1038,6 +1034,12 @@ class Anima(supported_models_base.BASE):
|
|||||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
|
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
|
||||||
|
|
||||||
|
def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
|
||||||
|
self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||||
|
if dtype is torch.float16:
|
||||||
|
self.memory_usage_factor *= 1.4
|
||||||
|
return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs)
|
||||||
|
|
||||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "cosmos_predict2",
|
"image_model": "cosmos_predict2",
|
||||||
|
|||||||
@ -23,7 +23,7 @@ class AnimaTokenizer:
|
|||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
|
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
|
|||||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|||||||
@ -622,6 +622,7 @@ class SamplerSASolver(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SamplerSASolver",
|
node_id="SamplerSASolver",
|
||||||
|
search_aliases=["sde"],
|
||||||
category="sampling/custom_sampling/samplers",
|
category="sampling/custom_sampling/samplers",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
@ -666,6 +667,7 @@ class SamplerSEEDS2(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SamplerSEEDS2",
|
node_id="SamplerSEEDS2",
|
||||||
|
search_aliases=["sde", "exp heun"],
|
||||||
category="sampling/custom_sampling/samplers",
|
category="sampling/custom_sampling/samplers",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||||
|
|||||||
@ -108,7 +108,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
|||||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||||
if easycache.is_past_end_timestep(timestep):
|
if easycache.is_past_end_timestep(timestep):
|
||||||
return executor(*args, **kwargs)
|
return executor(*args, **kwargs)
|
||||||
x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels)
|
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||||
# prepare next x_prev
|
# prepare next x_prev
|
||||||
next_x_prev = x
|
next_x_prev = x
|
||||||
input_change = None
|
input_change = None
|
||||||
|
|||||||
@ -391,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
|||||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||||
normalized_latent = latent / latent_vector_magnitude
|
normalized_latent = latent / latent_vector_magnitude
|
||||||
|
|
||||||
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
dims = list(range(1, latent_vector_magnitude.ndim))
|
||||||
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||||
|
std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||||
|
|
||||||
top = (std * 5 + mean) * multiplier
|
top = (std * 5 + mean) * multiplier
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
comfyui-frontend-package==1.38.13
|
comfyui-frontend-package==1.38.13
|
||||||
comfyui-workflow-templates==0.8.31
|
comfyui-workflow-templates==0.8.31
|
||||||
comfyui-embedded-docs==0.4.0
|
comfyui-embedded-docs==0.4.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user