Merge branch 'comfyanonymous:master' into master

This commit is contained in:
RandomGitUser321 2025-08-21 11:15:06 -04:00 committed by GitHub
commit 0a320efbae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 1825 additions and 77 deletions

View File

@ -71,6 +71,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
- [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11) - [HiDream E1.1](https://comfyanonymous.github.io/ComfyUI_examples/hidream/#hidream-e11)
- [Qwen Image Edit](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/#edit-model)
- Video Models - Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/) - [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/) - [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)

View File

@ -363,10 +363,17 @@ class UserManager():
if not overwrite and os.path.exists(path): if not overwrite and os.path.exists(path):
return web.Response(status=409, text="File already exists") return web.Response(status=409, text="File already exists")
body = await request.read() try:
body = await request.read()
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(body) f.write(body)
except OSError as e:
logging.warning(f"Error saving file '{path}': {e}")
return web.Response(
status=400,
reason="Invalid filename. Please avoid special characters like :\\/*?\"<>|"
)
user_path = self.get_request_user_filepath(request, None) user_path = self.get_request_user_filepath(request, None)
if full_info: if full_info:

View File

@ -97,7 +97,7 @@ class CLIPTextModel_(torch.nn.Module):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32, embeds_info=[]):
if embeds is not None: if embeds is not None:
x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device) x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
else: else:

View File

@ -164,8 +164,11 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]): def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
indexes = torch.where(model_options["transformer_options"]["sample_sigmas"] == timestep[0]) mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
self._step = int(indexes[0]) matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
self._step = int(matches[0].item())
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model full_length = x_in.size(self.dim) # TODO: choose dim based on model

View File

@ -224,19 +224,27 @@ class Flux(nn.Module):
if ref_latents is not None: if ref_latents is not None:
h = 0 h = 0
w = 0 w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
for ref in ref_latents: for ref in ref_latents:
h_offset = 0 if index_ref_method:
w_offset = 0 index += 1
if ref.shape[-2] + h > ref.shape[-1] + w: h_offset = 0
w_offset = w w_offset = 0
else: else:
h_offset = h index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1) img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))

View File

@ -333,21 +333,25 @@ class QwenImageTransformer2DModel(nn.Module):
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device) self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def pos_embeds(self, x, context): def process_img(self, x, index=0, h_offset=0, w_offset=0):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
patch_size = self.patch_size patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
h_len = ((h + (patch_size // 2)) // patch_size) h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) h_offset = ((h_offset + (patch_size // 2)) // patch_size)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_start = round(max(h_len, w_len)) img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
txt_ids = torch.linspace(txt_start, txt_start + context.shape[1], steps=context.shape[1], device=x.device, dtype=x.dtype).reshape(1, -1, 1).repeat(bs, 1, 3) img_ids[:, :, 0] = img_ids[:, :, 1] + index
ids = torch.cat((txt_ids, img_ids), dim=1) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
return self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
def forward( def forward(
self, self,
@ -356,6 +360,7 @@ class QwenImageTransformer2DModel(nn.Module):
context, context,
attention_mask=None, attention_mask=None,
guidance: torch.Tensor = None, guidance: torch.Tensor = None,
ref_latents=None,
transformer_options={}, transformer_options={},
**kwargs **kwargs
): ):
@ -363,13 +368,39 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = context encoder_hidden_states = context
encoder_hidden_states_mask = attention_mask encoder_hidden_states_mask = attention_mask
image_rotary_emb = self.pos_embeds(x, context) hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) if ref_latents is not None:
orig_shape = hidden_states.shape h = 0
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) w = 0
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) index = 0
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
w_offset = 0
if ref.shape[-2] + h > ref.shape[-1] + w:
w_offset = w
else:
h_offset = h
h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@ -385,6 +416,7 @@ class QwenImageTransformer2DModel(nn.Module):
) )
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
@ -405,9 +437,15 @@ class QwenImageTransformer2DModel(nn.Module):
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
) )
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -768,7 +768,12 @@ class CameraWanModel(WanModel):
operations=None, operations=None,
): ):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations) if model_type == 'camera':
model_type = 'i2v'
else:
model_type = 't2v'
super().__init__(model_type=model_type, patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype} operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings) self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)

View File

@ -890,6 +890,10 @@ class Flux(BaseModel):
for lat in ref_latents: for lat in ref_latents:
latents.append(self.process_latent_in(lat)) latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents) out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out return out
def extra_conds_shapes(self, **kwargs): def extra_conds_shapes(self, **kwargs):
@ -1321,10 +1325,28 @@ class Omnigen2(BaseModel):
class QwenImage(BaseModel): class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None): def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
ref_latents_method = kwargs.get("reference_latents_method", None)
if ref_latents_method is not None:
out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out return out

View File

@ -364,7 +364,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1] dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.') dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys: elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera" if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
else:
dit_config["model_type"] = "camera_2.2"
else: else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v" dit_config["model_type"] = "i2v"

View File

@ -593,7 +593,13 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else: else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
models = set(models) models_temp = set()
for m in models:
models_temp.add(m)
for mm in m.model_patches_models():
models_temp.add(mm)
models = models_temp
models_to_load = [] models_to_load = []

View File

@ -430,6 +430,9 @@ class ModelPatcher:
def set_model_forward_timestep_embed_patch(self, patch): def set_model_forward_timestep_embed_patch(self, patch):
self.set_model_patch(patch, "forward_timestep_embed_patch") self.set_model_patch(patch, "forward_timestep_embed_patch")
def set_model_double_block_patch(self, patch):
self.set_model_patch(patch, "double_block")
def add_object_patch(self, name, obj): def add_object_patch(self, name, obj):
self.object_patches[name] = obj self.object_patches[name] = obj
@ -486,6 +489,30 @@ class ModelPatcher:
if hasattr(wrap_func, "to"): if hasattr(wrap_func, "to"):
self.model_options["model_function_wrapper"] = wrap_func.to(device) self.model_options["model_function_wrapper"] = wrap_func.to(device)
def model_patches_models(self):
to = self.model_options["transformer_options"]
models = []
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], "models"):
models += patch_list[i].models()
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], "models"):
models += patch_list[k].models()
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, "models"):
models += wrap_func.models()
return models
def model_dtype(self): def model_dtype(self):
if hasattr(self.model, "get_dtype"): if hasattr(self.model, "get_dtype"):
return self.model.get_dtype() return self.model.get_dtype()

View File

@ -204,17 +204,19 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32) tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
index = 0 index = 0
pad_extra = 0 pad_extra = 0
embeds_info = []
for o in other_embeds: for o in other_embeds:
emb = o[1] emb = o[1]
if torch.is_tensor(emb): if torch.is_tensor(emb):
emb = {"type": "embedding", "data": emb} emb = {"type": "embedding", "data": emb}
extra = None
emb_type = emb.get("type", None) emb_type = emb.get("type", None)
if emb_type == "embedding": if emb_type == "embedding":
emb = emb.get("data", None) emb = emb.get("data", None)
else: else:
if hasattr(self.transformer, "preprocess_embed"): if hasattr(self.transformer, "preprocess_embed"):
emb = self.transformer.preprocess_embed(emb, device=device) emb, extra = self.transformer.preprocess_embed(emb, device=device)
else: else:
emb = None emb = None
@ -229,6 +231,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1) tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:] attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
index += emb_shape - 1 index += emb_shape - 1
embeds_info.append({"type": emb_type, "index": ind, "size": emb_shape, "extra": extra})
else: else:
index += -1 index += -1
pad_extra += emb_shape pad_extra += emb_shape
@ -243,11 +246,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
attention_masks.append(attention_mask) attention_masks.append(attention_mask)
num_tokens.append(sum(attention_mask)) num_tokens.append(sum(attention_mask))
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens): def forward(self, tokens):
device = self.transformer.get_input_embeddings().weight.device device = self.transformer.get_input_embeddings().weight.device
embeds, attention_mask, num_tokens = self.process_tokens(tokens, device) embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None attention_mask_model = None
if self.enable_attention_masks: if self.enable_attention_masks:
@ -258,7 +261,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
else: else:
intermediate_output = self.layer_idx intermediate_output = self.layer_idx
outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
if self.layer == "last": if self.layer == "last":
z = outputs[0].float() z = outputs[0].float()
@ -531,7 +534,10 @@ class SDTokenizer:
min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text) text = escape_important(text)
parsed_weights = token_weights(text, 1.0) if kwargs.get("disable_weights", False):
parsed_weights = [(text, 1.0)]
else:
parsed_weights = token_weights(text, 1.0)
# tokenize words # tokenize words
tokens = [] tokens = []

View File

@ -1046,6 +1046,18 @@ class WAN21_Camera(WAN21_T2V):
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device) out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out return out
class WAN22_Camera(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "camera_2.2",
"in_dim": 36,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN21_Vace(WAN21_T2V): class WAN21_Vace(WAN21_T2V):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -1260,6 +1272,6 @@ class QwenImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -116,7 +116,7 @@ class BertModel_(torch.nn.Module):
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations) self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations) self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype) x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
mask = None mask = None
if attention_mask is not None: if attention_mask is not None:

View File

@ -2,12 +2,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Any from typing import Optional, Any
import math
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management import comfy.model_management
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.model_management import comfy.model_management
from . import qwen_vl
@dataclass @dataclass
class Llama2Config: class Llama2Config:
@ -25,6 +27,7 @@ class Llama2Config:
rms_norm_add = False rms_norm_add = False
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = False qkv_bias = False
rope_dims = None
@dataclass @dataclass
class Qwen25_3BConfig: class Qwen25_3BConfig:
@ -42,6 +45,7 @@ class Qwen25_3BConfig:
rms_norm_add = False rms_norm_add = False
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = True qkv_bias = True
rope_dims = None
@dataclass @dataclass
class Qwen25_7BVLI_Config: class Qwen25_7BVLI_Config:
@ -59,6 +63,7 @@ class Qwen25_7BVLI_Config:
rms_norm_add = False rms_norm_add = False
mlp_activation = "silu" mlp_activation = "silu"
qkv_bias = True qkv_bias = True
rope_dims = [16, 24, 24]
@dataclass @dataclass
class Gemma2_2B_Config: class Gemma2_2B_Config:
@ -76,6 +81,7 @@ class Gemma2_2B_Config:
rms_norm_add = True rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh" mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False qkv_bias = False
rope_dims = None
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@ -100,24 +106,30 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
def precompute_freqs_cis(head_dim, seq_len, theta, device=None): def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=None):
theta_numerator = torch.arange(0, head_dim, 2, device=device).float() theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim)) inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float() position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() cos = emb.cos()
sin = emb.sin() sin = emb.sin()
if rope_dims is not None and position_ids.shape[0] > 1:
mrope_section = rope_dims * 2
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(0)
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
return (cos, sin) return (cos, sin)
def apply_rope(xq, xk, freqs_cis): def apply_rope(xq, xk, freqs_cis):
cos = freqs_cis[0].unsqueeze(1) cos = freqs_cis[0]
sin = freqs_cis[1].unsqueeze(1) sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin) q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin) k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed, k_embed return q_embed, k_embed
@ -277,7 +289,7 @@ class Llama2_(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
if embeds is not None: if embeds is not None:
x = embeds x = embeds
else: else:
@ -286,9 +298,13 @@ class Llama2_(nn.Module):
if self.normalize_in: if self.normalize_in:
x *= self.config.hidden_size ** 0.5 x *= self.config.hidden_size ** 0.5
if position_ids is None:
position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0)
freqs_cis = precompute_freqs_cis(self.config.head_dim, freqs_cis = precompute_freqs_cis(self.config.head_dim,
x.shape[1], position_ids,
self.config.rope_theta, self.config.rope_theta,
self.config.rope_dims,
device=x.device) device=x.device)
mask = None mask = None
@ -372,8 +388,38 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
self.num_layers = config.num_hidden_layers self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
return self.visual(image.to(device, dtype=torch.float32), grid), grid
return None, None
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
grid = None
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
start = e.get("index")
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next, start_next + (embeds.shape[1] - end), device=embeds.device)
position_ids[0, start:end] = start
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start, start + max_d, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
if grid is None:
position_ids = None
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids)
class Gemma2_2B(BaseLlama, torch.nn.Module): class Gemma2_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()

View File

@ -15,13 +15,27 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_7b", tokenizer=Qwen25_7BVLITokenizer)
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
if llama_template is None: if llama_template is None:
llama_text = self.llama_template.format(text) if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else: else:
llama_text = llama_template.format(text) llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
qwen_tokens = tokens[key_name]
for r in qwen_tokens:
for i in range(len(r)):
if r[i][0] == 151655:
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return tokens
class Qwen25_7BVLIModel(sd1_clip.SDClipModel): class Qwen25_7BVLIModel(sd1_clip.SDClipModel):

View File

@ -0,0 +1,428 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math
from comfy.ldm.modules.attention import optimized_attention_for_device
def process_qwen2vl_images(
images: torch.Tensor,
min_pixels: int = 3136,
max_pixels: int = 12845056,
patch_size: int = 14,
temporal_patch_size: int = 2,
merge_size: int = 2,
image_mean: list = None,
image_std: list = None,
):
if image_mean is None:
image_mean = [0.48145466, 0.4578275, 0.40821073]
if image_std is None:
image_std = [0.26862954, 0.26130258, 0.27577711]
batch_size, height, width, channels = images.shape
device = images.device
# dtype = images.dtype
images = images.permute(0, 3, 1, 2)
grid_thw_list = []
img = images[0]
factor = patch_size * merge_size
h_bar = round(height / factor) * factor
w_bar = round(width / factor) * factor
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, math.floor(height / beta / factor) * factor)
w_bar = max(factor, math.floor(width / beta / factor) * factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = math.ceil(height * beta / factor) * factor
w_bar = math.ceil(width * beta / factor) * factor
img_resized = F.interpolate(
img.unsqueeze(0),
size=(h_bar, w_bar),
mode='bilinear',
align_corners=False
).squeeze(0)
normalized = img_resized.clone()
for c in range(3):
normalized[c] = (img_resized[c] - image_mean[c]) / image_std[c]
grid_h = h_bar // patch_size
grid_w = w_bar // patch_size
grid_thw = torch.tensor([1, grid_h, grid_w], device=device, dtype=torch.long)
pixel_values = normalized
grid_thw_list.append(grid_thw)
image_grid_thw = torch.stack(grid_thw_list)
grid_t = 1
channel = pixel_values.shape[0]
pixel_values = pixel_values.unsqueeze(0).repeat(2, 1, 1, 1)
patches = pixel_values.reshape(
grid_t,
temporal_patch_size,
channel,
grid_h // merge_size,
merge_size,
patch_size,
grid_w // merge_size,
merge_size,
patch_size,
)
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
flatten_patches = patches.reshape(
grid_t * grid_h * grid_w,
channel * temporal_patch_size * patch_size * patch_size
)
return flatten_patches, image_grid_thw
class VisionPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
embed_dim: int = 3584,
device=None,
dtype=None,
ops=None,
):
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = ops.Conv3d(
in_channels,
embed_dim,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
device=device,
dtype=dtype
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states = self.proj(hidden_states)
return hidden_states.view(-1, self.embed_dim)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(q, k, cos, sin):
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0):
super().__init__()
self.dim = dim
self.theta = theta
def forward(self, seqlen: int, device) -> torch.Tensor:
inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.dim, 2, dtype=torch.float, device=device) / self.dim))
seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.outer(seq, inv_freq)
return freqs
class PatchMerger(nn.Module):
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2, device=None, dtype=None, ops=None):
super().__init__()
self.hidden_size = context_dim * (spatial_merge_size ** 2)
self.ln_q = ops.RMSNorm(context_dim, eps=1e-6, device=device, dtype=dtype)
self.mlp = nn.Sequential(
ops.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype),
nn.GELU(),
ops.Linear(self.hidden_size, dim, device=device, dtype=dtype),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x).reshape(-1, self.hidden_size)
x = self.mlp(x)
return x
class VisionAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, device=None, dtype=None, ops=None):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scaling = self.head_dim ** -0.5
self.qkv = ops.Linear(hidden_size, hidden_size * 3, bias=True, device=device, dtype=dtype)
self.proj = ops.Linear(hidden_size, hidden_size, bias=True, device=device, dtype=dtype)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cu_seqlens=None,
optimized_attention=None,
) -> torch.Tensor:
if hidden_states.dim() == 2:
seq_length, _ = hidden_states.shape
batch_size = 1
hidden_states = hidden_states.unsqueeze(0)
else:
batch_size, seq_length, _ = hidden_states.shape
qkv = self.qkv(hidden_states)
qkv = qkv.reshape(batch_size, seq_length, 3, self.num_heads, self.head_dim)
query_states, key_states, value_states = qkv.reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
optimized_attention(q, k, v, self.num_heads, skip_reshape=True)
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.proj(attn_output)
return attn_output
class VisionMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, device=None, dtype=None, ops=None):
super().__init__()
self.gate_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
self.up_proj = ops.Linear(hidden_size, intermediate_size, bias=True, device=device, dtype=dtype)
self.down_proj = ops.Linear(intermediate_size, hidden_size, bias=True, device=device, dtype=dtype)
self.act_fn = nn.SiLU()
def forward(self, hidden_state):
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class VisionBlock(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, num_heads: int, device=None, dtype=None, ops=None):
super().__init__()
self.norm1 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.norm2 = ops.RMSNorm(hidden_size, eps=1e-6, device=device, dtype=dtype)
self.attn = VisionAttention(hidden_size, num_heads, device=device, dtype=dtype, ops=ops)
self.mlp = VisionMLP(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cu_seqlens=None,
optimized_attention=None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.attn(hidden_states, position_embeddings, cu_seqlens, optimized_attention)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class Qwen2VLVisionTransformer(nn.Module):
def __init__(
self,
hidden_size: int = 3584,
output_hidden_size: int = 3584,
intermediate_size: int = 3420,
num_heads: int = 16,
num_layers: int = 32,
patch_size: int = 14,
temporal_patch_size: int = 2,
spatial_merge_size: int = 2,
window_size: int = 112,
device=None,
dtype=None,
ops=None
):
super().__init__()
self.hidden_size = hidden_size
self.patch_size = patch_size
self.spatial_merge_size = spatial_merge_size
self.window_size = window_size
self.fullatt_block_indexes = [7, 15, 23, 31]
self.patch_embed = VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=3,
embed_dim=hidden_size,
device=device,
dtype=dtype,
ops=ops,
)
head_dim = hidden_size // num_heads
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
VisionBlock(hidden_size, intermediate_size, num_heads, device, dtype, ops)
for _ in range(num_layers)
])
self.merger = PatchMerger(
dim=output_hidden_size,
context_dim=hidden_size,
spatial_merge_size=spatial_merge_size,
device=device,
dtype=dtype,
ops=ops,
)
def get_window_index(self, grid_thw):
window_index = []
cu_window_seqlens = [0]
window_index_id = 0
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h = grid_h // self.spatial_merge_size
llm_grid_w = grid_w // self.spatial_merge_size
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_size * self.spatial_merge_size + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def get_position_embeddings(self, grid_thw, device):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h, device=device).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3).flatten()
wpos_ids = torch.arange(w, device=device).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3).flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device)
return rotary_pos_emb_full[pos_ids].flatten(1)
def forward(
self,
pixel_values: torch.Tensor,
image_grid_thw: Optional[torch.Tensor] = None,
) -> torch.Tensor:
optimized_attention = optimized_attention_for_device(pixel_values.device, mask=False, small_input=True)
hidden_states = self.patch_embed(pixel_values)
window_index, cu_window_seqlens = self.get_window_index(image_grid_thw)
cu_window_seqlens = torch.tensor(cu_window_seqlens, device=hidden_states.device)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
position_embeddings = self.get_position_embeddings(image_grid_thw, hidden_states.device)
seq_len, _ = hidden_states.size()
spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
hidden_states = hidden_states.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
position_embeddings = position_embeddings.reshape(seq_len // spatial_merge_unit, spatial_merge_unit, -1)
position_embeddings = position_embeddings[window_index, :, :]
position_embeddings = position_embeddings.reshape(seq_len, -1)
position_embeddings = torch.cat((position_embeddings, position_embeddings), dim=-1)
position_embeddings = (position_embeddings.cos(), position_embeddings.sin())
cu_seqlens = torch.repeat_interleave(image_grid_thw[:, 1] * image_grid_thw[:, 2], image_grid_thw[:, 0]).cumsum(
dim=0,
dtype=torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for i, block in enumerate(self.blocks):
if i in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
hidden_states = self.merger(hidden_states)
return hidden_states

View File

@ -199,7 +199,7 @@ class T5Stack(torch.nn.Module):
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate) # self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]):
mask = None mask = None
if attention_mask is not None: if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])

View File

@ -726,6 +726,10 @@ class SEGS(ComfyTypeIO):
class AnyType(ComfyTypeIO): class AnyType(ComfyTypeIO):
Type = Any Type = Any
@comfytype(io_type="MODEL_PATCH")
class MODEL_PATCH(ComfyTypeIO):
Type = Any
@comfytype(io_type="COMFY_MULTITYPED_V3") @comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType: class MultiType:
Type = Any Type = Any

View File

@ -1315,6 +1315,7 @@ class KlingTaskStatus(str, Enum):
class KlingTextToVideoModelName(str, Enum): class KlingTextToVideoModelName(str, Enum):
kling_v1 = 'kling-v1' kling_v1 = 'kling-v1'
kling_v1_6 = 'kling-v1-6' kling_v1_6 = 'kling-v1-6'
kling_v2_1_master = 'kling-v2-1-master'
class KlingVideoGenAspectRatio(str, Enum): class KlingVideoGenAspectRatio(str, Enum):
@ -1347,6 +1348,8 @@ class KlingVideoGenModelName(str, Enum):
kling_v1_5 = 'kling-v1-5' kling_v1_5 = 'kling-v1-5'
kling_v1_6 = 'kling-v1-6' kling_v1_6 = 'kling-v1-6'
kling_v2_master = 'kling-v2-master' kling_v2_master = 'kling-v2-master'
kling_v2_1 = 'kling-v2-1'
kling_v2_1_master = 'kling-v2-1-master'
class KlingVideoResult(BaseModel): class KlingVideoResult(BaseModel):
@ -1620,13 +1623,14 @@ class MinimaxTaskResultResponse(BaseModel):
task_id: str = Field(..., description='The task ID being queried.') task_id: str = Field(..., description='The task ID being queried.')
class Model(str, Enum): class MiniMaxModel(str, Enum):
T2V_01_Director = 'T2V-01-Director' T2V_01_Director = 'T2V-01-Director'
I2V_01_Director = 'I2V-01-Director' I2V_01_Director = 'I2V-01-Director'
S2V_01 = 'S2V-01' S2V_01 = 'S2V-01'
I2V_01 = 'I2V-01' I2V_01 = 'I2V-01'
I2V_01_live = 'I2V-01-live' I2V_01_live = 'I2V-01-live'
T2V_01 = 'T2V-01' T2V_01 = 'T2V-01'
Hailuo_02 = 'MiniMax-Hailuo-02'
class SubjectReferenceItem(BaseModel): class SubjectReferenceItem(BaseModel):
@ -1648,7 +1652,7 @@ class MinimaxVideoGenerationRequest(BaseModel):
None, None,
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
) )
model: Model = Field( model: MiniMaxModel = Field(
..., ...,
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01', description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
) )
@ -1665,6 +1669,14 @@ class MinimaxVideoGenerationRequest(BaseModel):
None, None,
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.', description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
) )
duration: Optional[int] = Field(
None,
description="The length of the output video in seconds."
)
resolution: Optional[str] = Field(
None,
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
)
class MinimaxVideoGenerationResponse(BaseModel): class MinimaxVideoGenerationResponse(BaseModel):

View File

@ -46,6 +46,8 @@ class GeminiModel(str, Enum):
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06" gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
gemini_2_5_pro = "gemini-2.5-pro"
gemini_2_5_flash = "gemini-2.5-flash"
def get_gemini_endpoint( def get_gemini_endpoint(
@ -97,7 +99,7 @@ class GeminiNode(ComfyNodeABC):
{ {
"tooltip": "The Gemini model to use for generating responses.", "tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiModel], "options": [model.value for model in GeminiModel],
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value, "default": GeminiModel.gemini_2_5_pro.value,
}, },
), ),
"seed": ( "seed": (

View File

@ -421,6 +421,8 @@ class KlingTextToVideoNode(KlingNodeBase):
"pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"), "pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"),
"standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"), "standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"),
"standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"), "standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"),
"pro mode / 5s duration / kling-v2-1-master": ("pro", "5", "kling-v2-1-master"),
"pro mode / 10s duration / kling-v2-1-master": ("pro", "10", "kling-v2-1-master"),
} }
@classmethod @classmethod

View File

@ -1,3 +1,4 @@
from inspect import cleandoc
from typing import Union from typing import Union
import logging import logging
import torch import torch
@ -10,7 +11,7 @@ from comfy_api_nodes.apis import (
MinimaxFileRetrieveResponse, MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse, MinimaxTaskResultResponse,
SubjectReferenceItem, SubjectReferenceItem,
Model MiniMaxModel
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
@ -84,7 +85,6 @@ class MinimaxTextToVideoNode:
FUNCTION = "generate_video" FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax" CATEGORY = "api node/video/MiniMax"
API_NODE = True API_NODE = True
OUTPUT_NODE = True
async def generate_video( async def generate_video(
self, self,
@ -121,7 +121,7 @@ class MinimaxTextToVideoNode:
response_model=MinimaxVideoGenerationResponse, response_model=MinimaxVideoGenerationResponse,
), ),
request=MinimaxVideoGenerationRequest( request=MinimaxVideoGenerationRequest(
model=Model(model), model=MiniMaxModel(model),
prompt=prompt_text, prompt=prompt_text,
callback_url=None, callback_url=None,
first_frame_image=image_url, first_frame_image=image_url,
@ -251,7 +251,6 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
FUNCTION = "generate_video" FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax" CATEGORY = "api node/video/MiniMax"
API_NODE = True API_NODE = True
OUTPUT_NODE = True
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode): class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
@ -313,7 +312,181 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
FUNCTION = "generate_video" FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax" CATEGORY = "api node/video/MiniMax"
API_NODE = True API_NODE = True
OUTPUT_NODE = True
class MinimaxHailuoVideoNode:
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation.",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
"first_frame_image": (
IO.IMAGE,
{
"tooltip": "Optional image to use as the first frame to generate a video."
},
),
"prompt_optimizer": (
IO.BOOLEAN,
{
"tooltip": "Optimize prompt to improve generation quality when needed.",
"default": True,
},
),
"duration": (
IO.COMBO,
{
"tooltip": "The length of the output video in seconds.",
"default": 6,
"options": [6, 10],
},
),
"resolution": (
IO.COMBO,
{
"tooltip": "The dimensions of the video display. "
"1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.",
"default": "768P",
"options": ["768P", "1080P"],
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = cleandoc(__doc__ or "")
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
async def generate_video(
self,
prompt_text,
seed=0,
first_frame_image: torch.Tensor=None, # used for ImageToVideo
prompt_optimizer=True,
duration=6,
resolution="768P",
model="MiniMax-Hailuo-02",
unique_id: Union[str, None]=None,
**kwargs,
):
if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text")
if model == "MiniMax-Hailuo-02" and resolution.upper() == "1080P" and duration != 6:
raise Exception(
"When model is MiniMax-Hailuo-02 and resolution is 1080P, duration is limited to 6 seconds."
)
# upload image, if passed in
image_url = None
if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/video_generation",
method=HttpMethod.POST,
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model),
prompt=prompt_text,
callback_url=None,
first_frame_image=image_url,
prompt_optimizer=prompt_optimizer,
duration=duration,
resolution=resolution,
),
auth_kwargs=kwargs,
)
response = await video_generate_operation.execute()
task_id = response.task_id
if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}")
average_duration = 120 if resolution == "768P" else 240
video_generate_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/minimax/query/video_generation",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=unique_id,
auth_kwargs=kwargs,
)
task_result = await video_generate_operation.execute()
file_id = task_result.file_id
if file_id is None:
raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/minimax/files/retrieve",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=kwargs,
)
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url
if file_url is None:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
if unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id)
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return (VideoFromFile(video_io),)
# A dictionary that contains all nodes you want to export with their names # A dictionary that contains all nodes you want to export with their names
@ -322,6 +495,7 @@ NODE_CLASS_MAPPINGS = {
"MinimaxTextToVideoNode": MinimaxTextToVideoNode, "MinimaxTextToVideoNode": MinimaxTextToVideoNode,
"MinimaxImageToVideoNode": MinimaxImageToVideoNode, "MinimaxImageToVideoNode": MinimaxImageToVideoNode,
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode, # "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
"MinimaxHailuoVideoNode": MinimaxHailuoVideoNode,
} }
# A dictionary that contains the friendly/humanly readable titles for the nodes # A dictionary that contains the friendly/humanly readable titles for the nodes
@ -329,4 +503,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"MinimaxTextToVideoNode": "MiniMax Text to Video", "MinimaxTextToVideoNode": "MiniMax Text to Video",
"MinimaxImageToVideoNode": "MiniMax Image to Video", "MinimaxImageToVideoNode": "MiniMax Image to Video",
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video", "MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
"MinimaxHailuoVideoNode": "MiniMax Hailuo Video",
} }

View File

@ -80,6 +80,9 @@ class SupportedOpenAIModel(str, Enum):
gpt_4_1 = "gpt-4.1" gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini" gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano" gpt_4_1_nano = "gpt-4.1-nano"
gpt_5 = "gpt-5"
gpt_5_mini = "gpt-5-mini"
gpt_5_nano = "gpt-5-nano"
class OpenAIDalle2(ComfyNodeABC): class OpenAIDalle2(ComfyNodeABC):
@ -464,8 +467,6 @@ class OpenAIGPTImage1(ComfyNodeABC):
path = "/proxy/openai/images/generations" path = "/proxy/openai/images/generations"
content_type = "application/json" content_type = "application/json"
request_class = OpenAIImageGenerationRequest request_class = OpenAIImageGenerationRequest
img_binaries = []
mask_binary = None
files = [] files = []
if image is not None: if image is not None:
@ -484,14 +485,11 @@ class OpenAIGPTImage1(ComfyNodeABC):
img_byte_arr = io.BytesIO() img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG") img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0) img_byte_arr.seek(0)
img_binary = img_byte_arr
img_binary.name = f"image_{i}.png"
img_binaries.append(img_binary)
if batch_size == 1: if batch_size == 1:
files.append(("image", img_binary)) files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
else: else:
files.append(("image[]", img_binary)) files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
if mask is not None: if mask is not None:
if image is None: if image is None:
@ -511,9 +509,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
mask_img_byte_arr = io.BytesIO() mask_img_byte_arr = io.BytesIO()
mask_img.save(mask_img_byte_arr, format="PNG") mask_img.save(mask_img_byte_arr, format="PNG")
mask_img_byte_arr.seek(0) mask_img_byte_arr.seek(0)
mask_binary = mask_img_byte_arr files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
mask_binary.name = "mask.png"
files.append(("mask", mask_binary))
# Build the operation # Build the operation
operation = SynchronousOperation( operation = SynchronousOperation(

View File

@ -0,0 +1,622 @@
import logging
from enum import Enum
from typing import Any, Callable, Optional, Literal, TypeVar
from typing_extensions import override
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.util.validation_utils import (
validate_aspect_ratio_closeness,
validate_image_dimensions,
validate_image_aspect_ratio_range,
get_number_of_images,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video"
VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video"
VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
R = TypeVar("R")
class VideoModelName(str, Enum):
vidu_q1 = 'viduq1'
class AspectRatio(str, Enum):
r_16_9 = "16:9"
r_9_16 = "9:16"
r_1_1 = "1:1"
class Resolution(str, Enum):
r_1080p = "1080p"
class MovementAmplitude(str, Enum):
auto = "auto"
small = "small"
medium = "medium"
large = "large"
class TaskCreationRequest(BaseModel):
model: VideoModelName = VideoModelName.vidu_q1
prompt: Optional[str] = Field(None, max_length=1500)
duration: Optional[Literal[5]] = 5
seed: Optional[int] = Field(0, ge=0, le=2147483647)
aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9
resolution: Optional[Resolution] = Resolution.r_1080p
movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
class TaskStatus(str, Enum):
created = "created"
queueing = "queueing"
processing = "processing"
success = "success"
failed = "failed"
class TaskCreationResponse(BaseModel):
task_id: str = Field(...)
state: TaskStatus = Field(...)
created_at: str = Field(...)
code: Optional[int] = Field(None, description="Error code")
class TaskResult(BaseModel):
id: str = Field(..., description="Creation id")
url: str = Field(..., description="The URL of the generated results, valid for one hour")
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
class TaskStatusResponse(BaseModel):
state: TaskStatus = Field(...)
err_code: Optional[str] = Field(None)
creations: list[TaskResult] = Field(..., description="Generated results")
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> R:
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[TaskStatus.success.value],
failed_statuses=[TaskStatus.failed.value],
status_extractor=lambda response: response.state.value,
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute()
def get_video_url_from_response(response) -> Optional[str]:
if response.creations:
return response.creations[0].url
return None
def get_video_from_response(response) -> TaskResult:
if not response.creations:
error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
logging.info(error_msg)
raise RuntimeError(error_msg)
logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url)
return response.creations[0]
async def execute_task(
vidu_endpoint: str,
auth_kwargs: Optional[dict[str, str]],
payload: TaskCreationRequest,
estimated_duration: int,
node_id: str,
) -> R:
response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=vidu_endpoint,
method=HttpMethod.POST,
request_model=TaskCreationRequest,
response_model=TaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
if response.state == TaskStatus.failed:
error_msg = f"Vidu request failed. Code: {response.code}"
logging.error(error_msg)
raise RuntimeError(error_msg)
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=VIDU_GET_GENERATION_STATUS % response.task_id,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=estimated_duration,
node_id=node_id,
)
class ViduTextToVideoNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="ViduTextToVideoNode",
display_name="Vidu Text To Video Generation",
category="api node/video/Vidu",
description="Generate video from text prompt",
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
comfy_io.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation",
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
tooltip="The aspect ratio of the output video",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
duration: int,
seed: int,
aspect_ratio: str,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
if not prompt:
raise ValueError("The prompt field is required and cannot be empty.")
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
duration=duration,
seed=seed,
aspect_ratio=aspect_ratio,
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduImageToVideoNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="ViduImageToVideoNode",
display_name="Vidu Image To Video Generation",
category="api node/video/Vidu",
description="Generate video from image and optional prompt",
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
comfy_io.Image.Input(
"image",
tooltip="An image to be used as the start frame of the generated video",
),
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="A textual description for video generation",
optional=True,
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str,
duration: int,
seed: int,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
duration=duration,
seed=seed,
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = await upload_images_to_comfyapi(
image,
max_images=1,
mime_type="image/png",
auth_kwargs=auth,
)
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduReferenceVideoNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="ViduReferenceVideoNode",
display_name="Vidu Reference To Video Generation",
category="api node/video/Vidu",
description="Generate video from multiple images and prompt",
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
comfy_io.Image.Input(
"images",
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
),
comfy_io.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation",
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
"aspect_ratio",
options=[model.value for model in AspectRatio],
default=AspectRatio.r_16_9.value,
tooltip="The aspect ratio of the output video",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
images: torch.Tensor,
prompt: str,
duration: int,
seed: int,
aspect_ratio: str,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
if not prompt:
raise ValueError("The prompt field is required and cannot be empty.")
a = get_number_of_images(images)
if a > 7:
raise ValueError("Too many images, maximum allowed is 7.")
for image in images:
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
duration=duration,
seed=seed,
aspect_ratio=aspect_ratio,
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = await upload_images_to_comfyapi(
images,
max_images=7,
mime_type="image/png",
auth_kwargs=auth,
)
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduStartEndToVideoNode(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="ViduStartEndToVideoNode",
display_name="Vidu Start End To Video Generation",
category="api node/video/Vidu",
description="Generate a video from start and end frames and a prompt",
inputs=[
comfy_io.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
comfy_io.Image.Input(
"first_frame",
tooltip="Start frame",
),
comfy_io.Image.Input(
"end_frame",
tooltip="End frame",
),
comfy_io.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation",
optional=True,
),
comfy_io.Int.Input(
"duration",
default=5,
min=5,
max=5,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Duration of the output video in seconds",
optional=True,
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=comfy_io.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed for video generation (0 for random)",
optional=True,
),
comfy_io.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration",
optional=True,
),
comfy_io.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
first_frame: torch.Tensor,
end_frame: torch.Tensor,
prompt: str,
duration: int,
seed: int,
resolution: str,
movement_amplitude: str,
) -> comfy_io.NodeOutput:
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model_name=model,
prompt=prompt,
duration=duration,
seed=seed,
resolution=resolution,
movement_amplitude=movement_amplitude,
)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = [
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
for frame in (first_frame, end_frame)
]
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
class ViduExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
ViduTextToVideoNode,
ViduImageToVideoNode,
ViduReferenceVideoNode,
ViduStartEndToVideoNode,
]
async def comfy_entrypoint() -> ViduExtension:
return ViduExtension()

View File

@ -53,6 +53,53 @@ def validate_image_aspect_ratio(
) )
def validate_image_aspect_ratio_range(
image: torch.Tensor,
min_ratio: tuple[float, float], # e.g. (1, 4)
max_ratio: tuple[float, float], # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
a1, b1 = min_ratio
a2, b2 = max_ratio
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
lo, hi = (a1 / b1), (a2 / b2)
if lo > hi:
lo, hi = hi, lo
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
w, h = get_image_dimensions(image)
if w <= 0 or h <= 0:
raise ValueError(f"Invalid image dimensions: {w}x{h}")
ar = w / h
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
if not ok:
op = "<" if strict else ""
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
return ar
def validate_aspect_ratio_closeness(
start_img,
end_img,
min_rel: float,
max_rel: float,
*,
strict: bool = False, # True => exclusive, False => inclusive
) -> None:
w1, h1 = get_image_dimensions(start_img)
w2, h2 = get_image_dimensions(end_img)
if min(w1, h1, w2, h2) <= 0:
raise ValueError("Invalid image dimensions")
ar1 = w1 / h1
ar2 = w2 / h2
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
closeness = max(ar1, ar2) / min(ar1, ar2)
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
if (closeness >= limit) if strict else (closeness > limit):
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}{max_rel}.")
def validate_video_dimensions( def validate_video_dimensions(
video: VideoInput, video: VideoInput,
min_width: Optional[int] = None, min_width: Optional[int] = None,
@ -98,3 +145,9 @@ def validate_video_duration(
raise ValueError( raise ValueError(
f"Video duration must be at most {max_duration}s, got {duration}s" f"Video duration must be at most {max_duration}s, got {duration}s"
) )
def get_number_of_images(images):
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
return len(images)

View File

@ -346,6 +346,24 @@ class LoadAudio:
return "Invalid audio file: {}".format(audio) return "Invalid audio file: {}".format(audio)
return True return True
class RecordAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {"audio": ("AUDIO_RECORD", {})}}
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = torchaudio.load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio, "EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio, "VAEEncodeAudio": VAEEncodeAudio,
@ -356,6 +374,7 @@ NODE_CLASS_MAPPINGS = {
"LoadAudio": LoadAudio, "LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio, "PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio, "ConditioningStableAudio": ConditioningStableAudio,
"RecordAudio": RecordAudio,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -367,4 +386,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveAudio": "Save Audio (FLAC)", "SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)", "SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)", "SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio",
} }

View File

@ -100,9 +100,28 @@ class FluxKontextImageScale:
return (image, ) return (image, )
class FluxKontextMultiReferenceLatentMethod:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index"), ),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
EXPERIMENTAL = True
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux, "CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance, "FluxGuidance": FluxGuidance,
"FluxDisableGuidance": FluxDisableGuidance, "FluxDisableGuidance": FluxDisableGuidance,
"FluxKontextImageScale": FluxKontextImageScale, "FluxKontextImageScale": FluxKontextImageScale,
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
} }

View File

@ -166,7 +166,7 @@ class LTXVAddGuide:
negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) negative = self.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full( mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], 1, 1), (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength, 1.0 - strength,
dtype=noise_mask.dtype, dtype=noise_mask.dtype,
device=noise_mask.device, device=noise_mask.device,

View File

@ -0,0 +1,161 @@
import torch
import folder_paths
import comfy.utils
import comfy.ops
import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
class BlockWiseControlBlock(torch.nn.Module):
# [linear, gelu, linear]
def __init__(self, dim: int = 3072, device=None, dtype=None, operations=None):
super().__init__()
self.x_rms = operations.RMSNorm(dim, eps=1e-6)
self.y_rms = operations.RMSNorm(dim, eps=1e-6)
self.input_proj = operations.Linear(dim, dim)
self.act = torch.nn.GELU()
self.output_proj = operations.Linear(dim, dim)
def forward(self, x, y):
x, y = self.x_rms(x), self.y_rms(y)
x = self.input_proj(x + y)
x = self.act(x)
x = self.output_proj(x)
return x
class QwenImageBlockWiseControlNet(torch.nn.Module):
def __init__(
self,
num_layers: int = 60,
in_dim: int = 64,
additional_in_dim: int = 0,
dim: int = 3072,
device=None, dtype=None, operations=None
):
super().__init__()
self.additional_in_dim = additional_in_dim
self.img_in = operations.Linear(in_dim + additional_in_dim, dim, device=device, dtype=dtype)
self.controlnet_blocks = torch.nn.ModuleList(
[
BlockWiseControlBlock(dim, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
]
)
def process_input_latent_image(self, latent_image):
latent_image[:, :16] = comfy.latent_formats.Wan21().process_in(latent_image[:, :16])
patch_size = 2
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(latent_image, (1, patch_size, patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
return self.img_in(hidden_states)
def control_block(self, img, controlnet_conditioning, block_id):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "name": (folder_paths.get_filename_list("model_patches"), ),
}}
RETURN_TYPES = ("MODEL_PATCH",)
FUNCTION = "load_model_patch"
EXPERIMENTAL = True
CATEGORY = "advanced/loaders"
def load_model_patch(self, name):
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
dtype = comfy.utils.weight_dtype(sd)
# TODO: this node will work with more types of model patches
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)
class DiffSynthCnetPatch:
def __init__(self, model_patch, vae, image, strength, mask=None):
self.model_patch = model_patch
self.vae = vae
self.image = image
self.strength = strength
self.mask = mask
self.encoded_image = model_patch.model.process_input_latent_image(self.encode_latent_cond(image))
def encode_latent_cond(self, image):
latent_image = self.vae.encode(image)
if self.model_patch.model.additional_in_dim > 0:
if self.mask is None:
mask_ = torch.ones_like(latent_image)[:, :self.model_patch.model.additional_in_dim // 4]
else:
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
return torch.cat([latent_image, mask_], dim=1)
else:
return latent_image
def __call__(self, kwargs):
x = kwargs.get("x")
img = kwargs.get("img")
block_index = kwargs.get("block_index")
if self.encoded_image is None or self.encoded_image.shape[1:] != img.shape[1:]:
spacial_compression = self.vae.spacial_compression_encode()
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.model_patch.model.process_input_latent_image(self.encode_latent_cond(image_scaled.movedim(1, -1)))
comfy.model_management.load_models_gpu(loaded_models)
img = img + (self.model_patch.model.control_block(img, self.encoded_image.to(img.dtype), block_index) * self.strength)
kwargs['img'] = img
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
return self
def models(self):
return [self.model_patch]
class QwenImageDiffsynthControlnet:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"vae": ("VAE",),
"image": ("IMAGE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
},
"optional": {"mask": ("MASK",)}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "diffsynth_controlnet"
EXPERIMENTAL = True
CATEGORY = "advanced/loaders/qwen"
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
model_patched = model.clone()
image = image[:, :, :, :3]
if mask is not None:
if mask.ndim == 3:
mask = mask.unsqueeze(1)
if mask.ndim == 4:
mask = mask.unsqueeze(2)
mask = 1.0 - mask
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
}

View File

@ -0,0 +1,48 @@
import node_helpers
import comfy.utils
import math
class TextEncodeQwenImageEdit:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
},
"optional": {"vae": ("VAE", ),
"image": ("IMAGE", ),}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, prompt, vae=None, image=None):
ref_latent = None
if image is None:
images = []
else:
samples = image.movedim(-1, 1)
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
image = s.movedim(1, -1)
images = [image[:, :, :, :3]]
if vae is not None:
ref_latent = vae.encode(image[:, :, :, :3])
tokens = clip.tokenize(prompt, images=images)
conditioning = clip.encode_from_tokens_scheduled(tokens)
if ref_latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
return (conditioning, )
NODE_CLASS_MAPPINGS = {
"TextEncodeQwenImageEdit": TextEncodeQwenImageEdit,
}

View File

@ -422,9 +422,12 @@ class WanCameraImageToVideo(io.ComfyNode):
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3]) concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]] concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
mask = torch.ones((1, 1, latent.shape[2] * 4, latent.shape[-2], latent.shape[-1]))
mask[:, :, :start_image.shape[0] + 3] = 0.0
mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent}) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent, "concat_mask": mask})
if camera_conditions is not None: if camera_conditions is not None:
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions}) positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
@ -696,7 +699,7 @@ class WanTrackToVideo(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="WanPhantomSubjectToVideo", node_id="WanTrackToVideo",
category="conditioning/video_models", category="conditioning/video_models",
inputs=[ inputs=[
io.Conditioning.Input("positive"), io.Conditioning.Input("positive"),

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.50" __version__ = "0.3.51"

View File

@ -46,6 +46,8 @@ folder_names_and_paths["photomaker"] = ([os.path.join(models_dir, "photomaker")]
folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""})
folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patches")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output") output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp") temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input") input_directory = os.path.join(base_path, "input")

View File

@ -2321,6 +2321,8 @@ async def init_builtin_extra_nodes():
"nodes_edit_model.py", "nodes_edit_model.py",
"nodes_tcfg.py", "nodes_tcfg.py",
"nodes_context_windows.py", "nodes_context_windows.py",
"nodes_qwen.py",
"nodes_model_patch.py"
] ]
import_failed = [] import_failed = []
@ -2350,6 +2352,7 @@ async def init_builtin_api_nodes():
"nodes_moonvalley.py", "nodes_moonvalley.py",
"nodes_rodin.py", "nodes_rodin.py",
"nodes_gemini.py", "nodes_gemini.py",
"nodes_vidu.py",
] ]
if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.50" version = "0.3.51"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.24.4 comfyui-frontend-package==1.25.9
comfyui-workflow-templates==0.1.59 comfyui-workflow-templates==0.1.62
comfyui-embedded-docs==0.2.6 comfyui-embedded-docs==0.2.6
torch torch
torchsde torchsde