Merge branch 'comfyanonymous:master' into master

This commit is contained in:
Extraltodeus 2025-12-07 17:24:23 +01:00 committed by GitHub
commit 277e46473d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 1375 additions and 338 deletions

View File

@ -320,6 +320,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536) 1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py` 2. Launch ComfyUI by running `python main.py`
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
### Setup
1. Install the manager dependencies:
```bash
pip install -r manager_requirements.txt
```
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
```bash
python main.py --enable-manager
```
### Command Line Options
| Flag | Description |
|------|-------------|
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
# Running # Running
```python main.py``` ```python main.py```

View File

@ -51,26 +51,36 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC): class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0): def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
self.index_list = index_list self.index_list = index_list
self.context_length = len(index_list) self.context_length = len(index_list)
self.dim = dim self.dim = dim
self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor: def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None: if dim is None:
dim = self.dim dim = self.dim
if dim == 0 and full.shape[dim] == 1: if dim == 0 and full.shape[dim] == 1:
return full return full
idx = [slice(None)] * dim + [self.index_list] idx = tuple([slice(None)] * dim + [self.index_list])
return full[idx].to(device) window = full[idx]
if retain_index_list:
idx = tuple([slice(None)] * dim + [retain_index_list])
window[idx] = full[idx]
return window.to(device)
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor: def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
if dim is None: if dim is None:
dim = self.dim dim = self.dim
idx = [slice(None)] * dim + [self.index_list] idx = tuple([slice(None)] * dim + [self.index_list])
full[idx] += to_add full[idx] += to_add
return full return full
def get_region_index(self, num_regions: int) -> int:
region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1)
class IndexListCallbacks: class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@ -94,7 +104,8 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC): class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0): def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
self.context_schedule = context_schedule self.context_schedule = context_schedule
self.fuse_method = fuse_method self.fuse_method = fuse_method
self.context_length = context_length self.context_length = context_length
@ -103,13 +114,18 @@ class IndexListContextHandler(ContextHandlerABC):
self.closed_loop = closed_loop self.closed_loop = closed_loop
self.dim = dim self.dim = dim
self._step = 0 self._step = 0
self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows
self.callbacks = {} self.callbacks = {}
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation # for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length: if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.") logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
return True return True
return False return False
@ -123,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC):
return None return None
# reuse or resize cond items to match context requirements # reuse or resize cond items to match context requirements
resized_cond = [] resized_cond = []
# if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it # cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in: for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy() resized_actual_cond = actual_cond.copy()
@ -146,12 +167,19 @@ class IndexListContextHandler(ContextHandlerABC):
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items(): for cond_key, cond_value in new_cond_item.items():
if isinstance(cond_value, torch.Tensor): if isinstance(cond_value, torch.Tensor):
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim): if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device) new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
audio_cond = cond_value.cond
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
# if has cond that is a Tensor, check if needs to be subset # if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim): if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device)) (cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
elif cond_key == "num_video_frames": # for SVD elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length new_cond_item[cond_key].cond = window.context_length
@ -164,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]): def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001) mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
matches = torch.nonzero(mask) matches = torch.nonzero(mask)
if torch.numel(matches) == 0: if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.") raise Exception("No sample_sigmas matched current timestep; something went wrong.")
@ -173,7 +201,7 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options) context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows] context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
return context_windows return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
@ -250,8 +278,8 @@ class IndexListContextHandler(ContextHandlerABC):
prev_weight = (bias_total / (bias_total + bias)) prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias)) new_weight = (bias / (bias_total + bias))
# account for dims of tensors # account for dims of tensors
idx_window = [slice(None)] * self.dim + [idx] idx_window = tuple([slice(None)] * self.dim + [idx])
pos_window = [slice(None)] * self.dim + [pos] pos_window = tuple([slice(None)] * self.dim + [pos])
# apply new values # apply new values
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
biases_final[i][idx] = bias_total + bias biases_final[i][idx] = bias_total + bias
@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
) )
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
model_options = extra_args.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is None:
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
"ContextWindows_sampler_sample",
_sampler_sample_wrapper
)
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape) total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device) weights_tensor = torch.Tensor(weights).to(device=device)
@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
for i in range(len(window)): for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end # 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta window[i] = window[i] + end_delta
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
logging.info("Context windows: Applying FreeNoise")
generator = torch.Generator(device='cpu').manual_seed(seed)
latent_video_length = noise.shape[dim]
delta = context_length - context_overlap
for start_idx in range(0, latent_video_length - context_length, delta):
place_idx = start_idx + context_length
actual_delta = min(delta, latent_video_length - place_idx)
if actual_delta <= 0:
break
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
source_slice = [slice(None)] * noise.ndim
source_slice[dim] = list_idx
target_slice = [slice(None)] * noise.ndim
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
noise[tuple(target_slice)] = noise[tuple(source_slice)]
return noise

View File

@ -0,0 +1,413 @@
import torch
from torch import nn
import math
import comfy.ldm.common_dit
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.layers import EmbedND
def attention(q, k, v, heads, transformer_options={}):
return optimized_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
heads=heads,
skip_reshape=True,
transformer_options=transformer_options
)
def apply_scale_shift_norm(norm, x, scale, shift):
return torch.addcmul(shift, norm(x), scale + 1.0)
def apply_gate_sum(x, out, gate):
return torch.addcmul(x, gate, out)
def get_shift_scale_gate(params):
shift, scale, gate = torch.chunk(params, 3, dim=-1)
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
def get_freqs(dim, max_period=10000.0):
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
class TimeEmbeddings(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.SiLU()
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, timestep, dtype):
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
class TextEmbeddings(nn.Module):
def __init__(self, text_dim, model_dim, operation_settings=None):
super().__init__()
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, text_embed):
text_embed = self.in_layer(text_embed)
return self.norm(text_embed).type_as(text_embed)
class VisualEmbeddings(nn.Module):
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
super().__init__()
self.patch_size = patch_size
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
x = x.movedim(1, -1) # B C T H W -> B T H W C
B, T, H, W, dim = x.shape
pt, ph, pw = self.patch_size
x = x.view(
B,
T // pt, pt,
H // ph, ph,
W // pw, pw,
dim,
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
return self.in_layer(x)
class Modulation(nn.Module):
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
super().__init__()
self.activation = nn.SiLU()
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
return self.out_layer(self.activation(x))
class SelfAttention(nn.Module):
def __init__(self, num_channels, head_dim, operation_settings=None):
super().__init__()
assert num_channels % head_dim == 0
self.num_heads = num_channels // head_dim
self.head_dim = head_dim
operations = operation_settings.get("operations")
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 2
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(norm_fn(result), freqs)
def _forward(self, x, freqs, transformer_options={}):
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
def _forward_chunked(self, x, freqs, transformer_options={}):
def process_chunks(proj_fn, norm_fn):
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
chunks = []
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
return torch.cat(chunks, dim=1)
q = process_chunks(self.to_query, self.query_norm)
k = process_chunks(self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
def forward(self, x, freqs, transformer_options={}):
if x.shape[1] > 8192:
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
else:
return self._forward(x, freqs, transformer_options=transformer_options)
class CrossAttention(SelfAttention):
def get_qkv(self, x, context):
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
return q, k, v
def forward(self, x, context, transformer_options={}):
q, k, v = self.get_qkv(x, context)
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
class FeedForward(nn.Module):
def __init__(self, dim, ff_dim, operation_settings=None):
super().__init__()
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.GELU()
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 4
def _forward(self, x):
return self.out_layer(self.activation(self.in_layer(x)))
def _forward_chunked(self, x):
chunks = torch.chunk(x, self.num_chunks, dim=1)
output_chunks = []
for chunk in chunks:
output_chunks.append(self._forward(chunk))
return torch.cat(output_chunks, dim=1)
def forward(self, x):
if x.shape[1] > 8192:
return self._forward_chunked(x)
else:
return self._forward(x)
class OutLayer(nn.Module):
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
super().__init__()
self.patch_size = patch_size
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, visual_embed, time_embed):
B, T, H, W, _ = visual_embed.shape
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
scale = scale[:, None, None, None, :]
shift = shift[:, None, None, None, :]
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
x = self.out_layer(visual_embed)
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
x = x.view(
B, T, H, W,
out_dim,
self.patch_size[0], self.patch_size[1], self.patch_size[2]
)
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
class TransformerEncoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
super().__init__()
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, x, time_embed, freqs, transformer_options={}):
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
shift, scale, gate = get_shift_scale_gate(self_attn_params)
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
out = self.self_attention(out, freqs, transformer_options=transformer_options)
x = apply_gate_sum(x, out, gate)
shift, scale, gate = get_shift_scale_gate(ff_params)
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
out = self.feed_forward(out)
x = apply_gate_sum(x, out, gate)
return x
class TransformerDecoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
super().__init__()
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
# self attention
shift, scale, gate = get_shift_scale_gate(self_attn_params)
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# cross attention
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# feed forward
shift, scale, gate = get_shift_scale_gate(ff_params)
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
visual_out = self.feed_forward(visual_out)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
return visual_embed
class Kandinsky5(nn.Module):
def __init__(
self,
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
dtype=None, device=None, operations=None, **kwargs
):
super().__init__()
head_dim = sum(axes_dims)
self.rope_scale_factor = rope_scale_factor
self.in_visual_dim = in_visual_dim
self.model_dim = model_dim
self.patch_size = patch_size
self.visual_embed_dim = visual_embed_dim
self.dtype = dtype
self.device = device
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
self.text_transformer_blocks = nn.ModuleList(
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
)
self.visual_transformer_blocks = nn.ModuleList(
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
)
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
steps = seq_len if steps is None else steps
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
return freqs
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if steps_t is None:
steps_t = t_len
if steps_h is None:
steps_h = h_len
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
else:
rope_scale_factor = self.rope_scale_factor
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
if h * w >= 14080:
rope_scale_factor = (1.0, 3.16, 3.16)
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
return freqs
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
context = self.text_embeddings(context)
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
for block in self.text_transformer_blocks:
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
visual_embed = self.visual_embeddings(x)
visual_shape = visual_embed.shape[:-1]
visual_embed = visual_embed.flatten(1, -2)
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.visual_transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
else:
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
visual_embed = visual_embed.reshape(*visual_shape, -1)
return self.out_layer(visual_embed, time_embed)
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
original_dims = x.ndim
if original_dims == 4:
x = x.unsqueeze(2)
bs, c, t_len, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if time_dim_replace is not None:
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
if original_dims == 4:
out = out.squeeze(2)
return out
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)

View File

@ -377,6 +377,7 @@ class NextDiT(nn.Module):
z_image_modulation=False, z_image_modulation=False,
time_scale=1.0, time_scale=1.0,
pad_tokens_multiple=None, pad_tokens_multiple=None,
clip_text_dim=None,
image_model=None, image_model=None,
device=None, device=None,
dtype=None, dtype=None,
@ -447,6 +448,31 @@ class NextDiT(nn.Module):
), ),
) )
self.clip_text_pooled_proj = None
if clip_text_dim is not None:
self.clip_text_dim = clip_text_dim
self.clip_text_pooled_proj = nn.Sequential(
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
clip_text_dim,
clip_text_dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.time_text_embed = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024) + clip_text_dim,
min(dim, 1024),
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
JointTransformerBlock( JointTransformerBlock(
@ -585,6 +611,15 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor) x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)

View File

@ -322,6 +322,13 @@ def model_lora_keys_unet(model, key_map={}):
key_map["diffusion_model.{}".format(key_lora)] = to key_map["diffusion_model.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
if isinstance(model, comfy.model_base.Kandinsky5):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
key_map["transformer.{}".format(key_lora)] = k
return key_map return key_map

View File

@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2 import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
@ -1109,6 +1110,10 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out: if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs["pooled_output"] # Newbie
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
return out return out
class WAN21(BaseModel): class WAN21(BaseModel):
@ -1630,3 +1635,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(False) out['disable_time_r'] = comfy.conds.CONDConstant(False)
return out return out
class Kandinsky5(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
device = kwargs["device"]
image = torch.zeros_like(noise)
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :1]
else:
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
time_dim_replace = kwargs.get("time_dim_replace", None)
if time_dim_replace is not None:
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
return out
class Kandinsky5Image(Kandinsky5):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
return None

View File

@ -423,6 +423,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512] dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0 dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0 dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None:
dit_config["clip_text_dim"] = ctd_weight.shape[0]
elif dit_config["dim"] == 3840: # Z image elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30 dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30 dit_config["n_kv_heads"] = 30
@ -611,6 +614,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
dit_config["model_dim"] = model_dim
if model_dim in [4096, 2560]: # pro video and lite image
dit_config["axes_dims"] = (32, 48, 48)
if model_dim == 2560: # lite image
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
elif model_dim == 1792: # lite video
dit_config["axes_dims"] = (16, 24, 24)
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
dit_config["image_model"] = "kandinsky5"
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None return None

View File

@ -1492,6 +1492,20 @@ def extended_fp16_support():
return True return True
LORA_COMPUTE_DTYPES = {}
def lora_compute_dtype(device):
dtype = LORA_COMPUTE_DTYPES.get(device, None)
if dtype is not None:
return dtype
if should_use_fp16(device):
dtype = torch.float16
else:
dtype = torch.float32
LORA_COMPUTE_DTYPES[device] = dtype
return dtype
def soft_empty_cache(force=False): def soft_empty_cache(force=False):
global cpu_state global cpu_state
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:

View File

@ -614,10 +614,11 @@ class ModelPatcher:
if key not in self.backup: if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
if device_to is not None: if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else: else:
temp_weight = weight.to(torch.float32, copy=True) temp_weight = weight.to(temp_dtype, copy=True)
if convert_func is not None: if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True) temp_weight = convert_func(temp_weight, inplace=True)
@ -761,6 +762,8 @@ class ModelPatcher:
key = "{}.{}".format(n, param) key = "{}.{}".format(n, param)
self.unpin_weight(key) self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to) self.patch_weight_to_device(key, device_to=device_to)
if comfy.model_management.is_device_cuda(device_to):
torch.cuda.synchronize()
logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True m.comfy_patched_weights = True

View File

@ -552,6 +552,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight_scale_key = f"{prefix}weight_scale" weight_scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(weight_scale_key, None) scale = state_dict.pop(weight_scale_key, None)
if scale is not None:
scale = scale.to(device)
layout_params = { layout_params = {
'scale': scale, 'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype, 'orig_dtype': MixedPrecisionOps._compute_dtype,

View File

@ -54,6 +54,7 @@ import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image import comfy.text_encoders.z_image
import comfy.text_encoders.ovis import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -98,7 +99,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP: class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}): def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
if no_init: if no_init:
return return
params = target.params.copy() params = target.params.copy()
@ -129,6 +130,27 @@ class CLIP:
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True self.patcher.is_clip = True
self.apply_hooks_to_conds = None self.apply_hooks_to_conds = None
if len(state_dict) > 0:
if isinstance(state_dict, list):
for c in state_dict:
m, u = self.load_sd(c)
if len(m) > 0:
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected: {}".format(u))
else:
m, u = self.load_sd(state_dict, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
if params['device'] == load_device: if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True) model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None self.layer_idx = None
@ -745,6 +767,8 @@ class VAE:
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
pixel_samples = None pixel_samples = None
do_tile = False do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
try: try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -962,6 +986,8 @@ class CLIPType(Enum):
HUNYUAN_IMAGE = 19 HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20 HUNYUAN_VIDEO_15 = 20
OVIS = 21 OVIS = 21
KANDINSKY5 = 22
KANDINSKY5_IMAGE = 23
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -1210,6 +1236,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_VIDEO_15: elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
elif clip_type == CLIPType.KANDINSKY5:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
else: else:
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@ -1225,14 +1257,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters += comfy.utils.calculate_parameters(c) parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options) clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected: {}".format(u))
return clip return clip
def load_gligen(ckpt_path): def load_gligen(ckpt_path):
@ -1362,17 +1387,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
clip_sd = model_config.process_clip_state_dict(sd) clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0: if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd) parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options) clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
else: else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")

View File

@ -21,6 +21,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2 import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image import comfy.text_encoders.z_image
from . import supported_models_base from . import supported_models_base
@ -1474,7 +1475,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
class Kandinsky5(supported_models_base.BASE):
unet_config = {
"image_model": "kandinsky5",
}
sampling_settings = {
"shift": 10.0,
}
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 1.1 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
class Kandinsky5Image(Kandinsky5):
unet_config = {
"image_model": "kandinsky5",
"model_dim": 2560,
"visual_embed_dim": 64,
}
sampling_settings = {
"shift": 3.0,
}
latent_format = latent_formats.Flux
memory_usage_factor = 1.1 #TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5Image(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -17,6 +17,7 @@
""" """
import torch import torch
import logging
from . import model_base from . import model_base
from . import utils from . import utils
from . import latent_formats from . import latent_formats
@ -117,3 +118,7 @@ class BASE:
def set_inference_dtype(self, dtype, manual_cast_dtype): def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype
def __getattr__(self, name):
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
return None

View File

@ -0,0 +1,68 @@
from comfy import sd1_clip
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
from .llama import Qwen25_7BVLI
class Kandinsky5Tokenizer(QwenImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Kandinsky5TEModel(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
return cond, l_pooled, extra
def set_clip_options(self, options):
super().set_clip_options(options)
self.clip_l.set_clip_options(options)
def reset_clip_options(self):
super().reset_clip_options()
self.clip_l.reset_clip_options()
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return super().load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class Kandinsky5TEModel_(Kandinsky5TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Kandinsky5TEModel_

View File

@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO):
'''Used by WAN Camera.''' '''Used by WAN Camera.'''
time_dim_concat: NotRequired[torch.Tensor] time_dim_concat: NotRequired[torch.Tensor]
'''Used by WAN Phantom Subject.''' '''Used by WAN Phantom Subject.'''
time_dim_replace: NotRequired[torch.Tensor]
'''Used by Kandinsky5 I2V.'''
CondList = list[tuple[torch.Tensor, PooledDict]] CondList = list[tuple[torch.Tensor, PooledDict]]
Type = CondList Type = CondList

View File

@ -92,6 +92,7 @@ class PikaImageToVideo(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod
@ -152,6 +153,7 @@ class PikaTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod
@ -239,6 +241,7 @@ class PikaScenes(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod
@ -323,6 +326,7 @@ class PikAdditionsNode(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod
@ -399,6 +403,7 @@ class PikaSwapsNode(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod
@ -466,6 +471,7 @@ class PikaffectsNode(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod
@ -515,6 +521,7 @@ class PikaStartEndFrameNode(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod

View File

@ -573,12 +573,14 @@ class EmptyAudio(IO.ComfyNode):
step=0.01, step=0.01,
tooltip="Duration of the empty audio clip in seconds", tooltip="Duration of the empty audio clip in seconds",
), ),
IO.Float.Input( IO.Int.Input(
"sample_rate", "sample_rate",
default=44100, default=44100,
tooltip="Sample rate of the empty audio clip.", tooltip="Sample rate of the empty audio clip.",
min=1,
max=192000,
), ),
IO.Float.Input( IO.Int.Input(
"channels", "channels",
default=2, default=2,
min=1, min=1,

View File

@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
], ],
outputs=[ outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."), io.Model.Output(tooltip="The model with context windows applied during sampling."),
@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model: def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
model = model.clone() model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode):
context_overlap=context_overlap, context_overlap=context_overlap,
context_stride=context_stride, context_stride=context_stride,
closed_loop=closed_loop, closed_loop=closed_loop,
dim=dim) dim=dim,
freenoise=freenoise,
cond_retain_index_list=cond_retain_index_list,
split_conds_to_windows=split_conds_to_windows
)
# make memory usage calculation only take into account the context window latents # make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model) comfy.context_windows.create_prepare_sampling_wrapper(model)
if freenoise: # no other use for this wrapper at this time
comfy.context_windows.create_sampler_sample_wrapper(model)
return io.NodeOutput(model) return io.NodeOutput(model)
class WanContextWindowsManualNode(ContextWindowsManualNode): class WanContextWindowsManualNode(ContextWindowsManualNode):
@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
] ]
return schema return schema
@classmethod @classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model: def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2) return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
class ContextWindowsExtension(ComfyExtension): class ContextWindowsExtension(ComfyExtension):

View File

@ -2,6 +2,8 @@
import torch import torch
import logging import logging
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
def Fourier_filter(x, threshold, scale): def Fourier_filter(x, threshold, scale):
# FFT # FFT
@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale):
return x_filtered.to(x.dtype) return x_filtered.to(x.dtype)
class FreeU: class FreeU(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return IO.Schema(
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}), node_id="FreeU",
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}), category="model_patches/unet",
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), inputs=[
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), IO.Model.Input("model"),
}} IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
RETURN_TYPES = ("MODEL",) IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
FUNCTION = "patch" IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
CATEGORY = "model_patches/unet" @classmethod
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {} on_cpu_devices = {}
@ -59,23 +66,31 @@ class FreeU:
m = model.clone() m = model.clone()
m.set_model_output_block_patch(output_block_patch) m.set_model_output_block_patch(output_block_patch)
return (m, ) return IO.NodeOutput(m)
class FreeU_V2: patch = execute # TODO: remove
class FreeU_V2(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return IO.Schema(
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}), node_id="FreeU_V2",
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}), category="model_patches/unet",
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), inputs=[
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), IO.Model.Input("model"),
}} IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
RETURN_TYPES = ("MODEL",) IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
FUNCTION = "patch" IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
CATEGORY = "model_patches/unet" @classmethod
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"] model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {} on_cpu_devices = {}
@ -105,9 +120,19 @@ class FreeU_V2:
m = model.clone() m = model.clone()
m.set_model_output_block_patch(output_block_patch) m.set_model_output_block_patch(output_block_patch)
return (m, ) return IO.NodeOutput(m)
NODE_CLASS_MAPPINGS = { patch = execute # TODO: remove
"FreeU": FreeU,
"FreeU_V2": FreeU_V2,
} class FreelunchExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
FreeU,
FreeU_V2,
]
async def comfy_entrypoint() -> FreelunchExtension:
return FreelunchExtension()

View File

@ -0,0 +1,136 @@
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class Kandinsky5ImageToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Kandinsky5ImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond_latent_out = {}
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encoded = vae.encode(start_image[:, :, :, :3])
cond_latent_out["samples"] = encoded
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
# normalization
normalized = (source - source_mean) / (source_std + 1e-8)
normalized = normalized * reference_std + reference_mean
return normalized
class NormalizeVideoLatentStart(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="NormalizeVideoLatentStart",
category="conditioning/video_models",
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
inputs=[
io.Latent.Input("latent"),
io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
],
outputs=[
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
if latent["samples"].shape[2] <= 1:
return io.NodeOutput(latent)
s = latent.copy()
samples = latent["samples"].clone()
first_frames = samples[:, :, :start_frame_count]
reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
samples[:, :, :start_frame_count] = normalized_first_frames
s["samples"] = samples
return io.NodeOutput(s)
class CLIPTextEncodeKandinsky5(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeKandinsky5",
category="advanced/conditioning/kandinsky5",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
tokens = clip.tokenize(clip_l)
tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class Kandinsky5Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Kandinsky5ImageToVideo,
NormalizeVideoLatentStart,
CLIPTextEncodeKandinsky5,
]
async def comfy_entrypoint() -> Kandinsky5Extension:
return Kandinsky5Extension()

View File

@ -4,7 +4,7 @@ import torch
import nodes import nodes
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import logging
def reshape_latent_to(target_shape, latent, repeat_batch=True): def reshape_latent_to(target_shape, latent, repeat_batch=True):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
@ -388,6 +388,42 @@ class LatentOperationSharpen(io.ComfyNode):
return luminance * sharpened return luminance * sharpened
return io.NodeOutput(sharpen) return io.NodeOutput(sharpen)
class ReplaceVideoLatentFrames(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ReplaceVideoLatentFrames",
category="latent/batch",
inputs=[
io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
],
outputs=[
io.Latent.Output(),
],
)
@classmethod
def execute(cls, destination, index, source=None) -> io.NodeOutput:
if source is None:
return io.NodeOutput(destination)
dest_frames = destination["samples"].shape[2]
source_frames = source["samples"].shape[2]
if index < 0:
index = dest_frames + index
if index > dest_frames:
logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
return io.NodeOutput(destination)
if index + source_frames > dest_frames:
logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
return io.NodeOutput(destination)
s = source.copy()
s_source = source["samples"]
s_destination = destination["samples"].clone()
s_destination[:, :, index:index + s_source.shape[2]] = s_source
s["samples"] = s_destination
return io.NodeOutput(s)
class LatentExtension(ComfyExtension): class LatentExtension(ComfyExtension):
@override @override
@ -405,6 +441,7 @@ class LatentExtension(ComfyExtension):
LatentApplyOperationCFG, LatentApplyOperationCFG,
LatentOperationTonemapReinhard, LatentOperationTonemapReinhard,
LatentOperationSharpen, LatentOperationSharpen,
ReplaceVideoLatentFrames
] ]

View File

@ -3,11 +3,10 @@ import scipy.ndimage
import torch import torch
import comfy.utils import comfy.utils
import node_helpers import node_helpers
import folder_paths from typing_extensions import override
import random from comfy_api.latest import ComfyExtension, IO, UI
import nodes import nodes
from nodes import MAX_RESOLUTION
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
source = source.to(destination.device) source = source.to(destination.device)
@ -46,202 +45,213 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
destination[..., top:bottom, left:right] = source_portion + destination_portion destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination return destination
class LatentCompositeMasked: class LatentCompositeMasked(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="LatentCompositeMasked",
"destination": ("LATENT",), category="latent",
"source": ("LATENT",), inputs=[
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), IO.Latent.Input("destination"),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), IO.Latent.Input("source"),
"resize_source": ("BOOLEAN", {"default": False}), IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
}, IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
"optional": { IO.Boolean.Input("resize_source", default=False),
"mask": ("MASK",), IO.Mask.Input("mask", optional=True),
} ],
} outputs=[IO.Latent.Output()],
RETURN_TYPES = ("LATENT",) )
FUNCTION = "composite"
CATEGORY = "latent" @classmethod
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
def composite(self, destination, source, x, y, resize_source, mask = None):
output = destination.copy() output = destination.copy()
destination = destination["samples"].clone() destination = destination["samples"].clone()
source = source["samples"] source = source["samples"]
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source) output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
return (output,) return IO.NodeOutput(output)
class ImageCompositeMasked: composite = execute # TODO: remove
class ImageCompositeMasked(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="ImageCompositeMasked",
"destination": ("IMAGE",), category="image",
"source": ("IMAGE",), inputs=[
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), IO.Image.Input("destination"),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), IO.Image.Input("source"),
"resize_source": ("BOOLEAN", {"default": False}), IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
}, IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
"optional": { IO.Boolean.Input("resize_source", default=False),
"mask": ("MASK",), IO.Mask.Input("mask", optional=True),
} ],
} outputs=[IO.Image.Output()],
RETURN_TYPES = ("IMAGE",) )
FUNCTION = "composite"
CATEGORY = "image" @classmethod
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
def composite(self, destination, source, x, y, resize_source, mask = None):
destination, source = node_helpers.image_alpha_fix(destination, source) destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1) destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
return (output,) return IO.NodeOutput(output)
class MaskToImage: composite = execute # TODO: remove
class MaskToImage(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="MaskToImage",
"mask": ("MASK",), display_name="Convert Mask to Image",
} category="mask",
} inputs=[
IO.Mask.Input("mask"),
],
outputs=[IO.Image.Output()],
)
CATEGORY = "mask" @classmethod
def execute(cls, mask) -> IO.NodeOutput:
RETURN_TYPES = ("IMAGE",)
FUNCTION = "mask_to_image"
def mask_to_image(self, mask):
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return (result,) return IO.NodeOutput(result)
class ImageToMask: mask_to_image = execute # TODO: remove
class ImageToMask(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="ImageToMask",
"image": ("IMAGE",), display_name="Convert Image to Mask",
"channel": (["red", "green", "blue", "alpha"],), category="mask",
} inputs=[
} IO.Image.Input("image"),
IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask" @classmethod
def execute(cls, image, channel) -> IO.NodeOutput:
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, image, channel):
channels = ["red", "green", "blue", "alpha"] channels = ["red", "green", "blue", "alpha"]
mask = image[:, :, :, channels.index(channel)] mask = image[:, :, :, channels.index(channel)]
return (mask,) return IO.NodeOutput(mask)
class ImageColorToMask: image_to_mask = execute # TODO: remove
class ImageColorToMask(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="ImageColorToMask",
"image": ("IMAGE",), category="mask",
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), inputs=[
} IO.Image.Input("image"),
} IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask" @classmethod
def execute(cls, image, color) -> IO.NodeOutput:
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, image, color):
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2] temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
mask = torch.where(temp == color, 1.0, 0).float() mask = torch.where(temp == color, 1.0, 0).float()
return (mask,) return IO.NodeOutput(mask)
class SolidMask: image_to_mask = execute # TODO: remove
class SolidMask(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="SolidMask",
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), category="mask",
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), inputs=[
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01),
} IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
} IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask" @classmethod
def execute(cls, value, width, height) -> IO.NodeOutput:
RETURN_TYPES = ("MASK",)
FUNCTION = "solid"
def solid(self, value, width, height):
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
return (out,) return IO.NodeOutput(out)
class InvertMask: solid = execute # TODO: remove
class InvertMask(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="InvertMask",
"mask": ("MASK",), category="mask",
} inputs=[
} IO.Mask.Input("mask"),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask" @classmethod
def execute(cls, mask) -> IO.NodeOutput:
RETURN_TYPES = ("MASK",)
FUNCTION = "invert"
def invert(self, mask):
out = 1.0 - mask out = 1.0 - mask
return (out,) return IO.NodeOutput(out)
class CropMask: invert = execute # TODO: remove
class CropMask(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="CropMask",
"mask": ("MASK",), category="mask",
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), inputs=[
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), IO.Mask.Input("mask"),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
} IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
} IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask" @classmethod
def execute(cls, mask, x, y, width, height) -> IO.NodeOutput:
RETURN_TYPES = ("MASK",)
FUNCTION = "crop"
def crop(self, mask, x, y, width, height):
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
out = mask[:, y:y + height, x:x + width] out = mask[:, y:y + height, x:x + width]
return (out,) return IO.NodeOutput(out)
class MaskComposite: crop = execute # TODO: remove
class MaskComposite(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="MaskComposite",
"destination": ("MASK",), category="mask",
"source": ("MASK",), inputs=[
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), IO.Mask.Input("destination"),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), IO.Mask.Input("source"),
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],), IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
} IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
} IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask" @classmethod
def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput:
RETURN_TYPES = ("MASK",)
FUNCTION = "combine"
def combine(self, destination, source, x, y, operation):
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
source = source.reshape((-1, source.shape[-2], source.shape[-1])) source = source.reshape((-1, source.shape[-2], source.shape[-1]))
@ -267,28 +277,29 @@ class MaskComposite:
output = torch.clamp(output, 0.0, 1.0) output = torch.clamp(output, 0.0, 1.0)
return (output,) return IO.NodeOutput(output)
class FeatherMask: combine = execute # TODO: remove
class FeatherMask(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="FeatherMask",
"mask": ("MASK",), category="mask",
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), inputs=[
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), IO.Mask.Input("mask"),
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
} IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
} IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask" @classmethod
def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput:
RETURN_TYPES = ("MASK",)
FUNCTION = "feather"
def feather(self, mask, left, top, right, bottom):
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
left = min(left, output.shape[-1]) left = min(left, output.shape[-1])
@ -312,26 +323,28 @@ class FeatherMask:
feather_rate = (y + 1) / bottom feather_rate = (y + 1) / bottom
output[:, -y, :] *= feather_rate output[:, -y, :] *= feather_rate
return (output,) return IO.NodeOutput(output)
class GrowMask: feather = execute # TODO: remove
class GrowMask(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="GrowMask",
"mask": ("MASK",), display_name="Grow Mask",
"expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), category="mask",
"tapered_corners": ("BOOLEAN", {"default": True}), inputs=[
}, IO.Mask.Input("mask"),
} IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("tapered_corners", default=True),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask" @classmethod
def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput:
RETURN_TYPES = ("MASK",)
FUNCTION = "expand_mask"
def expand_mask(self, mask, expand, tapered_corners):
c = 0 if tapered_corners else 1 c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c], kernel = np.array([[c, 1, c],
[1, 1, 1], [1, 1, 1],
@ -347,69 +360,74 @@ class GrowMask:
output = scipy.ndimage.grey_dilation(output, footprint=kernel) output = scipy.ndimage.grey_dilation(output, footprint=kernel)
output = torch.from_numpy(output) output = torch.from_numpy(output)
out.append(output) out.append(output)
return (torch.stack(out, dim=0),) return IO.NodeOutput(torch.stack(out, dim=0))
class ThresholdMask: expand_mask = execute # TODO: remove
class ThresholdMask(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="ThresholdMask",
"mask": ("MASK",), category="mask",
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), inputs=[
} IO.Mask.Input("mask"),
} IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01),
],
outputs=[IO.Mask.Output()],
)
CATEGORY = "mask" @classmethod
def execute(cls, mask, value) -> IO.NodeOutput:
RETURN_TYPES = ("MASK",)
FUNCTION = "image_to_mask"
def image_to_mask(self, mask, value):
mask = (mask > value).float() mask = (mask > value).float()
return (mask,) return IO.NodeOutput(mask)
image_to_mask = execute # TODO: remove
# Mask Preview - original implement from # Mask Preview - original implement from
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81 # https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes # upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
class MaskPreview(nodes.SaveImage): class MaskPreview(IO.ComfyNode):
def __init__(self): @classmethod
self.output_dir = folder_paths.get_temp_directory() def define_schema(cls):
self.type = "temp" return IO.Schema(
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) node_id="MaskPreview",
self.compress_level = 4 display_name="Preview Mask",
category="mask",
description="Saves the input images to your ComfyUI output directory.",
inputs=[
IO.Mask.Input("mask"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput:
return { return IO.NodeOutput(ui=UI.PreviewMask(mask))
"required": {"mask": ("MASK",), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
FUNCTION = "execute"
CATEGORY = "mask"
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
NODE_CLASS_MAPPINGS = { class MaskExtension(ComfyExtension):
"LatentCompositeMasked": LatentCompositeMasked, @override
"ImageCompositeMasked": ImageCompositeMasked, async def get_node_list(self) -> list[type[IO.ComfyNode]]:
"MaskToImage": MaskToImage, return [
"ImageToMask": ImageToMask, LatentCompositeMasked,
"ImageColorToMask": ImageColorToMask, ImageCompositeMasked,
"SolidMask": SolidMask, MaskToImage,
"InvertMask": InvertMask, ImageToMask,
"CropMask": CropMask, ImageColorToMask,
"MaskComposite": MaskComposite, SolidMask,
"FeatherMask": FeatherMask, InvertMask,
"GrowMask": GrowMask, CropMask,
"ThresholdMask": ThresholdMask, MaskComposite,
"MaskPreview": MaskPreview FeatherMask,
} GrowMask,
ThresholdMask,
MaskPreview,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"ImageToMask": "Convert Image to Mask", async def comfy_entrypoint() -> MaskExtension:
"MaskToImage": "Convert Mask to Image", return MaskExtension()
}

View File

@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode):
return io.NodeOutput(m) return io.NodeOutput(m)
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"PatchModelAddDownscale": "",
}
class ModelDownscaleExtension(ComfyExtension): class ModelDownscaleExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:

View File

@ -63,9 +63,9 @@ def cuda_malloc_supported():
return True return True
if not args.cuda_malloc: version = ""
try:
version = "" try:
torch_spec = importlib.util.find_spec("torch") torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations: for folder in torch_spec.submodule_search_locations:
ver_file = os.path.join(folder, "version.py") ver_file = os.path.join(folder, "version.py")
@ -74,7 +74,11 @@ if not args.cuda_malloc:
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
version = module.__version__ version = module.__version__
except:
pass
if not args.cuda_malloc:
try:
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
args.cuda_malloc = cuda_malloc_supported() args.cuda_malloc = cuda_malloc_supported()
@ -90,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc:
env_var += ",backend:cudaMallocAsync" env_var += ",backend:cudaMallocAsync"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
def get_torch_version_noimport():
return str(version)

View File

@ -167,6 +167,9 @@ if __name__ == "__main__":
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
import cuda_malloc import cuda_malloc
if "rocm" in cuda_malloc.get_torch_version_noimport():
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
if 'torch' in sys.modules: if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")

View File

@ -1 +1 @@
comfyui_manager==4.0.3b3 comfyui_manager==4.0.3b4

View File

@ -970,7 +970,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ), "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (["default", "cpu"], {"advanced": True}),
@ -2357,6 +2357,7 @@ async def init_builtin_extra_nodes():
"nodes_rope.py", "nodes_rope.py",
"nodes_logic.py", "nodes_logic.py",
"nodes_nop.py", "nodes_nop.py",
"nodes_kandinsky5.py",
] ]
import_failed = [] import_failed = []