diff --git a/.ci/windows_intel_base_files/run_intel_gpu.bat b/.ci/windows_intel_base_files/run_intel_gpu.bat new file mode 100755 index 000000000..274d7c948 --- /dev/null +++ b/.ci/windows_intel_base_files/run_intel_gpu.bat @@ -0,0 +1,2 @@ +.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build +pause diff --git a/QUANTIZATION.md b/QUANTIZATION.md index 1693e13f3..300822029 100644 --- a/QUANTIZATION.md +++ b/QUANTIZATION.md @@ -139,9 +139,9 @@ Example: "_quantization_metadata": { "format_version": "1.0", "layers": { - "model.layers.0.mlp.up_proj": "float8_e4m3fn", - "model.layers.0.mlp.down_proj": "float8_e4m3fn", - "model.layers.1.mlp.up_proj": "float8_e4m3fn" + "model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"}, + "model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"}, + "model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"} } } } @@ -165,4 +165,4 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s 3. **Compute scales**: Derive `input_scale` from collected statistics 4. **Store in checkpoint**: Save `input_scale` parameters alongside weights -The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. \ No newline at end of file +The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. diff --git a/README.md b/README.md index 1eeb810de..f05311421 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,9 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat #### Alternative Downloads: -[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) +[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) + +[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z) [Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index b224306da..1477afa01 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -67,7 +67,7 @@ class InternalRoutes: (entry for entry in os.scandir(directory) if is_visible_file(entry)), key=lambda entry: -entry.stat().st_mtime ) - return web.json_response([entry.name for entry in sorted_files], status=200) + return web.json_response([f"{entry.name} [{directory_type}]" for entry in sorted_files], status=200) def get_app(self): diff --git a/blueprints/Brightness and Contrast.json b/blueprints/Brightness and Contrast.json index 6a234139d..90bfe999d 100644 --- a/blueprints/Brightness and Contrast.json +++ b/blueprints/Brightness and Contrast.json @@ -182,7 +182,7 @@ ] }, "widgets_values": [ - 50 + 0 ] }, { diff --git a/blueprints/Glow.json b/blueprints/Glow.json index 42cf63e8a..8c690fc68 100644 --- a/blueprints/Glow.json +++ b/blueprints/Glow.json @@ -316,7 +316,7 @@ "step": 1 }, "widgets_values": [ - 30 + 0 ] }, { diff --git a/comfy/ldm/ernie/model.py b/comfy/ldm/ernie/model.py new file mode 100644 index 000000000..eba661aec --- /dev/null +++ b/comfy/ldm/ernie/model.py @@ -0,0 +1,301 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management + +def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: + assert dim % 2 == 0 + if not comfy.model_management.supports_fp64(pos.device): + device = torch.device("cpu") + else: + device = pos.device + + scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim + omega = 1.0 / (theta**scale) + out = torch.einsum("...n,d->...nd", pos.to(device), omega) + out = torch.stack([torch.cos(out), torch.sin(out)], dim=0) + return out.to(dtype=torch.float32, device=pos.device) + +def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + rot_dim = freqs_cis.shape[-1] + x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] + cos_ = freqs_cis[0] + sin_ = freqs_cis[1] + x1, x2 = x.chunk(2, dim=-1) + x_rotated = torch.cat((-x2, x1), dim=-1) + return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) + +class ErnieImageEmbedND3(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: tuple): + super().__init__() + self.dim = dim + self.theta = theta + self.axes_dim = list(axes_dim) + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) + emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2] + return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] + +class ErnieImagePatchEmbedDynamic(nn.Module): + def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None): + super().__init__() + self.patch_size = patch_size + self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + batch_size, dim, height, width = x.shape + return x.reshape(batch_size, dim, height * width).transpose(1, 2).contiguous() + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, flip_sin_to_cos: bool = False): + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + + def forward(self, timesteps: torch.Tensor) -> torch.Tensor: + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) / half_dim + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + if self.flip_sin_to_cos: + emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) + else: + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + return emb + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, operations, device=None, dtype=None): + super().__init__() + Linear = operations.Linear + self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, device=device, dtype=dtype) + self.act = nn.SiLU() + self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, device=device, dtype=dtype) + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + return sample + +class ErnieImageAttention(nn.Module): + def __init__(self, query_dim: int, heads: int, dim_head: int, eps: float = 1e-6, operations=None, device=None, dtype=None): + super().__init__() + self.heads = heads + self.head_dim = dim_head + self.inner_dim = heads * dim_head + + Linear = operations.Linear + RMSNorm = operations.RMSNorm + + self.to_q = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype) + self.to_k = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype) + self.to_v = Linear(query_dim, self.inner_dim, bias=False, device=device, dtype=dtype) + + self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype) + self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=True, device=device, dtype=dtype) + + self.to_out = nn.ModuleList([Linear(self.inner_dim, query_dim, bias=False, device=device, dtype=dtype)]) + + def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_rotary_emb: torch.Tensor = None) -> torch.Tensor: + B, S, _ = x.shape + + q_flat = self.to_q(x) + k_flat = self.to_k(x) + v_flat = self.to_v(x) + + query = q_flat.view(B, S, self.heads, self.head_dim) + key = k_flat.view(B, S, self.heads, self.head_dim) + + query = self.norm_q(query) + key = self.norm_k(key) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + q_flat = query.reshape(B, S, -1) + k_flat = key.reshape(B, S, -1) + + hidden_states = optimized_attention(q_flat, k_flat, v_flat, self.heads, mask=attention_mask) + + return self.to_out[0](hidden_states) + +class ErnieImageFeedForward(nn.Module): + def __init__(self, hidden_size: int, ffn_hidden_size: int, operations, device=None, dtype=None): + super().__init__() + Linear = operations.Linear + self.gate_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype) + self.up_proj = Linear(hidden_size, ffn_hidden_size, bias=False, device=device, dtype=dtype) + self.linear_fc2 = Linear(ffn_hidden_size, hidden_size, bias=False, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.up_proj(x) * F.gelu(self.gate_proj(x))) + +class ErnieImageSharedAdaLNBlock(nn.Module): + def __init__(self, hidden_size: int, num_heads: int, ffn_hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None): + super().__init__() + RMSNorm = operations.RMSNorm + + self.adaLN_sa_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype) + self.self_attention = ErnieImageAttention( + query_dim=hidden_size, + dim_head=hidden_size // num_heads, + heads=num_heads, + eps=eps, + operations=operations, + device=device, + dtype=dtype + ) + self.adaLN_mlp_ln = RMSNorm(hidden_size, eps=eps, device=device, dtype=dtype) + self.mlp = ErnieImageFeedForward(hidden_size, ffn_hidden_size, operations=operations, device=device, dtype=dtype) + + def forward(self, x, rotary_pos_emb, temb, attention_mask=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = temb + + residual = x + x_norm = self.adaLN_sa_ln(x) + x_norm = x_norm * (1 + scale_msa) + shift_msa + + attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) + x = residual + gate_msa * attn_out + + residual = x + x_norm = self.adaLN_mlp_ln(x) + x_norm = x_norm * (1 + scale_mlp) + shift_mlp + + return residual + gate_mlp * self.mlp(x_norm) + +class ErnieImageAdaLNContinuous(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None): + super().__init__() + LayerNorm = operations.LayerNorm + Linear = operations.Linear + self.norm = LayerNorm(hidden_size, elementwise_affine=False, eps=eps, device=device, dtype=dtype) + self.linear = Linear(hidden_size, hidden_size * 2, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + scale, shift = self.linear(conditioning).chunk(2, dim=-1) + x = self.norm(x) + x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)) + return x + +class ErnieImageModel(nn.Module): + def __init__( + self, + hidden_size: int = 4096, + num_attention_heads: int = 32, + num_layers: int = 36, + ffn_hidden_size: int = 12288, + in_channels: int = 128, + out_channels: int = 128, + patch_size: int = 1, + text_in_dim: int = 3072, + rope_theta: int = 256, + rope_axes_dim: tuple = (32, 48, 48), + eps: float = 1e-6, + qk_layernorm: bool = True, + device=None, + dtype=None, + operations=None, + **kwargs + ): + super().__init__() + self.dtype = dtype + self.hidden_size = hidden_size + self.num_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.patch_size = patch_size + self.out_channels = out_channels + + Linear = operations.Linear + + self.x_embedder = ErnieImagePatchEmbedDynamic(in_channels, hidden_size, patch_size, operations, device, dtype) + self.text_proj = Linear(text_in_dim, hidden_size, bias=False, device=device, dtype=dtype) if text_in_dim != hidden_size else None + + self.time_proj = Timesteps(hidden_size, flip_sin_to_cos=False) + self.time_embedding = TimestepEmbedding(hidden_size, hidden_size, operations, device, dtype) + + self.pos_embed = ErnieImageEmbedND3(dim=self.head_dim, theta=rope_theta, axes_dim=rope_axes_dim) + + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + Linear(hidden_size, 6 * hidden_size, device=device, dtype=dtype) + ) + + self.layers = nn.ModuleList([ + ErnieImageSharedAdaLNBlock(hidden_size, num_attention_heads, ffn_hidden_size, eps, operations, device, dtype) + for _ in range(num_layers) + ]) + + self.final_norm = ErnieImageAdaLNContinuous(hidden_size, eps, operations, device, dtype) + self.final_linear = Linear(hidden_size, patch_size * patch_size * out_channels, device=device, dtype=dtype) + + def forward(self, x, timesteps, context, **kwargs): + device, dtype = x.device, x.dtype + B, C, H, W = x.shape + p, Hp, Wp = self.patch_size, H // self.patch_size, W // self.patch_size + N_img = Hp * Wp + + img_bsh = self.x_embedder(x) + + text_bth = context + if self.text_proj is not None and text_bth.numel() > 0: + text_bth = self.text_proj(text_bth) + Tmax = text_bth.shape[1] + + hidden_states = torch.cat([img_bsh, text_bth], dim=1) + + text_ids = torch.zeros((B, Tmax, 3), device=device, dtype=torch.float32) + text_ids[:, :, 0] = torch.linspace(0, Tmax - 1, steps=Tmax, device=x.device, dtype=torch.float32) + index = float(Tmax) + + transformer_options = kwargs.get("transformer_options", {}) + rope_options = transformer_options.get("rope_options", None) + + h_len, w_len = float(Hp), float(Wp) + h_offset, w_offset = 0.0, 0.0 + + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + image_ids = torch.zeros((Hp, Wp, 3), device=device, dtype=torch.float32) + image_ids[:, :, 0] = image_ids[:, :, 1] + index + image_ids[:, :, 1] = image_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=Hp, device=device, dtype=torch.float32).unsqueeze(1) + image_ids[:, :, 2] = image_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=Wp, device=device, dtype=torch.float32).unsqueeze(0) + + image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1) + + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype) + del image_ids, text_ids + + sample = self.time_proj(timesteps).to(dtype) + c = self.time_embedding(sample) + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ + t.unsqueeze(1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1) + ] + + temb = [shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp] + for layer in self.layers: + hidden_states = layer(hidden_states, rotary_pos_emb, temb) + + hidden_states = self.final_norm(hidden_states, c).type_as(hidden_states) + + patches = self.final_linear(hidden_states)[:, :N_img, :] + output = ( + patches.view(B, Hp, Wp, p, p, self.out_channels) + .permute(0, 5, 1, 3, 2, 4) + .contiguous() + .view(B, self.out_channels, H, W) + ) + + return output diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 824daf5e6..6d0aed827 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -16,7 +16,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 - if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): + if not comfy.model_management.supports_fp64(pos.device): device = torch.device("cpu") else: device = pos.device diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 295310df6..4b92c44cf 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -34,6 +34,16 @@ class TimestepBlock(nn.Module): #This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index" def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): for layer in ts: + if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: + found_patched = False + for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: + if isinstance(layer, class_type): + x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) + found_patched = True + break + if found_patched: + continue + if isinstance(layer, VideoResBlock): x = layer(x, emb, num_video_frames, image_only_indicator) elif isinstance(layer, TimestepBlock): @@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: - if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: - found_patched = False - for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: - if isinstance(layer, class_type): - x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) - found_patched = True - break - if found_patched: - continue x = layer(x) return x @@ -894,6 +895,12 @@ class UNetModel(nn.Module): h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') + if "middle_block_after_patch" in transformer_patches: + patch = transformer_patches["middle_block_after_patch"] + for p in patch: + out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y, + "timesteps": timesteps, "transformer_options": transformer_options}) + h = out["h"] for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) @@ -905,8 +912,9 @@ class UNetModel(nn.Module): for p in patch: h, hsp = p(h, hsp, transformer_options) - h = th.cat([h, hsp], dim=1) - del hsp + if hsp is not None: + h = th.cat([h, hsp], dim=1) + del hsp if len(hs) > 0: output_shape = hs[-1].shape else: diff --git a/comfy/ldm/modules/sdpose.py b/comfy/ldm/modules/sdpose.py index d67b60b76..1a9585fc2 100644 --- a/comfy/ldm/modules/sdpose.py +++ b/comfy/ldm/modules/sdpose.py @@ -90,7 +90,7 @@ class HeatmapHead(torch.nn.Module): origin_max = np.max(hm[k]) dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32) dr[border:-border, border:-border] = hm[k].copy() - dr = gaussian_filter(dr, sigma=2.0) + dr = gaussian_filter(dr, sigma=2.0, truncate=2.5) hm[k] = dr[border:-border, border:-border].copy() cur_max = np.max(hm[k]) if cur_max > 0: diff --git a/comfy/ldm/supir/__init__.py b/comfy/ldm/supir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/supir/supir_modules.py b/comfy/ldm/supir/supir_modules.py new file mode 100644 index 000000000..7389b01d2 --- /dev/null +++ b/comfy/ldm/supir/supir_modules.py @@ -0,0 +1,226 @@ +import torch +import torch.nn as nn + +from comfy.ldm.modules.diffusionmodules.util import timestep_embedding +from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer +from comfy.ldm.modules.attention import optimized_attention + + +class ZeroSFT(nn.Module): + def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None): + super().__init__() + + ks = 3 + pw = ks // 2 + + self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device) + + nhidden = 128 + + self.mlp_shared = nn.Sequential( + operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device), + nn.SiLU() + ) + self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device) + self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device) + + self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device) + self.pre_concat = bool(concat_channels != 0) + + def forward(self, c, h, h_ori=None, control_scale=1): + if h_ori is not None and self.pre_concat: + h_raw = torch.cat([h_ori, h], dim=1) + else: + h_raw = h + + h = h + self.zero_conv(c) + if h_ori is not None and self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + actv = self.mlp_shared(c) + gamma = self.zero_mul(actv) + beta = self.zero_add(actv) + h = self.param_free_norm(h) + h = torch.addcmul(h + beta, h, gamma) + if h_ori is not None and not self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + return torch.lerp(h_raw, h, control_scale) + + +class _CrossAttnInner(nn.Module): + """Inner cross-attention module matching the state_dict layout of the original CrossAttention.""" + def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), + ) + + def forward(self, x, context): + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + return self.to_out(optimized_attention(q, k, v, self.heads)) + + +class ZeroCrossAttn(nn.Module): + def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None): + super().__init__() + heads = query_dim // 64 + dim_head = 64 + self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations) + self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device) + self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device) + + def forward(self, context, x, control_scale=1): + b, c, h, w = x.shape + x_in = x + + x = self.attn( + self.norm1(x).flatten(2).transpose(1, 2), + self.norm2(context).flatten(2).transpose(1, 2), + ).transpose(1, 2).unflatten(2, (h, w)) + + return x_in + x * control_scale + + +class GLVControl(nn.Module): + """SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only).""" + def __init__( + self, + in_channels=4, + model_channels=320, + num_res_blocks=2, + attention_resolutions=(4, 2), + channel_mult=(1, 2, 4), + num_head_channels=64, + transformer_depth=(1, 2, 10), + context_dim=2048, + adm_in_channels=2816, + use_linear_in_transformer=True, + use_checkpoint=False, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__() + self.model_channels = model_channels + time_embed_dim = model_channels * 4 + + self.time_embed = nn.Sequential( + operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device), + ) + + self.label_emb = nn.Sequential( + nn.Sequential( + operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device), + ) + ) + + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device) + ) + ]) + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(num_res_blocks): + layers = [ + ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels, + dtype=dtype, device=device, operations=operations) + ] + ch = mult * model_channels + if ds in attention_resolutions: + num_heads = ch // num_head_channels + layers.append( + SpatialTransformer(ch, num_heads, num_head_channels, + depth=transformer_depth[level], context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + dtype=dtype, device=device, operations=operations) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + if level != len(channel_mult) - 1: + self.input_blocks.append( + TimestepEmbedSequential( + Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations) + ) + ) + ds *= 2 + + num_heads = ch // num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations), + SpatialTransformer(ch, num_heads, num_head_channels, + depth=transformer_depth[-1], context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + dtype=dtype, device=device, operations=operations), + ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations), + ) + + self.input_hint_block = TimestepEmbedSequential( + operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device) + ) + + def forward(self, x, timesteps, xt, context=None, y=None, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + emb = self.time_embed(t_emb) + self.label_emb(y) + + guided_hint = self.input_hint_block(x, emb, context) + + hs = [] + h = xt + for module in self.input_blocks: + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + hs.append(h) + return hs + + +class SUPIR(nn.Module): + """ + SUPIR model containing GLVControl (control encoder) and project_modules (adapters). + State dict keys match the original SUPIR checkpoint layout: + control_model.* -> GLVControl + project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn + """ + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + + self.control_model = GLVControl(dtype=dtype, device=device, operations=operations) + + project_channel_scale = 2 + cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3 + project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3] + concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0] + cross_attn_insert_idx = [6, 3] + + self.project_modules = nn.ModuleList() + for i in range(len(cond_output_channels)): + self.project_modules.append(ZeroSFT( + project_channels[i], cond_output_channels[i], + concat_channels=concat_channels[i], + dtype=dtype, device=device, operations=operations, + )) + + for i in cross_attn_insert_idx: + self.project_modules.insert(i, ZeroCrossAttn( + cond_output_channels[i], concat_channels[i], + dtype=dtype, device=device, operations=operations, + )) diff --git a/comfy/ldm/supir/supir_patch.py b/comfy/ldm/supir/supir_patch.py new file mode 100644 index 000000000..b67ab4cd8 --- /dev/null +++ b/comfy/ldm/supir/supir_patch.py @@ -0,0 +1,103 @@ +import torch +from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample + + +class SUPIRPatch: + """ + Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters). + Runs GLVControl lazily on first patch invocation per step, applies adapters through + middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch. + """ + SIGMA_MAX = 14.6146 + + def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end): + self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl + self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn + self.hint_latent = hint_latent # encoded LQ image latent + self.strength_start = strength_start + self.strength_end = strength_end + self.cached_features = None + self.adapter_idx = 0 + self.control_idx = 0 + self.current_control_idx = 0 + self.active = True + + def _ensure_features(self, kwargs): + """Run GLVControl on first call per step, cache results.""" + if self.cached_features is not None: + return + x = kwargs["x"] + b = x.shape[0] + hint = self.hint_latent.to(device=x.device, dtype=x.dtype) + if hint.shape[0] != b: + hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b] + self.cached_features = self.model_patch.model.control_model( + hint, kwargs["timesteps"], x, + kwargs["context"], kwargs["y"] + ) + self.adapter_idx = len(self.project_modules) - 1 + self.control_idx = len(self.cached_features) - 1 + + def _get_control_scale(self, kwargs): + if self.strength_start == self.strength_end: + return self.strength_end + sigma = kwargs["transformer_options"].get("sigmas") + if sigma is None: + return self.strength_end + s = sigma[0].item() if sigma.dim() > 0 else sigma.item() + t = min(s / self.SIGMA_MAX, 1.0) + return t * (self.strength_start - self.strength_end) + self.strength_end + + def middle_after(self, kwargs): + """middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block.""" + self.cached_features = None # reset from previous step + self.current_scale = self._get_control_scale(kwargs) + self.active = self.current_scale > 0 + if not self.active: + return {"h": kwargs["h"]} + self._ensure_features(kwargs) + h = kwargs["h"] + h = self.project_modules[self.adapter_idx]( + self.cached_features[self.control_idx], h, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + self.control_idx -= 1 + return {"h": h} + + def output_block(self, h, hsp, transformer_options): + """output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat.""" + if not self.active: + return h, hsp + self.current_control_idx = self.control_idx + h = self.project_modules[self.adapter_idx]( + self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + self.control_idx -= 1 + return h, None + + def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw): + """forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample.""" + block_type, _ = transformer_options["block"] + if block_type == "output" and self.active and self.cached_features is not None: + x = self.project_modules[self.adapter_idx]( + self.cached_features[self.current_control_idx], x, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + return layer(x, output_shape=output_shape) + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.cached_features = None + if self.hint_latent is not None: + self.hint_latent = self.hint_latent.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + + def register(self, model_patcher): + """Register all patches on a cloned model patcher.""" + model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch") + model_patcher.set_model_output_block_patch(self.output_block) + model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch") diff --git a/comfy/model_base.py b/comfy/model_base.py index c2ae646aa..5c2668ba9 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -53,6 +53,7 @@ import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 import comfy.ldm.rt_detr.rtdetr_v4 +import comfy.ldm.ernie.model import comfy.model_management import comfy.patcher_extension @@ -1962,3 +1963,14 @@ class Kandinsky5Image(Kandinsky5): class RT_DETR_v4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) + +class ErnieImage(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8bed6828d..ca06cdd1e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -713,6 +713,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0] return dit_config + if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image + dit_config = {} + dit_config["image_model"] = "ernie" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/model_management.py b/comfy/model_management.py index 14d9f80fb..46261a0ed 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1762,6 +1762,21 @@ def supports_mxfp8_compute(device=None): return True +def supports_fp64(device=None): + if is_device_mps(device): + return False + + if is_intel_xpu(): + return False + + if is_directml_enabled(): + return False + + if is_ixuca(): + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index a74a51902..092bc6a79 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -595,6 +595,10 @@ class ModelPatcher: def set_model_noise_refiner_patch(self, patch): self.set_model_patch(patch, "noise_refiner") + def set_model_middle_block_after_patch(self, patch): + self.set_model_patch(patch, "middle_block_after_patch") + + def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): rope_options = self.model_options["transformer_options"].get("rope_options", {}) rope_options["scale_x"] = scale_x diff --git a/comfy/ops.py b/comfy/ops.py index b5cd1d47e..7a9b4b84c 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1151,7 +1151,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec if param is None: continue p = fn(param) - if p.is_inference(): + if (not torch.is_inference_mode_enabled()) and p.is_inference(): p = p.clone() self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) for key, buf in self._buffers.items(): diff --git a/comfy/sd.py b/comfy/sd.py index b177f8c89..0ce450ace 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -62,6 +62,7 @@ import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 +import comfy.text_encoders.ernie import comfy.model_patcher import comfy.lora @@ -1235,6 +1236,7 @@ class TEModel(Enum): QWEN35_4B = 25 QWEN35_9B = 26 QWEN35_27B = 27 + MINISTRAL_3_3B = 28 def detect_te_model(sd): @@ -1301,6 +1303,8 @@ def detect_te_model(sd): return TEModel.MISTRAL3_24B else: return TEModel.MISTRAL3_24B_PRUNED_FLUX2 + if weight.shape[0] == 3072: + return TEModel.MINISTRAL_3_3B return TEModel.LLAMA3_8 return None @@ -1458,6 +1462,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif te_model == TEModel.QWEN3_06B: clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer + elif te_model == TEModel.MINISTRAL_3_3B: + clip_target.clip = comfy.text_encoders.ernie.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ernie.ErnieTokenizer + tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) else: # clip_l if clip_type == CLIPType.SD3: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9a5612716..58d4ce731 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -26,6 +26,7 @@ import comfy.text_encoders.z_image import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image +import comfy.text_encoders.ernie from . import supported_models_base from . import latent_formats @@ -1749,6 +1750,37 @@ class RT_DETR_v4(supported_models_base.BASE): def clip_target(self, state_dict={}): return None -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4] + +class ErnieImage(supported_models_base.BASE): + unet_config = { + "image_model": "ernie", + } + + sampling_settings = { + "multiplier": 1000.0, + "shift": 3.0, + } + + memory_usage_factor = 10.0 + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.ErnieImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}ministral3_3b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/ernie.py b/comfy/text_encoders/ernie.py new file mode 100644 index 000000000..46d24d222 --- /dev/null +++ b/comfy/text_encoders/ernie.py @@ -0,0 +1,38 @@ +from .flux import Mistral3Tokenizer +from comfy import sd1_clip +import comfy.text_encoders.llama + +class Ministral3_3BTokenizer(Mistral3Tokenizer): + def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='ministral3_3b', tokenizer_data={}): + return super().__init__(embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_data=tokenizer_data) + +class ErnieTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="ministral3_3b", tokenizer=Mistral3Tokenizer) + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + tokens = super().tokenize_with_weights(text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + + +class Ministral3_3BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + textmodel_json_config = {} + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ministral3_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ErnieTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, name="ministral3_3b", clip_model=Ministral3_3BModel): + super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class ErnieTEModel_(ErnieTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return ErnieTEModel_ diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 1ae398789..d5eb91dcb 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -116,9 +116,9 @@ class MistralTokenizerClass: return LlamaTokenizerFast(**kwargs) class Mistral3Tokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_data={}): self.tekken_data = tokenizer_data.get("tekken_model", None) - super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, disable_weights=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) def state_dict(self): return {"tekken_model": self.tekken_data} diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 06f2fbf74..6ea8e36b1 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -60,6 +60,30 @@ class Mistral3Small24BConfig: final_norm: bool = True lm_head: bool = False +@dataclass +class Ministral3_3BConfig: + vocab_size: int = 131072 + hidden_size: int = 3072 + intermediate_size: int = 9216 + num_hidden_layers: int = 26 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 262144 + rms_norm_eps: float = 1e-5 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True + lm_head: bool = False + stop_tokens = [2] + @dataclass class Qwen25_3BConfig: vocab_size: int = 151936 @@ -946,6 +970,15 @@ class Mistral3Small24B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Ministral3_3B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Ministral3_3BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen25_3B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 18455396d..3755323ac 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -52,6 +52,26 @@ class TaskImageContent(BaseModel): role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None) +class TaskVideoContentUrl(BaseModel): + url: str = Field(...) + + +class TaskVideoContent(BaseModel): + type: str = Field("video_url") + video_url: TaskVideoContentUrl = Field(...) + role: str = Field("reference_video") + + +class TaskAudioContentUrl(BaseModel): + url: str = Field(...) + + +class TaskAudioContent(BaseModel): + type: str = Field("audio_url") + audio_url: TaskAudioContentUrl = Field(...) + role: str = Field("reference_audio") + + class Text2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent] = Field(..., min_length=1) @@ -64,6 +84,17 @@ class Image2VideoTaskCreationRequest(BaseModel): generate_audio: bool | None = Field(...) +class Seedance2TaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = Field(..., min_length=1) + generate_audio: bool | None = Field(None) + resolution: str | None = Field(None) + ratio: str | None = Field(None) + duration: int | None = Field(None, ge=4, le=15) + seed: int | None = Field(None, ge=0, le=2147483647) + watermark: bool | None = Field(None) + + class TaskCreationResponse(BaseModel): id: str = Field(...) @@ -77,12 +108,27 @@ class TaskStatusResult(BaseModel): video_url: str = Field(...) +class TaskStatusUsage(BaseModel): + completion_tokens: int = Field(0) + total_tokens: int = Field(0) + + class TaskStatusResponse(BaseModel): id: str = Field(...) model: str = Field(...) status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) error: TaskStatusError | None = Field(None) content: TaskStatusResult | None = Field(None) + usage: TaskStatusUsage | None = Field(None) + + +# Dollars per 1K tokens, keyed by (model_id, has_video_input). +SEEDANCE2_PRICE_PER_1K_TOKENS = { + ("dreamina-seedance-2-0-260128", False): 0.007, + ("dreamina-seedance-2-0-260128", True): 0.0043, + ("dreamina-seedance-2-0-fast-260128", False): 0.0056, + ("dreamina-seedance-2-0-fast-260128", True): 0.0033, +} RECOMMENDED_PRESETS = [ @@ -112,6 +158,12 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [ ("Custom", None, None), ] +# Seedance 2.0 reference video pixel count limits per model. +SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = { + "dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408}, + "dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408}, +} + # The time in this dictionary are given for 10 seconds duration. VIDEO_TASKS_EXECUTION_TIME = { "seedance-1-0-lite-t2v-250428": { diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index de0c22e70..429c32444 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -8,16 +8,23 @@ from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, + SEEDANCE2_PRICE_PER_1K_TOKENS, + SEEDANCE2_REF_VIDEO_PIXEL_LIMITS, VIDEO_TASKS_EXECUTION_TIME, Image2VideoTaskCreationRequest, ImageTaskCreationResponse, + Seedance2TaskCreationRequest, Seedream4Options, Seedream4TaskCreationRequest, + TaskAudioContent, + TaskAudioContentUrl, TaskCreationResponse, TaskImageContent, TaskImageContentUrl, TaskStatusResponse, TaskTextContent, + TaskVideoContent, + TaskVideoContentUrl, Text2ImageTaskCreationRequest, Text2VideoTaskCreationRequest, ) @@ -29,7 +36,10 @@ from comfy_api_nodes.util import ( image_tensor_pair_to_batch, poll_op, sync_op, + upload_audio_to_comfyapi, + upload_image_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_image_aspect_ratio, validate_image_dimensions, validate_string, @@ -46,12 +56,56 @@ SEEDREAM_MODELS = { # Long-running tasks endpoints(e.g., video) BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} +BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT = "/proxy/byteplus-seedance2/api/v3/contents/generations/tasks" # + /{task_id} + +SEEDANCE_MODELS = { + "Seedance 2.0": "dreamina-seedance-2-0-260128", + "Seedance 2.0 Fast": "dreamina-seedance-2-0-fast-260128", +} DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"} + logger = logging.getLogger(__name__) +def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None: + """Validate reference video pixel count against Seedance 2.0 model limits.""" + limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id) + if not limits: + return + try: + w, h = video.get_dimensions() + except Exception: + return + pixels = w * h + min_px = limits.get("min") + max_px = limits.get("max") + if min_px and pixels < min_px: + raise ValueError( + f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model." + ) + if max_px and pixels > max_px: + raise ValueError( + f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. " + f"Maximum is {max_px:,}px for this model. Try downscaling the video." + ) + + +def _seedance2_price_extractor(model_id: str, has_video_input: bool): + """Returns a price_extractor closure for Seedance 2.0 poll_op.""" + rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input)) + if rate is None: + return None + + def extractor(response: TaskStatusResponse) -> float | None: + if response.usage is None: + return None + return response.usage.total_tokens * 1.43 * rate / 1_000.0 + + return extractor + + def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: if response.error: error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}" @@ -335,8 +389,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): mp_provided = out_num_pixels / 1_000_000.0 if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400: raise ValueError( - f"Minimum image resolution for the selected model is 3.68MP, " - f"but {mp_provided:.2f}MP provided." + f"Minimum image resolution for the selected model is 3.68MP, " f"but {mp_provided:.2f}MP provided." ) if "seedream-4-0" in model and out_num_pixels < 921600: raise ValueError( @@ -952,33 +1005,6 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ) -async def process_video_task( - cls: type[IO.ComfyNode], - payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, - estimated_duration: int | None, -) -> IO.NodeOutput: - if payload.model in DEPRECATED_MODELS: - logger.warning( - "Model '%s' is deprecated and will be deactivated on May 13, 2026. " - "Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.", - payload.model, - ) - initial_response = await sync_op( - cls, - ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), - data=payload, - response_model=TaskCreationResponse, - ) - response = await poll_op( - cls, - ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), - status_extractor=lambda r: r.status, - estimated_duration=estimated_duration, - response_model=TaskStatusResponse, - ) - return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) - - def raise_if_text_params(prompt: str, text_params: list[str]) -> None: for i in text_params: if f"--{i} " in prompt: @@ -1040,6 +1066,542 @@ PRICE_BADGE_VIDEO = IO.PriceBadge( ) +def _seedance2_text_inputs(resolutions: list[str]): + return [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for video generation.", + ), + IO.Combo.Input( + "resolution", + options=resolutions, + tooltip="Resolution of the output video.", + ), + IO.Combo.Input( + "ratio", + options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"], + tooltip="Aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=7, + min=4, + max=15, + step=1, + tooltip="Duration of the output video in seconds (4-15).", + display_mode=IO.NumberDisplay.slider, + ), + IO.Boolean.Input( + "generate_audio", + default=True, + tooltip="Enable audio generation for the output video.", + ), + ] + + +class ByteDance2TextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDance2TextToVideoNode", + display_name="ByteDance Seedance 2.0 Text to Video", + category="api node/video/ByteDance", + description="Generate video using Seedance 2.0 models based on a text prompt.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])), + IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])), + ], + tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add a watermark to the video.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]), + expr=""" + ( + $rate480 := 10044; + $rate720 := 21600; + $rate1080 := 48800; + $m := widgets.model; + $pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001; + $res := $lookup(widgets, "model.resolution"); + $dur := $lookup(widgets, "model.duration"); + $rate := $res = "1080p" ? $rate1080 : + $res = "720p" ? $rate720 : + $rate480; + $cost := $dur * $rate * $pricePer1K / 1000; + {"type": "usd", "usd": $cost, "format": {"approximate": true}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + seed: int, + watermark: bool, + ) -> IO.NodeOutput: + validate_string(model["prompt"], strip_whitespace=True, min_length=1) + model_id = SEEDANCE_MODELS[model["model"]] + initial_response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + data=Seedance2TaskCreationRequest( + model=model_id, + content=[TaskTextContent(text=model["prompt"])], + generate_audio=model["generate_audio"], + resolution=model["resolution"], + ratio=model["ratio"], + duration=model["duration"], + seed=seed, + watermark=watermark, + ), + response_model=TaskCreationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.status, + price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), + poll_interval=9, + max_poll_attempts=180, + ) + return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) + + +class ByteDance2FirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDance2FirstLastFrameNode", + display_name="ByteDance Seedance 2.0 First-Last-Frame to Video", + category="api node/video/ByteDance", + description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])), + IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])), + ], + tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", + ), + IO.Image.Input( + "first_frame", + tooltip="First frame image for the video.", + ), + IO.Image.Input( + "last_frame", + tooltip="Last frame image for the video.", + optional=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add a watermark to the video.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]), + expr=""" + ( + $rate480 := 10044; + $rate720 := 21600; + $rate1080 := 48800; + $m := widgets.model; + $pricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001; + $res := $lookup(widgets, "model.resolution"); + $dur := $lookup(widgets, "model.duration"); + $rate := $res = "1080p" ? $rate1080 : + $res = "720p" ? $rate720 : + $rate480; + $cost := $dur * $rate * $pricePer1K / 1000; + {"type": "usd", "usd": $cost, "format": {"approximate": true}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + first_frame: Input.Image, + seed: int, + watermark: bool, + last_frame: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(model["prompt"], strip_whitespace=True, min_length=1) + model_id = SEEDANCE_MODELS[model["model"]] + + content: list[TaskTextContent | TaskImageContent] = [ + TaskTextContent(text=model["prompt"]), + TaskImageContent( + image_url=TaskImageContentUrl( + url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.") + ), + role="first_frame", + ), + ] + if last_frame is not None: + content.append( + TaskImageContent( + image_url=TaskImageContentUrl( + url=await upload_image_to_comfyapi(cls, last_frame, wait_label="Uploading last frame.") + ), + role="last_frame", + ), + ) + + initial_response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + data=Seedance2TaskCreationRequest( + model=model_id, + content=content, + generate_audio=model["generate_audio"], + resolution=model["resolution"], + ratio=model["ratio"], + duration=model["duration"], + seed=seed, + watermark=watermark, + ), + response_model=TaskCreationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.status, + price_extractor=_seedance2_price_extractor(model_id, has_video_input=False), + poll_interval=9, + max_poll_attempts=180, + ) + return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) + + +def _seedance2_reference_inputs(resolutions: list[str]): + return [ + *_seedance2_text_inputs(resolutions), + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("reference_image"), + names=[ + "image_1", + "image_2", + "image_3", + "image_4", + "image_5", + "image_6", + "image_7", + "image_8", + "image_9", + ], + min=0, + ), + ), + IO.Autogrow.Input( + "reference_videos", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("reference_video"), + names=["video_1", "video_2", "video_3"], + min=0, + ), + ), + IO.Autogrow.Input( + "reference_audios", + template=IO.Autogrow.TemplateNames( + IO.Audio.Input("reference_audio"), + names=["audio_1", "audio_2", "audio_3"], + min=0, + ), + ), + ] + + +class ByteDance2ReferenceNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ByteDance2ReferenceNode", + display_name="ByteDance Seedance 2.0 Reference to Video", + category="api node/video/ByteDance", + description="Generate, edit, or extend video using Seedance 2.0 with reference images, " + "videos, and audio. Supports multimodal reference, video editing, and video extension.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])), + IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])), + ], + tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add a watermark to the video.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model", "model.resolution", "model.duration"], + input_groups=["model.reference_videos"], + ), + expr=""" + ( + $rate480 := 10044; + $rate720 := 21600; + $rate1080 := 48800; + $m := widgets.model; + $hasVideo := $lookup(inputGroups, "model.reference_videos") > 0; + $noVideoPricePer1K := $contains($m, "fast") ? 0.008008 : 0.01001; + $videoPricePer1K := $contains($m, "fast") ? 0.004719 : 0.006149; + $res := $lookup(widgets, "model.resolution"); + $dur := $lookup(widgets, "model.duration"); + $rate := $res = "1080p" ? $rate1080 : + $res = "720p" ? $rate720 : + $rate480; + $noVideoCost := $dur * $rate * $noVideoPricePer1K / 1000; + $minVideoFactor := $ceil($dur * 5 / 3); + $minVideoCost := $minVideoFactor * $rate * $videoPricePer1K / 1000; + $maxVideoCost := (15 + $dur) * $rate * $videoPricePer1K / 1000; + $hasVideo + ? { + "type": "range_usd", + "min_usd": $minVideoCost, + "max_usd": $maxVideoCost, + "format": {"approximate": true} + } + : { + "type": "usd", + "usd": $noVideoCost, + "format": {"approximate": true} + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + seed: int, + watermark: bool, + ) -> IO.NodeOutput: + validate_string(model["prompt"], strip_whitespace=True, min_length=1) + + reference_images = model.get("reference_images", {}) + reference_videos = model.get("reference_videos", {}) + reference_audios = model.get("reference_audios", {}) + + if not reference_images and not reference_videos: + raise ValueError("At least one reference image or video is required.") + + model_id = SEEDANCE_MODELS[model["model"]] + has_video_input = len(reference_videos) > 0 + total_video_duration = 0.0 + for i, key in enumerate(reference_videos, 1): + video = reference_videos[key] + _validate_ref_video_pixels(video, model_id, i) + try: + dur = video.get_duration() + if dur < 1.8: + raise ValueError(f"Reference video {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.") + total_video_duration += dur + except ValueError: + raise + except Exception: + pass + if total_video_duration > 15.1: + raise ValueError(f"Total reference video duration is {total_video_duration:.1f}s. Maximum is 15.1 seconds.") + + total_audio_duration = 0.0 + for i, key in enumerate(reference_audios, 1): + audio = reference_audios[key] + dur = int(audio["waveform"].shape[-1]) / int(audio["sample_rate"]) + if dur < 1.8: + raise ValueError(f"Reference audio {i} is too short: {dur:.1f}s. Minimum duration is 1.8 seconds.") + total_audio_duration += dur + if total_audio_duration > 15.1: + raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.") + + content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [ + TaskTextContent(text=model["prompt"]), + ] + for i, key in enumerate(reference_images, 1): + content.append( + TaskImageContent( + image_url=TaskImageContentUrl( + url=await upload_image_to_comfyapi( + cls, + image=reference_images[key], + wait_label=f"Uploading image {i}", + ), + ), + role="reference_image", + ), + ) + for i, key in enumerate(reference_videos, 1): + content.append( + TaskVideoContent( + video_url=TaskVideoContentUrl( + url=await upload_video_to_comfyapi( + cls, + reference_videos[key], + wait_label=f"Uploading video {i}", + ), + ), + ), + ) + for key in reference_audios: + content.append( + TaskAudioContent( + audio_url=TaskAudioContentUrl( + url=await upload_audio_to_comfyapi( + cls, + reference_audios[key], + container_format="mp3", + codec_name="libmp3lame", + mime_type="audio/mpeg", + ), + ), + ), + ) + initial_response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + data=Seedance2TaskCreationRequest( + model=model_id, + content=content, + generate_audio=model["generate_audio"], + resolution=model["resolution"], + ratio=model["ratio"], + duration=model["duration"], + seed=seed, + watermark=watermark, + ), + response_model=TaskCreationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"{BYTEPLUS_SEEDANCE2_TASK_STATUS_ENDPOINT}/{initial_response.id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.status, + price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input), + poll_interval=9, + max_poll_attempts=180, + ) + return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) + + +async def process_video_task( + cls: type[IO.ComfyNode], + payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, + estimated_duration: int | None, +) -> IO.NodeOutput: + if payload.model in DEPRECATED_MODELS: + logger.warning( + "Model '%s' is deprecated and will be deactivated on May 13, 2026. " + "Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.", + payload.model, + ) + initial_response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + data=payload, + response_model=TaskCreationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), + status_extractor=lambda r: r.status, + estimated_duration=estimated_duration, + response_model=TaskStatusResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) + + class ByteDanceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1050,6 +1612,9 @@ class ByteDanceExtension(ComfyExtension): ByteDanceImageToVideoNode, ByteDanceFirstLastFrameNode, ByteDanceImageReferenceNode, + ByteDance2TextToVideoNode, + ByteDance2FirstLastFrameNode, + ByteDance2ReferenceNode, ] diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index dabc899d6..f42d84616 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -558,7 +558,7 @@ class GrokVideoReferenceNode(IO.ComfyNode): ( $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $refs := inputGroups["model.reference_images"]; + $refs := $lookup(inputGroups, "model.reference_images"); $rate := $res = "720p" ? 0.07 : 0.05; $price := ($rate * $dur + 0.002 * $refs) * 1.43; {"type":"usd","usd": $price} diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index 44c94a98e..5fc31bccd 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -221,14 +221,17 @@ class TencentTextToModelNode(IO.ComfyNode): response_model=To3DProTaskResultResponse, status_extractor=lambda r: r.Status, ) - obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url) + obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False) + obj_result = None + if obj_file_response: + obj_result = await download_and_extract_obj_zip(obj_file_response.Url) return IO.NodeOutput( f"{task_id}.glb", await download_url_to_file_3d( get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id ), - obj_result.obj, - obj_result.texture, + obj_result.obj if obj_result else None, + obj_result.texture if obj_result else None, ) @@ -378,17 +381,30 @@ class TencentImageToModelNode(IO.ComfyNode): response_model=To3DProTaskResultResponse, status_extractor=lambda r: r.Status, ) - obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url) + obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False) + if obj_file_response: + obj_result = await download_and_extract_obj_zip(obj_file_response.Url) + return IO.NodeOutput( + f"{task_id}.glb", + await download_url_to_file_3d( + get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id + ), + obj_result.obj, + obj_result.texture, + obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3), + obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3), + obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3), + ) return IO.NodeOutput( f"{task_id}.glb", await download_url_to_file_3d( get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id ), - obj_result.obj, - obj_result.texture, - obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3), - obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3), - obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3), + None, + None, + None, + None, + None, ) diff --git a/comfy_api_nodes/nodes_quiver.py b/comfy_api_nodes/nodes_quiver.py index 61533263f..28862e368 100644 --- a/comfy_api_nodes/nodes_quiver.py +++ b/comfy_api_nodes/nodes_quiver.py @@ -17,6 +17,44 @@ from comfy_api_nodes.util import ( ) from comfy_extras.nodes_images import SVG +_ARROW_MODELS = ["arrow-1.1", "arrow-1.1-max", "arrow-preview"] + + +def _arrow_sampling_inputs(): + """Shared sampling inputs for all Arrow model variants.""" + return [ + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Randomness control. Higher values increase randomness.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=1.0, + min=0.05, + max=1.0, + step=0.05, + display_mode=IO.NumberDisplay.slider, + tooltip="Nucleus sampling parameter.", + advanced=True, + ), + IO.Float.Input( + "presence_penalty", + default=0.0, + min=-2.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Token presence penalty.", + advanced=True, + ), + ] + class QuiverTextToSVGNode(IO.ComfyNode): @classmethod @@ -39,6 +77,7 @@ class QuiverTextToSVGNode(IO.ComfyNode): default="", tooltip="Additional style or formatting guidance.", optional=True, + advanced=True, ), IO.Autogrow.Input( "reference_images", @@ -53,43 +92,7 @@ class QuiverTextToSVGNode(IO.ComfyNode): ), IO.DynamicCombo.Input( "model", - options=[ - IO.DynamicCombo.Option( - "arrow-preview", - [ - IO.Float.Input( - "temperature", - default=1.0, - min=0.0, - max=2.0, - step=0.1, - display_mode=IO.NumberDisplay.slider, - tooltip="Randomness control. Higher values increase randomness.", - advanced=True, - ), - IO.Float.Input( - "top_p", - default=1.0, - min=0.05, - max=1.0, - step=0.05, - display_mode=IO.NumberDisplay.slider, - tooltip="Nucleus sampling parameter.", - advanced=True, - ), - IO.Float.Input( - "presence_penalty", - default=0.0, - min=-2.0, - max=2.0, - step=0.1, - display_mode=IO.NumberDisplay.slider, - tooltip="Token presence penalty.", - advanced=True, - ), - ], - ), - ], + options=[IO.DynamicCombo.Option(m, _arrow_sampling_inputs()) for m in _ARROW_MODELS], tooltip="Model to use for SVG generation.", ), IO.Int.Input( @@ -112,7 +115,16 @@ class QuiverTextToSVGNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.429}""", + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $contains(widgets.model, "max") + ? {"type":"usd","usd":0.3575} + : $contains(widgets.model, "preview") + ? {"type":"usd","usd":0.429} + : {"type":"usd","usd":0.286} + ) + """, ), ) @@ -176,12 +188,13 @@ class QuiverImageToSVGNode(IO.ComfyNode): "auto_crop", default=False, tooltip="Automatically crop to the dominant subject.", + advanced=True, ), IO.DynamicCombo.Input( "model", options=[ IO.DynamicCombo.Option( - "arrow-preview", + m, [ IO.Int.Input( "target_size", @@ -189,39 +202,12 @@ class QuiverImageToSVGNode(IO.ComfyNode): min=128, max=4096, tooltip="Square resize target in pixels.", - ), - IO.Float.Input( - "temperature", - default=1.0, - min=0.0, - max=2.0, - step=0.1, - display_mode=IO.NumberDisplay.slider, - tooltip="Randomness control. Higher values increase randomness.", - advanced=True, - ), - IO.Float.Input( - "top_p", - default=1.0, - min=0.05, - max=1.0, - step=0.05, - display_mode=IO.NumberDisplay.slider, - tooltip="Nucleus sampling parameter.", - advanced=True, - ), - IO.Float.Input( - "presence_penalty", - default=0.0, - min=-2.0, - max=2.0, - step=0.1, - display_mode=IO.NumberDisplay.slider, - tooltip="Token presence penalty.", advanced=True, ), + *_arrow_sampling_inputs(), ], - ), + ) + for m in _ARROW_MODELS ], tooltip="Model to use for SVG vectorization.", ), @@ -245,7 +231,16 @@ class QuiverImageToSVGNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.429}""", + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $contains(widgets.model, "max") + ? {"type":"usd","usd":0.3575} + : $contains(widgets.model, "preview") + ? {"type":"usd","usd":0.429} + : {"type":"usd","usd":0.286} + ) + """, ), ) diff --git a/comfy_api_nodes/nodes_sonilo.py b/comfy_api_nodes/nodes_sonilo.py new file mode 100644 index 000000000..5518f5902 --- /dev/null +++ b/comfy_api_nodes/nodes_sonilo.py @@ -0,0 +1,287 @@ +import base64 +import json +import logging +import time +from urllib.parse import urljoin + +import aiohttp +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.util import ( + ApiEndpoint, + audio_bytes_to_audio_input, + upload_video_to_comfyapi, + validate_string, +) +from comfy_api_nodes.util._helpers import ( + default_base_url, + get_auth_header, + get_node_id, + is_processing_interrupted, +) +from comfy_api_nodes.util.common_exceptions import ProcessingInterrupted +from server import PromptServer + +logger = logging.getLogger(__name__) + + +class SoniloVideoToMusic(IO.ComfyNode): + """Generate music from video using Sonilo's AI model.""" + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="SoniloVideoToMusic", + display_name="Sonilo Video to Music", + category="api node/audio/Sonilo", + description="Generate music from video content using Sonilo's AI model. " + "Analyzes the video and creates matching music.", + inputs=[ + IO.Video.Input( + "video", + tooltip="Input video to generate music from. Maximum duration: 6 minutes.", + ), + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Optional text prompt to guide music generation. " + "Leave empty for best quality - the model will fully analyze the video content.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed for reproducibility. Currently ignored by the Sonilo " + "service but kept for graph consistency.", + ), + ], + outputs=[IO.Audio.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr='{"type":"usd","usd":0.009,"format":{"suffix":"/second"}}', + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + prompt: str = "", + seed: int = 0, + ) -> IO.NodeOutput: + video_url = await upload_video_to_comfyapi(cls, video, max_duration=360) + form = aiohttp.FormData() + form.add_field("video_url", video_url) + if prompt.strip(): + form.add_field("prompt", prompt.strip()) + audio_bytes = await _stream_sonilo_music( + cls, + ApiEndpoint(path="/proxy/sonilo/v2m/generate", method="POST"), + form, + ) + return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes)) + + +class SoniloTextToMusic(IO.ComfyNode): + """Generate music from a text prompt using Sonilo's AI model.""" + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="SoniloTextToMusic", + display_name="Sonilo Text to Music", + category="api node/audio/Sonilo", + description="Generate music from a text prompt using Sonilo's AI model. " + "Leave duration at 0 to let the model infer it from the prompt.", + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt describing the music to generate.", + ), + IO.Int.Input( + "duration", + default=0, + min=0, + max=360, + tooltip="Target duration in seconds. Set to 0 to let the model " + "infer the duration from the prompt. Maximum: 6 minutes.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed for reproducibility. Currently ignored by the Sonilo " + "service but kept for graph consistency.", + ), + ], + outputs=[IO.Audio.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr=""" + ( + widgets.duration > 0 + ? {"type":"usd","usd": 0.005 * widgets.duration} + : {"type":"usd","usd": 0.005, "format":{"suffix":"/second"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + duration: int = 0, + seed: int = 0, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + form = aiohttp.FormData() + form.add_field("prompt", prompt) + if duration > 0: + form.add_field("duration", str(duration)) + audio_bytes = await _stream_sonilo_music( + cls, + ApiEndpoint(path="/proxy/sonilo/t2m/generate", method="POST"), + form, + ) + return IO.NodeOutput(audio_bytes_to_audio_input(audio_bytes)) + + +async def _stream_sonilo_music( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + form: aiohttp.FormData, +) -> bytes: + """POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes.""" + url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/")) + + headers: dict[str, str] = {} + headers.update(get_auth_header(cls)) + headers.update(endpoint.headers) + + node_id = get_node_id(cls) + start_ts = time.monotonic() + last_chunk_status_ts = 0.0 + audio_streams: dict[int, list[bytes]] = {} + title: str | None = None + + timeout = aiohttp.ClientTimeout(total=1200.0, sock_read=300.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + PromptServer.instance.send_progress_text("Status: Queued", node_id) + async with session.post(url, data=form, headers=headers) as resp: + if resp.status >= 400: + msg = await _extract_error_message(resp) + raise Exception(f"Sonilo API error ({resp.status}): {msg}") + + while True: + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + + raw_line = await resp.content.readline() + if not raw_line: + break + + line = raw_line.decode("utf-8").strip() + if not line: + continue + + try: + evt = json.loads(line) + except json.JSONDecodeError: + logger.warning("Sonilo: skipping malformed NDJSON line") + continue + + evt_type = evt.get("type") + if evt_type == "error": + code = evt.get("code", "UNKNOWN") + message = evt.get("message", "Unknown error") + raise Exception(f"Sonilo generation error ({code}): {message}") + if evt_type == "duration": + duration_sec = evt.get("duration_sec") + if duration_sec is not None: + PromptServer.instance.send_progress_text( + f"Status: Generating\nVideo duration: {duration_sec:.1f}s", + node_id, + ) + elif evt_type in ("titles", "title"): + # v2m sends a "titles" list, t2m sends a scalar "title" + if evt_type == "titles": + titles = evt.get("titles", []) + if titles: + title = titles[0] + else: + title = evt.get("title") or title + if title: + PromptServer.instance.send_progress_text( + f"Status: Generating\nTitle: {title}", + node_id, + ) + elif evt_type == "audio_chunk": + stream_idx = evt.get("stream_index", 0) + chunk_data = base64.b64decode(evt["data"]) + + if stream_idx not in audio_streams: + audio_streams[stream_idx] = [] + audio_streams[stream_idx].append(chunk_data) + + now = time.monotonic() + if now - last_chunk_status_ts >= 1.0: + total_chunks = sum(len(chunks) for chunks in audio_streams.values()) + elapsed = int(now - start_ts) + status_lines = ["Status: Receiving audio"] + if title: + status_lines.append(f"Title: {title}") + status_lines.append(f"Chunks received: {total_chunks}") + status_lines.append(f"Time elapsed: {elapsed}s") + PromptServer.instance.send_progress_text("\n".join(status_lines), node_id) + last_chunk_status_ts = now + elif evt_type == "complete": + break + + if not audio_streams: + raise Exception("Sonilo API returned no audio data.") + + PromptServer.instance.send_progress_text("Status: Completed", node_id) + selected_stream = 0 if 0 in audio_streams else min(audio_streams) + return b"".join(audio_streams[selected_stream]) + + +async def _extract_error_message(resp: aiohttp.ClientResponse) -> str: + """Extract a human-readable error message from an HTTP error response.""" + try: + error_body = await resp.json() + detail = error_body.get("detail", {}) + if isinstance(detail, dict): + return detail.get("message", str(detail)) + return str(detail) + except Exception: + return await resp.text() + + +class SoniloExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [SoniloVideoToMusic, SoniloTextToMusic] + + +async def comfy_entrypoint() -> SoniloExtension: + return SoniloExtension() diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 9ef13c83b..906d8ff35 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -401,7 +401,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.25}""", + expr="""{"type":"usd","usd":0.4}""", ), ) @@ -510,7 +510,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.25}""", + expr="""{"type":"usd","usd":0.6}""", ), ) @@ -593,7 +593,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.01}""", + expr="""{"type":"usd","usd":0.02}""", ), ) diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 176e6bc2f..748559a6b 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -7,7 +7,10 @@ import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats import comfy.ldm.lumina.controlnet +import comfy.ldm.supir.supir_modules from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel +from comfy_api.latest import io +from comfy.ldm.supir.supir_patch import SUPIRPatch class BlockWiseControlBlock(torch.nn.Module): @@ -266,6 +269,27 @@ class ModelPatchLoader: out_dim=sd["audio_proj.norm.weight"].shape[0], device=comfy.model_management.unet_offload_device(), operations=comfy.ops.manual_cast) + elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd: + prefix_replace = {} + if 'model.control_model.input_hint_block.0.weight' in sd: + prefix_replace["model.control_model."] = "control_model." + prefix_replace["model.diffusion_model.project_modules."] = "project_modules." + else: + prefix_replace["control_model."] = "control_model." + prefix_replace["project_modules."] = "project_modules." + + # Extract denoise_encoder weights before filter_keys discards them + de_prefix = "first_stage_model.denoise_encoder." + denoise_encoder_sd = {} + for k in list(sd.keys()): + if k.startswith(de_prefix): + denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k) + + sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True) + sd.pop("control_model.mask_LQ", None) + model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + if denoise_encoder_sd: + model.denoise_encoder_sd = denoise_encoder_sd model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) model.load_state_dict(sd, assign=model_patcher.is_dynamic()) @@ -565,9 +589,89 @@ class MultiTalkModelPatch(torch.nn.Module): ) +class SUPIRApply(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SUPIRApply", + category="model_patches/supir", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.ModelPatch.Input("model_patch"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01, + tooltip="Control strength at the start of sampling (high sigma)."), + io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01, + tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."), + io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True, + tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."), + io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Sigma threshold below which restore_cfg is disabled."), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def _encode_with_denoise_encoder(cls, vae, model_patch, image): + """Encode using denoise_encoder weights from SUPIR checkpoint if available.""" + denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None) + if not denoise_sd: + return vae.encode(image) + + # Clone VAE patcher, apply denoise_encoder weights to clone, encode + orig_patcher = vae.patcher + vae.patcher = orig_patcher.clone() + patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()} + vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0) + try: + return vae.encode(image) + finally: + vae.patcher = orig_patcher + + @classmethod + def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type, + strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput: + model_patched = model.clone() + hint_latent = model.get_model_object("latent_format").process_in( + cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3])) + patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end) + patch.register(model_patched) + + if restore_cfg > 0.0: + # Round-trip to match original pipeline: decode hint, re-encode with regular VAE + latent_format = model.get_model_object("latent_format") + decoded = vae.decode(latent_format.process_out(hint_latent)) + x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3])) + sigma_max = 14.6146 + + def restore_cfg_function(args): + denoised = args["denoised"] + sigma = args["sigma"] + if sigma.dim() > 0: + s = sigma[0].item() + else: + s = sigma.item() + if s > restore_cfg_s_tmin: + ref = x_center.to(device=denoised.device, dtype=denoised.dtype) + b = denoised.shape[0] + if ref.shape[0] != b: + ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b] + sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma + d_center = denoised - ref + denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg) + return denoised + + model_patched.set_model_sampler_post_cfg_function(restore_cfg_function) + + return io.NodeOutput(model_patched) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, "ZImageFunControlnet": ZImageFunControlnet, "USOStyleReference": USOStyleReference, + "SUPIRApply": SUPIRApply, } diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 9037c3d20..c932b747a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -6,6 +6,7 @@ from PIL import Image import math from enum import Enum from typing import TypedDict, Literal +import kornia import comfy.utils import comfy.model_management @@ -660,6 +661,228 @@ class BatchImagesMasksLatentsNode(io.ComfyNode): return io.NodeOutput(batched) +class ColorTransfer(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ColorTransfer", + category="image/postprocessing", + description="Match the colors of one image to another using various algorithms.", + search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"], + inputs=[ + io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."), + io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"), + io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],), + io.DynamicCombo.Input("source_stats", + tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)", + options=[ + io.DynamicCombo.Option("per_frame", []), + io.DynamicCombo.Option("uniform", []), + io.DynamicCombo.Option("target_frame", [ + io.Int.Input("target_index", default=0, min=0, max=10000, + tooltip="Frame index used as the source baseline for computing the transform to image_ref"), + ]), + ]), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + io.Image.Output(display_name="image"), + ], + ) + + @staticmethod + def _to_lab(images, i, device): + return kornia.color.rgb_to_lab( + images[i:i+1].to(device, dtype=torch.float32).permute(0, 3, 1, 2)) + + @staticmethod + def _pool_stats(images, device, is_reinhard, eps): + """Two-pass pooled mean + std/cov across all frames.""" + N, C = images.shape[0], images.shape[3] + HW = images.shape[1] * images.shape[2] + mean = torch.zeros(C, 1, device=device, dtype=torch.float32) + for i in range(N): + mean += ColorTransfer._to_lab(images, i, device).view(C, -1).mean(dim=-1, keepdim=True) + mean /= N + acc = torch.zeros(C, 1 if is_reinhard else C, device=device, dtype=torch.float32) + for i in range(N): + centered = ColorTransfer._to_lab(images, i, device).view(C, -1) - mean + if is_reinhard: + acc += (centered * centered).mean(dim=-1, keepdim=True) + else: + acc += centered @ centered.T / HW + if is_reinhard: + return mean, torch.sqrt(acc / N).clamp_min_(eps) + return mean, acc / N + + @staticmethod + def _frame_stats(lab_flat, hw, is_reinhard, eps): + """Per-frame mean + std/cov.""" + mean = lab_flat.mean(dim=-1, keepdim=True) + if is_reinhard: + return mean, lab_flat.std(dim=-1, keepdim=True, unbiased=False).clamp_min_(eps) + centered = lab_flat - mean + return mean, centered @ centered.T / hw + + @staticmethod + def _mkl_matrix(cov_s, cov_r, eps): + """Compute MKL 3x3 transform matrix from source and ref covariances.""" + eig_val_s, eig_vec_s = torch.linalg.eigh(cov_s) + sqrt_val_s = torch.sqrt(eig_val_s.clamp_min(0)).clamp_min_(eps) + + scaled_V = eig_vec_s * sqrt_val_s.unsqueeze(0) + mid = scaled_V.T @ cov_r @ scaled_V + eig_val_m, eig_vec_m = torch.linalg.eigh(mid) + sqrt_m = torch.sqrt(eig_val_m.clamp_min(0)) + + inv_sqrt_s = 1.0 / sqrt_val_s + inv_scaled_V = eig_vec_s * inv_sqrt_s.unsqueeze(0) + M_half = (eig_vec_m * sqrt_m.unsqueeze(0)) @ eig_vec_m.T + return inv_scaled_V @ M_half @ inv_scaled_V.T + + @staticmethod + def _histogram_lut(src, ref, bins=256): + """Build per-channel LUT from source and ref histograms. src/ref: (C, HW) in [0,1].""" + s_bins = (src * (bins - 1)).long().clamp(0, bins - 1) + r_bins = (ref * (bins - 1)).long().clamp(0, bins - 1) + s_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype) + r_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype) + ones_s = torch.ones_like(src) + ones_r = torch.ones_like(ref) + s_hist.scatter_add_(1, s_bins, ones_s) + r_hist.scatter_add_(1, r_bins, ones_r) + s_cdf = s_hist.cumsum(1) + s_cdf = s_cdf / s_cdf[:, -1:] + r_cdf = r_hist.cumsum(1) + r_cdf = r_cdf / r_cdf[:, -1:] + return torch.searchsorted(r_cdf, s_cdf).clamp_max_(bins - 1).float() / (bins - 1) + + @classmethod + def _pooled_cdf(cls, images, device, num_bins=256): + """Build pooled CDF across all frames, one frame at a time.""" + C = images.shape[3] + hist = torch.zeros(C, num_bins, device=device, dtype=torch.float32) + for i in range(images.shape[0]): + frame = images[i].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1) + bins = (frame * (num_bins - 1)).long().clamp(0, num_bins - 1) + hist.scatter_add_(1, bins, torch.ones_like(frame)) + cdf = hist.cumsum(1) + return cdf / cdf[:, -1:] + + @classmethod + def _build_histogram_transform(cls, image_target, image_ref, device, stats_mode, target_index, B): + """Build per-frame or uniform LUT transform for histogram mode.""" + if stats_mode == 'per_frame': + return None # LUT computed per-frame in the apply loop + + r_cdf = cls._pooled_cdf(image_ref, device) + if stats_mode == 'target_frame': + ti = min(target_index, B - 1) + s_cdf = cls._pooled_cdf(image_target[ti:ti+1], device) + else: + s_cdf = cls._pooled_cdf(image_target, device) + return torch.searchsorted(r_cdf, s_cdf).clamp_max_(255).float() / 255.0 + + @classmethod + def _build_lab_transform(cls, image_target, image_ref, device, stats_mode, target_index, is_reinhard): + """Build transform parameters for Lab-based methods. Returns a transform function.""" + eps = 1e-6 + B, H, W, C = image_target.shape + B_ref = image_ref.shape[0] + single_ref = B_ref == 1 + HW = H * W + HW_ref = image_ref.shape[1] * image_ref.shape[2] + + # Precompute ref stats + if single_ref or stats_mode in ('uniform', 'target_frame'): + ref_mean, ref_sc = cls._pool_stats(image_ref, device, is_reinhard, eps) + + # Uniform/target_frame: precompute single affine transform + if stats_mode in ('uniform', 'target_frame'): + if stats_mode == 'target_frame': + ti = min(target_index, B - 1) + s_lab = cls._to_lab(image_target, ti, device).view(C, -1) + s_mean, s_sc = cls._frame_stats(s_lab, HW, is_reinhard, eps) + else: + s_mean, s_sc = cls._pool_stats(image_target, device, is_reinhard, eps) + + if is_reinhard: + scale = ref_sc / s_sc + offset = ref_mean - scale * s_mean + return lambda src_flat, **_: src_flat * scale + offset + T = cls._mkl_matrix(s_sc, ref_sc, eps) + offset = ref_mean - T @ s_mean + return lambda src_flat, **_: T @ src_flat + offset + + # per_frame + def per_frame_transform(src_flat, frame_idx): + s_mean, s_sc = cls._frame_stats(src_flat, HW, is_reinhard, eps) + + if single_ref: + r_mean, r_sc = ref_mean, ref_sc + else: + ri = min(frame_idx, B_ref - 1) + r_mean, r_sc = cls._frame_stats(cls._to_lab(image_ref, ri, device).view(C, -1), HW_ref, is_reinhard, eps) + + centered = src_flat - s_mean + if is_reinhard: + return centered * (r_sc / s_sc) + r_mean + T = cls._mkl_matrix(centered @ centered.T / HW, r_sc, eps) + return T @ centered + r_mean + + return per_frame_transform + + @classmethod + def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput: + stats_mode = source_stats["source_stats"] + target_index = source_stats.get("target_index", 0) + + if strength == 0 or image_ref is None: + return io.NodeOutput(image_target) + + device = comfy.model_management.get_torch_device() + intermediate_device = comfy.model_management.intermediate_device() + intermediate_dtype = comfy.model_management.intermediate_dtype() + + B, H, W, C = image_target.shape + B_ref = image_ref.shape[0] + pbar = comfy.utils.ProgressBar(B) + out = torch.empty(B, H, W, C, device=intermediate_device, dtype=intermediate_dtype) + + if method == 'histogram': + uniform_lut = cls._build_histogram_transform( + image_target, image_ref, device, stats_mode, target_index, B) + + for i in range(B): + src = image_target[i].to(device, dtype=torch.float32).permute(2, 0, 1) + src_flat = src.reshape(C, -1) + if uniform_lut is not None: + lut = uniform_lut + else: + ri = min(i, B_ref - 1) + ref = image_ref[ri].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1) + lut = cls._histogram_lut(src_flat, ref) + bin_idx = (src_flat * 255).long().clamp(0, 255) + matched = lut.gather(1, bin_idx).view(C, H, W) + result = matched if strength == 1.0 else torch.lerp(src, matched, strength) + out[i] = result.permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype) + pbar.update(1) + else: + transform = cls._build_lab_transform(image_target, image_ref, device, stats_mode, target_index, is_reinhard=method == "reinhard_lab") + + for i in range(B): + src_frame = cls._to_lab(image_target, i, device) + corrected = transform(src_frame.view(C, -1), frame_idx=i) + if strength == 1.0: + result = kornia.color.lab_to_rgb(corrected.view(1, C, H, W)) + else: + result = kornia.color.lab_to_rgb(torch.lerp(src_frame, corrected.view(1, C, H, W), strength)) + out[i] = result.squeeze(0).permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype) + pbar.update(1) + + return io.NodeOutput(out) + + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -673,6 +896,7 @@ class PostProcessingExtension(ComfyExtension): BatchImagesNode, BatchMasksNode, BatchLatentsNode, + ColorTransfer, # BatchImagesMasksLatentsNode, ] diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index b0a6f279d..0a1558f2b 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -11,7 +11,7 @@ class PreviewAny(): "required": {"source": (IO.ANY, {})}, } - RETURN_TYPES = () + RETURN_TYPES = (IO.STRING,) FUNCTION = "main" OUTPUT_NODE = True @@ -33,7 +33,7 @@ class PreviewAny(): except Exception: value = 'source exists, but could not be serialized.' - return {"ui": {"text": (value,)}} + return {"ui": {"text": (value,)}, "result": (value,)} NODE_CLASS_MAPPINGS = { "PreviewAny": PreviewAny, diff --git a/comfy_extras/nodes_rtdetr.py b/comfy_extras/nodes_rtdetr.py index 61307e268..7feaf3ab3 100644 --- a/comfy_extras/nodes_rtdetr.py +++ b/comfy_extras/nodes_rtdetr.py @@ -32,10 +32,12 @@ class RTDETR_detect(io.ComfyNode): def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput: B, H, W, C = image.shape - image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled") - comfy.model_management.load_model_gpu(model) - results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts + results = [] + for i in range(0, B, 32): + batch = image[i:i + 32] + image_in = comfy.utils.common_upscale(batch.movedim(-1, 1), 640, 640, "bilinear", crop="disabled") + results.extend(model.model.diffusion_model(image_in, (W, H))) all_bbox_dicts = [] diff --git a/comfy_extras/nodes_sdpose.py b/comfy_extras/nodes_sdpose.py index 46b5fb226..7d54967d5 100644 --- a/comfy_extras/nodes_sdpose.py +++ b/comfy_extras/nodes_sdpose.py @@ -1,5 +1,6 @@ import torch import comfy.utils +import comfy.model_management import numpy as np import math import colorsys @@ -410,7 +411,9 @@ class SDPoseDrawKeypoints(io.ComfyNode): pose_outputs.append(canvas) pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0) - final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0 + final_pose_output = torch.from_numpy(pose_outputs_np).to( + device=comfy.model_management.intermediate_device(), + dtype=comfy.model_management.intermediate_dtype()) / 255.0 return io.NodeOutput(final_pose_output) class SDPoseKeypointExtractor(io.ComfyNode): @@ -459,6 +462,27 @@ class SDPoseKeypointExtractor(io.ComfyNode): model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768 model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024 + def _resize_to_model(imgs): + """Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left).""" + h, w = imgs.shape[-3], imgs.shape[-2] + scale = min(model_h / h, model_w / w) + sh, sw = int(round(h * scale)), int(round(w * scale)) + pt, pl = (model_h - sh) // 2, (model_w - sw) // 2 + chw = imgs.permute(0, 3, 1, 2).float() + scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled") + padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) + padded[:, :, pt:pt + sh, pl:pl + sw] = scaled + return padded.permute(0, 2, 3, 1), scale, pt, pl + + def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0): + """Remap keypoints from model space back to original image space.""" + kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) + invalid = kp[..., 0] < 0 + kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x + kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y + kp[invalid] = -1 + return kp + def _run_on_latent(latent_batch): """Run one forward pass and return (keypoints_list, scores_list) for the batch.""" nonlocal captured_feat @@ -504,36 +528,19 @@ class SDPoseKeypointExtractor(io.ComfyNode): if x2 <= x1 or y2 <= y1: continue - crop_h_px, crop_w_px = y2 - y1, x2 - x1 crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C) - - # scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size. - scale = min(model_h / crop_h_px, model_w / crop_w_px) - scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale)) - pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2 - - crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW - scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled") - padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) - padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled - crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC + crop_resized, scale, pad_top, pad_left = _resize_to_model(crop) latent_crop = vae.encode(crop_resized) kp_batch, sc_batch = _run_on_latent(latent_crop) - kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space - - # remove padding offset, undo scale, offset to full-image coordinates. - kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) - kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1 - kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1 - + kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1) img_keypoints.append(kp) - img_scores.append(sc) + img_scores.append(sc_batch[0]) else: - # No bboxes for this image – run on the full image - latent_img = vae.encode(img) + img_resized, scale, pad_top, pad_left = _resize_to_model(img) + latent_img = vae.encode(img_resized) kp_batch, sc_batch = _run_on_latent(latent_img) - img_keypoints.append(kp_batch[0]) + img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left)) img_scores.append(sc_batch[0]) all_keypoints.append(img_keypoints) @@ -541,19 +548,16 @@ class SDPoseKeypointExtractor(io.ComfyNode): pbar.update(1) else: # full-image mode, batched - tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints") - for batch_start in range(0, total_images, batch_size): - batch_end = min(batch_start + batch_size, total_images) - latent_batch = vae.encode(image[batch_start:batch_end]) - + for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"): + batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size]) + latent_batch = vae.encode(batch_resized) kp_batch, sc_batch = _run_on_latent(latent_batch) for kp, sc in zip(kp_batch, sc_batch): - all_keypoints.append([kp]) + all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)]) all_scores.append([sc]) - tqdm_pbar.update(1) - pbar.update(batch_end - batch_start) + pbar.update(len(kp_batch)) openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width) return io.NodeOutput(openpose_frames) diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 75a8bb4ee..604076c4e 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -1,4 +1,5 @@ import re +import json from typing_extensions import override from comfy_api.latest import ComfyExtension, io @@ -375,6 +376,39 @@ class RegexReplace(io.ComfyNode): return io.NodeOutput(result) +class JsonExtractString(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="JsonExtractString", + display_name="Extract String from JSON", + category="utils/string", + search_aliases=["json", "extract json", "parse json", "json value", "read json"], + inputs=[ + io.String.Input("json_string", multiline=True), + io.String.Input("key", multiline=False), + ], + outputs=[ + io.String.Output(), + ] + ) + + @classmethod + def execute(cls, json_string, key): + try: + data = json.loads(json_string) + if isinstance(data, dict) and key in data: + value = data[key] + if value is None: + return io.NodeOutput("") + + return io.NodeOutput(str(value)) + + return io.NodeOutput("") + + except (json.JSONDecodeError, TypeError): + return io.NodeOutput("") + class StringExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -390,6 +424,7 @@ class StringExtension(ComfyExtension): RegexMatch, RegexExtract, RegexReplace, + JsonExtractString, ] async def comfy_entrypoint() -> StringExtension: diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py index f1aeb63fa..1f46d820f 100644 --- a/comfy_extras/nodes_textgen.py +++ b/comfy_extras/nodes_textgen.py @@ -35,6 +35,7 @@ class TextGenerate(io.ComfyNode): io.Int.Input("max_length", default=256, min=1, max=2048), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."), + io.Boolean.Input("use_default_template", optional=True, default=True, tooltip="Use the built in system prompt/template if the model has one.", advanced=True), ], outputs=[ io.String.Output(display_name="generated_text"), @@ -42,9 +43,9 @@ class TextGenerate(io.ComfyNode): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: - tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1, thinking=thinking) + tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking) # Get sampling parameters from dynamic combo do_sample = sampling_mode.get("sampling_mode") == "on" @@ -160,12 +161,12 @@ class TextGenerateLTX2Prompt(TextGenerate): ) @classmethod - def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False) -> io.NodeOutput: + def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: if image is None: formatted_prompt = f"system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}\nuser\nUser Raw Input Prompt: {prompt}.\nmodel\n" else: formatted_prompt = f"system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}\nuser\n\n\n\nUser Raw Input Prompt: {prompt}.\nmodel\n" - return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking) + return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template) class TextgenExtension(ComfyExtension): diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index db4f9d231..d3ee3f1c1 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -6,6 +6,7 @@ import comfy.utils import folder_paths from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import comfy.model_management try: from spandrel_extra_arches import EXTRA_REGISTRY @@ -78,13 +79,15 @@ class ImageUpscaleWithModel(io.ComfyNode): tile = 512 overlap = 32 + output_device = comfy.model_management.intermediate_device() + oom = True try: while oom: try: steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) oom = False except Exception as e: model_management.raise_non_oom(e) @@ -94,7 +97,7 @@ class ImageUpscaleWithModel(io.ComfyNode): finally: upscale_model.to("cpu") - s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) + s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) return io.NodeOutput(s) upscale = execute # TODO: remove diff --git a/comfyui_version.py b/comfyui_version.py index 61d7672ca..2a1eb9905 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.18.1" +__version__ = "0.19.3" diff --git a/pyproject.toml b/pyproject.toml index 1fc9402a1..8fa92ecbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.18.1" +version = "0.19.3" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index c60219a88..a8e4f9bf6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.42.8 -comfyui-workflow-templates==0.9.44 +comfyui-frontend-package==1.42.11 +comfyui-workflow-templates==0.9.57 comfyui-embedded-docs==0.4.3 torch torchsde