mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 02:53:05 +08:00
Merge branch 'master' into v3-improvements
This commit is contained in:
commit
cb7d2456fc
1
.github/workflows/test-ci.yml
vendored
1
.github/workflows/test-ci.yml
vendored
@ -5,6 +5,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
|
- release/**
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- 'app/**'
|
- 'app/**'
|
||||||
- 'input/**'
|
- 'input/**'
|
||||||
|
|||||||
4
.github/workflows/test-execution.yml
vendored
4
.github/workflows/test-execution.yml
vendored
@ -2,9 +2,9 @@ name: Execution Tests
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
4
.github/workflows/test-launch.yml
vendored
4
.github/workflows/test-launch.yml
vendored
@ -2,9 +2,9 @@ name: Test server launches without errors
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
4
.github/workflows/test-unit.yml
vendored
4
.github/workflows/test-unit.yml
vendored
@ -2,9 +2,9 @@ name: Unit Tests
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
1
.github/workflows/update-version.yml
vendored
1
.github/workflows/update-version.yml
vendored
@ -6,6 +6,7 @@ on:
|
|||||||
- "pyproject.toml"
|
- "pyproject.toml"
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
|
- release/**
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
update-version:
|
update-version:
|
||||||
|
|||||||
@ -1618,6 +1618,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||||||
x = x + sde_noise * sigmas[i + 1] * s_noise
|
x = x + sde_noise * sigmas[i + 1] * s_noise
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
|
||||||
|
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||||
|
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
|
||||||
|
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
|
||||||
|
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
|
||||||
@ -1765,7 +1776,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
# Predictor
|
# Predictor
|
||||||
if sigmas[i + 1] == 0:
|
if sigmas[i + 1] == 0:
|
||||||
# Denoising step
|
# Denoising step
|
||||||
x = denoised
|
x_pred = denoised
|
||||||
else:
|
else:
|
||||||
tau_t = tau_func(sigmas[i + 1])
|
tau_t = tau_func(sigmas[i + 1])
|
||||||
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
|
||||||
@ -1786,7 +1797,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
if tau_t > 0 and s_noise > 0:
|
if tau_t > 0 and s_noise > 0:
|
||||||
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
|
||||||
x_pred = x_pred + noise
|
x_pred = x_pred + noise
|
||||||
return x
|
return x_pred
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
@ -322,6 +322,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
pooled_projection_dim: int = 768,
|
pooled_projection_dim: int = 768,
|
||||||
guidance_embeds: bool = False,
|
guidance_embeds: bool = False,
|
||||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||||
|
default_ref_method="index",
|
||||||
image_model=None,
|
image_model=None,
|
||||||
final_layer=True,
|
final_layer=True,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -334,6 +335,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels or in_channels
|
self.out_channels = out_channels or in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.default_ref_method = default_ref_method
|
||||||
|
|
||||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||||
|
|
||||||
@ -361,6 +363,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if self.default_ref_method == "index_timestep_zero":
|
||||||
|
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
|
||||||
|
|
||||||
if final_layer:
|
if final_layer:
|
||||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||||
@ -416,7 +421,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
ref_method = kwargs.get("ref_latents_method", "index")
|
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
||||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||||
timestep_zero = ref_method == "index_timestep_zero"
|
timestep_zero = ref_method == "index_timestep_zero"
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
|
|||||||
@ -259,7 +259,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
if "__x0__" in state_dict_keys: # x0 pred
|
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||||
dit_config["use_x0"] = True
|
dit_config["use_x0"] = True
|
||||||
else:
|
else:
|
||||||
dit_config["use_x0"] = False
|
dit_config["use_x0"] = False
|
||||||
@ -618,6 +618,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["image_model"] = "qwen_image"
|
dit_config["image_model"] = "qwen_image"
|
||||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||||
|
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||||
|
|||||||
@ -720,7 +720,7 @@ class Sampler:
|
|||||||
sigma = float(sigmas[0])
|
sigma = float(sigmas[0])
|
||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
|
|||||||
@ -1594,12 +1594,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
|||||||
|
|
||||||
@final
|
@final
|
||||||
@classmethod
|
@classmethod
|
||||||
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
|
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data | None) -> type[ComfyNode]:
|
||||||
"""Creates clone of real node class to prevent monkey-patching."""
|
"""Creates clone of real node class to prevent monkey-patching."""
|
||||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||||
# set hidden
|
# set hidden
|
||||||
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
|
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None)
|
||||||
return type_clone
|
return type_clone
|
||||||
|
|
||||||
@final
|
@final
|
||||||
|
|||||||
52
comfy_api_nodes/apis/openai_api.py
Normal file
52
comfy_api_nodes/apis/openai_api.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Datum2(BaseModel):
|
||||||
|
b64_json: str | None = Field(None, description="Base64 encoded image data")
|
||||||
|
revised_prompt: str | None = Field(None, description="Revised prompt")
|
||||||
|
url: str | None = Field(None, description="URL of the image")
|
||||||
|
|
||||||
|
|
||||||
|
class InputTokensDetails(BaseModel):
|
||||||
|
image_tokens: int | None = None
|
||||||
|
text_tokens: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Usage(BaseModel):
|
||||||
|
input_tokens: int | None = None
|
||||||
|
input_tokens_details: InputTokensDetails | None = None
|
||||||
|
output_tokens: int | None = None
|
||||||
|
total_tokens: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageGenerationResponse(BaseModel):
|
||||||
|
data: list[Datum2] | None = None
|
||||||
|
usage: Usage | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageEditRequest(BaseModel):
|
||||||
|
background: str | None = Field(None, description="Background transparency")
|
||||||
|
model: str = Field(...)
|
||||||
|
moderation: str | None = Field(None)
|
||||||
|
n: int | None = Field(None, description="The number of images to generate")
|
||||||
|
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||||
|
output_format: str | None = Field(None)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||||
|
size: str | None = Field(None, description="Size of the output image")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIImageGenerationRequest(BaseModel):
|
||||||
|
background: str | None = Field(None, description="Background transparency")
|
||||||
|
model: str | None = Field(None)
|
||||||
|
moderation: str | None = Field(None)
|
||||||
|
n: int | None = Field(
|
||||||
|
None,
|
||||||
|
description="The number of images to generate.",
|
||||||
|
)
|
||||||
|
output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)")
|
||||||
|
output_format: str | None = Field(None)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
quality: str | None = Field(None, description="The quality of the generated image")
|
||||||
|
size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)")
|
||||||
|
style: str | None = Field(None, description="Style of the image (only for dall-e-3)")
|
||||||
@ -1,46 +1,45 @@
|
|||||||
from io import BytesIO
|
import base64
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import cleandoc
|
from io import BytesIO
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import folder_paths
|
|
||||||
import base64
|
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
OpenAIImageGenerationRequest,
|
|
||||||
OpenAIImageEditRequest,
|
|
||||||
OpenAIImageGenerationResponse,
|
|
||||||
OpenAICreateResponse,
|
|
||||||
OpenAIResponse,
|
|
||||||
CreateModelResponseProperties,
|
CreateModelResponseProperties,
|
||||||
Item,
|
|
||||||
OutputContent,
|
|
||||||
InputImageContent,
|
|
||||||
Detail,
|
Detail,
|
||||||
InputTextContent,
|
|
||||||
InputMessage,
|
|
||||||
InputMessageContentList,
|
|
||||||
InputContent,
|
InputContent,
|
||||||
InputFileContent,
|
InputFileContent,
|
||||||
|
InputImageContent,
|
||||||
|
InputMessage,
|
||||||
|
InputMessageContentList,
|
||||||
|
InputTextContent,
|
||||||
|
Item,
|
||||||
|
OpenAICreateResponse,
|
||||||
|
OpenAIResponse,
|
||||||
|
OutputContent,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.apis.openai_api import (
|
||||||
|
OpenAIImageEditRequest,
|
||||||
|
OpenAIImageGenerationRequest,
|
||||||
|
OpenAIImageGenerationResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
downscale_image_tensor,
|
|
||||||
download_url_to_bytesio,
|
|
||||||
validate_string,
|
|
||||||
tensor_to_base64_string,
|
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
sync_op,
|
download_url_to_bytesio,
|
||||||
|
downscale_image_tensor,
|
||||||
poll_op,
|
poll_op,
|
||||||
|
sync_op,
|
||||||
|
tensor_to_base64_string,
|
||||||
text_filepath_to_data_uri,
|
text_filepath_to_data_uri,
|
||||||
|
validate_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
|
||||||
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
|
||||||
|
|
||||||
@ -98,9 +97,6 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIDalle2(IO.ComfyNode):
|
class OpenAIDalle2(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images synchronously via OpenAI's DALL·E 2 endpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -108,7 +104,7 @@ class OpenAIDalle2(IO.ComfyNode):
|
|||||||
node_id="OpenAIDalle2",
|
node_id="OpenAIDalle2",
|
||||||
display_name="OpenAI DALL·E 2",
|
display_name="OpenAI DALL·E 2",
|
||||||
category="api node/image/OpenAI",
|
category="api node/image/OpenAI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -234,9 +230,6 @@ class OpenAIDalle2(IO.ComfyNode):
|
|||||||
|
|
||||||
|
|
||||||
class OpenAIDalle3(IO.ComfyNode):
|
class OpenAIDalle3(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images synchronously via OpenAI's DALL·E 3 endpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -244,7 +237,7 @@ class OpenAIDalle3(IO.ComfyNode):
|
|||||||
node_id="OpenAIDalle3",
|
node_id="OpenAIDalle3",
|
||||||
display_name="OpenAI DALL·E 3",
|
display_name="OpenAI DALL·E 3",
|
||||||
category="api node/image/OpenAI",
|
category="api node/image/OpenAI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -326,10 +319,16 @@ class OpenAIDalle3(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_tokens_price_image_1(response: OpenAIImageGenerationResponse) -> float | None:
|
||||||
|
# https://platform.openai.com/docs/pricing
|
||||||
|
return ((response.usage.input_tokens * 10.0) + (response.usage.output_tokens * 40.0)) / 1_000_000.0
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> float | None:
|
||||||
|
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTImage1(IO.ComfyNode):
|
class OpenAIGPTImage1(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images synchronously via OpenAI's GPT Image 1 endpoint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -337,13 +336,13 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
node_id="OpenAIGPTImage1",
|
node_id="OpenAIGPTImage1",
|
||||||
display_name="OpenAI GPT Image 1",
|
display_name="OpenAI GPT Image 1",
|
||||||
category="api node/image/OpenAI",
|
category="api node/image/OpenAI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
default="",
|
default="",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="Text prompt for GPT Image 1",
|
tooltip="Text prompt for GPT Image",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -365,8 +364,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"background",
|
"background",
|
||||||
default="opaque",
|
default="auto",
|
||||||
options=["opaque", "transparent"],
|
options=["auto", "opaque", "transparent"],
|
||||||
tooltip="Return image with or without background",
|
tooltip="Return image with or without background",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
@ -397,6 +396,11 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["gpt-image-1", "gpt-image-1.5"],
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
@ -412,32 +416,34 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
prompt,
|
prompt: str,
|
||||||
seed=0,
|
seed: int = 0,
|
||||||
quality="low",
|
quality: str = "low",
|
||||||
background="opaque",
|
background: str = "opaque",
|
||||||
image=None,
|
image: Input.Image | None = None,
|
||||||
mask=None,
|
mask: Input.Image | None = None,
|
||||||
n=1,
|
n: int = 1,
|
||||||
size="1024x1024",
|
size: str = "1024x1024",
|
||||||
|
model: str = "gpt-image-1",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
model = "gpt-image-1"
|
|
||||||
path = "/proxy/openai/images/generations"
|
if mask is not None and image is None:
|
||||||
content_type = "application/json"
|
raise ValueError("Cannot use a mask without an input image")
|
||||||
request_class = OpenAIImageGenerationRequest
|
|
||||||
files = []
|
if model == "gpt-image-1":
|
||||||
|
price_extractor = calculate_tokens_price_image_1
|
||||||
|
elif model == "gpt-image-1.5":
|
||||||
|
price_extractor = calculate_tokens_price_image_1_5
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model: {model}")
|
||||||
|
|
||||||
if image is not None:
|
if image is not None:
|
||||||
path = "/proxy/openai/images/edits"
|
files = []
|
||||||
request_class = OpenAIImageEditRequest
|
|
||||||
content_type = "multipart/form-data"
|
|
||||||
|
|
||||||
batch_size = image.shape[0]
|
batch_size = image.shape[0]
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
single_image = image[i : i + 1]
|
single_image = image[i: i + 1]
|
||||||
scaled_image = downscale_image_tensor(single_image).squeeze()
|
scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze()
|
||||||
|
|
||||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||||
img = Image.fromarray(image_np)
|
img = Image.fromarray(image_np)
|
||||||
@ -450,44 +456,59 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if image is None:
|
if image.shape[0] != 1:
|
||||||
raise Exception("Cannot use a mask without an input image")
|
raise Exception("Cannot use a mask with multiple image")
|
||||||
if image.shape[0] != 1:
|
if mask.shape[1:] != image.shape[1:-1]:
|
||||||
raise Exception("Cannot use a mask with multiple image")
|
raise Exception("Mask and Image must be the same size")
|
||||||
if mask.shape[1:] != image.shape[1:-1]:
|
_, height, width = mask.shape
|
||||||
raise Exception("Mask and Image must be the same size")
|
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
||||||
batch, height, width = mask.shape
|
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
||||||
rgba_mask = torch.zeros(height, width, 4, device="cpu")
|
|
||||||
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
|
|
||||||
|
|
||||||
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
|
scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze()
|
||||||
|
|
||||||
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
|
||||||
mask_img = Image.fromarray(mask_np)
|
mask_img = Image.fromarray(mask_np)
|
||||||
mask_img_byte_arr = BytesIO()
|
mask_img_byte_arr = BytesIO()
|
||||||
mask_img.save(mask_img_byte_arr, format="PNG")
|
mask_img.save(mask_img_byte_arr, format="PNG")
|
||||||
mask_img_byte_arr.seek(0)
|
mask_img_byte_arr.seek(0)
|
||||||
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
|
||||||
|
|
||||||
# Build the operation
|
|
||||||
response = await sync_op(
|
|
||||||
cls,
|
|
||||||
ApiEndpoint(path=path, method="POST"),
|
|
||||||
response_model=OpenAIImageGenerationResponse,
|
|
||||||
data=request_class(
|
|
||||||
model=model,
|
|
||||||
prompt=prompt,
|
|
||||||
quality=quality,
|
|
||||||
background=background,
|
|
||||||
n=n,
|
|
||||||
seed=seed,
|
|
||||||
size=size,
|
|
||||||
),
|
|
||||||
files=files if files else None,
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
|
||||||
|
response_model=OpenAIImageGenerationResponse,
|
||||||
|
data=OpenAIImageEditRequest(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
background=background,
|
||||||
|
n=n,
|
||||||
|
seed=seed,
|
||||||
|
size=size,
|
||||||
|
moderation="low",
|
||||||
|
),
|
||||||
|
content_type="multipart/form-data",
|
||||||
|
files=files,
|
||||||
|
price_extractor=price_extractor,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
|
||||||
|
response_model=OpenAIImageGenerationResponse,
|
||||||
|
data=OpenAIImageGenerationRequest(
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
quality=quality,
|
||||||
|
background=background,
|
||||||
|
n=n,
|
||||||
|
seed=seed,
|
||||||
|
size=size,
|
||||||
|
moderation="low",
|
||||||
|
),
|
||||||
|
price_extractor=price_extractor,
|
||||||
|
)
|
||||||
return IO.NodeOutput(await validate_and_cast_response(response))
|
return IO.NodeOutput(await validate_and_cast_response(response))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -21,26 +19,26 @@ from comfy_api_nodes.util import (
|
|||||||
|
|
||||||
class Text2ImageInputField(BaseModel):
|
class Text2ImageInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Image2ImageInputField(BaseModel):
|
class Image2ImageInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
images: list[str] = Field(..., min_length=1, max_length=2)
|
images: list[str] = Field(..., min_length=1, max_length=2)
|
||||||
|
|
||||||
|
|
||||||
class Text2VideoInputField(BaseModel):
|
class Text2VideoInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
audio_url: Optional[str] = Field(None)
|
audio_url: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoInputField(BaseModel):
|
class Image2VideoInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
img_url: str = Field(...)
|
img_url: str = Field(...)
|
||||||
audio_url: Optional[str] = Field(None)
|
audio_url: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Txt2ImageParametersField(BaseModel):
|
class Txt2ImageParametersField(BaseModel):
|
||||||
@ -52,7 +50,7 @@ class Txt2ImageParametersField(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Image2ImageParametersField(BaseModel):
|
class Image2ImageParametersField(BaseModel):
|
||||||
size: Optional[str] = Field(None)
|
size: str | None = Field(None)
|
||||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(True)
|
||||||
@ -61,19 +59,21 @@ class Image2ImageParametersField(BaseModel):
|
|||||||
class Text2VideoParametersField(BaseModel):
|
class Text2VideoParametersField(BaseModel):
|
||||||
size: str = Field(...)
|
size: str = Field(...)
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
duration: int = Field(5, ge=5, le=10)
|
duration: int = Field(5, ge=5, le=15)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(True)
|
||||||
audio: bool = Field(False, description="Should be audio generated automatically")
|
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||||
|
shot_type: str = Field("single")
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoParametersField(BaseModel):
|
class Image2VideoParametersField(BaseModel):
|
||||||
resolution: str = Field(...)
|
resolution: str = Field(...)
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
duration: int = Field(5, ge=5, le=10)
|
duration: int = Field(5, ge=5, le=15)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(True)
|
||||||
audio: bool = Field(False, description="Should be audio generated automatically")
|
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||||
|
shot_type: str = Field("single")
|
||||||
|
|
||||||
|
|
||||||
class Text2ImageTaskCreationRequest(BaseModel):
|
class Text2ImageTaskCreationRequest(BaseModel):
|
||||||
@ -106,39 +106,39 @@ class TaskCreationOutputField(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TaskCreationResponse(BaseModel):
|
class TaskCreationResponse(BaseModel):
|
||||||
output: Optional[TaskCreationOutputField] = Field(None)
|
output: TaskCreationOutputField | None = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
code: Optional[str] = Field(None, description="The error code of the failed request.")
|
code: str | None = Field(None, description="Error code for the failed request.")
|
||||||
message: Optional[str] = Field(None, description="Details of the failed request.")
|
message: str | None = Field(None, description="Details about the failed request.")
|
||||||
|
|
||||||
|
|
||||||
class TaskResult(BaseModel):
|
class TaskResult(BaseModel):
|
||||||
url: Optional[str] = Field(None)
|
url: str | None = Field(None)
|
||||||
code: Optional[str] = Field(None)
|
code: str | None = Field(None)
|
||||||
message: Optional[str] = Field(None)
|
message: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
task_status: str = Field(...)
|
task_status: str = Field(...)
|
||||||
results: Optional[list[TaskResult]] = Field(None)
|
results: list[TaskResult] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
task_status: str = Field(...)
|
task_status: str = Field(...)
|
||||||
video_url: Optional[str] = Field(None)
|
video_url: str | None = Field(None)
|
||||||
code: Optional[str] = Field(None)
|
code: str | None = Field(None)
|
||||||
message: Optional[str] = Field(None)
|
message: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskStatusResponse(BaseModel):
|
class ImageTaskStatusResponse(BaseModel):
|
||||||
output: Optional[ImageTaskStatusOutputField] = Field(None)
|
output: ImageTaskStatusOutputField | None = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class VideoTaskStatusResponse(BaseModel):
|
class VideoTaskStatusResponse(BaseModel):
|
||||||
output: Optional[VideoTaskStatusOutputField] = Field(None)
|
output: VideoTaskStatusOutputField | None = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
@ -152,7 +152,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
node_id="WanTextToImageApi",
|
node_id="WanTextToImageApi",
|
||||||
display_name="Wan Text to Image",
|
display_name="Wan Text to Image",
|
||||||
category="api node/image/Wan",
|
category="api node/image/Wan",
|
||||||
description="Generates image based on text prompt.",
|
description="Generates an image based on a text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
@ -164,13 +164,13 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -209,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -252,7 +252,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -272,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
display_name="Wan Image to Image",
|
display_name="Wan Image to Image",
|
||||||
category="api node/image/Wan",
|
category="api node/image/Wan",
|
||||||
description="Generates an image from one or two input images and a text prompt. "
|
description="Generates an image from one or two input images and a text prompt. "
|
||||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
@ -282,19 +282,19 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
|
tooltip="Single-image editing or multi-image fusion. Maximum 2 images.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
# redo this later as an optional combo of recommended resolutions
|
# redo this later as an optional combo of recommended resolutions
|
||||||
@ -328,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -347,7 +347,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
# width: int = 1024,
|
# width: int = 1024,
|
||||||
@ -357,7 +357,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
):
|
):
|
||||||
n_images = get_number_of_images(image)
|
n_images = get_number_of_images(image)
|
||||||
if n_images not in (1, 2):
|
if n_images not in (1, 2):
|
||||||
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.")
|
||||||
images = []
|
images = []
|
||||||
for i in image:
|
for i in image:
|
||||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
||||||
@ -376,7 +376,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -395,25 +395,25 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
node_id="WanTextToVideoApi",
|
node_id="WanTextToVideoApi",
|
||||||
display_name="Wan Text to Video",
|
display_name="Wan Text to Video",
|
||||||
category="api node/video/Wan",
|
category="api node/video/Wan",
|
||||||
description="Generates video based on text prompt.",
|
description="Generates a video based on a text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-t2v-preview"],
|
options=["wan2.5-t2v-preview", "wan2.6-t2v"],
|
||||||
default="wan2.5-t2v-preview",
|
default="wan2.6-t2v",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -433,23 +433,23 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
"1080p: 4:3 (1632x1248)",
|
"1080p: 4:3 (1632x1248)",
|
||||||
"1080p: 3:4 (1248x1632)",
|
"1080p: 3:4 (1248x1632)",
|
||||||
],
|
],
|
||||||
default="480p: 1:1 (624x624)",
|
default="720p: 1:1 (960x960)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=10,
|
max=15,
|
||||||
step=5,
|
step=5,
|
||||||
display_mode=IO.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Available durations: 5 and 10 seconds",
|
tooltip="A 15-second duration is available only for the Wan 2.6 model.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Audio.Input(
|
IO.Audio.Input(
|
||||||
"audio",
|
"audio",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -466,7 +466,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="If there is no audio input, generate audio automatically.",
|
tooltip="If no audio input is provided, generate audio automatically.",
|
||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
@ -477,7 +477,15 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"shot_type",
|
||||||
|
options=["single", "multi"],
|
||||||
|
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||||
|
"single continuous shot or multiple shots with cuts. "
|
||||||
|
"This parameter takes effect only when prompt_extend is True.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -498,14 +506,19 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
size: str = "480p: 1:1 (624x624)",
|
size: str = "720p: 1:1 (960x960)",
|
||||||
duration: int = 5,
|
duration: int = 5,
|
||||||
audio: Optional[Input.Audio] = None,
|
audio: Input.Audio | None = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = True,
|
||||||
|
shot_type: str = "single",
|
||||||
):
|
):
|
||||||
|
if "480p" in size and model == "wan2.6-t2v":
|
||||||
|
raise ValueError("The Wan 2.6 model does not support 480p.")
|
||||||
|
if duration == 15 and model == "wan2.5-t2v-preview":
|
||||||
|
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
|
||||||
width, height = RES_IN_PARENS.search(size).groups()
|
width, height = RES_IN_PARENS.search(size).groups()
|
||||||
audio_url = None
|
audio_url = None
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
@ -526,11 +539,12 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
audio=generate_audio,
|
audio=generate_audio,
|
||||||
prompt_extend=prompt_extend,
|
prompt_extend=prompt_extend,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
shot_type=shot_type,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -549,12 +563,12 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
node_id="WanImageToVideoApi",
|
node_id="WanImageToVideoApi",
|
||||||
display_name="Wan Image to Video",
|
display_name="Wan Image to Video",
|
||||||
category="api node/video/Wan",
|
category="api node/video/Wan",
|
||||||
description="Generates video based on the first frame and text prompt.",
|
description="Generates a video from the first frame and a text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-i2v-preview"],
|
options=["wan2.5-i2v-preview", "wan2.6-i2v"],
|
||||||
default="wan2.5-i2v-preview",
|
default="wan2.6-i2v",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
@ -564,13 +578,13 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -580,23 +594,23 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"720P",
|
"720P",
|
||||||
"1080P",
|
"1080P",
|
||||||
],
|
],
|
||||||
default="480P",
|
default="720P",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=10,
|
max=15,
|
||||||
step=5,
|
step=5,
|
||||||
display_mode=IO.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Available durations: 5 and 10 seconds",
|
tooltip="Duration 15 available only for WAN2.6 model.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Audio.Input(
|
IO.Audio.Input(
|
||||||
"audio",
|
"audio",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -613,7 +627,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="If there is no audio input, generate audio automatically.",
|
tooltip="If no audio input is provided, generate audio automatically.",
|
||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
@ -624,7 +638,15 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"shot_type",
|
||||||
|
options=["single", "multi"],
|
||||||
|
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||||
|
"single continuous shot or multiple shots with cuts. "
|
||||||
|
"This parameter takes effect only when prompt_extend is True.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -643,19 +665,24 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
resolution: str = "480P",
|
resolution: str = "720P",
|
||||||
duration: int = 5,
|
duration: int = 5,
|
||||||
audio: Optional[Input.Audio] = None,
|
audio: Input.Audio | None = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = True,
|
||||||
|
shot_type: str = "single",
|
||||||
):
|
):
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Exactly one input image is required.")
|
raise ValueError("Exactly one input image is required.")
|
||||||
|
if "480P" in resolution and model == "wan2.6-i2v":
|
||||||
|
raise ValueError("The Wan 2.6 model does not support 480P.")
|
||||||
|
if duration == 15 and model == "wan2.5-i2v-preview":
|
||||||
|
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
|
||||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
||||||
audio_url = None
|
audio_url = None
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
@ -677,11 +704,12 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
audio=generate_audio,
|
audio=generate_audio,
|
||||||
prompt_extend=prompt_extend,
|
prompt_extend=prompt_extend,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
shot_type=shot_type,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
|
|||||||
@ -129,7 +129,7 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
|||||||
return img_byte_arr
|
return img_byte_arr
|
||||||
|
|
||||||
|
|
||||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor:
|
||||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||||
samples = image.movedim(-1, 1)
|
samples = image.movedim(-1, 1)
|
||||||
total = int(total_pixels)
|
total = int(total_pixels)
|
||||||
|
|||||||
@ -671,7 +671,16 @@ class SamplerSEEDS2(io.ComfyNode):
|
|||||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||||
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||||
],
|
],
|
||||||
outputs=[io.Sampler.Output()]
|
outputs=[io.Sampler.Output()],
|
||||||
|
description=(
|
||||||
|
"This sampler node can represent multiple samplers:\n\n"
|
||||||
|
"seeds_2\n"
|
||||||
|
"- default setting\n\n"
|
||||||
|
"exp_heun_2_x0\n"
|
||||||
|
"- solver_type=phi_2, r=1.0, eta=0.0\n\n"
|
||||||
|
"exp_heun_2_x0_sde\n"
|
||||||
|
"- solver_type=phi_2, r=1.0, eta=1.0, s_noise=1.0"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -313,22 +313,46 @@ class ZImageControlPatch:
|
|||||||
self.inpaint_image = inpaint_image
|
self.inpaint_image = inpaint_image
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.encoded_image = self.encode_latent_cond(image)
|
self.is_inpaint = self.model_patch.model.additional_in_dim > 0
|
||||||
self.encoded_image_size = (image.shape[1], image.shape[2])
|
|
||||||
|
skip_encoding = False
|
||||||
|
if self.image is not None and self.inpaint_image is not None:
|
||||||
|
if self.image.shape != self.inpaint_image.shape:
|
||||||
|
skip_encoding = True
|
||||||
|
|
||||||
|
if skip_encoding:
|
||||||
|
self.encoded_image = None
|
||||||
|
else:
|
||||||
|
self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image)
|
||||||
|
if self.image is None:
|
||||||
|
self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2])
|
||||||
|
else:
|
||||||
|
self.encoded_image_size = (self.image.shape[1], self.image.shape[2])
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
|
|
||||||
def encode_latent_cond(self, control_image, inpaint_image=None):
|
def encode_latent_cond(self, control_image=None, inpaint_image=None):
|
||||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
latent_image = None
|
||||||
if self.model_patch.model.additional_in_dim > 0:
|
if control_image is not None:
|
||||||
if self.mask is None:
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
||||||
mask_ = torch.zeros_like(latent_image)[:, :1]
|
|
||||||
else:
|
if self.is_inpaint:
|
||||||
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
|
||||||
if inpaint_image is None:
|
if inpaint_image is None:
|
||||||
inpaint_image = torch.ones_like(control_image) * 0.5
|
inpaint_image = torch.ones_like(control_image) * 0.5
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center")
|
||||||
|
inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5
|
||||||
|
|
||||||
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
||||||
|
|
||||||
|
if self.mask is None:
|
||||||
|
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
|
||||||
|
else:
|
||||||
|
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
|
||||||
|
|
||||||
|
if latent_image is None:
|
||||||
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
|
||||||
|
|
||||||
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
||||||
else:
|
else:
|
||||||
return latent_image
|
return latent_image
|
||||||
@ -344,13 +368,18 @@ class ZImageControlPatch:
|
|||||||
block_type = kwargs.get("block_type", "")
|
block_type = kwargs.get("block_type", "")
|
||||||
spacial_compression = self.vae.spacial_compression_encode()
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
image_scaled = None
|
||||||
|
if self.image is not None:
|
||||||
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||||
|
self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2])
|
||||||
|
|
||||||
inpaint_scaled = None
|
inpaint_scaled = None
|
||||||
if self.inpaint_image is not None:
|
if self.inpaint_image is not None:
|
||||||
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||||
|
self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2])
|
||||||
|
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled)
|
self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled)
|
||||||
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
cnet_blocks = self.model_patch.model.n_control_layers
|
cnet_blocks = self.model_patch.model.n_control_layers
|
||||||
@ -391,7 +420,8 @@ class ZImageControlPatch:
|
|||||||
|
|
||||||
def to(self, device_or_dtype):
|
def to(self, device_or_dtype):
|
||||||
if isinstance(device_or_dtype, torch.device):
|
if isinstance(device_or_dtype, torch.device):
|
||||||
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
if self.encoded_image is not None:
|
||||||
|
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -414,9 +444,12 @@ class QwenImageDiffsynthControlnet:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders/qwen"
|
CATEGORY = "advanced/loaders/qwen"
|
||||||
|
|
||||||
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
|
def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None):
|
||||||
model_patched = model.clone()
|
model_patched = model.clone()
|
||||||
image = image[:, :, :, :3]
|
if image is not None:
|
||||||
|
image = image[:, :, :, :3]
|
||||||
|
if inpaint_image is not None:
|
||||||
|
inpaint_image = inpaint_image[:, :, :, :3]
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if mask.ndim == 3:
|
if mask.ndim == 3:
|
||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
@ -425,13 +458,24 @@ class QwenImageDiffsynthControlnet:
|
|||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
|
|
||||||
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||||
patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask)
|
patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask)
|
||||||
model_patched.set_model_noise_refiner_patch(patch)
|
model_patched.set_model_noise_refiner_patch(patch)
|
||||||
model_patched.set_model_double_block_patch(patch)
|
model_patched.set_model_double_block_patch(patch)
|
||||||
else:
|
else:
|
||||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|
||||||
|
class ZImageFunControlnet(QwenImageDiffsynthControlnet):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"model_patch": ("MODEL_PATCH",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}}
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders/zimage"
|
||||||
|
|
||||||
class UsoStyleProjectorPatch:
|
class UsoStyleProjectorPatch:
|
||||||
def __init__(self, model_patch, encoded_image):
|
def __init__(self, model_patch, encoded_image):
|
||||||
@ -479,5 +523,6 @@ class USOStyleReference:
|
|||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelPatchLoader": ModelPatchLoader,
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
|
"ZImageFunControlnet": ZImageFunControlnet,
|
||||||
"USOStyleReference": USOStyleReference,
|
"USOStyleReference": USOStyleReference,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.4.0"
|
__version__ = "0.5.0"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.4.0"
|
version = "0.5.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.34.8
|
comfyui-frontend-package==1.34.9
|
||||||
comfyui-workflow-templates==0.7.59
|
comfyui-workflow-templates==0.7.59
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user