mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-13 01:30:32 +08:00
Merge remote-tracking branch 'upstream/master' into scail
This commit is contained in:
commit
da10a522cc
@ -46,6 +46,8 @@ class NodeReplaceManager:
|
|||||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||||
need_replacement: set[str] = set()
|
need_replacement: set[str] = set()
|
||||||
for node_number, node_struct in prompt.items():
|
for node_number, node_struct in prompt.items():
|
||||||
|
if "class_type" not in node_struct or "inputs" not in node_struct:
|
||||||
|
continue
|
||||||
class_type = node_struct["class_type"]
|
class_type = node_struct["class_type"]
|
||||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||||
|
|||||||
@ -4,6 +4,25 @@ import comfy.utils
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def is_equal(x, y):
|
||||||
|
if torch.is_tensor(x) and torch.is_tensor(y):
|
||||||
|
return torch.equal(x, y)
|
||||||
|
elif isinstance(x, dict) and isinstance(y, dict):
|
||||||
|
if x.keys() != y.keys():
|
||||||
|
return False
|
||||||
|
return all(is_equal(x[k], y[k]) for k in x)
|
||||||
|
elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)):
|
||||||
|
if type(x) is not type(y) or len(x) != len(y):
|
||||||
|
return False
|
||||||
|
return all(is_equal(a, b) for a, b in zip(x, y))
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return x == y
|
||||||
|
except Exception:
|
||||||
|
logging.warning("comparison issue with COND")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class CONDRegular:
|
class CONDRegular:
|
||||||
def __init__(self, cond):
|
def __init__(self, cond):
|
||||||
self.cond = cond
|
self.cond = cond
|
||||||
@ -84,7 +103,7 @@ class CONDConstant(CONDRegular):
|
|||||||
return self._copy_with(self.cond)
|
return self._copy_with(self.cond)
|
||||||
|
|
||||||
def can_concat(self, other):
|
def can_concat(self, other):
|
||||||
if self.cond != other.cond:
|
if not is_equal(self.cond, other.cond):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -218,7 +218,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None,
|
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
run_vx = transformer_options.get("run_vx", True)
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
run_ax = transformer_options.get("run_ax", True)
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
@ -234,7 +234,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2)))
|
||||||
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
|
||||||
del vshift_msa, vscale_msa
|
del vshift_msa, vscale_msa
|
||||||
attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options)
|
attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||||
del norm_vx
|
del norm_vx
|
||||||
# video cross-attention
|
# video cross-attention
|
||||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||||
@ -726,7 +726,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
|
||||||
|
|
||||||
def _process_transformer_blocks(
|
def _process_transformer_blocks(
|
||||||
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
|
self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs
|
||||||
):
|
):
|
||||||
vx = x[0]
|
vx = x[0]
|
||||||
ax = x[1]
|
ax = x[1]
|
||||||
@ -770,6 +770,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
v_cross_gate_timestep=args["v_cross_gate_timestep"],
|
||||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||||
transformer_options=args["transformer_options"],
|
transformer_options=args["transformer_options"],
|
||||||
|
self_attention_mask=args.get("self_attention_mask"),
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -790,6 +791,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
|
||||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||||
"transformer_options": transformer_options,
|
"transformer_options": transformer_options,
|
||||||
|
"self_attention_mask": self_attention_mask,
|
||||||
},
|
},
|
||||||
{"original_block": block_wrap},
|
{"original_block": block_wrap},
|
||||||
)
|
)
|
||||||
@ -811,6 +813,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
|
||||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
|
self_attention_mask=self_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [vx, ax]
|
return [vx, ax]
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import functools
|
import functools
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
@ -14,6 +15,8 @@ import comfy.ldm.common_dit
|
|||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def _log_base(x, base):
|
def _log_base(x, base):
|
||||||
return np.log(x) / np.log(base)
|
return np.log(x) / np.log(base)
|
||||||
|
|
||||||
@ -415,12 +418,12 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||||
x.addcmul_(attn1_input, gate_msa)
|
x.addcmul_(attn1_input, gate_msa)
|
||||||
del attn1_input
|
del attn1_input
|
||||||
|
|
||||||
@ -638,8 +641,16 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
"""Process input data. Must be implemented by subclasses."""
|
"""Process input data. Must be implemented by subclasses."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||||
|
"""Build self-attention mask for per-guide attention attenuation.
|
||||||
|
|
||||||
|
Base implementation returns None (no attenuation). Subclasses that
|
||||||
|
support guide-based attention control should override this.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
|
||||||
"""Process transformer blocks. Must be implemented by subclasses."""
|
"""Process transformer blocks. Must be implemented by subclasses."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -788,9 +799,17 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
||||||
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
||||||
|
|
||||||
|
# Build self-attention mask for per-guide attenuation
|
||||||
|
self_attention_mask = self._build_guide_self_attention_mask(
|
||||||
|
x, transformer_options, merged_args
|
||||||
|
)
|
||||||
|
|
||||||
# Process transformer blocks
|
# Process transformer blocks
|
||||||
x = self._process_transformer_blocks(
|
x = self._process_transformer_blocks(
|
||||||
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
x, context, attention_mask, timestep, pe,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
self_attention_mask=self_attention_mask,
|
||||||
|
**merged_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process output
|
# Process output
|
||||||
@ -890,13 +909,243 @@ class LTXVModel(LTXBaseModel):
|
|||||||
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
||||||
|
|
||||||
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
||||||
|
|
||||||
|
# Compute per-guide surviving token counts from guide_attention_entries.
|
||||||
|
# Each entry tracks one guide reference; they are appended in order and
|
||||||
|
# their pre_filter_counts partition the kf_grid_mask.
|
||||||
|
guide_entries = kwargs.get("guide_attention_entries", None)
|
||||||
|
if guide_entries:
|
||||||
|
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
|
||||||
|
if total_pfc != len(kf_grid_mask):
|
||||||
|
raise ValueError(
|
||||||
|
f"guide pre_filter_counts ({total_pfc}) != "
|
||||||
|
f"keyframe grid mask length ({len(kf_grid_mask)})"
|
||||||
|
)
|
||||||
|
resolved_entries = []
|
||||||
|
offset = 0
|
||||||
|
for entry in guide_entries:
|
||||||
|
pfc = entry["pre_filter_count"]
|
||||||
|
entry_mask = kf_grid_mask[offset:offset + pfc]
|
||||||
|
surviving = int(entry_mask.sum().item())
|
||||||
|
resolved_entries.append({
|
||||||
|
**entry,
|
||||||
|
"surviving_count": surviving,
|
||||||
|
})
|
||||||
|
offset += pfc
|
||||||
|
additional_args["resolved_guide_entries"] = resolved_entries
|
||||||
|
|
||||||
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
||||||
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
||||||
|
|
||||||
|
# Total surviving guide tokens (all guides)
|
||||||
|
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
||||||
|
|
||||||
x = self.patchify_proj(x)
|
x = self.patchify_proj(x)
|
||||||
return x, pixel_coords, additional_args
|
return x, pixel_coords, additional_args
|
||||||
|
|
||||||
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
||||||
|
"""Build self-attention mask for per-guide attention attenuation.
|
||||||
|
|
||||||
|
Reads resolved_guide_entries from merged_args (computed in _process_input)
|
||||||
|
to build a log-space additive bias mask that attenuates noisy ↔ guide
|
||||||
|
attention for each guide reference independently.
|
||||||
|
|
||||||
|
Returns None if no attenuation is needed (all strengths == 1.0 and no
|
||||||
|
spatial masks, or no guide tokens).
|
||||||
|
"""
|
||||||
|
if isinstance(x, list):
|
||||||
|
# AV model: x = [vx, ax]; use vx for token count and device
|
||||||
|
total_tokens = x[0].shape[1]
|
||||||
|
device = x[0].device
|
||||||
|
dtype = x[0].dtype
|
||||||
|
else:
|
||||||
|
total_tokens = x.shape[1]
|
||||||
|
device = x.device
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
|
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
|
||||||
|
if num_guide_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
resolved_entries = merged_args.get("resolved_guide_entries", None)
|
||||||
|
if not resolved_entries:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if any attenuation is actually needed
|
||||||
|
needs_attenuation = any(
|
||||||
|
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
||||||
|
for e in resolved_entries
|
||||||
|
)
|
||||||
|
if not needs_attenuation:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build per-guide-token weights for all tracked guide tokens.
|
||||||
|
# Guides are appended in order at the end of the sequence.
|
||||||
|
guide_start = total_tokens - num_guide_tokens
|
||||||
|
all_weights = []
|
||||||
|
total_tracked = 0
|
||||||
|
|
||||||
|
for entry in resolved_entries:
|
||||||
|
surviving = entry["surviving_count"]
|
||||||
|
if surviving == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
strength = entry["strength"]
|
||||||
|
pixel_mask = entry.get("pixel_mask")
|
||||||
|
latent_shape = entry.get("latent_shape")
|
||||||
|
|
||||||
|
if pixel_mask is not None and latent_shape is not None:
|
||||||
|
f_lat, h_lat, w_lat = latent_shape
|
||||||
|
per_token = self._downsample_mask_to_latent(
|
||||||
|
pixel_mask.to(device=device, dtype=dtype),
|
||||||
|
f_lat, h_lat, w_lat,
|
||||||
|
)
|
||||||
|
# per_token shape: (B, f_lat*h_lat*w_lat).
|
||||||
|
# Collapse batch dim — the mask is assumed identical across the
|
||||||
|
# batch; validate and take the first element to get (1, tokens).
|
||||||
|
if per_token.shape[0] > 1:
|
||||||
|
ref = per_token[0]
|
||||||
|
for bi in range(1, per_token.shape[0]):
|
||||||
|
if not torch.equal(ref, per_token[bi]):
|
||||||
|
logger.warning(
|
||||||
|
"pixel_mask differs across batch elements; "
|
||||||
|
"using first element only."
|
||||||
|
)
|
||||||
|
break
|
||||||
|
per_token = per_token[:1]
|
||||||
|
# `surviving` is the post-grid_mask token count.
|
||||||
|
# Clamp to surviving to handle any mismatch safely.
|
||||||
|
n_weights = min(per_token.shape[1], surviving)
|
||||||
|
weights = per_token[:, :n_weights] * strength # (1, n_weights)
|
||||||
|
else:
|
||||||
|
weights = torch.full(
|
||||||
|
(1, surviving), strength, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
all_weights.append(weights)
|
||||||
|
total_tracked += weights.shape[1]
|
||||||
|
|
||||||
|
if not all_weights:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Concatenate per-token weights for all tracked guides
|
||||||
|
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
||||||
|
|
||||||
|
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
||||||
|
if (tracked_weights >= 1.0).all():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build the mask: guide tokens are at the end of the sequence.
|
||||||
|
# Tracked guides come first (in order), untracked follow.
|
||||||
|
return self._build_self_attention_mask(
|
||||||
|
total_tokens, num_guide_tokens, total_tracked,
|
||||||
|
tracked_weights, guide_start, device, dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
||||||
|
"""Downsample a pixel-space mask to per-token latent weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
|
||||||
|
f_lat: Number of latent frames (pre-dilation original count).
|
||||||
|
h_lat: Latent height (pre-dilation original height).
|
||||||
|
w_lat: Latent width (pre-dilation original width).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(B, F_lat * H_lat * W_lat) flattened per-token weights.
|
||||||
|
"""
|
||||||
|
b = mask.shape[0]
|
||||||
|
f_pix = mask.shape[2]
|
||||||
|
|
||||||
|
# Spatial downsampling: area interpolation per frame
|
||||||
|
spatial_down = torch.nn.functional.interpolate(
|
||||||
|
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
|
||||||
|
size=(h_lat, w_lat),
|
||||||
|
mode="area",
|
||||||
|
)
|
||||||
|
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
|
||||||
|
|
||||||
|
# Temporal downsampling: first pixel frame maps to first latent frame,
|
||||||
|
# remaining pixel frames are averaged in groups for causal temporal structure.
|
||||||
|
first_frame = spatial_down[:, :, :1, :, :]
|
||||||
|
if f_pix > 1 and f_lat > 1:
|
||||||
|
remaining_pix = f_pix - 1
|
||||||
|
remaining_lat = f_lat - 1
|
||||||
|
t = remaining_pix // remaining_lat
|
||||||
|
if t < 1:
|
||||||
|
# Fewer pixel frames than latent frames — upsample by repeating
|
||||||
|
# the available pixel frames via nearest interpolation.
|
||||||
|
rest_flat = rearrange(
|
||||||
|
spatial_down[:, :, 1:, :, :],
|
||||||
|
"b 1 f h w -> (b h w) 1 f",
|
||||||
|
)
|
||||||
|
rest_up = torch.nn.functional.interpolate(
|
||||||
|
rest_flat, size=remaining_lat, mode="nearest",
|
||||||
|
)
|
||||||
|
rest = rearrange(
|
||||||
|
rest_up, "(b h w) 1 f -> b 1 f h w",
|
||||||
|
b=b, h=h_lat, w=w_lat,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Trim trailing pixel frames that don't fill a complete group
|
||||||
|
usable = remaining_lat * t
|
||||||
|
rest = rearrange(
|
||||||
|
spatial_down[:, :, 1:1 + usable, :, :],
|
||||||
|
"b 1 (f t) h w -> b 1 f t h w",
|
||||||
|
t=t,
|
||||||
|
)
|
||||||
|
rest = rest.mean(dim=3)
|
||||||
|
latent_mask = torch.cat([first_frame, rest], dim=2)
|
||||||
|
elif f_lat > 1:
|
||||||
|
# Single pixel frame but multiple latent frames — repeat the
|
||||||
|
# single frame across all latent frames.
|
||||||
|
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
|
||||||
|
else:
|
||||||
|
latent_mask = first_frame
|
||||||
|
|
||||||
|
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
||||||
|
tracked_weights, guide_start, device, dtype):
|
||||||
|
"""Build a log-space additive self-attention bias mask.
|
||||||
|
|
||||||
|
Attenuates attention between noisy tokens and tracked guide tokens.
|
||||||
|
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_tokens: Total sequence length.
|
||||||
|
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
||||||
|
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
||||||
|
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
||||||
|
guide_start: Index where guide tokens begin in the sequence.
|
||||||
|
device: Target device.
|
||||||
|
dtype: Target dtype.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(1, 1, total_tokens, total_tokens) additive bias mask.
|
||||||
|
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
||||||
|
"""
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
||||||
|
tracked_end = guide_start + tracked_count
|
||||||
|
|
||||||
|
# Convert weights to log-space bias
|
||||||
|
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
||||||
|
log_w = torch.full_like(w, finfo.min)
|
||||||
|
positive_mask = w > 0
|
||||||
|
if positive_mask.any():
|
||||||
|
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
||||||
|
|
||||||
|
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
||||||
|
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
||||||
|
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
||||||
|
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
||||||
"""Process transformer blocks for LTXV."""
|
"""Process transformer blocks for LTXV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
@ -906,10 +1155,10 @@ class LTXVModel(LTXBaseModel):
|
|||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(
|
x = block(
|
||||||
@ -919,6 +1168,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
pe=pe,
|
pe=pe,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
|
self_attention_mask=self_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -459,6 +459,7 @@ class WanVAE(nn.Module):
|
|||||||
attn_scales=[],
|
attn_scales=[],
|
||||||
temperal_downsample=[True, True, False],
|
temperal_downsample=[True, True, False],
|
||||||
image_channels=3,
|
image_channels=3,
|
||||||
|
conv_out_channels=3,
|
||||||
dropout=0.0):
|
dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -474,7 +475,7 @@ class WanVAE(nn.Module):
|
|||||||
attn_scales, self.temperal_downsample, dropout)
|
attn_scales, self.temperal_downsample, dropout)
|
||||||
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
||||||
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
||||||
self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks,
|
self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks,
|
||||||
attn_scales, self.temperal_upsample, dropout)
|
attn_scales, self.temperal_upsample, dropout)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
|
|||||||
@ -76,6 +76,7 @@ class ModelType(Enum):
|
|||||||
FLUX = 8
|
FLUX = 8
|
||||||
IMG_TO_IMG = 9
|
IMG_TO_IMG = 9
|
||||||
FLOW_COSMOS = 10
|
FLOW_COSMOS = 10
|
||||||
|
IMG_TO_IMG_FLOW = 11
|
||||||
|
|
||||||
|
|
||||||
def model_sampling(model_config, model_type):
|
def model_sampling(model_config, model_type):
|
||||||
@ -108,6 +109,8 @@ def model_sampling(model_config, model_type):
|
|||||||
elif model_type == ModelType.FLOW_COSMOS:
|
elif model_type == ModelType.FLOW_COSMOS:
|
||||||
c = comfy.model_sampling.COSMOS_RFLOW
|
c = comfy.model_sampling.COSMOS_RFLOW
|
||||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||||
|
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||||
|
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||||
|
|
||||||
class ModelSampling(s, c):
|
class ModelSampling(s, c):
|
||||||
pass
|
pass
|
||||||
@ -971,6 +974,10 @@ class LTXV(BaseModel):
|
|||||||
if keyframe_idxs is not None:
|
if keyframe_idxs is not None:
|
||||||
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
|
||||||
|
|
||||||
|
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||||
|
if guide_attention_entries is not None:
|
||||||
|
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
|
||||||
@ -1023,6 +1030,10 @@ class LTXAV(BaseModel):
|
|||||||
if latent_shapes is not None:
|
if latent_shapes is not None:
|
||||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||||
|
|
||||||
|
guide_attention_entries = kwargs.get("guide_attention_entries", None)
|
||||||
|
if guide_attention_entries is not None:
|
||||||
|
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
|
||||||
@ -1466,6 +1477,12 @@ class WAN22(WAN21):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
|
class WAN21_FlowRVS(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||||
|
model_config.unet_config["model_type"] = "t2v"
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
self.image_to_video = image_to_video
|
||||||
|
|
||||||
class WAN21_SCAIL(WAN21):
|
class WAN21_SCAIL(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel)
|
||||||
|
|||||||
@ -511,6 +511,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
if ref_conv_weight is not None:
|
if ref_conv_weight is not None:
|
||||||
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1]
|
||||||
|
|
||||||
|
if metadata is not None and "config" in metadata:
|
||||||
|
dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
|
||||||
|
|||||||
@ -271,6 +271,7 @@ class ModelPatcher:
|
|||||||
self.is_clip = False
|
self.is_clip = False
|
||||||
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
|
||||||
|
|
||||||
|
self.cached_patcher_init: tuple[Callable, tuple] | None = None
|
||||||
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||||
self.model.model_loaded_weight_memory = 0
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
@ -307,8 +308,15 @@ class ModelPatcher:
|
|||||||
def get_free_memory(self, device):
|
def get_free_memory(self, device):
|
||||||
return comfy.model_management.get_free_memory(device)
|
return comfy.model_management.get_free_memory(device)
|
||||||
|
|
||||||
def clone(self):
|
def clone(self, disable_dynamic=False):
|
||||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
class_ = self.__class__
|
||||||
|
model = self.model
|
||||||
|
if self.is_dynamic() and disable_dynamic:
|
||||||
|
class_ = ModelPatcher
|
||||||
|
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
|
||||||
|
model = temp_model_patcher.model
|
||||||
|
|
||||||
|
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
for k in self.patches:
|
for k in self.patches:
|
||||||
n.patches[k] = self.patches[k][:]
|
n.patches[k] = self.patches[k][:]
|
||||||
@ -362,6 +370,8 @@ class ModelPatcher:
|
|||||||
n.is_clip = self.is_clip
|
n.is_clip = self.is_clip
|
||||||
n.hook_mode = self.hook_mode
|
n.hook_mode = self.hook_mode
|
||||||
|
|
||||||
|
n.cached_patcher_init = self.cached_patcher_init
|
||||||
|
|
||||||
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
|
||||||
callback(self, n)
|
callback(self, n)
|
||||||
return n
|
return n
|
||||||
|
|||||||
@ -83,6 +83,16 @@ class IMG_TO_IMG(X0):
|
|||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
|
class IMG_TO_IMG_FLOW(CONST):
|
||||||
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
|
def inverse_noise_scaling(self, sigma, latent):
|
||||||
|
return 1.0 - latent
|
||||||
|
|
||||||
class COSMOS_RFLOW:
|
class COSMOS_RFLOW:
|
||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
sigma = (sigma / (sigma + 1))
|
sigma = (sigma / (sigma + 1))
|
||||||
|
|||||||
10
comfy/ops.py
10
comfy/ops.py
@ -19,7 +19,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import json
|
import json
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
@ -296,7 +296,7 @@ class disable_weight_init:
|
|||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -317,7 +317,7 @@ class disable_weight_init:
|
|||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
if not comfy.model_management.WINDOWS or not enables_dynamic_vram():
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||||
@ -827,6 +827,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
else:
|
else:
|
||||||
sd = {}
|
sd = {}
|
||||||
|
|
||||||
|
if not hasattr(self, 'weight'):
|
||||||
|
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
|
||||||
|
return sd
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
sd["{}bias".format(prefix)] = self.bias
|
sd["{}bias".format(prefix)] = self.bias
|
||||||
|
|
||||||
|
|||||||
32
comfy/sd.py
32
comfy/sd.py
@ -694,8 +694,9 @@ class VAE:
|
|||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
self.output_channels = sd["encoder.conv1.weight"].shape[1]
|
||||||
|
self.conv_out_channels = sd["decoder.head.2.weight"].shape[0]
|
||||||
self.pad_channel_value = 1.0
|
self.pad_channel_value = 1.0
|
||||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0}
|
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0}
|
||||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||||
@ -1530,14 +1531,24 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||||
if out is None:
|
if out is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
|
||||||
|
if output_model:
|
||||||
|
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
|
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
|
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
|
||||||
|
embedding_directory=embedding_directory,
|
||||||
|
model_options=model_options,
|
||||||
|
te_model_options=te_model_options,
|
||||||
|
disable_dynamic=disable_dynamic)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
@ -1586,7 +1597,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
if output_model:
|
if output_model:
|
||||||
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
|
||||||
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
|
||||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||||
|
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
|
||||||
|
|
||||||
if output_vae:
|
if output_vae:
|
||||||
@ -1637,7 +1649,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False):
|
||||||
"""
|
"""
|
||||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
@ -1721,7 +1733,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
|
||||||
|
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
if not model_management.is_device_cpu(offload_device):
|
if not model_management.is_device_cpu(offload_device):
|
||||||
model.to(offload_device)
|
model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
|
||||||
@ -1730,12 +1743,13 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
|||||||
logging.info("left over keys in diffusion model: {}".format(left_over))
|
logging.info("left over keys in diffusion model: {}".format(left_over))
|
||||||
return model_patcher
|
return model_patcher
|
||||||
|
|
||||||
def load_diffusion_model(unet_path, model_options={}):
|
def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
|
||||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||||
|
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_unet(unet_path, dtype=None):
|
||||||
|
|||||||
@ -1256,6 +1256,16 @@ class WAN22_T2V(WAN21_T2V):
|
|||||||
out = model_base.WAN22(self, image_to_video=True, device=device)
|
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class WAN21_FlowRVS(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "flow_rvs",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN21_SCAIL(WAN21_T2V):
|
class WAN21_SCAIL(WAN21_T2V):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -1677,6 +1687,6 @@ class ACEStep15(supported_models_base.BASE):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import math
|
import math
|
||||||
|
import itertools
|
||||||
|
|
||||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
@ -72,7 +73,7 @@ class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
|||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
@ -199,8 +200,10 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
constant /= 2.0
|
constant /= 2.0
|
||||||
|
|
||||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs])
|
||||||
num_tokens = max(num_tokens, 64)
|
|
||||||
|
num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m
|
||||||
|
num_tokens = max(num_tokens, 642)
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
|||||||
@ -29,7 +29,7 @@ import itertools
|
|||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import mmap
|
import mmap
|
||||||
@ -113,7 +113,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
metadata = None
|
metadata = None
|
||||||
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
|
||||||
try:
|
try:
|
||||||
if enables_dynamic_vram():
|
if comfy.memory_management.aimdo_enabled:
|
||||||
sd, metadata = load_safetensors(ckpt)
|
sd, metadata = load_safetensors(ckpt)
|
||||||
if not return_metadata:
|
if not return_metadata:
|
||||||
metadata = None
|
metadata = None
|
||||||
|
|||||||
@ -27,6 +27,7 @@ class Seedream4TaskCreationRequest(BaseModel):
|
|||||||
sequential_image_generation: str = Field("disabled")
|
sequential_image_generation: str = Field("disabled")
|
||||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||||
watermark: bool = Field(False)
|
watermark: bool = Field(False)
|
||||||
|
output_format: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskCreationResponse(BaseModel):
|
class ImageTaskCreationResponse(BaseModel):
|
||||||
@ -106,6 +107,7 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
|||||||
("2496x1664 (3:2)", 2496, 1664),
|
("2496x1664 (3:2)", 2496, 1664),
|
||||||
("1664x2496 (2:3)", 1664, 2496),
|
("1664x2496 (2:3)", 1664, 2496),
|
||||||
("3024x1296 (21:9)", 3024, 1296),
|
("3024x1296 (21:9)", 3024, 1296),
|
||||||
|
("3072x3072 (1:1)", 3072, 3072),
|
||||||
("4096x4096 (1:1)", 4096, 4096),
|
("4096x4096 (1:1)", 4096, 4096),
|
||||||
("Custom", None, None),
|
("Custom", None, None),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -37,6 +37,12 @@ from comfy_api_nodes.util import (
|
|||||||
|
|
||||||
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
|
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
|
||||||
|
|
||||||
|
SEEDREAM_MODELS = {
|
||||||
|
"seedream 5.0 lite": "seedream-5-0-260128",
|
||||||
|
"seedream-4-5-251128": "seedream-4-5-251128",
|
||||||
|
"seedream-4-0-250828": "seedream-4-0-250828",
|
||||||
|
}
|
||||||
|
|
||||||
# Long-running tasks endpoints(e.g., video)
|
# Long-running tasks endpoints(e.g., video)
|
||||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||||
@ -180,14 +186,13 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ByteDanceSeedreamNode",
|
node_id="ByteDanceSeedreamNode",
|
||||||
display_name="ByteDance Seedream 4.5",
|
display_name="ByteDance Seedream 5.0",
|
||||||
category="api node/image/ByteDance",
|
category="api node/image/ByteDance",
|
||||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["seedream-4-5-251128", "seedream-4-0-250828"],
|
options=list(SEEDREAM_MODELS.keys()),
|
||||||
tooltip="Model name",
|
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -198,7 +203,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Input image(s) for image-to-image generation. "
|
tooltip="Input image(s) for image-to-image generation. "
|
||||||
"List of 1-10 images for single or multi-reference generation.",
|
"Reference image(s) for single or multi-reference generation.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -210,8 +215,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
"width",
|
"width",
|
||||||
default=2048,
|
default=2048,
|
||||||
min=1024,
|
min=1024,
|
||||||
max=4096,
|
max=6240,
|
||||||
step=8,
|
step=2,
|
||||||
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -219,8 +224,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
"height",
|
"height",
|
||||||
default=2048,
|
default=2048,
|
||||||
min=1024,
|
min=1024,
|
||||||
max=4096,
|
max=4992,
|
||||||
step=8,
|
step=2,
|
||||||
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -283,7 +288,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$price := $contains(widgets.model, "seedream-4-5-251128") ? 0.04 : 0.03;
|
$price := $contains(widgets.model, "5.0 lite") ? 0.035 :
|
||||||
|
$contains(widgets.model, "4-5") ? 0.04 : 0.03;
|
||||||
{
|
{
|
||||||
"type":"usd",
|
"type":"usd",
|
||||||
"usd": $price,
|
"usd": $price,
|
||||||
@ -309,6 +315,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
watermark: bool = False,
|
watermark: bool = False,
|
||||||
fail_on_partial: bool = True,
|
fail_on_partial: bool = True,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
|
model = SEEDREAM_MODELS[model]
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
w = h = None
|
w = h = None
|
||||||
for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4:
|
for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4:
|
||||||
@ -318,15 +325,12 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
|
|
||||||
if w is None or h is None:
|
if w is None or h is None:
|
||||||
w, h = width, height
|
w, h = width, height
|
||||||
if not (1024 <= w <= 4096) or not (1024 <= h <= 4096):
|
|
||||||
raise ValueError(
|
|
||||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
|
||||||
)
|
|
||||||
out_num_pixels = w * h
|
out_num_pixels = w * h
|
||||||
mp_provided = out_num_pixels / 1_000_000.0
|
mp_provided = out_num_pixels / 1_000_000.0
|
||||||
if "seedream-4-5" in model and out_num_pixels < 3686400:
|
if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
|
f"Minimum image resolution for the selected model is 3.68MP, "
|
||||||
f"but {mp_provided:.2f}MP provided."
|
f"but {mp_provided:.2f}MP provided."
|
||||||
)
|
)
|
||||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||||
@ -334,9 +338,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
||||||
f"but {mp_provided:.2f}MP provided."
|
f"but {mp_provided:.2f}MP provided."
|
||||||
)
|
)
|
||||||
|
max_pixels = 10_404_496 if "seedream-5-0" in model else 16_777_216
|
||||||
|
if out_num_pixels > max_pixels:
|
||||||
|
raise ValueError(
|
||||||
|
f"Maximum image resolution for the selected model is {max_pixels / 1_000_000:.2f}MP, "
|
||||||
|
f"but {mp_provided:.2f}MP provided."
|
||||||
|
)
|
||||||
n_input_images = get_number_of_images(image) if image is not None else 0
|
n_input_images = get_number_of_images(image) if image is not None else 0
|
||||||
if n_input_images > 10:
|
max_num_of_images = 14 if model == "seedream-5-0-260128" else 10
|
||||||
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
|
if n_input_images > max_num_of_images:
|
||||||
|
raise ValueError(
|
||||||
|
f"Maximum of {max_num_of_images} reference images are supported, but {n_input_images} received."
|
||||||
|
)
|
||||||
if sequential_image_generation == "auto" and n_input_images + max_images > 15:
|
if sequential_image_generation == "auto" and n_input_images + max_images > 15:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The maximum number of generated images plus the number of reference images cannot exceed 15."
|
"The maximum number of generated images plus the number of reference images cannot exceed 15."
|
||||||
@ -364,6 +377,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
|||||||
sequential_image_generation=sequential_image_generation,
|
sequential_image_generation=sequential_image_generation,
|
||||||
sequential_image_generation_options=Seedream4Options(max_images=max_images),
|
sequential_image_generation_options=Seedream4Options(max_images=max_images),
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
output_format="png" if model == "seedream-5-0-260128" else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if len(response.data) == 1:
|
if len(response.data) == 1:
|
||||||
|
|||||||
@ -134,6 +134,36 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
|||||||
generate = execute # TODO: remove
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
|
||||||
|
"""Append a guide_attention_entry to both positive and negative conditioning.
|
||||||
|
|
||||||
|
Each entry tracks one guide reference for per-reference attention control.
|
||||||
|
Entries are derived independently from each conditioning to avoid cross-contamination.
|
||||||
|
"""
|
||||||
|
new_entry = {
|
||||||
|
"pre_filter_count": pre_filter_count,
|
||||||
|
"strength": strength,
|
||||||
|
"pixel_mask": None,
|
||||||
|
"latent_shape": latent_shape,
|
||||||
|
}
|
||||||
|
results = []
|
||||||
|
for cond in (positive, negative):
|
||||||
|
# Read existing entries from this specific conditioning
|
||||||
|
existing = []
|
||||||
|
for t in cond:
|
||||||
|
found = t[1].get("guide_attention_entries", None)
|
||||||
|
if found is not None:
|
||||||
|
existing = found
|
||||||
|
break
|
||||||
|
# Shallow copy and append (no deepcopy needed — entries contain
|
||||||
|
# only scalars and None for pixel_mask at this call site).
|
||||||
|
entries = [*existing, new_entry]
|
||||||
|
results.append(node_helpers.conditioning_set_values(
|
||||||
|
cond, {"guide_attention_entries": entries}
|
||||||
|
))
|
||||||
|
return results[0], results[1]
|
||||||
|
|
||||||
|
|
||||||
def conditioning_get_any_value(conditioning, key, default=None):
|
def conditioning_get_any_value(conditioning, key, default=None):
|
||||||
for t in conditioning:
|
for t in conditioning:
|
||||||
if key in t[1]:
|
if key in t[1]:
|
||||||
@ -324,6 +354,13 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
scale_factors,
|
scale_factors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Track this guide for per-reference attention control.
|
||||||
|
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
|
||||||
|
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
|
||||||
|
positive, negative = _append_guide_attention_entry(
|
||||||
|
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
|
||||||
|
)
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||||
|
|
||||||
generate = execute # TODO: remove
|
generate = execute # TODO: remove
|
||||||
@ -359,8 +396,14 @@ class LTXVCropGuides(io.ComfyNode):
|
|||||||
latent_image = latent_image[:, :, :-num_keyframes]
|
latent_image = latent_image[:, :, :-num_keyframes]
|
||||||
noise_mask = noise_mask[:, :, :-num_keyframes]
|
noise_mask = noise_mask[:, :, :-num_keyframes]
|
||||||
|
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None})
|
positive = node_helpers.conditioning_set_values(positive, {
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None})
|
"keyframe_idxs": None,
|
||||||
|
"guide_attention_entries": None,
|
||||||
|
})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {
|
||||||
|
"keyframe_idxs": None,
|
||||||
|
"guide_attention_entries": None,
|
||||||
|
})
|
||||||
|
|
||||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
|
||||||
|
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class ModelSamplingDiscrete:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "model": ("MODEL",),
|
return {"required": { "model": ("MODEL",),
|
||||||
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],),
|
"sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],),
|
||||||
"zsnr": ("BOOLEAN", {"default": False, "advanced": True}),
|
"zsnr": ("BOOLEAN", {"default": False, "advanced": True}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
@ -76,6 +76,8 @@ class ModelSamplingDiscrete:
|
|||||||
sampling_type = comfy.model_sampling.X0
|
sampling_type = comfy.model_sampling.X0
|
||||||
elif sampling == "img_to_img":
|
elif sampling == "img_to_img":
|
||||||
sampling_type = comfy.model_sampling.IMG_TO_IMG
|
sampling_type = comfy.model_sampling.IMG_TO_IMG
|
||||||
|
elif sampling == "img_to_img_flow":
|
||||||
|
sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||||
|
|
||||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -79,7 +79,6 @@ class Blur(io.ComfyNode):
|
|||||||
node_id="ImageBlur",
|
node_id="ImageBlur",
|
||||||
display_name="Image Blur",
|
display_name="Image Blur",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
essentials_category="Image Tools",
|
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
|
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
|
||||||
@ -568,6 +567,7 @@ class BatchImagesNode(io.ComfyNode):
|
|||||||
node_id="BatchImagesNode",
|
node_id="BatchImagesNode",
|
||||||
display_name="Batch Images",
|
display_name="Batch Images",
|
||||||
category="image",
|
category="image",
|
||||||
|
essentials_category="Image Tools",
|
||||||
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
|
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Autogrow.Input("images", template=autogrow_template)
|
io.Autogrow.Input("images", template=autogrow_template)
|
||||||
|
|||||||
@ -25,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, backend) -> io.NodeOutput:
|
def execute(cls, model, backend) -> io.NodeOutput:
|
||||||
m = model.clone()
|
m = model.clone(disable_dynamic=True)
|
||||||
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
|
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
|
||||||
return io.NodeOutput(m)
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|||||||
@ -147,7 +147,6 @@ class GetVideoComponents(io.ComfyNode):
|
|||||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||||
display_name="Get Video Components",
|
display_name="Get Video Components",
|
||||||
category="image/video",
|
category="image/video",
|
||||||
essentials_category="Video Tools",
|
|
||||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||||
@ -218,6 +217,7 @@ class VideoSlice(io.ComfyNode):
|
|||||||
"start time",
|
"start time",
|
||||||
],
|
],
|
||||||
category="image/video",
|
category="image/video",
|
||||||
|
essentials_category="Video Tools",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Video.Input("video"),
|
io.Video.Input("video"),
|
||||||
io.Float.Input(
|
io.Float.Input(
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.14.1"
|
__version__ = "0.15.0"
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -1925,7 +1925,6 @@ class ImageInvert:
|
|||||||
|
|
||||||
class ImageBatch:
|
class ImageBatch:
|
||||||
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
|
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
|
||||||
ESSENTIALS_CATEGORY = "Image Tools"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.14.1"
|
version = "0.15.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
comfyui-frontend-package==1.39.16
|
comfyui-frontend-package==1.39.19
|
||||||
comfyui-workflow-templates==0.9.2
|
comfyui-workflow-templates==0.9.3
|
||||||
comfyui-embedded-docs==0.4.1
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
@ -22,7 +22,7 @@ alembic
|
|||||||
SQLAlchemy
|
SQLAlchemy
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.7
|
comfy-kitchen>=0.2.7
|
||||||
comfy-aimdo>=0.2.0
|
comfy-aimdo>=0.2.2
|
||||||
requests
|
requests
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user