Merge branch 'comfyanonymous:master' into master

This commit is contained in:
RandomGitUser321 2025-08-21 11:15:06 -04:00 committed by GitHub
commit 0a320efbae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 1825 additions and 77 deletions

View File

@ -71,6 +71,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)

View File

@ -363,10 +363,17 @@ class UserManager():
if not overwrite and os.path.exists(path):
return web.Response(status=409, text="File already exists")
body = await request.read()
try:
body = await request.read()
with open(path, "wb") as f:
f.write(body)
with open(path, "wb") as f:
f.write(body)
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(
status=400,
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
)
user_path = self.get_request_user_filepath(request, None)
if full_info:

View File

@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else:

View File

@ -164,8 +164,11 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0])
self._step = int(indexes[0])
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model

View File

@ -224,19 +224,27 @@ class Flux(nn.Module):
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
for ref in ref_latents:
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
else:
h_offset = h
index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset)
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))

View File

@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module):
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
def pos_embeds(self, x, context):
def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
txt_start = round(max(h_len, w_len))
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward(
self,
@ -356,6 +360,7 @@ class QwenImageTransformer2DModel(nn.Module):
context,
attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None,
transformer_options={},
**kwargs
):
@ -363,13 +368,39 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask
image_rotary_emb = self.pos_embeds(x, context)
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@ -385,6 +416,7 @@ class QwenImageTransformer2DModel(nn.Module):
)
patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
@ -405,9 +437,15 @@ class QwenImageTransformer2DModel(nn.Module):
image_rotary_emb=image_rotary_emb,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -768,7 +768,12 @@ class CameraWanModel(WanModel):
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
if model_type == 'camera':
model_type = 'i2v'
else:
model_type = 't2v'
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)

View File

@ -890,6 +890,10 @@ class Flux(BaseModel):
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out
def extra_conds_shapes(self, **kwargs):
@ -1321,10 +1325,28 @@ class Omnigen2(BaseModel):
class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out

View File

@ -364,7 +364,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
else:
dit_config["model_type"] = "camera_2.2"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"

View File

@ -593,7 +593,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
models = set(models)
models_temp = set()
for m in models:
models_temp.add(m)
for mm in m.model_patches_models():
models_temp.add(mm)
models = models_temp
models_to_load = []

View File

@ -430,6 +430,9 @@ class ModelPatcher:
def set_model_forward_timestep_embed_patch(self, patch):
self.set_model_patch(patch, "forward_timestep_embed_patch")
def set_model_double_block_patch(self, patch):
self.set_model_patch(patch, "double_block")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
@ -486,6 +489,30 @@ class ModelPatcher:
if hasattr(wrap_func, "to"):
self.model_options["model_function_wrapper"] = wrap_func.to(device)
def model_patches_models(self):
to = self.model_options["transformer_options"]
models = []
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "models"):
models += patch_list[i].models()
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "models"):
models += patch_list[k].models()
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, "models"):
models += wrap_func.models()
return models
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()

View File

@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
index = 0
pad_extra = 0
embeds_info = []
for o in other_embeds:
emb = o[1]
if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb}
extra = None
emb_type = emb.get("type", None)
if emb_type == "embedding":
emb = emb.get("data", None)
else:
if hasattr(self.transformer, "preprocess_embed"):
emb = self.transformer.preprocess_embed(emb, device=device)
emb, extra = self.transformer.preprocess_embed(emb, device=device)
else:
emb = None
@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
index += emb_shape - 1
embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
else:
index += -1
pad_extra += emb_shape
@ -243,11 +246,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
attention_masks.append(attention_mask)
num_tokens.append(sum(attention_mask))
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None
if self.enable_attention_masks:
@ -258,7 +261,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else:
intermediate_output = self.layer_idx
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
if self.layer == "last":
z = outputs[0].float()
@ -531,7 +534,10 @@ class SDTokenizer:
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
if kwargs.get("disable_weights", False):
parsed_weights = [(text, 1.0)]
else:
parsed_weights = token_weights(text, 1.0)
# tokenize words
tokens = []

View File

@ -1046,6 +1046,18 @@ class WAN21_Camera(WAN21_T2V):
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN22_Camera(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "camera_2.2",
"in_dim": 36,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN21_Vace(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1260,6 +1272,6 @@ class QwenImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module):
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
mask = None
if attention_mask is not None:

View File

@ -2,12 +2,14 @@ import torch
import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any
import math
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit
import comfy.model_management
from . import qwen_vl
@dataclass
class Llama2Config:
@ -25,6 +27,7 @@ class Llama2Config:
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
@dataclass
class Qwen25_3BConfig:
@ -42,6 +45,7 @@ class Qwen25_3BConfig:
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
rope_dims = None
@dataclass
class Qwen25_7BVLI_Config:
@ -59,6 +63,7 @@ class Qwen25_7BVLI_Config:
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = True
rope_dims = [16, 24, 24]
@dataclass
class Gemma2_2B_Config:
@ -76,6 +81,7 @@ class Gemma2_2B_Config:
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@ -100,24 +106,30 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
return (cos, sin)
def apply_rope(xq, xk, freqs_cis):
cos = freqs_cis[0].unsqueeze(1)
sin = freqs_cis[1].unsqueeze(1)
cos = freqs_cis[0]
sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed, k_embed
@ -277,7 +289,7 @@ class Llama2_(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
if embeds is not None:
x = embeds
else:
@ -286,9 +298,13 @@ class Llama2_(nn.Module):
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
if position_ids is None:
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
freqs_cis = precompute_freqs_cis(self.config.head_dim,
x.shape[1],
position_ids,
self.config.rope_theta,
self.config.rope_dims,
device=x.device)
mask = None
@ -372,8 +388,38 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
return self.visual(image.to(device, dtype=torch.float32), grid), grid
return None, None
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
grid = None
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
start = e.get("index")
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
position_ids[0, start:end] = start
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
if grid is None:
position_ids = None
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@ -15,13 +15,27 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs):
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == 151655:
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return tokens
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):

View File

@ -0,0 +1,428 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math
from comfy.ldm.modules.attention import optimized_attention_for_device
def process_qwen2vl_images(
images: torch.Tensor,
min_pixels: int = 3136,
max_pixels: int = 12845056,
patch_size: int = 14,
temporal_patch_size: int = 2,
merge_size: int = 2,
image_mean: list = None,
image_std: list = None,
):
if image_mean is None:
image_mean = [0.48145466, 0.4578275, 0.40821073]
if image_std is None:
image_std = [0.26862954, 0.26130258, 0.27577711]
batch_size, height, width, channels = images.shape
device = images.device
# dtype = images.dtype
images = images.permute(0, 3, 1, 2)
grid_thw_list = []
img = images[0]
factor = patch_size * merge_size
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
img_resized = F.interpolate(
img.unsqueeze(0),
size=(h_bar, w_bar),
mode='bilinear',
align_corners=False
).squeeze(0)
normalized = img_resized.clone()
for c in range(3):
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
grid_h = h_bar // patch_size
grid_w = w_bar // patch_size
grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long)
pixel_values = normalized
grid_thw_list.append(grid_thw)
image_grid_thw = torch.stack(grid_thw_list)
grid_t = 1
channel = pixel_values.shape[0]
pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1)
patches = pixel_values.reshape(
grid_t,
temporal_patch_size,
channel,
grid_h // merge_size,
merge_size,
patch_size,
grid_w // merge_size,
merge_size,
patch_size,
)
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w,
channel * temporal_patch_size * patch_size * patch_size
)
return flatten_patches, image_grid_thw
class VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
embed_dim: int = 3584,
device=None,
dtype=None,
ops=None,
):
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = ops.Conv3d(
in_channels,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
device=device,
dtype=dtype
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states = self.proj(hidden_states)
return hidden_states.view(-1, self.embed_dim)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(q, k, cos, sin):
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0):
super().__init__()
self.dim = dim
self.theta = theta
def forward(self, seqlen: int, device) -> torch.Tensor:
inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim))
seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.outer(seq, inv_freq)
return freqs
class PatchMerger(nn.Module):
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None):
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size ** 2)
self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype)
self.mlp = nn.Sequential(
ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype),
nn.GELU(),
ops.Linear(self.hidden_size, dim, device=device, dtype=dtype),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x).reshape(-1, self.hidden_size)
x = self.mlp(x)
return x
class VisionAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scaling = self.head_dim ** -0.5
self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype)
self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cu_seqlens=None,
optimized_attention=None,
) -> torch.Tensor:
if hidden_states.dim() == 2:
seq_length, _ = hidden_states.shape
batch_size = 1
hidden_states = hidden_states.unsqueeze(0)
else:
batch_size, seq_length, _ = hidden_states.shape
qkv = self.qkv(hidden_states)
qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim)
query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class VisionMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None):
super().__init__()
self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
self.act_fn = nn.SiLU()
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class VisionBlock(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None):
super().__init__()
self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cu_seqlens=None,
optimized_attention=None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen2VLVisionTransformer(nn.Module):
def __init__(
self,
hidden_size: int = 3584,
output_hidden_size: int = 3584,
intermediate_size: int = 3420,
num_heads: int = 16,
num_layers: int = 32,
patch_size: int = 14,
temporal_patch_size: int = 2,
spatial_merge_size: int = 2,
window_size: int = 112,
device=None,
dtype=None,
ops=None
):
super().__init__()
self.hidden_size = hidden_size
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.window_size = window_size
self.fullatt_block_indexes = [7, 15, 23, 31]
self.patch_embed = VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=3,
embed_dim=hidden_size,
device=device,
dtype=dtype,
ops=ops,
)
head_dim = hidden_size // num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops)
for _ in range(num_layers)
])
self.merger = PatchMerger(
dim=output_hidden_size,
context_dim=hidden_size,
spatial_merge_size=spatial_merge_size,
device=device,
dtype=dtype,
ops=ops,
)
def get_window_index(self, grid_thw):
window_index = []
cu_window_seqlens = [0]
window_index_id = 0
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h = grid_h // self.spatial_merge_size
llm_grid_w = grid_w // self.spatial_merge_size
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def get_position_embeddings(self, grid_thw, device):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device)
return rotary_pos_emb_full[pos_ids].flatten(1)
def forward(
self,
pixel_values: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
) -> torch.Tensor:
optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True)
hidden_states = self.patch_embed(pixel_values)
window_index, cu_window_seqlens = self.get_window_index(image_grid_thw)
cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device)
seq_len, _ = hidden_states.size()
spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
position_embeddings = position_embeddings[window_index, :, :]
position_embeddings = position_embeddings.reshape(seq_len, -1)
position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1)
position_embeddings = (position_embeddings.cos(), position_embeddings.sin())
cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
dim=0,
dtype=torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for i, block in enumerate(self.blocks):
if i in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
hidden_states = self.merger(hidden_states)
return hidden_states

View File

@ -199,7 +199,7 @@ class T5Stack(torch.nn.Module):
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])

View File

@ -726,6 +726,10 @@ class SEGS(ComfyTypeIO):
class AnyType(ComfyTypeIO):
Type = Any
@comfytype(io_type="MODEL_PATCH")
class MODEL_PATCH(ComfyTypeIO):
Type = Any
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any

View File

@ -1315,6 +1315,7 @@ class KlingTaskStatus(str, Enum):
class KlingTextToVideoModelName(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_6 = 'kling-v1-6'
kling_v2_1_master = 'kling-v2-1-master'
class KlingVideoGenAspectRatio(str, Enum):
@ -1347,6 +1348,8 @@ class KlingVideoGenModelName(str, Enum):
kling_v1_5 = 'kling-v1-5'
kling_v1_6 = 'kling-v1-6'
kling_v2_master = 'kling-v2-master'
kling_v2_1 = 'kling-v2-1'
kling_v2_1_master = 'kling-v2-1-master'
class KlingVideoResult(BaseModel):
@ -1620,13 +1623,14 @@ class MinimaxTaskResultResponse(BaseModel):
task_id: str = Field(..., description='The task ID being queried.')
class Model(str, Enum):
class MiniMaxModel(str, Enum):
T2V_01_Director = 'T2V-01-Director'
I2V_01_Director = 'I2V-01-Director'
S2V_01 = 'S2V-01'
I2V_01 = 'I2V-01'
I2V_01_live = 'I2V-01-live'
T2V_01 = 'T2V-01'
Hailuo_02 = 'MiniMax-Hailuo-02'
class SubjectReferenceItem(BaseModel):
@ -1648,7 +1652,7 @@ class MinimaxVideoGenerationRequest(BaseModel):
None,
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
)
model: Model = Field(
model: MiniMaxModel = Field(
...,
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
)
@ -1665,6 +1669,14 @@ class MinimaxVideoGenerationRequest(BaseModel):
None,
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
)
duration: Optional[int] = Field(
None,
description="The length of the output video in seconds."
)
resolution: Optional[str] = Field(
None,
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
)
class MinimaxVideoGenerationResponse(BaseModel):

View File

@ -46,6 +46,8 @@ class GeminiModel(str, Enum):
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
gemini_2_5_pro = "gemini-2.5-pro"
gemini_2_5_flash = "gemini-2.5-flash"
def get_gemini_endpoint(
@ -97,7 +99,7 @@ class GeminiNode(ComfyNodeABC):
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiModel],
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
"default": GeminiModel.gemini_2_5_pro.value,
},
),
"seed": (

View File

@ -421,6 +421,8 @@ class KlingTextToVideoNode(KlingNodeBase):
"pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"),
"standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"),
"standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"),
"pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"),
"pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"),
}
@classmethod

View File

@ -1,3 +1,4 @@
from inspect import cleandoc
from typing import Union
import logging
import torch
@ -10,7 +11,7 @@ from comfy_api_nodes.apis import (
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem,
Model
MiniMaxModel
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
@ -84,7 +85,6 @@ class MinimaxTextToVideoNode:
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
OUTPUT_NODE = True
async def generate_video(
self,
@ -121,7 +121,7 @@ class MinimaxTextToVideoNode:
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=Model(model),
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
first_frame_image=image_url,
@ -251,7 +251,6 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
OUTPUT_NODE = True
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
@ -313,7 +312,181 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
OUTPUT_NODE = True
class MinimaxHailuoVideoNode:
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation.",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
"first_frame_image": (
IO.IMAGE,
{
"tooltip": "Optional image to use as the first frame to generate a video."
},
),
"prompt_optimizer": (
IO.BOOLEAN,
{
"tooltip": "Optimize prompt to improve generation quality when needed.",
"default": True,
},
),
"duration": (
IO.COMBO,
{
"tooltip": "The length of the output video in seconds.",
"default": 6,
"options": [6, 10],
},
),
"resolution": (
IO.COMBO,
{
"tooltip": "The dimensions of the video display. "
"1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.",
"default": "768P",
"options": ["768P", "1080P"],
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = cleandoc(__doc__ or "")
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
async def generate_video(
self,
prompt_text,
seed=0,
first_frame_image: torch.Tensor=None, # used for ImageToVideo
prompt_optimizer=True,
duration=6,
resolution="768P",
model="MiniMax-Hailuo-02",
unique_id: Union[str, None]=None,
**kwargs,
):
if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text")
if model == "MiniMax-Hailuo-02" and resolution.upper() == "1080P" and duration != 6:
raise Exception(
"When model is MiniMax-Hailuo-02 and resolution is 1080P, duration is limited to 6 seconds."
)
# upload image, if passed in
image_url = None
if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
first_frame_image=image_url,
prompt_optimizer=prompt_optimizer,
duration=duration,
resolution=resolution,
),
auth_kwargs=kwargs,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
average_duration = 120 if resolution == "768P" else 240
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=unique_id,
auth_kwargs=kwargs,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=kwargs,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
if unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id)
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return (VideoFromFile(video_io),)
# A dictionary that contains all nodes you want to export with their names
@ -322,6 +495,7 @@ NODE_CLASS_MAPPINGS = {
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
"MinimaxHailuoVideoNode": MinimaxHailuoVideoNode,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
@ -329,4 +503,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"MinimaxTextToVideoNode": "MiniMax Text to Video",
"MinimaxImageToVideoNode": "MiniMax Image to Video",
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
"MinimaxHailuoVideoNode": "MiniMax Hailuo Video",
}

View File

@ -80,6 +80,9 @@ class SupportedOpenAIModel(str, Enum):
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
gpt_5 = "gpt-5"
gpt_5_mini = "gpt-5-mini"
gpt_5_nano = "gpt-5-nano"
class OpenAIDalle2(ComfyNodeABC):
@ -464,8 +467,6 @@ class OpenAIGPTImage1(ComfyNodeABC):
path = "/proxy/openai/images/generations"
content_type = "application/json"
request_class = OpenAIImageGenerationRequest
img_binaries = []
mask_binary = None
files = []
if image is not None:
@ -484,14 +485,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
img_binary = img_byte_arr
img_binary.name = f"image_{i}.png"
img_binaries.append(img_binary)
if batch_size == 1:
files.append(("image", img_binary))
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
else:
files.append(("image[]", img_binary))
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None:
if image is None:
@ -511,9 +509,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
mask_img_byte_arr = io.BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0)
mask_binary = mask_img_byte_arr
mask_binary.name = "mask.png"
files.append(("mask", mask_binary))
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
# Build the operation
operation = SynchronousOperation(

View File

@ -0,0 +1,622 @@
import logging
from enum import Enum
from typing import Any, Callable, Optional, Literal, TypeVar
from typing_extensions import override
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.util.validation_utils import (
validate_aspect_ratio_closeness,
validate_image_dimensions,
validate_image_aspect_ratio_range,
get_number_of_images,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video"
VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video"
VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
R = TypeVar("R")
class VideoModelName(str, Enum):
vidu_q1 = 'viduq1'
class AspectRatio(str, Enum):
r_16_9 = "16:9"
r_9_16 = "9:16"
r_1_1 = "1:1"
class Resolution(str, Enum):
r_1080p = "1080p"
class MovementAmplitude(str, Enum):
auto = "auto"
small = "small"
medium = "medium"
large = "large"
class TaskCreationRequest(BaseModel):
model: VideoModelName = VideoModelName.vidu_q1
prompt: Optional[str] = Field(None, max_length=1500)
duration: Optional[Literal[5]] = 5
seed: Optional[int] = Field(0, ge=0, le=2147483647)
aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9
resolution: Optional[Resolution] = Resolution.r_1080p
movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
class TaskStatus(str, Enum):
created = "created"
queueing = "queueing"
processing = "processing"
success = "success"
failed = "failed"
class TaskCreationResponse(BaseModel):
task_id: str = Field(...)
state: TaskStatus = Field(...)
created_at: str = Field(...)
code: Optional[int] = Field(None, description="Error code")
class TaskResult(BaseModel):
id: str = Field(..., description="Creation id")
url: str = Field(..., description="The URL of the generated results, valid for one hour")
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
class TaskStatusResponse(BaseModel):
state: TaskStatus = Field(...)
err_code: Optional[str] = Field(None)
creations: list[TaskResult] = Field(..., description="Generated results")
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> R:
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[TaskStatus.success.value],
failed_statuses=[TaskStatus.failed.value],
status_extractor=lambda response: response.state.value,
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute()
def get_video_url_from_response(response) -> Optional[str]:
if response.creations:
return response.creations[0].url
return None
def get_video_from_response(response) -> TaskResult:
if not response.creations:
error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
logging.info(error_msg)
raise RuntimeError(error_msg)
logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url)
return response.creations[0]
async def execute_task(
vidu_endpoint: str,
auth_kwargs: Optional[dict[str, str]],
payload: TaskCreationRequest,
estimated_duration: int,
node_id: str,
) -> R:
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=vidu_endpoint,
method=HttpMethod.POST,
request_model=TaskCreationRequest,
response_model=TaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
if response.state == TaskStatus.failed:
error_msg = f"Vidu request failed. Code: {response.code}"
logging.error(error_msg)
raise RuntimeError(error_msg)
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=VIDU_GET_GENERATION_STATUS % response.task_id,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=estimated_duration,
node_id=node_id,
)
class ViduTextToVideoNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="ViduTextToVideoNode",
display_name="Vidu Text To Video Generation",
category="api node/video/Vidu",
description="Generate video from text prompt",
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
comfy_io.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation",
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
tooltip="The aspect ratio of the output video",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
duration: int,
seed: int,
aspect_ratio: str,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
if not prompt:
raise ValueError("The prompt field is required and cannot be empty.")
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
duration=duration,
seed=seed,
aspect_ratio=aspect_ratio,
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduImageToVideoNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="ViduImageToVideoNode",
display_name="Vidu Image To Video Generation",
category="api node/video/Vidu",
description="Generate video from image and optional prompt",
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
comfy_io.Image.Input(
"image",
tooltip="An image to be used as the start frame of the generated video",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="A textual description for video generation",
optional=True,
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str,
duration: int,
seed: int,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
duration=duration,
seed=seed,
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = await upload_images_to_comfyapi(
image,
max_images=1,
mime_type="image/png",
auth_kwargs=auth,
)
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduReferenceVideoNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="ViduReferenceVideoNode",
display_name="Vidu Reference To Video Generation",
category="api node/video/Vidu",
description="Generate video from multiple images and prompt",
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
comfy_io.Image.Input(
"images",
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
),
comfy_io.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation",
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
tooltip="The aspect ratio of the output video",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
images: torch.Tensor,
prompt: str,
duration: int,
seed: int,
aspect_ratio: str,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
if not prompt:
raise ValueError("The prompt field is required and cannot be empty.")
a = get_number_of_images(images)
if a > 7:
raise ValueError("Too many images, maximum allowed is 7.")
for image in images:
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
duration=duration,
seed=seed,
aspect_ratio=aspect_ratio,
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = await upload_images_to_comfyapi(
images,
max_images=7,
mime_type="image/png",
auth_kwargs=auth,
)
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduStartEndToVideoNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="ViduStartEndToVideoNode",
display_name="Vidu Start End To Video Generation",
category="api node/video/Vidu",
description="Generate a video from start and end frames and a prompt",
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
comfy_io.Image.Input(
"first_frame",
tooltip="Start frame",
),
comfy_io.Image.Input(
"end_frame",
tooltip="End frame",
),
comfy_io.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation",
optional=True,
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
first_frame: torch.Tensor,
end_frame: torch.Tensor,
prompt: str,
duration: int,
seed: int,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
duration=duration,
seed=seed,
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = [
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
for frame in (first_frame, end_frame)
]
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
ViduTextToVideoNode,
ViduImageToVideoNode,
ViduReferenceVideoNode,
ViduStartEndToVideoNode,
]
async def comfy_entrypoint() -> ViduExtension:
return ViduExtension()

View File

@ -53,6 +53,53 @@ def validate_image_aspect_ratio(
)
def validate_image_aspect_ratio_range(
image: torch.Tensor,
min_ratio: tuple[float, float], # e.g. (1, 4)
max_ratio: tuple[float, float], # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
a1, b1 = min_ratio
a2, b2 = max_ratio
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
lo, hi = (a1 / b1), (a2 / b2)
if lo > hi:
lo, hi = hi, lo
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
w, h = get_image_dimensions(image)
if w <= 0 or h <= 0:
raise ValueError(f"Invalid image dimensions: {w}x{h}")
ar = w / h
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
if not ok:
op = "<" if strict else ""
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
return ar
def validate_aspect_ratio_closeness(
start_img,
end_img,
min_rel: float,
max_rel: float,
*,
strict: bool = False, # True => exclusive, False => inclusive
) -> None:
w1, h1 = get_image_dimensions(start_img)
w2, h2 = get_image_dimensions(end_img)
if min(w1, h1, w2, h2) <= 0:
raise ValueError("Invalid image dimensions")
ar1 = w1 / h1
ar2 = w2 / h2
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
closeness = max(ar1, ar2) / min(ar1, ar2)
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
if (closeness >= limit) if strict else (closeness > limit):
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}{max_rel}.")
def validate_video_dimensions(
video: VideoInput,
min_width: Optional[int] = None,
@ -98,3 +145,9 @@ def validate_video_duration(
raise ValueError(
f"Video duration must be at most {max_duration}s, got {duration}s"
)
def get_number_of_images(images):
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
return len(images)

View File

@ -346,6 +346,24 @@ class LoadAudio:
return "Invalid audio file: {}".format(audio)
return True
class RecordAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"audio": ("AUDIO_RECORD", {})}}
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio,
@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = {
"LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
"RecordAudio": RecordAudio,
}
NODE_DISPLAY_NAME_MAPPINGS = {
@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio",
}

View File

@ -100,9 +100,28 @@ class FluxKontextImageScale:
return (image, )
class FluxKontextMultiReferenceLatentMethod:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index"), ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
EXPERIMENTAL = True
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance,
"FluxKontextImageScale": FluxKontextImageScale,
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
}

View File

@ -166,7 +166,7 @@ class LTXVAddGuide:
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1),
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,

View File

@ -0,0 +1,161 @@
import torch
import folder_paths
import comfy.utils
import comfy.ops
import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
class BlockWiseControlBlock(torch.nn.Module):
# [linear, gelu, linear]
def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None):
super().__init__()
self.x_rms = operations.RMSNorm(dim, eps=1e-6)
self.y_rms = operations.RMSNorm(dim, eps=1e-6)
self.input_proj = operations.Linear(dim, dim)
self.act = torch.nn.GELU()
self.output_proj = operations.Linear(dim, dim)
def forward(self, x, y):
x, y = self.x_rms(x), self.y_rms(y)
x = self.input_proj(x + y)
x = self.act(x)
x = self.output_proj(x)
return x
class QwenImageBlockWiseControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
in_dim: int = 64,
additional_in_dim: int = 0,
dim: int = 3072,
device=None, dtype=None, operations=None
):
super().__init__()
self.additional_in_dim = additional_in_dim
self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype)
self.controlnet_blocks = torch.nn.ModuleList(
[
BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
]
)
def process_input_latent_image(self, latent_image):
latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16])
patch_size = 2
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
return self.img_in(hidden_states)
def control_block(self, img, controlnet_conditioning, block_id):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ),
}}
RETURN_TYPES = ("MODEL_PATCH",)
FUNCTION = "load_model_patch"
EXPERIMENTAL = True
CATEGORY = "advanced/loaders"
def load_model_patch(self, name):
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
dtype = comfy.utils.weight_dtype(sd)
# TODO: this node will work with more types of model patches
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)
class DiffSynthCnetPatch:
def __init__(self, model_patch, vae, image, strength, mask=None):
self.model_patch = model_patch
self.vae = vae
self.image = image
self.strength = strength
self.mask = mask
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
def encode_latent_cond(self, image):
latent_image = self.vae.encode(image)
if self.model_patch.model.additional_in_dim > 0:
if self.mask is None:
mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4]
else:
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
return torch.cat([latent_image, mask_], dim=1)
else:
return latent_image
def __call__(self, kwargs):
x = kwargs.get("x")
img = kwargs.get("img")
block_index = kwargs.get("block_index")
if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]:
spacial_compression = self.vae.spacial_compression_encode()
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
comfy.model_management.load_models_gpu(loaded_models)
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
kwargs['img'] = img
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
return self
def models(self):
return [self.model_patch]
class QwenImageDiffsynthControlnet:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"vae": ("VAE",),
"image": ("IMAGE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
},
"optional": {"mask": ("MASK",)}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "diffsynth_controlnet"
EXPERIMENTAL = True
CATEGORY = "advanced/loaders/qwen"
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
model_patched = model.clone()
image = image[:, :, :, :3]
if mask is not None:
if mask.ndim == 3:
mask = mask.unsqueeze(1)
if mask.ndim == 4:
mask = mask.unsqueeze(2)
mask = 1.0 - mask
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
}

View File

@ -0,0 +1,48 @@
import node_helpers
import comfy.utils
import math
class TextEncodeQwenImageEdit:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
},
"optional": {"vae": ("VAE", ),
"image": ("IMAGE", ),}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, prompt, vae=None, image=None):
ref_latent = None
if image is None:
images = []
else:
samples = image.movedim(-1, 1)
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
image = s.movedim(1, -1)
images = [image[:, :, :, :3]]
if vae is not None:
ref_latent = vae.encode(image[:, :, :, :3])
tokens = clip.tokenize(prompt, images=images)
conditioning = clip.encode_from_tokens_scheduled(tokens)
if ref_latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
return (conditioning, )
NODE_CLASS_MAPPINGS = {
"TextEncodeQwenImageEdit": TextEncodeQwenImageEdit,
}

View File

@ -422,9 +422,12 @@ class WanCameraImageToVideo(io.ComfyNode):
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
mask[:, :, :start_image.shape[0] + 3] = 0.0
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask})
if camera_conditions is not None:
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
@ -696,7 +699,7 @@ class WanTrackToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanPhantomSubjectToVideo",
node_id="WanTrackToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.50"
__version__ = "0.3.51"

View File

@ -46,6 +46,8 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input")

View File

@ -2321,6 +2321,8 @@ async def init_builtin_extra_nodes():
"nodes_edit_model.py",
"nodes_tcfg.py",
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_model_patch.py"
]
import_failed = []
@ -2350,6 +2352,7 @@ async def init_builtin_api_nodes():
"nodes_moonvalley.py",
"nodes_rodin.py",
"nodes_gemini.py",
"nodes_vidu.py",
]
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.50"
version = "0.3.51"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.24.4
comfyui-workflow-templates==0.1.59
comfyui-frontend-package==1.25.9
comfyui-workflow-templates==0.1.62
comfyui-embedded-docs==0.2.6
torch
torchsde