mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 19:13:02 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
277e46473d
26
README.md
26
README.md
@ -320,6 +320,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
|||||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||||
2. Launch ComfyUI by running `python main.py`
|
2. Launch ComfyUI by running `python main.py`
|
||||||
|
|
||||||
|
|
||||||
|
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
|
||||||
|
|
||||||
|
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
1. Install the manager dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -r manager_requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
|
||||||
|
```bash
|
||||||
|
python main.py --enable-manager
|
||||||
|
```
|
||||||
|
|
||||||
|
### Command Line Options
|
||||||
|
|
||||||
|
| Flag | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||||
|
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||||
|
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||||
|
|
||||||
|
|
||||||
# Running
|
# Running
|
||||||
|
|
||||||
```python main.py```
|
```python main.py```
|
||||||
|
|||||||
@ -51,26 +51,36 @@ class ContextHandlerABC(ABC):
|
|||||||
|
|
||||||
|
|
||||||
class IndexListContextWindow(ContextWindowABC):
|
class IndexListContextWindow(ContextWindowABC):
|
||||||
def __init__(self, index_list: list[int], dim: int=0):
|
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||||
self.index_list = index_list
|
self.index_list = index_list
|
||||||
self.context_length = len(index_list)
|
self.context_length = len(index_list)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
self.total_frames = total_frames
|
||||||
|
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||||
|
|
||||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = self.dim
|
dim = self.dim
|
||||||
if dim == 0 and full.shape[dim] == 1:
|
if dim == 0 and full.shape[dim] == 1:
|
||||||
return full
|
return full
|
||||||
idx = [slice(None)] * dim + [self.index_list]
|
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||||
return full[idx].to(device)
|
window = full[idx]
|
||||||
|
if retain_index_list:
|
||||||
|
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||||
|
window[idx] = full[idx]
|
||||||
|
return window.to(device)
|
||||||
|
|
||||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||||
if dim is None:
|
if dim is None:
|
||||||
dim = self.dim
|
dim = self.dim
|
||||||
idx = [slice(None)] * dim + [self.index_list]
|
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||||
full[idx] += to_add
|
full[idx] += to_add
|
||||||
return full
|
return full
|
||||||
|
|
||||||
|
def get_region_index(self, num_regions: int) -> int:
|
||||||
|
region_idx = int(self.center_ratio * num_regions)
|
||||||
|
return min(max(region_idx, 0), num_regions - 1)
|
||||||
|
|
||||||
|
|
||||||
class IndexListCallbacks:
|
class IndexListCallbacks:
|
||||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||||
@ -94,7 +104,8 @@ class ContextFuseMethod:
|
|||||||
|
|
||||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||||
class IndexListContextHandler(ContextHandlerABC):
|
class IndexListContextHandler(ContextHandlerABC):
|
||||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||||
|
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||||
self.context_schedule = context_schedule
|
self.context_schedule = context_schedule
|
||||||
self.fuse_method = fuse_method
|
self.fuse_method = fuse_method
|
||||||
self.context_length = context_length
|
self.context_length = context_length
|
||||||
@ -103,13 +114,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
self.closed_loop = closed_loop
|
self.closed_loop = closed_loop
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self._step = 0
|
self._step = 0
|
||||||
|
self.freenoise = freenoise
|
||||||
|
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||||
|
self.split_conds_to_windows = split_conds_to_windows
|
||||||
|
|
||||||
self.callbacks = {}
|
self.callbacks = {}
|
||||||
|
|
||||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||||
if x_in.size(self.dim) > self.context_length:
|
if x_in.size(self.dim) > self.context_length:
|
||||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||||
|
if self.cond_retain_index_list:
|
||||||
|
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -123,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return None
|
return None
|
||||||
# reuse or resize cond items to match context requirements
|
# reuse or resize cond items to match context requirements
|
||||||
resized_cond = []
|
resized_cond = []
|
||||||
|
# if multiple conds, split based on primary region
|
||||||
|
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||||
|
region = window.get_region_index(len(cond_in))
|
||||||
|
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
||||||
|
cond_in = [cond_in[region]]
|
||||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||||
for actual_cond in cond_in:
|
for actual_cond in cond_in:
|
||||||
resized_actual_cond = actual_cond.copy()
|
resized_actual_cond = actual_cond.copy()
|
||||||
@ -146,12 +167,19 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||||
for cond_key, cond_value in new_cond_item.items():
|
for cond_key, cond_value in new_cond_item.items():
|
||||||
if isinstance(cond_value, torch.Tensor):
|
if isinstance(cond_value, torch.Tensor):
|
||||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||||
|
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||||
|
# Handle audio_embed (temporal dim is 1)
|
||||||
|
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
|
audio_cond = cond_value.cond
|
||||||
|
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||||
# if has cond that is a Tensor, check if needs to be subset
|
# if has cond that is a Tensor, check if needs to be subset
|
||||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
||||||
|
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
||||||
elif cond_key == "num_video_frames": # for SVD
|
elif cond_key == "num_video_frames": # for SVD
|
||||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||||
new_cond_item[cond_key].cond = window.context_length
|
new_cond_item[cond_key].cond = window.context_length
|
||||||
@ -164,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return resized_cond
|
return resized_cond
|
||||||
|
|
||||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||||
matches = torch.nonzero(mask)
|
matches = torch.nonzero(mask)
|
||||||
if torch.numel(matches) == 0:
|
if torch.numel(matches) == 0:
|
||||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||||
@ -173,7 +201,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||||
return context_windows
|
return context_windows
|
||||||
|
|
||||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
@ -250,8 +278,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
prev_weight = (bias_total / (bias_total + bias))
|
prev_weight = (bias_total / (bias_total + bias))
|
||||||
new_weight = (bias / (bias_total + bias))
|
new_weight = (bias / (bias_total + bias))
|
||||||
# account for dims of tensors
|
# account for dims of tensors
|
||||||
idx_window = [slice(None)] * self.dim + [idx]
|
idx_window = tuple([slice(None)] * self.dim + [idx])
|
||||||
pos_window = [slice(None)] * self.dim + [pos]
|
pos_window = tuple([slice(None)] * self.dim + [pos])
|
||||||
# apply new values
|
# apply new values
|
||||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||||
biases_final[i][idx] = bias_total + bias
|
biases_final[i][idx] = bias_total + bias
|
||||||
@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
|
||||||
|
model_options = extra_args.get("model_options", None)
|
||||||
|
if model_options is None:
|
||||||
|
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||||
|
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||||
|
if handler is None:
|
||||||
|
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||||
|
if not handler.freenoise:
|
||||||
|
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||||
|
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||||
|
|
||||||
|
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||||
|
model.add_wrapper_with_key(
|
||||||
|
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||||
|
"ContextWindows_sampler_sample",
|
||||||
|
_sampler_sample_wrapper
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||||
total_dims = len(x_in.shape)
|
total_dims = len(x_in.shape)
|
||||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||||
@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
|
|||||||
for i in range(len(window)):
|
for i in range(len(window)):
|
||||||
# 2) add end_delta to each val to slide windows to end
|
# 2) add end_delta to each val to slide windows to end
|
||||||
window[i] = window[i] + end_delta
|
window[i] = window[i] + end_delta
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
||||||
|
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
||||||
|
logging.info("Context windows: Applying FreeNoise")
|
||||||
|
generator = torch.Generator(device='cpu').manual_seed(seed)
|
||||||
|
latent_video_length = noise.shape[dim]
|
||||||
|
delta = context_length - context_overlap
|
||||||
|
|
||||||
|
for start_idx in range(0, latent_video_length - context_length, delta):
|
||||||
|
place_idx = start_idx + context_length
|
||||||
|
|
||||||
|
actual_delta = min(delta, latent_video_length - place_idx)
|
||||||
|
if actual_delta <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
||||||
|
|
||||||
|
source_slice = [slice(None)] * noise.ndim
|
||||||
|
source_slice[dim] = list_idx
|
||||||
|
target_slice = [slice(None)] * noise.ndim
|
||||||
|
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
||||||
|
|
||||||
|
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
||||||
|
|
||||||
|
return noise
|
||||||
|
|||||||
413
comfy/ldm/kandinsky5/model.py
Normal file
413
comfy/ldm/kandinsky5/model.py
Normal file
@ -0,0 +1,413 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
|
|
||||||
|
def attention(q, k, v, heads, transformer_options={}):
|
||||||
|
return optimized_attention(
|
||||||
|
q.transpose(1, 2),
|
||||||
|
k.transpose(1, 2),
|
||||||
|
v.transpose(1, 2),
|
||||||
|
heads=heads,
|
||||||
|
skip_reshape=True,
|
||||||
|
transformer_options=transformer_options
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||||
|
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||||
|
|
||||||
|
def apply_gate_sum(x, out, gate):
|
||||||
|
return torch.addcmul(x, gate, out)
|
||||||
|
|
||||||
|
def get_shift_scale_gate(params):
|
||||||
|
shift, scale, gate = torch.chunk(params, 3, dim=-1)
|
||||||
|
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||||
|
|
||||||
|
def get_freqs(dim, max_period=10000.0):
|
||||||
|
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeEmbeddings(nn.Module):
|
||||||
|
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
assert model_dim % 2 == 0
|
||||||
|
self.model_dim = model_dim
|
||||||
|
self.max_period = max_period
|
||||||
|
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, timestep, dtype):
|
||||||
|
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||||
|
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
||||||
|
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||||
|
return time_embed
|
||||||
|
|
||||||
|
|
||||||
|
class TextEmbeddings(nn.Module):
|
||||||
|
def __init__(self, text_dim, model_dim, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, text_embed):
|
||||||
|
text_embed = self.in_layer(text_embed)
|
||||||
|
return self.norm(text_embed).type_as(text_embed)
|
||||||
|
|
||||||
|
|
||||||
|
class VisualEmbeddings(nn.Module):
|
||||||
|
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
||||||
|
B, T, H, W, dim = x.shape
|
||||||
|
pt, ph, pw = self.patch_size
|
||||||
|
|
||||||
|
x = x.view(
|
||||||
|
B,
|
||||||
|
T // pt, pt,
|
||||||
|
H // ph, ph,
|
||||||
|
W // pw, pw,
|
||||||
|
dim,
|
||||||
|
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
||||||
|
|
||||||
|
return self.in_layer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Modulation(nn.Module):
|
||||||
|
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.activation = nn.SiLU()
|
||||||
|
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.out_layer(self.activation(x))
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
assert num_channels % head_dim == 0
|
||||||
|
self.num_heads = num_channels // head_dim
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.num_chunks = 2
|
||||||
|
|
||||||
|
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
||||||
|
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
|
return apply_rope1(norm_fn(result), freqs)
|
||||||
|
|
||||||
|
def _forward(self, x, freqs, transformer_options={}):
|
||||||
|
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||||
|
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||||
|
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
|
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||||
|
return self.out_layer(out)
|
||||||
|
|
||||||
|
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||||
|
def process_chunks(proj_fn, norm_fn):
|
||||||
|
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||||
|
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||||
|
chunks = []
|
||||||
|
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
||||||
|
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||||
|
return torch.cat(chunks, dim=1)
|
||||||
|
|
||||||
|
q = process_chunks(self.to_query, self.query_norm)
|
||||||
|
k = process_chunks(self.to_key, self.key_norm)
|
||||||
|
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
|
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||||
|
return self.out_layer(out)
|
||||||
|
|
||||||
|
def forward(self, x, freqs, transformer_options={}):
|
||||||
|
if x.shape[1] > 8192:
|
||||||
|
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||||
|
else:
|
||||||
|
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossAttention(SelfAttention):
|
||||||
|
def get_qkv(self, x, context):
|
||||||
|
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||||
|
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||||
|
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def forward(self, x, context, transformer_options={}):
|
||||||
|
q, k, v = self.get_qkv(x, context)
|
||||||
|
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||||
|
return self.out_layer(out)
|
||||||
|
|
||||||
|
|
||||||
|
class FeedForward(nn.Module):
|
||||||
|
def __init__(self, dim, ff_dim, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.activation = nn.GELU()
|
||||||
|
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.num_chunks = 4
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
return self.out_layer(self.activation(self.in_layer(x)))
|
||||||
|
|
||||||
|
def _forward_chunked(self, x):
|
||||||
|
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||||
|
output_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
output_chunks.append(self._forward(chunk))
|
||||||
|
return torch.cat(output_chunks, dim=1)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if x.shape[1] > 8192:
|
||||||
|
return self._forward_chunked(x)
|
||||||
|
else:
|
||||||
|
return self._forward(x)
|
||||||
|
|
||||||
|
|
||||||
|
class OutLayer(nn.Module):
|
||||||
|
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
|
||||||
|
def forward(self, visual_embed, time_embed):
|
||||||
|
B, T, H, W, _ = visual_embed.shape
|
||||||
|
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||||
|
scale = scale[:, None, None, None, :]
|
||||||
|
shift = shift[:, None, None, None, :]
|
||||||
|
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
|
||||||
|
x = self.out_layer(visual_embed)
|
||||||
|
|
||||||
|
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
|
||||||
|
x = x.view(
|
||||||
|
B, T, H, W,
|
||||||
|
out_dim,
|
||||||
|
self.patch_size[0], self.patch_size[1], self.patch_size[2]
|
||||||
|
)
|
||||||
|
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderBlock(nn.Module):
|
||||||
|
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
|
||||||
|
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
def forward(self, x, time_embed, freqs, transformer_options={}):
|
||||||
|
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||||
|
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||||
|
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||||
|
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
||||||
|
x = apply_gate_sum(x, out, gate)
|
||||||
|
|
||||||
|
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||||
|
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||||
|
out = self.feed_forward(out)
|
||||||
|
x = apply_gate_sum(x, out, gate)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderBlock(nn.Module):
|
||||||
|
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||||
|
super().__init__()
|
||||||
|
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
operations = operation_settings.get("operations")
|
||||||
|
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
|
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||||
|
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||||
|
# self attention
|
||||||
|
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||||
|
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||||
|
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||||
|
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||||
|
# cross attention
|
||||||
|
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||||
|
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||||
|
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
||||||
|
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||||
|
# feed forward
|
||||||
|
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||||
|
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||||
|
visual_out = self.feed_forward(visual_out)
|
||||||
|
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||||
|
return visual_embed
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
|
||||||
|
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
|
||||||
|
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
|
||||||
|
dtype=None, device=None, operations=None, **kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
head_dim = sum(axes_dims)
|
||||||
|
self.rope_scale_factor = rope_scale_factor
|
||||||
|
self.in_visual_dim = in_visual_dim
|
||||||
|
self.model_dim = model_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.visual_embed_dim = visual_embed_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
|
||||||
|
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
|
||||||
|
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
|
||||||
|
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
self.text_transformer_blocks = nn.ModuleList(
|
||||||
|
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.visual_transformer_blocks = nn.ModuleList(
|
||||||
|
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
|
||||||
|
|
||||||
|
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
|
||||||
|
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
|
||||||
|
|
||||||
|
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
|
||||||
|
steps = seq_len if steps is None else steps
|
||||||
|
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
|
||||||
|
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
|
||||||
|
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||||
|
|
||||||
|
patch_size = self.patch_size
|
||||||
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||||
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||||
|
|
||||||
|
if steps_t is None:
|
||||||
|
steps_t = t_len
|
||||||
|
if steps_h is None:
|
||||||
|
steps_h = h_len
|
||||||
|
if steps_w is None:
|
||||||
|
steps_w = w_len
|
||||||
|
|
||||||
|
h_start = 0
|
||||||
|
w_start = 0
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
if rope_options is not None:
|
||||||
|
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
|
||||||
|
t_start += rope_options.get("shift_t", 0.0)
|
||||||
|
h_start += rope_options.get("shift_y", 0.0)
|
||||||
|
w_start += rope_options.get("shift_x", 0.0)
|
||||||
|
else:
|
||||||
|
rope_scale_factor = self.rope_scale_factor
|
||||||
|
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
||||||
|
if h * w >= 14080:
|
||||||
|
rope_scale_factor = (1.0, 3.16, 3.16)
|
||||||
|
|
||||||
|
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
||||||
|
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
||||||
|
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
||||||
|
|
||||||
|
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||||
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||||
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||||
|
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||||
|
|
||||||
|
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
context = self.text_embeddings(context)
|
||||||
|
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
||||||
|
|
||||||
|
for block in self.text_transformer_blocks:
|
||||||
|
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
visual_embed = self.visual_embeddings(x)
|
||||||
|
visual_shape = visual_embed.shape[:-1]
|
||||||
|
visual_embed = visual_embed.flatten(1, -2)
|
||||||
|
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
|
for i, block in enumerate(self.visual_transformer_blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
|
if ("double_block", i) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||||
|
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||||
|
else:
|
||||||
|
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||||
|
return self.out_layer(visual_embed, time_embed)
|
||||||
|
|
||||||
|
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||||
|
original_dims = x.ndim
|
||||||
|
if original_dims == 4:
|
||||||
|
x = x.unsqueeze(2)
|
||||||
|
bs, c, t_len, h, w = x.shape
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
|
if time_dim_replace is not None:
|
||||||
|
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||||
|
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||||
|
|
||||||
|
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
|
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||||
|
if original_dims == 4:
|
||||||
|
out = out.squeeze(2)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||||
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
self._forward,
|
||||||
|
self,
|
||||||
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
|
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||||
@ -377,6 +377,7 @@ class NextDiT(nn.Module):
|
|||||||
z_image_modulation=False,
|
z_image_modulation=False,
|
||||||
time_scale=1.0,
|
time_scale=1.0,
|
||||||
pad_tokens_multiple=None,
|
pad_tokens_multiple=None,
|
||||||
|
clip_text_dim=None,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -447,6 +448,31 @@ class NextDiT(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.clip_text_pooled_proj = None
|
||||||
|
|
||||||
|
if clip_text_dim is not None:
|
||||||
|
self.clip_text_dim = clip_text_dim
|
||||||
|
self.clip_text_pooled_proj = nn.Sequential(
|
||||||
|
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
clip_text_dim,
|
||||||
|
clip_text_dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.time_text_embed = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
min(dim, 1024) + clip_text_dim,
|
||||||
|
min(dim, 1024),
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
JointTransformerBlock(
|
JointTransformerBlock(
|
||||||
@ -585,6 +611,15 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||||
|
|
||||||
|
if self.clip_text_pooled_proj is not None:
|
||||||
|
pooled = kwargs.get("clip_text_pooled", None)
|
||||||
|
if pooled is not None:
|
||||||
|
pooled = self.clip_text_pooled_proj(pooled)
|
||||||
|
else:
|
||||||
|
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||||
|
|||||||
@ -322,6 +322,13 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
|
key_map["{}".format(key_lora)] = k
|
||||||
|
key_map["transformer.{}".format(key_lora)] = k
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model
|
|||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
import comfy.ldm.omnigen.omnigen2
|
import comfy.ldm.omnigen.omnigen2
|
||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
|
import comfy.ldm.kandinsky5.model
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
@ -1109,6 +1110,10 @@ class Lumina2(BaseModel):
|
|||||||
if 'num_tokens' not in out:
|
if 'num_tokens' not in out:
|
||||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||||
|
|
||||||
|
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
||||||
|
if clip_text_pooled is not None:
|
||||||
|
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class WAN21(BaseModel):
|
class WAN21(BaseModel):
|
||||||
@ -1630,3 +1635,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
|
|||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Kandinsky5(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
|
||||||
|
|
||||||
|
def encode_adm(self, **kwargs):
|
||||||
|
return kwargs["pooled_output"]
|
||||||
|
|
||||||
|
def concat_cond(self, **kwargs):
|
||||||
|
noise = kwargs.get("noise", None)
|
||||||
|
device = kwargs["device"]
|
||||||
|
image = torch.zeros_like(noise)
|
||||||
|
|
||||||
|
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||||
|
if mask is None:
|
||||||
|
mask = torch.zeros_like(noise)[:, :1]
|
||||||
|
else:
|
||||||
|
mask = 1.0 - mask
|
||||||
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||||
|
if mask.shape[-3] < noise.shape[-3]:
|
||||||
|
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||||
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
|
|
||||||
|
return torch.cat((image, mask), dim=1)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
attention_mask = kwargs.get("attention_mask", None)
|
||||||
|
if attention_mask is not None:
|
||||||
|
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
time_dim_replace = kwargs.get("time_dim_replace", None)
|
||||||
|
if time_dim_replace is not None:
|
||||||
|
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class Kandinsky5Image(Kandinsky5):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device)
|
||||||
|
|
||||||
|
def concat_cond(self, **kwargs):
|
||||||
|
return None
|
||||||
|
|||||||
@ -423,6 +423,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["axes_lens"] = [300, 512, 512]
|
dit_config["axes_lens"] = [300, 512, 512]
|
||||||
dit_config["rope_theta"] = 10000.0
|
dit_config["rope_theta"] = 10000.0
|
||||||
dit_config["ffn_dim_multiplier"] = 4.0
|
dit_config["ffn_dim_multiplier"] = 4.0
|
||||||
|
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||||
|
if ctd_weight is not None:
|
||||||
|
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||||
elif dit_config["dim"] == 3840: # Z image
|
elif dit_config["dim"] == 3840: # Z image
|
||||||
dit_config["n_heads"] = 30
|
dit_config["n_heads"] = 30
|
||||||
dit_config["n_kv_heads"] = 30
|
dit_config["n_kv_heads"] = 30
|
||||||
@ -611,6 +614,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||||
|
dit_config = {}
|
||||||
|
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["model_dim"] = model_dim
|
||||||
|
if model_dim in [4096, 2560]: # pro video and lite image
|
||||||
|
dit_config["axes_dims"] = (32, 48, 48)
|
||||||
|
if model_dim == 2560: # lite image
|
||||||
|
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
|
||||||
|
elif model_dim == 1792: # lite video
|
||||||
|
dit_config["axes_dims"] = (16, 24, 24)
|
||||||
|
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["image_model"] = "kandinsky5"
|
||||||
|
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
|
||||||
|
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
|
||||||
|
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -1492,6 +1492,20 @@ def extended_fp16_support():
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
LORA_COMPUTE_DTYPES = {}
|
||||||
|
def lora_compute_dtype(device):
|
||||||
|
dtype = LORA_COMPUTE_DTYPES.get(device, None)
|
||||||
|
if dtype is not None:
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
if should_use_fp16(device):
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
LORA_COMPUTE_DTYPES[device] = dtype
|
||||||
|
return dtype
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
|
|||||||
@ -614,10 +614,11 @@ class ModelPatcher:
|
|||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||||
|
|
||||||
|
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(temp_dtype, copy=True)
|
||||||
if convert_func is not None:
|
if convert_func is not None:
|
||||||
temp_weight = convert_func(temp_weight, inplace=True)
|
temp_weight = convert_func(temp_weight, inplace=True)
|
||||||
|
|
||||||
@ -761,6 +762,8 @@ class ModelPatcher:
|
|||||||
key = "{}.{}".format(n, param)
|
key = "{}.{}".format(n, param)
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
self.patch_weight_to_device(key, device_to=device_to)
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
if comfy.model_management.is_device_cuda(device_to):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
m.comfy_patched_weights = True
|
m.comfy_patched_weights = True
|
||||||
|
|||||||
@ -552,6 +552,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
|
|
||||||
weight_scale_key = f"{prefix}weight_scale"
|
weight_scale_key = f"{prefix}weight_scale"
|
||||||
scale = state_dict.pop(weight_scale_key, None)
|
scale = state_dict.pop(weight_scale_key, None)
|
||||||
|
if scale is not None:
|
||||||
|
scale = scale.to(device)
|
||||||
layout_params = {
|
layout_params = {
|
||||||
'scale': scale,
|
'scale': scale,
|
||||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||||
|
|||||||
55
comfy/sd.py
55
comfy/sd.py
@ -54,6 +54,7 @@ import comfy.text_encoders.qwen_image
|
|||||||
import comfy.text_encoders.hunyuan_image
|
import comfy.text_encoders.hunyuan_image
|
||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
import comfy.text_encoders.ovis
|
import comfy.text_encoders.ovis
|
||||||
|
import comfy.text_encoders.kandinsky5
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -98,7 +99,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
@ -129,6 +130,27 @@ class CLIP:
|
|||||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||||
self.patcher.is_clip = True
|
self.patcher.is_clip = True
|
||||||
self.apply_hooks_to_conds = None
|
self.apply_hooks_to_conds = None
|
||||||
|
if len(state_dict) > 0:
|
||||||
|
if isinstance(state_dict, list):
|
||||||
|
for c in state_dict:
|
||||||
|
m, u = self.load_sd(c)
|
||||||
|
if len(m) > 0:
|
||||||
|
logging.warning("clip missing: {}".format(m))
|
||||||
|
|
||||||
|
if len(u) > 0:
|
||||||
|
logging.debug("clip unexpected: {}".format(u))
|
||||||
|
else:
|
||||||
|
m, u = self.load_sd(state_dict, full_model=True)
|
||||||
|
if len(m) > 0:
|
||||||
|
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||||
|
if len(m_filter) > 0:
|
||||||
|
logging.warning("clip missing: {}".format(m))
|
||||||
|
else:
|
||||||
|
logging.debug("clip missing: {}".format(m))
|
||||||
|
|
||||||
|
if len(u) > 0:
|
||||||
|
logging.debug("clip unexpected {}:".format(u))
|
||||||
|
|
||||||
if params['device'] == load_device:
|
if params['device'] == load_device:
|
||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
@ -745,6 +767,8 @@ class VAE:
|
|||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
do_tile = False
|
do_tile = False
|
||||||
|
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||||
|
samples_in = samples_in[:, :, 0]
|
||||||
try:
|
try:
|
||||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||||
@ -962,6 +986,8 @@ class CLIPType(Enum):
|
|||||||
HUNYUAN_IMAGE = 19
|
HUNYUAN_IMAGE = 19
|
||||||
HUNYUAN_VIDEO_15 = 20
|
HUNYUAN_VIDEO_15 = 20
|
||||||
OVIS = 21
|
OVIS = 21
|
||||||
|
KANDINSKY5 = 22
|
||||||
|
KANDINSKY5_IMAGE = 23
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -1210,6 +1236,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
||||||
|
elif clip_type == CLIPType.KANDINSKY5:
|
||||||
|
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
|
||||||
|
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||||
|
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||||
else:
|
else:
|
||||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||||
@ -1225,14 +1257,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
parameters += comfy.utils.calculate_parameters(c)
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
||||||
for c in clip_data:
|
|
||||||
m, u = clip.load_sd(c)
|
|
||||||
if len(m) > 0:
|
|
||||||
logging.warning("clip missing: {}".format(m))
|
|
||||||
|
|
||||||
if len(u) > 0:
|
|
||||||
logging.debug("clip unexpected: {}".format(u))
|
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
def load_gligen(ckpt_path):
|
def load_gligen(ckpt_path):
|
||||||
@ -1362,17 +1387,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
||||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
|
||||||
if len(m) > 0:
|
|
||||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
|
||||||
if len(m_filter) > 0:
|
|
||||||
logging.warning("clip missing: {}".format(m))
|
|
||||||
else:
|
|
||||||
logging.debug("clip missing: {}".format(m))
|
|
||||||
|
|
||||||
if len(u) > 0:
|
|
||||||
logging.debug("clip unexpected {}:".format(u))
|
|
||||||
else:
|
else:
|
||||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import comfy.text_encoders.ace
|
|||||||
import comfy.text_encoders.omnigen2
|
import comfy.text_encoders.omnigen2
|
||||||
import comfy.text_encoders.qwen_image
|
import comfy.text_encoders.qwen_image
|
||||||
import comfy.text_encoders.hunyuan_image
|
import comfy.text_encoders.hunyuan_image
|
||||||
|
import comfy.text_encoders.kandinsky5
|
||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
@ -1474,7 +1475,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
|
|
||||||
|
|
||||||
|
class Kandinsky5(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "kandinsky5",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 10.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.HunyuanVideo
|
||||||
|
|
||||||
|
memory_usage_factor = 1.1 #TODO
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Kandinsky5(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5Image(Kandinsky5):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "kandinsky5",
|
||||||
|
"model_dim": 2560,
|
||||||
|
"visual_embed_dim": 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
latent_format = latent_formats.Flux
|
||||||
|
memory_usage_factor = 1.1 #TODO
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.Kandinsky5Image(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import logging
|
||||||
from . import model_base
|
from . import model_base
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -117,3 +118,7 @@ class BASE:
|
|||||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||||
self.unet_config['dtype'] = dtype
|
self.unet_config['dtype'] = dtype
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
|
||||||
|
return None
|
||||||
|
|||||||
68
comfy/text_encoders/kandinsky5.py
Normal file
68
comfy/text_encoders/kandinsky5.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from comfy import sd1_clip
|
||||||
|
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||||
|
from .llama import Qwen25_7BVLI
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5Tokenizer(QwenImageTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||||
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||||
|
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||||
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
|
||||||
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5TEModel(QwenImageTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||||
|
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs):
|
||||||
|
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
|
||||||
|
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
|
||||||
|
|
||||||
|
return cond, l_pooled, extra
|
||||||
|
|
||||||
|
def set_clip_options(self, options):
|
||||||
|
super().set_clip_options(options)
|
||||||
|
self.clip_l.set_clip_options(options)
|
||||||
|
|
||||||
|
def reset_clip_options(self):
|
||||||
|
super().reset_clip_options()
|
||||||
|
self.clip_l.reset_clip_options()
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||||
|
return self.clip_l.load_sd(sd)
|
||||||
|
else:
|
||||||
|
return super().load_sd(sd)
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class Kandinsky5TEModel_(Kandinsky5TEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return Kandinsky5TEModel_
|
||||||
@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO):
|
|||||||
'''Used by WAN Camera.'''
|
'''Used by WAN Camera.'''
|
||||||
time_dim_concat: NotRequired[torch.Tensor]
|
time_dim_concat: NotRequired[torch.Tensor]
|
||||||
'''Used by WAN Phantom Subject.'''
|
'''Used by WAN Phantom Subject.'''
|
||||||
|
time_dim_replace: NotRequired[torch.Tensor]
|
||||||
|
'''Used by Kandinsky5 I2V.'''
|
||||||
|
|
||||||
CondList = list[tuple[torch.Tensor, PooledDict]]
|
CondList = list[tuple[torch.Tensor, PooledDict]]
|
||||||
Type = CondList
|
Type = CondList
|
||||||
|
|||||||
@ -92,6 +92,7 @@ class PikaImageToVideo(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -152,6 +153,7 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -239,6 +241,7 @@ class PikaScenes(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -323,6 +326,7 @@ class PikAdditionsNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -399,6 +403,7 @@ class PikaSwapsNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -466,6 +471,7 @@ class PikaffectsNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -515,6 +521,7 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
|||||||
IO.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -573,12 +573,14 @@ class EmptyAudio(IO.ComfyNode):
|
|||||||
step=0.01,
|
step=0.01,
|
||||||
tooltip="Duration of the empty audio clip in seconds",
|
tooltip="Duration of the empty audio clip in seconds",
|
||||||
),
|
),
|
||||||
IO.Float.Input(
|
IO.Int.Input(
|
||||||
"sample_rate",
|
"sample_rate",
|
||||||
default=44100,
|
default=44100,
|
||||||
tooltip="Sample rate of the empty audio clip.",
|
tooltip="Sample rate of the empty audio clip.",
|
||||||
|
min=1,
|
||||||
|
max=192000,
|
||||||
),
|
),
|
||||||
IO.Float.Input(
|
IO.Int.Input(
|
||||||
"channels",
|
"channels",
|
||||||
default=2,
|
default=2,
|
||||||
min=1,
|
min=1,
|
||||||
|
|||||||
@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||||
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
|
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
|
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||||
@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||||
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||||
@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
context_overlap=context_overlap,
|
context_overlap=context_overlap,
|
||||||
context_stride=context_stride,
|
context_stride=context_stride,
|
||||||
closed_loop=closed_loop,
|
closed_loop=closed_loop,
|
||||||
dim=dim)
|
dim=dim,
|
||||||
|
freenoise=freenoise,
|
||||||
|
cond_retain_index_list=cond_retain_index_list,
|
||||||
|
split_conds_to_windows=split_conds_to_windows
|
||||||
|
)
|
||||||
# make memory usage calculation only take into account the context window latents
|
# make memory usage calculation only take into account the context window latents
|
||||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||||
|
if freenoise: # no other use for this wrapper at this time
|
||||||
|
comfy.context_windows.create_sampler_sample_wrapper(model)
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||||
@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
|||||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
|
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
|
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
]
|
]
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
|
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
||||||
|
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
|
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||||
|
|
||||||
|
|
||||||
class ContextWindowsExtension(ComfyExtension):
|
class ContextWindowsExtension(ComfyExtension):
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
|
|
||||||
def Fourier_filter(x, threshold, scale):
|
def Fourier_filter(x, threshold, scale):
|
||||||
# FFT
|
# FFT
|
||||||
@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale):
|
|||||||
return x_filtered.to(x.dtype)
|
return x_filtered.to(x.dtype)
|
||||||
|
|
||||||
|
|
||||||
class FreeU:
|
class FreeU(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return IO.Schema(
|
||||||
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="FreeU",
|
||||||
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="model_patches/unet",
|
||||||
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
IO.Model.Input("model"),
|
||||||
}}
|
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("MODEL",)
|
IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||||
|
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "model_patches/unet"
|
@classmethod
|
||||||
|
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
|
||||||
def patch(self, model, b1, b2, s1, s2):
|
|
||||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||||
on_cpu_devices = {}
|
on_cpu_devices = {}
|
||||||
@ -59,23 +66,31 @@ class FreeU:
|
|||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_output_block_patch(output_block_patch)
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
return (m, )
|
return IO.NodeOutput(m)
|
||||||
|
|
||||||
class FreeU_V2:
|
patch = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class FreeU_V2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return IO.Schema(
|
||||||
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
|
node_id="FreeU_V2",
|
||||||
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
|
category="model_patches/unet",
|
||||||
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
|
inputs=[
|
||||||
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
IO.Model.Input("model"),
|
||||||
}}
|
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
|
||||||
RETURN_TYPES = ("MODEL",)
|
IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
|
||||||
FUNCTION = "patch"
|
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||||
|
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "model_patches/unet"
|
@classmethod
|
||||||
|
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
|
||||||
def patch(self, model, b1, b2, s1, s2):
|
|
||||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||||
on_cpu_devices = {}
|
on_cpu_devices = {}
|
||||||
@ -105,9 +120,19 @@ class FreeU_V2:
|
|||||||
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
m.set_model_output_block_patch(output_block_patch)
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
return (m, )
|
return IO.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
patch = execute # TODO: remove
|
||||||
"FreeU": FreeU,
|
|
||||||
"FreeU_V2": FreeU_V2,
|
|
||||||
}
|
class FreelunchExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
FreeU,
|
||||||
|
FreeU_V2,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> FreelunchExtension:
|
||||||
|
return FreelunchExtension()
|
||||||
|
|||||||
136
comfy_extras/nodes_kandinsky5.py
Normal file
136
comfy_extras/nodes_kandinsky5.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
import nodes
|
||||||
|
import node_helpers
|
||||||
|
import torch
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.utils
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5ImageToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Kandinsky5ImageToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
|
||||||
|
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||||
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
cond_latent_out = {}
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
encoded = vae.encode(start_image[:, :, :, :3])
|
||||||
|
cond_latent_out["samples"] = encoded
|
||||||
|
|
||||||
|
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||||
|
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
|
||||||
|
|
||||||
|
|
||||||
|
def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
|
||||||
|
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
|
||||||
|
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
|
||||||
|
|
||||||
|
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
|
||||||
|
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
|
||||||
|
|
||||||
|
# normalization
|
||||||
|
normalized = (source - source_mean) / (source_std + 1e-8)
|
||||||
|
normalized = normalized * reference_std + reference_mean
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
class NormalizeVideoLatentStart(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="NormalizeVideoLatentStart",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("latent"),
|
||||||
|
io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
|
||||||
|
io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
|
||||||
|
if latent["samples"].shape[2] <= 1:
|
||||||
|
return io.NodeOutput(latent)
|
||||||
|
s = latent.copy()
|
||||||
|
samples = latent["samples"].clone()
|
||||||
|
|
||||||
|
first_frames = samples[:, :, :start_frame_count]
|
||||||
|
reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
|
||||||
|
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
|
||||||
|
|
||||||
|
samples[:, :, :start_frame_count] = normalized_first_frames
|
||||||
|
s["samples"] = samples
|
||||||
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPTextEncodeKandinsky5(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="CLIPTextEncodeKandinsky5",
|
||||||
|
category="advanced/conditioning/kandinsky5",
|
||||||
|
inputs=[
|
||||||
|
io.Clip.Input("clip"),
|
||||||
|
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||||
|
io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
|
||||||
|
tokens = clip.tokenize(clip_l)
|
||||||
|
tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
|
||||||
|
|
||||||
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
|
||||||
|
|
||||||
|
class Kandinsky5Extension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
Kandinsky5ImageToVideo,
|
||||||
|
NormalizeVideoLatentStart,
|
||||||
|
CLIPTextEncodeKandinsky5,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> Kandinsky5Extension:
|
||||||
|
return Kandinsky5Extension()
|
||||||
@ -4,7 +4,7 @@ import torch
|
|||||||
import nodes
|
import nodes
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
import logging
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||||
if latent.shape[1:] != target_shape[1:]:
|
if latent.shape[1:] != target_shape[1:]:
|
||||||
@ -388,6 +388,42 @@ class LatentOperationSharpen(io.ComfyNode):
|
|||||||
return luminance * sharpened
|
return luminance * sharpened
|
||||||
return io.NodeOutput(sharpen)
|
return io.NodeOutput(sharpen)
|
||||||
|
|
||||||
|
class ReplaceVideoLatentFrames(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ReplaceVideoLatentFrames",
|
||||||
|
category="latent/batch",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
|
||||||
|
io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
|
||||||
|
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, destination, index, source=None) -> io.NodeOutput:
|
||||||
|
if source is None:
|
||||||
|
return io.NodeOutput(destination)
|
||||||
|
dest_frames = destination["samples"].shape[2]
|
||||||
|
source_frames = source["samples"].shape[2]
|
||||||
|
if index < 0:
|
||||||
|
index = dest_frames + index
|
||||||
|
if index > dest_frames:
|
||||||
|
logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
|
||||||
|
return io.NodeOutput(destination)
|
||||||
|
if index + source_frames > dest_frames:
|
||||||
|
logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
|
||||||
|
return io.NodeOutput(destination)
|
||||||
|
s = source.copy()
|
||||||
|
s_source = source["samples"]
|
||||||
|
s_destination = destination["samples"].clone()
|
||||||
|
s_destination[:, :, index:index + s_source.shape[2]] = s_source
|
||||||
|
s["samples"] = s_destination
|
||||||
|
return io.NodeOutput(s)
|
||||||
|
|
||||||
class LatentExtension(ComfyExtension):
|
class LatentExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
@ -405,6 +441,7 @@ class LatentExtension(ComfyExtension):
|
|||||||
LatentApplyOperationCFG,
|
LatentApplyOperationCFG,
|
||||||
LatentOperationTonemapReinhard,
|
LatentOperationTonemapReinhard,
|
||||||
LatentOperationSharpen,
|
LatentOperationSharpen,
|
||||||
|
ReplaceVideoLatentFrames
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,11 +3,10 @@ import scipy.ndimage
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import node_helpers
|
import node_helpers
|
||||||
import folder_paths
|
from typing_extensions import override
|
||||||
import random
|
from comfy_api.latest import ComfyExtension, IO, UI
|
||||||
|
|
||||||
import nodes
|
import nodes
|
||||||
from nodes import MAX_RESOLUTION
|
|
||||||
|
|
||||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||||
source = source.to(destination.device)
|
source = source.to(destination.device)
|
||||||
@ -46,202 +45,213 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
|||||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
class LatentCompositeMasked:
|
class LatentCompositeMasked(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="LatentCompositeMasked",
|
||||||
"destination": ("LATENT",),
|
category="latent",
|
||||||
"source": ("LATENT",),
|
inputs=[
|
||||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
IO.Latent.Input("destination"),
|
||||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
IO.Latent.Input("source"),
|
||||||
"resize_source": ("BOOLEAN", {"default": False}),
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
},
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
|
||||||
"optional": {
|
IO.Boolean.Input("resize_source", default=False),
|
||||||
"mask": ("MASK",),
|
IO.Mask.Input("mask", optional=True),
|
||||||
}
|
],
|
||||||
}
|
outputs=[IO.Latent.Output()],
|
||||||
RETURN_TYPES = ("LATENT",)
|
)
|
||||||
FUNCTION = "composite"
|
|
||||||
|
|
||||||
CATEGORY = "latent"
|
@classmethod
|
||||||
|
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
||||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
|
||||||
output = destination.copy()
|
output = destination.copy()
|
||||||
destination = destination["samples"].clone()
|
destination = destination["samples"].clone()
|
||||||
source = source["samples"]
|
source = source["samples"]
|
||||||
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
|
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
|
||||||
return (output,)
|
return IO.NodeOutput(output)
|
||||||
|
|
||||||
class ImageCompositeMasked:
|
composite = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ImageCompositeMasked(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="ImageCompositeMasked",
|
||||||
"destination": ("IMAGE",),
|
category="image",
|
||||||
"source": ("IMAGE",),
|
inputs=[
|
||||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Image.Input("destination"),
|
||||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Image.Input("source"),
|
||||||
"resize_source": ("BOOLEAN", {"default": False}),
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
},
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
"optional": {
|
IO.Boolean.Input("resize_source", default=False),
|
||||||
"mask": ("MASK",),
|
IO.Mask.Input("mask", optional=True),
|
||||||
}
|
],
|
||||||
}
|
outputs=[IO.Image.Output()],
|
||||||
RETURN_TYPES = ("IMAGE",)
|
)
|
||||||
FUNCTION = "composite"
|
|
||||||
|
|
||||||
CATEGORY = "image"
|
@classmethod
|
||||||
|
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
||||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
|
||||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||||
destination = destination.clone().movedim(-1, 1)
|
destination = destination.clone().movedim(-1, 1)
|
||||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||||
return (output,)
|
return IO.NodeOutput(output)
|
||||||
|
|
||||||
class MaskToImage:
|
composite = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class MaskToImage(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="MaskToImage",
|
||||||
"mask": ("MASK",),
|
display_name="Convert Mask to Image",
|
||||||
}
|
category="mask",
|
||||||
}
|
inputs=[
|
||||||
|
IO.Mask.Input("mask"),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("IMAGE",)
|
|
||||||
FUNCTION = "mask_to_image"
|
|
||||||
|
|
||||||
def mask_to_image(self, mask):
|
|
||||||
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||||
return (result,)
|
return IO.NodeOutput(result)
|
||||||
|
|
||||||
class ImageToMask:
|
mask_to_image = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="ImageToMask",
|
||||||
"image": ("IMAGE",),
|
display_name="Convert Image to Mask",
|
||||||
"channel": (["red", "green", "blue", "alpha"],),
|
category="mask",
|
||||||
}
|
inputs=[
|
||||||
}
|
IO.Image.Input("image"),
|
||||||
|
IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, image, channel) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
FUNCTION = "image_to_mask"
|
|
||||||
|
|
||||||
def image_to_mask(self, image, channel):
|
|
||||||
channels = ["red", "green", "blue", "alpha"]
|
channels = ["red", "green", "blue", "alpha"]
|
||||||
mask = image[:, :, :, channels.index(channel)]
|
mask = image[:, :, :, channels.index(channel)]
|
||||||
return (mask,)
|
return IO.NodeOutput(mask)
|
||||||
|
|
||||||
class ImageColorToMask:
|
image_to_mask = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ImageColorToMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="ImageColorToMask",
|
||||||
"image": ("IMAGE",),
|
category="mask",
|
||||||
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
|
inputs=[
|
||||||
}
|
IO.Image.Input("image"),
|
||||||
}
|
IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, image, color) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
FUNCTION = "image_to_mask"
|
|
||||||
|
|
||||||
def image_to_mask(self, image, color):
|
|
||||||
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
|
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
|
||||||
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
|
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
|
||||||
mask = torch.where(temp == color, 1.0, 0).float()
|
mask = torch.where(temp == color, 1.0, 0).float()
|
||||||
return (mask,)
|
return IO.NodeOutput(mask)
|
||||||
|
|
||||||
class SolidMask:
|
image_to_mask = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class SolidMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="SolidMask",
|
||||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
category="mask",
|
||||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
inputs=[
|
||||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
}
|
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, value, width, height) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "solid"
|
|
||||||
|
|
||||||
def solid(self, value, width, height):
|
|
||||||
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
|
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
|
||||||
return (out,)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
class InvertMask:
|
solid = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class InvertMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="InvertMask",
|
||||||
"mask": ("MASK",),
|
category="mask",
|
||||||
}
|
inputs=[
|
||||||
}
|
IO.Mask.Input("mask"),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "invert"
|
|
||||||
|
|
||||||
def invert(self, mask):
|
|
||||||
out = 1.0 - mask
|
out = 1.0 - mask
|
||||||
return (out,)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
class CropMask:
|
invert = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class CropMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="CropMask",
|
||||||
"mask": ("MASK",),
|
category="mask",
|
||||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
inputs=[
|
||||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Mask.Input("mask"),
|
||||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask, x, y, width, height) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "crop"
|
|
||||||
|
|
||||||
def crop(self, mask, x, y, width, height):
|
|
||||||
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
||||||
out = mask[:, y:y + height, x:x + width]
|
out = mask[:, y:y + height, x:x + width]
|
||||||
return (out,)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
class MaskComposite:
|
crop = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class MaskComposite(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="MaskComposite",
|
||||||
"destination": ("MASK",),
|
category="mask",
|
||||||
"source": ("MASK",),
|
inputs=[
|
||||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Mask.Input("destination"),
|
||||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Mask.Input("source"),
|
||||||
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "combine"
|
|
||||||
|
|
||||||
def combine(self, destination, source, x, y, operation):
|
|
||||||
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
|
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
|
||||||
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
|
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
|
||||||
|
|
||||||
@ -267,28 +277,29 @@ class MaskComposite:
|
|||||||
|
|
||||||
output = torch.clamp(output, 0.0, 1.0)
|
output = torch.clamp(output, 0.0, 1.0)
|
||||||
|
|
||||||
return (output,)
|
return IO.NodeOutput(output)
|
||||||
|
|
||||||
class FeatherMask:
|
combine = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class FeatherMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="FeatherMask",
|
||||||
"mask": ("MASK",),
|
category="mask",
|
||||||
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
inputs=[
|
||||||
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Mask.Input("mask"),
|
||||||
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
}
|
IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "feather"
|
|
||||||
|
|
||||||
def feather(self, mask, left, top, right, bottom):
|
|
||||||
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
||||||
|
|
||||||
left = min(left, output.shape[-1])
|
left = min(left, output.shape[-1])
|
||||||
@ -312,26 +323,28 @@ class FeatherMask:
|
|||||||
feather_rate = (y + 1) / bottom
|
feather_rate = (y + 1) / bottom
|
||||||
output[:, -y, :] *= feather_rate
|
output[:, -y, :] *= feather_rate
|
||||||
|
|
||||||
return (output,)
|
return IO.NodeOutput(output)
|
||||||
|
|
||||||
class GrowMask:
|
feather = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class GrowMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="GrowMask",
|
||||||
"mask": ("MASK",),
|
display_name="Grow Mask",
|
||||||
"expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
|
category="mask",
|
||||||
"tapered_corners": ("BOOLEAN", {"default": True}),
|
inputs=[
|
||||||
},
|
IO.Mask.Input("mask"),
|
||||||
}
|
IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
IO.Boolean.Input("tapered_corners", default=True),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
|
|
||||||
FUNCTION = "expand_mask"
|
|
||||||
|
|
||||||
def expand_mask(self, mask, expand, tapered_corners):
|
|
||||||
c = 0 if tapered_corners else 1
|
c = 0 if tapered_corners else 1
|
||||||
kernel = np.array([[c, 1, c],
|
kernel = np.array([[c, 1, c],
|
||||||
[1, 1, 1],
|
[1, 1, 1],
|
||||||
@ -347,69 +360,74 @@ class GrowMask:
|
|||||||
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
||||||
output = torch.from_numpy(output)
|
output = torch.from_numpy(output)
|
||||||
out.append(output)
|
out.append(output)
|
||||||
return (torch.stack(out, dim=0),)
|
return IO.NodeOutput(torch.stack(out, dim=0))
|
||||||
|
|
||||||
class ThresholdMask:
|
expand_mask = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="ThresholdMask",
|
||||||
"mask": ("MASK",),
|
category="mask",
|
||||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
inputs=[
|
||||||
}
|
IO.Mask.Input("mask"),
|
||||||
}
|
IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[IO.Mask.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask"
|
@classmethod
|
||||||
|
def execute(cls, mask, value) -> IO.NodeOutput:
|
||||||
RETURN_TYPES = ("MASK",)
|
|
||||||
FUNCTION = "image_to_mask"
|
|
||||||
|
|
||||||
def image_to_mask(self, mask, value):
|
|
||||||
mask = (mask > value).float()
|
mask = (mask > value).float()
|
||||||
return (mask,)
|
return IO.NodeOutput(mask)
|
||||||
|
|
||||||
|
image_to_mask = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
# Mask Preview - original implement from
|
# Mask Preview - original implement from
|
||||||
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
||||||
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
||||||
class MaskPreview(nodes.SaveImage):
|
class MaskPreview(IO.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.output_dir = folder_paths.get_temp_directory()
|
def define_schema(cls):
|
||||||
self.type = "temp"
|
return IO.Schema(
|
||||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
node_id="MaskPreview",
|
||||||
self.compress_level = 4
|
display_name="Preview Mask",
|
||||||
|
category="mask",
|
||||||
|
description="Saves the input images to your ComfyUI output directory.",
|
||||||
|
inputs=[
|
||||||
|
IO.Mask.Input("mask"),
|
||||||
|
],
|
||||||
|
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput:
|
||||||
return {
|
return IO.NodeOutput(ui=UI.PreviewMask(mask))
|
||||||
"required": {"mask": ("MASK",), },
|
|
||||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
|
||||||
}
|
|
||||||
|
|
||||||
FUNCTION = "execute"
|
|
||||||
CATEGORY = "mask"
|
|
||||||
|
|
||||||
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
|
||||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
|
||||||
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class MaskExtension(ComfyExtension):
|
||||||
"LatentCompositeMasked": LatentCompositeMasked,
|
@override
|
||||||
"ImageCompositeMasked": ImageCompositeMasked,
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
"MaskToImage": MaskToImage,
|
return [
|
||||||
"ImageToMask": ImageToMask,
|
LatentCompositeMasked,
|
||||||
"ImageColorToMask": ImageColorToMask,
|
ImageCompositeMasked,
|
||||||
"SolidMask": SolidMask,
|
MaskToImage,
|
||||||
"InvertMask": InvertMask,
|
ImageToMask,
|
||||||
"CropMask": CropMask,
|
ImageColorToMask,
|
||||||
"MaskComposite": MaskComposite,
|
SolidMask,
|
||||||
"FeatherMask": FeatherMask,
|
InvertMask,
|
||||||
"GrowMask": GrowMask,
|
CropMask,
|
||||||
"ThresholdMask": ThresholdMask,
|
MaskComposite,
|
||||||
"MaskPreview": MaskPreview
|
FeatherMask,
|
||||||
}
|
GrowMask,
|
||||||
|
ThresholdMask,
|
||||||
|
MaskPreview,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"ImageToMask": "Convert Image to Mask",
|
async def comfy_entrypoint() -> MaskExtension:
|
||||||
"MaskToImage": "Convert Mask to Image",
|
return MaskExtension()
|
||||||
}
|
|
||||||
|
|||||||
@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode):
|
|||||||
return io.NodeOutput(m)
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
# Sampling
|
|
||||||
"PatchModelAddDownscale": "",
|
|
||||||
}
|
|
||||||
|
|
||||||
class ModelDownscaleExtension(ComfyExtension):
|
class ModelDownscaleExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
|||||||
@ -63,18 +63,22 @@ def cuda_malloc_supported():
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
version = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch_spec = importlib.util.find_spec("torch")
|
||||||
|
for folder in torch_spec.submodule_search_locations:
|
||||||
|
ver_file = os.path.join(folder, "version.py")
|
||||||
|
if os.path.isfile(ver_file):
|
||||||
|
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
version = module.__version__
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if not args.cuda_malloc:
|
if not args.cuda_malloc:
|
||||||
try:
|
try:
|
||||||
version = ""
|
|
||||||
torch_spec = importlib.util.find_spec("torch")
|
|
||||||
for folder in torch_spec.submodule_search_locations:
|
|
||||||
ver_file = os.path.join(folder, "version.py")
|
|
||||||
if os.path.isfile(ver_file):
|
|
||||||
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
|
||||||
module = importlib.util.module_from_spec(spec)
|
|
||||||
spec.loader.exec_module(module)
|
|
||||||
version = module.__version__
|
|
||||||
|
|
||||||
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
|
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
|
||||||
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
|
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
|
||||||
args.cuda_malloc = cuda_malloc_supported()
|
args.cuda_malloc = cuda_malloc_supported()
|
||||||
@ -90,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc:
|
|||||||
env_var += ",backend:cudaMallocAsync"
|
env_var += ",backend:cudaMallocAsync"
|
||||||
|
|
||||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||||
|
|
||||||
|
def get_torch_version_noimport():
|
||||||
|
return str(version)
|
||||||
|
|||||||
3
main.py
3
main.py
@ -167,6 +167,9 @@ if __name__ == "__main__":
|
|||||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||||
|
|
||||||
import cuda_malloc
|
import cuda_malloc
|
||||||
|
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
||||||
|
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
||||||
|
|
||||||
|
|
||||||
if 'torch' in sys.modules:
|
if 'torch' in sys.modules:
|
||||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.0.3b3
|
comfyui_manager==4.0.3b4
|
||||||
|
|||||||
3
nodes.py
3
nodes.py
@ -970,7 +970,7 @@ class DualCLIPLoader:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ),
|
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -2357,6 +2357,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_rope.py",
|
"nodes_rope.py",
|
||||||
"nodes_logic.py",
|
"nodes_logic.py",
|
||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
|
"nodes_kandinsky5.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user