Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski 2025-09-11 20:59:47 -07:00
commit efcd8280d6
75 changed files with 7128 additions and 1305 deletions

30
.github/workflows/test-execution.yml vendored Normal file
View File

@ -0,0 +1,30 @@
name: Execution Tests
on:
push:
branches: [ main, master ]
pull_request:
branches: [ main, master ]
jobs:
test:
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
runs-on: ${{ matrix.os }}
continue-on-error: true
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'
- name: Install requirements
run: |
python -m pip install --upgrade pip
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install -r tests-unit/requirements.txt
- name: Run Execution Tests
run: |
python -m pytest tests/execution -v --skip-timing-checks

View File

@ -143,8 +143,9 @@ class PerformanceFeature(enum.Enum):
Fp16Accumulation = "fp16_accumulation"
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")

View File

@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module):
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
all_intermediate = None
if intermediate_output is not None:
if intermediate_output < 0:
if intermediate_output == "all":
all_intermediate = []
intermediate_output = None
elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
intermediate = None
@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1)
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):

View File

@ -50,7 +50,13 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
model_type = config.get("model_type", "clip_vision_model")
model_class = IMAGE_ENCODERS.get(model_type)
if model_type == "siglip_vision_model":
self.return_all_hidden_states = True
else:
self.return_all_hidden_states = False
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@ -68,12 +74,18 @@ class ClipVisionModel():
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output()
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
if self.return_all_hidden_states:
all_hs = out[1].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = all_hs[:, -2]
outputs["all_hidden_states"] = all_hs
else:
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
outputs["mm_projected"] = out[3]
return outputs
@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:
# Dinov2
elif 'encoder.layer.39.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
else:
return None

View File

@ -288,7 +288,10 @@ class ControlNet(ControlBase):
to_concat = []
for c in self.extra_concat_orig:
c = c.to(self.cond_hint.device)
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
if c.ndim < self.cond_hint.ndim:
c = c.unsqueeze(2)
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
@ -628,11 +631,18 @@ def load_controlnet_flux_instantx(sd, model_options={}):
def load_controlnet_qwen_instantx(sd, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
extra_condition_channels = 0
concat_mask = False
if control_latent_channels == 68: #inpaint controlnet
extra_condition_channels = control_latent_channels - 64
concat_mask = True
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.Wan21()
extra_conds = []
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def convert_mistoline(sd):

View File

@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module):
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class Dinov2MLP(torch.nn.Module):
def __init__(self, hidden_size: int, dtype, device, operations):
super().__init__()
mlp_ratio = 4
hidden_features = int(hidden_size * mlp_ratio)
self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype)
self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype)
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
hidden_state = self.fc1(hidden_state)
hidden_state = torch.nn.functional.gelu(hidden_state)
hidden_state = self.fc2(hidden_state)
return hidden_state
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module):
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
if use_swiglu_ffn:
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
else:
self.mlp = Dinov2MLP(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module):
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module):
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
for i, layer in enumerate(self.layer):
x = layer(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module):
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):

View File

@ -0,0 +1,22 @@
{
"hidden_size": 1024,
"use_mask_token": true,
"patch_size": 14,
"image_size": 518,
"num_channels": 3,
"num_attention_heads": 16,
"initializer_range": 0.02,
"attention_probs_dropout_prob": 0.0,
"hidden_dropout_prob": 0.0,
"hidden_act": "gelu",
"mlp_ratio": 4,
"model_type": "dinov2",
"num_hidden_layers": 24,
"layer_norm_eps": 1e-6,
"qkv_bias": true,
"use_swiglu_ffn": false,
"layerscale_value": 1.0,
"drop_path_rate": 0.0,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4):
return sigmas
def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor:
"""Compute the result of h*phi_1(h) in exponential integrator methods."""
return torch.expm1(h)
def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor:
"""Compute the result of h*phi_2(h) in exponential integrator methods."""
return (torch.expm1(h) - h) / h
@torch.no_grad()
def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
"""Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
@ -1550,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@ -1564,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling)
sigmas = offset_first_sigma_for_snr(sigmas, model_sampling)
fac = 1 / (2 * r)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r * h
fac = 1 / (2 * r)
sigma_s_1 = sigma_fn(lambda_s_1)
continue
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r)
sigma_s_1 = sigma_fn(lambda_s_1)
coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r < 1
noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt()
noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1])
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised
if inject_noise:
sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
denoised_d = (1 - fac) * denoised + fac * denoised_2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
# Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1])
x = x + sde_noise * sigmas[i + 1] * s_noise
return x
@torch.no_grad()
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
"""SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3.
arXiv: https://arxiv.org/abs/2305.14267
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
inject_noise = eta > 0 and s_noise > 0
model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling')
@ -1624,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised = model(x, sigmas[i] * s_in, **extra_args)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
x = denoised
else:
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = lambda_s + r_1 * h
lambda_s_2 = lambda_s + r_2 * h
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
continue
# alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t)
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1])
h = lambda_t - lambda_s
h_eta = h * (eta + 1)
lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1)
lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2)
sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2)
coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
if inject_noise:
# 0 < r_1 < r_2 < 1
noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt()
noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt()
noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
alpha_s_1 = sigma_s_1 * lambda_s_1.exp()
alpha_s_2 = sigma_s_2 * lambda_s_2.exp()
alpha_t = sigmas[i + 1] * lambda_t.exp()
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised
if inject_noise:
x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 1
x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised
if inject_noise:
sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1)
x_2 = x_2 + sde_noise * sigma_s_1 * s_noise
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
if inject_noise:
x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 2
a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta)
a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2
x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2)
if inject_noise:
segment_factor = (r_1 - r_2) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2)
x_3 = x_3 + sde_noise * sigma_s_2 * s_noise
denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
# Step 3
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
if inject_noise:
x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
# Step 3
b3 = ei_h_phi_2(-h_eta) / r_2
b1 = ei_h_phi_1(-h_eta) - b3
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3)
if inject_noise:
segment_factor = (r_2 - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()
sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1])
x = x + sde_noise * sigmas[i + 1] * s_noise
return x

View File

@ -533,11 +533,89 @@ class Wan22(Wan21):
0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744
]).view(1, self.latent_channels, 1, 1, 1)
class HunyuanImage21(LatentFormat):
latent_channels = 64
latent_dimensions = 2
scale_factor = 0.75289
latent_rgb_factors = [
[-0.0154, -0.0397, -0.0521],
[ 0.0005, 0.0093, 0.0006],
[-0.0805, -0.0773, -0.0586],
[-0.0494, -0.0487, -0.0498],
[-0.0212, -0.0076, -0.0261],
[-0.0179, -0.0417, -0.0505],
[ 0.0158, 0.0310, 0.0239],
[ 0.0409, 0.0516, 0.0201],
[ 0.0350, 0.0553, 0.0036],
[-0.0447, -0.0327, -0.0479],
[-0.0038, -0.0221, -0.0365],
[-0.0423, -0.0718, -0.0654],
[ 0.0039, 0.0368, 0.0104],
[ 0.0655, 0.0217, 0.0122],
[ 0.0490, 0.1638, 0.2053],
[ 0.0932, 0.0829, 0.0650],
[-0.0186, -0.0209, -0.0135],
[-0.0080, -0.0076, -0.0148],
[-0.0284, -0.0201, 0.0011],
[-0.0642, -0.0294, -0.0777],
[-0.0035, 0.0076, -0.0140],
[ 0.0519, 0.0731, 0.0887],
[-0.0102, 0.0095, 0.0704],
[ 0.0068, 0.0218, -0.0023],
[-0.0726, -0.0486, -0.0519],
[ 0.0260, 0.0295, 0.0263],
[ 0.0250, 0.0333, 0.0341],
[ 0.0168, -0.0120, -0.0174],
[ 0.0226, 0.1037, 0.0114],
[ 0.2577, 0.1906, 0.1604],
[-0.0646, -0.0137, -0.0018],
[-0.0112, 0.0309, 0.0358],
[-0.0347, 0.0146, -0.0481],
[ 0.0234, 0.0179, 0.0201],
[ 0.0157, 0.0313, 0.0225],
[ 0.0423, 0.0675, 0.0524],
[-0.0031, 0.0027, -0.0255],
[ 0.0447, 0.0555, 0.0330],
[-0.0152, 0.0103, 0.0299],
[-0.0755, -0.0489, -0.0635],
[ 0.0853, 0.0788, 0.1017],
[-0.0272, -0.0294, -0.0471],
[ 0.0440, 0.0400, -0.0137],
[ 0.0335, 0.0317, -0.0036],
[-0.0344, -0.0621, -0.0984],
[-0.0127, -0.0630, -0.0620],
[-0.0648, 0.0360, 0.0924],
[-0.0781, -0.0801, -0.0409],
[ 0.0363, 0.0613, 0.0499],
[ 0.0238, 0.0034, 0.0041],
[-0.0135, 0.0258, 0.0310],
[ 0.0614, 0.1086, 0.0589],
[ 0.0428, 0.0350, 0.0205],
[ 0.0153, 0.0173, -0.0018],
[-0.0288, -0.0455, -0.0091],
[ 0.0344, 0.0109, -0.0157],
[-0.0205, -0.0247, -0.0187],
[ 0.0487, 0.0126, 0.0064],
[-0.0220, -0.0013, 0.0074],
[-0.0203, -0.0094, -0.0048],
[-0.0719, 0.0429, -0.0442],
[ 0.1042, 0.0497, 0.0356],
[-0.0659, -0.0578, -0.0280],
[-0.0060, -0.0322, -0.0234]]
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1
scale_factor = 0.9990943042622529
class Hunyuan3Dv2_1(LatentFormat):
scale_factor = 1.0039506158752403
latent_channels = 64
latent_dimensions = 1
class Hunyuan3Dv2mini(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -632,7 +632,7 @@ class ContinuousTransformer(nn.Module):
# Attention layers
if self.rotary_pos_emb is not None:
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device)
else:
rotary_pos_emb = None

View File

@ -106,6 +106,7 @@ class Flux(nn.Module):
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@ -117,9 +118,17 @@ class Flux(nn.Module):
if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt)
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
img = out["img"]
txt = out["txt"]
img_ids = out["img_ids"]
txt_ids = out["txt_ids"]
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
@ -233,12 +242,18 @@ class Flux(nn.Module):
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "offset") == "index"
ref_latents_method = kwargs.get("ref_latents_method", "offset")
for ref in ref_latents:
if index_ref_method:
if ref_latents_method == "index":
index += 1
h_offset = 0
w_offset = 0
elif ref_latents_method == "uxo":
index = 0
h_offset = h_len * patch_size + h
w_offset = w_len * patch_size + w
h += ref.shape[-2]
w += ref.shape[-1]
else:
index = 1
h_offset = 0

View File

@ -4,81 +4,458 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Tuple, List, Callable, Optional
import numpy as np
from einops import repeat, rearrange
import math
from tqdm import tqdm
from typing import Optional
import logging
import comfy.ops
ops = comfy.ops.disable_weight_init
def generate_dense_grid_points(
bbox_min: np.ndarray,
bbox_max: np.ndarray,
octree_resolution: int,
indexing: str = "ij",
):
length = bbox_max - bbox_min
num_cells = octree_resolution
def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True):
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
xyz = np.stack((xs, ys, zs), axis=-1)
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
# manually create the pointer vector
assert src.size(0) == batch.numel()
return xyz, grid_size, length
batch_size = int(batch.max()) + 1
deg = src.new_zeros(batch_size, dtype = torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch))
ptr_vec = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_vec[1:])
#return fps_sampling(src, ptr_vec, ratio)
sampled_indicies = []
for b in range(batch_size):
# start and the end of each batch
start, end = ptr_vec[b].item(), ptr_vec[b + 1].item()
# points from the point cloud
points = src[start:end]
num_points = points.size(0)
num_samples = max(1, math.ceil(num_points * sampling_ratio))
selected = torch.zeros(num_samples, device = src.device, dtype = torch.long)
distances = torch.full((num_points,), float("inf"), device = src.device)
# select a random start point
if start_random:
farthest = torch.randint(0, num_points, (1,), device = src.device)
else:
farthest = torch.tensor([0], device = src.device, dtype = torch.long)
for i in range(num_samples):
selected[i] = farthest
centroid = points[farthest].squeeze(0)
dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance
distances = torch.minimum(distances, dist)
farthest = torch.argmax(distances)
sampled_indicies.append(torch.arange(start, end)[selected])
return torch.cat(sampled_indicies, dim = 0)
class PointCrossAttention(nn.Module):
def __init__(self,
num_latents: int,
downsample_ratio: float,
pc_size: int,
pc_sharpedge_size: int,
point_feats: int,
width: int,
heads: int,
layers: int,
fourier_embedder,
normal_pe: bool = False,
qkv_bias: bool = False,
use_ln_post: bool = True,
qk_norm: bool = True):
super().__init__()
self.fourier_embedder = fourier_embedder
self.pc_size = pc_size
self.normal_pe = normal_pe
self.downsample_ratio = downsample_ratio
self.pc_sharpedge_size = pc_sharpedge_size
self.num_latents = num_latents
self.point_feats = point_feats
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width)
self.cross_attn = ResidualCrossAttentionBlock(
width = width,
heads = heads,
qkv_bias = qkv_bias,
qk_norm = qk_norm
)
self.self_attn = None
if layers > 0:
self.self_attn = Transformer(
width = width,
heads = heads,
qkv_bias = qkv_bias,
qk_norm = qk_norm,
layers = layers
)
if use_ln_post:
self.ln_post = nn.LayerNorm(width)
else:
self.ln_post = None
def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor):
"""
Subsample points randomly from the point cloud (input_pc)
Further sample the subsampled points to get query_pc
take the fourier embeddings for both input and query pc
Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc.
Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc).
More computationally efficient.
Features are additional information for each point in the cloud
"""
B, _, D = point_cloud.shape
num_latents = int(self.num_latents)
num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents
num_sharpedge_query = num_latents - num_random_query
# Split random and sharpedge surface points
random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1)
# assert statements
assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size"
assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size"
input_random_pc_size = int(num_random_query * self.downsample_ratio)
random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \
self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size)
input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio)
if input_sharpedge_pc_size == 0:
sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device)
sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device)
else:
sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \
self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size)
# concat the random and sharpedges
query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1)
input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1)
query = self.fourier_embedder(query_pc)
data = self.fourier_embedder(input_pc)
if self.point_feats > 0:
random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1)
input_random_surface_features, query_random_features = \
self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B,
input_pc_size = input_random_pc_size, idx_query = random_idx_query)
if input_sharpedge_pc_size == 0:
input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats,
dtype = input_random_surface_features.dtype, device = point_cloud.device)
query_sharpedge_features = torch.zeros(B, 0, self.point_feats,
dtype = query_random_features.dtype, device = point_cloud.device)
else:
input_sharpedge_surface_features, query_sharpedge_features = \
self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features,
batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size)
query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1)
input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1)
if self.normal_pe:
# apply the fourier embeddings on the first 3 dims (xyz)
input_features_pe = self.fourier_embedder(input_features[..., :3])
query_features_pe = self.fourier_embedder(query_features[..., :3])
# replace the first 3 dims with the new PE ones
input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1)
query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1)
# concat at the channels dim
query = torch.cat([query, query_features], dim = -1)
data = torch.cat([data, input_features], dim = -1)
# don't return pc_info to avoid unnecessary memory usuage
return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1])
def forward(self, point_cloud: torch.Tensor, features: torch.Tensor):
query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features)
# apply projections
query = self.input_proj(query)
data = self.input_proj(data)
# apply cross attention between query and data
latents = self.cross_attn(query, data)
if self.self_attn is not None:
latents = self.self_attn(latents)
if self.ln_post is not None:
latents = self.ln_post(latents)
return latents
class VanillaVolumeDecoder:
def subsample(self, pc, num_query, input_pc_size: int):
"""
num_query: number of points to keep after FPS
input_pc_size: number of points to select before FPS
"""
B, _, D = pc.shape
query_ratio = num_query / input_pc_size
# random subsampling of points inside the point cloud
idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size]
input_pc = pc[:, idx_pc, :]
# flatten to allow applying fps across the whole batch
flattent_input_pc = input_pc.view(B * input_pc_size, D)
# construct a batch_down tensor to tell fps
# which points belong to which batch
N_down = int(flattent_input_pc.shape[0] / B)
batch_down = torch.arange(B).to(pc.device)
batch_down = torch.repeat_interleave(batch_down, N_down)
idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio)
query_pc = flattent_input_pc[idx_query].view(B, -1, D)
return query_pc, input_pc, idx_pc, idx_query
def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query):
B = batch_size
input_surface_features = features[:, idx_pc, :]
flattent_input_features = input_surface_features.view(B * input_pc_size, -1)
query_features = flattent_input_features[idx_query].view(B, -1,
flattent_input_features.shape[-1])
return input_surface_features, query_features
def normalize_mesh(mesh, scale = 0.9999):
"""Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]"""
bbox = mesh.bounds
center = (bbox[1] + bbox[0]) / 2
max_extent = (bbox[1] - bbox[0]).max()
mesh.apply_translation(-center)
mesh.apply_scale((2 * scale) / max_extent)
return mesh
def sample_pointcloud(mesh, num = 200000):
""" Uniformly sample points from the surface of the mesh """
points, face_idx = mesh.sample(num, return_index = True)
normals = mesh.face_normals[face_idx]
return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32))
def detect_sharp_edges(mesh, threshold=0.985):
"""Return edge indices (a, b) that lie on sharp boundaries of the mesh."""
V, F = mesh.vertices, mesh.faces
VN, FN = mesh.vertex_normals, mesh.face_normals
sharp_mask = np.ones(V.shape[0])
for i in range(3):
indices = F[:, i]
alignment = np.einsum('ij,ij->i', VN[indices], FN)
dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1)
sharp_mask[indices] = np.min(dot_stack, axis=-1)
edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]])
edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]])
sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold)
return edge_a[sharp_edges], edge_b[sharp_edges]
def sharp_sample_pointcloud(mesh, num = 16384):
""" Sample points preferentially from sharp edges in the mesh. """
edge_a, edge_b = detect_sharp_edges(mesh)
V, VN = mesh.vertices, mesh.vertex_normals
va, vb = V[edge_a], V[edge_b]
na, nb = VN[edge_a], VN[edge_b]
edge_lengths = np.linalg.norm(vb - va, axis=-1)
weights = edge_lengths / edge_lengths.sum()
indices = np.searchsorted(np.cumsum(weights), np.random.rand(num))
t = np.random.rand(num, 1)
samples = t * va[indices] + (1 - t) * vb[indices]
normals = t * na[indices] + (1 - t) * nb[indices]
return samples.astype(np.float32), normals.astype(np.float32)
def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"):
"""Load a surface with optional sharp-edge annotations from a trimesh mesh."""
import trimesh
try:
mesh_full = trimesh.util.concatenate(mesh.dump())
except Exception:
mesh_full = trimesh.util.concatenate(mesh)
mesh_full = normalize_mesh(mesh_full)
faces = mesh_full.faces
vertices = mesh_full.vertices
origin_face_count = faces.shape[0]
mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count])
mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:])
area_surface = mesh_surface.area
area_fill = mesh_fill.area
total_area = area_surface + area_fill
sample_num = 499712 // 2
fill_ratio = area_fill / total_area if total_area > 0 else 0
num_fill = int(sample_num * fill_ratio)
num_surface = sample_num - num_fill
surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface)
fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill)
sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num)
def assemble_tensor(points, normals, label=None):
data = torch.cat([points, normals], dim=1).half().to(device)
if label is not None:
label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device)
data = torch.cat([data, label_tensor], dim=1)
return data
surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0),
torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0),
label = 0 if sharpedge_flag else None)
sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals),
label = 1 if sharpedge_flag else None)
rng = np.random.default_rng()
surface = surface[rng.choice(surface.shape[0], num_points, replace = False)]
sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)]
full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0)
return full
class SharpEdgeSurfaceLoader:
""" Load mesh surface and sharp edge samples. """
def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192):
self.num_uniform_points = num_uniform_points
self.num_sharp_points = num_sharp_points
self.total_points = num_uniform_points + num_sharp_points
def __call__(self, mesh_input, device = "cuda"):
mesh = self._load_mesh(mesh_input)
return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device)
@staticmethod
def _load_mesh(mesh_input):
import trimesh
if isinstance(mesh_input, str):
mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True)
else:
mesh = mesh_input
if isinstance(mesh, trimesh.Scene):
combined = None
for obj in mesh.geometry.values():
combined = obj if combined is None else combined + obj
return combined
return mesh
class DiagonalGaussianDistribution:
def __init__(self, params: torch.Tensor, feature_dim: int = -1):
# divide quant channels (8) into mean and log variance
self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar)
def sample(self):
eps = torch.randn_like(self.std)
z = self.mean + eps * self.std
return z
################################################
# Volume Decoder
################################################
class VanillaVolumeDecoder():
@torch.no_grad()
def __call__(
self,
latents: torch.FloatTensor,
geo_decoder: Callable,
bounds: Union[Tuple[float], List[float], float] = 1.01,
num_chunks: int = 10000,
octree_resolution: int = None,
enable_pbar: bool = True,
**kwargs,
):
device = latents.device
dtype = latents.dtype
batch_size = latents.shape[0]
def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01,
num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs):
# 1. generate query points
if isinstance(bounds, float):
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
xyz_samples, grid_size, length = generate_dense_grid_points(
bbox_min=bbox_min,
bbox_max=bbox_max,
octree_resolution=octree_resolution,
indexing="ij"
)
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:])
x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32)
y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32)
z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32)
[xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij")
xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3)
grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1]
# 2. latents to 3d volume
batch_logits = []
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding",
disable=not enable_pbar):
chunk_queries = xyz_samples[start: start + num_chunks, :]
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
logits = geo_decoder(queries=chunk_queries, latents=latents)
chunk_queries = xyz[start: start + num_chunks, :]
chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1)
logits = geo_decoder(queries = chunk_queries, latents = latents)
batch_logits.append(logits)
grid_logits = torch.cat(batch_logits, dim=1)
grid_logits = grid_logits.view((batch_size, *grid_size)).float()
grid_logits = torch.cat(batch_logits, dim = 1)
grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float()
return grid_logits
class FourierEmbedder(nn.Module):
"""The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
each feature dimension of `x[..., i]` into:
@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module):
else:
return x
class CrossAttentionProcessor:
def __call__(self, attn, q, k, v):
out = comfy.ops.scaled_dot_product_attention(q, k, v)
return out
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
@ -232,38 +607,41 @@ class MLP(nn.Module):
def forward(self, x):
return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
class QKVMultiheadCrossAttention(nn.Module):
def __init__(
self,
*,
heads: int,
n_data = None,
width=None,
qk_norm=False,
norm_layer=ops.LayerNorm
):
super().__init__()
self.heads = heads
self.n_data = n_data
self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
self.attn_processor = CrossAttentionProcessor()
def forward(self, q, kv):
_, n_ctx, _ = q.shape
bs, n_data, width = kv.shape
attn_ch = width // self.heads // 2
q = q.view(bs, n_ctx, self.heads, -1)
kv = kv.view(bs, n_data, self.heads, -1)
k, v = torch.split(kv, attn_ch, dim=-1)
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
out = self.attn_processor(self, q, k, v)
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
return out
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
return out
class MultiheadCrossAttention(nn.Module):
def __init__(
@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module):
x = self.c_proj(x)
return x
class ResidualCrossAttentionBlock(nn.Module):
def __init__(
self,
@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)]
out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
return out
@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module):
drop_path_rate: float = 0.0
):
super().__init__()
self.width = width
self.heads = heads
self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
self.c_proj = ops.Linear(width, width)
self.attention = QKVMultiheadAttention(
@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module):
self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
if self.downsample_ratio != 1:
self.latents_proj = ops.Linear(width * downsample_ratio, width)
if self.enable_ln_post == False:
if not self.enable_ln_post:
qk_norm = False
self.cross_attn_decoder = ResidualCrossAttentionBlock(
width=width,
@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module):
class ShapeVAE(nn.Module):
def __init__(
self,
*,
embed_dim: int,
width: int,
heads: int,
num_decoder_layers: int,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
include_pi: bool = True,
qkv_bias: bool = True,
qk_norm: bool = False,
label_type: str = "binary",
drop_path_rate: float = 0.0,
scale_factor: float = 1.0,
self,
*,
num_latents: int = 4096,
embed_dim: int = 64,
width: int = 1024,
heads: int = 16,
num_decoder_layers: int = 16,
num_encoder_layers: int = 8,
pc_size: int = 81920,
pc_sharpedge_size: int = 0,
point_feats: int = 4,
downsample_ratio: int = 20,
geo_decoder_downsample_ratio: int = 1,
geo_decoder_mlp_expand_ratio: int = 4,
geo_decoder_ln_post: bool = True,
num_freqs: int = 8,
qkv_bias: bool = False,
qk_norm: bool = True,
drop_path_rate: float = 0.0,
include_pi: bool = False,
scale_factor: float = 1.0039506158752403,
label_type: str = "binary",
):
super().__init__()
self.geo_decoder_ln_post = geo_decoder_ln_post
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
self.encoder = PointCrossAttention(layers = num_encoder_layers,
num_latents = num_latents,
downsample_ratio = downsample_ratio,
heads = heads,
pc_size = pc_size,
width = width,
point_feats = point_feats,
fourier_embedder = self.fourier_embedder,
pc_sharpedge_size = pc_sharpedge_size)
self.post_kl = ops.Linear(embed_dim, width)
self.transformer = Transformer(
@ -583,5 +975,14 @@ class ShapeVAE(nn.Module):
grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
return grid_logits.movedim(-2, -1)
def encode(self, x):
return None
def encode(self, surface):
pc, feats = surface[:, :, :3], surface[:, :, 3:]
latents = self.encoder(pc, feats)
moments = self.pre_kl(latents)
posterior = DiagonalGaussianDistribution(moments, feature_dim = -1)
latents = posterior.sample()
return latents

View File

@ -0,0 +1,659 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, operations, device, dtype):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type == "mps":
return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype)
return F.gelu(gate)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class FeedForward(nn.Module):
def __init__(self, dim: int, dim_out = None, mult: int = 4,
dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None):
super().__init__()
if inner_dim is None:
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype)
self.net = nn.ModuleList([])
self.net.append(act_fn)
self.net.append(nn.Dropout(dropout))
self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class AddAuxLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, x, loss):
# do nothing in forward (no computation)
ctx.requires_aux_loss = loss.requires_grad
ctx.dtype = loss.dtype
return x
@staticmethod
def backward(ctx, grad_output):
# add the aux loss gradients
grad_loss = None
# put the aux grad the same as the main grad loss
# aux grad contributes equally
if ctx.requires_aux_loss:
grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device)
return grad_output, grad_loss
class MoEGate(nn.Module):
def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None):
super().__init__()
self.top_k = num_experts_per_tok
self.n_routed_experts = num_experts
self.alpha = aux_loss_alpha
self.gating_dim = embed_dim
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# flatten hidden states
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
# get logits and pass it to softmax
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None)
scores = logits.softmax(dim = -1)
topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False)
if self.training and self.alpha > 0.0:
scores_for_aux = scores
# used bincount instead of one hot encoding
counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float()
ce = counts / topk_idx.numel() # normalized expert usage
# mean expert score
Pi = scores_for_aux.mean(0)
# expert balance loss
aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha
else:
aux_loss = None
return topk_idx, topk_weight, aux_loss
class MoEBlock(nn.Module):
def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0,
ff_inner_dim: int = None, operations = None, device = None, dtype = None):
super().__init__()
self.moe_top_k = moe_top_k
self.num_experts = num_experts
self.experts = nn.ModuleList([
FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
for _ in range(num_experts)
])
self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype)
self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype)
def forward(self, hidden_states) -> torch.Tensor:
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0)
y = torch.empty_like(hidden_states, dtype = hidden_states.dtype)
for i, expert in enumerate(self.experts):
tmp = expert(hidden_states[flat_topk_idx == i])
y[flat_topk_idx == i] = tmp.to(hidden_states.dtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1)
y = y.view(*orig_shape)
y = AddAuxLoss.apply(y, aux_loss)
else:
y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()
# no need for .numpy().cpu() here
tokens_per_expert = flat_expert_indices.bincount().cumsum(0)
token_idxs = idxs // self.moe_top_k
for i, end_idx in enumerate(tokens_per_expert):
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
if start_idx == end_idx:
continue
expert = self.experts[i]
exp_token_idx = token_idxs[start_idx:end_idx]
expert_tokens = x[exp_token_idx]
expert_out = expert(expert_tokens)
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
# use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_
# + avoid dtype conversion
expert_cache.index_add_(0, exp_token_idx, expert_out)
return expert_cache
class Timesteps(nn.Module):
def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0,
scale: float = 1.0, max_period: int = 10000):
super().__init__()
self.num_channels = num_channels
half_dim = num_channels // 2
# precompute the “inv_freq” vector once
exponent = -math.log(max_period) * torch.arange(
half_dim, dtype=torch.float32
) / (half_dim - downscale_freq_shift)
inv_freq = torch.exp(exponent)
# pad
if num_channels % 2 == 1:
# well pad a zero at the end of the cos-half
inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)])
# register to buffer so it moves with the device
self.register_buffer("inv_freq", inv_freq, persistent = False)
self.scale = scale
def forward(self, timesteps: torch.Tensor):
x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0)
# fused CUDA kernels for sin and cos
sin_emb = x.sin()
cos_emb = x.cos()
emb = torch.cat([sin_emb, cos_emb], dim = 1)
# scale factor
if self.scale != 1.0:
emb = emb * self.scale
# If we padded inv_freq for odd, emb is already wide enough; otherwise:
if emb.shape[1] > self.num_channels:
emb = emb[:, :self.num_channels]
return emb
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype),
nn.GELU(),
operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype),
)
self.frequency_embedding_size = frequency_embedding_size
if cond_proj_dim is not None:
self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype)
self.time_embed = Timesteps(hidden_size)
def forward(self, timesteps, condition):
timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype)
if condition is not None:
cond_embed = self.cond_proj(condition)
timestep_embed = timestep_embed + cond_embed
time_conditioned = self.mlp(timestep_embed)
# for broadcasting with image tokens
return time_conditioned.unsqueeze(1)
class MLP(nn.Module):
def __init__(self, *, width: int, operations = None, device = None, dtype = None):
super().__init__()
self.width = width
self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype)
self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype)
self.gelu = nn.GELU()
def forward(self, x):
return self.fc2(self.gelu(self.fc1(x)))
class CrossAttention(nn.Module):
def __init__(
self,
qdim,
kdim,
num_heads,
qkv_bias=True,
qk_norm=False,
norm_layer=nn.LayerNorm,
use_fp16: bool = False,
operations = None,
dtype = None,
device = None,
**kwargs,
):
super().__init__()
self.qdim = qdim
self.kdim = kdim
self.num_heads = num_heads
self.head_dim = self.qdim // num_heads
self.scale = self.head_dim ** -0.5
self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype)
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
if norm_layer == nn.LayerNorm:
norm_layer = operations.LayerNorm
else:
norm_layer = operations.RMSNorm
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype)
def forward(self, x, y):
b, s1, _ = x.shape
_, s2, _ = y.shape
y = y.to(next(self.to_k.parameters()).dtype)
q = self.to_q(x)
k = self.to_k(y)
v = self.to_v(y)
kv = torch.cat((k, v), dim=-1)
split_size = kv.shape[-1] // self.num_heads // 2
kv = kv.view(1, -1, self.num_heads, split_size * 2)
k, v = torch.split(kv, split_size, dim=-1)
q = q.view(b, s1, self.num_heads, self.head_dim)
k = k.view(b, s2, self.num_heads, self.head_dim)
v = v.reshape(b, s2, self.num_heads * self.head_dim)
q = self.q_norm(q)
k = self.k_norm(k)
x = optimized_attention(
q.reshape(b, s1, self.num_heads * self.head_dim),
k.reshape(b, s2, self.num_heads * self.head_dim),
v,
heads=self.num_heads,
)
out = self.out_proj(x)
return out
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads,
qkv_bias = True,
qk_norm = False,
norm_layer = nn.LayerNorm,
use_fp16: bool = False,
operations = None,
device = None,
dtype = None
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = self.dim // num_heads
self.scale = self.head_dim ** -0.5
self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype)
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
if norm_layer == nn.LayerNorm:
norm_layer = operations.LayerNorm
else:
norm_layer = operations.RMSNorm
self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity()
self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype)
def forward(self, x):
B, N, _ = x.shape
query = self.to_q(x)
key = self.to_k(x)
value = self.to_v(x)
qkv_combined = torch.cat((query, key, value), dim=-1)
split_size = qkv_combined.shape[-1] // self.num_heads // 3
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
query, key, value = torch.split(qkv, split_size, dim=-1)
query = query.reshape(B, N, self.num_heads, self.head_dim)
key = key.reshape(B, N, self.num_heads, self.head_dim)
value = value.reshape(B, N, self.num_heads * self.head_dim)
query = self.q_norm(query)
key = self.k_norm(key)
x = optimized_attention(
query.reshape(B, N, self.num_heads * self.head_dim),
key.reshape(B, N, self.num_heads * self.head_dim),
value,
heads=self.num_heads,
)
x = self.out_proj(x)
return x
class HunYuanDiTBlock(nn.Module):
def __init__(
self,
hidden_size,
c_emb_size,
num_heads,
text_states_dim=1024,
qk_norm=False,
norm_layer=nn.LayerNorm,
qk_norm_layer=True,
qkv_bias=True,
skip_connection=True,
timested_modulate=False,
use_moe: bool = False,
num_experts: int = 8,
moe_top_k: int = 2,
use_fp16: bool = False,
operations = None,
device = None, dtype = None
):
super().__init__()
# eps can't be 1e-6 in fp16 mode because of numerical stability issues
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm,
norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations)
self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.timested_modulate = timested_modulate
if self.timested_modulate:
self.default_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype)
)
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias,
qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16,
device = device, dtype = dtype, operations = operations)
self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
if skip_connection:
self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype)
else:
self.skip_linear = None
self.use_moe = use_moe
if self.use_moe:
self.moe = MoEBlock(
hidden_size,
num_experts = num_experts,
moe_top_k = moe_top_k,
dropout = 0.0,
ff_inner_dim = int(hidden_size * 4.0),
device = device, dtype = dtype,
operations = operations
)
else:
self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype)
def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None):
if self.skip_linear is not None:
combined = torch.cat([skip_tensor, hidden_states], dim=-1)
hidden_states = self.skip_linear(combined)
hidden_states = self.skip_norm(hidden_states)
# self attention
if self.timested_modulate:
modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1)
hidden_states = hidden_states + modulation_shift
self_attn_out = self.attn1(self.norm1(hidden_states))
hidden_states = hidden_states + self_attn_out
# cross attention
hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states)
# MLP Layer
mlp_input = self.norm3(hidden_states)
if self.use_moe:
hidden_states = hidden_states + self.moe(mlp_input)
else:
hidden_states = hidden_states + self.mlp(mlp_input)
return hidden_states
class FinalLayer(nn.Module):
def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None):
super().__init__()
if use_fp16:
eps = 1.0 / 65504
else:
eps = 1e-6
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype)
self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype)
def forward(self, x):
x = self.norm_final(x)
x = x[:, 1:]
x = self.linear(x)
return x
class HunYuanDiTPlain(nn.Module):
# init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml
def __init__(
self,
in_channels: int = 64,
hidden_size: int = 2048,
context_dim: int = 1024,
depth: int = 21,
num_heads: int = 16,
qk_norm: bool = True,
qkv_bias: bool = False,
num_moe_layers: int = 6,
guidance_cond_proj_dim = 2048,
norm_type = 'layer',
num_experts: int = 8,
moe_top_k: int = 2,
use_fp16: bool = False,
dtype = None,
device = None,
operations = None,
**kwargs
):
self.dtype = dtype
super().__init__()
self.depth = depth
self.in_channels = in_channels
self.out_channels = in_channels
self.num_heads = num_heads
self.hidden_size = hidden_size
norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm
qk_norm = operations.RMSNorm
self.context_dim = context_dim
self.guidance_cond_proj_dim = guidance_cond_proj_dim
self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype)
self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations)
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
text_states_dim=context_dim,
qk_norm=qk_norm,
norm_layer = norm,
qk_norm_layer = qk_norm,
skip_connection=layer > depth // 2,
qkv_bias=qkv_bias,
use_moe=True if depth - layer <= num_moe_layers else False,
num_experts=num_experts,
moe_top_k=moe_top_k,
use_fp16 = use_fp16,
device = device, dtype = dtype, operations = operations)
for layer in range(depth)
])
self.depth = depth
self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype)
def forward(self, x, t, context, transformer_options = {}, **kwargs):
x = x.movedim(-1, -2)
uncond_emb, cond_emb = context.chunk(2, dim = 0)
context = torch.cat([cond_emb, uncond_emb], dim = 0)
main_condition = context
t = 1.0 - t
time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond'))
x = x.to(dtype = next(self.x_embedder.parameters()).dtype)
x_embedded = self.x_embedder(x)
combined = torch.cat([time_embedded, x_embedded], dim=1)
def block_wrap(args):
return block(
args["x"],
args["t"],
args["cond"],
skip_tensor=args.get("skip"),)
skip_stack = []
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for idx, block in enumerate(self.blocks):
if idx <= self.depth // 2:
skip_input = None
else:
skip_input = skip_stack.pop()
if ("block", idx) in blocks_replace:
combined = blocks_replace[("block", idx)](
{
"x": combined,
"t": time_embedded,
"cond": main_condition,
"skip": skip_input,
},
{"original_block": block_wrap},
)
else:
combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input)
if idx < self.depth // 2:
skip_stack.append(combined)
output = self.final_layer(combined)
output = output.movedim(-2, -1) * (-1.0)
cond_emb, uncond_emb = output.chunk(2, dim = 0)
return torch.cat([uncond_emb, cond_emb])

View File

@ -40,6 +40,8 @@ class HunyuanVideoParams:
patch_size: list
qkv_bias: bool
guidance_embed: bool
byt5: bool
meanflow: bool
class SelfAttentionRef(nn.Module):
@ -161,6 +163,30 @@ class TokenRefiner(nn.Module):
x = self.individual_token_refiner(x, c, mask)
return x
class ByT5Mapper(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None):
super().__init__()
self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device)
self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device)
self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device)
self.use_res = use_res
self.act_fn = nn.GELU()
def forward(self, x):
if self.use_res:
res = x
x = self.layernorm(x)
x = self.fc1(x)
x = self.act_fn(x)
x = self.fc2(x)
x2 = self.act_fn(x)
x2 = self.fc3(x2)
if self.use_res:
x2 = x2 + res
return x2
class HunyuanVideo(nn.Module):
"""
Transformer model for flow matching on sequences.
@ -185,9 +211,13 @@ class HunyuanVideo(nn.Module):
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
if params.vec_in_dim is not None:
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.vector_in = None
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
@ -215,6 +245,23 @@ class HunyuanVideo(nn.Module):
]
)
if params.byt5:
self.byt5_in = ByT5Mapper(
in_dim=1472,
out_dim=2048,
hidden_dim=2048,
out_dim1=self.hidden_size,
use_res=False,
dtype=dtype, device=device, operations=operations
)
else:
self.byt5_in = None
if params.meanflow:
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.time_r_in = None
if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
@ -226,7 +273,8 @@ class HunyuanVideo(nn.Module):
txt_ids: Tensor,
txt_mask: Tensor,
timesteps: Tensor,
y: Tensor,
y: Tensor = None,
txt_byt5=None,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
@ -240,6 +288,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if self.time_r_in is not None:
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
if len(w) > 0:
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
vec = (vec + vec_r) / 2
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
@ -250,13 +306,17 @@ class HunyuanVideo(nn.Module):
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
if self.vector_in is not None:
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
else:
vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1)
frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
modulation_dims_txt = [(0, None, 1)]
else:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.vector_in is not None:
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
modulation_dims = None
modulation_dims_txt = None
@ -269,6 +329,12 @@ class HunyuanVideo(nn.Module):
txt = self.txt_in(txt, timesteps, txt_mask)
if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids)
@ -328,12 +394,16 @@ class HunyuanVideo(nn.Module):
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
shape = initial_shape[-len(self.patch_size):]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
if img.ndim == 8:
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
else:
img = img.permute(0, 3, 1, 4, 2, 5)
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3])
return img
def img_ids(self, x):
@ -348,16 +418,30 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def img_ids_2d(self, x):
bs, c, h, w = x.shape
patch_size = self.patch_size
h_len = ((h + (patch_size[0] // 2)) // patch_size[0])
w_len = ((w + (patch_size[1] // 2)) // patch_size[1])
img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype)
img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out

View File

@ -0,0 +1,136 @@
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
import comfy.ops
ops = comfy.ops.disable_weight_init
class PixelShuffle2D(nn.Module):
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
super().__init__()
self.conv = op(in_dim, out_dim >> 2, 3, 1, 1)
self.ratio = (in_dim << 2) // out_dim
def forward(self, x):
b, c, h, w = x.shape
h2, w2 = h >> 1, w >> 1
y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2)
r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2)
return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2)
class PixelUnshuffle2D(nn.Module):
def __init__(self, in_dim, out_dim, op=ops.Conv2d):
super().__init__()
self.conv = op(in_dim, out_dim << 2, 3, 1, 1)
self.scale = (out_dim << 2) // in_dim
def forward(self, x):
b, c, h, w = x.shape
h2, w2 = h << 1, w << 1
y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2)
return y + r
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, downsample_match_channel=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=ops.Conv2d)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d)
ch = nxt
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1)
def forward(self, x):
x = self.conv_in(x)
for stage in self.down:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
b, c, h, w = x.shape
grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, h, w).mean(2)
return self.conv_out(F.silu(self.norm_out(x))) + skip
class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, upsample_match_channel=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
ch = block_out_channels[0]
self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=ops.Conv2d)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d)
ch = nxt
self.up.append(stage)
self.norm_out = ops.GroupNorm(32, ch, 1e-6, True)
self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1)
def forward(self, z):
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for stage in self.up:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x)))

View File

@ -145,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512, conv_op=ops.Conv2d):
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@ -183,7 +183,7 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
def forward(self, x, temb):
def forward(self, x, temb=None):
h = x
h = self.norm1(h)
h = self.swish(h)

View File

@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
for k in sdk:
hidden_size = model.model_config.unet_config.get("hidden_size", 0)
if k.endswith(".weight") and ".linear1." in k:
key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3))
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:

View File

@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
def convert_uso_lora(sd):
sd_out = {}
for k in sd:
tensor = sd[k]
k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight")
.replace(".up.weight", ".lora_up.weight")
.replace(".qkv_lora2.", ".txt_attn.qkv.")
.replace(".qkv_lora1.", ".img_attn.qkv.")
.replace(".proj_lora1.", ".img_attn.proj.")
.replace(".proj_lora2.", ".txt_attn.proj.")
.replace(".qkv_lora.", ".linear1_qkv.")
.replace(".proj_lora.", ".linear2.")
.replace(".processor.", ".")
)
sd_out[k_to] = tensor
return sd_out
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
return convert_lora_wan_fun(sd)
if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd:
return convert_uso_lora(sd)
return sd

View File

@ -16,6 +16,8 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import comfy.ldm.hunyuan3dv2_1
import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@ -1282,6 +1284,21 @@ class Hunyuan3Dv2(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class Hunyuan3Dv2_1(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guidance = kwargs.get("guidance", 5.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
@ -1391,3 +1408,27 @@ class QwenImage(BaseModel):
if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class HunyuanImage21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
if conditioning_byt5small is not None:
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out

View File

@ -136,20 +136,40 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)]
out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)]
dit_config["image_model"] = "hunyuan_video"
dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = list(in_w.shape[2:])
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["vec_in_dim"] = 768
else:
dit_config["vec_in_dim"] = None
if len(dit_config["patch_size"]) == 2:
dit_config["axes_dim"] = [64, 64]
else:
dit_config["axes_dim"] = [16, 56, 56]
if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["meanflow"] = True
else:
dit_config["meanflow"] = False
dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_w.shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["num_heads"] = in_w.shape[0] // 128
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 256
dit_config["qkv_bias"] = True
if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict:
dit_config["byt5"] = True
else:
dit_config["byt5"] = False
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
@ -400,6 +420,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1
dit_config = {}
dit_config["image_model"] = "hunyuan3d2_1"
dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1]
dit_config["context_dim"] = 1024
dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0]
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 16
dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}")
dit_config["qkv_bias"] = False
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
dit_config = {}
dit_config["image_model"] = "hidream"

View File

@ -23,6 +23,7 @@ from enum import Enum
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
import importlib
import platform
import weakref
import gc
@ -313,6 +314,24 @@ def is_amd():
return True
return False
def amd_min_version(device=None, min_rdna_version=0):
if not is_amd():
return False
if is_device_cpu(device):
return False
arch = torch.cuda.get_device_properties(device).gcnArchName
if arch.startswith('gfx') and len(arch) == 7:
try:
cmp_rdna_version = int(arch[4]) + 2
except:
cmp_rdna_version = 0
if cmp_rdna_version >= min_rdna_version:
return True
return False
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
MIN_WEIGHT_MEMORY_RATIO = 0.0
@ -345,12 +364,13 @@ try:
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8):
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
# if torch_version_numeric >= (2, 8):
# if any((a in arch) for a in ["gfx1201"]):
# ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
SUPPORT_FP8_OPS = True
@ -933,7 +953,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
# also a problem on RDNA4 except fp32 is also slow there.
# This is due to large bf16 convolutions being extremely slow.
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
return d
return torch.float32

View File

@ -513,6 +513,9 @@ class ModelPatcher:
def set_model_double_block_patch(self, patch):
self.set_model_patch(patch, "double_block")
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def add_object_patch(self, name, obj):
self.object_patches[name] = obj

View File

@ -52,6 +52,9 @@ except (ModuleNotFoundError, TypeError):
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)

View File

@ -17,6 +17,7 @@ import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.hunyuan_video.vae
import yaml
import math
import os
@ -48,6 +49,7 @@ import comfy.text_encoders.hidream
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.model_patcher
import comfy.lora
@ -328,6 +330,19 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.downscale_ratio = 32
self.upscale_ratio = 32
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@ -446,17 +461,29 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
# Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
self.latent_dim = 1
ln_post = "geo_decoder.ln_post.weight" in sd
inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2):
batch, num_tokens, hidden_dim = shape
dtype_size = model_management.dtype_size(dtype)
total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers)
return total_mem
# better memory estimations
self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \
estimate_memory(shape, dtype, num_layers, kv_cache_multiplier)
self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE()
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
@ -773,6 +800,7 @@ class CLIPType(Enum):
ACE = 16
OMNIGEN2 = 17
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -794,6 +822,7 @@ class TEModel(Enum):
GEMMA_2_2B = 9
QWEN25_3B = 10
QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -811,6 +840,9 @@ def detect_te_model(sd):
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight']
if weight.shape[0] == 384:
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
return TEModel.GEMMA_2_2B
@ -925,8 +957,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer
elif te_model == TEModel.QWEN25_7B:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
if clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
else:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
@ -970,6 +1006,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer

View File

@ -20,6 +20,7 @@ import comfy.text_encoders.wan
import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
from . import supported_models_base
from . import latent_formats
@ -1128,6 +1129,17 @@ class Hunyuan3Dv2(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return None
class Hunyuan3Dv2_1(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2_1",
}
latent_format = latent_formats.Hunyuan3Dv2_1
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Hunyuan3Dv2_1(self, device = device)
return out
class Hunyuan3Dv2mini(Hunyuan3Dv2):
unet_config = {
"image_model": "hunyuan3d2",
@ -1284,7 +1296,31 @@ class QwenImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect))
class HunyuanImage21(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vec_in_dim": None,
}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
sampling_settings = {
"shift": 5.0,
}
latent_format = latent_formats.HunyuanImage21
memory_usage_factor = 7.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanImage21(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@ -0,0 +1,22 @@
{
"d_ff": 3584,
"d_kv": 64,
"d_model": 1472,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "t5",
"num_decoder_layers": 4,
"num_heads": 6,
"num_layers": 12,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 1510
}

View File

@ -0,0 +1,127 @@
{
"<extra_id_0>": 259,
"<extra_id_100>": 359,
"<extra_id_101>": 360,
"<extra_id_102>": 361,
"<extra_id_103>": 362,
"<extra_id_104>": 363,
"<extra_id_105>": 364,
"<extra_id_106>": 365,
"<extra_id_107>": 366,
"<extra_id_108>": 367,
"<extra_id_109>": 368,
"<extra_id_10>": 269,
"<extra_id_110>": 369,
"<extra_id_111>": 370,
"<extra_id_112>": 371,
"<extra_id_113>": 372,
"<extra_id_114>": 373,
"<extra_id_115>": 374,
"<extra_id_116>": 375,
"<extra_id_117>": 376,
"<extra_id_118>": 377,
"<extra_id_119>": 378,
"<extra_id_11>": 270,
"<extra_id_120>": 379,
"<extra_id_121>": 380,
"<extra_id_122>": 381,
"<extra_id_123>": 382,
"<extra_id_124>": 383,
"<extra_id_12>": 271,
"<extra_id_13>": 272,
"<extra_id_14>": 273,
"<extra_id_15>": 274,
"<extra_id_16>": 275,
"<extra_id_17>": 276,
"<extra_id_18>": 277,
"<extra_id_19>": 278,
"<extra_id_1>": 260,
"<extra_id_20>": 279,
"<extra_id_21>": 280,
"<extra_id_22>": 281,
"<extra_id_23>": 282,
"<extra_id_24>": 283,
"<extra_id_25>": 284,
"<extra_id_26>": 285,
"<extra_id_27>": 286,
"<extra_id_28>": 287,
"<extra_id_29>": 288,
"<extra_id_2>": 261,
"<extra_id_30>": 289,
"<extra_id_31>": 290,
"<extra_id_32>": 291,
"<extra_id_33>": 292,
"<extra_id_34>": 293,
"<extra_id_35>": 294,
"<extra_id_36>": 295,
"<extra_id_37>": 296,
"<extra_id_38>": 297,
"<extra_id_39>": 298,
"<extra_id_3>": 262,
"<extra_id_40>": 299,
"<extra_id_41>": 300,
"<extra_id_42>": 301,
"<extra_id_43>": 302,
"<extra_id_44>": 303,
"<extra_id_45>": 304,
"<extra_id_46>": 305,
"<extra_id_47>": 306,
"<extra_id_48>": 307,
"<extra_id_49>": 308,
"<extra_id_4>": 263,
"<extra_id_50>": 309,
"<extra_id_51>": 310,
"<extra_id_52>": 311,
"<extra_id_53>": 312,
"<extra_id_54>": 313,
"<extra_id_55>": 314,
"<extra_id_56>": 315,
"<extra_id_57>": 316,
"<extra_id_58>": 317,
"<extra_id_59>": 318,
"<extra_id_5>": 264,
"<extra_id_60>": 319,
"<extra_id_61>": 320,
"<extra_id_62>": 321,
"<extra_id_63>": 322,
"<extra_id_64>": 323,
"<extra_id_65>": 324,
"<extra_id_66>": 325,
"<extra_id_67>": 326,
"<extra_id_68>": 327,
"<extra_id_69>": 328,
"<extra_id_6>": 265,
"<extra_id_70>": 329,
"<extra_id_71>": 330,
"<extra_id_72>": 331,
"<extra_id_73>": 332,
"<extra_id_74>": 333,
"<extra_id_75>": 334,
"<extra_id_76>": 335,
"<extra_id_77>": 336,
"<extra_id_78>": 337,
"<extra_id_79>": 338,
"<extra_id_7>": 266,
"<extra_id_80>": 339,
"<extra_id_81>": 340,
"<extra_id_82>": 341,
"<extra_id_83>": 342,
"<extra_id_84>": 343,
"<extra_id_85>": 344,
"<extra_id_86>": 345,
"<extra_id_87>": 346,
"<extra_id_88>": 347,
"<extra_id_89>": 348,
"<extra_id_8>": 267,
"<extra_id_90>": 349,
"<extra_id_91>": 350,
"<extra_id_92>": 351,
"<extra_id_93>": 352,
"<extra_id_94>": 353,
"<extra_id_95>": 354,
"<extra_id_96>": 355,
"<extra_id_97>": 356,
"<extra_id_98>": 357,
"<extra_id_99>": 358,
"<extra_id_9>": 268
}

View File

@ -0,0 +1,150 @@
{
"additional_special_tokens": [
"<extra_id_0>",
"<extra_id_1>",
"<extra_id_2>",
"<extra_id_3>",
"<extra_id_4>",
"<extra_id_5>",
"<extra_id_6>",
"<extra_id_7>",
"<extra_id_8>",
"<extra_id_9>",
"<extra_id_10>",
"<extra_id_11>",
"<extra_id_12>",
"<extra_id_13>",
"<extra_id_14>",
"<extra_id_15>",
"<extra_id_16>",
"<extra_id_17>",
"<extra_id_18>",
"<extra_id_19>",
"<extra_id_20>",
"<extra_id_21>",
"<extra_id_22>",
"<extra_id_23>",
"<extra_id_24>",
"<extra_id_25>",
"<extra_id_26>",
"<extra_id_27>",
"<extra_id_28>",
"<extra_id_29>",
"<extra_id_30>",
"<extra_id_31>",
"<extra_id_32>",
"<extra_id_33>",
"<extra_id_34>",
"<extra_id_35>",
"<extra_id_36>",
"<extra_id_37>",
"<extra_id_38>",
"<extra_id_39>",
"<extra_id_40>",
"<extra_id_41>",
"<extra_id_42>",
"<extra_id_43>",
"<extra_id_44>",
"<extra_id_45>",
"<extra_id_46>",
"<extra_id_47>",
"<extra_id_48>",
"<extra_id_49>",
"<extra_id_50>",
"<extra_id_51>",
"<extra_id_52>",
"<extra_id_53>",
"<extra_id_54>",
"<extra_id_55>",
"<extra_id_56>",
"<extra_id_57>",
"<extra_id_58>",
"<extra_id_59>",
"<extra_id_60>",
"<extra_id_61>",
"<extra_id_62>",
"<extra_id_63>",
"<extra_id_64>",
"<extra_id_65>",
"<extra_id_66>",
"<extra_id_67>",
"<extra_id_68>",
"<extra_id_69>",
"<extra_id_70>",
"<extra_id_71>",
"<extra_id_72>",
"<extra_id_73>",
"<extra_id_74>",
"<extra_id_75>",
"<extra_id_76>",
"<extra_id_77>",
"<extra_id_78>",
"<extra_id_79>",
"<extra_id_80>",
"<extra_id_81>",
"<extra_id_82>",
"<extra_id_83>",
"<extra_id_84>",
"<extra_id_85>",
"<extra_id_86>",
"<extra_id_87>",
"<extra_id_88>",
"<extra_id_89>",
"<extra_id_90>",
"<extra_id_91>",
"<extra_id_92>",
"<extra_id_93>",
"<extra_id_94>",
"<extra_id_95>",
"<extra_id_96>",
"<extra_id_97>",
"<extra_id_98>",
"<extra_id_99>",
"<extra_id_100>",
"<extra_id_101>",
"<extra_id_102>",
"<extra_id_103>",
"<extra_id_104>",
"<extra_id_105>",
"<extra_id_106>",
"<extra_id_107>",
"<extra_id_108>",
"<extra_id_109>",
"<extra_id_110>",
"<extra_id_111>",
"<extra_id_112>",
"<extra_id_113>",
"<extra_id_114>",
"<extra_id_115>",
"<extra_id_116>",
"<extra_id_117>",
"<extra_id_118>",
"<extra_id_119>",
"<extra_id_120>",
"<extra_id_121>",
"<extra_id_122>",
"<extra_id_123>",
"<extra_id_124>"
],
"eos_token": {
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,100 @@
from comfy import sd1_clip
import comfy.text_encoders.llama
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
from transformers import ByT5Tokenizer
import os
import re
class ByT5SmallTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data)
class HunyuanImageTokenizer(QwenImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
# self.llama_template_images = "{}"
self.byt5 = ByT5SmallTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
# ByT5 processing for HunyuanImage
text_prompt_texts = []
pattern_quote_single = r'\'(.*?)\''
pattern_quote_double = r'\"(.*?)\"'
pattern_quote_chinese_single = r'(.*?)'
pattern_quote_chinese_double = r'“(.*?)”'
matches_quote_single = re.findall(pattern_quote_single, text)
matches_quote_double = re.findall(pattern_quote_double, text)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)
text_prompt_texts.extend(matches_quote_single)
text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)
if len(text_prompt_texts) > 0:
out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs)
return out
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ByT5SmallModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_config_small_glyph.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class HunyuanImageTEModel(QwenImageTEModel):
def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}):
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
if byt5:
self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options)
else:
self.byt5_small = None
def encode_token_weights(self, token_weight_pairs):
cond, p, extra = super().encode_token_weights(token_weight_pairs)
if self.byt5_small is not None and "byt5" in token_weight_pairs:
out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"])
extra["conditioning_byt5small"] = out[0]
return cond, p, extra
def set_clip_options(self, options):
super().set_clip_options(options)
if self.byt5_small is not None:
self.byt5_small.set_clip_options(options)
def reset_clip_options(self):
super().reset_clip_options()
if self.byt5_small is not None:
self.byt5_small.reset_clip_options()
def load_sd(self, sd):
if "encoder.block.0.layer.0.SelfAttention.o.weight" in sd:
return self.byt5_small.load_sd(sd)
else:
return super().load_sd(sd)
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
class QwenImageTEModel_(HunyuanImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
return QwenImageTEModel_

View File

@ -128,11 +128,12 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=N
def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
q_embed = (xq * cos) + (rotate_half(xq) * sin)
k_embed = (xk * cos) + (rotate_half(xk) * sin)
return q_embed, k_embed
return q_embed.to(org_dtype), k_embed.to(org_dtype)
class Attention(nn.Module):

View File

@ -1190,13 +1190,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
raise NotImplementedError
@classmethod
def validate_inputs(cls, **kwargs) -> bool:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
def validate_inputs(cls, **kwargs) -> bool | str:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.
If the function returns a string, it will be used as the validation error message for the node.
"""
raise NotImplementedError
@classmethod
def fingerprint_inputs(cls, **kwargs) -> Any:
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED."""
"""Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.
If this function returns the same value as last run, the node will not be executed."""
raise NotImplementedError
@classmethod

View File

@ -518,6 +518,71 @@ async def upload_audio_to_comfyapi(
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(io.BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:

View File

@ -951,7 +951,11 @@ class MagicPrompt2(str, Enum):
class StyleType1(str, Enum):
AUTO = 'AUTO'
GENERAL = 'GENERAL'
REALISTIC = 'REALISTIC'
DESIGN = 'DESIGN'
FICTION = 'FICTION'
class ImagenImageGenerationInstance(BaseModel):
@ -2676,7 +2680,7 @@ class ReleaseNote(BaseModel):
class RenderingSpeed(str, Enum):
BALANCED = 'BALANCED'
DEFAULT = 'DEFAULT'
TURBO = 'TURBO'
QUALITY = 'QUALITY'
@ -4918,6 +4922,14 @@ class IdeogramV3EditRequest(BaseModel):
None,
description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.',
)
character_reference_images: Optional[List[str]] = Field(
None,
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
)
character_reference_images_mask: Optional[List[str]] = Field(
None,
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
)
class IdeogramV3Request(BaseModel):
@ -4951,6 +4963,14 @@ class IdeogramV3Request(BaseModel):
style_type: Optional[StyleType1] = Field(
None, description='The type of style to apply'
)
character_reference_images: Optional[List[str]] = Field(
None,
description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.'
)
character_reference_images_mask: Optional[List[str]] = Field(
None,
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
)
class ImagenGenerateImageResponse(BaseModel):

View File

@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel):
class StabilityAsyncResponse(BaseModel):
id: Optional[str] = Field(None)
class StabilityTextToAudioRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
duration: int = Field(190, ge=1, le=190)
seed: int = Field(0, ge=0, le=4294967294)
steps: int = Field(8, ge=4, le=8)
output_format: str = Field("wav")
class StabilityAudioToAudioRequest(StabilityTextToAudioRequest):
strength: float = Field(0.01, ge=0.01, le=1.0)
class StabilityAudioInpaintRequest(StabilityTextToAudioRequest):
mask_start: int = Field(30, ge=0, le=190)
mask_end: int = Field(190, ge=0, le=190)
class StabilityAudioResponse(BaseModel):
audio: Optional[str] = Field(None)

File diff suppressed because it is too large Load Diff

View File

@ -255,6 +255,7 @@ class IdeogramV1(comfy_io.ComfyNode):
display_name="Ideogram V1",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
"prompt",
@ -383,6 +384,7 @@ class IdeogramV2(comfy_io.ComfyNode):
display_name="Ideogram V2",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
"prompt",
@ -552,6 +554,7 @@ class IdeogramV3(comfy_io.ComfyNode):
category="api node/image/Ideogram",
description="Generates images using the Ideogram V3 model. "
"Supports both regular image generation from text prompts and image editing with mask.",
is_api_node=True,
inputs=[
comfy_io.String.Input(
"prompt",
@ -612,11 +615,21 @@ class IdeogramV3(comfy_io.ComfyNode):
),
comfy_io.Combo.Input(
"rendering_speed",
options=["BALANCED", "TURBO", "QUALITY"],
default="BALANCED",
options=["DEFAULT", "TURBO", "QUALITY"],
default="DEFAULT",
tooltip="Controls the trade-off between generation speed and quality",
optional=True,
),
comfy_io.Image.Input(
"character_image",
tooltip="Image to use as character reference.",
optional=True,
),
comfy_io.Mask.Input(
"character_mask",
tooltip="Optional mask for character reference image.",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
@ -639,12 +652,46 @@ class IdeogramV3(comfy_io.ComfyNode):
magic_prompt_option="AUTO",
seed=0,
num_images=1,
rendering_speed="BALANCED",
rendering_speed="DEFAULT",
character_image=None,
character_mask=None,
):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT"
character_img_binary = None
character_mask_binary = None
if character_image is not None:
input_tensor = character_image.squeeze().cpu()
if character_mask is not None:
character_mask = resize_mask_to_image(character_mask, character_image, allow_gradient=False)
character_mask = 1.0 - character_mask
if character_mask.shape[1:] != character_image.shape[1:-1]:
raise Exception("Character mask and image must be the same size")
mask_np = (character_mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
mask_img = Image.fromarray(mask_np)
mask_byte_arr = BytesIO()
mask_img.save(mask_byte_arr, format="PNG")
mask_byte_arr.seek(0)
character_mask_binary = mask_byte_arr
character_mask_binary.name = "mask.png"
img_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(img_np)
img_byte_arr = BytesIO()
img.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
character_img_binary = img_byte_arr
character_img_binary.name = "image.png"
elif character_mask is not None:
raise Exception("Character mask requires character image to be present")
# Check if both image and mask are provided for editing mode
if image is not None and mask is not None:
# Edit mode
@ -693,6 +740,15 @@ class IdeogramV3(comfy_io.ComfyNode):
if num_images > 1:
edit_request.num_images = num_images
files = {
"image": img_binary,
"mask": mask_binary,
}
if character_img_binary:
files["character_reference_images"] = character_img_binary
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
@ -702,10 +758,7 @@ class IdeogramV3(comfy_io.ComfyNode):
response_model=IdeogramGenerateResponse,
),
request=edit_request,
files={
"image": img_binary,
"mask": mask_binary,
},
files=files,
content_type="multipart/form-data",
auth_kwargs=auth,
)
@ -739,6 +792,14 @@ class IdeogramV3(comfy_io.ComfyNode):
if num_images > 1:
gen_request.num_images = num_images
files = {}
if character_img_binary:
files["character_reference_images"] = character_img_binary
if character_mask_binary:
files["character_mask_binary"] = character_mask_binary
if files:
gen_request.style_type = "AUTO"
# Execute the operation for generation mode
operation = SynchronousOperation(
endpoint=ApiEndpoint(
@ -748,6 +809,8 @@ class IdeogramV3(comfy_io.ComfyNode):
response_model=IdeogramGenerateResponse,
),
request=gen_request,
files=files if files else None,
content_type="multipart/form-data",
auth_kwargs=auth,
)

View File

@ -12,6 +12,7 @@ User Guides:
"""
from typing import Union, Optional, Any
from typing_extensions import override
from enum import Enum
import torch
@ -46,9 +47,9 @@ from comfy_api_nodes.apinode_utils import (
validate_string,
download_url_to_image_tensor,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input_impl import VideoFromFile
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@ -85,20 +86,11 @@ class RunwayGen3aAspectRatio(str, Enum):
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if response.output and len(response.output) > 0:
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
return None
# TODO: replace with updated image validation utils (upstream)
def validate_input_image(image: torch.Tensor) -> bool:
"""
Validate the input image is within the size limits for the Runway API.
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
"""
return image.shape[2] < 8000 and image.shape[1] < 8000
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
@ -134,458 +126,438 @@ def extract_progress_from_task_status(
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the image URL from the task status response if it exists."""
if response.output and len(response.output) > 0:
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
return None
class RunwayVideoGenNode(ComfyNodeABC):
"""Runway Video Node Base."""
async def get_response(
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=estimated_duration,
node_id=node_id,
)
RETURN_TYPES = ("VIDEO",)
FUNCTION = "api_call"
CATEGORY = "api node/video/Runway"
API_NODE = True
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
async def generate_video(
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
estimated_duration: Optional[int] = None,
) -> VideoFromFile:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
auth_kwargs=auth_kwargs,
)
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no video data found in response."
)
return True
initial_response = await initial_operation.execute()
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no video data found in response.")
video_url = get_video_url_from_task_status(final_response)
return await download_url_to_video_output(video_url)
class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
@classmethod
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayImageToVideoNodeGen3a",
display_name="Runway Image to Video (Gen3a Turbo)",
category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen3a Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
comfy_io.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
),
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
),
comfy_io.Combo.Input(
"ratio",
options=[model.value for model in RunwayGen3aAspectRatio],
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed for generation",
),
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
async def generate_video(
self,
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
@classmethod
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
) -> comfy_io.NodeOutput:
validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
download_urls = await upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
)
initial_response = await initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
final_response = await self.get_response(task_id, auth_kwargs, node_id)
self.validate_response(final_response)
video_url = get_video_url_from_task_status(final_response)
return (await download_url_to_video_output(video_url),)
return comfy_io.NodeOutput(
await generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
)
)
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen3a Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayImageToVideoNodeGen4",
display_name="Runway Image to Video (Gen4 Turbo)",
category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen4 Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
comfy_io.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
comfy_io.Combo.Input(
"ratio",
enum_type=RunwayGen3aAspectRatio,
options=[model.value for model in RunwayGen4TurboAspectRatio],
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed for generation",
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(
self,
@classmethod
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
) -> comfy_io.NodeOutput:
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Upload image
download_urls = await upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
return comfy_io.NodeOutput(
await generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen4_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
),
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
)
)
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen4 Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayFirstLastFrameNode",
display_name="Runway First-Last-Frame to Video",
category="api node/video/Runway",
description="Upload first and last keyframes, draft a prompt, and generate a video. "
"More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. "
"This would give the generation more time to smoothly transition between the two inputs. "
"Before diving in, review these best practices to ensure that your input selections "
"will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
comfy_io.Image.Input(
"start_frame",
tooltip="Start frame to be used for the video",
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
comfy_io.Image.Input(
"end_frame",
tooltip="End frame to be used for the video. Supported for gen3a_turbo only.",
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
comfy_io.Combo.Input(
"duration",
options=[model.value for model in Duration],
),
comfy_io.Combo.Input(
"ratio",
enum_type=RunwayGen4TurboAspectRatio,
options=[model.value for model in RunwayGen3aAspectRatio],
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed for generation",
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
async def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
# Upload image
download_urls = await upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen4_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
"""Runway First-Last Frame Node."""
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
],
outputs=[
comfy_io.Video.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"end_frame": (
IO.IMAGE,
{
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen3aAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"unique_id": "UNIQUE_ID",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(
self,
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
) -> comfy_io.NodeOutput:
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
validate_input_image(end_frame)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Upload images
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi(
stacked_input_images,
max_images=2,
mime_type="image/png",
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return await self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
]
return comfy_io.NodeOutput(
await generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
]
),
),
),
auth_kwargs=kwargs,
node_id=unique_id,
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
)
)
class RunwayTextToImageNode(ComfyNodeABC):
"""Runway Text to Image Node."""
RETURN_TYPES = ("IMAGE",)
FUNCTION = "api_call"
CATEGORY = "api node/image/Runway"
API_NODE = True
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
class RunwayTextToImageNode(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
def define_schema(cls):
return comfy_io.Schema(
node_id="RunwayTextToImageNode",
display_name="Runway Text to Image",
category="api node/image/Runway",
description="Generate an image from a text prompt using Runway's Gen 4 model. "
"You can also include reference image to guide the generation.",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the generation",
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayTextToImageRequest,
comfy_io.Combo.Input(
"ratio",
enum_type=RunwayTextToImageAspectRatioEnum,
options=[model.value for model in RunwayTextToImageAspectRatioEnum],
),
},
"optional": {
"reference_image": (
IO.IMAGE,
{"tooltip": "Optional reference image to guide the generation"},
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
def validate_response(self, response: TaskStatusResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no image data found in response."
)
return True
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
node_id=node_id,
comfy_io.Image.Input(
"reference_image",
tooltip="Optional reference image to guide the generation",
optional=True,
),
],
outputs=[
comfy_io.Image.Output(),
],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
async def api_call(
self,
@classmethod
async def execute(
cls,
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[torch.Tensor]:
# Validate inputs
) -> comfy_io.NodeOutput:
validate_string(prompt, min_length=1)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Prepare reference images if provided
reference_images = None
if reference_image is not None:
validate_input_image(reference_image)
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
download_urls = await upload_images_to_comfyapi(
reference_image,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload reference image to comfy api.")
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
# Create request
request = RunwayTextToImageRequest(
promptText=prompt,
model=Model4.gen4_image,
@ -593,7 +565,6 @@ class RunwayTextToImageNode(ComfyNodeABC):
referenceImages=reference_images,
)
# Execute initial request
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_IMAGE,
@ -602,34 +573,33 @@ class RunwayTextToImageNode(ComfyNodeABC):
response_model=RunwayTextToImageResponse,
),
request=request,
auth_kwargs=kwargs,
auth_kwargs=auth_kwargs,
)
initial_response = await initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
# Poll for completion
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
final_response = await get_response(
initial_response.id,
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
)
self.validate_response(final_response)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no image data found in response.")
# Download and return image
image_url = get_image_url_from_task_status(final_response)
return (await download_url_to_image_tensor(image_url),)
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
NODE_CLASS_MAPPINGS = {
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
"RunwayTextToImageNode": RunwayTextToImageNode,
}
class RunwayExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
RunwayFirstLastFrameNode,
RunwayImageToVideoNodeGen3a,
RunwayImageToVideoNodeGen4,
RunwayTextToImageNode,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
"RunwayTextToImageNode": "Runway Text to Image",
}
async def comfy_entrypoint() -> RunwayExtension:
return RunwayExtension()

File diff suppressed because it is too large Load Diff

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional
import torch
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
@ -101,7 +101,7 @@ def validate_aspect_ratio_closeness(
def validate_video_dimensions(
video: VideoInput,
video: Input.Video,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
@ -126,7 +126,7 @@ def validate_video_dimensions(
def validate_video_duration(
video: VideoInput,
video: Input.Video,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
):
@ -151,3 +151,17 @@ def get_number_of_images(images):
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
return len(images)
def validate_audio_duration(
audio: Input.Audio,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
) -> None:
sr = int(audio["sample_rate"])
dur = int(audio["waveform"].shape[-1]) / sr
eps = 1.0 / sr
if min_duration is not None and dur + eps < min_duration:
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
if max_duration is not None and dur - eps > max_duration:
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")

View File

@ -181,8 +181,9 @@ class WebUIProgressHandler(ProgressHandler):
}
# Send a combined progress_state message with all node states
# Include client_id to ensure message is only sent to the initiating client
self.server_instance.send_sync(
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}
"progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id
)
@override

View File

@ -1,6 +1,10 @@
#from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html
import numpy as np
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
def loglinear_interp(t_steps, num_steps):
"""
@ -19,25 +23,30 @@ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.694615152
"SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582],
"SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]}
class AlignYourStepsScheduler:
class AlignYourStepsScheduler(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required":
{"model_type": (["SD1", "SDXL", "SVD"], ),
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AlignYourStepsScheduler",
category="sampling/custom_sampling/schedulers",
inputs=[
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
io.Int.Input("steps", default=10, min=1, max=10000),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
],
outputs=[io.Sigmas.Output()],
)
def get_sigmas(self, model_type, steps, denoise):
# Deprecated: use the V3 schema's `execute` method instead of this.
return AlignYourStepsScheduler().execute(model_type, steps, denoise).result
@classmethod
def execute(cls, model_type, steps, denoise) -> io.NodeOutput:
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
return io.NodeOutput(torch.FloatTensor([]))
total_steps = round(steps * denoise)
sigmas = NOISE_LEVELS[model_type][:]
@ -46,8 +55,15 @@ class AlignYourStepsScheduler:
sigmas = sigmas[-(total_steps + 1):]
sigmas[-1] = 0
return (torch.FloatTensor(sigmas), )
return io.NodeOutput(torch.FloatTensor(sigmas))
NODE_CLASS_MAPPINGS = {
"AlignYourStepsScheduler": AlignYourStepsScheduler,
}
class AlignYourStepsExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
AlignYourStepsScheduler,
]
async def comfy_entrypoint() -> AlignYourStepsExtension:
return AlignYourStepsExtension()

View File

@ -162,7 +162,12 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
total_steps = len(args[3])-1
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
# catch division by zero for log statement; sucks to crash after all sampling is done
try:
speedup = total_steps/(total_steps-easycache.total_steps_skipped)
except ZeroDivisionError:
speedup = 1.0
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).")
easycache.reset()
guider.model_options = orig_model_options

View File

@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod:
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"reference_latents_method": (("offset", "index"), ),
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
}}
RETURN_TYPES = ("CONDITIONING",)
@ -115,6 +115,8 @@ class FluxKontextMultiReferenceLatentMethod:
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, reference_latents_method):
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
reference_latents_method = "uxo"
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
return (c, )

View File

@ -113,6 +113,20 @@ class HunyuanImageToVideo:
out_latent["samples"] = latent
return (positive, out_latent)
class EmptyHunyuanImageLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent"
def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
NODE_CLASS_MAPPINGS = {
@ -120,4 +134,5 @@ NODE_CLASS_MAPPINGS = {
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
}

View File

@ -8,13 +8,16 @@ import folder_paths
import comfy.model_management
from comfy.cli_args import args
class EmptyLatentHunyuan3Dv2:
@classmethod
def INPUT_TYPES(s):
return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
return {
"required": {
"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
@ -24,7 +27,6 @@ class EmptyLatentHunyuan3Dv2:
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
return ({"samples": latent, "type": "hunyuan3dv2"}, )
class Hunyuan3Dv2Conditioning:
@classmethod
def INPUT_TYPES(s):
@ -81,7 +83,6 @@ class VOXEL:
def __init__(self, data):
self.data = data
class VAEDecodeHunyuan3D:
@classmethod
def INPUT_TYPES(s):
@ -99,7 +100,6 @@ class VAEDecodeHunyuan3D:
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
return (voxels, )
def voxel_to_mesh(voxels, threshold=0.5, device=None):
if device is None:
device = torch.device("cpu")
@ -230,13 +230,9 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
[0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
], device=device)
corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
for c, (dz, dy, dx) in enumerate(corner_offsets):
corner_values[:, c] = padded[
cell_positions[:, 0] + dz,
cell_positions[:, 1] + dy,
cell_positions[:, 2] + dx
]
pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0)
z_idx, y_idx, x_idx = pos.unbind(-1)
corner_values = padded[z_idx, y_idx, x_idx]
corner_signs = corner_values > threshold
has_inside = torch.any(corner_signs, dim=1)

View File

@ -625,6 +625,37 @@ class ImageFlip:
return (image,)
class ImageScaleToMaxDimension:
upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"]
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",),
"upscale_method": (s.upscale_methods,),
"largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "upscale"
CATEGORY = "image/upscaling"
def upscale(self, image, upscale_method, largest_size):
height = image.shape[1]
width = image.shape[2]
if height > width:
width = round((width / height) * largest_size)
height = largest_size
elif width > height:
height = round((height / width) * largest_size)
width = largest_size
else:
height = largest_size
width = largest_size
samples = image.movedim(-1, 1)
s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled")
s = s.movedim(1, -1)
return (s,)
NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
@ -639,4 +670,5 @@ NODE_CLASS_MAPPINGS = {
"GetImageSize": GetImageSize,
"ImageRotate": ImageRotate,
"ImageFlip": ImageFlip,
"ImageScaleToMaxDimension": ImageScaleToMaxDimension,
}

View File

@ -1,4 +1,5 @@
import torch
from torch import nn
import folder_paths
import comfy.utils
import comfy.ops
@ -58,6 +59,136 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
class SigLIPMultiFeatProjModel(torch.nn.Module):
"""
SigLIP Multi-Feature Projection Model for processing style features from different layers
and projecting them into a unified hidden space.
Args:
siglip_token_nums (int): Number of SigLIP tokens, default 257
style_token_nums (int): Number of style tokens, default 256
siglip_token_dims (int): Dimension of SigLIP tokens, default 1536
hidden_size (int): Hidden layer size, default 3072
context_layer_norm (bool): Whether to use context layer normalization, default False
"""
def __init__(
self,
siglip_token_nums: int = 729,
style_token_nums: int = 64,
siglip_token_dims: int = 1152,
hidden_size: int = 3072,
context_layer_norm: bool = True,
device=None, dtype=None, operations=None
):
super().__init__()
# High-level feature processing (layer -2)
self.high_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.high_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
# Mid-level feature processing (layer -11)
self.mid_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.mid_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
# Low-level feature processing (layer -20)
self.low_embedding_linear = nn.Sequential(
operations.Linear(siglip_token_nums, style_token_nums),
nn.SiLU()
)
self.low_layer_norm = (
operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity()
)
self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True)
def forward(self, siglip_outputs):
"""
Forward pass function
Args:
siglip_outputs: Output from SigLIP model, containing hidden_states
Returns:
torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size]
"""
dtype = next(self.high_embedding_linear.parameters()).dtype
# Process high-level features (layer -2)
high_embedding = self._process_layer_features(
siglip_outputs[2],
self.high_embedding_linear,
self.high_layer_norm,
self.high_projection,
dtype
)
# Process mid-level features (layer -11)
mid_embedding = self._process_layer_features(
siglip_outputs[1],
self.mid_embedding_linear,
self.mid_layer_norm,
self.mid_projection,
dtype
)
# Process low-level features (layer -20)
low_embedding = self._process_layer_features(
siglip_outputs[0],
self.low_embedding_linear,
self.low_layer_norm,
self.low_projection,
dtype
)
# Concatenate features from all layersmodel_patch
return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1)
def _process_layer_features(
self,
hidden_states: torch.Tensor,
embedding_linear: nn.Module,
layer_norm: nn.Module,
projection: nn.Module,
dtype: torch.dtype
) -> torch.Tensor:
"""
Helper function to process features from a single layer
Args:
hidden_states: Input hidden states [bs, seq_len, dim]
embedding_linear: Embedding linear layer
layer_norm: Layer normalization
projection: Projection layer
dtype: Target data type
Returns:
torch.Tensor: Processed features [bs, style_token_nums, hidden_size]
"""
# Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim]
embedding = embedding_linear(
hidden_states.to(dtype).transpose(1, 2)
).transpose(1, 2)
# Apply layer normalization
embedding = layer_norm(embedding)
# Project to target hidden space
embedding = projection(embedding)
return embedding
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
@ -73,9 +204,14 @@ class ModelPatchLoader:
model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name)
sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True)
dtype = comfy.utils.weight_dtype(sd)
# TODO: this node will work with more types of model patches
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
if 'controlnet_blocks.0.y_rms.weight' in sd:
additional_in_dim = sd["img_in.weight"].shape[1] - 64
model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
return (model,)
@ -157,7 +293,51 @@ class QwenImageDiffsynthControlnet:
return (model_patched,)
class UsoStyleProjectorPatch:
def __init__(self, model_patch, encoded_image):
self.model_patch = model_patch
self.encoded_image = encoded_image
def __call__(self, kwargs):
txt_ids = kwargs.get("txt_ids")
txt = kwargs.get("txt")
siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype)
txt = torch.cat([siglip_embedding, txt], dim=1)
kwargs['txt'] = txt
kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1)
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
return self
def models(self):
return [self.model_patch]
class USOStyleReference:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_patch"
EXPERIMENTAL = True
CATEGORY = "advanced/model_patches/flux"
def apply_patch(self, model, model_patch, clip_vision_output):
encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states))
model_patched = model.clone()
model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image))
return (model_patched,)
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"USOStyleReference": USOStyleReference,
}

View File

@ -1,98 +1,109 @@
# Primitive nodes that are evaluated at backend.
from __future__ import annotations
import sys
from typing_extensions import override
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO
from comfy_api.latest import ComfyExtension, io
class String(ComfyNodeABC):
class String(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveString",
display_name="String",
category="utils/primitive",
inputs=[
io.String.Input("value"),
],
outputs=[io.String.Output()],
)
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
class StringMultiline(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.STRING, {"multiline": True,},)},
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: str) -> tuple[str]:
return (value,)
def execute(cls, value: str) -> io.NodeOutput:
return io.NodeOutput(value)
class Int(ComfyNodeABC):
class StringMultiline(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveStringMultiline",
display_name="String (Multiline)",
category="utils/primitive",
inputs=[
io.String.Input("value", multiline=True),
],
outputs=[io.String.Output()],
)
RETURN_TYPES = (IO.INT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: int) -> tuple[int]:
return (value,)
class Float(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})},
}
RETURN_TYPES = (IO.FLOAT,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: float) -> tuple[float]:
return (value,)
def execute(cls, value: str) -> io.NodeOutput:
return io.NodeOutput(value)
class Boolean(ComfyNodeABC):
class Int(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {"value": (IO.BOOLEAN, {})},
}
def define_schema(cls):
return io.Schema(
node_id="PrimitiveInt",
display_name="Int",
category="utils/primitive",
inputs=[
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True),
],
outputs=[io.Int.Output()],
)
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/primitive"
def execute(self, value: bool) -> tuple[bool]:
return (value,)
@classmethod
def execute(cls, value: int) -> io.NodeOutput:
return io.NodeOutput(value)
NODE_CLASS_MAPPINGS = {
"PrimitiveString": String,
"PrimitiveStringMultiline": StringMultiline,
"PrimitiveInt": Int,
"PrimitiveFloat": Float,
"PrimitiveBoolean": Boolean,
}
class Float(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PrimitiveFloat",
display_name="Float",
category="utils/primitive",
inputs=[
io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize),
],
outputs=[io.Float.Output()],
)
NODE_DISPLAY_NAME_MAPPINGS = {
"PrimitiveString": "String",
"PrimitiveStringMultiline": "String (Multiline)",
"PrimitiveInt": "Int",
"PrimitiveFloat": "Float",
"PrimitiveBoolean": "Boolean",
}
@classmethod
def execute(cls, value: float) -> io.NodeOutput:
return io.NodeOutput(value)
class Boolean(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="PrimitiveBoolean",
display_name="Boolean",
category="utils/primitive",
inputs=[
io.Boolean.Input("value"),
],
outputs=[io.Boolean.Output()],
)
@classmethod
def execute(cls, value: bool) -> io.NodeOutput:
return io.NodeOutput(value)
class PrimitivesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
String,
StringMultiline,
Int,
Float,
Boolean,
]
async def comfy_entrypoint() -> PrimitivesExtension:
return PrimitivesExtension()

View File

@ -17,55 +17,61 @@
"""
import torch
import nodes
from typing_extensions import override
import comfy.utils
import nodes
from comfy_api.latest import ComfyExtension, io
class StableCascade_EmptyLatentImage:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_EmptyLatentImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_EmptyLatentImage",
category="latent/stable_cascade",
inputs=[
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, width, height, compression, batch_size=1):
def execute(cls, width, height, compression, batch_size=1):
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
return ({
return io.NodeOutput({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageC_VAEEncode:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_StageC_VAEEncode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_StageC_VAEEncode",
category="latent/stable_cascade",
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
io.Int.Input("compression", default=42, min=4, max=128, step=1),
],
outputs=[
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
}}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "latent/stable_cascade"
def generate(self, image, vae, compression):
def execute(cls, image, vae, compression):
width = image.shape[-2]
height = image.shape[-3]
out_width = (width // compression) * vae.downscale_ratio
@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode:
c_latent = vae.encode(s[:,:,:,:3])
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
return ({
return io.NodeOutput({
"samples": c_latent,
}, {
"samples": b_latent,
})
class StableCascade_StageB_Conditioning:
class StableCascade_StageB_Conditioning(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "conditioning": ("CONDITIONING",),
"stage_c": ("LATENT",),
}}
RETURN_TYPES = ("CONDITIONING",)
def define_schema(cls):
return io.Schema(
node_id="StableCascade_StageB_Conditioning",
category="conditioning/stable_cascade",
inputs=[
io.Conditioning.Input("conditioning"),
io.Latent.Input("stage_c"),
],
outputs=[
io.Conditioning.Output(),
],
)
FUNCTION = "set_prior"
CATEGORY = "conditioning/stable_cascade"
def set_prior(self, conditioning, stage_c):
@classmethod
def execute(cls, conditioning, stage_c):
c = []
for t in conditioning:
d = t[1].copy()
d['stable_cascade_prior'] = stage_c['samples']
d["stable_cascade_prior"] = stage_c["samples"]
n = [t[0], d]
c.append(n)
return (c, )
return io.NodeOutput(c)
class StableCascade_SuperResolutionControlnet:
def __init__(self, device="cpu"):
self.device = device
class StableCascade_SuperResolutionControlnet(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="StableCascade_SuperResolutionControlnet",
category="_for_testing/stable_cascade",
is_experimental=True,
inputs=[
io.Image.Input("image"),
io.Vae.Input("vae"),
],
outputs=[
io.Image.Output(display_name="controlnet_input"),
io.Latent.Output(display_name="stage_c"),
io.Latent.Output(display_name="stage_b"),
],
)
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
}}
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
EXPERIMENTAL = True
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):
def execute(cls, image, vae):
width = image.shape[-2]
height = image.shape[-3]
batch_size = image.shape[0]
@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet:
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
return (controlnet_input, {
return io.NodeOutput(controlnet_input, {
"samples": c_latent,
}, {
"samples": b_latent,
})
NODE_CLASS_MAPPINGS = {
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
}
class StableCascadeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
StableCascade_EmptyLatentImage,
StableCascade_StageB_Conditioning,
StableCascade_StageC_VAEEncode,
StableCascade_SuperResolutionControlnet,
]
async def comfy_entrypoint() -> StableCascadeExtension:
return StableCascadeExtension()

View File

@ -5,52 +5,49 @@ import av
import torch
import folder_paths
import json
from typing import Optional, Literal
from typing import Optional
from typing_extensions import override
from fractions import Fraction
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
from comfy_api.latest import Input, InputImpl, Types
from comfy_api.input import AudioInput, ImageInput, VideoInput
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
from comfy_api.latest import ComfyExtension, io, ui
from comfy.cli_args import args
class SaveWEBM:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
class SaveWEBM(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveWEBM",
category="image/video",
is_experimental=True,
inputs=[
io.Image.Input("images"),
io.String.Input("filename_prefix", default="ComfyUI"),
io.Combo.Input("codec", options=["vp9", "av1"]),
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"codec": (["vp9", "av1"],),
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_images"
OUTPUT_NODE = True
CATEGORY = "image/video"
EXPERIMENTAL = True
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
)
file = f"{filename}_{counter:05}_.webm"
container = av.open(os.path.join(full_output_folder, file), mode="w")
if prompt is not None:
container.metadata["prompt"] = json.dumps(prompt)
if cls.hidden.prompt is not None:
container.metadata["prompt"] = json.dumps(cls.hidden.prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
container.metadata[x] = json.dumps(extra_pnginfo[x])
if cls.hidden.extra_pnginfo is not None:
for x in cls.hidden.extra_pnginfo:
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
@ -69,63 +66,46 @@ class SaveWEBM:
container.mux(stream.encode())
container.close()
results: list[FileLocator] = [{
"filename": file,
"subfolder": subfolder,
"type": self.type
}]
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
class SaveVideo(ComfyNodeABC):
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type: Literal["output"] = "output"
self.prefix_append = ""
class SaveVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveVideo",
display_name="Save Video",
category="image/video",
description="Saves the input images to your ComfyUI output directory.",
inputs=[
io.Video.Input("video", tooltip="The video to save."),
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
],
outputs=[],
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
},
"hidden": {
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO"
},
}
RETURN_TYPES = ()
FUNCTION = "save_video"
OUTPUT_NODE = True
CATEGORY = "image/video"
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
self.output_dir,
folder_paths.get_output_directory(),
width,
height
)
results: list[FileLocator] = list()
saved_metadata = None
if not args.disable_metadata:
metadata = {}
if extra_pnginfo is not None:
metadata.update(extra_pnginfo)
if prompt is not None:
metadata["prompt"] = prompt
if cls.hidden.extra_pnginfo is not None:
metadata.update(cls.hidden.extra_pnginfo)
if cls.hidden.prompt is not None:
metadata["prompt"] = cls.hidden.prompt
if len(metadata) > 0:
saved_metadata = metadata
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
format=format,
@ -133,83 +113,82 @@ class SaveVideo(ComfyNodeABC):
metadata=saved_metadata
)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
return { "ui": { "images": results, "animated": (True,) } }
class CreateVideo(ComfyNodeABC):
class CreateVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": (IO.IMAGE, {"tooltip": "The images to create a video from."}),
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}),
},
"optional": {
"audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}),
}
}
def define_schema(cls):
return io.Schema(
node_id="CreateVideo",
display_name="Create Video",
category="image/video",
description="Create a video from images.",
inputs=[
io.Image.Input("images", tooltip="The images to create a video from."),
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
],
outputs=[
io.Video.Output(),
],
)
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "create_video"
CATEGORY = "image/video"
DESCRIPTION = "Create a video from images."
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
return (InputImpl.VideoFromComponents(
Types.VideoComponents(
images=images,
audio=audio,
frame_rate=Fraction(fps),
)
),)
class GetVideoComponents(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"video": (IO.VIDEO, {"tooltip": "The video to extract components from."}),
}
}
RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT)
RETURN_NAMES = ("images", "audio", "fps")
FUNCTION = "get_components"
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
return io.NodeOutput(
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
)
CATEGORY = "image/video"
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
class GetVideoComponents(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GetVideoComponents",
display_name="Get Video Components",
category="image/video",
description="Extracts all components from a video: frames, audio, and framerate.",
inputs=[
io.Video.Input("video", tooltip="The video to extract components from."),
],
outputs=[
io.Image.Output(display_name="images"),
io.Audio.Output(display_name="audio"),
io.Float.Output(display_name="fps"),
],
)
def get_components(self, video: Input.Video):
@classmethod
def execute(cls, video: VideoInput) -> io.NodeOutput:
components = video.get_components()
return (components.images, components.audio, float(components.frame_rate))
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
class LoadVideo(ComfyNodeABC):
class LoadVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
def define_schema(cls):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
files = folder_paths.filter_files_content_types(files, ["video"])
return {"required":
{"file": (sorted(files), {"video_upload": True})},
}
CATEGORY = "image/video"
RETURN_TYPES = (IO.VIDEO,)
FUNCTION = "load_video"
def load_video(self, file):
video_path = folder_paths.get_annotated_filepath(file)
return (InputImpl.VideoFromFile(video_path),)
return io.Schema(
node_id="LoadVideo",
display_name="Load Video",
category="image/video",
inputs=[
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
],
outputs=[
io.Video.Output(),
],
)
@classmethod
def IS_CHANGED(cls, file):
def execute(cls, file) -> io.NodeOutput:
video_path = folder_paths.get_annotated_filepath(file)
return io.NodeOutput(VideoFromFile(video_path))
@classmethod
def fingerprint_inputs(s, file):
video_path = folder_paths.get_annotated_filepath(file)
mod_time = os.path.getmtime(video_path)
# Instead of hashing the file, we can just use the modification time to avoid
@ -217,24 +196,23 @@ class LoadVideo(ComfyNodeABC):
return mod_time
@classmethod
def VALIDATE_INPUTS(cls, file):
def validate_inputs(s, file):
if not folder_paths.exists_annotated_filepath(file):
return "Invalid video file: {}".format(file)
return True
NODE_CLASS_MAPPINGS = {
"SaveWEBM": SaveWEBM,
"SaveVideo": SaveVideo,
"CreateVideo": CreateVideo,
"GetVideoComponents": GetVideoComponents,
"LoadVideo": LoadVideo,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaveVideo": "Save Video",
"CreateVideo": "Create Video",
"GetVideoComponents": "Get Video Components",
"LoadVideo": "Load Video",
}
class VideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SaveWEBM,
SaveVideo,
CreateVideo,
GetVideoComponents,
LoadVideo,
]
async def comfy_entrypoint() -> VideoExtension:
return VideoExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.55"
__version__ = "0.3.59"

View File

@ -113,7 +113,6 @@ import gc
if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
if __name__ == "__main__":
if args.default_device is not None:

1
middleware/__init__.py Normal file
View File

@ -0,0 +1 @@
"""Server middleware modules"""

View File

@ -0,0 +1,52 @@
"""Cache control middleware for ComfyUI server"""
from aiohttp import web
from typing import Callable, Awaitable
# Time in seconds
ONE_HOUR: int = 3600
ONE_DAY: int = 86400
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
)
@web.middleware
async def cache_control(
request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]]
) -> web.Response:
"""Cache control middleware that sets appropriate cache headers based on file type and response status"""
response: web.Response = await handler(request)
if (
request.path.endswith(".js")
or request.path.endswith(".css")
or request.path.endswith("index.json")
):
response.headers.setdefault("Cache-Control", "no-cache")
return response
# Early return for non-image files - no cache headers needed
if not request.path.lower().endswith(IMG_EXTENSIONS):
return response
# Handle image files
if response.status == 404:
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}")
elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308):
# Success responses and permanent redirects - cache for 1 day
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}")
elif response.status in (302, 303, 307):
# Temporary redirects - no cache
response.headers.setdefault("Cache-Control", "no-cache")
# Note: 304 Not Modified falls through - no cache headers set
return response

View File

@ -925,7 +925,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -953,7 +953,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@ -963,7 +963,7 @@ class DualCLIPLoader:
CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama"
DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small"
def load_clip(self, clip_name1, clip_name2, type, device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
@ -2345,6 +2345,7 @@ async def init_builtin_api_nodes():
"nodes_veo2.py",
"nodes_kling.py",
"nodes_bfl.py",
"nodes_bytedance.py",
"nodes_luma.py",
"nodes_recraft.py",
"nodes_pixverse.py",

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.55"
version = "0.3.59"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.25.11
comfyui-workflow-templates==0.1.70
comfyui-workflow-templates==0.1.81
comfyui-embedded-docs==0.2.6
torch
torchsde

View File

@ -3,11 +3,7 @@ from urllib import request
#This is the ComfyUI api prompt format.
#If you want it for a specific workflow you can "enable dev mode options"
#in the settings of the UI (gear beside the "Queue Size: ") this will enable
#a button on the UI to save workflows in api format.
#keep in mind ComfyUI is pre alpha software so this format will change a bit.
#If you want it for a specific workflow you can "File -> Export (API)" in the interface.
#this is the one for the default workflow
prompt_text = """

View File

@ -39,20 +39,15 @@ from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
# Import cache control middleware
from middleware.cache_middleware import cache_control
async def send_socket_catch_exception(function, message):
try:
await function(message)
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
logging.warning("send error: {}".format(err))
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'):
response.headers.setdefault('Cache-Control', 'no-cache')
return response
@web.middleware
async def compress_body(request: web.Request, handler):
accept_encoding = request.headers.get("Accept-Encoding", "")
@ -729,7 +724,34 @@ class PromptServer():
@routes.post("/interrupt")
async def post_interrupt(request):
nodes.interrupt_processing()
try:
json_data = await request.json()
except json.JSONDecodeError:
json_data = {}
# Check if a specific prompt_id was provided for targeted interruption
prompt_id = json_data.get('prompt_id')
if prompt_id:
currently_running, _ = self.prompt_queue.get_current_queue()
# Check if the prompt_id matches any currently running prompt
should_interrupt = False
for item in currently_running:
# item structure: (number, prompt_id, prompt, extra_data, outputs_to_execute)
if item[1] == prompt_id:
logging.info(f"Interrupting prompt {prompt_id}")
should_interrupt = True
break
if should_interrupt:
nodes.interrupt_processing()
else:
logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt")
else:
# No prompt_id provided, do a global interrupt
logging.info("Global interrupt (no prompt_id specified)")
nodes.interrupt_processing()
return web.Response(status=200)
@routes.post("/free")

View File

@ -0,0 +1,255 @@
"""Tests for server cache control middleware"""
import pytest
from aiohttp import web
from aiohttp.test_utils import make_mocked_request
from typing import Dict, Any
from middleware.cache_middleware import cache_control, ONE_HOUR, ONE_DAY, IMG_EXTENSIONS
pytestmark = pytest.mark.asyncio # Apply asyncio mark to all tests
# Test configuration data
CACHE_SCENARIOS = [
# Image file scenarios
{
"name": "image_200_status",
"path": "/test.jpg",
"status": 200,
"expected_cache": f"public, max-age={ONE_DAY}",
"should_have_header": True,
},
{
"name": "image_404_status",
"path": "/missing.jpg",
"status": 404,
"expected_cache": f"public, max-age={ONE_HOUR}",
"should_have_header": True,
},
# JavaScript/CSS scenarios
{
"name": "js_no_cache",
"path": "/script.js",
"status": 200,
"expected_cache": "no-cache",
"should_have_header": True,
},
{
"name": "css_no_cache",
"path": "/styles.css",
"status": 200,
"expected_cache": "no-cache",
"should_have_header": True,
},
{
"name": "index_json_no_cache",
"path": "/api/index.json",
"status": 200,
"expected_cache": "no-cache",
"should_have_header": True,
},
# Non-matching files
{
"name": "html_no_header",
"path": "/index.html",
"status": 200,
"expected_cache": None,
"should_have_header": False,
},
{
"name": "txt_no_header",
"path": "/data.txt",
"status": 200,
"expected_cache": None,
"should_have_header": False,
},
{
"name": "api_endpoint_no_header",
"path": "/api/endpoint",
"status": 200,
"expected_cache": None,
"should_have_header": False,
},
{
"name": "pdf_no_header",
"path": "/file.pdf",
"status": 200,
"expected_cache": None,
"should_have_header": False,
},
]
# Status code scenarios for images
IMAGE_STATUS_SCENARIOS = [
# Success statuses get long cache
{"status": 200, "expected": f"public, max-age={ONE_DAY}"},
{"status": 201, "expected": f"public, max-age={ONE_DAY}"},
{"status": 202, "expected": f"public, max-age={ONE_DAY}"},
{"status": 204, "expected": f"public, max-age={ONE_DAY}"},
{"status": 206, "expected": f"public, max-age={ONE_DAY}"},
# Permanent redirects get long cache
{"status": 301, "expected": f"public, max-age={ONE_DAY}"},
{"status": 308, "expected": f"public, max-age={ONE_DAY}"},
# Temporary redirects get no cache
{"status": 302, "expected": "no-cache"},
{"status": 303, "expected": "no-cache"},
{"status": 307, "expected": "no-cache"},
# 404 gets short cache
{"status": 404, "expected": f"public, max-age={ONE_HOUR}"},
]
# Case sensitivity test paths
CASE_SENSITIVITY_PATHS = ["/image.JPG", "/photo.PNG", "/pic.JpEg"]
# Edge case test paths
EDGE_CASE_PATHS = [
{
"name": "query_strings_ignored",
"path": "/image.jpg?v=123&size=large",
"expected": f"public, max-age={ONE_DAY}",
},
{
"name": "multiple_dots_in_path",
"path": "/image.min.jpg",
"expected": f"public, max-age={ONE_DAY}",
},
{
"name": "nested_paths_with_images",
"path": "/static/images/photo.jpg",
"expected": f"public, max-age={ONE_DAY}",
},
]
class TestCacheControl:
"""Test cache control middleware functionality"""
@pytest.fixture
def status_handler_factory(self):
"""Create a factory for handlers that return specific status codes"""
def factory(status: int, headers: Dict[str, str] = None):
async def handler(request):
return web.Response(status=status, headers=headers or {})
return handler
return factory
@pytest.fixture
def mock_handler(self, status_handler_factory):
"""Create a mock handler that returns a response with 200 status"""
return status_handler_factory(200)
@pytest.fixture
def handler_with_existing_cache(self, status_handler_factory):
"""Create a handler that returns response with existing Cache-Control header"""
return status_handler_factory(200, {"Cache-Control": "max-age=3600"})
async def assert_cache_header(
self,
response: web.Response,
expected_cache: str = None,
should_have_header: bool = True,
):
"""Helper to assert cache control headers"""
if should_have_header:
assert "Cache-Control" in response.headers
if expected_cache:
assert response.headers["Cache-Control"] == expected_cache
else:
assert "Cache-Control" not in response.headers
# Parameterized tests
@pytest.mark.parametrize("scenario", CACHE_SCENARIOS, ids=lambda x: x["name"])
async def test_cache_control_scenarios(
self, scenario: Dict[str, Any], status_handler_factory
):
"""Test various cache control scenarios"""
handler = status_handler_factory(scenario["status"])
request = make_mocked_request("GET", scenario["path"])
response = await cache_control(request, handler)
assert response.status == scenario["status"]
await self.assert_cache_header(
response, scenario["expected_cache"], scenario["should_have_header"]
)
@pytest.mark.parametrize("ext", IMG_EXTENSIONS)
async def test_all_image_extensions(self, ext: str, mock_handler):
"""Test all defined image extensions are handled correctly"""
request = make_mocked_request("GET", f"/image{ext}")
response = await cache_control(request, mock_handler)
assert response.status == 200
assert "Cache-Control" in response.headers
assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}"
@pytest.mark.parametrize(
"status_scenario", IMAGE_STATUS_SCENARIOS, ids=lambda x: f"status_{x['status']}"
)
async def test_image_status_codes(
self, status_scenario: Dict[str, Any], status_handler_factory
):
"""Test different status codes for image requests"""
handler = status_handler_factory(status_scenario["status"])
request = make_mocked_request("GET", "/image.jpg")
response = await cache_control(request, handler)
assert response.status == status_scenario["status"]
assert "Cache-Control" in response.headers
assert response.headers["Cache-Control"] == status_scenario["expected"]
@pytest.mark.parametrize("path", CASE_SENSITIVITY_PATHS)
async def test_case_insensitive_image_extension(self, path: str, mock_handler):
"""Test that image extensions are matched case-insensitively"""
request = make_mocked_request("GET", path)
response = await cache_control(request, mock_handler)
assert "Cache-Control" in response.headers
assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}"
@pytest.mark.parametrize("edge_case", EDGE_CASE_PATHS, ids=lambda x: x["name"])
async def test_edge_cases(self, edge_case: Dict[str, str], mock_handler):
"""Test edge cases like query strings, nested paths, etc."""
request = make_mocked_request("GET", edge_case["path"])
response = await cache_control(request, mock_handler)
assert "Cache-Control" in response.headers
assert response.headers["Cache-Control"] == edge_case["expected"]
# Header preservation tests (special cases not covered by parameterization)
async def test_js_preserves_existing_headers(self, handler_with_existing_cache):
"""Test that .js files preserve existing Cache-Control headers"""
request = make_mocked_request("GET", "/script.js")
response = await cache_control(request, handler_with_existing_cache)
# setdefault should preserve existing header
assert response.headers["Cache-Control"] == "max-age=3600"
async def test_css_preserves_existing_headers(self, handler_with_existing_cache):
"""Test that .css files preserve existing Cache-Control headers"""
request = make_mocked_request("GET", "/styles.css")
response = await cache_control(request, handler_with_existing_cache)
# setdefault should preserve existing header
assert response.headers["Cache-Control"] == "max-age=3600"
async def test_image_preserves_existing_headers(self, status_handler_factory):
"""Test that image cache headers preserve existing Cache-Control"""
handler = status_handler_factory(200, {"Cache-Control": "private, no-cache"})
request = make_mocked_request("GET", "/image.jpg")
response = await cache_control(request, handler)
# setdefault should preserve existing header
assert response.headers["Cache-Control"] == "private, no-cache"
async def test_304_not_modified_inherits_cache(self, status_handler_factory):
"""Test that 304 Not Modified doesn't set cache headers for images"""
handler = status_handler_factory(304, {"Cache-Control": "max-age=7200"})
request = make_mocked_request("GET", "/not-modified.jpg")
response = await cache_control(request, handler)
assert response.status == 304
# Should preserve existing cache header, not override
assert response.headers["Cache-Control"] == "max-age=7200"

View File

@ -6,6 +6,7 @@ def pytest_addoption(parser):
parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images')
parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.addoption("--port", type=int, default=8188, help="Set the listen port.")
parser.addoption("--skip-timing-checks", action="store_true", default=False, help="Skip timing-related assertions in tests (useful for CI environments with variable performance)")
# This initializes args at the beginning of the test session
@pytest.fixture(scope="session", autouse=True)
@ -19,6 +20,11 @@ def args_pytest(pytestconfig):
return args
@pytest.fixture(scope="session")
def skip_timing_checks(pytestconfig):
"""Fixture that returns whether timing checks should be skipped."""
return pytestconfig.getoption("--skip-timing-checks")
def pytest_collection_modifyitems(items):
# Modifies items so tests run in the correct order

View File

@ -7,7 +7,7 @@ import subprocess
from pytest import fixture
from comfy_execution.graph_utils import GraphBuilder
from tests.inference.test_execution import ComfyClient, run_warmup
from tests.execution.test_execution import ComfyClient, run_warmup
@pytest.mark.execution
@ -23,7 +23,7 @@ class TestAsyncNodes:
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
@ -81,7 +81,7 @@ class TestAsyncNodes:
assert len(result_images) == 1, "Should have 1 image"
assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black"
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
"""Test that multiple async nodes execute in parallel."""
# Warmup execution to ensure server is fully initialized
run_warmup(client)
@ -104,7 +104,8 @@ class TestAsyncNodes:
elapsed_time = time.time() - start_time
# Should take ~0.5s (max duration) not 1.2s (sum of durations)
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
if not skip_timing_checks:
assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s"
# Verify all nodes executed
assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3)
@ -150,7 +151,7 @@ class TestAsyncNodes:
with pytest.raises(urllib.error.HTTPError):
client.run(g)
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
"""Test async nodes with lazy evaluation."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_lazy")
@ -173,7 +174,8 @@ class TestAsyncNodes:
elapsed_time = time.time() - start_time
# Should only execute sleep1, not sleep2
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
if not skip_timing_checks:
assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s"
assert result.did_run(sleep1), "Sleep1 should have executed"
assert not result.did_run(sleep2), "Sleep2 should have been skipped"
@ -310,7 +312,7 @@ class TestAsyncNodes:
images = result.get_images(output)
assert len(images) == 1, "Should have blocked second image"
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
"""Test that async nodes are properly cached."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_cache")
@ -330,9 +332,10 @@ class TestAsyncNodes:
elapsed_time = time.time() - start_time
assert not result2.did_run(sleep_node), "Should be cached"
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
if not skip_timing_checks:
assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant"
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
"""Test async nodes within dynamically generated prompts."""
# Warmup execution to ensure server is fully initialized
run_warmup(client, prefix="warmup_dynamic")
@ -345,8 +348,8 @@ class TestAsyncNodes:
dynamic_async = g.node("TestDynamicAsyncGeneration",
image1=image1.out(0),
image2=image2.out(0),
num_async_nodes=3,
sleep_duration=0.2)
num_async_nodes=5,
sleep_duration=0.4)
g.node("SaveImage", images=dynamic_async.out(0))
start_time = time.time()
@ -354,7 +357,8 @@ class TestAsyncNodes:
elapsed_time = time.time() - start_time
# Should execute async nodes in parallel within dynamic prompt
assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s"
if not skip_timing_checks:
assert elapsed_time < 1.0, f"Dynamic async execution took {elapsed_time}s"
assert result.did_run(dynamic_async)
def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder):

View File

@ -149,7 +149,7 @@ class TestExecution:
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
use_lru, lru_size = request.param
@ -518,7 +518,7 @@ class TestExecution:
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached"
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder):
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
@ -541,14 +541,15 @@ class TestExecution:
# The test should take around 3.0 seconds (the longest sleep duration)
# plus some overhead, but definitely less than the sum of all sleeps (9.0s)
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
if not skip_timing_checks:
assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s"
# Verify that all nodes executed
assert result.did_run(sleep_node1), "Sleep node 1 should have run"
assert result.did_run(sleep_node2), "Sleep node 2 should have run"
assert result.did_run(sleep_node3), "Sleep node 3 should have run"
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder):
def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
# Warmup execution to ensure server is fully initialized
run_warmup(client)
@ -574,7 +575,9 @@ class TestExecution:
# Similar to the previous test, expect parallel execution of the sleep nodes
# which should complete in less than the sum of all sleeps
assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s"
# Lots of leeway here since Windows CI is slow
if not skip_timing_checks:
assert elapsed_time < 13.0, f"Expansion execution took {elapsed_time}s"
# Verify the parallel sleep node executed
assert result.did_run(parallel_sleep), "ParallelSleep node should have run"

View File

@ -0,0 +1,233 @@
"""Test that progress updates are properly isolated between WebSocket clients."""
import json
import pytest
import time
import threading
import uuid
import websocket
from typing import List, Dict, Any
from comfy_execution.graph_utils import GraphBuilder
from tests.execution.test_execution import ComfyClient
class ProgressTracker:
"""Tracks progress messages received by a WebSocket client."""
def __init__(self, client_id: str):
self.client_id = client_id
self.progress_messages: List[Dict[str, Any]] = []
self.lock = threading.Lock()
def add_message(self, message: Dict[str, Any]):
"""Thread-safe addition of progress messages."""
with self.lock:
self.progress_messages.append(message)
def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]:
"""Get all progress messages for a specific prompt_id."""
with self.lock:
return [
msg for msg in self.progress_messages
if msg.get('data', {}).get('prompt_id') == prompt_id
]
def has_cross_contamination(self, own_prompt_id: str) -> bool:
"""Check if this client received progress for other prompts."""
with self.lock:
for msg in self.progress_messages:
msg_prompt_id = msg.get('data', {}).get('prompt_id')
if msg_prompt_id and msg_prompt_id != own_prompt_id:
return True
return False
class IsolatedClient(ComfyClient):
"""Extended ComfyClient that tracks all WebSocket messages."""
def __init__(self):
super().__init__()
self.progress_tracker = None
self.all_messages: List[Dict[str, Any]] = []
def connect(self, listen='127.0.0.1', port=8188, client_id=None):
"""Connect with a specific client_id and set up message tracking."""
if client_id is None:
client_id = str(uuid.uuid4())
super().connect(listen, port, client_id)
self.progress_tracker = ProgressTracker(client_id)
def listen_for_messages(self, duration: float = 5.0):
"""Listen for WebSocket messages for a specified duration."""
end_time = time.time() + duration
self.ws.settimeout(0.5) # Non-blocking with timeout
while time.time() < end_time:
try:
out = self.ws.recv()
if isinstance(out, str):
message = json.loads(out)
self.all_messages.append(message)
# Track progress_state messages
if message.get('type') == 'progress_state':
self.progress_tracker.add_message(message)
except websocket.WebSocketTimeoutException:
continue
except Exception:
# Log error silently in test context
break
@pytest.mark.execution
class TestProgressIsolation:
"""Test suite for verifying progress update isolation between clients."""
@pytest.fixture(scope="class", autouse=True)
def _server(self, args_pytest):
"""Start the ComfyUI server for testing."""
import subprocess
pargs = [
'python', 'main.py',
'--output-directory', args_pytest["output_dir"],
'--listen', args_pytest["listen"],
'--port', str(args_pytest["port"]),
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
'--cpu',
]
p = subprocess.Popen(pargs)
yield
p.kill()
def start_client_with_retry(self, listen: str, port: int, client_id: str = None):
"""Start client with connection retries."""
client = IsolatedClient()
# Connect to server (with retries)
n_tries = 5
for i in range(n_tries):
time.sleep(4)
try:
client.connect(listen, port, client_id)
return client
except ConnectionRefusedError as e:
print(e) # noqa: T201
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts")
def test_progress_isolation_between_clients(self, args_pytest):
"""Test that progress updates are isolated between different clients."""
listen = args_pytest["listen"]
port = args_pytest["port"]
# Create two separate clients with unique IDs
client_a_id = "client_a_" + str(uuid.uuid4())
client_b_id = "client_b_" + str(uuid.uuid4())
try:
# Connect both clients with retries
client_a = self.start_client_with_retry(listen, port, client_a_id)
client_b = self.start_client_with_retry(listen, port, client_b_id)
# Create simple workflows for both clients
graph_a = GraphBuilder(prefix="client_a")
image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
graph_a.node("PreviewImage", images=image_a.out(0))
graph_b = GraphBuilder(prefix="client_b")
image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
graph_b.node("PreviewImage", images=image_b.out(0))
# Submit workflows from both clients
prompt_a = graph_a.finalize()
prompt_b = graph_b.finalize()
response_a = client_a.queue_prompt(prompt_a)
prompt_id_a = response_a['prompt_id']
response_b = client_b.queue_prompt(prompt_b)
prompt_id_b = response_b['prompt_id']
# Start threads to listen for messages on both clients
def listen_client_a():
client_a.listen_for_messages(duration=10.0)
def listen_client_b():
client_b.listen_for_messages(duration=10.0)
thread_a = threading.Thread(target=listen_client_a)
thread_b = threading.Thread(target=listen_client_b)
thread_a.start()
thread_b.start()
# Wait for threads to complete
thread_a.join()
thread_b.join()
# Verify isolation
# Client A should only receive progress for prompt_id_a
assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \
f"Client A received progress updates for other clients' workflows. " \
f"Expected only {prompt_id_a}, but got messages for multiple prompts."
# Client B should only receive progress for prompt_id_b
assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \
f"Client B received progress updates for other clients' workflows. " \
f"Expected only {prompt_id_b}, but got messages for multiple prompts."
# Verify each client received their own progress updates
client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a)
client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b)
assert len(client_a_messages) > 0, \
"Client A did not receive any progress updates for its own workflow"
assert len(client_b_messages) > 0, \
"Client B did not receive any progress updates for its own workflow"
# Ensure no cross-contamination
client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b)
client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a)
assert len(client_a_other) == 0, \
f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow"
assert len(client_b_other) == 0, \
f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow"
finally:
# Clean up connections
if hasattr(client_a, 'ws'):
client_a.ws.close()
if hasattr(client_b, 'ws'):
client_b.ws.close()
def test_progress_with_missing_client_id(self, args_pytest):
"""Test that progress updates handle missing client_id gracefully."""
listen = args_pytest["listen"]
port = args_pytest["port"]
try:
# Connect client with retries
client = self.start_client_with_retry(listen, port)
# Create a simple workflow
graph = GraphBuilder(prefix="test_missing_id")
image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1)
graph.node("PreviewImage", images=image.out(0))
# Submit workflow
prompt = graph.finalize()
response = client.queue_prompt(prompt)
prompt_id = response['prompt_id']
# Listen for messages
client.listen_for_messages(duration=5.0)
# Should still receive progress updates for own workflow
messages = client.progress_tracker.get_messages_for_prompt(prompt_id)
assert len(messages) > 0, \
"Client did not receive progress updates even though it initiated the workflow"
finally:
if hasattr(client, 'ws'):
client.ws.close()