Merge remote-tracking branch 'upstream/master' into scail

This commit is contained in:
kijai 2026-02-26 12:04:36 +02:00
commit da10a522cc
25 changed files with 468 additions and 62 deletions

View File

@ -46,6 +46,8 @@ class NodeReplaceManager:
connections: dict[str, list[tuple[str, str, int]]] = {} connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set() need_replacement: set[str] = set()
for node_number, node_struct in prompt.items(): for node_number, node_struct in prompt.items():
if "class_type" not in node_struct or "inputs" not in node_struct:
continue
class_type = node_struct["class_type"] class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement # need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type): if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):

View File

@ -4,6 +4,25 @@ import comfy.utils
import logging import logging
def is_equal(x, y):
if torch.is_tensor(x) and torch.is_tensor(y):
return torch.equal(x, y)
elif isinstance(x, dict) and isinstance(y, dict):
if x.keys() != y.keys():
return False
return all(is_equal(x[k], y[k]) for k in x)
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
if type(x) is not type(y) or len(x) != len(y):
return False
return all(is_equal(a, b) for a, b in zip(x, y))
else:
try:
return x == y
except Exception:
logging.warning("comparison issue with COND")
return False
class CONDRegular: class CONDRegular:
def __init__(self, cond): def __init__(self, cond):
self.cond = cond self.cond = cond
@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
return self._copy_with(self.cond) return self._copy_with(self.cond)
def can_concat(self, other): def can_concat(self, other):
if self.cond != other.cond: if not is_equal(self.cond, other.cond):
return False return False
return True return True

View File

@ -218,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module):
def forward( def forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True) run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True) run_ax = transformer_options.get("run_ax", True)
@ -234,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module):
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2))) vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
del vshift_msa, vscale_msa del vshift_msa, vscale_msa
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
del norm_vx del norm_vx
# video cross-attention # video cross-attention
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
@ -726,7 +726,7 @@ class LTXAVModel(LTXVModel):
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
def _process_transformer_blocks( def _process_transformer_blocks(
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
): ):
vx = x[0] vx = x[0]
ax = x[1] ax = x[1]
@ -770,6 +770,7 @@ class LTXAVModel(LTXVModel):
v_cross_gate_timestep=args["v_cross_gate_timestep"], v_cross_gate_timestep=args["v_cross_gate_timestep"],
a_cross_gate_timestep=args["a_cross_gate_timestep"], a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"], transformer_options=args["transformer_options"],
self_attention_mask=args.get("self_attention_mask"),
) )
return out return out
@ -790,6 +791,7 @@ class LTXAVModel(LTXVModel):
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options, "transformer_options": transformer_options,
"self_attention_mask": self_attention_mask,
}, },
{"original_block": block_wrap}, {"original_block": block_wrap},
) )
@ -811,6 +813,7 @@ class LTXAVModel(LTXVModel):
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options, transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
) )
return [vx, ax] return [vx, ax]

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
import functools import functools
import logging
import math import math
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
@ -14,6 +15,8 @@ import comfy.ldm.common_dit
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
logger = logging.getLogger(__name__)
def _log_base(x, base): def _log_base(x, base):
return np.log(x) / np.log(base) return np.log(x) / np.log(base)
@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
attn1_input = comfy.ldm.common_dit.rms_norm(x) attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa) x.addcmul_(attn1_input, gate_msa)
del attn1_input del attn1_input
@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
"""Process input data. Must be implemented by subclasses.""" """Process input data. Must be implemented by subclasses."""
pass pass
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
"""Build self-attention mask for per-guide attention attenuation.
Base implementation returns None (no attenuation). Subclasses that
support guide-based attention control should override this.
"""
return None
@abstractmethod @abstractmethod
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
"""Process transformer blocks. Must be implemented by subclasses.""" """Process transformer blocks. Must be implemented by subclasses."""
pass pass
@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC):
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
# Build self-attention mask for per-guide attenuation
self_attention_mask = self._build_guide_self_attention_mask(
x, transformer_options, merged_args
)
# Process transformer blocks # Process transformer blocks
x = self._process_transformer_blocks( x = self._process_transformer_blocks(
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args x, context, attention_mask, timestep, pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
**merged_args,
) )
# Process output # Process output
@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel):
pixel_coords = pixel_coords[:, :, grid_mask, ...] pixel_coords = pixel_coords[:, :, grid_mask, ...]
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
# Compute per-guide surviving token counts from guide_attention_entries.
# Each entry tracks one guide reference; they are appended in order and
# their pre_filter_counts partition the kf_grid_mask.
guide_entries = kwargs.get("guide_attention_entries", None)
if guide_entries:
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
if total_pfc != len(kf_grid_mask):
raise ValueError(
f"guide pre_filter_counts ({total_pfc}) != "
f"keyframe grid mask length ({len(kf_grid_mask)})"
)
resolved_entries = []
offset = 0
for entry in guide_entries:
pfc = entry["pre_filter_count"]
entry_mask = kf_grid_mask[offset:offset + pfc]
surviving = int(entry_mask.sum().item())
resolved_entries.append({
**entry,
"surviving_count": surviving,
})
offset += pfc
additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
x = self.patchify_proj(x) x = self.patchify_proj(x)
return x, pixel_coords, additional_args return x, pixel_coords, additional_args
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
"""Build self-attention mask for per-guide attention attenuation.
Reads resolved_guide_entries from merged_args (computed in _process_input)
to build a log-space additive bias mask that attenuates noisy guide
attention for each guide reference independently.
Returns None if no attenuation is needed (all strengths == 1.0 and no
spatial masks, or no guide tokens).
"""
if isinstance(x, list):
# AV model: x = [vx, ax]; use vx for token count and device
total_tokens = x[0].shape[1]
device = x[0].device
dtype = x[0].dtype
else:
total_tokens = x.shape[1]
device = x.device
dtype = x.dtype
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
if num_guide_tokens == 0:
return None
resolved_entries = merged_args.get("resolved_guide_entries", None)
if not resolved_entries:
return None
# Check if any attenuation is actually needed
needs_attenuation = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries
)
if not needs_attenuation:
return None
# Build per-guide-token weights for all tracked guide tokens.
# Guides are appended in order at the end of the sequence.
guide_start = total_tokens - num_guide_tokens
all_weights = []
total_tracked = 0
for entry in resolved_entries:
surviving = entry["surviving_count"]
if surviving == 0:
continue
strength = entry["strength"]
pixel_mask = entry.get("pixel_mask")
latent_shape = entry.get("latent_shape")
if pixel_mask is not None and latent_shape is not None:
f_lat, h_lat, w_lat = latent_shape
per_token = self._downsample_mask_to_latent(
pixel_mask.to(device=device, dtype=dtype),
f_lat, h_lat, w_lat,
)
# per_token shape: (B, f_lat*h_lat*w_lat).
# Collapse batch dim — the mask is assumed identical across the
# batch; validate and take the first element to get (1, tokens).
if per_token.shape[0] > 1:
ref = per_token[0]
for bi in range(1, per_token.shape[0]):
if not torch.equal(ref, per_token[bi]):
logger.warning(
"pixel_mask differs across batch elements; "
"using first element only."
)
break
per_token = per_token[:1]
# `surviving` is the post-grid_mask token count.
# Clamp to surviving to handle any mismatch safely.
n_weights = min(per_token.shape[1], surviving)
weights = per_token[:, :n_weights] * strength # (1, n_weights)
else:
weights = torch.full(
(1, surviving), strength, device=device, dtype=dtype
)
all_weights.append(weights)
total_tracked += weights.shape[1]
if not all_weights:
return None
# Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
if (tracked_weights >= 1.0).all():
return None
# Build the mask: guide tokens are at the end of the sequence.
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
@staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
"""Downsample a pixel-space mask to per-token latent weights.
Args:
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
f_lat: Number of latent frames (pre-dilation original count).
h_lat: Latent height (pre-dilation original height).
w_lat: Latent width (pre-dilation original width).
Returns:
(B, F_lat * H_lat * W_lat) flattened per-token weights.
"""
b = mask.shape[0]
f_pix = mask.shape[2]
# Spatial downsampling: area interpolation per frame
spatial_down = torch.nn.functional.interpolate(
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
size=(h_lat, w_lat),
mode="area",
)
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
# Temporal downsampling: first pixel frame maps to first latent frame,
# remaining pixel frames are averaged in groups for causal temporal structure.
first_frame = spatial_down[:, :, :1, :, :]
if f_pix > 1 and f_lat > 1:
remaining_pix = f_pix - 1
remaining_lat = f_lat - 1
t = remaining_pix // remaining_lat
if t < 1:
# Fewer pixel frames than latent frames — upsample by repeating
# the available pixel frames via nearest interpolation.
rest_flat = rearrange(
spatial_down[:, :, 1:, :, :],
"b 1 f h w -> (b h w) 1 f",
)
rest_up = torch.nn.functional.interpolate(
rest_flat, size=remaining_lat, mode="nearest",
)
rest = rearrange(
rest_up, "(b h w) 1 f -> b 1 f h w",
b=b, h=h_lat, w=w_lat,
)
else:
# Trim trailing pixel frames that don't fill a complete group
usable = remaining_lat * t
rest = rearrange(
spatial_down[:, :, 1:1 + usable, :, :],
"b 1 (f t) h w -> b 1 f t h w",
t=t,
)
rest = rest.mean(dim=3)
latent_mask = torch.cat([first_frame, rest], dim=2)
elif f_lat > 1:
# Single pixel frame but multiple latent frames — repeat the
# single frame across all latent frames.
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
else:
latent_mask = first_frame
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV.""" """Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel):
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block( x = block(
@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel):
timestep=timestep, timestep=timestep,
pe=pe, pe=pe,
transformer_options=transformer_options, transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
) )
return x return x

View File

@ -459,6 +459,7 @@ class WanVAE(nn.Module):
attn_scales=[], attn_scales=[],
temperal_downsample=[True, True, False], temperal_downsample=[True, True, False],
image_channels=3, image_channels=3,
conv_out_channels=3,
dropout=0.0): dropout=0.0):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -474,7 +475,7 @@ class WanVAE(nn.Module):
attn_scales, self.temperal_downsample, dropout) attn_scales, self.temperal_downsample, dropout)
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
self.conv2 = CausalConv3d(z_dim, z_dim, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1)
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks, self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
attn_scales, self.temperal_upsample, dropout) attn_scales, self.temperal_upsample, dropout)
def encode(self, x): def encode(self, x):

View File

@ -76,6 +76,7 @@ class ModelType(Enum):
FLUX = 8 FLUX = 8
IMG_TO_IMG = 9 IMG_TO_IMG = 9
FLOW_COSMOS = 10 FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
@ -108,6 +109,8 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLOW_COSMOS: elif model_type == ModelType.FLOW_COSMOS:
c = comfy.model_sampling.COSMOS_RFLOW c = comfy.model_sampling.COSMOS_RFLOW
s = comfy.model_sampling.ModelSamplingCosmosRFlow s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@ -971,6 +974,10 @@ class LTXV(BaseModel):
if keyframe_idxs is not None: if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
@ -1023,6 +1030,10 @@ class LTXAV(BaseModel):
if latent_shapes is not None: if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
guide_attention_entries = kwargs.get("guide_attention_entries", None)
if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@ -1466,6 +1477,12 @@ class WAN22(WAN21):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image return latent_image
class WAN21_FlowRVS(WAN21):
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
model_config.unet_config["model_type"] = "t2v"
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
self.image_to_video = image_to_video
class WAN21_SCAIL(WAN21): class WAN21_SCAIL(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)

View File

@ -511,6 +511,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if ref_conv_weight is not None: if ref_conv_weight is not None:
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
if metadata is not None and "config" in metadata:
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
return dit_config return dit_config
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D

View File

@ -271,6 +271,7 @@ class ModelPatcher:
self.is_clip = False self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | None = None
if not hasattr(self.model, 'model_loaded_weight_memory'): if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
@ -307,8 +308,15 @@ class ModelPatcher:
def get_free_memory(self, device): def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device) return comfy.model_management.get_free_memory(device)
def clone(self): def clone(self, disable_dynamic=False):
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) class_ = self.__class__
model = self.model
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
model = temp_model_patcher.model
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -362,6 +370,8 @@ class ModelPatcher:
n.is_clip = self.is_clip n.is_clip = self.is_clip
n.hook_mode = self.hook_mode n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n) callback(self, n)
return n return n

View File

@ -83,6 +83,16 @@ class IMG_TO_IMG(X0):
def calculate_input(self, sigma, noise): def calculate_input(self, sigma, noise):
return noise return noise
class IMG_TO_IMG_FLOW(CONST):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
return latent_image
def inverse_noise_scaling(self, sigma, latent):
return 1.0 - latent
class COSMOS_RFLOW: class COSMOS_RFLOW:
def calculate_input(self, sigma, noise): def calculate_input(self, sigma, noise):
sigma = (sigma / (sigma + 1)) sigma = (sigma / (sigma + 1))

View File

@ -19,7 +19,7 @@
import torch import torch
import logging import logging
import comfy.model_management import comfy.model_management
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram from comfy.cli_args import args, PerformanceFeature
import comfy.float import comfy.float
import json import json
import comfy.memory_management import comfy.memory_management
@ -296,7 +296,7 @@ class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp): class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
super().__init__(in_features, out_features, bias, device, dtype) super().__init__(in_features, out_features, bias, device, dtype)
return return
@ -317,7 +317,7 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs): strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
@ -827,6 +827,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
else: else:
sd = {} sd = {}
if not hasattr(self, 'weight'):
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
return sd
if self.bias is not None: if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias sd["{}bias".format(prefix)] = self.bias

View File

@ -694,8 +694,9 @@ class VAE:
self.latent_dim = 3 self.latent_dim = 3
self.latent_channels = 16 self.latent_channels = 16
self.output_channels = sd["encoder.conv1.weight"].shape[1] self.output_channels = sd["encoder.conv1.weight"].shape[1]
self.conv_out_channels = sd["decoder.head.2.weight"].shape[0]
self.pad_channel_value = 1.0 self.pad_channel_value = 1.0
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
@ -1530,14 +1531,24 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None: if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
return out return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic)
return model
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
@ -1586,7 +1597,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_model: if output_model:
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae: if output_vae:
@ -1637,7 +1649,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False):
""" """
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@ -1721,7 +1733,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
model_config.optimizations["fp8"] = True model_config.optimizations["fp8"] = True
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
if not model_management.is_device_cpu(offload_device): if not model_management.is_device_cpu(offload_device):
model.to(offload_device) model.to(offload_device)
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
@ -1730,12 +1743,13 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
logging.info("left over keys in diffusion model: {}".format(left_over)) logging.info("left over keys in diffusion model: {}".format(left_over))
return model_patcher return model_patcher
def load_diffusion_model(unet_path, model_options={}): def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model return model
def load_unet(unet_path, dtype=None): def load_unet(unet_path, dtype=None):

View File

@ -1256,6 +1256,16 @@ class WAN22_T2V(WAN21_T2V):
out = model_base.WAN22(self, image_to_video=True, device=device) out = model_base.WAN22(self, image_to_video=True, device=device)
return out return out
class WAN21_FlowRVS(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "flow_rvs",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
return out
class WAN21_SCAIL(WAN21_T2V): class WAN21_SCAIL(WAN21_T2V):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -1677,6 +1687,6 @@ class ACEStep15(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
import torch import torch
import comfy.utils import comfy.utils
import math import math
import itertools
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -72,7 +73,7 @@ class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None) tokenizer = tokenizer_data.get("spiece_model", None)
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106} special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
@ -199,8 +200,10 @@ class LTXAVTEModel(torch.nn.Module):
constant /= 2.0 constant /= 2.0
token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs])
num_tokens = max(num_tokens, 64)
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m
num_tokens = max(num_tokens, 642)
return num_tokens * constant * 1024 * 1024 return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):

View File

@ -29,7 +29,7 @@ import itertools
from torch.nn.functional import interpolate from torch.nn.functional import interpolate
from tqdm.auto import trange from tqdm.auto import trange
from einops import rearrange from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram from comfy.cli_args import args
import json import json
import time import time
import mmap import mmap
@ -113,7 +113,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
metadata = None metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
try: try:
if enables_dynamic_vram(): if comfy.memory_management.aimdo_enabled:
sd, metadata = load_safetensors(ckpt) sd, metadata = load_safetensors(ckpt)
if not return_metadata: if not return_metadata:
metadata = None metadata = None

View File

@ -27,6 +27,7 @@ class Seedream4TaskCreationRequest(BaseModel):
sequential_image_generation: str = Field("disabled") sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(False) watermark: bool = Field(False)
output_format: str | None = None
class ImageTaskCreationResponse(BaseModel): class ImageTaskCreationResponse(BaseModel):
@ -106,6 +107,7 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2496x1664 (3:2)", 2496, 1664), ("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496), ("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296), ("3024x1296 (21:9)", 3024, 1296),
("3072x3072 (1:1)", 3072, 3072),
("4096x4096 (1:1)", 4096, 4096), ("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None), ("Custom", None, None),
] ]

View File

@ -37,6 +37,12 @@ from comfy_api_nodes.util import (
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
SEEDREAM_MODELS = {
"seedream 5.0 lite": "seedream-5-0-260128",
"seedream-4-5-251128": "seedream-4-5-251128",
"seedream-4-0-250828": "seedream-4-0-250828",
}
# Long-running tasks endpoints(e.g., video) # Long-running tasks endpoints(e.g., video)
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
@ -180,14 +186,13 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="ByteDanceSeedreamNode", node_id="ByteDanceSeedreamNode",
display_name="ByteDance Seedream 4.5", display_name="ByteDance Seedream 5.0",
category="api node/image/ByteDance", category="api node/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=["seedream-4-5-251128", "seedream-4-0-250828"], options=list(SEEDREAM_MODELS.keys()),
tooltip="Model name",
), ),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -198,7 +203,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
IO.Image.Input( IO.Image.Input(
"image", "image",
tooltip="Input image(s) for image-to-image generation. " tooltip="Input image(s) for image-to-image generation. "
"List of 1-10 images for single or multi-reference generation.", "Reference image(s) for single or multi-reference generation.",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
@ -210,8 +215,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
"width", "width",
default=2048, default=2048,
min=1024, min=1024,
max=4096, max=6240,
step=8, step=2,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
optional=True, optional=True,
), ),
@ -219,8 +224,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
"height", "height",
default=2048, default=2048,
min=1024, min=1024,
max=4096, max=4992,
step=8, step=2,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
optional=True, optional=True,
), ),
@ -283,7 +288,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["model"]), depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr=""" expr="""
( (
$price := $contains(widgets.model, "seedream-4-5-251128") ? 0.04 : 0.03; $price := $contains(widgets.model, "5.0 lite") ? 0.035 :
$contains(widgets.model, "4-5") ? 0.04 : 0.03;
{ {
"type":"usd", "type":"usd",
"usd": $price, "usd": $price,
@ -309,6 +315,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
watermark: bool = False, watermark: bool = False,
fail_on_partial: bool = True, fail_on_partial: bool = True,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
model = SEEDREAM_MODELS[model]
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
w = h = None w = h = None
for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4: for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4:
@ -318,15 +325,12 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
if w is None or h is None: if w is None or h is None:
w, h = width, height w, h = width, height
if not (1024 <= w <= 4096) or not (1024 <= h <= 4096):
raise ValueError(
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
)
out_num_pixels = w * h out_num_pixels = w * h
mp_provided = out_num_pixels / 1_000_000.0 mp_provided = out_num_pixels / 1_000_000.0
if "seedream-4-5" in model and out_num_pixels < 3686400: if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400:
raise ValueError( raise ValueError(
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, " f"Minimum image resolution for the selected model is 3.68MP, "
f"but {mp_provided:.2f}MP provided." f"but {mp_provided:.2f}MP provided."
) )
if "seedream-4-0" in model and out_num_pixels < 921600: if "seedream-4-0" in model and out_num_pixels < 921600:
@ -334,9 +338,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
f"Minimum image resolution that the selected model can generate is 0.92MP, " f"Minimum image resolution that the selected model can generate is 0.92MP, "
f"but {mp_provided:.2f}MP provided." f"but {mp_provided:.2f}MP provided."
) )
max_pixels = 10_404_496 if "seedream-5-0" in model else 16_777_216
if out_num_pixels > max_pixels:
raise ValueError(
f"Maximum image resolution for the selected model is {max_pixels / 1_000_000:.2f}MP, "
f"but {mp_provided:.2f}MP provided."
)
n_input_images = get_number_of_images(image) if image is not None else 0 n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10: max_num_of_images = 14 if model == "seedream-5-0-260128" else 10
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") if n_input_images > max_num_of_images:
raise ValueError(
f"Maximum of {max_num_of_images} reference images are supported, but {n_input_images} received."
)
if sequential_image_generation == "auto" and n_input_images + max_images > 15: if sequential_image_generation == "auto" and n_input_images + max_images > 15:
raise ValueError( raise ValueError(
"The maximum number of generated images plus the number of reference images cannot exceed 15." "The maximum number of generated images plus the number of reference images cannot exceed 15."
@ -364,6 +377,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
sequential_image_generation=sequential_image_generation, sequential_image_generation=sequential_image_generation,
sequential_image_generation_options=Seedream4Options(max_images=max_images), sequential_image_generation_options=Seedream4Options(max_images=max_images),
watermark=watermark, watermark=watermark,
output_format="png" if model == "seedream-5-0-260128" else None,
), ),
) )
if len(response.data) == 1: if len(response.data) == 1:

View File

@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode):
generate = execute # TODO: remove generate = execute # TODO: remove
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
"""Append a guide_attention_entry to both positive and negative conditioning.
Each entry tracks one guide reference for per-reference attention control.
Entries are derived independently from each conditioning to avoid cross-contamination.
"""
new_entry = {
"pre_filter_count": pre_filter_count,
"strength": strength,
"pixel_mask": None,
"latent_shape": latent_shape,
}
results = []
for cond in (positive, negative):
# Read existing entries from this specific conditioning
existing = []
for t in cond:
found = t[1].get("guide_attention_entries", None)
if found is not None:
existing = found
break
# Shallow copy and append (no deepcopy needed — entries contain
# only scalars and None for pixel_mask at this call site).
entries = [*existing, new_entry]
results.append(node_helpers.conditioning_set_values(
cond, {"guide_attention_entries": entries}
))
return results[0], results[1]
def conditioning_get_any_value(conditioning, key, default=None): def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning: for t in conditioning:
if key in t[1]: if key in t[1]:
@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode):
scale_factors, scale_factors,
) )
# Track this guide for per-reference attention control.
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
positive, negative = _append_guide_attention_entry(
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove generate = execute # TODO: remove
@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode):
latent_image = latent_image[:, :, :-num_keyframes] latent_image = latent_image[:, :, :-num_keyframes]
noise_mask = noise_mask[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes]
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) positive = node_helpers.conditioning_set_values(positive, {
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) "keyframe_idxs": None,
"guide_attention_entries": None,
})
negative = node_helpers.conditioning_set_values(negative, {
"keyframe_idxs": None,
"guide_attention_entries": None,
})
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})

View File

@ -52,7 +52,7 @@ class ModelSamplingDiscrete:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],), "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],),
"zsnr": ("BOOLEAN", {"default": False, "advanced": True}), "zsnr": ("BOOLEAN", {"default": False, "advanced": True}),
}} }}
@ -76,6 +76,8 @@ class ModelSamplingDiscrete:
sampling_type = comfy.model_sampling.X0 sampling_type = comfy.model_sampling.X0
elif sampling == "img_to_img": elif sampling == "img_to_img":
sampling_type = comfy.model_sampling.IMG_TO_IMG sampling_type = comfy.model_sampling.IMG_TO_IMG
elif sampling == "img_to_img_flow":
sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW
class ModelSamplingAdvanced(sampling_base, sampling_type): class ModelSamplingAdvanced(sampling_base, sampling_type):
pass pass

View File

@ -79,7 +79,6 @@ class Blur(io.ComfyNode):
node_id="ImageBlur", node_id="ImageBlur",
display_name="Image Blur", display_name="Image Blur",
category="image/postprocessing", category="image/postprocessing",
essentials_category="Image Tools",
inputs=[ inputs=[
io.Image.Input("image"), io.Image.Input("image"),
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1), io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
@ -568,6 +567,7 @@ class BatchImagesNode(io.ComfyNode):
node_id="BatchImagesNode", node_id="BatchImagesNode",
display_name="Batch Images", display_name="Batch Images",
category="image", category="image",
essentials_category="Image Tools",
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
inputs=[ inputs=[
io.Autogrow.Input("images", template=autogrow_template) io.Autogrow.Input("images", template=autogrow_template)

View File

@ -25,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model, backend) -> io.NodeOutput: def execute(cls, model, backend) -> io.NodeOutput:
m = model.clone() m = model.clone(disable_dynamic=True)
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict}) set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
return io.NodeOutput(m) return io.NodeOutput(m)

View File

@ -147,7 +147,6 @@ class GetVideoComponents(io.ComfyNode):
search_aliases=["extract frames", "split video", "video to images", "demux"], search_aliases=["extract frames", "split video", "video to images", "demux"],
display_name="Get Video Components", display_name="Get Video Components",
category="image/video", category="image/video",
essentials_category="Video Tools",
description="Extracts all components from a video: frames, audio, and framerate.", description="Extracts all components from a video: frames, audio, and framerate.",
inputs=[ inputs=[
io.Video.Input("video", tooltip="The video to extract components from."), io.Video.Input("video", tooltip="The video to extract components from."),
@ -218,6 +217,7 @@ class VideoSlice(io.ComfyNode):
"start time", "start time",
], ],
category="image/video", category="image/video",
essentials_category="Video Tools",
inputs=[ inputs=[
io.Video.Input("video"), io.Video.Input("video"),
io.Float.Input( io.Float.Input(

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.14.1" __version__ = "0.15.0"

View File

@ -1925,7 +1925,6 @@ class ImageInvert:
class ImageBatch: class ImageBatch:
SEARCH_ALIASES = ["combine images", "merge images", "stack images"] SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
ESSENTIALS_CATEGORY = "Image Tools"
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.14.1" version = "0.15.0"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.39.16 comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.2 comfyui-workflow-templates==0.9.3
comfyui-embedded-docs==0.4.1 comfyui-embedded-docs==0.4.3
torch torch
torchsde torchsde
torchvision torchvision
@ -22,7 +22,7 @@ alembic
SQLAlchemy SQLAlchemy
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.7 comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.0 comfy-aimdo>=0.2.2
requests requests
#non essential dependencies: #non essential dependencies: