diff --git a/.github/workflows/test-execution.yml b/.github/workflows/test-execution.yml new file mode 100644 index 000000000..00ef07ebf --- /dev/null +++ b/.github/workflows/test-execution.yml @@ -0,0 +1,30 @@ +name: Execution Tests + +on: + push: + branches: [ main, master ] + pull_request: + branches: [ main, master ] + +jobs: + test: + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + runs-on: ${{ matrix.os }} + continue-on-error: true + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.12' + - name: Install requirements + run: | + python -m pip install --upgrade pip + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + pip install -r tests-unit/requirements.txt + - name: Run Execution Tests + run: | + python -m pytest tests/execution -v --skip-timing-checks diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 638d01119..fa12cd875 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -143,8 +143,9 @@ class PerformanceFeature(enum.Enum): Fp16Accumulation = "fp16_accumulation" Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" + AutoTune = "autotune" -parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops") +parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.") diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7e47d8a55..7c0cadab5 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -61,8 +61,12 @@ class CLIPEncoder(torch.nn.Module): def forward(self, x, mask=None, intermediate_output=None): optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) + all_intermediate = None if intermediate_output is not None: - if intermediate_output < 0: + if intermediate_output == "all": + all_intermediate = [] + intermediate_output = None + elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output intermediate = None @@ -70,6 +74,12 @@ class CLIPEncoder(torch.nn.Module): x = l(x, mask, optimized_attention) if i == intermediate_output: intermediate = x.clone() + if all_intermediate is not None: + all_intermediate.append(x.unsqueeze(1).clone()) + + if all_intermediate is not None: + intermediate = torch.cat(all_intermediate, dim=1) + return x, intermediate class CLIPEmbeddings(torch.nn.Module): diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 00aab9164..447b1ce4a 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -50,7 +50,13 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) - model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model")) + model_type = config.get("model_type", "clip_vision_model") + model_class = IMAGE_ENCODERS.get(model_type) + if model_type == "siglip_vision_model": + self.return_all_hidden_states = True + else: + self.return_all_hidden_states = False + self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) @@ -68,12 +74,18 @@ class ClipVisionModel(): def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() - out = self.model(pixel_values=pixel_values, intermediate_output=-2) + out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) - outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) + if self.return_all_hidden_states: + all_hs = out[1].to(comfy.model_management.intermediate_device()) + outputs["penultimate_hidden_states"] = all_hs[:, -2] + outputs["all_hidden_states"] = all_hs + else: + outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device()) + outputs["mm_projected"] = out[3] return outputs @@ -124,8 +136,12 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") else: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") - elif "embeddings.patch_embeddings.projection.weight" in sd: + + # Dinov2 + elif 'encoder.layer.39.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") + elif 'encoder.layer.23.layer_scale2.lambda1' in sd: + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") else: return None diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 54d83c069..11273e32f 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -288,7 +288,10 @@ class ControlNet(ControlBase): to_concat = [] for c in self.extra_concat_orig: c = c.to(self.cond_hint.device) - c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center") + c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center") + if c.ndim < self.cond_hint.ndim: + c = c.unsqueeze(2) + c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2) to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0])) self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1) @@ -628,11 +631,18 @@ def load_controlnet_flux_instantx(sd, model_options={}): def load_controlnet_qwen_instantx(sd, model_options={}): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) - control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1] + + extra_condition_channels = 0 + concat_mask = False + if control_latent_channels == 68: #inpaint controlnet + extra_condition_channels = control_latent_channels - 64 + concat_mask = True + control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_load_state_dict(control_model, sd) latent_format = comfy.latent_formats.Wan21() extra_conds = [] - control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) + control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control def convert_mistoline(sd): diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index 976f98c65..9b6dace9d 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -31,6 +31,20 @@ class LayerScale(torch.nn.Module): def forward(self, x): return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype) +class Dinov2MLP(torch.nn.Module): + def __init__(self, hidden_size: int, dtype, device, operations): + super().__init__() + + mlp_ratio = 4 + hidden_features = int(hidden_size * mlp_ratio) + self.fc1 = operations.Linear(hidden_size, hidden_features, bias = True, device=device, dtype=dtype) + self.fc2 = operations.Linear(hidden_features, hidden_size, bias = True, device=device, dtype=dtype) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = torch.nn.functional.gelu(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state class SwiGLUFFN(torch.nn.Module): def __init__(self, dim, dtype, device, operations): @@ -50,12 +64,15 @@ class SwiGLUFFN(torch.nn.Module): class Dino2Block(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn): super().__init__() self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) self.layer_scale1 = LayerScale(dim, dtype, device, operations) self.layer_scale2 = LayerScale(dim, dtype, device, operations) - self.mlp = SwiGLUFFN(dim, dtype, device, operations) + if use_swiglu_ffn: + self.mlp = SwiGLUFFN(dim, dtype, device, operations) + else: + self.mlp = Dinov2MLP(dim, dtype, device, operations) self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) @@ -66,9 +83,10 @@ class Dino2Block(torch.nn.Module): class Dino2Encoder(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn): super().__init__() - self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)]) + self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) + for _ in range(num_layers)]) def forward(self, x, intermediate_output=None): optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) @@ -78,8 +96,8 @@ class Dino2Encoder(torch.nn.Module): intermediate_output = len(self.layer) + intermediate_output intermediate = None - for i, l in enumerate(self.layer): - x = l(x, optimized_attention) + for i, layer in enumerate(self.layer): + x = layer(x, optimized_attention) if i == intermediate_output: intermediate = x.clone() return x, intermediate @@ -128,9 +146,10 @@ class Dinov2Model(torch.nn.Module): dim = config_dict["hidden_size"] heads = config_dict["num_attention_heads"] layer_norm_eps = config_dict["layer_norm_eps"] + use_swiglu_ffn = config_dict["use_swiglu_ffn"] self.embeddings = Dino2Embeddings(dim, dtype, device, operations) - self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations) + self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) def forward(self, pixel_values, attention_mask=None, intermediate_output=None): diff --git a/comfy/image_encoders/dino2_large.json b/comfy/image_encoders/dino2_large.json new file mode 100644 index 000000000..43fbb58ff --- /dev/null +++ b/comfy/image_encoders/dino2_large.json @@ -0,0 +1,22 @@ +{ + "hidden_size": 1024, + "use_mask_token": true, + "patch_size": 14, + "image_size": 518, + "num_channels": 3, + "num_attention_heads": 16, + "initializer_range": 0.02, + "attention_probs_dropout_prob": 0.0, + "hidden_dropout_prob": 0.0, + "hidden_act": "gelu", + "mlp_ratio": 4, + "model_type": "dinov2", + "num_hidden_layers": 24, + "layer_norm_eps": 1e-6, + "qkv_bias": true, + "use_swiglu_ffn": false, + "layerscale_value": 1.0, + "drop_path_rate": 0.0, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] +} diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index fe6844b17..2d7e09838 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -171,6 +171,16 @@ def offset_first_sigma_for_snr(sigmas, model_sampling, percent_offset=1e-4): return sigmas +def ei_h_phi_1(h: torch.Tensor) -> torch.Tensor: + """Compute the result of h*phi_1(h) in exponential integrator methods.""" + return torch.expm1(h) + + +def ei_h_phi_2(h: torch.Tensor) -> torch.Tensor: + """Compute the result of h*phi_2(h) in exponential integrator methods.""" + return (torch.expm1(h) - h) / h + + @torch.no_grad() def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" @@ -1550,13 +1560,12 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None @torch.no_grad() def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. - arXiv: https://arxiv.org/abs/2305.14267 + arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - inject_noise = eta > 0 and s_noise > 0 model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') @@ -1564,55 +1573,53 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non lambda_fn = partial(sigma_to_half_log_snr, model_sampling=model_sampling) sigmas = offset_first_sigma_for_snr(sigmas, model_sampling) + fac = 1 / (2 * r) + for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: x = denoised - else: - lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) - h = lambda_t - lambda_s - h_eta = h * (eta + 1) - lambda_s_1 = lambda_s + r * h - fac = 1 / (2 * r) - sigma_s_1 = sigma_fn(lambda_s_1) + continue - # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) - alpha_s_1 = sigma_s_1 * lambda_s_1.exp() - alpha_t = sigmas[i + 1] * lambda_t.exp() + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = torch.lerp(lambda_s, lambda_t, r) + sigma_s_1 = sigma_fn(lambda_s_1) - coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - # 0 < r < 1 - noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt() - noise_coeff_2 = (-r * h * eta).exp() * (-2 * (1 - r) * h * eta).expm1().neg().sqrt() - noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigmas[i + 1]) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() - # Step 1 - x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r * h_eta) * denoised + if inject_noise: + sde_noise = (-2 * r * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1) + x_2 = x_2 + sde_noise * sigma_s_1 * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) - # Step 2 - denoised_d = (1 - fac) * denoised + fac * denoised_2 - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_2 * denoised_d - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise + # Step 2 + denoised_d = torch.lerp(denoised, denoised_2, fac) + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d + if inject_noise: + segment_factor = (r - 1) * h * eta + sde_noise = sde_noise * segment_factor.exp() + sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigmas[i + 1]) + x = x + sde_noise * sigmas[i + 1] * s_noise return x @torch.no_grad() def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3): """SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 3. - arXiv: https://arxiv.org/abs/2305.14267 + arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) """ extra_args = {} if extra_args is None else extra_args seed = extra_args.get("seed", None) noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler s_in = x.new_ones([x.shape[0]]) - inject_noise = eta > 0 and s_noise > 0 model_sampling = model.inner_model.model_patcher.get_model_object('model_sampling') @@ -1624,45 +1631,49 @@ def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=Non denoised = model(x, sigmas[i] * s_in, **extra_args) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: x = denoised - else: - lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) - h = lambda_t - lambda_s - h_eta = h * (eta + 1) - lambda_s_1 = lambda_s + r_1 * h - lambda_s_2 = lambda_s + r_2 * h - sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) + continue - # alpha_t = sigma_t * exp(log(alpha_t / sigma_t)) = sigma_t * exp(lambda_t) - alpha_s_1 = sigma_s_1 * lambda_s_1.exp() - alpha_s_2 = sigma_s_2 * lambda_s_2.exp() - alpha_t = sigmas[i + 1] * lambda_t.exp() + lambda_s, lambda_t = lambda_fn(sigmas[i]), lambda_fn(sigmas[i + 1]) + h = lambda_t - lambda_s + h_eta = h * (eta + 1) + lambda_s_1 = torch.lerp(lambda_s, lambda_t, r_1) + lambda_s_2 = torch.lerp(lambda_s, lambda_t, r_2) + sigma_s_1, sigma_s_2 = sigma_fn(lambda_s_1), sigma_fn(lambda_s_2) - coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1() - if inject_noise: - # 0 < r_1 < r_2 < 1 - noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt() - noise_coeff_2 = (-r_1 * h * eta).exp() * (-2 * (r_2 - r_1) * h * eta).expm1().neg().sqrt() - noise_coeff_3 = (-r_2 * h * eta).exp() * (-2 * (1 - r_2) * h * eta).expm1().neg().sqrt() - noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1]) + alpha_s_1 = sigma_s_1 * lambda_s_1.exp() + alpha_s_2 = sigma_s_2 * lambda_s_2.exp() + alpha_t = sigmas[i + 1] * lambda_t.exp() - # Step 1 - x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * coeff_1 * denoised - if inject_noise: - x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise - denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) + # Step 1 + x_2 = sigma_s_1 / sigmas[i] * (-r_1 * h * eta).exp() * x - alpha_s_1 * ei_h_phi_1(-r_1 * h_eta) * denoised + if inject_noise: + sde_noise = (-2 * r_1 * h * eta).expm1().neg().sqrt() * noise_sampler(sigmas[i], sigma_s_1) + x_2 = x_2 + sde_noise * sigma_s_1 * s_noise + denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) - # Step 2 - x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * coeff_2 * denoised + (r_2 / r_1) * alpha_s_2 * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised) - if inject_noise: - x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise - denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) + # Step 2 + a3_2 = r_2 / r_1 * ei_h_phi_2(-r_2 * h_eta) + a3_1 = ei_h_phi_1(-r_2 * h_eta) - a3_2 + x_3 = sigma_s_2 / sigmas[i] * (-r_2 * h * eta).exp() * x - alpha_s_2 * (a3_1 * denoised + a3_2 * denoised_2) + if inject_noise: + segment_factor = (r_1 - r_2) * h * eta + sde_noise = sde_noise * segment_factor.exp() + sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_1, sigma_s_2) + x_3 = x_3 + sde_noise * sigma_s_2 * s_noise + denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args) - # Step 3 - x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * coeff_3 * denoised + (1. / r_2) * alpha_t * (coeff_3 / h_eta + 1) * (denoised_3 - denoised) - if inject_noise: - x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise + # Step 3 + b3 = ei_h_phi_2(-h_eta) / r_2 + b1 = ei_h_phi_1(-h_eta) - b3 + x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b3 * denoised_3) + if inject_noise: + segment_factor = (r_2 - 1) * h * eta + sde_noise = sde_noise * segment_factor.exp() + sde_noise = sde_noise + segment_factor.mul(2).expm1().neg().sqrt() * noise_sampler(sigma_s_2, sigmas[i + 1]) + x = x + sde_noise * sigmas[i + 1] * s_noise return x diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index caf4991fc..f975b5e11 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -533,11 +533,89 @@ class Wan22(Wan21): 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 ]).view(1, self.latent_channels, 1, 1, 1) +class HunyuanImage21(LatentFormat): + latent_channels = 64 + latent_dimensions = 2 + scale_factor = 0.75289 + + latent_rgb_factors = [ + [-0.0154, -0.0397, -0.0521], + [ 0.0005, 0.0093, 0.0006], + [-0.0805, -0.0773, -0.0586], + [-0.0494, -0.0487, -0.0498], + [-0.0212, -0.0076, -0.0261], + [-0.0179, -0.0417, -0.0505], + [ 0.0158, 0.0310, 0.0239], + [ 0.0409, 0.0516, 0.0201], + [ 0.0350, 0.0553, 0.0036], + [-0.0447, -0.0327, -0.0479], + [-0.0038, -0.0221, -0.0365], + [-0.0423, -0.0718, -0.0654], + [ 0.0039, 0.0368, 0.0104], + [ 0.0655, 0.0217, 0.0122], + [ 0.0490, 0.1638, 0.2053], + [ 0.0932, 0.0829, 0.0650], + [-0.0186, -0.0209, -0.0135], + [-0.0080, -0.0076, -0.0148], + [-0.0284, -0.0201, 0.0011], + [-0.0642, -0.0294, -0.0777], + [-0.0035, 0.0076, -0.0140], + [ 0.0519, 0.0731, 0.0887], + [-0.0102, 0.0095, 0.0704], + [ 0.0068, 0.0218, -0.0023], + [-0.0726, -0.0486, -0.0519], + [ 0.0260, 0.0295, 0.0263], + [ 0.0250, 0.0333, 0.0341], + [ 0.0168, -0.0120, -0.0174], + [ 0.0226, 0.1037, 0.0114], + [ 0.2577, 0.1906, 0.1604], + [-0.0646, -0.0137, -0.0018], + [-0.0112, 0.0309, 0.0358], + [-0.0347, 0.0146, -0.0481], + [ 0.0234, 0.0179, 0.0201], + [ 0.0157, 0.0313, 0.0225], + [ 0.0423, 0.0675, 0.0524], + [-0.0031, 0.0027, -0.0255], + [ 0.0447, 0.0555, 0.0330], + [-0.0152, 0.0103, 0.0299], + [-0.0755, -0.0489, -0.0635], + [ 0.0853, 0.0788, 0.1017], + [-0.0272, -0.0294, -0.0471], + [ 0.0440, 0.0400, -0.0137], + [ 0.0335, 0.0317, -0.0036], + [-0.0344, -0.0621, -0.0984], + [-0.0127, -0.0630, -0.0620], + [-0.0648, 0.0360, 0.0924], + [-0.0781, -0.0801, -0.0409], + [ 0.0363, 0.0613, 0.0499], + [ 0.0238, 0.0034, 0.0041], + [-0.0135, 0.0258, 0.0310], + [ 0.0614, 0.1086, 0.0589], + [ 0.0428, 0.0350, 0.0205], + [ 0.0153, 0.0173, -0.0018], + [-0.0288, -0.0455, -0.0091], + [ 0.0344, 0.0109, -0.0157], + [-0.0205, -0.0247, -0.0187], + [ 0.0487, 0.0126, 0.0064], + [-0.0220, -0.0013, 0.0074], + [-0.0203, -0.0094, -0.0048], + [-0.0719, 0.0429, -0.0442], + [ 0.1042, 0.0497, 0.0356], + [-0.0659, -0.0578, -0.0280], + [-0.0060, -0.0322, -0.0234]] + + latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206] + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 scale_factor = 0.9990943042622529 +class Hunyuan3Dv2_1(LatentFormat): + scale_factor = 1.0039506158752403 + latent_channels = 64 + latent_dimensions = 1 + class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index 179c5b67e..d0d69bbdc 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -632,7 +632,7 @@ class ContinuousTransformer(nn.Module): # Attention layers if self.rotary_pos_emb is not None: - rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device) + rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device) else: rotary_pos_emb = None diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 1344c3a57..8ea7d4f57 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -106,6 +106,7 @@ class Flux(nn.Module): if y is None: y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -117,9 +118,17 @@ class Flux(nn.Module): if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) txt = self.txt_in(txt) + if "post_input" in patches: + for p in patches["post_input"]: + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + img = out["img"] + txt = out["txt"] + img_ids = out["img_ids"] + txt_ids = out["txt_ids"] + if img_ids is not None: ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) @@ -233,12 +242,18 @@ class Flux(nn.Module): h = 0 w = 0 index = 0 - index_ref_method = kwargs.get("ref_latents_method", "offset") == "index" + ref_latents_method = kwargs.get("ref_latents_method", "offset") for ref in ref_latents: - if index_ref_method: + if ref_latents_method == "index": index += 1 h_offset = 0 w_offset = 0 + elif ref_latents_method == "uxo": + index = 0 + h_offset = h_len * patch_size + h + w_offset = w_len * patch_size + w + h += ref.shape[-2] + w += ref.shape[-1] else: index = 1 h_offset = 0 diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py index 6e8cbf1d9..760944827 100644 --- a/comfy/ldm/hunyuan3d/vae.py +++ b/comfy/ldm/hunyuan3d/vae.py @@ -4,81 +4,458 @@ import torch import torch.nn as nn import torch.nn.functional as F - - -from typing import Union, Tuple, List, Callable, Optional - import numpy as np -from einops import repeat, rearrange +import math from tqdm import tqdm + +from typing import Optional + import logging import comfy.ops ops = comfy.ops.disable_weight_init -def generate_dense_grid_points( - bbox_min: np.ndarray, - bbox_max: np.ndarray, - octree_resolution: int, - indexing: str = "ij", -): - length = bbox_max - bbox_min - num_cells = octree_resolution +def fps(src: torch.Tensor, batch: torch.Tensor, sampling_ratio: float, start_random: bool = True): - x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) - y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) - z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) - [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) - xyz = np.stack((xs, ys, zs), axis=-1) - grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] + # manually create the pointer vector + assert src.size(0) == batch.numel() - return xyz, grid_size, length + batch_size = int(batch.max()) + 1 + deg = src.new_zeros(batch_size, dtype = torch.long) + + deg.scatter_add_(0, batch, torch.ones_like(batch)) + + ptr_vec = deg.new_zeros(batch_size + 1) + torch.cumsum(deg, 0, out=ptr_vec[1:]) + + #return fps_sampling(src, ptr_vec, ratio) + sampled_indicies = [] + + for b in range(batch_size): + # start and the end of each batch + start, end = ptr_vec[b].item(), ptr_vec[b + 1].item() + # points from the point cloud + points = src[start:end] + + num_points = points.size(0) + num_samples = max(1, math.ceil(num_points * sampling_ratio)) + + selected = torch.zeros(num_samples, device = src.device, dtype = torch.long) + distances = torch.full((num_points,), float("inf"), device = src.device) + + # select a random start point + if start_random: + farthest = torch.randint(0, num_points, (1,), device = src.device) + else: + farthest = torch.tensor([0], device = src.device, dtype = torch.long) + + for i in range(num_samples): + selected[i] = farthest + centroid = points[farthest].squeeze(0) + dist = torch.norm(points - centroid, dim = 1) # compute euclidean distance + distances = torch.minimum(distances, dist) + farthest = torch.argmax(distances) + + sampled_indicies.append(torch.arange(start, end)[selected]) + + return torch.cat(sampled_indicies, dim = 0) +class PointCrossAttention(nn.Module): + def __init__(self, + num_latents: int, + downsample_ratio: float, + pc_size: int, + pc_sharpedge_size: int, + point_feats: int, + width: int, + heads: int, + layers: int, + fourier_embedder, + normal_pe: bool = False, + qkv_bias: bool = False, + use_ln_post: bool = True, + qk_norm: bool = True): + + super().__init__() + + self.fourier_embedder = fourier_embedder + + self.pc_size = pc_size + self.normal_pe = normal_pe + self.downsample_ratio = downsample_ratio + self.pc_sharpedge_size = pc_sharpedge_size + self.num_latents = num_latents + self.point_feats = point_feats + + self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width) + + self.cross_attn = ResidualCrossAttentionBlock( + width = width, + heads = heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm + ) + + self.self_attn = None + if layers > 0: + self.self_attn = Transformer( + width = width, + heads = heads, + qkv_bias = qkv_bias, + qk_norm = qk_norm, + layers = layers + ) + + if use_ln_post: + self.ln_post = nn.LayerNorm(width) + else: + self.ln_post = None + + def sample_points_and_latents(self, point_cloud: torch.Tensor, features: torch.Tensor): + + """ + Subsample points randomly from the point cloud (input_pc) + Further sample the subsampled points to get query_pc + take the fourier embeddings for both input and query pc + + Mental Note: FPS-sampled points (query_pc) act as latent tokens that attend to and learn from the broader context in input_pc. + Goal: get a smaller represenation (query_pc) to represent the entire scence structure by learning from a broader subset (input_pc). + More computationally efficient. + + Features are additional information for each point in the cloud + """ + + B, _, D = point_cloud.shape + + num_latents = int(self.num_latents) + + num_random_query = self.pc_size / (self.pc_size + self.pc_sharpedge_size) * num_latents + num_sharpedge_query = num_latents - num_random_query + + # Split random and sharpedge surface points + random_pc, sharpedge_pc = torch.split(point_cloud, [self.pc_size, self.pc_sharpedge_size], dim=1) + + # assert statements + assert random_pc.shape[1] <= self.pc_size, "Random surface points size must be less than or equal to pc_size" + assert sharpedge_pc.shape[1] <= self.pc_sharpedge_size, "Sharpedge surface points size must be less than or equal to pc_sharpedge_size" + + input_random_pc_size = int(num_random_query * self.downsample_ratio) + random_query_pc, random_input_pc, random_idx_pc, random_idx_query = \ + self.subsample(pc = random_pc, num_query = num_random_query, input_pc_size = input_random_pc_size) + + input_sharpedge_pc_size = int(num_sharpedge_query * self.downsample_ratio) + + if input_sharpedge_pc_size == 0: + sharpedge_input_pc = torch.zeros(B, 0, D, dtype = random_input_pc.dtype).to(point_cloud.device) + sharpedge_query_pc = torch.zeros(B, 0, D, dtype= random_query_pc.dtype).to(point_cloud.device) + + else: + sharpedge_query_pc, sharpedge_input_pc, sharpedge_idx_pc, sharpedge_idx_query = \ + self.subsample(pc = sharpedge_pc, num_query = num_sharpedge_query, input_pc_size = input_sharpedge_pc_size) + + # concat the random and sharpedges + query_pc = torch.cat([random_query_pc, sharpedge_query_pc], dim = 1) + input_pc = torch.cat([random_input_pc, sharpedge_input_pc], dim = 1) + + query = self.fourier_embedder(query_pc) + data = self.fourier_embedder(input_pc) + + if self.point_feats > 0: + random_surface_features, sharpedge_surface_features = torch.split(features, [self.pc_size, self.pc_sharpedge_size], dim = 1) + + input_random_surface_features, query_random_features = \ + self.handle_features(features = random_surface_features, idx_pc = random_idx_pc, batch_size = B, + input_pc_size = input_random_pc_size, idx_query = random_idx_query) + + if input_sharpedge_pc_size == 0: + input_sharpedge_surface_features = torch.zeros(B, 0, self.point_feats, + dtype = input_random_surface_features.dtype, device = point_cloud.device) + + query_sharpedge_features = torch.zeros(B, 0, self.point_feats, + dtype = query_random_features.dtype, device = point_cloud.device) + else: + + input_sharpedge_surface_features, query_sharpedge_features = \ + self.handle_features(idx_pc = sharpedge_idx_pc, features = sharpedge_surface_features, + batch_size = B, idx_query = sharpedge_idx_query, input_pc_size = input_sharpedge_pc_size) + + query_features = torch.cat([query_random_features, query_sharpedge_features], dim = 1) + input_features = torch.cat([input_random_surface_features, input_sharpedge_surface_features], dim = 1) + + if self.normal_pe: + # apply the fourier embeddings on the first 3 dims (xyz) + input_features_pe = self.fourier_embedder(input_features[..., :3]) + query_features_pe = self.fourier_embedder(query_features[..., :3]) + # replace the first 3 dims with the new PE ones + input_features = torch.cat([input_features_pe, input_features[..., :3]], dim = -1) + query_features = torch.cat([query_features_pe, query_features[..., :3]], dim = -1) + + # concat at the channels dim + query = torch.cat([query, query_features], dim = -1) + data = torch.cat([data, input_features], dim = -1) + + # don't return pc_info to avoid unnecessary memory usuage + return query.view(B, -1, query.shape[-1]), data.view(B, -1, data.shape[-1]) + + def forward(self, point_cloud: torch.Tensor, features: torch.Tensor): + + query, data = self.sample_points_and_latents(point_cloud = point_cloud, features = features) + + # apply projections + query = self.input_proj(query) + data = self.input_proj(data) + + # apply cross attention between query and data + latents = self.cross_attn(query, data) + + if self.self_attn is not None: + latents = self.self_attn(latents) + + if self.ln_post is not None: + latents = self.ln_post(latents) + + return latents -class VanillaVolumeDecoder: + def subsample(self, pc, num_query, input_pc_size: int): + + """ + num_query: number of points to keep after FPS + input_pc_size: number of points to select before FPS + """ + + B, _, D = pc.shape + query_ratio = num_query / input_pc_size + + # random subsampling of points inside the point cloud + idx_pc = torch.randperm(pc.shape[1], device = pc.device)[:input_pc_size] + input_pc = pc[:, idx_pc, :] + + # flatten to allow applying fps across the whole batch + flattent_input_pc = input_pc.view(B * input_pc_size, D) + + # construct a batch_down tensor to tell fps + # which points belong to which batch + N_down = int(flattent_input_pc.shape[0] / B) + batch_down = torch.arange(B).to(pc.device) + batch_down = torch.repeat_interleave(batch_down, N_down) + + idx_query = fps(flattent_input_pc, batch_down, sampling_ratio = query_ratio) + query_pc = flattent_input_pc[idx_query].view(B, -1, D) + + return query_pc, input_pc, idx_pc, idx_query + + def handle_features(self, features, idx_pc, input_pc_size, batch_size: int, idx_query): + + B = batch_size + + input_surface_features = features[:, idx_pc, :] + flattent_input_features = input_surface_features.view(B * input_pc_size, -1) + query_features = flattent_input_features[idx_query].view(B, -1, + flattent_input_features.shape[-1]) + + return input_surface_features, query_features + +def normalize_mesh(mesh, scale = 0.9999): + """Normalize mesh to fit in [-scale, scale]. Translate mesh so its center is [0,0,0]""" + + bbox = mesh.bounds + center = (bbox[1] + bbox[0]) / 2 + + max_extent = (bbox[1] - bbox[0]).max() + mesh.apply_translation(-center) + mesh.apply_scale((2 * scale) / max_extent) + + return mesh + +def sample_pointcloud(mesh, num = 200000): + """ Uniformly sample points from the surface of the mesh """ + + points, face_idx = mesh.sample(num, return_index = True) + normals = mesh.face_normals[face_idx] + return torch.from_numpy(points.astype(np.float32)), torch.from_numpy(normals.astype(np.float32)) + +def detect_sharp_edges(mesh, threshold=0.985): + """Return edge indices (a, b) that lie on sharp boundaries of the mesh.""" + + V, F = mesh.vertices, mesh.faces + VN, FN = mesh.vertex_normals, mesh.face_normals + + sharp_mask = np.ones(V.shape[0]) + for i in range(3): + indices = F[:, i] + alignment = np.einsum('ij,ij->i', VN[indices], FN) + dot_stack = np.stack((sharp_mask[indices], alignment), axis=-1) + sharp_mask[indices] = np.min(dot_stack, axis=-1) + + edge_a = np.concatenate([F[:, 0], F[:, 1], F[:, 2]]) + edge_b = np.concatenate([F[:, 1], F[:, 2], F[:, 0]]) + sharp_edges = (sharp_mask[edge_a] < threshold) & (sharp_mask[edge_b] < threshold) + + return edge_a[sharp_edges], edge_b[sharp_edges] + + +def sharp_sample_pointcloud(mesh, num = 16384): + """ Sample points preferentially from sharp edges in the mesh. """ + + edge_a, edge_b = detect_sharp_edges(mesh) + V, VN = mesh.vertices, mesh.vertex_normals + + va, vb = V[edge_a], V[edge_b] + na, nb = VN[edge_a], VN[edge_b] + + edge_lengths = np.linalg.norm(vb - va, axis=-1) + weights = edge_lengths / edge_lengths.sum() + + indices = np.searchsorted(np.cumsum(weights), np.random.rand(num)) + t = np.random.rand(num, 1) + + samples = t * va[indices] + (1 - t) * vb[indices] + normals = t * na[indices] + (1 - t) * nb[indices] + + return samples.astype(np.float32), normals.astype(np.float32) + +def load_surface_sharpedge(mesh, num_points=4096, num_sharp_points=4096, sharpedge_flag = True, device = "cuda"): + """Load a surface with optional sharp-edge annotations from a trimesh mesh.""" + + import trimesh + + try: + mesh_full = trimesh.util.concatenate(mesh.dump()) + except Exception: + mesh_full = trimesh.util.concatenate(mesh) + + mesh_full = normalize_mesh(mesh_full) + + faces = mesh_full.faces + vertices = mesh_full.vertices + origin_face_count = faces.shape[0] + + mesh_surface = trimesh.Trimesh(vertices=vertices, faces=faces[:origin_face_count]) + mesh_fill = trimesh.Trimesh(vertices=vertices, faces=faces[origin_face_count:]) + + area_surface = mesh_surface.area + area_fill = mesh_fill.area + total_area = area_surface + area_fill + + sample_num = 499712 // 2 + fill_ratio = area_fill / total_area if total_area > 0 else 0 + + num_fill = int(sample_num * fill_ratio) + num_surface = sample_num - num_fill + + surf_pts, surf_normals = sample_pointcloud(mesh_surface, num_surface) + fill_pts, fill_normals = (torch.zeros(0, 3), torch.zeros(0, 3)) if num_fill == 0 else sample_pointcloud(mesh_fill, num_fill) + + sharp_pts, sharp_normals = sharp_sample_pointcloud(mesh_surface, sample_num) + + def assemble_tensor(points, normals, label=None): + + data = torch.cat([points, normals], dim=1).half().to(device) + + if label is not None: + label_tensor = torch.full((data.shape[0], 1), float(label), dtype=torch.float16).to(device) + data = torch.cat([data, label_tensor], dim=1) + + return data + + surface = assemble_tensor(torch.cat([surf_pts.to(device), fill_pts.to(device)], dim=0), + torch.cat([surf_normals.to(device), fill_normals.to(device)], dim=0), + label = 0 if sharpedge_flag else None) + + sharp_surface = assemble_tensor(torch.from_numpy(sharp_pts), torch.from_numpy(sharp_normals), + label = 1 if sharpedge_flag else None) + + rng = np.random.default_rng() + + surface = surface[rng.choice(surface.shape[0], num_points, replace = False)] + sharp_surface = sharp_surface[rng.choice(sharp_surface.shape[0], num_sharp_points, replace = False)] + + full = torch.cat([surface, sharp_surface], dim = 0).unsqueeze(0) + + return full + +class SharpEdgeSurfaceLoader: + """ Load mesh surface and sharp edge samples. """ + + def __init__(self, num_uniform_points = 8192, num_sharp_points = 8192): + + self.num_uniform_points = num_uniform_points + self.num_sharp_points = num_sharp_points + self.total_points = num_uniform_points + num_sharp_points + + def __call__(self, mesh_input, device = "cuda"): + mesh = self._load_mesh(mesh_input) + return load_surface_sharpedge(mesh, self.num_uniform_points, self.num_sharp_points, device = device) + + @staticmethod + def _load_mesh(mesh_input): + import trimesh + + if isinstance(mesh_input, str): + mesh = trimesh.load(mesh_input, force="mesh", merge_primitives = True) + else: + mesh = mesh_input + + if isinstance(mesh, trimesh.Scene): + combined = None + for obj in mesh.geometry.values(): + combined = obj if combined is None else combined + obj + return combined + + return mesh + +class DiagonalGaussianDistribution: + def __init__(self, params: torch.Tensor, feature_dim: int = -1): + + # divide quant channels (8) into mean and log variance + self.mean, self.logvar = torch.chunk(params, 2, dim = feature_dim) + + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.std = torch.exp(0.5 * self.logvar) + + def sample(self): + + eps = torch.randn_like(self.std) + z = self.mean + eps * self.std + + return z + +################################################ +# Volume Decoder +################################################ + +class VanillaVolumeDecoder(): @torch.no_grad() - def __call__( - self, - latents: torch.FloatTensor, - geo_decoder: Callable, - bounds: Union[Tuple[float], List[float], float] = 1.01, - num_chunks: int = 10000, - octree_resolution: int = None, - enable_pbar: bool = True, - **kwargs, - ): - device = latents.device - dtype = latents.dtype - batch_size = latents.shape[0] + def __call__(self, latents: torch.Tensor, geo_decoder: callable, octree_resolution: int, bounds = 1.01, + num_chunks: int = 10_000, enable_pbar: bool = True, **kwargs): - # 1. generate query points if isinstance(bounds, float): bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] - bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) - xyz_samples, grid_size, length = generate_dense_grid_points( - bbox_min=bbox_min, - bbox_max=bbox_max, - octree_resolution=octree_resolution, - indexing="ij" - ) - xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) + bbox_min, bbox_max = torch.tensor(bounds[:3]), torch.tensor(bounds[3:]) + + x = torch.linspace(bbox_min[0], bbox_max[0], int(octree_resolution) + 1, dtype = torch.float32) + y = torch.linspace(bbox_min[1], bbox_max[1], int(octree_resolution) + 1, dtype = torch.float32) + z = torch.linspace(bbox_min[2], bbox_max[2], int(octree_resolution) + 1, dtype = torch.float32) + + [xs, ys, zs] = torch.meshgrid(x, y, z, indexing = "ij") + xyz = torch.stack((xs, ys, zs), axis=-1).to(latents.device, dtype = latents.dtype).contiguous().reshape(-1, 3) + grid_size = [int(octree_resolution) + 1, int(octree_resolution) + 1, int(octree_resolution) + 1] - # 2. latents to 3d volume batch_logits = [] - for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding", + for start in tqdm(range(0, xyz.shape[0], num_chunks), desc="Volume Decoding", disable=not enable_pbar): - chunk_queries = xyz_samples[start: start + num_chunks, :] - chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) - logits = geo_decoder(queries=chunk_queries, latents=latents) + + chunk_queries = xyz[start: start + num_chunks, :] + chunk_queries = chunk_queries.unsqueeze(0).repeat(latents.shape[0], 1, 1) + logits = geo_decoder(queries = chunk_queries, latents = latents) batch_logits.append(logits) - grid_logits = torch.cat(batch_logits, dim=1) - grid_logits = grid_logits.view((batch_size, *grid_size)).float() + grid_logits = torch.cat(batch_logits, dim = 1) + grid_logits = grid_logits.view((latents.shape[0], *grid_size)).float() return grid_logits - class FourierEmbedder(nn.Module): """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts each feature dimension of `x[..., i]` into: @@ -175,13 +552,11 @@ class FourierEmbedder(nn.Module): else: return x - class CrossAttentionProcessor: def __call__(self, attn, q, k, v): out = comfy.ops.scaled_dot_product_attention(q, k, v) return out - class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ @@ -232,38 +607,41 @@ class MLP(nn.Module): def forward(self, x): return self.drop_path(self.c_proj(self.gelu(self.c_fc(x)))) - class QKVMultiheadCrossAttention(nn.Module): def __init__( self, - *, heads: int, + n_data = None, width=None, qk_norm=False, norm_layer=ops.LayerNorm ): super().__init__() self.heads = heads + self.n_data = n_data self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() - self.attn_processor = CrossAttentionProcessor() - def forward(self, q, kv): + _, n_ctx, _ = q.shape bs, n_data, width = kv.shape + attn_ch = width // self.heads // 2 q = q.view(bs, n_ctx, self.heads, -1) + kv = kv.view(bs, n_data, self.heads, -1) k, v = torch.split(kv, attn_ch, dim=-1) q = self.q_norm(q) k = self.k_norm(k) - q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) - out = self.attn_processor(self, q, k, v) - out = out.transpose(1, 2).reshape(bs, n_ctx, -1) - return out + q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)] + out = F.scaled_dot_product_attention(q, k, v) + + out = out.transpose(1, 2).reshape(bs, n_ctx, -1) + + return out class MultiheadCrossAttention(nn.Module): def __init__( @@ -306,7 +684,6 @@ class MultiheadCrossAttention(nn.Module): x = self.c_proj(x) return x - class ResidualCrossAttentionBlock(nn.Module): def __init__( self, @@ -366,7 +743,7 @@ class QKVMultiheadAttention(nn.Module): q = self.q_norm(q) k = self.k_norm(k) - q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v)) + q, k, v = [t.permute(0, 2, 1, 3) for t in (q, k, v)] out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1) return out @@ -383,8 +760,7 @@ class MultiheadAttention(nn.Module): drop_path_rate: float = 0.0 ): super().__init__() - self.width = width - self.heads = heads + self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias) self.c_proj = ops.Linear(width, width) self.attention = QKVMultiheadAttention( @@ -491,7 +867,7 @@ class CrossAttentionDecoder(nn.Module): self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width) if self.downsample_ratio != 1: self.latents_proj = ops.Linear(width * downsample_ratio, width) - if self.enable_ln_post == False: + if not self.enable_ln_post: qk_norm = False self.cross_attn_decoder = ResidualCrossAttentionBlock( width=width, @@ -522,28 +898,44 @@ class CrossAttentionDecoder(nn.Module): class ShapeVAE(nn.Module): def __init__( - self, - *, - embed_dim: int, - width: int, - heads: int, - num_decoder_layers: int, - geo_decoder_downsample_ratio: int = 1, - geo_decoder_mlp_expand_ratio: int = 4, - geo_decoder_ln_post: bool = True, - num_freqs: int = 8, - include_pi: bool = True, - qkv_bias: bool = True, - qk_norm: bool = False, - label_type: str = "binary", - drop_path_rate: float = 0.0, - scale_factor: float = 1.0, + self, + *, + num_latents: int = 4096, + embed_dim: int = 64, + width: int = 1024, + heads: int = 16, + num_decoder_layers: int = 16, + num_encoder_layers: int = 8, + pc_size: int = 81920, + pc_sharpedge_size: int = 0, + point_feats: int = 4, + downsample_ratio: int = 20, + geo_decoder_downsample_ratio: int = 1, + geo_decoder_mlp_expand_ratio: int = 4, + geo_decoder_ln_post: bool = True, + num_freqs: int = 8, + qkv_bias: bool = False, + qk_norm: bool = True, + drop_path_rate: float = 0.0, + include_pi: bool = False, + scale_factor: float = 1.0039506158752403, + label_type: str = "binary", ): super().__init__() self.geo_decoder_ln_post = geo_decoder_ln_post self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) + self.encoder = PointCrossAttention(layers = num_encoder_layers, + num_latents = num_latents, + downsample_ratio = downsample_ratio, + heads = heads, + pc_size = pc_size, + width = width, + point_feats = point_feats, + fourier_embedder = self.fourier_embedder, + pc_sharpedge_size = pc_sharpedge_size) + self.post_kl = ops.Linear(embed_dim, width) self.transformer = Transformer( @@ -583,5 +975,14 @@ class ShapeVAE(nn.Module): grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar) return grid_logits.movedim(-2, -1) - def encode(self, x): - return None + def encode(self, surface): + + pc, feats = surface[:, :, :3], surface[:, :, 3:] + latents = self.encoder(pc, feats) + + moments = self.pre_kl(latents) + posterior = DiagonalGaussianDistribution(moments, feature_dim = -1) + + latents = posterior.sample() + + return latents diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py new file mode 100644 index 000000000..d48d9d642 --- /dev/null +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -0,0 +1,659 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management + +class GELU(nn.Module): + + def __init__(self, dim_in: int, dim_out: int, operations, device, dtype): + super().__init__() + self.proj = operations.Linear(dim_in, dim_out, device = device, dtype = dtype) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + + if gate.device.type == "mps": + return F.gelu(gate.to(dtype = torch.float32)).to(dtype = gate.dtype) + + return F.gelu(gate) + + def forward(self, hidden_states): + + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + + return hidden_states + +class FeedForward(nn.Module): + + def __init__(self, dim: int, dim_out = None, mult: int = 4, + dropout: float = 0.0, inner_dim = None, operations = None, device = None, dtype = None): + + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + + dim_out = dim_out if dim_out is not None else dim + + act_fn = GELU(dim, inner_dim, operations = operations, device = device, dtype = dtype) + + self.net = nn.ModuleList([]) + self.net.append(act_fn) + + self.net.append(nn.Dropout(dropout)) + self.net.append(operations.Linear(inner_dim, dim_out, device = device, dtype = dtype)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + +class AddAuxLoss(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, loss): + # do nothing in forward (no computation) + ctx.requires_aux_loss = loss.requires_grad + ctx.dtype = loss.dtype + + return x + + @staticmethod + def backward(ctx, grad_output): + # add the aux loss gradients + grad_loss = None + # put the aux grad the same as the main grad loss + # aux grad contributes equally + if ctx.requires_aux_loss: + grad_loss = torch.ones(1, dtype = ctx.dtype, device = grad_output.device) + + return grad_output, grad_loss + +class MoEGate(nn.Module): + + def __init__(self, embed_dim, num_experts=16, num_experts_per_tok=2, aux_loss_alpha=0.01, device = None, dtype = None): + + super().__init__() + self.top_k = num_experts_per_tok + self.n_routed_experts = num_experts + + self.alpha = aux_loss_alpha + + self.gating_dim = embed_dim + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), device = device, dtype = dtype)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + # flatten hidden states + hidden_states = hidden_states.view(-1, hidden_states.size(-1)) + + # get logits and pass it to softmax + logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), bias = None) + scores = logits.softmax(dim = -1) + + topk_weight, topk_idx = torch.topk(scores, k = self.top_k, dim = -1, sorted = False) + + if self.training and self.alpha > 0.0: + scores_for_aux = scores + + # used bincount instead of one hot encoding + counts = torch.bincount(topk_idx.view(-1), minlength = self.n_routed_experts).float() + ce = counts / topk_idx.numel() # normalized expert usage + + # mean expert score + Pi = scores_for_aux.mean(0) + + # expert balance loss + aux_loss = (Pi * ce * self.n_routed_experts).sum() * self.alpha + else: + aux_loss = None + + return topk_idx, topk_weight, aux_loss + +class MoEBlock(nn.Module): + def __init__(self, dim, num_experts: int = 6, moe_top_k: int = 2, dropout: float = 0.0, + ff_inner_dim: int = None, operations = None, device = None, dtype = None): + super().__init__() + + self.moe_top_k = moe_top_k + self.num_experts = num_experts + + self.experts = nn.ModuleList([ + FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype) + for _ in range(num_experts) + ]) + + self.gate = MoEGate(dim, num_experts = num_experts, num_experts_per_tok = moe_top_k, device = device, dtype = dtype) + self.shared_experts = FeedForward(dim, dropout = dropout, inner_dim = ff_inner_dim, operations = operations, device = device, dtype = dtype) + + def forward(self, hidden_states) -> torch.Tensor: + + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + + if self.training: + + hidden_states = hidden_states.repeat_interleave(self.moe_top_k, dim = 0) + y = torch.empty_like(hidden_states, dtype = hidden_states.dtype) + + for i, expert in enumerate(self.experts): + tmp = expert(hidden_states[flat_topk_idx == i]) + y[flat_topk_idx == i] = tmp.to(hidden_states.dtype) + + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim = 1) + y = y.view(*orig_shape) + + y = AddAuxLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, flat_expert_indices = flat_topk_idx,flat_expert_weights = topk_weight.view(-1, 1)).view(*orig_shape) + + y = y + self.shared_experts(identity) + + return y + + @torch.no_grad() + def moe_infer(self, x, flat_expert_indices, flat_expert_weights): + + expert_cache = torch.zeros_like(x) + idxs = flat_expert_indices.argsort() + + # no need for .numpy().cpu() here + tokens_per_expert = flat_expert_indices.bincount().cumsum(0) + token_idxs = idxs // self.moe_top_k + + for i, end_idx in enumerate(tokens_per_expert): + + start_idx = 0 if i == 0 else tokens_per_expert[i-1] + + if start_idx == end_idx: + continue + + expert = self.experts[i] + exp_token_idx = token_idxs[start_idx:end_idx] + + expert_tokens = x[exp_token_idx] + expert_out = expert(expert_tokens) + + expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]]) + + # use index_add_ with a 1-D index tensor directly avoids building a large [N, D] index map and extra memcopy required by scatter_reduce_ + # + avoid dtype conversion + expert_cache.index_add_(0, exp_token_idx, expert_out) + + return expert_cache + +class Timesteps(nn.Module): + def __init__(self, num_channels: int, downscale_freq_shift: float = 0.0, + scale: float = 1.0, max_period: int = 10000): + super().__init__() + + self.num_channels = num_channels + half_dim = num_channels // 2 + + # precompute the “inv_freq” vector once + exponent = -math.log(max_period) * torch.arange( + half_dim, dtype=torch.float32 + ) / (half_dim - downscale_freq_shift) + + inv_freq = torch.exp(exponent) + + # pad + if num_channels % 2 == 1: + # we’ll pad a zero at the end of the cos-half + inv_freq = torch.cat([inv_freq, inv_freq.new_zeros(1)]) + + # register to buffer so it moves with the device + self.register_buffer("inv_freq", inv_freq, persistent = False) + self.scale = scale + + def forward(self, timesteps: torch.Tensor): + + x = timesteps.float().unsqueeze(1) * self.inv_freq.to(timesteps.device).unsqueeze(0) + + + # fused CUDA kernels for sin and cos + sin_emb = x.sin() + cos_emb = x.cos() + + emb = torch.cat([sin_emb, cos_emb], dim = 1) + + # scale factor + if self.scale != 1.0: + emb = emb * self.scale + + # If we padded inv_freq for odd, emb is already wide enough; otherwise: + if emb.shape[1] > self.num_channels: + emb = emb[:, :self.num_channels] + + return emb + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size = 256, cond_proj_dim = None, operations = None, device = None, dtype = None): + super().__init__() + + self.mlp = nn.Sequential( + operations.Linear(hidden_size, frequency_embedding_size, bias=True, device = device, dtype = dtype), + nn.GELU(), + operations.Linear(frequency_embedding_size, hidden_size, bias=True, device = device, dtype = dtype), + ) + self.frequency_embedding_size = frequency_embedding_size + + if cond_proj_dim is not None: + self.cond_proj = operations.Linear(cond_proj_dim, frequency_embedding_size, bias=False, device = device, dtype = dtype) + + self.time_embed = Timesteps(hidden_size) + + def forward(self, timesteps, condition): + + timestep_embed = self.time_embed(timesteps).type(self.mlp[0].weight.dtype) + + if condition is not None: + cond_embed = self.cond_proj(condition) + timestep_embed = timestep_embed + cond_embed + + time_conditioned = self.mlp(timestep_embed) + + # for broadcasting with image tokens + return time_conditioned.unsqueeze(1) + +class MLP(nn.Module): + def __init__(self, *, width: int, operations = None, device = None, dtype = None): + super().__init__() + self.width = width + self.fc1 = operations.Linear(width, width * 4, device = device, dtype = dtype) + self.fc2 = operations.Linear(width * 4, width, device = device, dtype = dtype) + self.gelu = nn.GELU() + + def forward(self, x): + return self.fc2(self.gelu(self.fc1(x))) + +class CrossAttention(nn.Module): + def __init__( + self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + norm_layer=nn.LayerNorm, + use_fp16: bool = False, + operations = None, + dtype = None, + device = None, + **kwargs, + ): + super().__init__() + self.qdim = qdim + self.kdim = kdim + + self.num_heads = num_heads + self.head_dim = self.qdim // num_heads + + self.scale = self.head_dim ** -0.5 + + self.to_q = operations.Linear(qdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + self.to_k = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + self.to_v = operations.Linear(kdim, qdim, bias=qkv_bias, device = device, dtype = dtype) + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + if norm_layer == nn.LayerNorm: + norm_layer = operations.LayerNorm + else: + norm_layer = operations.RMSNorm + + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.out_proj = operations.Linear(qdim, qdim, bias=True, device = device, dtype = dtype) + + def forward(self, x, y): + + b, s1, _ = x.shape + _, s2, _ = y.shape + + y = y.to(next(self.to_k.parameters()).dtype) + + q = self.to_q(x) + k = self.to_k(y) + v = self.to_v(y) + + kv = torch.cat((k, v), dim=-1) + split_size = kv.shape[-1] // self.num_heads // 2 + + kv = kv.view(1, -1, self.num_heads, split_size * 2) + k, v = torch.split(kv, split_size, dim=-1) + + q = q.view(b, s1, self.num_heads, self.head_dim) + k = k.view(b, s2, self.num_heads, self.head_dim) + v = v.reshape(b, s2, self.num_heads * self.head_dim) + + q = self.q_norm(q) + k = self.k_norm(k) + + x = optimized_attention( + q.reshape(b, s1, self.num_heads * self.head_dim), + k.reshape(b, s2, self.num_heads * self.head_dim), + v, + heads=self.num_heads, + ) + + out = self.out_proj(x) + + return out + +class Attention(nn.Module): + + def __init__( + self, + dim, + num_heads, + qkv_bias = True, + qk_norm = False, + norm_layer = nn.LayerNorm, + use_fp16: bool = False, + operations = None, + device = None, + dtype = None + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = self.dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.to_q = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + self.to_k = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + self.to_v = operations.Linear(dim, dim, bias = qkv_bias, device = device, dtype = dtype) + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + if norm_layer == nn.LayerNorm: + norm_layer = operations.LayerNorm + else: + norm_layer = operations.RMSNorm + + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps = eps, device = device, dtype = dtype) if qk_norm else nn.Identity() + self.out_proj = operations.Linear(dim, dim, device = device, dtype = dtype) + + def forward(self, x): + B, N, _ = x.shape + + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + qkv_combined = torch.cat((query, key, value), dim=-1) + split_size = qkv_combined.shape[-1] // self.num_heads // 3 + + qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) + query, key, value = torch.split(qkv, split_size, dim=-1) + + query = query.reshape(B, N, self.num_heads, self.head_dim) + key = key.reshape(B, N, self.num_heads, self.head_dim) + value = value.reshape(B, N, self.num_heads * self.head_dim) + + query = self.q_norm(query) + key = self.k_norm(key) + + x = optimized_attention( + query.reshape(B, N, self.num_heads * self.head_dim), + key.reshape(B, N, self.num_heads * self.head_dim), + value, + heads=self.num_heads, + ) + + x = self.out_proj(x) + return x + +class HunYuanDiTBlock(nn.Module): + def __init__( + self, + hidden_size, + c_emb_size, + num_heads, + text_states_dim=1024, + qk_norm=False, + norm_layer=nn.LayerNorm, + qk_norm_layer=True, + qkv_bias=True, + skip_connection=True, + timested_modulate=False, + use_moe: bool = False, + num_experts: int = 8, + moe_top_k: int = 2, + use_fp16: bool = False, + operations = None, + device = None, dtype = None + ): + super().__init__() + + # eps can't be 1e-6 in fp16 mode because of numerical stability issues + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + self.norm1 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, + norm_layer=qk_norm_layer, use_fp16 = use_fp16, device = device, dtype = dtype, operations = operations) + + self.norm2 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + self.timested_modulate = timested_modulate + if self.timested_modulate: + self.default_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(c_emb_size, hidden_size, bias=True, device = device, dtype = dtype) + ) + + self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=qkv_bias, + qk_norm=qk_norm, norm_layer=qk_norm_layer, use_fp16 = use_fp16, + device = device, dtype = dtype, operations = operations) + + self.norm3 = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + + if skip_connection: + self.skip_norm = norm_layer(hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, device = device, dtype = dtype) + else: + self.skip_linear = None + + self.use_moe = use_moe + + if self.use_moe: + self.moe = MoEBlock( + hidden_size, + num_experts = num_experts, + moe_top_k = moe_top_k, + dropout = 0.0, + ff_inner_dim = int(hidden_size * 4.0), + device = device, dtype = dtype, + operations = operations + ) + else: + self.mlp = MLP(width=hidden_size, operations=operations, device = device, dtype = dtype) + + def forward(self, hidden_states, conditioning=None, text_states=None, skip_tensor=None): + + if self.skip_linear is not None: + combined = torch.cat([skip_tensor, hidden_states], dim=-1) + hidden_states = self.skip_linear(combined) + hidden_states = self.skip_norm(hidden_states) + + # self attention + if self.timested_modulate: + modulation_shift = self.default_modulation(conditioning).unsqueeze(dim=1) + hidden_states = hidden_states + modulation_shift + + self_attn_out = self.attn1(self.norm1(hidden_states)) + hidden_states = hidden_states + self_attn_out + + # cross attention + hidden_states = hidden_states + self.attn2(self.norm2(hidden_states), text_states) + + # MLP Layer + mlp_input = self.norm3(hidden_states) + + if self.use_moe: + hidden_states = hidden_states + self.moe(mlp_input) + else: + hidden_states = hidden_states + self.mlp(mlp_input) + + return hidden_states + +class FinalLayer(nn.Module): + + def __init__(self, final_hidden_size, out_channels, operations, use_fp16: bool = False, device = None, dtype = None): + super().__init__() + + if use_fp16: + eps = 1.0 / 65504 + else: + eps = 1e-6 + + self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine = True, eps = eps, device = device, dtype = dtype) + self.linear = operations.Linear(final_hidden_size, out_channels, bias = True, device = device, dtype = dtype) + + def forward(self, x): + x = self.norm_final(x) + x = x[:, 1:] + x = self.linear(x) + return x + +class HunYuanDiTPlain(nn.Module): + + # init with the defaults values from https://huggingface.co/tencent/Hunyuan3D-2.1/blob/main/hunyuan3d-dit-v2-1/config.yaml + def __init__( + self, + in_channels: int = 64, + hidden_size: int = 2048, + context_dim: int = 1024, + depth: int = 21, + num_heads: int = 16, + qk_norm: bool = True, + qkv_bias: bool = False, + num_moe_layers: int = 6, + guidance_cond_proj_dim = 2048, + norm_type = 'layer', + num_experts: int = 8, + moe_top_k: int = 2, + use_fp16: bool = False, + dtype = None, + device = None, + operations = None, + **kwargs + ): + + self.dtype = dtype + + super().__init__() + + self.depth = depth + + self.in_channels = in_channels + self.out_channels = in_channels + + self.num_heads = num_heads + self.hidden_size = hidden_size + + norm = operations.LayerNorm if norm_type == 'layer' else operations.RMSNorm + qk_norm = operations.RMSNorm + + self.context_dim = context_dim + self.guidance_cond_proj_dim = guidance_cond_proj_dim + + self.x_embedder = operations.Linear(in_channels, hidden_size, bias = True, device = device, dtype = dtype) + self.t_embedder = TimestepEmbedder(hidden_size, hidden_size * 4, cond_proj_dim = guidance_cond_proj_dim, device = device, dtype = dtype, operations = operations) + + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList([ + HunYuanDiTBlock(hidden_size=hidden_size, + c_emb_size=hidden_size, + num_heads=num_heads, + text_states_dim=context_dim, + qk_norm=qk_norm, + norm_layer = norm, + qk_norm_layer = qk_norm, + skip_connection=layer > depth // 2, + qkv_bias=qkv_bias, + use_moe=True if depth - layer <= num_moe_layers else False, + num_experts=num_experts, + moe_top_k=moe_top_k, + use_fp16 = use_fp16, + device = device, dtype = dtype, operations = operations) + for layer in range(depth) + ]) + + self.depth = depth + + self.final_layer = FinalLayer(hidden_size, self.out_channels, use_fp16 = use_fp16, operations = operations, device = device, dtype = dtype) + + def forward(self, x, t, context, transformer_options = {}, **kwargs): + + x = x.movedim(-1, -2) + uncond_emb, cond_emb = context.chunk(2, dim = 0) + + context = torch.cat([cond_emb, uncond_emb], dim = 0) + main_condition = context + + t = 1.0 - t + + time_embedded = self.t_embedder(t, condition = kwargs.get('guidance_cond')) + + x = x.to(dtype = next(self.x_embedder.parameters()).dtype) + x_embedded = self.x_embedder(x) + + combined = torch.cat([time_embedded, x_embedded], dim=1) + + def block_wrap(args): + return block( + args["x"], + args["t"], + args["cond"], + skip_tensor=args.get("skip"),) + + skip_stack = [] + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + for idx, block in enumerate(self.blocks): + if idx <= self.depth // 2: + skip_input = None + else: + skip_input = skip_stack.pop() + + if ("block", idx) in blocks_replace: + + combined = blocks_replace[("block", idx)]( + { + "x": combined, + "t": time_embedded, + "cond": main_condition, + "skip": skip_input, + }, + {"original_block": block_wrap}, + ) + else: + combined = block(combined, time_embedded, main_condition, skip_tensor=skip_input) + + if idx < self.depth // 2: + skip_stack.append(combined) + + output = self.final_layer(combined) + output = output.movedim(-2, -1) * (-1.0) + + cond_emb, uncond_emb = output.chunk(2, dim = 0) + return torch.cat([uncond_emb, cond_emb]) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index da1011596..7732182a4 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -40,6 +40,8 @@ class HunyuanVideoParams: patch_size: list qkv_bias: bool guidance_embed: bool + byt5: bool + meanflow: bool class SelfAttentionRef(nn.Module): @@ -161,6 +163,30 @@ class TokenRefiner(nn.Module): x = self.individual_token_refiner(x, c, mask) return x + +class ByT5Mapper(nn.Module): + def __init__(self, in_dim, out_dim, hidden_dim, out_dim1, use_res=False, dtype=None, device=None, operations=None): + super().__init__() + self.layernorm = operations.LayerNorm(in_dim, dtype=dtype, device=device) + self.fc1 = operations.Linear(in_dim, hidden_dim, dtype=dtype, device=device) + self.fc2 = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) + self.fc3 = operations.Linear(out_dim, out_dim1, dtype=dtype, device=device) + self.use_res = use_res + self.act_fn = nn.GELU() + + def forward(self, x): + if self.use_res: + res = x + x = self.layernorm(x) + x = self.fc1(x) + x = self.act_fn(x) + x = self.fc2(x) + x2 = self.act_fn(x) + x2 = self.fc3(x2) + if self.use_res: + x2 = x2 + res + return x2 + class HunyuanVideo(nn.Module): """ Transformer model for flow matching on sequences. @@ -185,9 +211,13 @@ class HunyuanVideo(nn.Module): self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations) + self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=len(self.patch_size) == 3, dtype=dtype, device=device, operations=operations) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + if params.vec_in_dim is not None: + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.vector_in = None + self.guidance_in = ( MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() ) @@ -215,6 +245,23 @@ class HunyuanVideo(nn.Module): ] ) + if params.byt5: + self.byt5_in = ByT5Mapper( + in_dim=1472, + out_dim=2048, + hidden_dim=2048, + out_dim1=self.hidden_size, + use_res=False, + dtype=dtype, device=device, operations=operations + ) + else: + self.byt5_in = None + + if params.meanflow: + self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.time_r_in = None + if final_layer: self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) @@ -226,7 +273,8 @@ class HunyuanVideo(nn.Module): txt_ids: Tensor, txt_mask: Tensor, timesteps: Tensor, - y: Tensor, + y: Tensor = None, + txt_byt5=None, guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, @@ -240,6 +288,14 @@ class HunyuanVideo(nn.Module): img = self.img_in(img) vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype)) + if self.time_r_in is not None: + w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved + if len(w) > 0: + timesteps_r = transformer_options['sample_sigmas'][w[0] + 1] + timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype) + vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype)) + vec = (vec + vec_r) / 2 + if ref_latent is not None: ref_latent_ids = self.img_ids(ref_latent) ref_latent = self.img_in(ref_latent) @@ -250,13 +306,17 @@ class HunyuanVideo(nn.Module): if guiding_frame_index is not None: token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0)) - vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) - vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + if self.vector_in is not None: + vec_ = self.vector_in(y[:, :self.params.vec_in_dim]) + vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1) + else: + vec = torch.cat([(token_replace_vec).unsqueeze(1), (vec).unsqueeze(1)], dim=1) frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2]) modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)] modulation_dims_txt = [(0, None, 1)] else: - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.vector_in is not None: + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) modulation_dims = None modulation_dims_txt = None @@ -269,6 +329,12 @@ class HunyuanVideo(nn.Module): txt = self.txt_in(txt, timesteps, txt_mask) + if self.byt5_in is not None and txt_byt5 is not None: + txt_byt5 = self.byt5_in(txt_byt5) + txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + txt = torch.cat((txt, txt_byt5), dim=1) + txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) + ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) @@ -328,12 +394,16 @@ class HunyuanVideo(nn.Module): img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels) - shape = initial_shape[-3:] + shape = initial_shape[-len(self.patch_size):] for i in range(len(shape)): shape[i] = shape[i] // self.patch_size[i] img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size) - img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) - img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) + if img.ndim == 8: + img = img.permute(0, 4, 1, 5, 2, 6, 3, 7) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4]) + else: + img = img.permute(0, 3, 1, 4, 2, 5) + img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3]) return img def img_ids(self, x): @@ -348,16 +418,30 @@ class HunyuanVideo(nn.Module): img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1) return repeat(img_ids, "t h w c -> b (t h w) c", b=bs) - def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + def img_ids_2d(self, x): + bs, c, h, w = x.shape + patch_size = self.patch_size + h_len = ((h + (patch_size[0] // 2)) // patch_size[0]) + w_len = ((w + (patch_size[1] // 2)) // patch_size[1]) + img_ids = torch.zeros((h_len, w_len, 2), device=x.device, dtype=x.dtype) + img_ids[:, :, 0] = img_ids[:, :, 0] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + return repeat(img_ids, "h w c -> b (h w) c", b=bs) + + def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): - bs, c, t, h, w = x.shape - img_ids = self.img_ids(x) - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) + def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs): + bs = x.shape[0] + if len(self.patch_size) == 3: + img_ids = self.img_ids(x) + txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + else: + img_ids = self.img_ids_2d(x) + txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/hunyuan_video/vae.py b/comfy/ldm/hunyuan_video/vae.py new file mode 100644 index 000000000..40c12b183 --- /dev/null +++ b/comfy/ldm/hunyuan_video/vae.py @@ -0,0 +1,136 @@ +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock +import comfy.ops +ops = comfy.ops.disable_weight_init + + +class PixelShuffle2D(nn.Module): + def __init__(self, in_dim, out_dim, op=ops.Conv2d): + super().__init__() + self.conv = op(in_dim, out_dim >> 2, 3, 1, 1) + self.ratio = (in_dim << 2) // out_dim + + def forward(self, x): + b, c, h, w = x.shape + h2, w2 = h >> 1, w >> 1 + y = self.conv(x).view(b, -1, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, -1, h2, w2) + r = x.view(b, c, h2, 2, w2, 2).permute(0, 3, 5, 1, 2, 4).reshape(b, c << 2, h2, w2) + return y + r.view(b, y.shape[1], self.ratio, h2, w2).mean(2) + + +class PixelUnshuffle2D(nn.Module): + def __init__(self, in_dim, out_dim, op=ops.Conv2d): + super().__init__() + self.conv = op(in_dim, out_dim << 2, 3, 1, 1) + self.scale = (out_dim << 2) // in_dim + + def forward(self, x): + b, c, h, w = x.shape + h2, w2 = h << 1, w << 1 + y = self.conv(x).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2) + r = x.repeat_interleave(self.scale, 1).view(b, 2, 2, -1, h, w).permute(0, 3, 4, 1, 5, 2).reshape(b, -1, h2, w2) + return y + r + + +class Encoder(nn.Module): + def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, + ffactor_spatial, downsample_match_channel=True, **_): + super().__init__() + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + self.conv_in = ops.Conv2d(in_channels, block_out_channels[0], 3, 1, 1) + + self.down = nn.ModuleList() + ch = block_out_channels[0] + depth = (ffactor_spatial >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=ops.Conv2d) + for j in range(num_res_blocks)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch + stage.downsample = PixelShuffle2D(ch, nxt, ops.Conv2d) + ch = nxt + self.down.append(stage) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + + self.norm_out = ops.GroupNorm(32, ch, 1e-6, True) + self.conv_out = ops.Conv2d(ch, z_channels << 1, 3, 1, 1) + + def forward(self, x): + x = self.conv_in(x) + + for stage in self.down: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'downsample'): + x = stage.downsample(x) + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + b, c, h, w = x.shape + grp = c // (self.z_channels << 1) + skip = x.view(b, c // grp, grp, h, w).mean(2) + + return self.conv_out(F.silu(self.norm_out(x))) + skip + + +class Decoder(nn.Module): + def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks, + ffactor_spatial, upsample_match_channel=True, **_): + super().__init__() + block_out_channels = block_out_channels[::-1] + self.z_channels = z_channels + self.block_out_channels = block_out_channels + self.num_res_blocks = num_res_blocks + + ch = block_out_channels[0] + self.conv_in = ops.Conv2d(z_channels, ch, 3, 1, 1) + + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv2d) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=ops.Conv2d) + + self.up = nn.ModuleList() + depth = (ffactor_spatial >> 1).bit_length() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=ops.Conv2d) + for j in range(num_res_blocks + 1)]) + ch = tgt + if i < depth: + nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch + stage.upsample = PixelUnshuffle2D(ch, nxt, ops.Conv2d) + ch = nxt + self.up.append(stage) + + self.norm_out = ops.GroupNorm(32, ch, 1e-6, True) + self.conv_out = ops.Conv2d(ch, out_channels, 3, 1, 1) + + def forward(self, z): + x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + + for stage in self.up: + for blk in stage.block: + x = blk(x) + if hasattr(stage, 'upsample'): + x = stage.upsample(x) + + return self.conv_out(F.silu(self.norm_out(x))) diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 1fd12b35a..8f598a848 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -145,7 +145,7 @@ class Downsample(nn.Module): class ResnetBlock(nn.Module): def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, - dropout, temb_channels=512, conv_op=ops.Conv2d): + dropout=0.0, temb_channels=512, conv_op=ops.Conv2d): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels @@ -183,7 +183,7 @@ class ResnetBlock(nn.Module): stride=1, padding=0) - def forward(self, x, temb): + def forward(self, x, temb=None): h = x h = self.norm1(h) h = self.swish(h) diff --git a/comfy/lora.py b/comfy/lora.py index 00358884b..4a44f1318 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -260,6 +260,10 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer + for k in sdk: + hidden_size = model.model_config.unet_config.get("hidden_size", 0) + if k.endswith(".weight") and ".linear1." in k: + key_map["{}".format(k.replace(".linear1.weight", ".linear1_qkv"))] = (k, (0, 0, hidden_size * 3)) if isinstance(model, comfy.model_base.GenmoMochi): for k in sdk: diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py index 3e00b63db..9d8d21efe 100644 --- a/comfy/lora_convert.py +++ b/comfy/lora_convert.py @@ -15,10 +15,29 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux def convert_lora_wan_fun(sd): #Wan Fun loras return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"}) +def convert_uso_lora(sd): + sd_out = {} + for k in sd: + tensor = sd[k] + k_to = "diffusion_model.{}".format(k.replace(".down.weight", ".lora_down.weight") + .replace(".up.weight", ".lora_up.weight") + .replace(".qkv_lora2.", ".txt_attn.qkv.") + .replace(".qkv_lora1.", ".img_attn.qkv.") + .replace(".proj_lora1.", ".img_attn.proj.") + .replace(".proj_lora2.", ".txt_attn.proj.") + .replace(".qkv_lora.", ".linear1_qkv.") + .replace(".proj_lora.", ".linear2.") + .replace(".processor.", ".") + ) + sd_out[k_to] = tensor + return sd_out + def convert_lora(sd): if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd: return convert_lora_bfl_control(sd) if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd: return convert_lora_wan_fun(sd) + if "single_blocks.37.processor.qkv_lora.up.weight" in sd and "double_blocks.18.processor.qkv_lora2.up.weight" in sd: + return convert_uso_lora(sd) return sd diff --git a/comfy/model_base.py b/comfy/model_base.py index 56a6798be..993ff65e6 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -16,6 +16,8 @@ along with this program. If not, see . """ +import comfy.ldm.hunyuan3dv2_1 +import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep @@ -1282,6 +1284,21 @@ class Hunyuan3Dv2(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class Hunyuan3Dv2_1(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3dv2_1.hunyuandit.HunYuanDiTPlain) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + guidance = kwargs.get("guidance", 5.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + return out + class HiDream(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) @@ -1391,3 +1408,27 @@ class QwenImage(BaseModel): if ref_latents is not None: out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out + +class HunyuanImage21(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_byt5small = kwargs.get("conditioning_byt5small", None) + if conditioning_byt5small is not None: + out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small) + + guidance = kwargs.get("guidance", 6.0) + if guidance is not None: + out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) + + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 9f3ab64df..fe983cede 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -136,20 +136,40 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video dit_config = {} + in_w = state_dict['{}img_in.proj.weight'.format(key_prefix)] + out_w = state_dict['{}final_layer.linear.weight'.format(key_prefix)] dit_config["image_model"] = "hunyuan_video" - dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels - dit_config["patch_size"] = [1, 2, 2] - dit_config["out_channels"] = 16 - dit_config["vec_in_dim"] = 768 - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 + dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels + dit_config["patch_size"] = list(in_w.shape[2:]) + dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"]) + if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys): + dit_config["vec_in_dim"] = 768 + else: + dit_config["vec_in_dim"] = None + + if len(dit_config["patch_size"]) == 2: + dit_config["axes_dim"] = [64, 64] + else: + dit_config["axes_dim"] = [16, 56, 56] + + if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys): + dit_config["meanflow"] = True + else: + dit_config["meanflow"] = False + + dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1] + dit_config["hidden_size"] = in_w.shape[0] dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 + dit_config["num_heads"] = in_w.shape[0] // 128 dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] dit_config["theta"] = 256 dit_config["qkv_bias"] = True + if '{}byt5_in.fc1.weight'.format(key_prefix) in state_dict: + dit_config["byt5"] = True + else: + dit_config["byt5"] = False + guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys)) dit_config["guidance_embed"] = len(guidance_keys) > 0 return dit_config @@ -400,6 +420,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config + if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 + + dit_config = {} + dit_config["image_model"] = "hunyuan3d2_1" + dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] + dit_config["context_dim"] = 1024 + dit_config["hidden_size"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[0] + dit_config["mlp_ratio"] = 4.0 + dit_config["num_heads"] = 16 + dit_config["depth"] = count_blocks(state_dict_keys, f"{key_prefix}blocks.{{}}") + dit_config["qkv_bias"] = False + dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys + return dit_config + if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream dit_config = {} dit_config["image_model"] = "hidream" diff --git a/comfy/model_management.py b/comfy/model_management.py index 4ac04b8b9..3dfa122c8 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -23,6 +23,7 @@ from enum import Enum from comfy.cli_args import args, PerformanceFeature import torch import sys +import importlib import platform import weakref import gc @@ -313,6 +314,24 @@ def is_amd(): return True return False +def amd_min_version(device=None, min_rdna_version=0): + if not is_amd(): + return False + + if is_device_cpu(device): + return False + + arch = torch.cuda.get_device_properties(device).gcnArchName + if arch.startswith('gfx') and len(arch) == 7: + try: + cmp_rdna_version = int(arch[4]) + 2 + except: + cmp_rdna_version = 0 + if cmp_rdna_version >= min_rdna_version: + return True + + return False + MIN_WEIGHT_MEMORY_RATIO = 0.4 if is_nvidia(): MIN_WEIGHT_MEMORY_RATIO = 0.0 @@ -345,12 +364,13 @@ try: logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 - ENABLE_PYTORCH_ATTENTION = True -# if torch_version_numeric >= (2, 8): -# if any((a in arch) for a in ["gfx1201"]): -# ENABLE_PYTORCH_ATTENTION = True + if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. + if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + ENABLE_PYTORCH_ATTENTION = True +# if torch_version_numeric >= (2, 8): +# if any((a in arch) for a in ["gfx1201"]): +# ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches SUPPORT_FP8_OPS = True @@ -933,7 +953,9 @@ def vae_dtype(device=None, allowed_dtypes=[]): # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 # slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3 - if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): + # also a problem on RDNA4 except fp32 is also slow there. + # This is due to large bf16 convolutions being extremely slow. + if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device): return d return torch.float32 diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e3d0b4840..c64778da0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -513,6 +513,9 @@ class ModelPatcher: def set_model_double_block_patch(self, patch): self.set_model_patch(patch, "double_block") + def set_model_post_input_patch(self, patch): + self.set_model_patch(patch, "post_input") + def add_object_patch(self, name, obj): self.object_patches[name] = obj diff --git a/comfy/ops.py b/comfy/ops.py index 18e7db705..55e958adb 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -52,6 +52,9 @@ except (ModuleNotFoundError, TypeError): cast_to = comfy.model_management.cast_to #TODO: remove once no more references +if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: + torch.backends.cudnn.benchmark = True + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) diff --git a/comfy/sd.py b/comfy/sd.py index bb5d61fb3..9dd9a74d4 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -17,6 +17,7 @@ import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline +import comfy.ldm.hunyuan_video.vae import yaml import math import os @@ -48,6 +49,7 @@ import comfy.text_encoders.hidream import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image +import comfy.text_encoders.hunyuan_image import comfy.model_patcher import comfy.lora @@ -328,6 +330,19 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "decoder.conv_in.weight" in sd and sd['decoder.conv_in.weight'].shape[1] == 64: + ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} + self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + self.downscale_ratio = 32 + self.upscale_ratio = 32 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"}, + encoder_config={'target': "comfy.ldm.hunyuan_video.vae.Encoder", 'params': ddconfig}, + decoder_config={'target': "comfy.ldm.hunyuan_video.vae.Decoder", 'params': ddconfig}) + + self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) + elif "decoder.conv_in.weight" in sd: #default SD1.x/SD2.x VAE parameters ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -446,17 +461,29 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) + # Hunyuan 3d v2 2.0 & 2.1 elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: + self.latent_dim = 1 - ln_post = "geo_decoder.ln_post.weight" in sd - inner_size = sd["geo_decoder.output_proj.weight"].shape[1] - downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size - mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size - self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO - self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO - ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post} - self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig) + + def estimate_memory(shape, dtype, num_layers = 16, kv_cache_multiplier = 2): + batch, num_tokens, hidden_dim = shape + dtype_size = model_management.dtype_size(dtype) + + total_mem = batch * num_tokens * hidden_dim * dtype_size * (1 + kv_cache_multiplier * num_layers) + return total_mem + + # better memory estimations + self.memory_used_encode = lambda shape, dtype, num_layers = 8, kv_cache_multiplier = 0:\ + estimate_memory(shape, dtype, num_layers, kv_cache_multiplier) + + self.memory_used_decode = lambda shape, dtype, num_layers = 16, kv_cache_multiplier = 2: \ + estimate_memory(shape, dtype, num_layers, kv_cache_multiplier) + + self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE() self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100) self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) @@ -773,6 +800,7 @@ class CLIPType(Enum): ACE = 16 OMNIGEN2 = 17 QWEN_IMAGE = 18 + HUNYUAN_IMAGE = 19 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -794,6 +822,7 @@ class TEModel(Enum): GEMMA_2_2B = 9 QWEN25_3B = 10 QWEN25_7B = 11 + BYT5_SMALL_GLYPH = 12 def detect_te_model(sd): if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: @@ -811,6 +840,9 @@ def detect_te_model(sd): if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: return TEModel.T5_XXL_OLD if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd: + weight = sd['encoder.block.0.layer.0.SelfAttention.k.weight'] + if weight.shape[0] == 384: + return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: return TEModel.GEMMA_2_2B @@ -925,8 +957,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.omnigen2.Omnigen2Tokenizer elif te_model == TEModel.QWEN25_7B: - clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer + if clip_type == CLIPType.HUNYUAN_IMAGE: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + else: + clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer else: # clip_l if clip_type == CLIPType.SD3: @@ -970,6 +1006,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer + elif clip_type == CLIPType.HUNYUAN_IMAGE: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 76260de00..aa953b462 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -20,6 +20,7 @@ import comfy.text_encoders.wan import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image +import comfy.text_encoders.hunyuan_image from . import supported_models_base from . import latent_formats @@ -1128,6 +1129,17 @@ class Hunyuan3Dv2(supported_models_base.BASE): def clip_target(self, state_dict={}): return None +class Hunyuan3Dv2_1(Hunyuan3Dv2): + unet_config = { + "image_model": "hunyuan3d2_1", + } + + latent_format = latent_formats.Hunyuan3Dv2_1 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Hunyuan3Dv2_1(self, device = device) + return out + class Hunyuan3Dv2mini(Hunyuan3Dv2): unet_config = { "image_model": "hunyuan3d2", @@ -1284,7 +1296,31 @@ class QwenImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.qwen_image.QwenImageTokenizer, comfy.text_encoders.qwen_image.te(**hunyuan_detect)) +class HunyuanImage21(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vec_in_dim": None, + } -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] + sampling_settings = { + "shift": 5.0, + } + + latent_format = latent_formats.HunyuanImage21 + + memory_usage_factor = 7.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanImage21(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage] models += [SVD_img2vid] diff --git a/comfy/text_encoders/byt5_config_small_glyph.json b/comfy/text_encoders/byt5_config_small_glyph.json new file mode 100644 index 000000000..0239c7164 --- /dev/null +++ b/comfy/text_encoders/byt5_config_small_glyph.json @@ -0,0 +1,22 @@ +{ + "d_ff": 3584, + "d_kv": 64, + "d_model": 1472, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 1, + "dense_act_fn": "gelu_pytorch_tanh", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 4, + "num_heads": 6, + "num_layers": 12, + "output_past": true, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "vocab_size": 1510 +} diff --git a/comfy/text_encoders/byt5_tokenizer/added_tokens.json b/comfy/text_encoders/byt5_tokenizer/added_tokens.json new file mode 100644 index 000000000..93c190b56 --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/added_tokens.json @@ -0,0 +1,127 @@ +{ + "": 259, + "": 359, + "": 360, + "": 361, + "": 362, + "": 363, + "": 364, + "": 365, + "": 366, + "": 367, + "": 368, + "": 269, + "": 369, + "": 370, + "": 371, + "": 372, + "": 373, + "": 374, + "": 375, + "": 376, + "": 377, + "": 378, + "": 270, + "": 379, + "": 380, + "": 381, + "": 382, + "": 383, + "": 271, + "": 272, + "": 273, + "": 274, + "": 275, + "": 276, + "": 277, + "": 278, + "": 260, + "": 279, + "": 280, + "": 281, + "": 282, + "": 283, + "": 284, + "": 285, + "": 286, + "": 287, + "": 288, + "": 261, + "": 289, + "": 290, + "": 291, + "": 292, + "": 293, + "": 294, + "": 295, + "": 296, + "": 297, + "": 298, + "": 262, + "": 299, + "": 300, + "": 301, + "": 302, + "": 303, + "": 304, + "": 305, + "": 306, + "": 307, + "": 308, + "": 263, + "": 309, + "": 310, + "": 311, + "": 312, + "": 313, + "": 314, + "": 315, + "": 316, + "": 317, + "": 318, + "": 264, + "": 319, + "": 320, + "": 321, + "": 322, + "": 323, + "": 324, + "": 325, + "": 326, + "": 327, + "": 328, + "": 265, + "": 329, + "": 330, + "": 331, + "": 332, + "": 333, + "": 334, + "": 335, + "": 336, + "": 337, + "": 338, + "": 266, + "": 339, + "": 340, + "": 341, + "": 342, + "": 343, + "": 344, + "": 345, + "": 346, + "": 347, + "": 348, + "": 267, + "": 349, + "": 350, + "": 351, + "": 352, + "": 353, + "": 354, + "": 355, + "": 356, + "": 357, + "": 358, + "": 268 +} diff --git a/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json new file mode 100644 index 000000000..04fd58b5f --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/special_tokens_map.json @@ -0,0 +1,150 @@ +{ + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "eos_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false + } +} diff --git a/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json new file mode 100644 index 000000000..5b1fe24c1 --- /dev/null +++ b/comfy/text_encoders/byt5_tokenizer/tokenizer_config.json @@ -0,0 +1,1163 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": true + }, + "259": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "260": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "261": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "262": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "263": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "264": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "265": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "266": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "267": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "268": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "269": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "270": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "271": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "272": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "273": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "274": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "275": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "276": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "277": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "278": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "279": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "280": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "281": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "282": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "283": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "284": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "285": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "286": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "287": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "288": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "289": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "290": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "291": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "292": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "293": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "294": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "295": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "296": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "297": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "298": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "299": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "300": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "301": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "302": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "303": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "304": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "305": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "306": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "307": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "308": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "309": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "310": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "311": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "312": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "313": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "314": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "315": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "316": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "317": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "318": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "319": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "320": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "321": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "322": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "323": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "324": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "325": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "326": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "327": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "328": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "329": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "330": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "331": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "332": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "333": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "334": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "335": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "336": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "337": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "338": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "339": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "340": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "341": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "342": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "343": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "344": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "345": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "346": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "347": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "348": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "349": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "350": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "351": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "352": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "353": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "354": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "355": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "356": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "357": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "358": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "359": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "360": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "361": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "362": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "363": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "364": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "365": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "366": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "367": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "368": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "369": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "370": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "371": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "372": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "373": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "374": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "375": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "376": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "377": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "378": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "379": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "380": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "381": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "382": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "383": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "additional_special_tokens": [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "" + ], + "clean_up_tokenization_spaces": false, + "eos_token": "", + "extra_ids": 0, + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "tokenizer_class": "ByT5Tokenizer", + "unk_token": "" +} diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py new file mode 100644 index 000000000..be396cae7 --- /dev/null +++ b/comfy/text_encoders/hunyuan_image.py @@ -0,0 +1,100 @@ +from comfy import sd1_clip +import comfy.text_encoders.llama +from .qwen_image import QwenImageTokenizer, QwenImageTEModel +from transformers import ByT5Tokenizer +import os +import re + +class ByT5SmallTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + +class HunyuanImageTokenizer(QwenImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + # self.llama_template_images = "{}" + self.byt5 = ByT5SmallTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = super().tokenize_with_weights(text, return_word_ids, **kwargs) + + # ByT5 processing for HunyuanImage + text_prompt_texts = [] + pattern_quote_single = r'\'(.*?)\'' + pattern_quote_double = r'\"(.*?)\"' + pattern_quote_chinese_single = r'‘(.*?)’' + pattern_quote_chinese_double = r'“(.*?)”' + + matches_quote_single = re.findall(pattern_quote_single, text) + matches_quote_double = re.findall(pattern_quote_double, text) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text) + + text_prompt_texts.extend(matches_quote_single) + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + if len(text_prompt_texts) > 0: + out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs) + return out + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): + llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) + if llama_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ByT5SmallModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): + textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "byt5_config_small_glyph.json") + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) + + +class HunyuanImageTEModel(QwenImageTEModel): + def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}): + super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + + if byt5: + self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options) + else: + self.byt5_small = None + + def encode_token_weights(self, token_weight_pairs): + cond, p, extra = super().encode_token_weights(token_weight_pairs) + if self.byt5_small is not None and "byt5" in token_weight_pairs: + out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"]) + extra["conditioning_byt5small"] = out[0] + return cond, p, extra + + def set_clip_options(self, options): + super().set_clip_options(options) + if self.byt5_small is not None: + self.byt5_small.set_clip_options(options) + + def reset_clip_options(self): + super().reset_clip_options() + if self.byt5_small is not None: + self.byt5_small.reset_clip_options() + + def load_sd(self, sd): + if "encoder.block.0.layer.0.SelfAttention.o.weight" in sd: + return self.byt5_small.load_sd(sd) + else: + return super().load_sd(sd) + +def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None): + class QwenImageTEModel_(HunyuanImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) + return QwenImageTEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 4c976058f..5e11956b5 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -128,11 +128,12 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=N def apply_rope(xq, xk, freqs_cis): + org_dtype = xq.dtype cos = freqs_cis[0] sin = freqs_cis[1] q_embed = (xq * cos) + (rotate_half(xq) * sin) k_embed = (xk * cos) + (rotate_half(xk) * sin) - return q_embed, k_embed + return q_embed.to(org_dtype), k_embed.to(org_dtype) class Attention(nn.Module): diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index e0ee943a7..f770109d5 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1190,13 +1190,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): raise NotImplementedError @classmethod - def validate_inputs(cls, **kwargs) -> bool: - """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" + def validate_inputs(cls, **kwargs) -> bool | str: + """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS. + + If the function returns a string, it will be used as the validation error message for the node. + """ raise NotImplementedError @classmethod def fingerprint_inputs(cls, **kwargs) -> Any: - """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED.""" + """Optionally, define this function to fingerprint inputs; equivalent to V1's IS_CHANGED. + + If this function returns the same value as last run, the node will not be executed.""" raise NotImplementedError @classmethod diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index f953f86df..37438f835 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -518,6 +518,71 @@ async def upload_audio_to_comfyapi( return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) +def f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2 ** 15) + elif wav.dtype == torch.int32: + return wav.float() / (2 ** 31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict: + """ + Decode any common audio container from bytes using PyAV and return + a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. + """ + with av.open(io.BytesIO(audio_bytes)) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in response.") + stream = af.streams.audio[0] + + in_sr = int(stream.codec_context.sample_rate) + out_sr = in_sr + + frames: list[torch.Tensor] = [] + n_channels = stream.channels or 1 + + for frame in af.decode(streams=stream.index): + arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] + buf = torch.from_numpy(arr) + if buf.ndim == 1: + buf = buf.unsqueeze(0) # [T] -> [1, T] + elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: + buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] + elif buf.shape[0] != n_channels: + buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] + frames.append(buf) + + if not frames: + raise ValueError("Decoded zero audio frames.") + + wav = torch.cat(frames, dim=1) # [C, T] + wav = f32_pcm(wav) + return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} + + +def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO: + waveform = audio["waveform"].cpu() + + output_buffer = io.BytesIO() + output_container = av.open(output_buffer, mode='w', format="mp3") + + out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) + out_stream.bit_rate = 320000 + + frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo') + frame.sample_rate = audio["sample_rate"] + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + output_container.mux(out_stream.encode(None)) + output_container.close() + output_buffer.seek(0) + return output_buffer + + def audio_to_base64_string( audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac" ) -> str: diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index 7a09df55b..78a23db30 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -951,7 +951,11 @@ class MagicPrompt2(str, Enum): class StyleType1(str, Enum): + AUTO = 'AUTO' GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + FICTION = 'FICTION' class ImagenImageGenerationInstance(BaseModel): @@ -2676,7 +2680,7 @@ class ReleaseNote(BaseModel): class RenderingSpeed(str, Enum): - BALANCED = 'BALANCED' + DEFAULT = 'DEFAULT' TURBO = 'TURBO' QUALITY = 'QUALITY' @@ -4918,6 +4922,14 @@ class IdeogramV3EditRequest(BaseModel): None, description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.', ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) class IdeogramV3Request(BaseModel): @@ -4951,6 +4963,14 @@ class IdeogramV3Request(BaseModel): style_type: Optional[StyleType1] = Field( None, description='The type of style to apply' ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) class ImagenGenerateImageResponse(BaseModel): diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability_api.py index 47c87daec..718360187 100644 --- a/comfy_api_nodes/apis/stability_api.py +++ b/comfy_api_nodes/apis/stability_api.py @@ -125,3 +125,25 @@ class StabilityResultsGetResponse(BaseModel): class StabilityAsyncResponse(BaseModel): id: Optional[str] = Field(None) + + +class StabilityTextToAudioRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + duration: int = Field(190, ge=1, le=190) + seed: int = Field(0, ge=0, le=4294967294) + steps: int = Field(8, ge=4, le=8) + output_format: str = Field("wav") + + +class StabilityAudioToAudioRequest(StabilityTextToAudioRequest): + strength: float = Field(0.01, ge=0.01, le=1.0) + + +class StabilityAudioInpaintRequest(StabilityTextToAudioRequest): + mask_start: int = Field(30, ge=0, le=190) + mask_end: int = Field(190, ge=0, le=190) + + +class StabilityAudioResponse(BaseModel): + audio: Optional[str] = Field(None) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py new file mode 100644 index 000000000..369a3a4fe --- /dev/null +++ b/comfy_api_nodes/nodes_bytedance.py @@ -0,0 +1,1217 @@ +import logging +import math +from enum import Enum +from typing import Literal, Optional, Type, Union +from typing_extensions import override + +import torch +from pydantic import BaseModel, Field + +from comfy_api.latest import ComfyExtension, io as comfy_io +from comfy_api_nodes.util.validation_utils import ( + validate_image_aspect_ratio_range, + get_number_of_images, + validate_image_dimensions, +) +from comfy_api_nodes.apis.client import ( + ApiEndpoint, + EmptyRequest, + HttpMethod, + SynchronousOperation, + PollingOperation, + T, +) +from comfy_api_nodes.apinode_utils import ( + download_url_to_image_tensor, + download_url_to_video_output, + upload_images_to_comfyapi, + validate_string, + image_tensor_pair_to_batch, +) + + +BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" + +# Long-running tasks endpoints(e.g., video) +BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" +BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} + + +class Text2ImageModelName(str, Enum): + seedream_3 = "seedream-3-0-t2i-250415" + + +class Image2ImageModelName(str, Enum): + seededit_3 = "seededit-3-0-i2i-250628" + + +class Text2VideoModelName(str, Enum): + seedance_1_pro = "seedance-1-0-pro-250528" + seedance_1_lite = "seedance-1-0-lite-t2v-250428" + + +class Image2VideoModelName(str, Enum): + """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757""" + seedance_1_pro = "seedance-1-0-pro-250528" + seedance_1_lite = "seedance-1-0-lite-i2v-250428" + + +class Text2ImageTaskCreationRequest(BaseModel): + model: Text2ImageModelName = Text2ImageModelName.seedream_3 + prompt: str = Field(...) + response_format: Optional[str] = Field("url") + size: Optional[str] = Field(None) + seed: Optional[int] = Field(0, ge=0, le=2147483647) + guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) + watermark: Optional[bool] = Field(True) + + +class Image2ImageTaskCreationRequest(BaseModel): + model: Image2ImageModelName = Image2ImageModelName.seededit_3 + prompt: str = Field(...) + response_format: Optional[str] = Field("url") + image: str = Field(..., description="Base64 encoded string or image URL") + size: Optional[str] = Field("adaptive") + seed: Optional[int] = Field(..., ge=0, le=2147483647) + guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) + watermark: Optional[bool] = Field(True) + + +class Seedream4Options(BaseModel): + max_images: int = Field(15) + + +class Seedream4TaskCreationRequest(BaseModel): + model: str = Field("seedream-4-0-250828") + prompt: str = Field(...) + response_format: str = Field("url") + image: Optional[list[str]] = Field(None, description="Image URLs") + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + sequential_image_generation: str = Field("disabled") + sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) + watermark: bool = Field(True) + + +class ImageTaskCreationResponse(BaseModel): + model: str = Field(...) + created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") + data: list = Field([], description="Contains information about the generated image(s).") + error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") + + +class TaskTextContent(BaseModel): + type: str = Field("text") + text: str = Field(...) + + +class TaskImageContentUrl(BaseModel): + url: str = Field(...) + + +class TaskImageContent(BaseModel): + type: str = Field("image_url") + image_url: TaskImageContentUrl = Field(...) + role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None) + + +class Text2VideoTaskCreationRequest(BaseModel): + model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro + content: list[TaskTextContent] = Field(..., min_length=1) + + +class Image2VideoTaskCreationRequest(BaseModel): + model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro + content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2) + + +class TaskCreationResponse(BaseModel): + id: str = Field(...) + + +class TaskStatusError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class TaskStatusResult(BaseModel): + video_url: str = Field(...) + + +class TaskStatusResponse(BaseModel): + id: str = Field(...) + model: str = Field(...) + status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) + error: Optional[TaskStatusError] = Field(None) + content: Optional[TaskStatusResult] = Field(None) + + +RECOMMENDED_PRESETS = [ + ("1024x1024 (1:1)", 1024, 1024), + ("864x1152 (3:4)", 864, 1152), + ("1152x864 (4:3)", 1152, 864), + ("1280x720 (16:9)", 1280, 720), + ("720x1280 (9:16)", 720, 1280), + ("832x1248 (2:3)", 832, 1248), + ("1248x832 (3:2)", 1248, 832), + ("1512x648 (21:9)", 1512, 648), + ("2048x2048 (1:1)", 2048, 2048), + ("Custom", None, None), +] + +RECOMMENDED_PRESETS_SEEDREAM_4 = [ + ("2048x2048 (1:1)", 2048, 2048), + ("2304x1728 (4:3)", 2304, 1728), + ("1728x2304 (3:4)", 1728, 2304), + ("2560x1440 (16:9)", 2560, 1440), + ("1440x2560 (9:16)", 1440, 2560), + ("2496x1664 (3:2)", 2496, 1664), + ("1664x2496 (2:3)", 1664, 2496), + ("3024x1296 (21:9)", 3024, 1296), + ("4096x4096 (1:1)", 4096, 4096), + ("Custom", None, None), +] + +# The time in this dictionary are given for 10 seconds duration. +VIDEO_TASKS_EXECUTION_TIME = { + "seedance-1-0-lite-t2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-lite-i2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-pro-250528": { + "480p": 70, + "720p": 85, + "1080p": 115, + }, +} + + +def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: + if response.error: + error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}" + logging.info(error_msg) + raise RuntimeError(error_msg) + logging.info("ByteDance task succeeded, image URL: %s", response.data[0]["url"]) + return response.data[0]["url"] + + +def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: + """Returns the video URL from the task status response if it exists.""" + if hasattr(response, "content") and response.content: + return response.content.video_url + return None + + +async def poll_until_finished( + auth_kwargs: dict[str, str], + task_id: str, + estimated_duration: Optional[int] = None, + node_id: Optional[str] = None, +) -> TaskStatusResponse: + """Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response.""" + return await PollingOperation( + poll_endpoint=ApiEndpoint( + path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=TaskStatusResponse, + ), + completed_statuses=[ + "succeeded", + ], + failed_statuses=[ + "cancelled", + "failed", + ], + status_extractor=lambda response: response.status, + auth_kwargs=auth_kwargs, + result_url_extractor=get_video_url_from_task_status, + estimated_duration=estimated_duration, + node_id=node_id, + ).execute() + + +class ByteDanceImageNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceImageNode", + display_name="ByteDance Image", + category="api node/image/ByteDance", + description="Generate images using ByteDance models via api based on prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in Text2ImageModelName], + default=Text2ImageModelName.seedream_3.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + comfy_io.Combo.Input( + "size_preset", + options=[label for label, _, _ in RECOMMENDED_PRESETS], + tooltip="Pick a recommended size. Select Custom to use the width and height below", + ), + comfy_io.Int.Input( + "width", + default=1024, + min=512, + max=2048, + step=64, + tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", + ), + comfy_io.Int.Input( + "height", + default=1024, + min=512, + max=2048, + step=64, + tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation", + optional=True, + ), + comfy_io.Float.Input( + "guidance_scale", + default=2.5, + min=1.0, + max=10.0, + step=0.01, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the image", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + size_preset: str, + width: int, + height: int, + seed: int, + guidance_scale: float, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + w = h = None + for label, tw, th in RECOMMENDED_PRESETS: + if label == size_preset: + w, h = tw, th + break + + if w is None or h is None: + w, h = width, height + if not (512 <= w <= 2048) or not (512 <= h <= 2048): + raise ValueError( + f"Custom size out of range: {w}x{h}. " + "Both width and height must be between 512 and 2048 pixels." + ) + + payload = Text2ImageTaskCreationRequest( + model=model, + prompt=prompt, + size=f"{w}x{h}", + seed=seed, + guidance_scale=guidance_scale, + watermark=watermark, + ) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=BYTEPLUS_IMAGE_ENDPOINT, + method=HttpMethod.POST, + request_model=Text2ImageTaskCreationRequest, + response_model=ImageTaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + + +class ByteDanceImageEditNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceImageEditNode", + display_name="ByteDance Image Edit", + category="api node/image/ByteDance", + description="Edit images using ByteDance models via api based on prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in Image2ImageModelName], + default=Image2ImageModelName.seededit_3.value, + tooltip="Model name", + ), + comfy_io.Image.Input( + "image", + tooltip="The base image to edit", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Instruction to edit image", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation", + optional=True, + ), + comfy_io.Float.Input( + "guidance_scale", + default=5.5, + min=1.0, + max=10.0, + step=0.01, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the image", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str, + seed: int, + guidance_scale: float, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + source_url = (await upload_images_to_comfyapi( + image, + max_images=1, + mime_type="image/png", + auth_kwargs=auth_kwargs, + ))[0] + payload = Image2ImageTaskCreationRequest( + model=model, + prompt=prompt, + image=source_url, + seed=seed, + guidance_scale=guidance_scale, + watermark=watermark, + ) + response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=BYTEPLUS_IMAGE_ENDPOINT, + method=HttpMethod.POST, + request_model=Image2ImageTaskCreationRequest, + response_model=ImageTaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + + +class ByteDanceSeedreamNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceSeedreamNode", + display_name="ByteDance Seedream 4", + category="api node/image/ByteDance", + description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=["seedream-4-0-250828"], + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for creating or editing an image.", + ), + comfy_io.Image.Input( + "image", + tooltip="Input image(s) for image-to-image generation. " + "List of 1-10 images for single or multi-reference generation.", + optional=True, + ), + comfy_io.Combo.Input( + "size_preset", + options=[label for label, _, _ in RECOMMENDED_PRESETS_SEEDREAM_4], + tooltip="Pick a recommended size. Select Custom to use the width and height below.", + ), + comfy_io.Int.Input( + "width", + default=2048, + min=1024, + max=4096, + step=64, + tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", + optional=True, + ), + comfy_io.Int.Input( + "height", + default=2048, + min=1024, + max=4096, + step=64, + tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", + optional=True, + ), + comfy_io.Combo.Input( + "sequential_image_generation", + options=["disabled", "auto"], + tooltip="Group image generation mode. " + "'disabled' generates a single image. " + "'auto' lets the model decide whether to generate multiple related images " + "(e.g., story scenes, character variations).", + optional=True, + ), + comfy_io.Int.Input( + "max_images", + default=1, + min=1, + max=15, + step=1, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Maximum number of images to generate when sequential_image_generation='auto'. " + "Total images (input + generated) cannot exceed 15.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the image.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + image: torch.Tensor = None, + size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0], + width: int = 2048, + height: int = 2048, + sequential_image_generation: str = "disabled", + max_images: int = 1, + seed: int = 0, + watermark: bool = True, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + w = h = None + for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4: + if label == size_preset: + w, h = tw, th + break + + if w is None or h is None: + w, h = width, height + if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): + raise ValueError( + f"Custom size out of range: {w}x{h}. " + "Both width and height must be between 1024 and 4096 pixels." + ) + n_input_images = get_number_of_images(image) if image is not None else 0 + if n_input_images > 10: + raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") + if sequential_image_generation == "auto" and n_input_images + max_images > 15: + raise ValueError( + "The maximum number of generated images plus the number of reference images cannot exceed 15." + ) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + reference_images_urls = [] + if n_input_images: + for i in image: + validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) + reference_images_urls = (await upload_images_to_comfyapi( + image, + max_images=n_input_images, + mime_type="image/png", + auth_kwargs=auth_kwargs, + )) + payload = Seedream4TaskCreationRequest( + model=model, + prompt=prompt, + image=reference_images_urls, + size=f"{w}x{h}", + seed=seed, + sequential_image_generation=sequential_image_generation, + sequential_image_generation_options=Seedream4Options(max_images=max_images), + watermark=watermark, + ) + response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=BYTEPLUS_IMAGE_ENDPOINT, + method=HttpMethod.POST, + request_model=Seedream4TaskCreationRequest, + response_model=ImageTaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + + if len(response.data) == 1: + return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) + return comfy_io.NodeOutput( + torch.cat([await download_url_to_image_tensor(str(i["url"])) for i in response.data]) + ) + + +class ByteDanceTextToVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceTextToVideoNode", + display_name="ByteDance Text to Video", + category="api node/video/ByteDance", + description="Generate video using ByteDance models via api based on prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in Text2VideoModelName], + default=Text2VideoModelName.seedance_1_pro.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + comfy_io.Combo.Input( + "resolution", + options=["480p", "720p", "1080p"], + tooltip="The resolution of the output video.", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=comfy_io.NumberDisplay.slider, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "camera_fixed", + default=False, + tooltip="Specifies whether to fix the camera. The platform appends an instruction " + "to fix the camera to your prompt, but does not guarantee the actual effect.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the video.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + camera_fixed: bool, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) + + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--camerafixed {str(camera_fixed).lower()} " + f"--watermark {str(watermark).lower()}" + ) + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + return await process_video_task( + request_model=Text2VideoTaskCreationRequest, + payload=Text2VideoTaskCreationRequest( + model=model, + content=[TaskTextContent(text=prompt)], + ), + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +class ByteDanceImageToVideoNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceImageToVideoNode", + display_name="ByteDance Image to Video", + category="api node/video/ByteDance", + description="Generate video using ByteDance models via api based on image and prompt", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[model.value for model in Image2VideoModelName], + default=Image2VideoModelName.seedance_1_pro.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + comfy_io.Image.Input( + "image", + tooltip="First frame to be used for the video.", + ), + comfy_io.Combo.Input( + "resolution", + options=["480p", "720p", "1080p"], + tooltip="The resolution of the output video.", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["adaptive", "16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=comfy_io.NumberDisplay.slider, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "camera_fixed", + default=False, + tooltip="Specifies whether to fix the camera. The platform appends an instruction " + "to fix the camera to your prompt, but does not guarantee the actual effect.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the video.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + image: torch.Tensor, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + camera_fixed: bool, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) + validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) + validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + + image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0] + + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--camerafixed {str(camera_fixed).lower()} " + f"--watermark {str(watermark).lower()}" + ) + + return await process_video_task( + request_model=Image2VideoTaskCreationRequest, + payload=Image2VideoTaskCreationRequest( + model=model, + content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], + ), + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +class ByteDanceFirstLastFrameNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceFirstLastFrameNode", + display_name="ByteDance First-Last-Frame to Video", + category="api node/video/ByteDance", + description="Generate video using prompt and first and last frames.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[Image2VideoModelName.seedance_1_lite.value], + default=Image2VideoModelName.seedance_1_lite.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + comfy_io.Image.Input( + "first_frame", + tooltip="First frame to be used for the video.", + ), + comfy_io.Image.Input( + "last_frame", + tooltip="Last frame to be used for the video.", + ), + comfy_io.Combo.Input( + "resolution", + options=["480p", "720p", "1080p"], + tooltip="The resolution of the output video.", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["adaptive", "16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=comfy_io.NumberDisplay.slider, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "camera_fixed", + default=False, + tooltip="Specifies whether to fix the camera. The platform appends an instruction " + "to fix the camera to your prompt, but does not guarantee the actual effect.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the video.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + first_frame: torch.Tensor, + last_frame: torch.Tensor, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + camera_fixed: bool, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) + for i in (first_frame, last_frame): + validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) + validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + + download_urls = await upload_images_to_comfyapi( + image_tensor_pair_to_batch(first_frame, last_frame), + max_images=2, + mime_type="image/png", + auth_kwargs=auth_kwargs, + ) + + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--camerafixed {str(camera_fixed).lower()} " + f"--watermark {str(watermark).lower()}" + ) + + return await process_video_task( + request_model=Image2VideoTaskCreationRequest, + payload=Image2VideoTaskCreationRequest( + model=model, + content=[ + TaskTextContent(text=prompt), + TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[0])), role="first_frame"), + TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), + ], + ), + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +class ByteDanceImageReferenceNode(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="ByteDanceImageReferenceNode", + display_name="ByteDance Reference Images to Video", + category="api node/video/ByteDance", + description="Generate video using prompt and reference images.", + inputs=[ + comfy_io.Combo.Input( + "model", + options=[Image2VideoModelName.seedance_1_lite.value], + default=Image2VideoModelName.seedance_1_lite.value, + tooltip="Model name", + ), + comfy_io.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the video.", + ), + comfy_io.Image.Input( + "images", + tooltip="One to four images.", + ), + comfy_io.Combo.Input( + "resolution", + options=["480p", "720p"], + tooltip="The resolution of the output video.", + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=["adaptive", "16:9", "4:3", "1:1", "3:4", "9:16", "21:9"], + tooltip="The aspect ratio of the output video.", + ), + comfy_io.Int.Input( + "duration", + default=5, + min=3, + max=12, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=comfy_io.NumberDisplay.slider, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to use for generation.", + optional=True, + ), + comfy_io.Boolean.Input( + "watermark", + default=True, + tooltip="Whether to add an \"AI generated\" watermark to the video.", + optional=True, + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + images: torch.Tensor, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + watermark: bool, + ) -> comfy_io.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"]) + for image in images: + validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) + validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + + image_urls = await upload_images_to_comfyapi( + images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs + ) + + prompt = ( + f"{prompt} " + f"--resolution {resolution} " + f"--ratio {aspect_ratio} " + f"--duration {duration} " + f"--seed {seed} " + f"--watermark {str(watermark).lower()}" + ) + x = [ + TaskTextContent(text=prompt), + *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls] + ] + return await process_video_task( + request_model=Image2VideoTaskCreationRequest, + payload=Image2VideoTaskCreationRequest( + model=model, + content=x, + ), + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), + ) + + +async def process_video_task( + request_model: Type[T], + payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], + auth_kwargs: dict, + node_id: str, + estimated_duration: int | None, +) -> comfy_io.NodeOutput: + initial_response = await SynchronousOperation( + endpoint=ApiEndpoint( + path=BYTEPLUS_TASK_ENDPOINT, + method=HttpMethod.POST, + request_model=request_model, + response_model=TaskCreationResponse, + ), + request=payload, + auth_kwargs=auth_kwargs, + ).execute() + response = await poll_until_finished( + auth_kwargs, + initial_response.id, + estimated_duration=estimated_duration, + node_id=node_id, + ) + return comfy_io.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) + + +def raise_if_text_params(prompt: str, text_params: list[str]) -> None: + for i in text_params: + if f"--{i} " in prompt: + raise ValueError( + f"--{i} is not allowed in the prompt, use the appropriated widget input to change this value." + ) + + +class ByteDanceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + ByteDanceImageNode, + ByteDanceImageEditNode, + ByteDanceSeedreamNode, + ByteDanceTextToVideoNode, + ByteDanceImageToVideoNode, + ByteDanceFirstLastFrameNode, + ByteDanceImageReferenceNode, + ] + +async def comfy_entrypoint() -> ByteDanceExtension: + return ByteDanceExtension() diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index d28895f3e..2d1c32e4f 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -255,6 +255,7 @@ class IdeogramV1(comfy_io.ComfyNode): display_name="Ideogram V1", category="api node/image/Ideogram", description="Generates images using the Ideogram V1 model.", + is_api_node=True, inputs=[ comfy_io.String.Input( "prompt", @@ -383,6 +384,7 @@ class IdeogramV2(comfy_io.ComfyNode): display_name="Ideogram V2", category="api node/image/Ideogram", description="Generates images using the Ideogram V2 model.", + is_api_node=True, inputs=[ comfy_io.String.Input( "prompt", @@ -552,6 +554,7 @@ class IdeogramV3(comfy_io.ComfyNode): category="api node/image/Ideogram", description="Generates images using the Ideogram V3 model. " "Supports both regular image generation from text prompts and image editing with mask.", + is_api_node=True, inputs=[ comfy_io.String.Input( "prompt", @@ -612,11 +615,21 @@ class IdeogramV3(comfy_io.ComfyNode): ), comfy_io.Combo.Input( "rendering_speed", - options=["BALANCED", "TURBO", "QUALITY"], - default="BALANCED", + options=["DEFAULT", "TURBO", "QUALITY"], + default="DEFAULT", tooltip="Controls the trade-off between generation speed and quality", optional=True, ), + comfy_io.Image.Input( + "character_image", + tooltip="Image to use as character reference.", + optional=True, + ), + comfy_io.Mask.Input( + "character_mask", + tooltip="Optional mask for character reference image.", + optional=True, + ), ], outputs=[ comfy_io.Image.Output(), @@ -639,12 +652,46 @@ class IdeogramV3(comfy_io.ComfyNode): magic_prompt_option="AUTO", seed=0, num_images=1, - rendering_speed="BALANCED", + rendering_speed="DEFAULT", + character_image=None, + character_mask=None, ): auth = { "auth_token": cls.hidden.auth_token_comfy_org, "comfy_api_key": cls.hidden.api_key_comfy_org, } + if rendering_speed == "BALANCED": # for backward compatibility + rendering_speed = "DEFAULT" + + character_img_binary = None + character_mask_binary = None + + if character_image is not None: + input_tensor = character_image.squeeze().cpu() + if character_mask is not None: + character_mask = resize_mask_to_image(character_mask, character_image, allow_gradient=False) + character_mask = 1.0 - character_mask + if character_mask.shape[1:] != character_image.shape[1:-1]: + raise Exception("Character mask and image must be the same size") + + mask_np = (character_mask.squeeze().cpu().numpy() * 255).astype(np.uint8) + mask_img = Image.fromarray(mask_np) + mask_byte_arr = BytesIO() + mask_img.save(mask_byte_arr, format="PNG") + mask_byte_arr.seek(0) + character_mask_binary = mask_byte_arr + character_mask_binary.name = "mask.png" + + img_np = (input_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(img_np) + img_byte_arr = BytesIO() + img.save(img_byte_arr, format="PNG") + img_byte_arr.seek(0) + character_img_binary = img_byte_arr + character_img_binary.name = "image.png" + elif character_mask is not None: + raise Exception("Character mask requires character image to be present") + # Check if both image and mask are provided for editing mode if image is not None and mask is not None: # Edit mode @@ -693,6 +740,15 @@ class IdeogramV3(comfy_io.ComfyNode): if num_images > 1: edit_request.num_images = num_images + files = { + "image": img_binary, + "mask": mask_binary, + } + if character_img_binary: + files["character_reference_images"] = character_img_binary + if character_mask_binary: + files["character_mask_binary"] = character_mask_binary + # Execute the operation for edit mode operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -702,10 +758,7 @@ class IdeogramV3(comfy_io.ComfyNode): response_model=IdeogramGenerateResponse, ), request=edit_request, - files={ - "image": img_binary, - "mask": mask_binary, - }, + files=files, content_type="multipart/form-data", auth_kwargs=auth, ) @@ -739,6 +792,14 @@ class IdeogramV3(comfy_io.ComfyNode): if num_images > 1: gen_request.num_images = num_images + files = {} + if character_img_binary: + files["character_reference_images"] = character_img_binary + if character_mask_binary: + files["character_mask_binary"] = character_mask_binary + if files: + gen_request.style_type = "AUTO" + # Execute the operation for generation mode operation = SynchronousOperation( endpoint=ApiEndpoint( @@ -748,6 +809,8 @@ class IdeogramV3(comfy_io.ComfyNode): response_model=IdeogramGenerateResponse, ), request=gen_request, + files=files if files else None, + content_type="multipart/form-data", auth_kwargs=auth, ) diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 98024a9fa..27b2bf748 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -12,6 +12,7 @@ User Guides: """ from typing import Union, Optional, Any +from typing_extensions import override from enum import Enum import torch @@ -46,9 +47,9 @@ from comfy_api_nodes.apinode_utils import ( validate_string, download_url_to_image_tensor, ) -from comfy_api_nodes.mapper_utils import model_field_to_node_input from comfy_api.input_impl import VideoFromFile -from comfy.comfy_types.node_typing import IO, ComfyNodeABC +from comfy_api.latest import ComfyExtension, io as comfy_io +from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" @@ -85,20 +86,11 @@ class RunwayGen3aAspectRatio(str, Enum): def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: """Returns the video URL from the task status response if it exists.""" - if response.output and len(response.output) > 0: + if hasattr(response, "output") and len(response.output) > 0: return response.output[0] return None -# TODO: replace with updated image validation utils (upstream) -def validate_input_image(image: torch.Tensor) -> bool: - """ - Validate the input image is within the size limits for the Runway API. - See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons - """ - return image.shape[2] < 8000 and image.shape[1] < 8000 - - async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, TaskStatusResponse], @@ -134,458 +126,438 @@ def extract_progress_from_task_status( def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: """Returns the image URL from the task status response if it exists.""" - if response.output and len(response.output) > 0: + if hasattr(response, "output") and len(response.output) > 0: return response.output[0] return None -class RunwayVideoGenNode(ComfyNodeABC): - """Runway Video Node Base.""" +async def get_response( + task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None +) -> TaskStatusResponse: + """Poll the task status until it is finished then get the response.""" + return await poll_until_finished( + auth_kwargs, + ApiEndpoint( + path=f"{PATH_GET_TASK_STATUS}/{task_id}", + method=HttpMethod.GET, + request_model=EmptyRequest, + response_model=TaskStatusResponse, + ), + estimated_duration=estimated_duration, + node_id=node_id, + ) - RETURN_TYPES = ("VIDEO",) - FUNCTION = "api_call" - CATEGORY = "api node/video/Runway" - API_NODE = True - def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool: - """ - Validate the task creation response from the Runway API matches - expected format. - """ - if not bool(response.id): - raise RunwayApiError("Invalid initial response from Runway API.") - return True +async def generate_video( + request: RunwayImageToVideoRequest, + auth_kwargs: dict[str, str], + node_id: Optional[str] = None, + estimated_duration: Optional[int] = None, +) -> VideoFromFile: + initial_operation = SynchronousOperation( + endpoint=ApiEndpoint( + path=PATH_IMAGE_TO_VIDEO, + method=HttpMethod.POST, + request_model=RunwayImageToVideoRequest, + response_model=RunwayImageToVideoResponse, + ), + request=request, + auth_kwargs=auth_kwargs, + ) - def validate_response(self, response: RunwayImageToVideoResponse) -> bool: - """ - Validate the successful task status response from the Runway API - matches expected format. - """ - if not response.output or len(response.output) == 0: - raise RunwayApiError( - "Runway task succeeded but no video data found in response." - ) - return True + initial_response = await initial_operation.execute() - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> RunwayImageToVideoResponse: - """Poll the task status until it is finished then get the response.""" - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - estimated_duration=AVERAGE_DURATION_FLF_SECONDS, - node_id=node_id, + final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration) + if not final_response.output: + raise RunwayApiError("Runway task succeeded but no video data found in response.") + + video_url = get_video_url_from_task_status(final_response) + return await download_url_to_video_output(video_url) + + +class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode): + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="RunwayImageToVideoNodeGen3a", + display_name="Runway Image to Video (Gen3a Turbo)", + category="api node/video/Runway", + description="Generate a video from a single starting frame using Gen3a Turbo model. " + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", + ), + comfy_io.Image.Input( + "start_frame", + tooltip="Start frame to be used for the video", + ), + comfy_io.Combo.Input( + "duration", + options=[model.value for model in Duration], + ), + comfy_io.Combo.Input( + "ratio", + options=[model.value for model in RunwayGen3aAspectRatio], + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967295, + step=1, + control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Random seed for generation", + ), + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - async def generate_video( - self, - request: RunwayImageToVideoRequest, - auth_kwargs: dict[str, str], - node_id: Optional[str] = None, - ) -> tuple[VideoFromFile]: - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=RunwayImageToVideoRequest, - response_model=RunwayImageToVideoResponse, - ), - request=request, + @classmethod + async def execute( + cls, + prompt: str, + start_frame: torch.Tensor, + duration: str, + ratio: str, + seed: int, + ) -> comfy_io.NodeOutput: + validate_string(prompt, min_length=1) + validate_image_dimensions(start_frame, max_width=7999, max_height=7999) + validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + + download_urls = await upload_images_to_comfyapi( + start_frame, + max_images=1, + mime_type="image/png", auth_kwargs=auth_kwargs, ) - initial_response = await initial_operation.execute() - self.validate_task_created(initial_response) - task_id = initial_response.id - - final_response = await self.get_response(task_id, auth_kwargs, node_id) - self.validate_response(final_response) - - video_url = get_video_url_from_task_status(final_response) - return (await download_url_to_video_output(video_url),) + return comfy_io.NodeOutput( + await generate_video( + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen3a_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[ + RunwayPromptImageDetailedObject( + uri=str(download_urls[0]), position="first" + ) + ] + ), + ), + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + ) + ) -class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): - """Runway Image to Video Node using Gen3a Turbo model.""" - - DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo." +class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True + def define_schema(cls): + return comfy_io.Schema( + node_id="RunwayImageToVideoNodeGen4", + display_name="Runway Image to Video (Gen4 Turbo)", + category="api node/video/Runway", + description="Generate a video from a single starting frame using Gen4 Turbo model. " + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", ), - "start_frame": ( - IO.IMAGE, - {"tooltip": "Start frame to be used for the video"}, + comfy_io.Image.Input( + "start_frame", + tooltip="Start frame to be used for the video", ), - "duration": model_field_to_node_input( - IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration + comfy_io.Combo.Input( + "duration", + options=[model.value for model in Duration], ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayImageToVideoRequest, + comfy_io.Combo.Input( "ratio", - enum_type=RunwayGen3aAspectRatio, + options=[model.value for model in RunwayGen4TurboAspectRatio], ), - "seed": model_field_to_node_input( - IO.INT, - RunwayImageToVideoRequest, + comfy_io.Int.Input( "seed", + default=0, + min=0, + max=4294967295, + step=1, control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Random seed for generation", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, start_frame: torch.Tensor, duration: str, ratio: str, seed: int, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Validate inputs + ) -> comfy_io.NodeOutput: validate_string(prompt, min_length=1) - validate_input_image(start_frame) + validate_image_dimensions(start_frame, max_width=7999, max_height=7999) + validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } - # Upload image download_urls = await upload_images_to_comfyapi( start_frame, max_images=1, mime_type="image/png", - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) - if len(download_urls) != 1: - raise RunwayApiError("Failed to upload one or more images to comfy api.") - return await self.generate_video( - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen3a_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + return comfy_io.NodeOutput( + await generate_video( + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen4_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[ + RunwayPromptImageDetailedObject( + uri=str(download_urls[0]), position="first" + ) + ] + ), ), - ), - auth_kwargs=kwargs, - node_id=unique_id, + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=AVERAGE_DURATION_FLF_SECONDS, + ) ) -class RunwayImageToVideoNodeGen4(RunwayVideoGenNode): - """Runway Image to Video Node using Gen4 Turbo model.""" - - DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video." +class RunwayFirstLastFrameNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True + def define_schema(cls): + return comfy_io.Schema( + node_id="RunwayFirstLastFrameNode", + display_name="Runway First-Last-Frame to Video", + category="api node/video/Runway", + description="Upload first and last keyframes, draft a prompt, and generate a video. " + "More complex transitions, such as cases where the Last frame is completely different " + "from the First frame, may benefit from the longer 10s duration. " + "This would give the generation more time to smoothly transition between the two inputs. " + "Before diving in, review these best practices to ensure that your input selections " + "will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", ), - "start_frame": ( - IO.IMAGE, - {"tooltip": "Start frame to be used for the video"}, + comfy_io.Image.Input( + "start_frame", + tooltip="Start frame to be used for the video", ), - "duration": model_field_to_node_input( - IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration + comfy_io.Image.Input( + "end_frame", + tooltip="End frame to be used for the video. Supported for gen3a_turbo only.", ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayImageToVideoRequest, + comfy_io.Combo.Input( + "duration", + options=[model.value for model in Duration], + ), + comfy_io.Combo.Input( "ratio", - enum_type=RunwayGen4TurboAspectRatio, + options=[model.value for model in RunwayGen3aAspectRatio], ), - "seed": model_field_to_node_input( - IO.INT, - RunwayImageToVideoRequest, + comfy_io.Int.Input( "seed", + default=0, + min=0, + max=4294967295, + step=1, control_after_generate=True, + display_mode=comfy_io.NumberDisplay.number, + tooltip="Random seed for generation", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, - prompt: str, - start_frame: torch.Tensor, - duration: str, - ratio: str, - seed: int, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Validate inputs - validate_string(prompt, min_length=1) - validate_input_image(start_frame) - - # Upload image - download_urls = await upload_images_to_comfyapi( - start_frame, - max_images=1, - mime_type="image/png", - auth_kwargs=kwargs, - ) - if len(download_urls) != 1: - raise RunwayApiError("Failed to upload one or more images to comfy api.") - - return await self.generate_video( - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen4_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] - ), - ), - auth_kwargs=kwargs, - node_id=unique_id, - ) - - -class RunwayFirstLastFrameNode(RunwayVideoGenNode): - """Runway First-Last Frame Node.""" - - DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3." - - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> RunwayImageToVideoResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - estimated_duration=AVERAGE_DURATION_FLF_SECONDS, - node_id=node_id, + ], + outputs=[ + comfy_io.Video.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True - ), - "start_frame": ( - IO.IMAGE, - {"tooltip": "Start frame to be used for the video"}, - ), - "end_frame": ( - IO.IMAGE, - { - "tooltip": "End frame to be used for the video. Supported for gen3a_turbo only." - }, - ), - "duration": model_field_to_node_input( - IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration - ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayImageToVideoRequest, - "ratio", - enum_type=RunwayGen3aAspectRatio, - ), - "seed": model_field_to_node_input( - IO.INT, - RunwayImageToVideoRequest, - "seed", - control_after_generate=True, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "unique_id": "UNIQUE_ID", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, + async def execute( + cls, prompt: str, start_frame: torch.Tensor, end_frame: torch.Tensor, duration: str, ratio: str, seed: int, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[VideoFromFile]: - # Validate inputs + ) -> comfy_io.NodeOutput: validate_string(prompt, min_length=1) - validate_input_image(start_frame) - validate_input_image(end_frame) + validate_image_dimensions(start_frame, max_width=7999, max_height=7999) + validate_image_dimensions(end_frame, max_width=7999, max_height=7999) + validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } - # Upload images stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) download_urls = await upload_images_to_comfyapi( stacked_input_images, max_images=2, mime_type="image/png", - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) if len(download_urls) != 2: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return await self.generate_video( - RunwayImageToVideoRequest( - promptText=prompt, - seed=seed, - model=Model("gen3a_turbo"), - duration=Duration(duration), - ratio=AspectRatio(ratio), - promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ), - RunwayPromptImageDetailedObject( - uri=str(download_urls[1]), position="last" - ), - ] + return comfy_io.NodeOutput( + await generate_video( + RunwayImageToVideoRequest( + promptText=prompt, + seed=seed, + model=Model("gen3a_turbo"), + duration=Duration(duration), + ratio=AspectRatio(ratio), + promptImage=RunwayPromptImageObject( + root=[ + RunwayPromptImageDetailedObject( + uri=str(download_urls[0]), position="first" + ), + RunwayPromptImageDetailedObject( + uri=str(download_urls[1]), position="last" + ), + ] + ), ), - ), - auth_kwargs=kwargs, - node_id=unique_id, + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=AVERAGE_DURATION_FLF_SECONDS, + ) ) -class RunwayTextToImageNode(ComfyNodeABC): - """Runway Text to Image Node.""" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "api_call" - CATEGORY = "api node/image/Runway" - API_NODE = True - DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation." +class RunwayTextToImageNode(comfy_io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": model_field_to_node_input( - IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True + def define_schema(cls): + return comfy_io.Schema( + node_id="RunwayTextToImageNode", + display_name="Runway Text to Image", + category="api node/image/Runway", + description="Generate an image from a text prompt using Runway's Gen 4 model. " + "You can also include reference image to guide the generation.", + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the generation", ), - "ratio": model_field_to_node_input( - IO.COMBO, - RunwayTextToImageRequest, + comfy_io.Combo.Input( "ratio", - enum_type=RunwayTextToImageAspectRatioEnum, + options=[model.value for model in RunwayTextToImageAspectRatioEnum], ), - }, - "optional": { - "reference_image": ( - IO.IMAGE, - {"tooltip": "Optional reference image to guide the generation"}, - ) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - def validate_task_created(self, response: RunwayTextToImageResponse) -> bool: - """ - Validate the task creation response from the Runway API matches - expected format. - """ - if not bool(response.id): - raise RunwayApiError("Invalid initial response from Runway API.") - return True - - def validate_response(self, response: TaskStatusResponse) -> bool: - """ - Validate the successful task status response from the Runway API - matches expected format. - """ - if not response.output or len(response.output) == 0: - raise RunwayApiError( - "Runway task succeeded but no image data found in response." - ) - return True - - async def get_response( - self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None - ) -> TaskStatusResponse: - """Poll the task status until it is finished then get the response.""" - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - estimated_duration=AVERAGE_DURATION_T2I_SECONDS, - node_id=node_id, + comfy_io.Image.Input( + "reference_image", + tooltip="Optional reference image to guide the generation", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, ) - async def api_call( - self, + @classmethod + async def execute( + cls, prompt: str, ratio: str, reference_image: Optional[torch.Tensor] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[torch.Tensor]: - # Validate inputs + ) -> comfy_io.NodeOutput: validate_string(prompt, min_length=1) + auth_kwargs = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + # Prepare reference images if provided reference_images = None if reference_image is not None: - validate_input_image(reference_image) + validate_image_dimensions(reference_image, max_width=7999, max_height=7999) + validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) download_urls = await upload_images_to_comfyapi( reference_image, max_images=1, mime_type="image/png", - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) - if len(download_urls) != 1: - raise RunwayApiError("Failed to upload reference image to comfy api.") - reference_images = [ReferenceImage(uri=str(download_urls[0]))] - # Create request request = RunwayTextToImageRequest( promptText=prompt, model=Model4.gen4_image, @@ -593,7 +565,6 @@ class RunwayTextToImageNode(ComfyNodeABC): referenceImages=reference_images, ) - # Execute initial request initial_operation = SynchronousOperation( endpoint=ApiEndpoint( path=PATH_TEXT_TO_IMAGE, @@ -602,34 +573,33 @@ class RunwayTextToImageNode(ComfyNodeABC): response_model=RunwayTextToImageResponse, ), request=request, - auth_kwargs=kwargs, + auth_kwargs=auth_kwargs, ) initial_response = await initial_operation.execute() - self.validate_task_created(initial_response) - task_id = initial_response.id # Poll for completion - final_response = await self.get_response( - task_id, auth_kwargs=kwargs, node_id=unique_id + final_response = await get_response( + initial_response.id, + auth_kwargs=auth_kwargs, + node_id=cls.hidden.unique_id, + estimated_duration=AVERAGE_DURATION_T2I_SECONDS, ) - self.validate_response(final_response) + if not final_response.output: + raise RunwayApiError("Runway task succeeded but no image data found in response.") - # Download and return image - image_url = get_image_url_from_task_status(final_response) - return (await download_url_to_image_tensor(image_url),) + return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response))) -NODE_CLASS_MAPPINGS = { - "RunwayFirstLastFrameNode": RunwayFirstLastFrameNode, - "RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a, - "RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4, - "RunwayTextToImageNode": RunwayTextToImageNode, -} +class RunwayExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + RunwayFirstLastFrameNode, + RunwayImageToVideoNodeGen3a, + RunwayImageToVideoNodeGen4, + RunwayTextToImageNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video", - "RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)", - "RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)", - "RunwayTextToImageNode": "Runway Text to Image", -} +async def comfy_entrypoint() -> RunwayExtension: + return RunwayExtension() diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 31309d831..5ba5ed986 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -1,5 +1,8 @@ from inspect import cleandoc -from comfy.comfy_types.node_typing import IO +from typing import Optional +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, Input, io as comfy_io from comfy_api_nodes.apis.stability_api import ( StabilityUpscaleConservativeRequest, StabilityUpscaleCreativeRequest, @@ -12,6 +15,10 @@ from comfy_api_nodes.apis.stability_api import ( Stability_SD3_5_Model, Stability_SD3_5_GenerationMode, get_stability_style_presets, + StabilityTextToAudioRequest, + StabilityAudioToAudioRequest, + StabilityAudioInpaintRequest, + StabilityAudioResponse, ) from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -24,7 +31,10 @@ from comfy_api_nodes.apinode_utils import ( bytesio_to_image_tensor, tensor_to_bytesio, validate_string, + audio_bytes_to_audio_input, + audio_input_to_mp3, ) +from comfy_api_nodes.util.validation_utils import validate_audio_duration import torch import base64 @@ -46,87 +56,94 @@ def get_async_dummy_status(x: StabilityResultsGetResponse): return StabilityPollStatus.in_progress -class StabilityStableImageUltraNode: +class StabilityStableImageUltraNode(comfy_io.ComfyNode): """ Generates images synchronously based on prompt and resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + - "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityStableImageUltraNode", + display_name="Stability AI Stable Image Ultra", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines" + "elements, colors, and subjects will lead to better results. " + "To control the weight of a given word use the format `(word:weight)`," + "where `word` is the word you'd like to control the weight of and `weight`" + "is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" + - "would convey a sky that was blue and green, but more green than blue." - }, + "would convey a sky that was blue and green, but more green than blue.", ), - "aspect_ratio": ([x.value for x in StabilityAspectRatio], - { - "default": StabilityAspectRatio.ratio_1_1, - "tooltip": "Aspect ratio of generated image.", - }, + comfy_io.Combo.Input( + "aspect_ratio", + options=[x.value for x in StabilityAspectRatio], + default=StabilityAspectRatio.ratio_1_1.value, + tooltip="Aspect ratio of generated image.", ), - "style_preset": (get_stability_style_presets(), - { - "tooltip": "Optional desired style of generated image.", - }, + comfy_io.Combo.Input( + "style_preset", + options=get_stability_style_presets(), + tooltip="Optional desired style of generated image.", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", ), - }, - "optional": { - "image": (IO.IMAGE,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature." - }, + comfy_io.Image.Input( + "image", + optional=True, ), - "image_denoise": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", - }, + comfy_io.String.Input( + "negative_prompt", + default="", + tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + comfy_io.Float.Input( + "image_denoise", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) - async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, - negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, - **kwargs): + @classmethod + async def execute( + cls, + prompt: str, + aspect_ratio: str, + style_preset: str, + seed: int, + image: Optional[torch.Tensor] = None, + negative_prompt: str = "", + image_denoise: Optional[float] = 0.5, + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) # prepare image binary if image present image_binary = None @@ -144,6 +161,11 @@ class StabilityStableImageUltraNode: "image": image_binary } + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/stability/v2beta/stable-image/generate/ultra", @@ -161,7 +183,7 @@ class StabilityStableImageUltraNode: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -171,95 +193,106 @@ class StabilityStableImageUltraNode: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return comfy_io.NodeOutput(returned_image) -class StabilityStableImageSD_3_5Node: +class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode): """ Generates images synchronously based on prompt and resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityStableImageSD_3_5Node", + display_name="Stability AI Stable Diffusion 3.5 Image", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.", + ), + comfy_io.Combo.Input( + "model", + options=[x.value for x in Stability_SD3_5_Model], + ), + comfy_io.Combo.Input( + "aspect_ratio", + options=[x.value for x in StabilityAspectRatio], + default=StabilityAspectRatio.ratio_1_1.value, + tooltip="Aspect ratio of generated image.", + ), + comfy_io.Combo.Input( + "style_preset", + options=get_stability_style_presets(), + tooltip="Optional desired style of generated image.", + ), + comfy_io.Float.Input( + "cfg_scale", + default=4.0, + min=1.0, + max=10.0, + step=0.1, + tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + comfy_io.Image.Input( + "image", + optional=True, + ), + comfy_io.String.Input( + "negative_prompt", + default="", + tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, + ), + comfy_io.Float.Input( + "image_denoise", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." - }, - ), - "model": ([x.value for x in Stability_SD3_5_Model],), - "aspect_ratio": ([x.value for x in StabilityAspectRatio], - { - "default": StabilityAspectRatio.ratio_1_1, - "tooltip": "Aspect ratio of generated image.", - }, - ), - "style_preset": (get_stability_style_presets(), - { - "tooltip": "Optional desired style of generated image.", - }, - ), - "cfg_scale": ( - IO.FLOAT, - { - "default": 4.0, - "min": 1.0, - "max": 10.0, - "step": 0.1, - "tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": { - "image": (IO.IMAGE,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." - }, - ), - "image_denoise": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, - negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, - **kwargs): + async def execute( + cls, + model: str, + prompt: str, + aspect_ratio: str, + style_preset: str, + seed: int, + cfg_scale: float, + image: Optional[torch.Tensor] = None, + negative_prompt: str = "", + image_denoise: Optional[float] = 0.5, + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) # prepare image binary if image present image_binary = None @@ -280,6 +313,11 @@ class StabilityStableImageSD_3_5Node: "image": image_binary } + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/stability/v2beta/stable-image/generate/sd3", @@ -300,7 +338,7 @@ class StabilityStableImageSD_3_5Node: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -310,72 +348,75 @@ class StabilityStableImageSD_3_5Node: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return comfy_io.NodeOutput(returned_image) -class StabilityUpscaleConservativeNode: +class StabilityUpscaleConservativeNode(comfy_io.ComfyNode): """ Upscale image with minimal alterations to 4K resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityUpscaleConservativeNode", + display_name="Stability AI Upscale Conservative", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.", + ), + comfy_io.Float.Input( + "creativity", + default=0.35, + min=0.2, + max=0.5, + step=0.01, + tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + comfy_io.String.Input( + "negative_prompt", + default="", + tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." - }, - ), - "creativity": ( - IO.FLOAT, - { - "default": 0.35, - "min": 0.2, - "max": 0.5, - "step": 0.01, - "tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, - **kwargs): + async def execute( + cls, + image: torch.Tensor, + prompt: str, + creativity: float, + seed: int, + negative_prompt: str = "", + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -386,6 +427,11 @@ class StabilityUpscaleConservativeNode: "image": image_binary } + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/stability/v2beta/stable-image/upscale/conservative", @@ -401,7 +447,7 @@ class StabilityUpscaleConservativeNode: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -411,77 +457,81 @@ class StabilityUpscaleConservativeNode: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return comfy_io.NodeOutput(returned_image) -class StabilityUpscaleCreativeNode: +class StabilityUpscaleCreativeNode(comfy_io.ComfyNode): """ Upscale image with minimal alterations to 4K resolution. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityUpscaleCreativeNode", + display_name="Stability AI Upscale Creative", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + comfy_io.String.Input( + "prompt", + multiline=True, + default="", + tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.", + ), + comfy_io.Float.Input( + "creativity", + default=0.3, + min=0.1, + max=0.5, + step=0.01, + tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.", + ), + comfy_io.Combo.Input( + "style_preset", + options=get_stability_style_presets(), + tooltip="Optional desired style of generated image.", + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + comfy_io.String.Input( + "negative_prompt", + default="", + tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", + force_input=True, + optional=True, + ), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results." - }, - ), - "creativity": ( - IO.FLOAT, - { - "default": 0.3, - "min": 0.1, - "max": 0.5, - "step": 0.01, - "tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.", - }, - ), - "style_preset": (get_stability_style_presets(), - { - "tooltip": "Optional desired style of generated image.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 4294967294, - "control_after_generate": True, - "tooltip": "The random seed used for creating the noise.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, - **kwargs): + async def execute( + cls, + image: torch.Tensor, + prompt: str, + creativity: float, + style_preset: str, + seed: int, + negative_prompt: str = "", + ) -> comfy_io.NodeOutput: validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -494,6 +544,11 @@ class StabilityUpscaleCreativeNode: "image": image_binary } + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/stability/v2beta/stable-image/upscale/creative", @@ -510,7 +565,7 @@ class StabilityUpscaleCreativeNode: ), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -525,7 +580,8 @@ class StabilityUpscaleCreativeNode: completed_statuses=[StabilityPollStatus.finished], failed_statuses=[StabilityPollStatus.failed], status_extractor=lambda x: get_async_dummy_status(x), - auth_kwargs=kwargs, + auth_kwargs=auth, + node_id=cls.hidden.unique_id, ) response_poll: StabilityResultsGetResponse = await operation.execute() @@ -535,41 +591,48 @@ class StabilityUpscaleCreativeNode: image_data = base64.b64decode(response_poll.result) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return comfy_io.NodeOutput(returned_image) -class StabilityUpscaleFastNode: +class StabilityUpscaleFastNode(comfy_io.ComfyNode): """ Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images. """ - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Stability AI" + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityUpscaleFastNode", + display_name="Stability AI Upscale Fast", + category="api node/image/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Image.Input("image"), + ], + outputs=[ + comfy_io.Image.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call(self, image: torch.Tensor, **kwargs): + async def execute(cls, image: torch.Tensor) -> comfy_io.NodeOutput: image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read() files = { "image": image_binary } + auth = { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + } + operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/stability/v2beta/stable-image/upscale/fast", @@ -580,7 +643,7 @@ class StabilityUpscaleFastNode: request=EmptyRequest(), files=files, content_type="multipart/form-data", - auth_kwargs=kwargs, + auth_kwargs=auth, ) response_api = await operation.execute() @@ -590,24 +653,323 @@ class StabilityUpscaleFastNode: image_data = base64.b64decode(response_api.image) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) - return (returned_image,) + return comfy_io.NodeOutput(returned_image) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "StabilityStableImageUltraNode": StabilityStableImageUltraNode, - "StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node, - "StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode, - "StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode, - "StabilityUpscaleFastNode": StabilityUpscaleFastNode, -} +class StabilityTextToAudio(comfy_io.ComfyNode): + """Generates high-quality music and sound effects from text descriptions.""" -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "StabilityStableImageUltraNode": "Stability AI Stable Image Ultra", - "StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image", - "StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative", - "StabilityUpscaleCreativeNode": "Stability AI Upscale Creative", - "StabilityUpscaleFastNode": "Stability AI Upscale Fast", -} + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityTextToAudio", + display_name="Stability AI Text To Audio", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + comfy_io.String.Input("prompt", multiline=True, default=""), + comfy_io.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + comfy_io.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + ], + outputs=[ + comfy_io.Audio.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> comfy_io.NodeOutput: + validate_string(prompt, max_length=10000) + payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", + method=HttpMethod.POST, + request_model=StabilityTextToAudioRequest, + response_model=StabilityAudioResponse, + ), + request=payload, + content_type="multipart/form-data", + auth_kwargs= { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + ) + response_api = await operation.execute() + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + +class StabilityAudioToAudio(comfy_io.ComfyNode): + """Transforms existing audio samples into new high-quality compositions using text instructions.""" + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityAudioToAudio", + display_name="Stability AI Audio To Audio", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + comfy_io.String.Input("prompt", multiline=True, default=""), + comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."), + comfy_io.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + comfy_io.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + comfy_io.Float.Input( + "strength", + default=1, + min=0.01, + max=1.0, + step=0.01, + display_mode=comfy_io.NumberDisplay.slider, + tooltip="Parameter controls how much influence the audio parameter has on the generated audio.", + optional=True, + ), + ], + outputs=[ + comfy_io.Audio.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float + ) -> comfy_io.NodeOutput: + validate_string(prompt, max_length=10000) + validate_audio_duration(audio, 6, 190) + payload = StabilityAudioToAudioRequest( + prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength + ) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", + method=HttpMethod.POST, + request_model=StabilityAudioToAudioRequest, + response_model=StabilityAudioResponse, + ), + request=payload, + content_type="multipart/form-data", + files={"audio": audio_input_to_mp3(audio)}, + auth_kwargs= { + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + ) + response_api = await operation.execute() + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + +class StabilityAudioInpaint(comfy_io.ComfyNode): + """Transforms part of existing audio sample using text instructions.""" + + @classmethod + def define_schema(cls): + return comfy_io.Schema( + node_id="StabilityAudioInpaint", + display_name="Stability AI Audio Inpaint", + category="api node/audio/Stability AI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + comfy_io.Combo.Input( + "model", + options=["stable-audio-2.5"], + ), + comfy_io.String.Input("prompt", multiline=True, default=""), + comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."), + comfy_io.Int.Input( + "duration", + default=190, + min=1, + max=190, + step=1, + tooltip="Controls the duration in seconds of the generated audio.", + optional=True, + ), + comfy_io.Int.Input( + "seed", + default=0, + min=0, + max=4294967294, + step=1, + display_mode=comfy_io.NumberDisplay.number, + control_after_generate=True, + tooltip="The random seed used for generation.", + optional=True, + ), + comfy_io.Int.Input( + "steps", + default=8, + min=4, + max=8, + step=1, + tooltip="Controls the number of sampling steps.", + optional=True, + ), + comfy_io.Int.Input( + "mask_start", + default=30, + min=0, + max=190, + step=1, + optional=True, + ), + comfy_io.Int.Input( + "mask_end", + default=190, + min=0, + max=190, + step=1, + optional=True, + ), + ], + outputs=[ + comfy_io.Audio.Output(), + ], + hidden=[ + comfy_io.Hidden.auth_token_comfy_org, + comfy_io.Hidden.api_key_comfy_org, + comfy_io.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + audio: Input.Audio, + duration: int, + seed: int, + steps: int, + mask_start: int, + mask_end: int, + ) -> comfy_io.NodeOutput: + validate_string(prompt, max_length=10000) + if mask_end <= mask_start: + raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})") + validate_audio_duration(audio, 6, 190) + + payload = StabilityAudioInpaintRequest( + prompt=prompt, + model=model, + duration=duration, + seed=seed, + steps=steps, + mask_start=mask_start, + mask_end=mask_end, + ) + operation = SynchronousOperation( + endpoint=ApiEndpoint( + path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", + method=HttpMethod.POST, + request_model=StabilityAudioInpaintRequest, + response_model=StabilityAudioResponse, + ), + request=payload, + content_type="multipart/form-data", + files={"audio": audio_input_to_mp3(audio)}, + auth_kwargs={ + "auth_token": cls.hidden.auth_token_comfy_org, + "comfy_api_key": cls.hidden.api_key_comfy_org, + }, + ) + response_api = await operation.execute() + if not response_api.audio: + raise ValueError("No audio file was received in response.") + return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) + + +class StabilityExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]: + return [ + StabilityStableImageUltraNode, + StabilityStableImageSD_3_5Node, + StabilityUpscaleConservativeNode, + StabilityUpscaleCreativeNode, + StabilityUpscaleFastNode, + StabilityTextToAudio, + StabilityAudioToAudio, + StabilityAudioInpaint, + ] + + +async def comfy_entrypoint() -> StabilityExtension: + return StabilityExtension() diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index 606b794bf..ca913e9b3 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -2,7 +2,7 @@ import logging from typing import Optional import torch -from comfy_api.input.video_types import VideoInput +from comfy_api.latest import Input def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: @@ -101,7 +101,7 @@ def validate_aspect_ratio_closeness( def validate_video_dimensions( - video: VideoInput, + video: Input.Video, min_width: Optional[int] = None, max_width: Optional[int] = None, min_height: Optional[int] = None, @@ -126,7 +126,7 @@ def validate_video_dimensions( def validate_video_duration( - video: VideoInput, + video: Input.Video, min_duration: Optional[float] = None, max_duration: Optional[float] = None, ): @@ -151,3 +151,17 @@ def get_number_of_images(images): if isinstance(images, torch.Tensor): return images.shape[0] if images.ndim >= 4 else 1 return len(images) + + +def validate_audio_duration( + audio: Input.Audio, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, +) -> None: + sr = int(audio["sample_rate"]) + dur = int(audio["waveform"].shape[-1]) / sr + eps = 1.0 / sr + if min_duration is not None and dur + eps < min_duration: + raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s") + if max_duration is not None and dur - eps > max_duration: + raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s") diff --git a/comfy_execution/progress.py b/comfy_execution/progress.py index e8f5ede1e..f951a3350 100644 --- a/comfy_execution/progress.py +++ b/comfy_execution/progress.py @@ -181,8 +181,9 @@ class WebUIProgressHandler(ProgressHandler): } # Send a combined progress_state message with all node states + # Include client_id to ensure message is only sent to the initiating client self.server_instance.send_sync( - "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes} + "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id ) @override diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py index 8d856d0e8..edd5dadd4 100644 --- a/comfy_extras/nodes_align_your_steps.py +++ b/comfy_extras/nodes_align_your_steps.py @@ -1,6 +1,10 @@ #from: https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/howto.html import numpy as np import torch +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + def loglinear_interp(t_steps, num_steps): """ @@ -19,25 +23,30 @@ NOISE_LEVELS = {"SD1": [14.6146412293, 6.4745760956, 3.8636745985, 2.694615152 "SDXL":[14.6146412293, 6.3184485287, 3.7681790315, 2.1811480769, 1.3405244945, 0.8620721141, 0.5550693289, 0.3798540708, 0.2332364134, 0.1114188177, 0.0291671582], "SVD": [700.00, 54.5, 15.886, 7.977, 4.248, 1.789, 0.981, 0.403, 0.173, 0.034, 0.002]} -class AlignYourStepsScheduler: +class AlignYourStepsScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model_type": (["SD1", "SDXL", "SVD"], ), - "steps": ("INT", {"default": 10, "min": 1, "max": 10000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" - - FUNCTION = "get_sigmas" + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="AlignYourStepsScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]), + io.Int.Input("steps", default=10, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()], + ) def get_sigmas(self, model_type, steps, denoise): + # Deprecated: use the V3 schema's `execute` method instead of this. + return AlignYourStepsScheduler().execute(model_type, steps, denoise).result + + @classmethod + def execute(cls, model_type, steps, denoise) -> io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = round(steps * denoise) sigmas = NOISE_LEVELS[model_type][:] @@ -46,8 +55,15 @@ class AlignYourStepsScheduler: sigmas = sigmas[-(total_steps + 1):] sigmas[-1] = 0 - return (torch.FloatTensor(sigmas), ) + return io.NodeOutput(torch.FloatTensor(sigmas)) -NODE_CLASS_MAPPINGS = { - "AlignYourStepsScheduler": AlignYourStepsScheduler, -} + +class AlignYourStepsExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + AlignYourStepsScheduler, + ] + +async def comfy_entrypoint() -> AlignYourStepsExtension: + return AlignYourStepsExtension() diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 9d2988f5f..c170e9fd9 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -162,7 +162,12 @@ def easycache_sample_wrapper(executor, *args, **kwargs): logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}") logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}") total_steps = len(args[3])-1 - logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).") + # catch division by zero for log statement; sucks to crash after all sampling is done + try: + speedup = total_steps/(total_steps-easycache.total_steps_skipped) + except ZeroDivisionError: + speedup = 1.0 + logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).") easycache.reset() guider.model_options = orig_model_options diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index c8db75bb3..25e029ffd 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -105,7 +105,7 @@ class FluxKontextMultiReferenceLatentMethod: def INPUT_TYPES(s): return {"required": { "conditioning": ("CONDITIONING", ), - "reference_latents_method": (("offset", "index"), ), + "reference_latents_method": (("offset", "index", "uxo/uno"), ), }} RETURN_TYPES = ("CONDITIONING",) @@ -115,6 +115,8 @@ class FluxKontextMultiReferenceLatentMethod: CATEGORY = "advanced/conditioning/flux" def append(self, conditioning, reference_latents_method): + if "uxo" in reference_latents_method or "uso" in reference_latents_method: + reference_latents_method = "uxo" c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) return (c, ) diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index d7278e7a7..ce031ceb2 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -113,6 +113,20 @@ class HunyuanImageToVideo: out_latent["samples"] = latent return (positive, out_latent) +class EmptyHunyuanImageLatent: + @classmethod + def INPUT_TYPES(s): + return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} + RETURN_TYPES = ("LATENT",) + FUNCTION = "generate" + + CATEGORY = "latent" + + def generate(self, width, height, batch_size=1): + latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device()) + return ({"samples":latent}, ) NODE_CLASS_MAPPINGS = { @@ -120,4 +134,5 @@ NODE_CLASS_MAPPINGS = { "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, "HunyuanImageToVideo": HunyuanImageToVideo, + "EmptyHunyuanImageLatent": EmptyHunyuanImageLatent, } diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 51e45336a..f6e71e0a8 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -8,13 +8,16 @@ import folder_paths import comfy.model_management from comfy.cli_args import args - class EmptyLatentHunyuan3Dv2: @classmethod def INPUT_TYPES(s): - return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} + return { + "required": { + "resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), + } + } + RETURN_TYPES = ("LATENT",) FUNCTION = "generate" @@ -24,7 +27,6 @@ class EmptyLatentHunyuan3Dv2: latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) return ({"samples": latent, "type": "hunyuan3dv2"}, ) - class Hunyuan3Dv2Conditioning: @classmethod def INPUT_TYPES(s): @@ -81,7 +83,6 @@ class VOXEL: def __init__(self, data): self.data = data - class VAEDecodeHunyuan3D: @classmethod def INPUT_TYPES(s): @@ -99,7 +100,6 @@ class VAEDecodeHunyuan3D: voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) return (voxels, ) - def voxel_to_mesh(voxels, threshold=0.5, device=None): if device is None: device = torch.device("cpu") @@ -230,13 +230,9 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1] ], device=device) - corner_values = torch.zeros((cell_positions.shape[0], 8), device=device) - for c, (dz, dy, dx) in enumerate(corner_offsets): - corner_values[:, c] = padded[ - cell_positions[:, 0] + dz, - cell_positions[:, 1] + dy, - cell_positions[:, 2] + dx - ] + pos = cell_positions.unsqueeze(1) + corner_offsets.unsqueeze(0) + z_idx, y_idx, x_idx = pos.unbind(-1) + corner_values = padded[z_idx, y_idx, x_idx] corner_signs = corner_values > threshold has_inside = torch.any(corner_signs, dim=1) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index fba80e2ae..392aea32c 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -625,6 +625,37 @@ class ImageFlip: return (image,) +class ImageScaleToMaxDimension: + upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"] + + @classmethod + def INPUT_TYPES(s): + return {"required": {"image": ("IMAGE",), + "upscale_method": (s.upscale_methods,), + "largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} + RETURN_TYPES = ("IMAGE",) + FUNCTION = "upscale" + + CATEGORY = "image/upscaling" + + def upscale(self, image, upscale_method, largest_size): + height = image.shape[1] + width = image.shape[2] + + if height > width: + width = round((width / height) * largest_size) + height = largest_size + elif width > height: + height = round((height / width) * largest_size) + width = largest_size + else: + height = largest_size + width = largest_size + + samples = image.movedim(-1, 1) + s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") + s = s.movedim(1, -1) + return (s,) NODE_CLASS_MAPPINGS = { "ImageCrop": ImageCrop, @@ -639,4 +670,5 @@ NODE_CLASS_MAPPINGS = { "GetImageSize": GetImageSize, "ImageRotate": ImageRotate, "ImageFlip": ImageFlip, + "ImageScaleToMaxDimension": ImageScaleToMaxDimension, } diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 65e766b52..783c59b6b 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -1,4 +1,5 @@ import torch +from torch import nn import folder_paths import comfy.utils import comfy.ops @@ -58,6 +59,136 @@ class QwenImageBlockWiseControlNet(torch.nn.Module): return self.controlnet_blocks[block_id](img, controlnet_conditioning) +class SigLIPMultiFeatProjModel(torch.nn.Module): + """ + SigLIP Multi-Feature Projection Model for processing style features from different layers + and projecting them into a unified hidden space. + + Args: + siglip_token_nums (int): Number of SigLIP tokens, default 257 + style_token_nums (int): Number of style tokens, default 256 + siglip_token_dims (int): Dimension of SigLIP tokens, default 1536 + hidden_size (int): Hidden layer size, default 3072 + context_layer_norm (bool): Whether to use context layer normalization, default False + """ + + def __init__( + self, + siglip_token_nums: int = 729, + style_token_nums: int = 64, + siglip_token_dims: int = 1152, + hidden_size: int = 3072, + context_layer_norm: bool = True, + device=None, dtype=None, operations=None + ): + super().__init__() + + # High-level feature processing (layer -2) + self.high_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.high_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.high_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + # Mid-level feature processing (layer -11) + self.mid_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.mid_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.mid_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + # Low-level feature processing (layer -20) + self.low_embedding_linear = nn.Sequential( + operations.Linear(siglip_token_nums, style_token_nums), + nn.SiLU() + ) + self.low_layer_norm = ( + operations.LayerNorm(siglip_token_dims) if context_layer_norm else nn.Identity() + ) + self.low_projection = operations.Linear(siglip_token_dims, hidden_size, bias=True) + + def forward(self, siglip_outputs): + """ + Forward pass function + + Args: + siglip_outputs: Output from SigLIP model, containing hidden_states + + Returns: + torch.Tensor: Concatenated multi-layer features with shape [bs, 3*style_token_nums, hidden_size] + """ + dtype = next(self.high_embedding_linear.parameters()).dtype + + # Process high-level features (layer -2) + high_embedding = self._process_layer_features( + siglip_outputs[2], + self.high_embedding_linear, + self.high_layer_norm, + self.high_projection, + dtype + ) + + # Process mid-level features (layer -11) + mid_embedding = self._process_layer_features( + siglip_outputs[1], + self.mid_embedding_linear, + self.mid_layer_norm, + self.mid_projection, + dtype + ) + + # Process low-level features (layer -20) + low_embedding = self._process_layer_features( + siglip_outputs[0], + self.low_embedding_linear, + self.low_layer_norm, + self.low_projection, + dtype + ) + + # Concatenate features from all layersmodel_patch + return torch.cat((high_embedding, mid_embedding, low_embedding), dim=1) + + def _process_layer_features( + self, + hidden_states: torch.Tensor, + embedding_linear: nn.Module, + layer_norm: nn.Module, + projection: nn.Module, + dtype: torch.dtype + ) -> torch.Tensor: + """ + Helper function to process features from a single layer + + Args: + hidden_states: Input hidden states [bs, seq_len, dim] + embedding_linear: Embedding linear layer + layer_norm: Layer normalization + projection: Projection layer + dtype: Target data type + + Returns: + torch.Tensor: Processed features [bs, style_token_nums, hidden_size] + """ + # Transform dimensions: [bs, seq_len, dim] -> [bs, dim, seq_len] -> [bs, dim, style_token_nums] -> [bs, style_token_nums, dim] + embedding = embedding_linear( + hidden_states.to(dtype).transpose(1, 2) + ).transpose(1, 2) + + # Apply layer normalization + embedding = layer_norm(embedding) + + # Project to target hidden space + embedding = projection(embedding) + + return embedding + class ModelPatchLoader: @classmethod def INPUT_TYPES(s): @@ -73,9 +204,14 @@ class ModelPatchLoader: model_patch_path = folder_paths.get_full_path_or_raise("model_patches", name) sd = comfy.utils.load_torch_file(model_patch_path, safe_load=True) dtype = comfy.utils.weight_dtype(sd) - # TODO: this node will work with more types of model patches - additional_in_dim = sd["img_in.weight"].shape[1] - 64 - model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + + if 'controlnet_blocks.0.y_rms.weight' in sd: + additional_in_dim = sd["img_in.weight"].shape[1] - 64 + model = QwenImageBlockWiseControlNet(additional_in_dim=additional_in_dim, device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + elif 'feature_embedder.mid_layer_norm.bias' in sd: + sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) + model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) return (model,) @@ -157,7 +293,51 @@ class QwenImageDiffsynthControlnet: return (model_patched,) +class UsoStyleProjectorPatch: + def __init__(self, model_patch, encoded_image): + self.model_patch = model_patch + self.encoded_image = encoded_image + + def __call__(self, kwargs): + txt_ids = kwargs.get("txt_ids") + txt = kwargs.get("txt") + siglip_embedding = self.model_patch.model(self.encoded_image.to(txt.dtype)).to(txt.dtype) + txt = torch.cat([siglip_embedding, txt], dim=1) + kwargs['txt'] = txt + kwargs['txt_ids'] = torch.cat([torch.zeros(siglip_embedding.shape[0], siglip_embedding.shape[1], 3, dtype=txt_ids.dtype, device=txt_ids.device), txt_ids], dim=1) + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + + +class USOStyleReference: + @classmethod + def INPUT_TYPES(s): + return {"required": {"model": ("MODEL",), + "model_patch": ("MODEL_PATCH",), + "clip_vision_output": ("CLIP_VISION_OUTPUT", ), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_patch" + EXPERIMENTAL = True + + CATEGORY = "advanced/model_patches/flux" + + def apply_patch(self, model, model_patch, clip_vision_output): + encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states)) + model_patched = model.clone() + model_patched.set_model_post_input_patch(UsoStyleProjectorPatch(model_patch, encoded_image)) + return (model_patched,) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, + "USOStyleReference": USOStyleReference, } diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 1f93f87a7..5a1aeba80 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -1,98 +1,109 @@ -# Primitive nodes that are evaluated at backend. -from __future__ import annotations - import sys +from typing_extensions import override -from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, IO +from comfy_api.latest import ComfyExtension, io -class String(ComfyNodeABC): +class String(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.STRING, {})}, - } + def define_schema(cls): + return io.Schema( + node_id="PrimitiveString", + display_name="String", + category="utils/primitive", + inputs=[ + io.String.Input("value"), + ], + outputs=[io.String.Output()], + ) - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: str) -> tuple[str]: - return (value,) - - -class StringMultiline(ComfyNodeABC): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.STRING, {"multiline": True,},)}, - } - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: str) -> tuple[str]: - return (value,) + def execute(cls, value: str) -> io.NodeOutput: + return io.NodeOutput(value) -class Int(ComfyNodeABC): +class StringMultiline(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.INT, {"min": -sys.maxsize, "max": sys.maxsize, "control_after_generate": True})}, - } + def define_schema(cls): + return io.Schema( + node_id="PrimitiveStringMultiline", + display_name="String (Multiline)", + category="utils/primitive", + inputs=[ + io.String.Input("value", multiline=True), + ], + outputs=[io.String.Output()], + ) - RETURN_TYPES = (IO.INT,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: int) -> tuple[int]: - return (value,) - - -class Float(ComfyNodeABC): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.FLOAT, {"min": -sys.maxsize, "max": sys.maxsize})}, - } - - RETURN_TYPES = (IO.FLOAT,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: float) -> tuple[float]: - return (value,) + def execute(cls, value: str) -> io.NodeOutput: + return io.NodeOutput(value) -class Boolean(ComfyNodeABC): +class Int(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": {"value": (IO.BOOLEAN, {})}, - } + def define_schema(cls): + return io.Schema( + node_id="PrimitiveInt", + display_name="Int", + category="utils/primitive", + inputs=[ + io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True), + ], + outputs=[io.Int.Output()], + ) - RETURN_TYPES = (IO.BOOLEAN,) - FUNCTION = "execute" - CATEGORY = "utils/primitive" - - def execute(self, value: bool) -> tuple[bool]: - return (value,) + @classmethod + def execute(cls, value: int) -> io.NodeOutput: + return io.NodeOutput(value) -NODE_CLASS_MAPPINGS = { - "PrimitiveString": String, - "PrimitiveStringMultiline": StringMultiline, - "PrimitiveInt": Int, - "PrimitiveFloat": Float, - "PrimitiveBoolean": Boolean, -} +class Float(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PrimitiveFloat", + display_name="Float", + category="utils/primitive", + inputs=[ + io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize), + ], + outputs=[io.Float.Output()], + ) -NODE_DISPLAY_NAME_MAPPINGS = { - "PrimitiveString": "String", - "PrimitiveStringMultiline": "String (Multiline)", - "PrimitiveInt": "Int", - "PrimitiveFloat": "Float", - "PrimitiveBoolean": "Boolean", -} + @classmethod + def execute(cls, value: float) -> io.NodeOutput: + return io.NodeOutput(value) + + +class Boolean(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="PrimitiveBoolean", + display_name="Boolean", + category="utils/primitive", + inputs=[ + io.Boolean.Input("value"), + ], + outputs=[io.Boolean.Output()], + ) + + @classmethod + def execute(cls, value: bool) -> io.NodeOutput: + return io.NodeOutput(value) + + +class PrimitivesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + String, + StringMultiline, + Int, + Float, + Boolean, + ] + +async def comfy_entrypoint() -> PrimitivesExtension: + return PrimitivesExtension() diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 003403215..04c0b366a 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -17,55 +17,61 @@ """ import torch -import nodes +from typing_extensions import override + import comfy.utils +import nodes +from comfy_api.latest import ComfyExtension, io -class StableCascade_EmptyLatentImage: - def __init__(self, device="cpu"): - self.device = device +class StableCascade_EmptyLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StableCascade_EmptyLatentImage", + category="latent/stable_cascade", + inputs=[ + io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("compression", default=42, min=4, max=128, step=1), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(display_name="stage_c"), + io.Latent.Output(display_name="stage_b"), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { - "width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), - "height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}), - "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}) - }} - RETURN_TYPES = ("LATENT", "LATENT") - RETURN_NAMES = ("stage_c", "stage_b") - FUNCTION = "generate" - - CATEGORY = "latent/stable_cascade" - - def generate(self, width, height, compression, batch_size=1): + def execute(cls, width, height, compression, batch_size=1): c_latent = torch.zeros([batch_size, 16, height // compression, width // compression]) b_latent = torch.zeros([batch_size, 4, height // 4, width // 4]) - return ({ + return io.NodeOutput({ "samples": c_latent, }, { "samples": b_latent, }) -class StableCascade_StageC_VAEEncode: - def __init__(self, device="cpu"): - self.device = device + +class StableCascade_StageC_VAEEncode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StableCascade_StageC_VAEEncode", + category="latent/stable_cascade", + inputs=[ + io.Image.Input("image"), + io.Vae.Input("vae"), + io.Int.Input("compression", default=42, min=4, max=128, step=1), + ], + outputs=[ + io.Latent.Output(display_name="stage_c"), + io.Latent.Output(display_name="stage_b"), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { - "image": ("IMAGE",), - "vae": ("VAE", ), - "compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}), - }} - RETURN_TYPES = ("LATENT", "LATENT") - RETURN_NAMES = ("stage_c", "stage_b") - FUNCTION = "generate" - - CATEGORY = "latent/stable_cascade" - - def generate(self, image, vae, compression): + def execute(cls, image, vae, compression): width = image.shape[-2] height = image.shape[-3] out_width = (width // compression) * vae.downscale_ratio @@ -75,51 +81,59 @@ class StableCascade_StageC_VAEEncode: c_latent = vae.encode(s[:,:,:,:3]) b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2]) - return ({ + return io.NodeOutput({ "samples": c_latent, }, { "samples": b_latent, }) -class StableCascade_StageB_Conditioning: + +class StableCascade_StageB_Conditioning(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "conditioning": ("CONDITIONING",), - "stage_c": ("LATENT",), - }} - RETURN_TYPES = ("CONDITIONING",) + def define_schema(cls): + return io.Schema( + node_id="StableCascade_StageB_Conditioning", + category="conditioning/stable_cascade", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("stage_c"), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) - FUNCTION = "set_prior" - - CATEGORY = "conditioning/stable_cascade" - - def set_prior(self, conditioning, stage_c): + @classmethod + def execute(cls, conditioning, stage_c): c = [] for t in conditioning: d = t[1].copy() - d['stable_cascade_prior'] = stage_c['samples'] + d["stable_cascade_prior"] = stage_c["samples"] n = [t[0], d] c.append(n) - return (c, ) + return io.NodeOutput(c) -class StableCascade_SuperResolutionControlnet: - def __init__(self, device="cpu"): - self.device = device + +class StableCascade_SuperResolutionControlnet(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="StableCascade_SuperResolutionControlnet", + category="_for_testing/stable_cascade", + is_experimental=True, + inputs=[ + io.Image.Input("image"), + io.Vae.Input("vae"), + ], + outputs=[ + io.Image.Output(display_name="controlnet_input"), + io.Latent.Output(display_name="stage_c"), + io.Latent.Output(display_name="stage_b"), + ], + ) @classmethod - def INPUT_TYPES(s): - return {"required": { - "image": ("IMAGE",), - "vae": ("VAE", ), - }} - RETURN_TYPES = ("IMAGE", "LATENT", "LATENT") - RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b") - FUNCTION = "generate" - - EXPERIMENTAL = True - CATEGORY = "_for_testing/stable_cascade" - - def generate(self, image, vae): + def execute(cls, image, vae): width = image.shape[-2] height = image.shape[-3] batch_size = image.shape[0] @@ -127,15 +141,22 @@ class StableCascade_SuperResolutionControlnet: c_latent = torch.zeros([batch_size, 16, height // 16, width // 16]) b_latent = torch.zeros([batch_size, 4, height // 2, width // 2]) - return (controlnet_input, { + return io.NodeOutput(controlnet_input, { "samples": c_latent, }, { "samples": b_latent, }) -NODE_CLASS_MAPPINGS = { - "StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage, - "StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning, - "StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode, - "StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet, -} + +class StableCascadeExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + StableCascade_EmptyLatentImage, + StableCascade_StageB_Conditioning, + StableCascade_StageC_VAEEncode, + StableCascade_SuperResolutionControlnet, + ] + +async def comfy_entrypoint() -> StableCascadeExtension: + return StableCascadeExtension() diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 969f888b9..69fabb12e 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -5,52 +5,49 @@ import av import torch import folder_paths import json -from typing import Optional, Literal +from typing import Optional +from typing_extensions import override from fractions import Fraction -from comfy.comfy_types import IO, FileLocator, ComfyNodeABC -from comfy_api.latest import Input, InputImpl, Types +from comfy_api.input import AudioInput, ImageInput, VideoInput +from comfy_api.input_impl import VideoFromComponents, VideoFromFile +from comfy_api.util import VideoCodec, VideoComponents, VideoContainer +from comfy_api.latest import ComfyExtension, io, ui from comfy.cli_args import args -class SaveWEBM: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveWEBM(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveWEBM", + category="image/video", + is_experimental=True, + inputs=[ + io.Image.Input("images"), + io.String.Input("filename_prefix", default="ComfyUI"), + io.Combo.Input("codec", options=["vp9", "av1"]), + io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), + io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "codec": (["vp9", "av1"],), - "fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "image/video" - - EXPERIMENTAL = True - - def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) + def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0] + ) file = f"{filename}_{counter:05}_.webm" container = av.open(os.path.join(full_output_folder, file), mode="w") - if prompt is not None: - container.metadata["prompt"] = json.dumps(prompt) + if cls.hidden.prompt is not None: + container.metadata["prompt"] = json.dumps(cls.hidden.prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - container.metadata[x] = json.dumps(extra_pnginfo[x]) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"} stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) @@ -69,63 +66,46 @@ class SaveWEBM: container.mux(stream.encode()) container.close() - results: list[FileLocator] = [{ - "filename": file, - "subfolder": subfolder, - "type": self.type - }] + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) - return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side - -class SaveVideo(ComfyNodeABC): - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type: Literal["output"] = "output" - self.prefix_append = "" +class SaveVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveVideo", + display_name="Save Video", + category="image/video", + description="Saves the input images to your ComfyUI output directory.", + inputs=[ + io.Video.Input("video", tooltip="The video to save."), + io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), + io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": (IO.VIDEO, {"tooltip": "The video to save."}), - "filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}), - "format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}), - "codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}), - }, - "hidden": { - "prompt": "PROMPT", - "extra_pnginfo": "EXTRA_PNGINFO" - }, - } - - RETURN_TYPES = () - FUNCTION = "save_video" - - OUTPUT_NODE = True - - CATEGORY = "image/video" - DESCRIPTION = "Saves the input images to your ComfyUI output directory." - - def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append + def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, - self.output_dir, + folder_paths.get_output_directory(), width, height ) - results: list[FileLocator] = list() saved_metadata = None if not args.disable_metadata: metadata = {} - if extra_pnginfo is not None: - metadata.update(extra_pnginfo) - if prompt is not None: - metadata["prompt"] = prompt + if cls.hidden.extra_pnginfo is not None: + metadata.update(cls.hidden.extra_pnginfo) + if cls.hidden.prompt is not None: + metadata["prompt"] = cls.hidden.prompt if len(metadata) > 0: saved_metadata = metadata - file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" + file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), format=format, @@ -133,83 +113,82 @@ class SaveVideo(ComfyNodeABC): metadata=saved_metadata ) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) - return { "ui": { "images": results, "animated": (True,) } } -class CreateVideo(ComfyNodeABC): +class CreateVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "images": (IO.IMAGE, {"tooltip": "The images to create a video from."}), - "fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}), - }, - "optional": { - "audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="CreateVideo", + display_name="Create Video", + category="image/video", + description="Create a video from images.", + inputs=[ + io.Image.Input("images", tooltip="The images to create a video from."), + io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), + io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + ], + outputs=[ + io.Video.Output(), + ], + ) - RETURN_TYPES = (IO.VIDEO,) - FUNCTION = "create_video" - - CATEGORY = "image/video" - DESCRIPTION = "Create a video from images." - - def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None): - return (InputImpl.VideoFromComponents( - Types.VideoComponents( - images=images, - audio=audio, - frame_rate=Fraction(fps), - ) - ),) - -class GetVideoComponents(ComfyNodeABC): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "video": (IO.VIDEO, {"tooltip": "The video to extract components from."}), - } - } - RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT) - RETURN_NAMES = ("images", "audio", "fps") - FUNCTION = "get_components" + def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput: + return io.NodeOutput( + VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) + ) - CATEGORY = "image/video" - DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate." +class GetVideoComponents(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="GetVideoComponents", + display_name="Get Video Components", + category="image/video", + description="Extracts all components from a video: frames, audio, and framerate.", + inputs=[ + io.Video.Input("video", tooltip="The video to extract components from."), + ], + outputs=[ + io.Image.Output(display_name="images"), + io.Audio.Output(display_name="audio"), + io.Float.Output(display_name="fps"), + ], + ) - def get_components(self, video: Input.Video): + @classmethod + def execute(cls, video: VideoInput) -> io.NodeOutput: components = video.get_components() - return (components.images, components.audio, float(components.frame_rate)) + return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) -class LoadVideo(ComfyNodeABC): +class LoadVideo(io.ComfyNode): @classmethod - def INPUT_TYPES(cls): + def define_schema(cls): input_dir = folder_paths.get_input_directory() files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] files = folder_paths.filter_files_content_types(files, ["video"]) - return {"required": - {"file": (sorted(files), {"video_upload": True})}, - } - - CATEGORY = "image/video" - - RETURN_TYPES = (IO.VIDEO,) - FUNCTION = "load_video" - def load_video(self, file): - video_path = folder_paths.get_annotated_filepath(file) - return (InputImpl.VideoFromFile(video_path),) + return io.Schema( + node_id="LoadVideo", + display_name="Load Video", + category="image/video", + inputs=[ + io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video), + ], + outputs=[ + io.Video.Output(), + ], + ) @classmethod - def IS_CHANGED(cls, file): + def execute(cls, file) -> io.NodeOutput: + video_path = folder_paths.get_annotated_filepath(file) + return io.NodeOutput(VideoFromFile(video_path)) + + @classmethod + def fingerprint_inputs(s, file): video_path = folder_paths.get_annotated_filepath(file) mod_time = os.path.getmtime(video_path) # Instead of hashing the file, we can just use the modification time to avoid @@ -217,24 +196,23 @@ class LoadVideo(ComfyNodeABC): return mod_time @classmethod - def VALIDATE_INPUTS(cls, file): + def validate_inputs(s, file): if not folder_paths.exists_annotated_filepath(file): return "Invalid video file: {}".format(file) return True -NODE_CLASS_MAPPINGS = { - "SaveWEBM": SaveWEBM, - "SaveVideo": SaveVideo, - "CreateVideo": CreateVideo, - "GetVideoComponents": GetVideoComponents, - "LoadVideo": LoadVideo, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SaveVideo": "Save Video", - "CreateVideo": "Create Video", - "GetVideoComponents": "Get Video Components", - "LoadVideo": "Load Video", -} +class VideoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SaveWEBM, + SaveVideo, + CreateVideo, + GetVideoComponents, + LoadVideo, + ] +async def comfy_entrypoint() -> VideoExtension: + return VideoExtension() diff --git a/comfyui_version.py b/comfyui_version.py index 36777e285..ee58205f5 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.55" +__version__ = "0.3.59" diff --git a/main.py b/main.py index b23d50816..c33f0e17b 100644 --- a/main.py +++ b/main.py @@ -113,7 +113,6 @@ import gc if os.name == "nt": os.environ['MIMALLOC_PURGE_DELAY'] = '0' - logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) if __name__ == "__main__": if args.default_device is not None: diff --git a/middleware/__init__.py b/middleware/__init__.py new file mode 100644 index 000000000..2d7c7c3a9 --- /dev/null +++ b/middleware/__init__.py @@ -0,0 +1 @@ +"""Server middleware modules""" diff --git a/middleware/cache_middleware.py b/middleware/cache_middleware.py new file mode 100644 index 000000000..374ef7934 --- /dev/null +++ b/middleware/cache_middleware.py @@ -0,0 +1,52 @@ +"""Cache control middleware for ComfyUI server""" + +from aiohttp import web +from typing import Callable, Awaitable + +# Time in seconds +ONE_HOUR: int = 3600 +ONE_DAY: int = 86400 +IMG_EXTENSIONS = ( + ".jpg", + ".jpeg", + ".png", + ".ppm", + ".bmp", + ".pgm", + ".tif", + ".tiff", + ".webp", +) + + +@web.middleware +async def cache_control( + request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]] +) -> web.Response: + """Cache control middleware that sets appropriate cache headers based on file type and response status""" + response: web.Response = await handler(request) + + if ( + request.path.endswith(".js") + or request.path.endswith(".css") + or request.path.endswith("index.json") + ): + response.headers.setdefault("Cache-Control", "no-cache") + return response + + # Early return for non-image files - no cache headers needed + if not request.path.lower().endswith(IMG_EXTENSIONS): + return response + + # Handle image files + if response.status == 404: + response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}") + elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308): + # Success responses and permanent redirects - cache for 1 day + response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}") + elif response.status in (302, 303, 307): + # Temporary redirects - no cache + response.headers.setdefault("Cache-Control", "no-cache") + # Note: 304 Not Modified falls through - no cache headers set + + return response diff --git a/nodes.py b/nodes.py index 7205d1e94..0c6a219d0 100644 --- a/nodes.py +++ b/nodes.py @@ -925,7 +925,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -953,7 +953,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -963,7 +963,7 @@ class DualCLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama" + DESCRIPTION = "[Recipes]\n\nsdxl: clip-l, clip-g\nsd3: clip-l, clip-g / clip-l, t5 / clip-g, t5\nflux: clip-l, t5\nhidream: at least one of t5 or llama, recommended t5 and llama\nhunyuan_image: qwen2.5vl 7b and byt5 small" def load_clip(self, clip_name1, clip_name2, type, device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -2345,6 +2345,7 @@ async def init_builtin_api_nodes(): "nodes_veo2.py", "nodes_kling.py", "nodes_bfl.py", + "nodes_bytedance.py", "nodes_luma.py", "nodes_recraft.py", "nodes_pixverse.py", diff --git a/pyproject.toml b/pyproject.toml index 04514b4a8..a7fc1a5a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.3.55" +version = "0.3.59" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.9" diff --git a/requirements.txt b/requirements.txt index 7f64aacca..0e21967ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.25.11 -comfyui-workflow-templates==0.1.70 +comfyui-workflow-templates==0.1.81 comfyui-embedded-docs==0.2.6 torch torchsde diff --git a/script_examples/basic_api_example.py b/script_examples/basic_api_example.py index 9128420c4..7e20cc2c1 100644 --- a/script_examples/basic_api_example.py +++ b/script_examples/basic_api_example.py @@ -3,11 +3,7 @@ from urllib import request #This is the ComfyUI api prompt format. -#If you want it for a specific workflow you can "enable dev mode options" -#in the settings of the UI (gear beside the "Queue Size: ") this will enable -#a button on the UI to save workflows in api format. - -#keep in mind ComfyUI is pre alpha software so this format will change a bit. +#If you want it for a specific workflow you can "File -> Export (API)" in the interface. #this is the one for the default workflow prompt_text = """ diff --git a/server.py b/server.py index 8f9c88ebf..43816a8cd 100644 --- a/server.py +++ b/server.py @@ -39,20 +39,15 @@ from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes from protocol import BinaryEventTypes +# Import cache control middleware +from middleware.cache_middleware import cache_control + async def send_socket_catch_exception(function, message): try: await function(message) except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err: logging.warning("send error: {}".format(err)) -@web.middleware -async def cache_control(request: web.Request, handler): - response: web.Response = await handler(request) - if request.path.endswith('.js') or request.path.endswith('.css') or request.path.endswith('index.json'): - response.headers.setdefault('Cache-Control', 'no-cache') - return response - - @web.middleware async def compress_body(request: web.Request, handler): accept_encoding = request.headers.get("Accept-Encoding", "") @@ -729,7 +724,34 @@ class PromptServer(): @routes.post("/interrupt") async def post_interrupt(request): - nodes.interrupt_processing() + try: + json_data = await request.json() + except json.JSONDecodeError: + json_data = {} + + # Check if a specific prompt_id was provided for targeted interruption + prompt_id = json_data.get('prompt_id') + if prompt_id: + currently_running, _ = self.prompt_queue.get_current_queue() + + # Check if the prompt_id matches any currently running prompt + should_interrupt = False + for item in currently_running: + # item structure: (number, prompt_id, prompt, extra_data, outputs_to_execute) + if item[1] == prompt_id: + logging.info(f"Interrupting prompt {prompt_id}") + should_interrupt = True + break + + if should_interrupt: + nodes.interrupt_processing() + else: + logging.info(f"Prompt {prompt_id} is not currently running, skipping interrupt") + else: + # No prompt_id provided, do a global interrupt + logging.info("Global interrupt (no prompt_id specified)") + nodes.interrupt_processing() + return web.Response(status=200) @routes.post("/free") diff --git a/tests-unit/server_test/test_cache_control.py b/tests-unit/server_test/test_cache_control.py new file mode 100644 index 000000000..8de59125a --- /dev/null +++ b/tests-unit/server_test/test_cache_control.py @@ -0,0 +1,255 @@ +"""Tests for server cache control middleware""" + +import pytest +from aiohttp import web +from aiohttp.test_utils import make_mocked_request +from typing import Dict, Any + +from middleware.cache_middleware import cache_control, ONE_HOUR, ONE_DAY, IMG_EXTENSIONS + +pytestmark = pytest.mark.asyncio # Apply asyncio mark to all tests + +# Test configuration data +CACHE_SCENARIOS = [ + # Image file scenarios + { + "name": "image_200_status", + "path": "/test.jpg", + "status": 200, + "expected_cache": f"public, max-age={ONE_DAY}", + "should_have_header": True, + }, + { + "name": "image_404_status", + "path": "/missing.jpg", + "status": 404, + "expected_cache": f"public, max-age={ONE_HOUR}", + "should_have_header": True, + }, + # JavaScript/CSS scenarios + { + "name": "js_no_cache", + "path": "/script.js", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, + { + "name": "css_no_cache", + "path": "/styles.css", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, + { + "name": "index_json_no_cache", + "path": "/api/index.json", + "status": 200, + "expected_cache": "no-cache", + "should_have_header": True, + }, + # Non-matching files + { + "name": "html_no_header", + "path": "/index.html", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, + { + "name": "txt_no_header", + "path": "/data.txt", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, + { + "name": "api_endpoint_no_header", + "path": "/api/endpoint", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, + { + "name": "pdf_no_header", + "path": "/file.pdf", + "status": 200, + "expected_cache": None, + "should_have_header": False, + }, +] + +# Status code scenarios for images +IMAGE_STATUS_SCENARIOS = [ + # Success statuses get long cache + {"status": 200, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 201, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 202, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 204, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 206, "expected": f"public, max-age={ONE_DAY}"}, + # Permanent redirects get long cache + {"status": 301, "expected": f"public, max-age={ONE_DAY}"}, + {"status": 308, "expected": f"public, max-age={ONE_DAY}"}, + # Temporary redirects get no cache + {"status": 302, "expected": "no-cache"}, + {"status": 303, "expected": "no-cache"}, + {"status": 307, "expected": "no-cache"}, + # 404 gets short cache + {"status": 404, "expected": f"public, max-age={ONE_HOUR}"}, +] + +# Case sensitivity test paths +CASE_SENSITIVITY_PATHS = ["/image.JPG", "/photo.PNG", "/pic.JpEg"] + +# Edge case test paths +EDGE_CASE_PATHS = [ + { + "name": "query_strings_ignored", + "path": "/image.jpg?v=123&size=large", + "expected": f"public, max-age={ONE_DAY}", + }, + { + "name": "multiple_dots_in_path", + "path": "/image.min.jpg", + "expected": f"public, max-age={ONE_DAY}", + }, + { + "name": "nested_paths_with_images", + "path": "/static/images/photo.jpg", + "expected": f"public, max-age={ONE_DAY}", + }, +] + + +class TestCacheControl: + """Test cache control middleware functionality""" + + @pytest.fixture + def status_handler_factory(self): + """Create a factory for handlers that return specific status codes""" + + def factory(status: int, headers: Dict[str, str] = None): + async def handler(request): + return web.Response(status=status, headers=headers or {}) + + return handler + + return factory + + @pytest.fixture + def mock_handler(self, status_handler_factory): + """Create a mock handler that returns a response with 200 status""" + return status_handler_factory(200) + + @pytest.fixture + def handler_with_existing_cache(self, status_handler_factory): + """Create a handler that returns response with existing Cache-Control header""" + return status_handler_factory(200, {"Cache-Control": "max-age=3600"}) + + async def assert_cache_header( + self, + response: web.Response, + expected_cache: str = None, + should_have_header: bool = True, + ): + """Helper to assert cache control headers""" + if should_have_header: + assert "Cache-Control" in response.headers + if expected_cache: + assert response.headers["Cache-Control"] == expected_cache + else: + assert "Cache-Control" not in response.headers + + # Parameterized tests + @pytest.mark.parametrize("scenario", CACHE_SCENARIOS, ids=lambda x: x["name"]) + async def test_cache_control_scenarios( + self, scenario: Dict[str, Any], status_handler_factory + ): + """Test various cache control scenarios""" + handler = status_handler_factory(scenario["status"]) + request = make_mocked_request("GET", scenario["path"]) + response = await cache_control(request, handler) + + assert response.status == scenario["status"] + await self.assert_cache_header( + response, scenario["expected_cache"], scenario["should_have_header"] + ) + + @pytest.mark.parametrize("ext", IMG_EXTENSIONS) + async def test_all_image_extensions(self, ext: str, mock_handler): + """Test all defined image extensions are handled correctly""" + request = make_mocked_request("GET", f"/image{ext}") + response = await cache_control(request, mock_handler) + + assert response.status == 200 + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}" + + @pytest.mark.parametrize( + "status_scenario", IMAGE_STATUS_SCENARIOS, ids=lambda x: f"status_{x['status']}" + ) + async def test_image_status_codes( + self, status_scenario: Dict[str, Any], status_handler_factory + ): + """Test different status codes for image requests""" + handler = status_handler_factory(status_scenario["status"]) + request = make_mocked_request("GET", "/image.jpg") + response = await cache_control(request, handler) + + assert response.status == status_scenario["status"] + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == status_scenario["expected"] + + @pytest.mark.parametrize("path", CASE_SENSITIVITY_PATHS) + async def test_case_insensitive_image_extension(self, path: str, mock_handler): + """Test that image extensions are matched case-insensitively""" + request = make_mocked_request("GET", path) + response = await cache_control(request, mock_handler) + + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == f"public, max-age={ONE_DAY}" + + @pytest.mark.parametrize("edge_case", EDGE_CASE_PATHS, ids=lambda x: x["name"]) + async def test_edge_cases(self, edge_case: Dict[str, str], mock_handler): + """Test edge cases like query strings, nested paths, etc.""" + request = make_mocked_request("GET", edge_case["path"]) + response = await cache_control(request, mock_handler) + + assert "Cache-Control" in response.headers + assert response.headers["Cache-Control"] == edge_case["expected"] + + # Header preservation tests (special cases not covered by parameterization) + async def test_js_preserves_existing_headers(self, handler_with_existing_cache): + """Test that .js files preserve existing Cache-Control headers""" + request = make_mocked_request("GET", "/script.js") + response = await cache_control(request, handler_with_existing_cache) + + # setdefault should preserve existing header + assert response.headers["Cache-Control"] == "max-age=3600" + + async def test_css_preserves_existing_headers(self, handler_with_existing_cache): + """Test that .css files preserve existing Cache-Control headers""" + request = make_mocked_request("GET", "/styles.css") + response = await cache_control(request, handler_with_existing_cache) + + # setdefault should preserve existing header + assert response.headers["Cache-Control"] == "max-age=3600" + + async def test_image_preserves_existing_headers(self, status_handler_factory): + """Test that image cache headers preserve existing Cache-Control""" + handler = status_handler_factory(200, {"Cache-Control": "private, no-cache"}) + request = make_mocked_request("GET", "/image.jpg") + response = await cache_control(request, handler) + + # setdefault should preserve existing header + assert response.headers["Cache-Control"] == "private, no-cache" + + async def test_304_not_modified_inherits_cache(self, status_handler_factory): + """Test that 304 Not Modified doesn't set cache headers for images""" + handler = status_handler_factory(304, {"Cache-Control": "max-age=7200"}) + request = make_mocked_request("GET", "/not-modified.jpg") + response = await cache_control(request, handler) + + assert response.status == 304 + # Should preserve existing cache header, not override + assert response.headers["Cache-Control"] == "max-age=7200" diff --git a/tests/conftest.py b/tests/conftest.py index 4e30eb581..290e3a5c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ def pytest_addoption(parser): parser.addoption('--output_dir', action="store", default='tests/inference/samples', help='Output directory for generated images') parser.addoption("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") parser.addoption("--port", type=int, default=8188, help="Set the listen port.") + parser.addoption("--skip-timing-checks", action="store_true", default=False, help="Skip timing-related assertions in tests (useful for CI environments with variable performance)") # This initializes args at the beginning of the test session @pytest.fixture(scope="session", autouse=True) @@ -19,6 +20,11 @@ def args_pytest(pytestconfig): return args +@pytest.fixture(scope="session") +def skip_timing_checks(pytestconfig): + """Fixture that returns whether timing checks should be skipped.""" + return pytestconfig.getoption("--skip-timing-checks") + def pytest_collection_modifyitems(items): # Modifies items so tests run in the correct order diff --git a/tests/inference/extra_model_paths.yaml b/tests/execution/extra_model_paths.yaml similarity index 100% rename from tests/inference/extra_model_paths.yaml rename to tests/execution/extra_model_paths.yaml diff --git a/tests/inference/test_async_nodes.py b/tests/execution/test_async_nodes.py similarity index 95% rename from tests/inference/test_async_nodes.py rename to tests/execution/test_async_nodes.py index f029953dd..c771b4b36 100644 --- a/tests/inference/test_async_nodes.py +++ b/tests/execution/test_async_nodes.py @@ -7,7 +7,7 @@ import subprocess from pytest import fixture from comfy_execution.graph_utils import GraphBuilder -from tests.inference.test_execution import ComfyClient, run_warmup +from tests.execution.test_execution import ComfyClient, run_warmup @pytest.mark.execution @@ -23,7 +23,7 @@ class TestAsyncNodes: '--output-directory', args_pytest["output_dir"], '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), - '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', '--cpu', ] use_lru, lru_size = request.param @@ -81,7 +81,7 @@ class TestAsyncNodes: assert len(result_images) == 1, "Should have 1 image" assert np.array(result_images[0]).min() == 0 and np.array(result_images[0]).max() == 0, "Image should be black" - def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder): + def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): """Test that multiple async nodes execute in parallel.""" # Warmup execution to ensure server is fully initialized run_warmup(client) @@ -104,7 +104,8 @@ class TestAsyncNodes: elapsed_time = time.time() - start_time # Should take ~0.5s (max duration) not 1.2s (sum of durations) - assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s" + if not skip_timing_checks: + assert elapsed_time < 0.8, f"Parallel execution took {elapsed_time}s, expected < 0.8s" # Verify all nodes executed assert result.did_run(sleep1) and result.did_run(sleep2) and result.did_run(sleep3) @@ -150,7 +151,7 @@ class TestAsyncNodes: with pytest.raises(urllib.error.HTTPError): client.run(g) - def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder): + def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): """Test async nodes with lazy evaluation.""" # Warmup execution to ensure server is fully initialized run_warmup(client, prefix="warmup_lazy") @@ -173,7 +174,8 @@ class TestAsyncNodes: elapsed_time = time.time() - start_time # Should only execute sleep1, not sleep2 - assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s" + if not skip_timing_checks: + assert elapsed_time < 0.5, f"Should skip sleep2, took {elapsed_time}s" assert result.did_run(sleep1), "Sleep1 should have executed" assert not result.did_run(sleep2), "Sleep2 should have been skipped" @@ -310,7 +312,7 @@ class TestAsyncNodes: images = result.get_images(output) assert len(images) == 1, "Should have blocked second image" - def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder): + def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): """Test that async nodes are properly cached.""" # Warmup execution to ensure server is fully initialized run_warmup(client, prefix="warmup_cache") @@ -330,9 +332,10 @@ class TestAsyncNodes: elapsed_time = time.time() - start_time assert not result2.did_run(sleep_node), "Should be cached" - assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" + if not skip_timing_checks: + assert elapsed_time < 0.1, f"Cached run took {elapsed_time}s, should be instant" - def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder): + def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): """Test async nodes within dynamically generated prompts.""" # Warmup execution to ensure server is fully initialized run_warmup(client, prefix="warmup_dynamic") @@ -345,8 +348,8 @@ class TestAsyncNodes: dynamic_async = g.node("TestDynamicAsyncGeneration", image1=image1.out(0), image2=image2.out(0), - num_async_nodes=3, - sleep_duration=0.2) + num_async_nodes=5, + sleep_duration=0.4) g.node("SaveImage", images=dynamic_async.out(0)) start_time = time.time() @@ -354,7 +357,8 @@ class TestAsyncNodes: elapsed_time = time.time() - start_time # Should execute async nodes in parallel within dynamic prompt - assert elapsed_time < 0.5, f"Dynamic async execution took {elapsed_time}s" + if not skip_timing_checks: + assert elapsed_time < 1.0, f"Dynamic async execution took {elapsed_time}s" assert result.did_run(dynamic_async) def test_async_resource_cleanup(self, client: ComfyClient, builder: GraphBuilder): diff --git a/tests/inference/test_execution.py b/tests/execution/test_execution.py similarity index 98% rename from tests/inference/test_execution.py rename to tests/execution/test_execution.py index e7b29302e..8ea05fdd8 100644 --- a/tests/inference/test_execution.py +++ b/tests/execution/test_execution.py @@ -149,7 +149,7 @@ class TestExecution: '--output-directory', args_pytest["output_dir"], '--listen', args_pytest["listen"], '--port', str(args_pytest["port"]), - '--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml', + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', '--cpu', ] use_lru, lru_size = request.param @@ -518,7 +518,7 @@ class TestExecution: assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" assert not result.did_run(test_node), "The execution should have been cached" - def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder): + def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized run_warmup(client) @@ -541,14 +541,15 @@ class TestExecution: # The test should take around 3.0 seconds (the longest sleep duration) # plus some overhead, but definitely less than the sum of all sleeps (9.0s) - assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s" + if not skip_timing_checks: + assert elapsed_time < 8.9, f"Parallel execution took {elapsed_time}s, expected less than 8.9s" # Verify that all nodes executed assert result.did_run(sleep_node1), "Sleep node 1 should have run" assert result.did_run(sleep_node2), "Sleep node 2 should have run" assert result.did_run(sleep_node3), "Sleep node 3 should have run" - def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder): + def test_parallel_sleep_expansion(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized run_warmup(client) @@ -574,7 +575,9 @@ class TestExecution: # Similar to the previous test, expect parallel execution of the sleep nodes # which should complete in less than the sum of all sleeps - assert elapsed_time < 10.0, f"Expansion execution took {elapsed_time}s, expected less than 5.5s" + # Lots of leeway here since Windows CI is slow + if not skip_timing_checks: + assert elapsed_time < 13.0, f"Expansion execution took {elapsed_time}s" # Verify the parallel sleep node executed assert result.did_run(parallel_sleep), "ParallelSleep node should have run" diff --git a/tests/execution/test_progress_isolation.py b/tests/execution/test_progress_isolation.py new file mode 100644 index 000000000..93dc0d41b --- /dev/null +++ b/tests/execution/test_progress_isolation.py @@ -0,0 +1,233 @@ +"""Test that progress updates are properly isolated between WebSocket clients.""" + +import json +import pytest +import time +import threading +import uuid +import websocket +from typing import List, Dict, Any +from comfy_execution.graph_utils import GraphBuilder +from tests.execution.test_execution import ComfyClient + + +class ProgressTracker: + """Tracks progress messages received by a WebSocket client.""" + + def __init__(self, client_id: str): + self.client_id = client_id + self.progress_messages: List[Dict[str, Any]] = [] + self.lock = threading.Lock() + + def add_message(self, message: Dict[str, Any]): + """Thread-safe addition of progress messages.""" + with self.lock: + self.progress_messages.append(message) + + def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]: + """Get all progress messages for a specific prompt_id.""" + with self.lock: + return [ + msg for msg in self.progress_messages + if msg.get('data', {}).get('prompt_id') == prompt_id + ] + + def has_cross_contamination(self, own_prompt_id: str) -> bool: + """Check if this client received progress for other prompts.""" + with self.lock: + for msg in self.progress_messages: + msg_prompt_id = msg.get('data', {}).get('prompt_id') + if msg_prompt_id and msg_prompt_id != own_prompt_id: + return True + return False + + +class IsolatedClient(ComfyClient): + """Extended ComfyClient that tracks all WebSocket messages.""" + + def __init__(self): + super().__init__() + self.progress_tracker = None + self.all_messages: List[Dict[str, Any]] = [] + + def connect(self, listen='127.0.0.1', port=8188, client_id=None): + """Connect with a specific client_id and set up message tracking.""" + if client_id is None: + client_id = str(uuid.uuid4()) + super().connect(listen, port, client_id) + self.progress_tracker = ProgressTracker(client_id) + + def listen_for_messages(self, duration: float = 5.0): + """Listen for WebSocket messages for a specified duration.""" + end_time = time.time() + duration + self.ws.settimeout(0.5) # Non-blocking with timeout + + while time.time() < end_time: + try: + out = self.ws.recv() + if isinstance(out, str): + message = json.loads(out) + self.all_messages.append(message) + + # Track progress_state messages + if message.get('type') == 'progress_state': + self.progress_tracker.add_message(message) + except websocket.WebSocketTimeoutException: + continue + except Exception: + # Log error silently in test context + break + + +@pytest.mark.execution +class TestProgressIsolation: + """Test suite for verifying progress update isolation between clients.""" + + @pytest.fixture(scope="class", autouse=True) + def _server(self, args_pytest): + """Start the ComfyUI server for testing.""" + import subprocess + pargs = [ + 'python', 'main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', + '--cpu', + ] + p = subprocess.Popen(pargs) + yield + p.kill() + + def start_client_with_retry(self, listen: str, port: int, client_id: str = None): + """Start client with connection retries.""" + client = IsolatedClient() + # Connect to server (with retries) + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + client.connect(listen, port, client_id) + return client + except ConnectionRefusedError as e: + print(e) # noqa: T201 + print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201 + raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts") + + def test_progress_isolation_between_clients(self, args_pytest): + """Test that progress updates are isolated between different clients.""" + listen = args_pytest["listen"] + port = args_pytest["port"] + + # Create two separate clients with unique IDs + client_a_id = "client_a_" + str(uuid.uuid4()) + client_b_id = "client_b_" + str(uuid.uuid4()) + + try: + # Connect both clients with retries + client_a = self.start_client_with_retry(listen, port, client_a_id) + client_b = self.start_client_with_retry(listen, port, client_b_id) + + # Create simple workflows for both clients + graph_a = GraphBuilder(prefix="client_a") + image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + graph_a.node("PreviewImage", images=image_a.out(0)) + + graph_b = GraphBuilder(prefix="client_b") + image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + graph_b.node("PreviewImage", images=image_b.out(0)) + + # Submit workflows from both clients + prompt_a = graph_a.finalize() + prompt_b = graph_b.finalize() + + response_a = client_a.queue_prompt(prompt_a) + prompt_id_a = response_a['prompt_id'] + + response_b = client_b.queue_prompt(prompt_b) + prompt_id_b = response_b['prompt_id'] + + # Start threads to listen for messages on both clients + def listen_client_a(): + client_a.listen_for_messages(duration=10.0) + + def listen_client_b(): + client_b.listen_for_messages(duration=10.0) + + thread_a = threading.Thread(target=listen_client_a) + thread_b = threading.Thread(target=listen_client_b) + + thread_a.start() + thread_b.start() + + # Wait for threads to complete + thread_a.join() + thread_b.join() + + # Verify isolation + # Client A should only receive progress for prompt_id_a + assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \ + f"Client A received progress updates for other clients' workflows. " \ + f"Expected only {prompt_id_a}, but got messages for multiple prompts." + + # Client B should only receive progress for prompt_id_b + assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \ + f"Client B received progress updates for other clients' workflows. " \ + f"Expected only {prompt_id_b}, but got messages for multiple prompts." + + # Verify each client received their own progress updates + client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a) + client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b) + + assert len(client_a_messages) > 0, \ + "Client A did not receive any progress updates for its own workflow" + assert len(client_b_messages) > 0, \ + "Client B did not receive any progress updates for its own workflow" + + # Ensure no cross-contamination + client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b) + client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a) + + assert len(client_a_other) == 0, \ + f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow" + assert len(client_b_other) == 0, \ + f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow" + + finally: + # Clean up connections + if hasattr(client_a, 'ws'): + client_a.ws.close() + if hasattr(client_b, 'ws'): + client_b.ws.close() + + def test_progress_with_missing_client_id(self, args_pytest): + """Test that progress updates handle missing client_id gracefully.""" + listen = args_pytest["listen"] + port = args_pytest["port"] + + try: + # Connect client with retries + client = self.start_client_with_retry(listen, port) + + # Create a simple workflow + graph = GraphBuilder(prefix="test_missing_id") + image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1) + graph.node("PreviewImage", images=image.out(0)) + + # Submit workflow + prompt = graph.finalize() + response = client.queue_prompt(prompt) + prompt_id = response['prompt_id'] + + # Listen for messages + client.listen_for_messages(duration=5.0) + + # Should still receive progress updates for own workflow + messages = client.progress_tracker.get_messages_for_prompt(prompt_id) + assert len(messages) > 0, \ + "Client did not receive progress updates even though it initiated the workflow" + + finally: + if hasattr(client, 'ws'): + client.ws.close() + diff --git a/tests/inference/testing_nodes/testing-pack/__init__.py b/tests/execution/testing_nodes/testing-pack/__init__.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/__init__.py rename to tests/execution/testing_nodes/testing-pack/__init__.py diff --git a/tests/inference/testing_nodes/testing-pack/api_test_nodes.py b/tests/execution/testing_nodes/testing-pack/api_test_nodes.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/api_test_nodes.py rename to tests/execution/testing_nodes/testing-pack/api_test_nodes.py diff --git a/tests/inference/testing_nodes/testing-pack/async_test_nodes.py b/tests/execution/testing_nodes/testing-pack/async_test_nodes.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/async_test_nodes.py rename to tests/execution/testing_nodes/testing-pack/async_test_nodes.py diff --git a/tests/inference/testing_nodes/testing-pack/conditions.py b/tests/execution/testing_nodes/testing-pack/conditions.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/conditions.py rename to tests/execution/testing_nodes/testing-pack/conditions.py diff --git a/tests/inference/testing_nodes/testing-pack/flow_control.py b/tests/execution/testing_nodes/testing-pack/flow_control.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/flow_control.py rename to tests/execution/testing_nodes/testing-pack/flow_control.py diff --git a/tests/inference/testing_nodes/testing-pack/specific_tests.py b/tests/execution/testing_nodes/testing-pack/specific_tests.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/specific_tests.py rename to tests/execution/testing_nodes/testing-pack/specific_tests.py diff --git a/tests/inference/testing_nodes/testing-pack/stubs.py b/tests/execution/testing_nodes/testing-pack/stubs.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/stubs.py rename to tests/execution/testing_nodes/testing-pack/stubs.py diff --git a/tests/inference/testing_nodes/testing-pack/tools.py b/tests/execution/testing_nodes/testing-pack/tools.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/tools.py rename to tests/execution/testing_nodes/testing-pack/tools.py diff --git a/tests/inference/testing_nodes/testing-pack/util.py b/tests/execution/testing_nodes/testing-pack/util.py similarity index 100% rename from tests/inference/testing_nodes/testing-pack/util.py rename to tests/execution/testing_nodes/testing-pack/util.py