mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-24 01:12:37 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
8f5fe5dd7d
1
.github/workflows/test-ci.yml
vendored
1
.github/workflows/test-ci.yml
vendored
@ -5,6 +5,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
|
- release/**
|
||||||
paths-ignore:
|
paths-ignore:
|
||||||
- 'app/**'
|
- 'app/**'
|
||||||
- 'input/**'
|
- 'input/**'
|
||||||
|
|||||||
4
.github/workflows/test-execution.yml
vendored
4
.github/workflows/test-execution.yml
vendored
@ -2,9 +2,9 @@ name: Execution Tests
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
4
.github/workflows/test-launch.yml
vendored
4
.github/workflows/test-launch.yml
vendored
@ -2,9 +2,9 @@ name: Test server launches without errors
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
4
.github/workflows/test-unit.yml
vendored
4
.github/workflows/test-unit.yml
vendored
@ -2,9 +2,9 @@ name: Unit Tests
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main, master ]
|
branches: [ main, master, release/** ]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
1
.github/workflows/update-version.yml
vendored
1
.github/workflows/update-version.yml
vendored
@ -6,6 +6,7 @@ on:
|
|||||||
- "pyproject.toml"
|
- "pyproject.toml"
|
||||||
branches:
|
branches:
|
||||||
- master
|
- master
|
||||||
|
- release/**
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
update-version:
|
update-version:
|
||||||
|
|||||||
@ -634,8 +634,11 @@ class NextDiT(nn.Module):
|
|||||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||||
freqs_cis = freqs_cis.to(img.device)
|
freqs_cis = freqs_cis.to(img.device)
|
||||||
|
|
||||||
|
transformer_options["total_blocks"] = len(self.layers)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
img_input = img
|
img_input = img
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||||
if "double_block" in patches:
|
if "double_block" in patches:
|
||||||
for p in patches["double_block"]:
|
for p in patches["double_block"]:
|
||||||
|
|||||||
@ -322,6 +322,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
pooled_projection_dim: int = 768,
|
pooled_projection_dim: int = 768,
|
||||||
guidance_embeds: bool = False,
|
guidance_embeds: bool = False,
|
||||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||||
|
default_ref_method="index",
|
||||||
image_model=None,
|
image_model=None,
|
||||||
final_layer=True,
|
final_layer=True,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -334,6 +335,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels or in_channels
|
self.out_channels = out_channels or in_channels
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.default_ref_method = default_ref_method
|
||||||
|
|
||||||
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||||
|
|
||||||
@ -361,6 +363,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
for _ in range(num_layers)
|
for _ in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
|
if self.default_ref_method == "index_timestep_zero":
|
||||||
|
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
|
||||||
|
|
||||||
if final_layer:
|
if final_layer:
|
||||||
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
|
||||||
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
|
||||||
@ -416,7 +421,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
h = 0
|
h = 0
|
||||||
w = 0
|
w = 0
|
||||||
index = 0
|
index = 0
|
||||||
ref_method = kwargs.get("ref_latents_method", "index")
|
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
||||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||||
timestep_zero = ref_method == "index_timestep_zero"
|
timestep_zero = ref_method == "index_timestep_zero"
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
|
|||||||
@ -568,7 +568,10 @@ class WanModel(torch.nn.Module):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -763,7 +766,10 @@ class VaceWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -862,7 +868,10 @@ class CameraWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(x, e=e0, freqs=freqs, context=context)
|
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
|
||||||
if audio_emb is not None:
|
if audio_emb is not None:
|
||||||
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
|
||||||
# head
|
# head
|
||||||
@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -523,7 +523,10 @@ class AnimateWanModel(WanModel):
|
|||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
transformer_options["total_blocks"] = len(self.blocks)
|
||||||
|
transformer_options["block_type"] = "double"
|
||||||
for i, block in enumerate(self.blocks):
|
for i, block in enumerate(self.blocks):
|
||||||
|
transformer_options["block_index"] = i
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -259,7 +259,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
if "__x0__" in state_dict_keys: # x0 pred
|
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||||
dit_config["use_x0"] = True
|
dit_config["use_x0"] = True
|
||||||
else:
|
else:
|
||||||
dit_config["use_x0"] = False
|
dit_config["use_x0"] = False
|
||||||
@ -618,6 +618,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["image_model"] = "qwen_image"
|
dit_config["image_model"] = "qwen_image"
|
||||||
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
dit_config["in_channels"] = state_dict['{}img_in.weight'.format(key_prefix)].shape[1]
|
||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
|
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||||
|
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from . import supported_models_base
|
|||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
|
|
||||||
from . import diffusers_convert
|
from . import diffusers_convert
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
class SD15(supported_models_base.BASE):
|
class SD15(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
@ -1028,7 +1029,13 @@ class ZImage(Lumina2):
|
|||||||
|
|
||||||
memory_usage_factor = 2.0
|
memory_usage_factor = 2.0
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
if comfy.model_management.extended_fp16_support():
|
||||||
|
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
|
||||||
|
self.supported_inference_dtypes.insert(1, torch.float16)
|
||||||
|
|
||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
|||||||
@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union
|
|||||||
from pydantic import BaseModel, Field, RootModel
|
from pydantic import BaseModel, Field, RootModel
|
||||||
|
|
||||||
class TripoModelVersion(str, Enum):
|
class TripoModelVersion(str, Enum):
|
||||||
|
v3_0_20250812 = 'v3.0-20250812'
|
||||||
v2_5_20250123 = 'v2.5-20250123'
|
v2_5_20250123 = 'v2.5-20250123'
|
||||||
v2_0_20240919 = 'v2.0-20240919'
|
v2_0_20240919 = 'v2.0-20240919'
|
||||||
v1_4_20240625 = 'v1.4-20240625'
|
v1_4_20240625 = 'v1.4-20240625'
|
||||||
|
|
||||||
|
|
||||||
|
class TripoGeometryQuality(str, Enum):
|
||||||
|
standard = 'standard'
|
||||||
|
detailed = 'detailed'
|
||||||
|
|
||||||
|
|
||||||
class TripoTextureQuality(str, Enum):
|
class TripoTextureQuality(str, Enum):
|
||||||
standard = 'standard'
|
standard = 'standard'
|
||||||
detailed = 'detailed'
|
detailed = 'detailed'
|
||||||
@ -61,14 +67,20 @@ class TripoSpec(str, Enum):
|
|||||||
class TripoAnimation(str, Enum):
|
class TripoAnimation(str, Enum):
|
||||||
IDLE = "preset:idle"
|
IDLE = "preset:idle"
|
||||||
WALK = "preset:walk"
|
WALK = "preset:walk"
|
||||||
|
RUN = "preset:run"
|
||||||
|
DIVE = "preset:dive"
|
||||||
CLIMB = "preset:climb"
|
CLIMB = "preset:climb"
|
||||||
JUMP = "preset:jump"
|
JUMP = "preset:jump"
|
||||||
RUN = "preset:run"
|
|
||||||
SLASH = "preset:slash"
|
SLASH = "preset:slash"
|
||||||
SHOOT = "preset:shoot"
|
SHOOT = "preset:shoot"
|
||||||
HURT = "preset:hurt"
|
HURT = "preset:hurt"
|
||||||
FALL = "preset:fall"
|
FALL = "preset:fall"
|
||||||
TURN = "preset:turn"
|
TURN = "preset:turn"
|
||||||
|
QUADRUPED_WALK = "preset:quadruped:walk"
|
||||||
|
HEXAPOD_WALK = "preset:hexapod:walk"
|
||||||
|
OCTOPOD_WALK = "preset:octopod:walk"
|
||||||
|
SERPENTINE_MARCH = "preset:serpentine:march"
|
||||||
|
AQUATIC_MARCH = "preset:aquatic:march"
|
||||||
|
|
||||||
class TripoStylizeStyle(str, Enum):
|
class TripoStylizeStyle(str, Enum):
|
||||||
LEGO = "lego"
|
LEGO = "lego"
|
||||||
@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum):
|
|||||||
BANNED = "banned"
|
BANNED = "banned"
|
||||||
EXPIRED = "expired"
|
EXPIRED = "expired"
|
||||||
|
|
||||||
|
class TripoFbxPreset(str, Enum):
|
||||||
|
BLENDER = "blender"
|
||||||
|
MIXAMO = "mixamo"
|
||||||
|
_3DSMAX = "3dsmax"
|
||||||
|
|
||||||
class TripoFileTokenReference(BaseModel):
|
class TripoFileTokenReference(BaseModel):
|
||||||
type: Optional[str] = Field(None, description='The type of the reference')
|
type: Optional[str] = Field(None, description='The type of the reference')
|
||||||
file_token: str
|
file_token: str
|
||||||
@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel):
|
|||||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
|
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||||
style: Optional[TripoStyle] = None
|
style: Optional[TripoStyle] = None
|
||||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||||
@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel):
|
|||||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
|
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||||
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
||||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel):
|
|||||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||||
|
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||||
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
||||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||||
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
||||||
@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel):
|
|||||||
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
||||||
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
||||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
|
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
|
||||||
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
|
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
|
||||||
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
|
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
|
||||||
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
|
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
|
||||||
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
|
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
|
||||||
texture_size: Optional[int] = Field(4096, description='The size of the texture')
|
texture_size: Optional[int] = Field(None, description='The size of the texture')
|
||||||
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
||||||
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
|
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
|
||||||
|
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
|
||||||
|
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
|
||||||
|
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
|
||||||
|
bake: Optional[bool] = Field(None, description='Whether to bake the model')
|
||||||
|
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
|
||||||
|
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
|
||||||
|
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
|
||||||
|
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
|
||||||
|
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
|
||||||
|
|
||||||
|
|
||||||
class TripoTaskRequest(RootModel):
|
class TripoTaskRequest(RootModel):
|
||||||
root: Union[
|
root: Union[
|
||||||
|
|||||||
@ -102,8 +102,9 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
IO.Int.Input("model_seed", default=42, optional=True),
|
IO.Int.Input("model_seed", default=42, optional=True),
|
||||||
IO.Int.Input("texture_seed", default=42, optional=True),
|
IO.Int.Input("texture_seed", default=42, optional=True),
|
||||||
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
|
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
|
||||||
IO.Boolean.Input("quad", default=False, optional=True),
|
IO.Boolean.Input("quad", default=False, optional=True),
|
||||||
|
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.String.Output(display_name="model_file"),
|
IO.String.Output(display_name="model_file"),
|
||||||
@ -131,6 +132,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
model_seed: Optional[int] = None,
|
model_seed: Optional[int] = None,
|
||||||
texture_seed: Optional[int] = None,
|
texture_seed: Optional[int] = None,
|
||||||
texture_quality: Optional[str] = None,
|
texture_quality: Optional[str] = None,
|
||||||
|
geometry_quality: Optional[str] = None,
|
||||||
face_limit: Optional[int] = None,
|
face_limit: Optional[int] = None,
|
||||||
quad: Optional[bool] = None,
|
quad: Optional[bool] = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
@ -154,6 +156,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
|||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit,
|
||||||
|
geometry_quality=geometry_quality,
|
||||||
auto_size=True,
|
auto_size=True,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
),
|
),
|
||||||
@ -194,6 +197,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
||||||
IO.Boolean.Input("quad", default=False, optional=True),
|
IO.Boolean.Input("quad", default=False, optional=True),
|
||||||
|
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.String.Output(display_name="model_file"),
|
IO.String.Output(display_name="model_file"),
|
||||||
@ -220,6 +224,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
orientation=None,
|
orientation=None,
|
||||||
texture_seed: Optional[int] = None,
|
texture_seed: Optional[int] = None,
|
||||||
texture_quality: Optional[str] = None,
|
texture_quality: Optional[str] = None,
|
||||||
|
geometry_quality: Optional[str] = None,
|
||||||
texture_alignment: Optional[str] = None,
|
texture_alignment: Optional[str] = None,
|
||||||
face_limit: Optional[int] = None,
|
face_limit: Optional[int] = None,
|
||||||
quad: Optional[bool] = None,
|
quad: Optional[bool] = None,
|
||||||
@ -246,6 +251,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
|||||||
pbr=pbr,
|
pbr=pbr,
|
||||||
model_seed=model_seed,
|
model_seed=model_seed,
|
||||||
orientation=orientation,
|
orientation=orientation,
|
||||||
|
geometry_quality=geometry_quality,
|
||||||
texture_alignment=texture_alignment,
|
texture_alignment=texture_alignment,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
@ -295,6 +301,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
|
||||||
IO.Boolean.Input("quad", default=False, optional=True),
|
IO.Boolean.Input("quad", default=False, optional=True),
|
||||||
|
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.String.Output(display_name="model_file"),
|
IO.String.Output(display_name="model_file"),
|
||||||
@ -323,6 +330,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
model_seed: Optional[int] = None,
|
model_seed: Optional[int] = None,
|
||||||
texture_seed: Optional[int] = None,
|
texture_seed: Optional[int] = None,
|
||||||
texture_quality: Optional[str] = None,
|
texture_quality: Optional[str] = None,
|
||||||
|
geometry_quality: Optional[str] = None,
|
||||||
texture_alignment: Optional[str] = None,
|
texture_alignment: Optional[str] = None,
|
||||||
face_limit: Optional[int] = None,
|
face_limit: Optional[int] = None,
|
||||||
quad: Optional[bool] = None,
|
quad: Optional[bool] = None,
|
||||||
@ -359,6 +367,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
|||||||
model_seed=model_seed,
|
model_seed=model_seed,
|
||||||
texture_seed=texture_seed,
|
texture_seed=texture_seed,
|
||||||
texture_quality=texture_quality,
|
texture_quality=texture_quality,
|
||||||
|
geometry_quality=geometry_quality,
|
||||||
texture_alignment=texture_alignment,
|
texture_alignment=texture_alignment,
|
||||||
face_limit=face_limit,
|
face_limit=face_limit,
|
||||||
quad=quad,
|
quad=quad,
|
||||||
@ -508,6 +517,8 @@ class TripoRetargetNode(IO.ComfyNode):
|
|||||||
options=[
|
options=[
|
||||||
"preset:idle",
|
"preset:idle",
|
||||||
"preset:walk",
|
"preset:walk",
|
||||||
|
"preset:run",
|
||||||
|
"preset:dive",
|
||||||
"preset:climb",
|
"preset:climb",
|
||||||
"preset:jump",
|
"preset:jump",
|
||||||
"preset:slash",
|
"preset:slash",
|
||||||
@ -515,6 +526,11 @@ class TripoRetargetNode(IO.ComfyNode):
|
|||||||
"preset:hurt",
|
"preset:hurt",
|
||||||
"preset:fall",
|
"preset:fall",
|
||||||
"preset:turn",
|
"preset:turn",
|
||||||
|
"preset:quadruped:walk",
|
||||||
|
"preset:hexapod:walk",
|
||||||
|
"preset:octopod:walk",
|
||||||
|
"preset:serpentine:march",
|
||||||
|
"preset:aquatic:march"
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -563,7 +579,7 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
"face_limit",
|
"face_limit",
|
||||||
default=-1,
|
default=-1,
|
||||||
min=-1,
|
min=-1,
|
||||||
max=500000,
|
max=2000000,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -579,6 +595,40 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
default="JPEG",
|
default="JPEG",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Boolean.Input("force_symmetry", default=False, optional=True),
|
||||||
|
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
|
||||||
|
IO.Float.Input(
|
||||||
|
"flatten_bottom_threshold",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
max=1.0,
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
|
||||||
|
IO.Float.Input(
|
||||||
|
"scale_factor",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input("with_animation", default=False, optional=True),
|
||||||
|
IO.Boolean.Input("pack_uv", default=False, optional=True),
|
||||||
|
IO.Boolean.Input("bake", default=False, optional=True),
|
||||||
|
IO.String.Input("part_names", default="", optional=True), # comma-separated list
|
||||||
|
IO.Combo.Input(
|
||||||
|
"fbx_preset",
|
||||||
|
options=["blender", "mixamo", "3dsmax"],
|
||||||
|
default="blender",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"export_orientation",
|
||||||
|
options=["align_image", "default"],
|
||||||
|
default="default",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Boolean.Input("animate_in_place", default=False, optional=True),
|
||||||
],
|
],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
hidden=[
|
hidden=[
|
||||||
@ -604,12 +654,31 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
original_model_task_id,
|
original_model_task_id,
|
||||||
format: str,
|
format: str,
|
||||||
quad: bool,
|
quad: bool,
|
||||||
|
force_symmetry: bool,
|
||||||
face_limit: int,
|
face_limit: int,
|
||||||
|
flatten_bottom: bool,
|
||||||
|
flatten_bottom_threshold: float,
|
||||||
texture_size: int,
|
texture_size: int,
|
||||||
texture_format: str,
|
texture_format: str,
|
||||||
|
pivot_to_center_bottom: bool,
|
||||||
|
scale_factor: float,
|
||||||
|
with_animation: bool,
|
||||||
|
pack_uv: bool,
|
||||||
|
bake: bool,
|
||||||
|
part_names: str,
|
||||||
|
fbx_preset: str,
|
||||||
|
export_vertex_colors: bool,
|
||||||
|
export_orientation: str,
|
||||||
|
animate_in_place: bool,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if not original_model_task_id:
|
if not original_model_task_id:
|
||||||
raise RuntimeError("original_model_task_id is required")
|
raise RuntimeError("original_model_task_id is required")
|
||||||
|
|
||||||
|
# Parse part_names from comma-separated string to list
|
||||||
|
part_names_list = None
|
||||||
|
if part_names and part_names.strip():
|
||||||
|
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
|
||||||
|
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
|
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
|
||||||
@ -618,9 +687,22 @@ class TripoConversionNode(IO.ComfyNode):
|
|||||||
original_model_task_id=original_model_task_id,
|
original_model_task_id=original_model_task_id,
|
||||||
format=format,
|
format=format,
|
||||||
quad=quad if quad else None,
|
quad=quad if quad else None,
|
||||||
|
force_symmetry=force_symmetry if force_symmetry else None,
|
||||||
face_limit=face_limit if face_limit != -1 else None,
|
face_limit=face_limit if face_limit != -1 else None,
|
||||||
|
flatten_bottom=flatten_bottom if flatten_bottom else None,
|
||||||
|
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
|
||||||
texture_size=texture_size if texture_size != 4096 else None,
|
texture_size=texture_size if texture_size != 4096 else None,
|
||||||
texture_format=texture_format if texture_format != "JPEG" else None,
|
texture_format=texture_format if texture_format != "JPEG" else None,
|
||||||
|
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
|
||||||
|
scale_factor=scale_factor if scale_factor != 1.0 else None,
|
||||||
|
with_animation=with_animation if with_animation else None,
|
||||||
|
pack_uv=pack_uv if pack_uv else None,
|
||||||
|
bake=bake if bake else None,
|
||||||
|
part_names=part_names_list,
|
||||||
|
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
|
||||||
|
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
|
||||||
|
export_orientation=export_orientation if export_orientation != "default" else None,
|
||||||
|
animate_in_place=animate_in_place if animate_in_place else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return await poll_until_finished(cls, response, average_duration=30)
|
return await poll_until_finished(cls, response, average_duration=30)
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@ -21,26 +19,26 @@ from comfy_api_nodes.util import (
|
|||||||
|
|
||||||
class Text2ImageInputField(BaseModel):
|
class Text2ImageInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Image2ImageInputField(BaseModel):
|
class Image2ImageInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
images: list[str] = Field(..., min_length=1, max_length=2)
|
images: list[str] = Field(..., min_length=1, max_length=2)
|
||||||
|
|
||||||
|
|
||||||
class Text2VideoInputField(BaseModel):
|
class Text2VideoInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
audio_url: Optional[str] = Field(None)
|
audio_url: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoInputField(BaseModel):
|
class Image2VideoInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: str | None = Field(None)
|
||||||
img_url: str = Field(...)
|
img_url: str = Field(...)
|
||||||
audio_url: Optional[str] = Field(None)
|
audio_url: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class Txt2ImageParametersField(BaseModel):
|
class Txt2ImageParametersField(BaseModel):
|
||||||
@ -52,7 +50,7 @@ class Txt2ImageParametersField(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Image2ImageParametersField(BaseModel):
|
class Image2ImageParametersField(BaseModel):
|
||||||
size: Optional[str] = Field(None)
|
size: str | None = Field(None)
|
||||||
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
n: int = Field(1, description="Number of images to generate.") # we support only value=1
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(True)
|
||||||
@ -61,19 +59,21 @@ class Image2ImageParametersField(BaseModel):
|
|||||||
class Text2VideoParametersField(BaseModel):
|
class Text2VideoParametersField(BaseModel):
|
||||||
size: str = Field(...)
|
size: str = Field(...)
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
duration: int = Field(5, ge=5, le=10)
|
duration: int = Field(5, ge=5, le=15)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(True)
|
||||||
audio: bool = Field(False, description="Should be audio generated automatically")
|
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||||
|
shot_type: str = Field("single")
|
||||||
|
|
||||||
|
|
||||||
class Image2VideoParametersField(BaseModel):
|
class Image2VideoParametersField(BaseModel):
|
||||||
resolution: str = Field(...)
|
resolution: str = Field(...)
|
||||||
seed: int = Field(..., ge=0, le=2147483647)
|
seed: int = Field(..., ge=0, le=2147483647)
|
||||||
duration: int = Field(5, ge=5, le=10)
|
duration: int = Field(5, ge=5, le=15)
|
||||||
prompt_extend: bool = Field(True)
|
prompt_extend: bool = Field(True)
|
||||||
watermark: bool = Field(True)
|
watermark: bool = Field(True)
|
||||||
audio: bool = Field(False, description="Should be audio generated automatically")
|
audio: bool = Field(False, description="Whether to generate audio automatically.")
|
||||||
|
shot_type: str = Field("single")
|
||||||
|
|
||||||
|
|
||||||
class Text2ImageTaskCreationRequest(BaseModel):
|
class Text2ImageTaskCreationRequest(BaseModel):
|
||||||
@ -106,39 +106,39 @@ class TaskCreationOutputField(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TaskCreationResponse(BaseModel):
|
class TaskCreationResponse(BaseModel):
|
||||||
output: Optional[TaskCreationOutputField] = Field(None)
|
output: TaskCreationOutputField | None = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
code: Optional[str] = Field(None, description="The error code of the failed request.")
|
code: str | None = Field(None, description="Error code for the failed request.")
|
||||||
message: Optional[str] = Field(None, description="Details of the failed request.")
|
message: str | None = Field(None, description="Details about the failed request.")
|
||||||
|
|
||||||
|
|
||||||
class TaskResult(BaseModel):
|
class TaskResult(BaseModel):
|
||||||
url: Optional[str] = Field(None)
|
url: str | None = Field(None)
|
||||||
code: Optional[str] = Field(None)
|
code: str | None = Field(None)
|
||||||
message: Optional[str] = Field(None)
|
message: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
class ImageTaskStatusOutputField(TaskCreationOutputField):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
task_status: str = Field(...)
|
task_status: str = Field(...)
|
||||||
results: Optional[list[TaskResult]] = Field(None)
|
results: list[TaskResult] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
class VideoTaskStatusOutputField(TaskCreationOutputField):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
task_status: str = Field(...)
|
task_status: str = Field(...)
|
||||||
video_url: Optional[str] = Field(None)
|
video_url: str | None = Field(None)
|
||||||
code: Optional[str] = Field(None)
|
code: str | None = Field(None)
|
||||||
message: Optional[str] = Field(None)
|
message: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ImageTaskStatusResponse(BaseModel):
|
class ImageTaskStatusResponse(BaseModel):
|
||||||
output: Optional[ImageTaskStatusOutputField] = Field(None)
|
output: ImageTaskStatusOutputField | None = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class VideoTaskStatusResponse(BaseModel):
|
class VideoTaskStatusResponse(BaseModel):
|
||||||
output: Optional[VideoTaskStatusOutputField] = Field(None)
|
output: VideoTaskStatusOutputField | None = Field(None)
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
@ -152,7 +152,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
node_id="WanTextToImageApi",
|
node_id="WanTextToImageApi",
|
||||||
display_name="Wan Text to Image",
|
display_name="Wan Text to Image",
|
||||||
category="api node/image/Wan",
|
category="api node/image/Wan",
|
||||||
description="Generates image based on text prompt.",
|
description="Generates an image based on a text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
@ -164,13 +164,13 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -209,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -252,7 +252,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -272,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
display_name="Wan Image to Image",
|
display_name="Wan Image to Image",
|
||||||
category="api node/image/Wan",
|
category="api node/image/Wan",
|
||||||
description="Generates an image from one or two input images and a text prompt. "
|
description="Generates an image from one or two input images and a text prompt. "
|
||||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
@ -282,19 +282,19 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
|
tooltip="Single-image editing or multi-image fusion. Maximum 2 images.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
# redo this later as an optional combo of recommended resolutions
|
# redo this later as an optional combo of recommended resolutions
|
||||||
@ -328,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -347,7 +347,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
# width: int = 1024,
|
# width: int = 1024,
|
||||||
@ -357,7 +357,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
):
|
):
|
||||||
n_images = get_number_of_images(image)
|
n_images = get_number_of_images(image)
|
||||||
if n_images not in (1, 2):
|
if n_images not in (1, 2):
|
||||||
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
raise ValueError(f"Expected 1 or 2 input images, but got {n_images}.")
|
||||||
images = []
|
images = []
|
||||||
for i in image:
|
for i in image:
|
||||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
||||||
@ -376,7 +376,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -395,25 +395,25 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
node_id="WanTextToVideoApi",
|
node_id="WanTextToVideoApi",
|
||||||
display_name="Wan Text to Video",
|
display_name="Wan Text to Video",
|
||||||
category="api node/video/Wan",
|
category="api node/video/Wan",
|
||||||
description="Generates video based on text prompt.",
|
description="Generates a video based on a text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-t2v-preview"],
|
options=["wan2.5-t2v-preview", "wan2.6-t2v"],
|
||||||
default="wan2.5-t2v-preview",
|
default="wan2.6-t2v",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -433,23 +433,23 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
"1080p: 4:3 (1632x1248)",
|
"1080p: 4:3 (1632x1248)",
|
||||||
"1080p: 3:4 (1248x1632)",
|
"1080p: 3:4 (1248x1632)",
|
||||||
],
|
],
|
||||||
default="480p: 1:1 (624x624)",
|
default="720p: 1:1 (960x960)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=10,
|
max=15,
|
||||||
step=5,
|
step=5,
|
||||||
display_mode=IO.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Available durations: 5 and 10 seconds",
|
tooltip="A 15-second duration is available only for the Wan 2.6 model.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Audio.Input(
|
IO.Audio.Input(
|
||||||
"audio",
|
"audio",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -466,7 +466,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="If there is no audio input, generate audio automatically.",
|
tooltip="If no audio input is provided, generate audio automatically.",
|
||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
@ -477,7 +477,15 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"shot_type",
|
||||||
|
options=["single", "multi"],
|
||||||
|
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||||
|
"single continuous shot or multiple shots with cuts. "
|
||||||
|
"This parameter takes effect only when prompt_extend is True.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -498,14 +506,19 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
model: str,
|
model: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
size: str = "480p: 1:1 (624x624)",
|
size: str = "720p: 1:1 (960x960)",
|
||||||
duration: int = 5,
|
duration: int = 5,
|
||||||
audio: Optional[Input.Audio] = None,
|
audio: Input.Audio | None = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = True,
|
||||||
|
shot_type: str = "single",
|
||||||
):
|
):
|
||||||
|
if "480p" in size and model == "wan2.6-t2v":
|
||||||
|
raise ValueError("The Wan 2.6 model does not support 480p.")
|
||||||
|
if duration == 15 and model == "wan2.5-t2v-preview":
|
||||||
|
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
|
||||||
width, height = RES_IN_PARENS.search(size).groups()
|
width, height = RES_IN_PARENS.search(size).groups()
|
||||||
audio_url = None
|
audio_url = None
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
@ -526,11 +539,12 @@ class WanTextToVideoApi(IO.ComfyNode):
|
|||||||
audio=generate_audio,
|
audio=generate_audio,
|
||||||
prompt_extend=prompt_extend,
|
prompt_extend=prompt_extend,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
shot_type=shot_type,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
@ -549,12 +563,12 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
node_id="WanImageToVideoApi",
|
node_id="WanImageToVideoApi",
|
||||||
display_name="Wan Image to Video",
|
display_name="Wan Image to Video",
|
||||||
category="api node/video/Wan",
|
category="api node/video/Wan",
|
||||||
description="Generates video based on the first frame and text prompt.",
|
description="Generates a video from the first frame and a text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-i2v-preview"],
|
options=["wan2.5-i2v-preview", "wan2.6-i2v"],
|
||||||
default="wan2.5-i2v-preview",
|
default="wan2.6-i2v",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
IO.Image.Input(
|
IO.Image.Input(
|
||||||
@ -564,13 +578,13 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt describing the elements and visual features. Supports English and Chinese.",
|
||||||
),
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative prompt describing what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -580,23 +594,23 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"720P",
|
"720P",
|
||||||
"1080P",
|
"1080P",
|
||||||
],
|
],
|
||||||
default="480P",
|
default="720P",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=10,
|
max=15,
|
||||||
step=5,
|
step=5,
|
||||||
display_mode=IO.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Available durations: 5 and 10 seconds",
|
tooltip="Duration 15 available only for WAN2.6 model.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
IO.Audio.Input(
|
IO.Audio.Input(
|
||||||
"audio",
|
"audio",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
tooltip="Audio must contain a clear, loud voice, without extraneous noise or background music.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
@ -613,7 +627,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="If there is no audio input, generate audio automatically.",
|
tooltip="If no audio input is provided, generate audio automatically.",
|
||||||
),
|
),
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
@ -624,7 +638,15 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
IO.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
tooltip="Whether to add an AI-generated watermark to the result.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"shot_type",
|
||||||
|
options=["single", "multi"],
|
||||||
|
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
|
||||||
|
"single continuous shot or multiple shots with cuts. "
|
||||||
|
"This parameter takes effect only when prompt_extend is True.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@ -643,19 +665,24 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
model: str,
|
model: str,
|
||||||
image: torch.Tensor,
|
image: Input.Image,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
resolution: str = "480P",
|
resolution: str = "720P",
|
||||||
duration: int = 5,
|
duration: int = 5,
|
||||||
audio: Optional[Input.Audio] = None,
|
audio: Input.Audio | None = None,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = True,
|
||||||
|
shot_type: str = "single",
|
||||||
):
|
):
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Exactly one input image is required.")
|
raise ValueError("Exactly one input image is required.")
|
||||||
|
if "480P" in resolution and model == "wan2.6-i2v":
|
||||||
|
raise ValueError("The Wan 2.6 model does not support 480P.")
|
||||||
|
if duration == 15 and model == "wan2.5-i2v-preview":
|
||||||
|
raise ValueError("A 15-second duration is supported only by the Wan 2.6 model.")
|
||||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
||||||
audio_url = None
|
audio_url = None
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
@ -677,11 +704,12 @@ class WanImageToVideoApi(IO.ComfyNode):
|
|||||||
audio=generate_audio,
|
audio=generate_audio,
|
||||||
prompt_extend=prompt_extend,
|
prompt_extend=prompt_extend,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
|
shot_type=shot_type,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if not initial_response.output:
|
if not initial_response.output:
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
response = await poll_op(
|
response = await poll_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
|
|||||||
@ -248,7 +248,10 @@ class ModelPatchLoader:
|
|||||||
config['n_control_layers'] = 15
|
config['n_control_layers'] = 15
|
||||||
config['additional_in_dim'] = 17
|
config['additional_in_dim'] = 17
|
||||||
config['refiner_control'] = True
|
config['refiner_control'] = True
|
||||||
config['broken'] = True
|
ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None)
|
||||||
|
if ref_weight is not None:
|
||||||
|
if torch.count_nonzero(ref_weight) == 0:
|
||||||
|
config['broken'] = True
|
||||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||||
|
|
||||||
model.load_state_dict(sd)
|
model.load_state_dict(sd)
|
||||||
@ -310,22 +313,46 @@ class ZImageControlPatch:
|
|||||||
self.inpaint_image = inpaint_image
|
self.inpaint_image = inpaint_image
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.encoded_image = self.encode_latent_cond(image)
|
self.is_inpaint = self.model_patch.model.additional_in_dim > 0
|
||||||
self.encoded_image_size = (image.shape[1], image.shape[2])
|
|
||||||
|
skip_encoding = False
|
||||||
|
if self.image is not None and self.inpaint_image is not None:
|
||||||
|
if self.image.shape != self.inpaint_image.shape:
|
||||||
|
skip_encoding = True
|
||||||
|
|
||||||
|
if skip_encoding:
|
||||||
|
self.encoded_image = None
|
||||||
|
else:
|
||||||
|
self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image)
|
||||||
|
if self.image is None:
|
||||||
|
self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2])
|
||||||
|
else:
|
||||||
|
self.encoded_image_size = (self.image.shape[1], self.image.shape[2])
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
|
|
||||||
def encode_latent_cond(self, control_image, inpaint_image=None):
|
def encode_latent_cond(self, control_image=None, inpaint_image=None):
|
||||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
latent_image = None
|
||||||
if self.model_patch.model.additional_in_dim > 0:
|
if control_image is not None:
|
||||||
if self.mask is None:
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
||||||
mask_ = torch.zeros_like(latent_image)[:, :1]
|
|
||||||
else:
|
if self.is_inpaint:
|
||||||
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
|
||||||
if inpaint_image is None:
|
if inpaint_image is None:
|
||||||
inpaint_image = torch.ones_like(control_image) * 0.5
|
inpaint_image = torch.ones_like(control_image) * 0.5
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center")
|
||||||
|
inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5
|
||||||
|
|
||||||
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
||||||
|
|
||||||
|
if self.mask is None:
|
||||||
|
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
|
||||||
|
else:
|
||||||
|
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
|
||||||
|
|
||||||
|
if latent_image is None:
|
||||||
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
|
||||||
|
|
||||||
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
||||||
else:
|
else:
|
||||||
return latent_image
|
return latent_image
|
||||||
@ -341,13 +368,18 @@ class ZImageControlPatch:
|
|||||||
block_type = kwargs.get("block_type", "")
|
block_type = kwargs.get("block_type", "")
|
||||||
spacial_compression = self.vae.spacial_compression_encode()
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
image_scaled = None
|
||||||
|
if self.image is not None:
|
||||||
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||||
|
self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2])
|
||||||
|
|
||||||
inpaint_scaled = None
|
inpaint_scaled = None
|
||||||
if self.inpaint_image is not None:
|
if self.inpaint_image is not None:
|
||||||
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||||
|
self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2])
|
||||||
|
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled)
|
self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled)
|
||||||
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
cnet_blocks = self.model_patch.model.n_control_layers
|
cnet_blocks = self.model_patch.model.n_control_layers
|
||||||
@ -388,7 +420,8 @@ class ZImageControlPatch:
|
|||||||
|
|
||||||
def to(self, device_or_dtype):
|
def to(self, device_or_dtype):
|
||||||
if isinstance(device_or_dtype, torch.device):
|
if isinstance(device_or_dtype, torch.device):
|
||||||
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
if self.encoded_image is not None:
|
||||||
|
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -411,9 +444,12 @@ class QwenImageDiffsynthControlnet:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders/qwen"
|
CATEGORY = "advanced/loaders/qwen"
|
||||||
|
|
||||||
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
|
def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None):
|
||||||
model_patched = model.clone()
|
model_patched = model.clone()
|
||||||
image = image[:, :, :, :3]
|
if image is not None:
|
||||||
|
image = image[:, :, :, :3]
|
||||||
|
if inpaint_image is not None:
|
||||||
|
inpaint_image = inpaint_image[:, :, :, :3]
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if mask.ndim == 3:
|
if mask.ndim == 3:
|
||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
@ -422,13 +458,24 @@ class QwenImageDiffsynthControlnet:
|
|||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
|
|
||||||
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||||
patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask)
|
patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask)
|
||||||
model_patched.set_model_noise_refiner_patch(patch)
|
model_patched.set_model_noise_refiner_patch(patch)
|
||||||
model_patched.set_model_double_block_patch(patch)
|
model_patched.set_model_double_block_patch(patch)
|
||||||
else:
|
else:
|
||||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|
||||||
|
class ZImageFunControlnet(QwenImageDiffsynthControlnet):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"model_patch": ("MODEL_PATCH",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}}
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders/zimage"
|
||||||
|
|
||||||
class UsoStyleProjectorPatch:
|
class UsoStyleProjectorPatch:
|
||||||
def __init__(self, model_patch, encoded_image):
|
def __init__(self, model_patch, encoded_image):
|
||||||
@ -476,5 +523,6 @@ class USOStyleReference:
|
|||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelPatchLoader": ModelPatchLoader,
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
|
"ZImageFunControlnet": ZImageFunControlnet,
|
||||||
"USOStyleReference": USOStyleReference,
|
"USOStyleReference": USOStyleReference,
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.34.8
|
comfyui-frontend-package==1.34.9
|
||||||
comfyui-workflow-templates==0.7.59
|
comfyui-workflow-templates==0.7.59
|
||||||
comfyui-embedded-docs==0.3.1
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user