Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-10-07 14:30:16 +09:00
commit 6b20418ad1
15 changed files with 1561 additions and 1576 deletions

View File

@ -23,8 +23,6 @@ class MusicDCAE(torch.nn.Module):
else:
self.source_sample_rate = source_sample_rate
# self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
self.transform = transforms.Compose([
transforms.Normalize(0.5, 0.5),
])
@ -37,10 +35,6 @@ class MusicDCAE(torch.nn.Module):
self.scale_factor = 0.1786
self.shift_factor = -1.9091
def load_audio(self, audio_path):
audio, sr = torchaudio.load(audio_path)
return audio, sr
def forward_mel(self, audios):
mels = []
for i in range(len(audios)):
@ -73,10 +67,8 @@ class MusicDCAE(torch.nn.Module):
latent = self.dcae.encoder(mel.unsqueeze(0))
latents.append(latent)
latents = torch.cat(latents, dim=0)
# latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
latents = (latents - self.shift_factor) * self.scale_factor
return latents
# return latents, latent_lengths
@torch.no_grad()
def decode(self, latents, audio_lengths=None, sr=None):
@ -91,9 +83,7 @@ class MusicDCAE(torch.nn.Module):
wav = self.vocoder.decode(mels[0]).squeeze(1)
if sr is not None:
# resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
wav = torchaudio.functional.resample(wav, 44100, sr)
# wav = resampler(wav)
else:
sr = 44100
pred_wavs.append(wav)
@ -101,7 +91,6 @@ class MusicDCAE(torch.nn.Module):
if audio_lengths is not None:
pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
return torch.stack(pred_wavs)
# return sr, pred_wavs
def forward(self, audios, audio_lengths=None, sr=None):
latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)

View File

@ -365,8 +365,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["patch_size"] = 2
dit_config["in_channels"] = 16
dit_config["dim"] = 2304
dit_config["cap_feat_dim"] = 2304
dit_config["n_layers"] = 26
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True

View File

@ -890,6 +890,7 @@ class TEModel(Enum):
QWEN25_3B = 10
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -912,6 +913,8 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
@ -1016,6 +1019,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)

View File

@ -3,6 +3,7 @@ import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
import math
import logging
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
@ -28,6 +29,9 @@ class Llama2Config:
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Qwen25_3BConfig:
@ -46,6 +50,9 @@ class Qwen25_3BConfig:
mlp_activation = "silu"
qkv_bias = True
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Qwen25_7BVLI_Config:
@ -64,6 +71,9 @@ class Qwen25_7BVLI_Config:
mlp_activation = "silu"
qkv_bias = True
rope_dims = [16, 24, 24]
q_norm = None
k_norm = None
rope_scale = None
@dataclass
class Gemma2_2B_Config:
@ -82,6 +92,32 @@ class Gemma2_2B_Config:
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
sliding_attention = None
rope_scale = None
@dataclass
class Gemma3_4B_Config:
vocab_size: int = 262208
hidden_size: int = 2560
intermediate_size: int = 10240
num_hidden_layers: int = 34
num_attention_heads: int = 8
num_key_value_heads: int = 4
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [10000.0, 1000000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0]
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@ -106,25 +142,40 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out = []
for index, t in enumerate(theta):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (t ** (theta_numerator / head_dim))
return (cos, sin)
if rope_scale is not None:
if isinstance(rope_scale, list):
inv_freq /= rope_scale[index]
else:
inv_freq /= rope_scale
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
out.append((cos, sin))
if len(out) == 1:
return out[0]
return out
def apply_rope(xq, xk, freqs_cis):
@ -152,6 +203,14 @@ class Attention(nn.Module):
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
self.q_norm = None
self.k_norm = None
if config.q_norm == "gemma3":
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.k_norm == "gemma3":
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
@ -168,6 +227,11 @@ class Attention(nn.Module):
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.q_norm is not None:
xq = self.q_norm(xq)
if self.k_norm is not None:
xk = self.k_norm(xk)
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@ -192,7 +256,7 @@ class MLP(nn.Module):
return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@ -226,7 +290,7 @@ class TransformerBlock(nn.Module):
return x
class TransformerBlockGemma2(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
@ -235,6 +299,13 @@ class TransformerBlockGemma2(nn.Module):
self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens)
self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
else:
self.sliding_attention = False
self.transformer_type = config.transformer_type
def forward(
self,
x: torch.Tensor,
@ -242,6 +313,14 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
if self.transformer_type == 'gemma3':
if self.sliding_attention:
if x.shape[1] > self.sliding_attention:
logging.warning("Warning: sliding attention not implemented, results may be incorrect")
freqs_cis = freqs_cis[1]
else:
freqs_cis = freqs_cis[0]
# Self Attention
residual = x
x = self.input_layernorm(x)
@ -276,7 +355,7 @@ class Llama2_(nn.Module):
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2":
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.normalize_in = True
else:
@ -284,8 +363,8 @@ class Llama2_(nn.Module):
self.normalize_in = False
self.layers = nn.ModuleList([
transformer(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
@ -305,6 +384,7 @@ class Llama2_(nn.Module):
freqs_cis = precompute_freqs_cis(self.config.head_dim,
position_ids,
self.config.rope_theta,
self.config.rope_scale,
self.config.rope_dims,
device=x.device)
@ -433,3 +513,12 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype

View File

@ -11,23 +11,41 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
class NTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer)
class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None):
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
@ -35,5 +53,5 @@ def te(dtype_llama=None, llama_scaled_fp8=None):
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
return LuminaTEModel_

View File

@ -152,7 +152,7 @@ def validate_aspect_ratio(
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
elif calculated_ratio > maximum_ratio:
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)

File diff suppressed because it is too large Load Diff

View File

@ -473,7 +473,7 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
height=width_height["height"],
use_negative_prompts=True,
)
"""Upload image to comfy backend to have a URL available for further processing"""
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
@ -591,7 +591,6 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
# Only include motion_intensity for Motion Transfer

File diff suppressed because it is too large Load Diff

View File

@ -146,7 +146,7 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
comfy_io.String.Input(
"negative_prompt",
default="",
force_input=True,
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
@ -284,7 +284,7 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
comfy_io.String.Input(
"negative_prompt",
default="",
force_input=True,
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),
@ -425,7 +425,7 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
comfy_io.String.Input(
"negative_prompt",
default="",
force_input=True,
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
optional=True,
),

View File

@ -107,7 +107,7 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
# if list already exists exists, just extend list with data
for check_list in lists_to_check:
for conv_tuple in check_list:
if conv_tuple[0] == parent_key and type(conv_tuple[1]) is list:
if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list):
conv_tuple[1].append(formatter(data))
return True
return False
@ -119,7 +119,7 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
if formatter is None:
formatter = lambda v: v # Multipart representation of value
if type(data) is not dict:
if not isinstance(data, dict):
# if list already exists exists, just extend list with data
added = handle_converted_lists(data, parent_key, converted_to_check)
if added:
@ -136,9 +136,9 @@ def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, co
for key, value in data.items():
current_key = key if parent_key is None else f"{parent_key}[{key}]"
if type(value) is dict:
if isinstance(value, dict):
converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items())
elif type(value) is list:
elif isinstance(value, list):
for ind, list_value in enumerate(value):
iter_key = f"{current_key}[]"
converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items())

View File

@ -360,7 +360,7 @@ class RecordAudio:
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.62"
__version__ = "0.3.63"

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.62"
version = "0.3.63"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
@ -57,18 +57,14 @@ messages_control.disable = [
"redefined-builtin",
"unnecessary-lambda",
"dangerous-default-value",
"invalid-overridden-method",
# next warnings should be fixed in future
"bad-classmethod-argument", # Class method should have 'cls' as first argument
"wrong-import-order", # Standard imports should be placed before third party imports
"logging-fstring-interpolation", # Use lazy % formatting in logging functions
"ungrouped-imports",
"unnecessary-pass",
"unidiomatic-typecheck",
"unnecessary-lambda-assignment",
"no-else-return",
"no-else-raise",
"invalid-overridden-method",
"unused-variable",
"pointless-string-statement",
"redefined-outer-name",
]

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.27.7
comfyui-workflow-templates==0.1.91
comfyui-workflow-templates==0.1.93
comfyui-embedded-docs==0.2.6
comfyui_manager==4.0.2
torch
@ -26,6 +26,5 @@ av>=14.2.0
#non essential dependencies:
kornia>=0.7.1
spandrel
soundfile
pydantic~=2.0
pydantic-settings~=2.0