mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-04 00:37:32 +08:00
Merge branch 'master' into partition-advanced-widgets
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
This commit is contained in:
commit
0358446ef0
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@ -7,6 +7,8 @@ on:
|
||||
jobs:
|
||||
send-webhook:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
steps:
|
||||
- name: Send release webhook
|
||||
env:
|
||||
@ -106,3 +108,37 @@ jobs:
|
||||
--fail --silent --show-error
|
||||
|
||||
echo "✅ Release webhook sent successfully"
|
||||
|
||||
- name: Send repository dispatch to desktop
|
||||
env:
|
||||
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PAYLOAD="$(jq -n \
|
||||
--arg release_tag "$RELEASE_TAG" \
|
||||
--arg release_url "$RELEASE_URL" \
|
||||
'{
|
||||
event_type: "comfyui_release_published",
|
||||
client_payload: {
|
||||
release_tag: $release_tag,
|
||||
release_url: $release_url
|
||||
}
|
||||
}')"
|
||||
|
||||
curl -fsSL \
|
||||
-X POST \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||
-d "$PAYLOAD"
|
||||
|
||||
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,7 +11,7 @@ extra_model_paths.yaml
|
||||
/.vs
|
||||
.vscode/
|
||||
.idea/
|
||||
venv/
|
||||
venv*/
|
||||
.venv/
|
||||
/web/extensions/*
|
||||
!/web/extensions/logging.js.example
|
||||
|
||||
@ -227,7 +227,7 @@ Put your VAE in: models/vae
|
||||
|
||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||
|
||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||
|
||||
|
||||
105
app/node_replace_manager.py
Normal file
105
app/node_replace_manager.py
Normal file
@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
if TYPE_CHECKING:
|
||||
from comfy_api.latest._io_public import NodeReplace
|
||||
|
||||
from comfy_execution.graph_utils import is_link
|
||||
import nodes
|
||||
|
||||
class NodeStruct(TypedDict):
|
||||
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||
class_type: str
|
||||
_meta: dict[str, str]
|
||||
|
||||
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
|
||||
new_node_struct = node_struct.copy()
|
||||
if empty_inputs:
|
||||
new_node_struct["inputs"] = {}
|
||||
else:
|
||||
new_node_struct["inputs"] = node_struct["inputs"].copy()
|
||||
new_node_struct["_meta"] = node_struct["_meta"].copy()
|
||||
return new_node_struct
|
||||
|
||||
|
||||
class NodeReplaceManager:
|
||||
"""Manages node replacement registrations."""
|
||||
|
||||
def __init__(self):
|
||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||
|
||||
def register(self, node_replace: NodeReplace):
|
||||
"""Register a node replacement mapping."""
|
||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||
|
||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||
"""Get replacements for an old node ID."""
|
||||
return self._replacements.get(old_node_id)
|
||||
|
||||
def has_replacement(self, old_node_id: str) -> bool:
|
||||
"""Check if a replacement exists for an old node ID."""
|
||||
return old_node_id in self._replacements
|
||||
|
||||
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||
need_replacement: set[str] = set()
|
||||
for node_number, node_struct in prompt.items():
|
||||
class_type = node_struct["class_type"]
|
||||
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||
need_replacement.add(node_number)
|
||||
# keep track of connections
|
||||
for input_id, input_value in node_struct["inputs"].items():
|
||||
if is_link(input_value):
|
||||
conn_number = input_value[0]
|
||||
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
|
||||
for node_number in need_replacement:
|
||||
node_struct = prompt[node_number]
|
||||
class_type = node_struct["class_type"]
|
||||
replacements = self.get_replacement(class_type)
|
||||
if replacements is None:
|
||||
continue
|
||||
# just use the first replacement
|
||||
replacement = replacements[0]
|
||||
new_node_id = replacement.new_node_id
|
||||
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||
continue
|
||||
# first, replace node id (class_type)
|
||||
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||
new_node_struct["class_type"] = new_node_id
|
||||
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
|
||||
# second, replace inputs
|
||||
if replacement.input_mapping is not None:
|
||||
for input_map in replacement.input_mapping:
|
||||
if "set_value" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
|
||||
elif "old_id" in input_map:
|
||||
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
|
||||
# finalize input replacement
|
||||
prompt[node_number] = new_node_struct
|
||||
# third, replace outputs
|
||||
if replacement.output_mapping is not None:
|
||||
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
|
||||
if node_number in connections:
|
||||
for conns in connections[node_number]:
|
||||
conn_node_number, conn_input_id, old_output_idx = conns
|
||||
for output_map in replacement.output_mapping:
|
||||
if output_map["old_idx"] == old_output_idx:
|
||||
new_output_idx = output_map["new_idx"]
|
||||
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||
previous_input[1] = new_output_idx
|
||||
|
||||
def as_dict(self):
|
||||
"""Serialize all replacements to dict."""
|
||||
return {
|
||||
k: [v.as_dict() for v in v_list]
|
||||
for k, v_list in self._replacements.items()
|
||||
}
|
||||
|
||||
def add_routes(self, routes):
|
||||
@routes.get("/node_replacements")
|
||||
async def get_node_replacements(request):
|
||||
return web.json_response(self.as_dict())
|
||||
@ -1,13 +0,0 @@
|
||||
import pickle
|
||||
|
||||
load = pickle.load
|
||||
|
||||
class Empty:
|
||||
pass
|
||||
|
||||
class Unpickler(pickle.Unpickler):
|
||||
def find_class(self, module, name):
|
||||
#TODO: safe unpickle
|
||||
if module.startswith("pytorch_lightning"):
|
||||
return Empty
|
||||
return super().find_class(module, name)
|
||||
@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
||||
self.model_sampling_current = None
|
||||
super().cleanup()
|
||||
|
||||
|
||||
class QwenFunControlNet(ControlNet):
|
||||
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||
# unchanged while >1 grows more gently.
|
||||
original_strength = self.strength
|
||||
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||
try:
|
||||
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||
finally:
|
||||
self.strength = original_strength
|
||||
|
||||
def pre_run(self, model, percent_to_timestep_function):
|
||||
super().pre_run(model, percent_to_timestep_function)
|
||||
self.set_extra_arg("base_model", model.diffusion_model)
|
||||
|
||||
def copy(self):
|
||||
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||
c.control_model = self.control_model
|
||||
c.control_model_wrapped = self.control_model_wrapped
|
||||
self.copy_to(c)
|
||||
return c
|
||||
|
||||
class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
sd = model_config.process_unet_state_dict(sd)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||
|
||||
in_features = sd["control_img_in.weight"].shape[1]
|
||||
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||
|
||||
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||
|
||||
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||
control_in_features=in_features,
|
||||
inner_dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
operations=operations,
|
||||
device=comfy.model_management.unet_offload_device(),
|
||||
dtype=unet_dtype,
|
||||
)
|
||||
model = controlnet_load_state_dict(model, sd)
|
||||
|
||||
latent_format = comfy.latent_formats.Wan21()
|
||||
control = QwenFunControlNet(
|
||||
model,
|
||||
compression_ratio=1,
|
||||
latent_format=latent_format,
|
||||
# Fun checkpoints already expect their own 33-channel context handling.
|
||||
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||
# and breaks the intended fallback packing path.
|
||||
concat_mask=False,
|
||||
load_device=load_device,
|
||||
manual_cast_dtype=manual_cast_dtype,
|
||||
extra_conds=[],
|
||||
)
|
||||
return control
|
||||
|
||||
def convert_mistoline(sd):
|
||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||
|
||||
@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||
|
||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
import math
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
from scipy import integrate
|
||||
import torch
|
||||
from torch import nn
|
||||
import torchsde
|
||||
from tqdm.auto import trange as trange_, tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from . import utils
|
||||
from . import deis
|
||||
@ -15,34 +14,7 @@ import comfy.model_patcher
|
||||
import comfy.model_sampling
|
||||
|
||||
import comfy.memory_management
|
||||
|
||||
|
||||
def trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
return trange_(*args, **kwargs)
|
||||
|
||||
pbar = trange_(*args, **kwargs, smoothing=1.0)
|
||||
pbar._i = 0
|
||||
pbar.set_postfix_str(" Model Initializing ... ")
|
||||
|
||||
_update = pbar.update
|
||||
|
||||
def warmup_update(n=1):
|
||||
pbar._i += 1
|
||||
if pbar._i == 1:
|
||||
pbar.i1_time = time.time()
|
||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||
elif pbar._i == 2:
|
||||
#bring forward the effective start time based the the diff between first and second iteration
|
||||
#to attempt to remove load overhead from the final step rate estimate.
|
||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||
pbar.set_postfix_str("")
|
||||
|
||||
_update(n)
|
||||
|
||||
pbar.update = warmup_update
|
||||
return pbar
|
||||
|
||||
from comfy.utils import model_trange as trange
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
@ -1110,7 +1110,7 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
|
||||
return encoder_hidden, encoder_mask, context_latents
|
||||
|
||||
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, **kwargs):
|
||||
def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs):
|
||||
text_attention_mask = None
|
||||
lyric_attention_mask = None
|
||||
refer_audio_order_mask = None
|
||||
@ -1140,6 +1140,9 @@ class AceStepConditionGenerationModel(nn.Module):
|
||||
src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes
|
||||
)
|
||||
|
||||
if replace_with_null_embeds:
|
||||
enc_hidden[:] = self.null_condition_emb.to(enc_hidden)
|
||||
|
||||
out = self.decoder(hidden_states=x,
|
||||
timestep=timestep,
|
||||
timestep_r=timestep,
|
||||
|
||||
@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
|
||||
if source_attention_mask.ndim == 2:
|
||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
x = self.in_proj(self.embed(target_input_ids))
|
||||
context = source_hidden_states
|
||||
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
|
||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||
position_embeddings = self.rotary_emb(x, position_ids)
|
||||
@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
||||
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||
if text_ids is not None:
|
||||
return self.llm_adapter(text_embeds, text_ids)
|
||||
out = self.llm_adapter(text_embeds, text_ids)
|
||||
if t5xxl_weights is not None:
|
||||
out = out * t5xxl_weights
|
||||
|
||||
if out.shape[1] < 512:
|
||||
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||
return out
|
||||
else:
|
||||
return text_embeds
|
||||
|
||||
def forward(self, x, timesteps, context, **kwargs):
|
||||
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||
if t5xxl_ids is not None:
|
||||
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||
return super().forward(x, timesteps, context, **kwargs)
|
||||
|
||||
@ -3,7 +3,6 @@ from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
@ -29,7 +28,7 @@ class Approximator(nn.Module):
|
||||
super().__init__()
|
||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||
|
||||
@property
|
||||
|
||||
@ -152,6 +152,7 @@ class Chroma(nn.Module):
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
# running on sequences img
|
||||
@ -228,6 +229,7 @@ class Chroma(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
|
||||
@ -4,8 +4,6 @@ from functools import lru_cache
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.ldm.flux.layers import RMSNorm
|
||||
|
||||
|
||||
class NerfEmbedder(nn.Module):
|
||||
"""
|
||||
@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
|
||||
# We now need to generate parameters for 3 matrices.
|
||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
|
||||
@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
|
||||
class NerfFinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
|
||||
class NerfFinalLayerConv(nn.Module):
|
||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||
self.conv = operations.Conv2d(
|
||||
in_channels=hidden_size,
|
||||
out_channels=out_channels,
|
||||
|
||||
@ -335,7 +335,7 @@ class FinalLayer(nn.Module):
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = operations.Linear(
|
||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||
)
|
||||
@ -463,6 +463,8 @@ class Block(nn.Module):
|
||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[dict] = {},
|
||||
) -> torch.Tensor:
|
||||
residual_dtype = x_B_T_H_W_D.dtype
|
||||
compute_dtype = emb_B_T_D.dtype
|
||||
if extra_per_block_pos_emb is not None:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||
|
||||
@ -512,7 +514,7 @@ class Block(nn.Module):
|
||||
result_B_T_H_W_D = rearrange(
|
||||
self.self_attn(
|
||||
# normalized_x_B_T_HW_D,
|
||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
None,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@ -522,7 +524,7 @@ class Block(nn.Module):
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
@ -536,7 +538,7 @@ class Block(nn.Module):
|
||||
)
|
||||
_result_B_T_H_W_D = rearrange(
|
||||
self.cross_attn(
|
||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
||||
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||
crossattn_emb,
|
||||
rope_emb=rope_emb_L_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
@ -555,7 +557,7 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
@ -563,8 +565,8 @@ class Block(nn.Module):
|
||||
scale_mlp_B_T_1_1_D,
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
@ -876,6 +878,14 @@ class MiniTrainDIT(nn.Module):
|
||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||
"transformer_options": kwargs.get("transformer_options", {}),
|
||||
}
|
||||
|
||||
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||
# in fp32, but run attention and MLP modules in fp16.
|
||||
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||
# quality degradation and visual artifacts.
|
||||
if x_B_T_H_W_D.dtype == torch.float16:
|
||||
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||
|
||||
for block in self.blocks:
|
||||
x_B_T_H_W_D = block(
|
||||
x_B_T_H_W_D,
|
||||
@ -884,6 +894,6 @@ class MiniTrainDIT(nn.Module):
|
||||
**block_kwargs,
|
||||
)
|
||||
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D)
|
||||
x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]]
|
||||
return x_B_C_Tt_Hp_Wp
|
||||
|
||||
@ -5,9 +5,9 @@ import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .math import attention, rope
|
||||
import comfy.ops
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
# Fix import for some custom nodes, TODO: delete eventually.
|
||||
RMSNorm = None
|
||||
|
||||
class EmbedND(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||
@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
||||
|
||||
|
||||
class QKNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
||||
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||
q = self.query_norm(q)
|
||||
@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
@ -206,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
@ -224,32 +217,23 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
@ -328,6 +312,9 @@ class SingleStreamBlock(nn.Module):
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
@ -337,6 +324,12 @@ class SingleStreamBlock(nn.Module):
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
|
||||
@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
|
||||
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope(xq, xk, freqs_cis)
|
||||
else:
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
def apply_rope1(x, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope1(x, freqs_cis)
|
||||
else:
|
||||
return q_apply_rope1(x, freqs_cis)
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
return x_out.reshape(*x.shape).type_as(x)
|
||||
|
||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
apply_rope = _apply_rope
|
||||
apply_rope1 = _apply_rope1
|
||||
|
||||
@ -16,7 +16,6 @@ from .layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@ -81,7 +80,7 @@ class Flux(nn.Module):
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
@ -143,6 +142,7 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@ -232,6 +232,7 @@ class Flux(nn.Module):
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
flipped_img_txt=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@ -305,6 +304,7 @@ class HunyuanVideo(nn.Module):
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
transformer_options = transformer_options.copy()
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
initial_shape = list(img.shape)
|
||||
@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module):
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
img_len = img.shape[1]
|
||||
if txt_mask is not None:
|
||||
attn_mask_len = img_len + txt.shape[1]
|
||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||
attn_mask[:, 0, img_len:] = txt_mask
|
||||
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
@ -413,10 +413,11 @@ class HunyuanVideo(nn.Module):
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
@ -435,9 +436,9 @@ class HunyuanVideo(nn.Module):
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, : img_len] += add
|
||||
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||
if ref_latent is not None:
|
||||
img = img[:, ref_latent.shape[1]:]
|
||||
|
||||
|
||||
@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
|
||||
return self.conv(x)
|
||||
|
||||
def interpolate_up(x, scale_factor):
|
||||
try:
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
except: #operation not implemented for bf16
|
||||
orig_shape = list(x.shape)
|
||||
out_shape = orig_shape[:2]
|
||||
for i in range(len(orig_shape) - 2):
|
||||
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
||||
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
||||
return out
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||
|
||||
@ -2,6 +2,196 @@ import torch
|
||||
import math
|
||||
|
||||
from .model import QwenImageTransformer2DModel
|
||||
from .model import QwenImageTransformerBlock
|
||||
|
||||
|
||||
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||
super().__init__(
|
||||
dim=dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
self.has_before_proj = has_before_proj
|
||||
if has_before_proj:
|
||||
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
control_in_features=132,
|
||||
inner_dim=3072,
|
||||
num_attention_heads=24,
|
||||
attention_head_dim=128,
|
||||
num_control_blocks=5,
|
||||
main_model_double=60,
|
||||
injection_layers=(0, 12, 24, 36, 48),
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.main_model_double = main_model_double
|
||||
self.injection_layers = tuple(injection_layers)
|
||||
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||
# to the reference Gen2/VideoX implementation around strength=1.
|
||||
self.hint_scale = 1.0
|
||||
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||
|
||||
self.control_blocks = torch.nn.ModuleList([])
|
||||
for i in range(num_control_blocks):
|
||||
self.control_blocks.append(
|
||||
QwenImageFunControlBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
has_before_proj=(i == 0),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
)
|
||||
|
||||
def _process_hint_tokens(self, hint):
|
||||
if hint is None:
|
||||
return None
|
||||
if hint.ndim == 4:
|
||||
hint = hint.unsqueeze(2)
|
||||
|
||||
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||
if hint.shape[1] == 16 and expected_c == 33:
|
||||
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||
zeros_inpaint = torch.zeros_like(hint)
|
||||
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||
|
||||
bs, c, t, h, w = hint.shape
|
||||
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(
|
||||
orig_shape[0],
|
||||
orig_shape[1],
|
||||
orig_shape[-3],
|
||||
orig_shape[-2] // 2,
|
||||
2,
|
||||
orig_shape[-1] // 2,
|
||||
2,
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||
hidden_states = hidden_states.reshape(
|
||||
bs,
|
||||
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||
c * 4,
|
||||
)
|
||||
|
||||
expected_in = self.control_img_in.weight.shape[1]
|
||||
cur_in = hidden_states.shape[-1]
|
||||
if cur_in < expected_in:
|
||||
pad = torch.zeros(
|
||||
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||
elif cur_in > expected_in:
|
||||
hidden_states = hidden_states[:, :, :expected_in]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
context,
|
||||
attention_mask=None,
|
||||
guidance: torch.Tensor = None,
|
||||
hint=None,
|
||||
transformer_options={},
|
||||
base_model=None,
|
||||
**kwargs,
|
||||
):
|
||||
if base_model is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||
|
||||
encoder_hidden_states_mask = attention_mask
|
||||
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||
encoder_hidden_states_mask = None
|
||||
|
||||
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||
hint_tokens = self._process_hint_tokens(hint)
|
||||
if hint_tokens is None:
|
||||
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||
|
||||
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||
hint_tokens = hint_tokens[:, :max_tokens]
|
||||
hidden_states = hidden_states[:, :max_tokens]
|
||||
img_ids = img_ids[:, :max_tokens]
|
||||
|
||||
txt_start = round(
|
||||
max(
|
||||
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||
)
|
||||
)
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
|
||||
hidden_states = base_model.img_in(hidden_states)
|
||||
encoder_hidden_states = base_model.txt_norm(context)
|
||||
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance * 1000
|
||||
|
||||
temb = (
|
||||
base_model.time_text_embed(timesteps, hidden_states)
|
||||
if guidance is None
|
||||
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||
)
|
||||
|
||||
c = self.control_img_in(hint_tokens)
|
||||
|
||||
for i, block in enumerate(self.control_blocks):
|
||||
if i == 0:
|
||||
c_in = block.before_proj(c) + hidden_states
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c, dim=0))
|
||||
c_in = all_c.pop(-1)
|
||||
|
||||
encoder_hidden_states, c_out = block(
|
||||
hidden_states=c_in,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||
all_c += [c_skip, c_out]
|
||||
c = torch.stack(all_c, dim=0)
|
||||
|
||||
hints = torch.unbind(c, dim=0)[:-1]
|
||||
|
||||
controlnet_block_samples = [None] * self.main_model_double
|
||||
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||
|
||||
return {"input": controlnet_block_samples}
|
||||
|
||||
|
||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||
|
||||
@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
||||
|
||||
return padded_tensor
|
||||
|
||||
def calculate_shape(patches, weight, key, original_weights=None):
|
||||
current_shape = weight.shape
|
||||
|
||||
for p in patches:
|
||||
v = p[1]
|
||||
offset = p[3]
|
||||
|
||||
# Offsets restore the old shape; lists force a diff without metadata
|
||||
if offset is not None or isinstance(v, list):
|
||||
continue
|
||||
|
||||
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||
adapter_shape = v.calculate_shape(key)
|
||||
if adapter_shape is not None:
|
||||
current_shape = adapter_shape
|
||||
continue
|
||||
|
||||
# Standard diff logic with padding
|
||||
if len(v) == 2:
|
||||
patch_type, patch_data = v[0], v[1]
|
||||
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
|
||||
current_shape = patch_data[0].shape
|
||||
|
||||
return current_shape
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
|
||||
@ -5,7 +5,7 @@ import comfy.utils
|
||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||
sd_out = {}
|
||||
for k in sd:
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
|
||||
sd_out[k_to] = sd[k]
|
||||
|
||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||
|
||||
@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
|
||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
dtype = self.get_dtype_inference()
|
||||
|
||||
xc = xc.to(dtype)
|
||||
device = xc.device
|
||||
@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
|
||||
def get_dtype(self):
|
||||
return self.diffusion_model.dtype
|
||||
|
||||
def get_dtype_inference(self):
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
return dtype
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None
|
||||
|
||||
@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
|
||||
input_shapes += shape
|
||||
|
||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||
dtype = self.get_dtype()
|
||||
if self.manual_cast_dtype is not None:
|
||||
dtype = self.manual_cast_dtype
|
||||
dtype = self.get_dtype_inference()
|
||||
#TODO: this needs to be tweaked
|
||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||
@ -1160,12 +1162,16 @@ class Anima(BaseModel):
|
||||
device = kwargs["device"]
|
||||
if cross_attn is not None:
|
||||
if t5xxl_ids is not None:
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
||||
if t5xxl_weights is not None:
|
||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||
|
||||
if torch.is_inference_mode_enabled(): # if not we are training
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
|
||||
else:
|
||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||
|
||||
if cross_attn.shape[1] < 512:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
@ -1552,6 +1558,8 @@ class ACEStep15(BaseModel):
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if torch.count_nonzero(cross_attn) == 0:
|
||||
out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True)
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
|
||||
@ -1575,6 +1583,10 @@ class ACEStep15(BaseModel):
|
||||
else:
|
||||
out['is_covers'] = comfy.conds.CONDConstant(False)
|
||||
|
||||
if refer_audio.shape[2] < noise.shape[2]:
|
||||
pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device)
|
||||
refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2)
|
||||
|
||||
out['refer_audio'] = comfy.conds.CONDRegular(refer_audio)
|
||||
return out
|
||||
|
||||
|
||||
@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def any_suffix_in(keys, prefix, main, suffix_list=[]):
|
||||
for x in suffix_list:
|
||||
if "{}{}{}".format(prefix, main, x) in keys:
|
||||
return True
|
||||
return False
|
||||
|
||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||
context_dim = None
|
||||
use_linear_in_transformer = False
|
||||
@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["meanflow_sum"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
dit_config = {}
|
||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["image_model"] = "flux2"
|
||||
@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
|
||||
dit_config["image_model"] = "chroma"
|
||||
dit_config["in_channels"] = 64
|
||||
dit_config["out_channels"] = 64
|
||||
@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["out_dim"] = 3072
|
||||
dit_config["hidden_dim"] = 5120
|
||||
dit_config["n_layers"] = 5
|
||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
||||
|
||||
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
|
||||
dit_config["image_model"] = "chroma_radiance"
|
||||
dit_config["in_channels"] = 3
|
||||
dit_config["out_channels"] = 3
|
||||
@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_depth"] = 4
|
||||
dit_config["nerf_max_freqs"] = 8
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
import psutil
|
||||
import logging
|
||||
from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import threading
|
||||
import torch
|
||||
import sys
|
||||
@ -55,6 +55,11 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
|
||||
# Training Related State
|
||||
in_training = False
|
||||
|
||||
|
||||
def get_supported_float8_types():
|
||||
float8_types = []
|
||||
try:
|
||||
@ -651,7 +656,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
cleanup_models_gc()
|
||||
global vram_state
|
||||
|
||||
@ -747,26 +752,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
|
||||
with torch.inference_mode():
|
||||
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
soft_empty_cache()
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
|
||||
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
|
||||
#thread local. So exploit that to escape context
|
||||
if enables_dynamic_vram():
|
||||
t = threading.Thread(
|
||||
target=load_models_gpu_thread,
|
||||
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
)
|
||||
t.start()
|
||||
t.join()
|
||||
else:
|
||||
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
|
||||
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
|
||||
def load_model_gpu(model):
|
||||
return load_models_gpu([model])
|
||||
|
||||
@ -1226,21 +1211,20 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if dtype is None:
|
||||
dtype = weight._model_dtype
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
v_tensor = weight._v_tensor
|
||||
else:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
weight._v_tensor = v_tensor
|
||||
weight._v_signature = signature
|
||||
#Send it over
|
||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
|
||||
#a non comfy weight
|
||||
r.copy_(v_tensor)
|
||||
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
|
||||
return r
|
||||
return v_tensor.to(dtype=dtype)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
|
||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||
#Offloaded casting could skip this, however it would make the quantizations
|
||||
|
||||
@ -19,7 +19,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
@ -317,7 +316,7 @@ class ModelPatcher:
|
||||
|
||||
n.object_patches = self.object_patches.copy()
|
||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||
n.model_options = copy.deepcopy(self.model_options)
|
||||
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
@ -407,13 +406,16 @@ class ModelPatcher:
|
||||
def memory_required(self, input_shape):
|
||||
return self.model.memory_required(input_shape=input_shape)
|
||||
|
||||
def disable_model_cfg1_optimization(self):
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
|
||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||
else:
|
||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||
if disable_cfg1_optimization:
|
||||
self.model_options["disable_cfg1_optimization"] = True
|
||||
self.disable_model_cfg1_optimization()
|
||||
|
||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||
@ -680,18 +682,19 @@ class ModelPatcher:
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self, prio_comfy_cast_weights=False):
|
||||
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
params = []
|
||||
skip = False
|
||||
for name, param in m.named_parameters(recurse=False):
|
||||
params.append(name)
|
||||
default = False
|
||||
params = { name: param for name, param in m.named_parameters(recurse=False) }
|
||||
for name, param in m.named_parameters(recurse=True):
|
||||
if name not in params:
|
||||
skip = True # skip random weights in non leaf modules
|
||||
default = True # default random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
if default and default_device is not None:
|
||||
for param in params.values():
|
||||
param.data = param.data.to(device=default_device)
|
||||
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
@ -1492,9 +1495,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if vbar is not None:
|
||||
vbar.prioritize()
|
||||
|
||||
#We have way more tools for acceleration on comfy weight offloading, so always
|
||||
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||
#with a high layer rate (e.g. autoregressive LLMs).
|
||||
#prioritize the non-comfy weights (note the order reverse).
|
||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||
loading.sort(reverse=True)
|
||||
|
||||
for x in loading:
|
||||
@ -1512,8 +1517,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if weight is None:
|
||||
return 0
|
||||
return (False, 0)
|
||||
if key in self.patches:
|
||||
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||
return (True, 0)
|
||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||
num_patches += 1
|
||||
else:
|
||||
@ -1524,10 +1531,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
setattr(m, param_key + "_function", weight_function)
|
||||
geometry = weight
|
||||
if not isinstance(weight, QuantizedTensor):
|
||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
|
||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||
weight._model_dtype = model_dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
return comfy.memory_management.vram_aligned_size(geometry)
|
||||
return (False, comfy.memory_management.vram_aligned_size(geometry))
|
||||
|
||||
def force_load_param(self, param_key, device_to):
|
||||
key = key_param_name_to_key(n, param_key)
|
||||
if key in self.backup:
|
||||
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.comfy_cast_weights = True
|
||||
@ -1535,13 +1548,19 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
m.seed_key = n
|
||||
set_dirty(m, dirty)
|
||||
|
||||
v_weight_size = 0
|
||||
v_weight_size += setup_param(self, m, n, "weight")
|
||||
v_weight_size += setup_param(self, m, n, "bias")
|
||||
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||
force_load = force_load or force_load_bias
|
||||
v_weight_size += v_weight_bias
|
||||
|
||||
if vbar is not None and not hasattr(m, "_v"):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
allocated_size += v_weight_size
|
||||
if force_load:
|
||||
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||
force_load_param(self, "weight", device_to)
|
||||
force_load_param(self, "bias", device_to)
|
||||
else:
|
||||
if vbar is not None and not hasattr(m, "_v"):
|
||||
m._v = vbar.alloc(v_weight_size)
|
||||
allocated_size += v_weight_size
|
||||
|
||||
else:
|
||||
for param in params:
|
||||
@ -1550,13 +1569,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
weight.seed_key = key
|
||||
set_dirty(weight, dirty)
|
||||
geometry = weight
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
|
||||
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||
weight_size = geometry.numel() * geometry.element_size()
|
||||
if vbar is not None and not hasattr(weight, "_v"):
|
||||
weight._v = vbar.alloc(weight_size)
|
||||
weight._model_dtype = model_dtype
|
||||
allocated_size += weight_size
|
||||
vbar.set_watermark_limit(allocated_size)
|
||||
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||
|
||||
@ -1577,7 +1599,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
||||
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||
@ -1598,6 +1620,13 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
if unpatch_weights:
|
||||
self.partially_unload_ram(1e32)
|
||||
self.partially_unload(None, 1e32)
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
|
||||
keys = list(self.backup.keys())
|
||||
for k in keys:
|
||||
bk = self.backup[k]
|
||||
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
assert not force_patch_weights #See above
|
||||
|
||||
48
comfy/ops.py
48
comfy/ops.py
@ -21,7 +21,6 @@ import logging
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import json
|
||||
import comfy.memory_management
|
||||
import comfy.pinned_memory
|
||||
@ -80,17 +79,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
if signature is not None:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
|
||||
if not resident:
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
@ -140,9 +143,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
@ -163,14 +170,14 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||
x = lowvram_fn(x)
|
||||
if (isinstance(orig, QuantizedTensor) and
|
||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
||||
(want_requant and len(fns) == 0 or update_weight)):
|
||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||
if orig.dtype == dtype and len(fns) == 0:
|
||||
if want_requant and len(fns) == 0:
|
||||
#The layer actually wants our freshly saved QT
|
||||
x = y
|
||||
else:
|
||||
y = x
|
||||
elif update_weight:
|
||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
||||
if update_weight:
|
||||
orig.copy_(y)
|
||||
for f in fns:
|
||||
@ -182,13 +189,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
s._v_signature=signature
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
@ -206,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
@ -456,7 +462,7 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
||||
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
@ -468,8 +474,7 @@ class disable_weight_init:
|
||||
weight = None
|
||||
bias = None
|
||||
offload_stream = None
|
||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
@ -845,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
||||
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
@ -876,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||
|
||||
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
||||
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||
|
||||
# Reshape output back to 3D if input was 3D
|
||||
if reshaped_3d:
|
||||
|
||||
@ -1,57 +1,10 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import numbers
|
||||
import logging
|
||||
|
||||
RMSNorm = None
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
except:
|
||||
rms_norm_torch = None
|
||||
logging.warning("Please update pytorch to use native RMSNorm")
|
||||
|
||||
RMSNorm = torch.nn.RMSNorm
|
||||
|
||||
def rms_norm(x, weight=None, eps=1e-6):
|
||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
if weight is None:
|
||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
if weight is None:
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||
else:
|
||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
||||
if weight is None:
|
||||
return r
|
||||
else:
|
||||
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
if RMSNorm is None:
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=1e-6,
|
||||
elementwise_affine=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
# mypy error: incompatible types in assignment
|
||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||
|
||||
@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
|
||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||
return memory_required, minimum_memory_required
|
||||
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||
_prepare_sampling,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||
)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||
real_model: BaseModel = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += get_additional_models_from_model_options(model_options)
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
||||
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
|
||||
memory_required = 1e20
|
||||
minimum_memory_required = None
|
||||
else:
|
||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||
memory_required += inference_memory
|
||||
minimum_memory_required += inference_memory
|
||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||
real_model = model.model
|
||||
|
||||
return real_model, conds, models
|
||||
|
||||
32
comfy/sd.py
32
comfy/sd.py
@ -423,6 +423,19 @@ class CLIP:
|
||||
def get_key_patches(self):
|
||||
return self.patcher.get_key_patches()
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
self.cond_stage_model.reset_clip_options()
|
||||
|
||||
if self.layer_idx is not None:
|
||||
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
|
||||
|
||||
self.load_model()
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
@ -793,8 +806,6 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
model_management.archive_model_dtypes(self.first_stage_model)
|
||||
|
||||
if device is None:
|
||||
device = model_management.vae_device()
|
||||
self.device = device
|
||||
@ -803,6 +814,7 @@ class VAE:
|
||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||
self.vae_dtype = dtype
|
||||
self.first_stage_model.to(self.vae_dtype)
|
||||
model_management.archive_model_dtypes(self.first_stage_model)
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
mp = comfy.model_patcher.CoreModelPatcher
|
||||
@ -1183,6 +1195,7 @@ class TEModel(Enum):
|
||||
JINA_CLIP_2 = 19
|
||||
QWEN3_8B = 20
|
||||
QWEN3_06B = 21
|
||||
GEMMA_3_4B_VISION = 22
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1211,7 +1224,10 @@ def detect_te_model(sd):
|
||||
if 'model.layers.47.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_12B
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
return TEModel.GEMMA_3_4B
|
||||
if 'vision_model.embeddings.patch_embedding.weight' in sd:
|
||||
return TEModel.GEMMA_3_4B_VISION
|
||||
else:
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
@ -1271,6 +1287,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
if "text_projection" in clip_data[i]:
|
||||
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
|
||||
if "lm_head.weight" in clip_data[i]:
|
||||
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
|
||||
|
||||
tokenizer_data = {}
|
||||
clip_target = EmptyClass()
|
||||
@ -1336,6 +1354,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_4B_VISION:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_12B:
|
||||
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.LLAMA3_8:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||
|
||||
@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
pad_token = self.special_tokens.get("pad", -1)
|
||||
if end_token is None:
|
||||
cmp_token = self.special_tokens.get("pad", -1)
|
||||
cmp_token = pad_token
|
||||
else:
|
||||
cmp_token = end_token
|
||||
|
||||
@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
other_embeds = []
|
||||
eos = False
|
||||
index = 0
|
||||
left_pad = False
|
||||
for y in x:
|
||||
if isinstance(y, numbers.Integral):
|
||||
if eos:
|
||||
token = int(y)
|
||||
if index == 0 and token == pad_token:
|
||||
left_pad = True
|
||||
|
||||
if eos or (left_pad and token == pad_token):
|
||||
attention_mask.append(0)
|
||||
else:
|
||||
attention_mask.append(1)
|
||||
token = int(y)
|
||||
left_pad = False
|
||||
|
||||
tokens_temp += [token]
|
||||
if not eos and token == cmp_token:
|
||||
if not eos and token == cmp_token and not left_pad:
|
||||
if end_token is None:
|
||||
attention_mask[-1] = 0
|
||||
eos = True
|
||||
@ -301,6 +308,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]):
|
||||
if isinstance(tokens, dict):
|
||||
tokens_only = next(iter(tokens.values())) # todo: get this better?
|
||||
else:
|
||||
tokens_only = tokens
|
||||
tokens_only = [[t[0] for t in b] for b in tokens_only]
|
||||
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens)
|
||||
|
||||
def parse_parentheses(string):
|
||||
result = []
|
||||
current_item = ""
|
||||
@ -656,6 +672,9 @@ class SDTokenizer:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
class SD1Tokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
|
||||
if name is not None:
|
||||
@ -679,6 +698,9 @@ class SD1Tokenizer:
|
||||
def state_dict(self):
|
||||
return getattr(self, self.clip).state_dict()
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=True):
|
||||
return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
class SD1CheckpointClipModel(SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||
@ -715,3 +737,6 @@ class SD1ClipModel(torch.nn.Module):
|
||||
|
||||
def load_sd(self, sd):
|
||||
return getattr(self, self.clip).load_sd(sd)
|
||||
|
||||
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
|
||||
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
|
||||
|
||||
@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith("_norm.scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
|
||||
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
||||
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
||||
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
|
||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
|
||||
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
||||
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
||||
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
@ -993,7 +1004,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
@ -1023,11 +1034,7 @@ class Anima(supported_models_base.BASE):
|
||||
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Anima(self, device=device)
|
||||
@ -1038,6 +1045,12 @@ class Anima(supported_models_base.BASE):
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect))
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs):
|
||||
self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
if dtype is torch.float16:
|
||||
self.memory_usage_factor *= 1.4
|
||||
return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs)
|
||||
|
||||
class CosmosI2VPredict2(CosmosT2IPredict2):
|
||||
unet_config = {
|
||||
"image_model": "cosmos_predict2",
|
||||
@ -1262,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
@ -1339,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
if key_out.endswith(".scale"):
|
||||
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Chroma(self, device=device)
|
||||
|
||||
@ -10,12 +10,12 @@ import comfy.utils
|
||||
def sample_manual_loop_no_classes(
|
||||
model,
|
||||
ids=None,
|
||||
paddings=[],
|
||||
execution_dtype=None,
|
||||
cfg_scale: float = 2.0,
|
||||
temperature: float = 0.85,
|
||||
top_p: float = 0.9,
|
||||
top_k: int = None,
|
||||
min_p: float = 0.000,
|
||||
seed: int = 1,
|
||||
min_tokens: int = 1,
|
||||
max_new_tokens: int = 2048,
|
||||
@ -23,6 +23,8 @@ def sample_manual_loop_no_classes(
|
||||
audio_end_id: int = 215669,
|
||||
eos_token_id: int = 151645,
|
||||
):
|
||||
if ids is None:
|
||||
return []
|
||||
device = model.execution_device
|
||||
|
||||
if execution_dtype is None:
|
||||
@ -32,31 +34,34 @@ def sample_manual_loop_no_classes(
|
||||
execution_dtype = torch.float32
|
||||
|
||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||
for i, t in enumerate(paddings):
|
||||
attention_mask[i, :t] = 0
|
||||
attention_mask[i, t:] = 1
|
||||
embeds_batch = embeds.shape[0]
|
||||
|
||||
output_audio_codes = []
|
||||
past_key_values = []
|
||||
generator = torch.Generator(device=device)
|
||||
generator.manual_seed(seed)
|
||||
model_config = model.transformer.model.config
|
||||
past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
|
||||
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
|
||||
|
||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||
|
||||
for step in range(max_new_tokens):
|
||||
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
|
||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||
past_key_values = outputs[2]
|
||||
|
||||
cond_logits = next_token_logits[0:1]
|
||||
uncond_logits = next_token_logits[1:2]
|
||||
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
||||
if cfg_scale != 1.0:
|
||||
cond_logits = next_token_logits[0:1]
|
||||
uncond_logits = next_token_logits[1:2]
|
||||
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
||||
else:
|
||||
cfg_logits = next_token_logits[0:1]
|
||||
|
||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||
use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
|
||||
if use_eos_score:
|
||||
eos_score = cfg_logits[:, eos_token_id].clone()
|
||||
|
||||
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||
@ -64,7 +69,7 @@ def sample_manual_loop_no_classes(
|
||||
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||
cfg_logits[:, audio_end_id:] = remove_logit_value
|
||||
|
||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
||||
if use_eos_score:
|
||||
cfg_logits[:, eos_token_id] = eos_score
|
||||
|
||||
if top_k is not None and top_k > 0:
|
||||
@ -72,6 +77,12 @@ def sample_manual_loop_no_classes(
|
||||
min_val = top_k_vals[..., -1, None]
|
||||
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
||||
|
||||
if min_p is not None and min_p > 0:
|
||||
probs = torch.softmax(cfg_logits, dim=-1)
|
||||
p_max = probs.max(dim=-1, keepdim=True).values
|
||||
indices_to_remove = probs < (min_p * p_max)
|
||||
cfg_logits[indices_to_remove] = remove_logit_value
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
@ -93,8 +104,8 @@ def sample_manual_loop_no_classes(
|
||||
break
|
||||
|
||||
embed, _, _, _ = model.process_tokens([[token]], device)
|
||||
embeds = embed.repeat(2, 1, 1)
|
||||
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||
embeds = embed.repeat(embeds_batch, 1, 1)
|
||||
attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||
|
||||
output_audio_codes.append(token - audio_start_id)
|
||||
progress_bar.update_absolute(step)
|
||||
@ -102,24 +113,29 @@ def sample_manual_loop_no_classes(
|
||||
return output_audio_codes
|
||||
|
||||
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
|
||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
|
||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||
positive = positive[0]
|
||||
negative = negative[0]
|
||||
|
||||
neg_pad = 0
|
||||
if len(negative) < len(positive):
|
||||
neg_pad = (len(positive) - len(negative))
|
||||
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
||||
if cfg_scale != 1.0:
|
||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||
negative = negative[0]
|
||||
|
||||
pos_pad = 0
|
||||
if len(negative) > len(positive):
|
||||
pos_pad = (len(negative) - len(positive))
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
neg_pad = 0
|
||||
if len(negative) < len(positive):
|
||||
neg_pad = (len(positive) - len(negative))
|
||||
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
||||
|
||||
paddings = [pos_pad, neg_pad]
|
||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
pos_pad = 0
|
||||
if len(negative) > len(positive):
|
||||
pos_pad = (len(negative) - len(positive))
|
||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||
|
||||
ids = [positive, negative]
|
||||
else:
|
||||
ids = [positive]
|
||||
|
||||
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||
|
||||
|
||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
@ -129,12 +145,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
|
||||
user_metas = {
|
||||
k: kwargs.pop(k)
|
||||
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
|
||||
for k in ("bpm", "duration", "keyscale", "timesignature")
|
||||
if k in kwargs
|
||||
}
|
||||
timesignature = user_metas.get("timesignature")
|
||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
|
||||
user_metas["timesignature"] = timesignature[:-2]
|
||||
user_metas = {
|
||||
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
|
||||
for k, v in user_metas.items()
|
||||
@ -147,8 +163,11 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
|
||||
|
||||
def _metas_to_cap(self, **kwargs) -> str:
|
||||
use_keys = ("bpm", "duration", "keyscale", "timesignature")
|
||||
use_keys = ("bpm", "timesignature", "keyscale", "duration")
|
||||
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
|
||||
timesignature = user_metas.get("timesignature")
|
||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||
user_metas["timesignature"] = timesignature[:-2]
|
||||
duration = user_metas["duration"]
|
||||
if duration == "N/A":
|
||||
user_metas["duration"] = "30 seconds"
|
||||
@ -159,9 +178,13 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
text = text.strip()
|
||||
text_negative = kwargs.get("caption_negative", text).strip()
|
||||
lyrics = kwargs.get("lyrics", "")
|
||||
lyrics_negative = kwargs.get("lyrics_negative", lyrics)
|
||||
duration = kwargs.get("duration", 120)
|
||||
if isinstance(duration, str):
|
||||
duration = float(duration.split(None, 1)[0])
|
||||
language = kwargs.get("language")
|
||||
seed = kwargs.get("seed", 0)
|
||||
|
||||
@ -170,28 +193,55 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
temperature = kwargs.get("temperature", 0.85)
|
||||
top_p = kwargs.get("top_p", 0.9)
|
||||
top_k = kwargs.get("top_k", 0.0)
|
||||
|
||||
min_p = kwargs.get("min_p", 0.000)
|
||||
|
||||
duration = math.ceil(duration)
|
||||
kwargs["duration"] = duration
|
||||
tokens_duration = duration * 5
|
||||
min_tokens = int(kwargs.get("min_tokens", tokens_duration))
|
||||
max_tokens = int(kwargs.get("max_tokens", tokens_duration))
|
||||
|
||||
cot_text = self._metas_to_cot(caption = text, **kwargs)
|
||||
metas_negative = {
|
||||
k.rsplit("_", 1)[0]: kwargs.pop(k)
|
||||
for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
|
||||
if k in kwargs
|
||||
}
|
||||
if not kwargs.get("use_negative_caption"):
|
||||
_ = metas_negative.pop("caption", None)
|
||||
|
||||
cot_text = self._metas_to_cot(caption=text, **kwargs)
|
||||
cot_text_negative = "<think>\n\n</think>" if not metas_negative else self._metas_to_cot(**metas_negative)
|
||||
meta_cap = self._metas_to_cap(**kwargs)
|
||||
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
|
||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
|
||||
lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
|
||||
qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
|
||||
|
||||
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
|
||||
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
|
||||
llm_prompts = {
|
||||
"lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
|
||||
"lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
|
||||
"lyrics": lyrics_template.format(language if language is not None else "", lyrics),
|
||||
"qwen3_06b": qwen3_06b_template.format(text, meta_cap),
|
||||
}
|
||||
|
||||
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
|
||||
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
|
||||
out["lm_metadata"] = {"min_tokens": duration * 5,
|
||||
out = {
|
||||
prompt_key: self.qwen3_06b.tokenize_with_weights(
|
||||
prompt,
|
||||
prompt_key == "qwen3_06b" and return_word_ids,
|
||||
disable_weights = True,
|
||||
**kwargs,
|
||||
)
|
||||
for prompt_key, prompt in llm_prompts.items()
|
||||
}
|
||||
out["lm_metadata"] = {"min_tokens": min_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"seed": seed,
|
||||
"generate_audio_codes": generate_audio_codes,
|
||||
"cfg_scale": cfg_scale,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"min_p": min_p,
|
||||
}
|
||||
return out
|
||||
|
||||
@ -252,7 +302,7 @@ class ACE15TEModel(torch.nn.Module):
|
||||
|
||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||
if lm_metadata["generate_audio_codes"]:
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
|
||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
|
||||
out["audio_codes"] = [audio_codes]
|
||||
|
||||
return base_out, None, out
|
||||
|
||||
@ -23,7 +23,7 @@ class AnimaTokenizer:
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = {}
|
||||
qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0
|
||||
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
return out
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@ import torch.nn as nn
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any, Tuple
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
import comfy.utils
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
@ -313,6 +315,13 @@ class Gemma3_4B_Config:
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
|
||||
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
|
||||
vision_config = GEMMA3_VISION_CONFIG
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
@dataclass
|
||||
class Gemma3_12B_Config:
|
||||
vocab_size: int = 262208
|
||||
@ -336,7 +345,7 @@ class Gemma3_12B_Config:
|
||||
rope_scale = [8.0, 1.0]
|
||||
final_norm: bool = True
|
||||
lm_head: bool = False
|
||||
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
|
||||
vision_config = GEMMA3_VISION_CONFIG
|
||||
mm_tokens_per_image = 256
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@ -355,13 +364,6 @@ class RMSNorm(nn.Module):
|
||||
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||
if not isinstance(theta, list):
|
||||
theta = [theta]
|
||||
@ -390,20 +392,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
||||
else:
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
out.append((cos, sin))
|
||||
sin_split = sin.shape[-1] // 2
|
||||
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
|
||||
|
||||
if len(out) == 1:
|
||||
return out[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
org_dtype = xq.dtype
|
||||
cos = freqs_cis[0]
|
||||
sin = freqs_cis[1]
|
||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||
nsin = freqs_cis[2]
|
||||
|
||||
q_embed = (xq * cos)
|
||||
q_split = q_embed.shape[-1] // 2
|
||||
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
|
||||
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
|
||||
|
||||
k_embed = (xk * cos)
|
||||
k_split = k_embed.shape[-1] // 2
|
||||
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
|
||||
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
|
||||
|
||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||
|
||||
|
||||
@ -438,8 +450,10 @@ class Attention(nn.Module):
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
xq = self.q_proj(hidden_states)
|
||||
xk = self.k_proj(hidden_states)
|
||||
xv = self.v_proj(hidden_states)
|
||||
@ -474,6 +488,11 @@ class Attention(nn.Module):
|
||||
else:
|
||||
present_key_value = (xk, xv, index + num_tokens)
|
||||
|
||||
if sliding_window is not None and xk.shape[2] > sliding_window:
|
||||
xk = xk[:, :, -sliding_window:]
|
||||
xv = xv[:, :, -sliding_window:]
|
||||
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
@ -556,10 +575,12 @@ class TransformerBlockGemma2(nn.Module):
|
||||
optimized_attention=None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
):
|
||||
sliding_window = None
|
||||
if self.transformer_type == 'gemma3':
|
||||
if self.sliding_attention:
|
||||
sliding_window = self.sliding_attention
|
||||
if x.shape[1] > self.sliding_attention:
|
||||
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
|
||||
sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
|
||||
sliding_mask.tril_(diagonal=-self.sliding_attention)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask + sliding_mask
|
||||
@ -578,6 +599,7 @@ class TransformerBlockGemma2(nn.Module):
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
past_key_value=past_key_value,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
|
||||
x = self.post_attention_layernorm(x)
|
||||
@ -762,6 +784,104 @@ class BaseLlama:
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
|
||||
class BaseGenerate:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
if hasattr(self.model, "lm_head"):
|
||||
module = self.model.lm_head
|
||||
else:
|
||||
module = self.model.embed_tokens
|
||||
|
||||
offload_stream = None
|
||||
if module.comfy_cast_weights:
|
||||
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
|
||||
else:
|
||||
weight = self.model.embed_tokens.weight.to(x)
|
||||
|
||||
x = torch.nn.functional.linear(input, weight, None)
|
||||
|
||||
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
|
||||
return x
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0):
|
||||
device = embeds.device
|
||||
model_config = self.model.config
|
||||
|
||||
if execution_dtype is None:
|
||||
if comfy.model_management.should_use_bf16(device):
|
||||
execution_dtype = torch.bfloat16
|
||||
else:
|
||||
execution_dtype = torch.float32
|
||||
embeds = embeds.to(execution_dtype)
|
||||
|
||||
if embeds.ndim == 2:
|
||||
embeds = embeds.unsqueeze(0)
|
||||
|
||||
past_key_values = [] #kv_cache init
|
||||
max_cache_len = embeds.shape[1] + max_length
|
||||
for x in range(model_config.num_hidden_layers):
|
||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
|
||||
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
|
||||
|
||||
generated_token_ids = []
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
|
||||
# Generation loop
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
|
||||
token_id = next_token[0].item()
|
||||
generated_token_ids.append(token_id)
|
||||
|
||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||
pbar.update(1)
|
||||
|
||||
if token_id in stop_tokens:
|
||||
break
|
||||
|
||||
return generated_token_ids
|
||||
|
||||
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
|
||||
|
||||
if not do_sample or temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
|
||||
# Sampling mode
|
||||
if repetition_penalty != 1.0:
|
||||
for i in range(logits.shape[0]):
|
||||
for token_id in set(token_history):
|
||||
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
|
||||
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
|
||||
if top_k > 0:
|
||||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
if min_p > 0.0:
|
||||
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
|
||||
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
|
||||
min_threshold = min_p * top_probs
|
||||
indices_to_remove = probs_before_filter < min_threshold
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 0] = False
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = torch.finfo(logits.dtype).min
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
|
||||
return torch.multinomial(probs, num_samples=1, generator=generator)
|
||||
|
||||
class BaseQwen3:
|
||||
def logits(self, x):
|
||||
input = x[:, -1:]
|
||||
@ -868,7 +988,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen25_7BVLI_Config(**config_dict)
|
||||
@ -878,6 +998,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
# todo: should this be tied or not?
|
||||
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
|
||||
@ -920,7 +1043,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma3_4B(BaseLlama, torch.nn.Module):
|
||||
class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_4B_Config(**config_dict)
|
||||
@ -929,7 +1052,25 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Gemma3_12B(BaseLlama, torch.nn.Module):
|
||||
class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_4B_Vision_Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
|
||||
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
|
||||
self.image_size = config.vision_config["image_size"]
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
|
||||
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
|
||||
return None, None
|
||||
|
||||
class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Gemma3_12B_Config(**config_dict)
|
||||
|
||||
@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import torch
|
||||
import comfy.utils
|
||||
import math
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@ -22,40 +23,79 @@ def ltxv_te(*args, **kwargs):
|
||||
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
|
||||
|
||||
|
||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
|
||||
class Gemma3_Tokenizer():
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs):
|
||||
self.llama_template = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n"
|
||||
self.llama_template_images = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>{}<end_of_turn>\n\n<start_of_turn>model\n"
|
||||
|
||||
if image is None:
|
||||
images = []
|
||||
else:
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(896 * 896)
|
||||
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
|
||||
images = [s[:, :, :, :3]]
|
||||
|
||||
if text.startswith('<start_of_turn>'):
|
||||
skip_template = True
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
else:
|
||||
if llama_template is None:
|
||||
if len(images) > 0:
|
||||
llama_text = self.llama_template_images.format(text)
|
||||
else:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
|
||||
|
||||
if len(images) > 0:
|
||||
embed_count = 0
|
||||
for r in text_tokens:
|
||||
for i, token in enumerate(r):
|
||||
if token[0] == 262144 and embed_count < len(images):
|
||||
r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:]
|
||||
embed_count += 1
|
||||
return text_tokens
|
||||
|
||||
class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
|
||||
|
||||
|
||||
class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
self.dtypes = set()
|
||||
self.dtypes.add(dtype)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
|
||||
text = llama_template.format(text)
|
||||
text_tokens = super().tokenize_with_weights(text, return_word_ids)
|
||||
embed_count = 0
|
||||
for k in text_tokens:
|
||||
tt = text_tokens[k]
|
||||
for r in tt:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 262144:
|
||||
if image_embeds is not None and embed_count < image_embeds.shape[0]:
|
||||
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return text_tokens
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||
|
||||
class LTXAVTEModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
@ -97,6 +137,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||
|
||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
|
||||
out_device = out.device
|
||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||
@ -111,6 +152,9 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
return out.to(out_device), pooled
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
return self.gemma3_12b.load_sd(sd)
|
||||
@ -138,6 +182,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
|
||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||
num_tokens = max(num_tokens, 64)
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
@ -150,3 +195,14 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
dtype = dtype_llama
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return LTXAVTEModel_
|
||||
|
||||
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Gemma3_12BModel_(Gemma3_12BModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Gemma3_12BModel_
|
||||
|
||||
@ -1,23 +1,23 @@
|
||||
from comfy import sd1_clip
|
||||
from .spiece_tokenizer import SPieceTokenizer
|
||||
import comfy.text_encoders.llama
|
||||
|
||||
from comfy.text_encoders.lt import Gemma3_Tokenizer
|
||||
import comfy.utils
|
||||
|
||||
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||
special_tokens = {"<end_of_turn>": 107}
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
|
||||
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
|
||||
class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
|
||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data)
|
||||
|
||||
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@ -31,6 +31,9 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[107])
|
||||
|
||||
class Gemma3_4BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
@ -40,6 +43,23 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106])
|
||||
|
||||
class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
return embeds
|
||||
|
||||
class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
|
||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||
@ -50,6 +70,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
|
||||
model = Gemma2_2BModel
|
||||
elif model_type == "gemma3_4b":
|
||||
model = Gemma3_4BModel
|
||||
elif model_type == "gemma3_4b_vision":
|
||||
model = Gemma3_4B_Vision_Model
|
||||
|
||||
class LuminaTEModel_(LuminaModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
|
||||
@ -6,9 +6,10 @@ class SPieceTokenizer:
|
||||
def from_pretrained(path, **kwargs):
|
||||
return SPieceTokenizer(path, **kwargs)
|
||||
|
||||
def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
|
||||
def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
|
||||
self.add_bos = add_bos
|
||||
self.add_eos = add_eos
|
||||
self.special_tokens = special_tokens
|
||||
import sentencepiece
|
||||
if torch.is_tensor(tokenizer_path):
|
||||
tokenizer_path = tokenizer_path.numpy().tobytes()
|
||||
@ -27,8 +28,32 @@ class SPieceTokenizer:
|
||||
return out
|
||||
|
||||
def __call__(self, string):
|
||||
if self.special_tokens is not None:
|
||||
import re
|
||||
special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
|
||||
if special_tokens_pattern and re.search(special_tokens_pattern, string):
|
||||
parts = re.split(f'({special_tokens_pattern})', string)
|
||||
result = []
|
||||
for part in parts:
|
||||
if not part:
|
||||
continue
|
||||
if part in self.special_tokens:
|
||||
result.append(self.special_tokens[part])
|
||||
else:
|
||||
encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
|
||||
result.extend(encoded)
|
||||
return {"input_ids": result}
|
||||
|
||||
out = self.tokenizer.encode(string)
|
||||
return {"input_ids": out}
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=False):
|
||||
|
||||
if skip_special_tokens and self.special_tokens:
|
||||
special_token_ids = set(self.special_tokens.values())
|
||||
token_ids = [tid for tid in token_ids if tid not in special_token_ids]
|
||||
|
||||
return self.tokenizer.decode(token_ids)
|
||||
|
||||
def serialize_model(self):
|
||||
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
|
||||
|
||||
@ -20,13 +20,14 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
import comfy.checkpoint_pickle
|
||||
import comfy.memory_management
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import logging
|
||||
import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm.auto import trange
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
import json
|
||||
@ -37,26 +38,26 @@ import warnings
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
|
||||
ALWAYS_SAFE_LOAD = False
|
||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
||||
|
||||
if True: # ckpt/pt file whitelist for safe loading of old sd files
|
||||
class ModelCheckpoint:
|
||||
pass
|
||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||
|
||||
def scalar(*args, **kwargs):
|
||||
from numpy.core.multiarray import scalar as sc
|
||||
return sc(*args, **kwargs)
|
||||
return None
|
||||
scalar.__module__ = "numpy.core.multiarray"
|
||||
|
||||
from numpy import dtype
|
||||
from numpy.dtypes import Float64DType
|
||||
from _codecs import encode
|
||||
|
||||
def encode(*args, **kwargs): # no longer necessary on newer torch
|
||||
return None
|
||||
encode.__module__ = "_codecs"
|
||||
|
||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
else:
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
|
||||
|
||||
# Current as of safetensors 0.7.0
|
||||
_TYPES = {
|
||||
@ -139,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if MMAP_TORCH_FILES:
|
||||
torch_args["mmap"] = True
|
||||
|
||||
if safe_load or ALWAYS_SAFE_LOAD:
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
else:
|
||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
@ -674,10 +672,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
||||
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
||||
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
@ -700,8 +698,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "norm.query_norm.weight",
|
||||
"attn.norm_k.weight": "norm.key_norm.weight",
|
||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||
}
|
||||
@ -1155,6 +1153,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||
|
||||
def model_trange(*args, **kwargs):
|
||||
if comfy.memory_management.aimdo_allocator is None:
|
||||
return trange(*args, **kwargs)
|
||||
|
||||
pbar = trange(*args, **kwargs, smoothing=1.0)
|
||||
pbar._i = 0
|
||||
pbar.set_postfix_str(" Model Initializing ... ")
|
||||
|
||||
_update = pbar.update
|
||||
|
||||
def warmup_update(n=1):
|
||||
pbar._i += 1
|
||||
if pbar._i == 1:
|
||||
pbar.i1_time = time.time()
|
||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||
elif pbar._i == 2:
|
||||
#bring forward the effective start time based the the diff between first and second iteration
|
||||
#to attempt to remove load overhead from the final step rate estimate.
|
||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||
pbar.set_postfix_str("")
|
||||
|
||||
_update(n)
|
||||
|
||||
pbar.update = warmup_update
|
||||
return pbar
|
||||
|
||||
PROGRESS_BAR_ENABLED = True
|
||||
def set_progress_bar_enabled(enabled):
|
||||
global PROGRESS_BAR_ENABLED
|
||||
@ -1376,3 +1400,29 @@ def string_to_seed(data):
|
||||
else:
|
||||
crc >>= 1
|
||||
return crc ^ 0xFFFFFFFF
|
||||
|
||||
def deepcopy_list_dict(obj, memo=None):
|
||||
if memo is None:
|
||||
memo = {}
|
||||
|
||||
obj_id = id(obj)
|
||||
if obj_id in memo:
|
||||
return memo[obj_id]
|
||||
|
||||
if isinstance(obj, dict):
|
||||
res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
res = [deepcopy_list_dict(i, memo) for i in obj]
|
||||
else:
|
||||
res = obj
|
||||
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
|
||||
"""Normalize image embeddings to match text embedding scale"""
|
||||
for info in embeds_info:
|
||||
if info.get("type") == "image":
|
||||
start_idx = info["index"]
|
||||
end_idx = start_idx + info["size"]
|
||||
embeds[:, start_idx:end_idx, :] /= scale_factor
|
||||
|
||||
@ -49,6 +49,12 @@ class WeightAdapterBase:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def calculate_shape(
|
||||
self,
|
||||
key
|
||||
):
|
||||
return None
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
|
||||
@ -21,6 +21,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||
from comfy.patcher_extension import PatcherInjection
|
||||
|
||||
@ -181,18 +182,21 @@ class BypassForwardHook:
|
||||
)
|
||||
return # Already injected
|
||||
|
||||
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
||||
device = None
|
||||
# Move adapter weights to compute device (GPU)
|
||||
# Use get_torch_device() instead of module.weight.device because
|
||||
# with offloading, module weights may be on CPU while compute happens on GPU
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
# Get dtype from module weight if available
|
||||
dtype = None
|
||||
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||
device = self.module.weight.device
|
||||
dtype = self.module.weight.dtype
|
||||
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
||||
device = self.module.W_q.device
|
||||
dtype = self.module.W_q.dtype
|
||||
|
||||
if device is not None:
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
# Only use dtype if it's a standard float type, not quantized
|
||||
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
||||
dtype = None
|
||||
|
||||
self._move_adapter_weights_to_device(device, dtype)
|
||||
|
||||
self.original_forward = self.module.forward
|
||||
self.module.forward = self._bypass_forward
|
||||
|
||||
@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
else:
|
||||
return None
|
||||
|
||||
def calculate_shape(
|
||||
self,
|
||||
key
|
||||
):
|
||||
reshape = self.weights[5]
|
||||
return tuple(reshape) if reshape is not None else None
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
weight,
|
||||
|
||||
@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
VERSION = "latest"
|
||||
STABLE = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.node_replacement = self.NodeReplacement()
|
||||
self.execution = self.Execution()
|
||||
|
||||
class NodeReplacement(ProxiedSingleton):
|
||||
async def register(self, node_replace: io.NodeReplace) -> None:
|
||||
"""Register a node replacement mapping."""
|
||||
from server import PromptServer
|
||||
PromptServer.instance.node_replace_manager.register(node_replace)
|
||||
|
||||
class Execution(ProxiedSingleton):
|
||||
async def set_progress(
|
||||
self,
|
||||
@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
execution: Execution
|
||||
|
||||
class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
|
||||
@ -34,6 +34,21 @@ class VideoInput(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = False,
|
||||
) -> VideoInput | None:
|
||||
"""
|
||||
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
||||
|
||||
Returns:
|
||||
A new VideoInput, or None if the result would have negative duration
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Get a streamable source for the video. This allows processing without
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Optional
|
||||
from .._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import itertools
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
|
||||
formats = container_format.split(",")
|
||||
return formats[0]
|
||||
|
||||
|
||||
def get_open_write_kwargs(
|
||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||
) -> dict:
|
||||
@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
|
||||
Class representing video input from a file.
|
||||
"""
|
||||
|
||||
def __init__(self, file: str | io.BytesIO):
|
||||
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
|
||||
"""
|
||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||
containing the file contents.
|
||||
"""
|
||||
self.__file = file
|
||||
self.__start_time = start_time
|
||||
self.__duration = duration
|
||||
|
||||
def get_stream_source(self) -> str | io.BytesIO:
|
||||
"""
|
||||
@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
|
||||
Returns:
|
||||
Duration in seconds
|
||||
"""
|
||||
raw_duration = self._get_raw_duration()
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
if self.__duration:
|
||||
return min(self.__duration, duration_from_start)
|
||||
return duration_from_start
|
||||
|
||||
def _get_raw_duration(self) -> float:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
|
||||
if video_stream and video_stream.average_rate:
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
if video_stream.codec.capabilities & 0x100
|
||||
else container.demux(video_stream)
|
||||
)
|
||||
for packet in frame_iterator:
|
||||
frame_count += 1
|
||||
if frame_count > 0:
|
||||
return float(frame_count / video_stream.average_rate)
|
||||
|
||||
@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
# 1. Prefer the frames field if available and usable
|
||||
if (
|
||||
video_stream.frames
|
||||
and video_stream.frames > 0
|
||||
and not self.__start_time
|
||||
and not self.__duration
|
||||
):
|
||||
return int(video_stream.frames)
|
||||
|
||||
# 2. Try to estimate from duration and average_rate using only metadata
|
||||
if container.duration is not None and video_stream.average_rate:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
if (
|
||||
getattr(video_stream, "duration", None) is not None
|
||||
and getattr(video_stream, "time_base", None) is not None
|
||||
and video_stream.average_rate
|
||||
):
|
||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
||||
raw_duration = float(video_stream.duration * video_stream.time_base)
|
||||
if self.__start_time < 0:
|
||||
duration_from_start = min(raw_duration, -self.__start_time)
|
||||
else:
|
||||
duration_from_start = raw_duration - self.__start_time
|
||||
duration_seconds = min(self.__duration, duration_from_start)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
# 3. Last resort: decode frames and count them (streaming)
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
frame_count = 1
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
frame_iterator = (
|
||||
container.decode(video_stream)
|
||||
if video_stream.codec.capabilities & 0x100
|
||||
else container.demux(video_stream)
|
||||
)
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= start_pts:
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
|
||||
for frame in frame_iterator:
|
||||
if frame.pts >= end_pts:
|
||||
break
|
||||
frame_count += 1
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
|
||||
return container.format.name
|
||||
|
||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
if self.__start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||
else:
|
||||
start_time = self.__start_time
|
||||
# Get video frames
|
||||
frames = []
|
||||
for frame in container.decode(video=0):
|
||||
start_pts = int(start_time / video_stream.time_base)
|
||||
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
for frame in container.decode(video_stream):
|
||||
if frame.pts < start_pts:
|
||||
continue
|
||||
if self.__duration and frame.pts >= end_pts:
|
||||
break
|
||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||
frames.append(img)
|
||||
@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
|
||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||
|
||||
# Get frame rate
|
||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
||||
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
||||
|
||||
# Get audio if available
|
||||
audio = None
|
||||
try:
|
||||
container.seek(0) # Reset the container to the beginning
|
||||
for stream in container.streams:
|
||||
if stream.type != 'audio':
|
||||
continue
|
||||
assert isinstance(stream, av.AudioStream)
|
||||
audio_frames = []
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
assert isinstance(frame, av.AudioFrame)
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
||||
})
|
||||
except StopIteration:
|
||||
pass # No audio stream
|
||||
container.seek(start_pts, stream=video_stream)
|
||||
# Use last stream for consistency
|
||||
if len(container.streams.audio):
|
||||
audio_stream = container.streams.audio[-1]
|
||||
audio_frames = []
|
||||
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
||||
frames = itertools.chain.from_iterable(
|
||||
map(resample, container.decode(audio_stream))
|
||||
)
|
||||
|
||||
has_first_frame = False
|
||||
for frame in frames:
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||
if to_skip < frame.samples:
|
||||
has_first_frame = True
|
||||
break
|
||||
if has_first_frame:
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
|
||||
for frame in frames:
|
||||
if frame.time > start_time + self.__duration:
|
||||
break
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||
if self.__duration:
|
||||
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
||||
|
||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||
audio = AudioInput({
|
||||
"waveform": audio_tensor,
|
||||
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
||||
})
|
||||
|
||||
metadata = container.metadata
|
||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||
@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
|
||||
path: str | io.BytesIO,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path,
|
||||
format=format,
|
||||
codec=codec,
|
||||
metadata=metadata
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
|
||||
output_container.mux(packet)
|
||||
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
return video_stream
|
||||
if len(container.streams.video):
|
||||
return container.streams.video[0]
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def as_trimmed(
|
||||
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
|
||||
) -> VideoInput | None:
|
||||
trimmed = VideoFromFile(
|
||||
self.get_stream_source(),
|
||||
start_time=start_time + self.__start_time,
|
||||
duration=duration,
|
||||
)
|
||||
if trimmed.get_duration() < duration and strict_duration:
|
||||
return None
|
||||
return trimmed
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
|
||||
return VideoComponents(
|
||||
images=self.__components.images,
|
||||
audio=self.__components.audio,
|
||||
frame_rate=self.__components.frame_rate
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def save_to(
|
||||
@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
):
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput):
|
||||
audio_stream: Optional[av.AudioStream] = None
|
||||
if self.__components.audio:
|
||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
|
||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput):
|
||||
output.mux(packet)
|
||||
|
||||
if audio_stream and self.__components.audio:
|
||||
waveform = self.__components.audio['waveform']
|
||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
||||
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
|
||||
frame.sample_rate = audio_sample_rate
|
||||
frame.pts = 0
|
||||
output.mux(audio_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output.mux(audio_stream.encode(None))
|
||||
|
||||
def as_trimmed(
|
||||
self,
|
||||
start_time: float | None = None,
|
||||
duration: float | None = None,
|
||||
strict_duration: bool = True,
|
||||
) -> VideoInput | None:
|
||||
if self.get_duration() < start_time + duration:
|
||||
return None
|
||||
#TODO Consider tracking duration and trimming at time of save?
|
||||
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
|
||||
|
||||
@ -75,6 +75,12 @@ class NumberDisplay(str, Enum):
|
||||
slider = "slider"
|
||||
|
||||
|
||||
class ControlAfterGenerate(str, Enum):
|
||||
fixed = "fixed"
|
||||
increment = "increment"
|
||||
decrement = "decrement"
|
||||
randomize = "randomize"
|
||||
|
||||
class _ComfyType(ABC):
|
||||
Type = Any
|
||||
io_type: str = None
|
||||
@ -263,7 +269,7 @@ class Int(ComfyTypeIO):
|
||||
class Input(WidgetInput):
|
||||
'''Integer input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.min = min
|
||||
@ -345,7 +351,7 @@ class Combo(ComfyTypeIO):
|
||||
tooltip: str=None,
|
||||
lazy: bool=None,
|
||||
default: str | int | Enum = None,
|
||||
control_after_generate: bool=None,
|
||||
control_after_generate: bool | ControlAfterGenerate=None,
|
||||
upload: UploadType=None,
|
||||
image_folder: FolderType=None,
|
||||
remote: RemoteOptions=None,
|
||||
@ -389,7 +395,7 @@ class MultiCombo(ComfyTypeI):
|
||||
Type = list[str]
|
||||
class Input(Combo.Input):
|
||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
|
||||
self.multiselect = True
|
||||
@ -1203,6 +1209,30 @@ class Color(ComfyTypeIO):
|
||||
def as_dict(self):
|
||||
return super().as_dict()
|
||||
|
||||
@comfytype(io_type="BOUNDING_BOX")
|
||||
class BoundingBox(ComfyTypeIO):
|
||||
class BoundingBoxDict(TypedDict):
|
||||
x: int
|
||||
y: int
|
||||
width: int
|
||||
height: int
|
||||
Type = BoundingBoxDict
|
||||
|
||||
class Input(WidgetInput):
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
|
||||
socketless: bool=True, default: dict=None, component: str=None):
|
||||
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
|
||||
self.component = component
|
||||
if default is None:
|
||||
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
|
||||
|
||||
def as_dict(self):
|
||||
d = super().as_dict()
|
||||
if self.component:
|
||||
d["component"] = self.component
|
||||
return d
|
||||
|
||||
|
||||
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
|
||||
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
|
||||
DYNAMIC_INPUT_LOOKUP[io_type] = func
|
||||
@ -1309,6 +1339,7 @@ class NodeInfoV1:
|
||||
api_node: bool=None
|
||||
price_badge: dict | None = None
|
||||
search_aliases: list[str]=None
|
||||
essentials_category: str=None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1430,6 +1461,8 @@ class Schema:
|
||||
"""Flags a node as expandable, allowing NodeOutput to include 'expand' property."""
|
||||
accept_all_inputs: bool=False
|
||||
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
|
||||
essentials_category: str | None = None
|
||||
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
|
||||
|
||||
def validate(self):
|
||||
'''Validate the schema:
|
||||
@ -1536,6 +1569,7 @@ class Schema:
|
||||
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
|
||||
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
|
||||
search_aliases=self.search_aliases if self.search_aliases else None,
|
||||
essentials_category=self.essentials_category,
|
||||
)
|
||||
return info
|
||||
|
||||
@ -2030,11 +2064,74 @@ class _UIOutput(ABC):
|
||||
...
|
||||
|
||||
|
||||
class InputMapOldId(TypedDict):
|
||||
"""Map an old node input to a new node input by ID."""
|
||||
new_id: str
|
||||
old_id: str
|
||||
|
||||
class InputMapSetValue(TypedDict):
|
||||
"""Set a specific value for a new node input."""
|
||||
new_id: str
|
||||
set_value: Any
|
||||
|
||||
InputMap = InputMapOldId | InputMapSetValue
|
||||
"""
|
||||
Input mapping for node replacement. Type is inferred by dictionary keys:
|
||||
- {"new_id": str, "old_id": str} - maps old input to new input
|
||||
- {"new_id": str, "set_value": Any} - sets a specific value for new input
|
||||
"""
|
||||
|
||||
class OutputMap(TypedDict):
|
||||
"""Map outputs of node replacement via indexes."""
|
||||
new_idx: int
|
||||
old_idx: int
|
||||
|
||||
class NodeReplace:
|
||||
"""
|
||||
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
|
||||
|
||||
Also supports assigning specific values to the input widgets of the new node.
|
||||
|
||||
Args:
|
||||
new_node_id: The class name of the new replacement node.
|
||||
old_node_id: The class name of the deprecated node.
|
||||
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
|
||||
connected. The workflow JSON stores widget values by their relative position index,
|
||||
not by ID. This list maps those positional indexes to input IDs, enabling the
|
||||
replacement system to correctly identify widget values during node migration.
|
||||
input_mapping: List of input mappings from old node to new node.
|
||||
output_mapping: List of output mappings from old node to new node.
|
||||
"""
|
||||
def __init__(self,
|
||||
new_node_id: str,
|
||||
old_node_id: str,
|
||||
old_widget_ids: list[str] | None=None,
|
||||
input_mapping: list[InputMap] | None=None,
|
||||
output_mapping: list[OutputMap] | None=None,
|
||||
):
|
||||
self.new_node_id = new_node_id
|
||||
self.old_node_id = old_node_id
|
||||
self.old_widget_ids = old_widget_ids
|
||||
self.input_mapping = input_mapping
|
||||
self.output_mapping = output_mapping
|
||||
|
||||
def as_dict(self):
|
||||
"""Create serializable representation of the node replacement."""
|
||||
return {
|
||||
"new_node_id": self.new_node_id,
|
||||
"old_node_id": self.old_node_id,
|
||||
"old_widget_ids": self.old_widget_ids,
|
||||
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
|
||||
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FolderType",
|
||||
"UploadType",
|
||||
"RemoteOptions",
|
||||
"NumberDisplay",
|
||||
"ControlAfterGenerate",
|
||||
|
||||
"comfytype",
|
||||
"Custom",
|
||||
@ -2121,4 +2218,6 @@ __all__ = [
|
||||
"ImageCompare",
|
||||
"PriceBadgeDepends",
|
||||
"PriceBadge",
|
||||
"BoundingBox",
|
||||
"NodeReplace",
|
||||
]
|
||||
|
||||
8
comfy_api_nodes/apis/__init__.py
generated
8
comfy_api_nodes/apis/__init__.py
generated
@ -1197,12 +1197,6 @@ class KlingImageGenImageReferenceType(str, Enum):
|
||||
face = 'face'
|
||||
|
||||
|
||||
class KlingImageGenModelName(str, Enum):
|
||||
kling_v1 = 'kling-v1'
|
||||
kling_v1_5 = 'kling-v1-5'
|
||||
kling_v2 = 'kling-v2'
|
||||
|
||||
|
||||
class KlingImageGenerationsRequest(BaseModel):
|
||||
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
|
||||
callback_url: Optional[AnyUrl] = Field(
|
||||
@ -1218,7 +1212,7 @@ class KlingImageGenerationsRequest(BaseModel):
|
||||
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
|
||||
)
|
||||
image_reference: Optional[KlingImageGenImageReferenceType] = None
|
||||
model_name: Optional[KlingImageGenModelName] = 'kling-v1'
|
||||
model_name: str = Field(...)
|
||||
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative text prompt', max_length=200
|
||||
|
||||
@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class BriaRemoveBackgroundRequest(BaseModel):
|
||||
image: str = Field(...)
|
||||
sync: bool = Field(False)
|
||||
visual_input_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on input image moderation failure."
|
||||
)
|
||||
visual_output_content_moderation: bool = Field(
|
||||
False, description="If true, returns 422 on visual output moderation failure."
|
||||
)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class BriaStatusResponse(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
status_url: str = Field(...)
|
||||
warning: str | None = Field(None)
|
||||
|
||||
|
||||
class BriaResult(BaseModel):
|
||||
class BriaRemoveBackgroundResult(BaseModel):
|
||||
image_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaRemoveBackgroundResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaRemoveBackgroundResult | None = Field(None)
|
||||
|
||||
|
||||
class BriaImageEditResult(BaseModel):
|
||||
structured_prompt: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaResponse(BaseModel):
|
||||
class BriaImageEditResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaResult | None = Field(None)
|
||||
result: BriaImageEditResult | None = Field(None)
|
||||
|
||||
|
||||
class BriaRemoveVideoBackgroundRequest(BaseModel):
|
||||
video: str = Field(...)
|
||||
background_color: str = Field(default="transparent", description="Background color for the output video.")
|
||||
output_container_and_codec: str = Field(...)
|
||||
preserve_audio: bool = Field(True)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class BriaRemoveVideoBackgroundResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class BriaRemoveVideoBackgroundResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaRemoveVideoBackgroundResult | None = Field(None)
|
||||
|
||||
@ -116,9 +116,15 @@ class GeminiGenerationConfig(BaseModel):
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class GeminiImageOutputOptions(BaseModel):
|
||||
mimeType: str = Field("image/png")
|
||||
compressionQuality: int | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageConfig(BaseModel):
|
||||
aspectRatio: str | None = Field(None)
|
||||
imageSize: str | None = Field(None)
|
||||
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
|
||||
@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
|
||||
|
||||
class To3DProTaskQueryRequest(BaseModel):
|
||||
JobId: str = Field(...)
|
||||
|
||||
|
||||
class To3DUVFileInput(BaseModel):
|
||||
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
|
||||
Url: str = Field(...)
|
||||
|
||||
|
||||
class To3DUVTaskRequest(BaseModel):
|
||||
File: To3DUVFileInput = Field(...)
|
||||
|
||||
|
||||
class TextureEditImageInfo(BaseModel):
|
||||
Url: str = Field(...)
|
||||
|
||||
|
||||
class TextureEditTaskRequest(BaseModel):
|
||||
File3D: To3DUVFileInput = Field(...)
|
||||
Image: TextureEditImageInfo | None = Field(None)
|
||||
Prompt: str | None = Field(None)
|
||||
EnablePBR: bool | None = Field(None)
|
||||
|
||||
@ -1,12 +1,22 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MultiPromptEntry(BaseModel):
|
||||
index: int = Field(...)
|
||||
prompt: str = Field(...)
|
||||
duration: str = Field(...)
|
||||
|
||||
|
||||
class OmniProText2VideoRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
|
||||
|
||||
class OmniParamImage(BaseModel):
|
||||
@ -26,6 +36,10 @@ class OmniProFirstLastFrameRequest(BaseModel):
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str | None = Field(None, description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class OmniProReferences2VideoRequest(BaseModel):
|
||||
@ -38,6 +52,10 @@ class OmniProReferences2VideoRequest(BaseModel):
|
||||
duration: str | None = Field(..., description="From 3 to 10.")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str | None = Field(None, description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusVideoResult(BaseModel):
|
||||
@ -54,6 +72,7 @@ class TaskStatusImageResult(BaseModel):
|
||||
class TaskStatusResults(BaseModel):
|
||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||
images: list[TaskStatusImageResult] | None = Field(None)
|
||||
series_images: list[TaskStatusImageResult] | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusResponseData(BaseModel):
|
||||
@ -77,31 +96,42 @@ class OmniImageParamImage(BaseModel):
|
||||
|
||||
|
||||
class OmniProImageRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-image-o1")
|
||||
resolution: str = Field(..., description="'1k' or '2k'")
|
||||
model_name: str = Field(...)
|
||||
resolution: str = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
n: int | None = Field(1, le=9)
|
||||
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||
result_type: str | None = Field(None, description="Set to 'series' for series generation")
|
||||
series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series")
|
||||
|
||||
|
||||
class TextToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
model_name: str = Field(...)
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
duration: str = Field(...)
|
||||
prompt: str | None = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class ImageToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
model_name: str = Field(...)
|
||||
image: str = Field(...)
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
image_tail: str | None = Field(None)
|
||||
duration: str = Field(...)
|
||||
prompt: str | None = Field(...)
|
||||
negative_prompt: str | None = Field(None)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
multi_shot: bool | None = Field(None)
|
||||
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||
shot_type: str | None = Field(None)
|
||||
|
||||
|
||||
class MotionControlRequest(BaseModel):
|
||||
|
||||
@ -198,11 +198,6 @@ dict_recraft_substyles_v3 = {
|
||||
}
|
||||
|
||||
|
||||
class RecraftModel(str, Enum):
|
||||
recraftv3 = 'recraftv3'
|
||||
recraftv2 = 'recraftv2'
|
||||
|
||||
|
||||
class RecraftImageSize(str, Enum):
|
||||
res_1024x1024 = '1024x1024'
|
||||
res_1365x1024 = '1365x1024'
|
||||
@ -221,6 +216,41 @@ class RecraftImageSize(str, Enum):
|
||||
res_1707x1024 = '1707x1024'
|
||||
|
||||
|
||||
RECRAFT_V4_SIZES = [
|
||||
"1024x1024",
|
||||
"1536x768",
|
||||
"768x1536",
|
||||
"1280x832",
|
||||
"832x1280",
|
||||
"1216x896",
|
||||
"896x1216",
|
||||
"1152x896",
|
||||
"896x1152",
|
||||
"832x1344",
|
||||
"1280x896",
|
||||
"896x1280",
|
||||
"1344x768",
|
||||
"768x1344",
|
||||
]
|
||||
|
||||
RECRAFT_V4_PRO_SIZES = [
|
||||
"2048x2048",
|
||||
"3072x1536",
|
||||
"1536x3072",
|
||||
"2560x1664",
|
||||
"1664x2560",
|
||||
"2432x1792",
|
||||
"1792x2432",
|
||||
"2304x1792",
|
||||
"1792x2304",
|
||||
"1664x2688",
|
||||
"1434x1024",
|
||||
"1024x1434",
|
||||
"2560x1792",
|
||||
"1792x2560",
|
||||
]
|
||||
|
||||
|
||||
class RecraftColorObject(BaseModel):
|
||||
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
|
||||
|
||||
@ -234,17 +264,16 @@ class RecraftControlsObject(BaseModel):
|
||||
|
||||
class RecraftImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
||||
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||
n: int = Field(..., description='The number of images to generate')
|
||||
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
|
||||
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
||||
model: str = Field(...)
|
||||
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
|
||||
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||
random_seed: int | None = Field(None, description="Seed for video generation")
|
||||
# text_layout
|
||||
|
||||
|
||||
class RecraftReturnedObject(BaseModel):
|
||||
|
||||
@ -3,7 +3,11 @@ from typing_extensions import override
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bria import (
|
||||
BriaEditImageRequest,
|
||||
BriaResponse,
|
||||
BriaRemoveBackgroundRequest,
|
||||
BriaRemoveBackgroundResponse,
|
||||
BriaRemoveVideoBackgroundRequest,
|
||||
BriaRemoveVideoBackgroundResponse,
|
||||
BriaImageEditResponse,
|
||||
BriaStatusResponse,
|
||||
InputModerationSettings,
|
||||
)
|
||||
@ -11,10 +15,12 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
convert_mask_to_image,
|
||||
download_url_to_image_tensor,
|
||||
get_number_of_images,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
|
||||
@ -73,21 +79,15 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"moderation",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"prompt_content_moderation", default=False
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"visual_input_moderation", default=False
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"visual_output_moderation", default=True
|
||||
),
|
||||
IO.Boolean.Input("prompt_content_moderation", default=False),
|
||||
IO.Boolean.Input("visual_input_moderation", default=False),
|
||||
IO.Boolean.Input("visual_output_moderation", default=True),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
],
|
||||
tooltip="Moderation settings",
|
||||
),
|
||||
@ -127,50 +127,26 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
mask: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if not prompt and not structured_prompt:
|
||||
raise ValueError(
|
||||
"One of prompt or structured_prompt is required to be non-empty."
|
||||
)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
raise ValueError("One of prompt or structured_prompt is required to be non-empty.")
|
||||
mask_url = None
|
||||
if mask is not None:
|
||||
mask_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
cls,
|
||||
convert_mask_to_image(mask),
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading mask",
|
||||
)
|
||||
)[0]
|
||||
mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask")
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
|
||||
data=BriaEditImageRequest(
|
||||
instruction=prompt if prompt else None,
|
||||
structured_instruction=structured_prompt if structured_prompt else None,
|
||||
images=await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading image",
|
||||
),
|
||||
images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")],
|
||||
mask=mask_url,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
guidance_scale=guidance_scale,
|
||||
seed=seed,
|
||||
model_version=model,
|
||||
steps_num=steps,
|
||||
prompt_content_moderation=moderation.get(
|
||||
"prompt_content_moderation", False
|
||||
),
|
||||
visual_input_content_moderation=moderation.get(
|
||||
"visual_input_moderation", False
|
||||
),
|
||||
visual_output_content_moderation=moderation.get(
|
||||
"visual_output_moderation", False
|
||||
),
|
||||
prompt_content_moderation=moderation.get("prompt_content_moderation", False),
|
||||
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
|
||||
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
@ -178,7 +154,7 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaResponse,
|
||||
response_model=BriaImageEditResponse,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_image_tensor(response.result.image_url),
|
||||
@ -186,11 +162,167 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class BriaRemoveImageBackground(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveImageBackground",
|
||||
display_name="Bria Remove Image Background",
|
||||
category="api node/image/Bria",
|
||||
description="Remove the background from an image using Bria RMBG 2.0.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.DynamicCombo.Input(
|
||||
"moderation",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("false", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"true",
|
||||
[
|
||||
IO.Boolean.Input("visual_input_moderation", default=False),
|
||||
IO.Boolean.Input("visual_output_moderation", default=True),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Moderation settings",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.018}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
moderation: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"),
|
||||
data=BriaRemoveBackgroundRequest(
|
||||
image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"),
|
||||
sync=False,
|
||||
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
|
||||
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url))
|
||||
|
||||
|
||||
class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveVideoBackground",
|
||||
display_name="Bria Remove Video Background",
|
||||
category="api node/video/Bria",
|
||||
description="Remove the background from a video using Bria. ",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Combo.Input(
|
||||
"background_color",
|
||||
options=[
|
||||
"Black",
|
||||
"White",
|
||||
"Gray",
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
"Yellow",
|
||||
"Cyan",
|
||||
"Magenta",
|
||||
"Orange",
|
||||
],
|
||||
tooltip="Background color for the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
background_color: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_video_duration(video, max_duration=60.0)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
|
||||
data=BriaRemoveVideoBackgroundRequest(
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
background_color=background_color,
|
||||
output_container_and_codec="mp4_h264",
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveVideoBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
class BriaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
BriaImageEditNode,
|
||||
BriaRemoveImageBackground,
|
||||
BriaRemoveVideoBackground,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
||||
import base64
|
||||
import os
|
||||
from enum import Enum
|
||||
from fnmatch import fnmatch
|
||||
from io import BytesIO
|
||||
from typing import Literal
|
||||
|
||||
@ -119,6 +120,13 @@ async def create_image_parts(
|
||||
return image_parts
|
||||
|
||||
|
||||
def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool:
|
||||
"""Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*')."""
|
||||
if mime is None:
|
||||
return False
|
||||
return fnmatch(mime.value, pattern)
|
||||
|
||||
|
||||
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
@ -151,9 +159,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
||||
for part in candidate.content.parts:
|
||||
if part_type == "text" and part.text:
|
||||
parts.append(part)
|
||||
elif part.inlineData and part.inlineData.mimeType == part_type:
|
||||
elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type):
|
||||
parts.append(part)
|
||||
elif part.fileData and part.fileData.mimeType == part_type:
|
||||
elif part.fileData and _mime_matches(part.fileData.mimeType, part_type):
|
||||
parts.append(part)
|
||||
|
||||
if not parts and blocked_reasons:
|
||||
@ -178,7 +186,7 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
|
||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
parts = get_parts_by_type(response, "image/*")
|
||||
for part in parts:
|
||||
if part.inlineData:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
@ -629,7 +637,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
|
||||
if not aspect_ratio:
|
||||
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
|
||||
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
|
||||
image_config = GeminiImageConfig() if aspect_ratio == "auto" else GeminiImageConfig(aspectRatio=aspect_ratio)
|
||||
|
||||
if images is not None:
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
@ -649,7 +657,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||
imageConfig=image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
|
||||
@ -1,31 +1,48 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.hunyuan3d import (
|
||||
Hunyuan3DViewImage,
|
||||
InputGenerateType,
|
||||
ResultFile3D,
|
||||
TextureEditTaskRequest,
|
||||
To3DProTaskCreateResponse,
|
||||
To3DProTaskQueryRequest,
|
||||
To3DProTaskRequest,
|
||||
To3DProTaskResultResponse,
|
||||
To3DUVFileInput,
|
||||
To3DUVTaskRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_file_3d,
|
||||
download_url_to_image_tensor,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_3d_model_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
|
||||
def _is_tencent_rate_limited(status: int, body: object) -> bool:
|
||||
return (
|
||||
status == 400
|
||||
and isinstance(body, dict)
|
||||
and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", ""))
|
||||
)
|
||||
|
||||
|
||||
def get_file_from_response(
|
||||
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
|
||||
) -> ResultFile3D | None:
|
||||
for i in response_objs:
|
||||
if i.Type.lower() == file_type.lower():
|
||||
return i
|
||||
if raise_if_not_found:
|
||||
raise ValueError(f"'{file_type}' file type is not found in the response.")
|
||||
return None
|
||||
|
||||
|
||||
@ -35,8 +52,9 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentTextToModelNode",
|
||||
display_name="Hunyuan3D: Text to Model (Pro)",
|
||||
display_name="Hunyuan3D: Text to Model",
|
||||
category="api node/3d/Tencent",
|
||||
essentials_category="3D",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
@ -120,6 +138,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
EnablePBR=generate_type.get("pbr", None),
|
||||
PolygonType=generate_type.get("polygon_type", None),
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
@ -131,11 +150,14 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
|
||||
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
|
||||
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
|
||||
return IO.NodeOutput(
|
||||
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -145,8 +167,9 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentImageToModelNode",
|
||||
display_name="Hunyuan3D: Image(s) to Model (Pro)",
|
||||
display_name="Hunyuan3D: Image(s) to Model",
|
||||
category="api node/3d/Tencent",
|
||||
essentials_category="3D",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
@ -268,6 +291,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
EnablePBR=generate_type.get("pbr", None),
|
||||
PolygonType=generate_type.get("polygon_type", None),
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
@ -279,11 +303,257 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
|
||||
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
|
||||
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
|
||||
return IO.NodeOutput(
|
||||
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TencentModelTo3DUVNode",
|
||||
display_name="Hunyuan3D: Model to UV",
|
||||
category="api node/3d/Tencent",
|
||||
description="Perform UV unfolding on a 3D model to generate UV texture. "
|
||||
"Input model must have less than 30000 faces.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny],
|
||||
tooltip="Input 3D model (GLB, OBJ, or FBX)",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'),
|
||||
)
|
||||
|
||||
SUPPORTED_FORMATS = {"glb", "obj", "fbx"}
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_3d: Types.File3D,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
file_format = model_3d.format.lower()
|
||||
if file_format not in cls.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported file format: '{file_format}'. "
|
||||
f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=To3DUVTaskRequest(
|
||||
File=To3DUVFileInput(
|
||||
Type=file_format.upper(),
|
||||
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
|
||||
)
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
|
||||
)
|
||||
|
||||
|
||||
class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DTextureEditNode",
|
||||
display_name="Hunyuan3D: 3D Texture Edit",
|
||||
category="api node/3d/Tencent",
|
||||
description="After inputting the 3D model, perform 3D model texture redrawing.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DFBX, IO.File3DAny],
|
||||
tooltip="3D model in FBX format. Model should have less than 100000 faces.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.6}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_3d: Types.File3D,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
file_format = model_3d.format.lower()
|
||||
if file_format != "fbx":
|
||||
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
|
||||
validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
|
||||
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=TextureEditTaskRequest(
|
||||
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
||||
Prompt=prompt,
|
||||
EnablePBR=True,
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
)
|
||||
|
||||
|
||||
class Tencent3DPartNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DPartNode",
|
||||
display_name="Hunyuan3D: 3D Part",
|
||||
category="api node/3d/Tencent",
|
||||
description="Automatically perform component identification and generation based on the model structure.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DFBX, IO.File3DAny],
|
||||
tooltip="3D model in FBX format. Model should have less than 30000 faces.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_3d: Types.File3D,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
_ = seed
|
||||
file_format = model_3d.format.lower()
|
||||
if file_format != "fbx":
|
||||
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
|
||||
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
|
||||
response_model=To3DProTaskCreateResponse,
|
||||
data=To3DUVTaskRequest(
|
||||
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
||||
),
|
||||
is_rate_limited=_is_tencent_rate_limited,
|
||||
)
|
||||
if response.Error:
|
||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||
result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"),
|
||||
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
)
|
||||
|
||||
|
||||
@ -293,6 +563,9 @@ class TencentHunyuan3DExtension(ComfyExtension):
|
||||
return [
|
||||
TencentTextToModelNode,
|
||||
TencentImageToModelNode,
|
||||
# TencentModelTo3DUVNode,
|
||||
# Tencent3DTextureEditNode,
|
||||
Tencent3DPartNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
|
||||
validate_image_dimensions,
|
||||
)
|
||||
|
||||
_EUR_TO_USD = 1.19
|
||||
|
||||
|
||||
def _tier_price_eur(megapixels: float) -> float:
|
||||
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
|
||||
if megapixels <= 1.3:
|
||||
return 0.143
|
||||
if megapixels <= 3.0:
|
||||
return 0.286
|
||||
if megapixels <= 6.4:
|
||||
return 0.429
|
||||
return 1.716
|
||||
|
||||
|
||||
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
|
||||
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
|
||||
num_steps = int(math.log2(scale))
|
||||
total_eur = 0.0
|
||||
pixels = width * height
|
||||
for _ in range(num_steps):
|
||||
total_eur += _tier_price_eur(pixels / 1_000_000)
|
||||
pixels *= 4
|
||||
return round(total_eur * _EUR_TO_USD, 2)
|
||||
|
||||
|
||||
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -105,11 +129,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
|
||||
expr="""
|
||||
(
|
||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
||||
$ad := widgets.auto_downscale;
|
||||
$mins := $ad
|
||||
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
|
||||
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -170,6 +203,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
f"Use a smaller input image or lower scale factor."
|
||||
)
|
||||
|
||||
final_height, final_width = get_image_dimensions(image)
|
||||
actual_scale = int(scale_factor.rstrip("x"))
|
||||
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
|
||||
@ -191,6 +228,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
price_extractor=lambda _: price_usd,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
@ -260,8 +298,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||
expr="""
|
||||
(
|
||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
||||
$mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||
$maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
|
||||
{
|
||||
"type": "range_usd",
|
||||
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||
"format": { "approximate": true }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -324,6 +368,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
f"Use a smaller input image or lower scale factor."
|
||||
)
|
||||
|
||||
final_height, final_width = get_image_dimensions(image)
|
||||
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
|
||||
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
|
||||
@ -342,6 +389,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
|
||||
response_model=TaskResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
price_extractor=lambda _: price_usd,
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=480,
|
||||
)
|
||||
@ -885,8 +933,8 @@ class MagnificExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
# MagnificImageUpscalerCreativeNode,
|
||||
# MagnificImageUpscalerPreciseV2Node,
|
||||
MagnificImageUpscalerCreativeNode,
|
||||
MagnificImageUpscalerPreciseV2Node,
|
||||
MagnificImageStyleTransferNode,
|
||||
MagnificImageRelightNode,
|
||||
MagnificImageSkinEnhancerNode,
|
||||
|
||||
@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=33,
|
||||
min=1,
|
||||
default=80,
|
||||
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Number of denoising steps",
|
||||
@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=33,
|
||||
min=1,
|
||||
default=60,
|
||||
min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
|
||||
max=100,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
video: Input.Video | None = None,
|
||||
control_type: str = "Motion Transfer",
|
||||
motion_intensity: int | None = 100,
|
||||
steps=33,
|
||||
steps=60,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=33,
|
||||
min=1,
|
||||
default=80,
|
||||
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
|
||||
max=100,
|
||||
step=1,
|
||||
tooltip="Inference steps",
|
||||
|
||||
@ -43,7 +43,6 @@ class SupportedOpenAIModel(str, Enum):
|
||||
o1 = "o1"
|
||||
o3 = "o3"
|
||||
o1_pro = "o1-pro"
|
||||
gpt_4o = "gpt-4o"
|
||||
gpt_4_1 = "gpt-4.1"
|
||||
gpt_4_1_mini = "gpt-4.1-mini"
|
||||
gpt_4_1_nano = "gpt-4.1-nano"
|
||||
@ -576,6 +575,7 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
node_id="OpenAIChatNode",
|
||||
display_name="OpenAI ChatGPT",
|
||||
category="api node/text/OpenAI",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses from an OpenAI model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -650,11 +650,6 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
"usd": [0.01, 0.04],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-4o") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0025, 0.01],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
: $contains($m, "gpt-4.1-nano") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.0001, 0.0004],
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
@ -9,6 +8,8 @@ from typing_extensions import override
|
||||
from comfy.utils import ProgressBar
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.recraft import (
|
||||
RECRAFT_V4_PRO_SIZES,
|
||||
RECRAFT_V4_SIZES,
|
||||
RecraftColor,
|
||||
RecraftColorChain,
|
||||
RecraftControls,
|
||||
@ -18,7 +19,6 @@ from comfy_api_nodes.apis.recraft import (
|
||||
RecraftImageGenerationResponse,
|
||||
RecraftImageSize,
|
||||
RecraftIO,
|
||||
RecraftModel,
|
||||
RecraftStyle,
|
||||
RecraftStyleV3,
|
||||
get_v3_substyles,
|
||||
@ -39,7 +39,7 @@ async def handle_recraft_file_request(
|
||||
cls: type[IO.ComfyNode],
|
||||
image: torch.Tensor,
|
||||
path: str,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
mask: torch.Tensor | None = None,
|
||||
total_pixels: int = 4096 * 4096,
|
||||
timeout: int = 1024,
|
||||
request=None,
|
||||
@ -73,11 +73,11 @@ async def handle_recraft_file_request(
|
||||
def recraft_multipart_parser(
|
||||
data,
|
||||
parent_key=None,
|
||||
formatter: Optional[type[callable]] = None,
|
||||
converted_to_check: Optional[list[list]] = None,
|
||||
formatter: type[callable] | None = None,
|
||||
converted_to_check: list[list] | None = None,
|
||||
is_list: bool = False,
|
||||
return_mode: str = "formdata", # "dict" | "formdata"
|
||||
) -> Union[dict, aiohttp.FormData]:
|
||||
) -> dict | aiohttp.FormData:
|
||||
"""
|
||||
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
|
||||
|
||||
@ -309,7 +309,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
|
||||
node_id="RecraftStyleV3InfiniteStyleLibrary",
|
||||
display_name="Recraft Style - Infinite Style Library",
|
||||
category="api node/image/Recraft",
|
||||
description="Select style based on preexisting UUID from Recraft's Infinite Style Library.",
|
||||
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
|
||||
inputs=[
|
||||
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
|
||||
],
|
||||
@ -485,7 +485,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=RecraftModel.recraftv3,
|
||||
model="recraftv3",
|
||||
size=size,
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
@ -598,7 +598,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
|
||||
request = RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=RecraftModel.recraftv3,
|
||||
model="recraftv3",
|
||||
n=n,
|
||||
strength=round(strength, 2),
|
||||
style=recraft_style.style,
|
||||
@ -698,7 +698,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
|
||||
request = RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=RecraftModel.recraftv3,
|
||||
model="recraftv3",
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
substyle=recraft_style.substyle,
|
||||
@ -810,7 +810,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=RecraftModel.recraftv3,
|
||||
model="recraftv3",
|
||||
size=size,
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
@ -933,7 +933,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
|
||||
request = RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model=RecraftModel.recraftv3,
|
||||
model="recraftv3",
|
||||
n=n,
|
||||
style=recraft_style.style,
|
||||
substyle=recraft_style.substyle,
|
||||
@ -963,6 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
|
||||
node_id="RecraftRemoveBackgroundNode",
|
||||
display_name="Recraft Remove Background",
|
||||
category="api node/image/Recraft",
|
||||
essentials_category="Image Tools",
|
||||
description="Remove background from image, and return processed image and mask.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -1078,6 +1079,252 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
|
||||
)
|
||||
|
||||
|
||||
class RecraftV4TextToImageNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToImageNode",
|
||||
display_name="Recraft V4 Text to Image",
|
||||
category="api node/image/Recraft",
|
||||
description="Generates images using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
tooltip="An optional text description of undesired elements on an image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_SIZES,
|
||||
default="1024x1024",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4_pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_PRO_SIZES,
|
||||
default="2048x2048",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"n",
|
||||
default=1,
|
||||
min=1,
|
||||
max=6,
|
||||
tooltip="The number of images to generate.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Custom(RecraftIO.CONTROLS).Input(
|
||||
"recraft_controls",
|
||||
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
model: dict,
|
||||
n: int,
|
||||
seed: int,
|
||||
recraft_controls: RecraftControls | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
|
||||
response_model=RecraftImageGenerationResponse,
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
model=model["model"],
|
||||
size=model["size"],
|
||||
n=n,
|
||||
controls=recraft_controls.create_api_model() if recraft_controls else None,
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
images = []
|
||||
for data in response.data:
|
||||
with handle_recraft_image_output():
|
||||
image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024))
|
||||
if len(image.shape) < 4:
|
||||
image = image.unsqueeze(0)
|
||||
images.append(image)
|
||||
return IO.NodeOutput(torch.cat(images, dim=0))
|
||||
|
||||
|
||||
class RecraftV4TextToVectorNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToVectorNode",
|
||||
display_name="Recraft V4 Text to Vector",
|
||||
category="api node/image/Recraft",
|
||||
description="Generates SVG using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
tooltip="An optional text description of undesired elements on an image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_SIZES,
|
||||
default="1024x1024",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"recraftv4_pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"size",
|
||||
options=RECRAFT_V4_PRO_SIZES,
|
||||
default="2048x2048",
|
||||
tooltip="The size of the generated image.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="The model to use for generation.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"n",
|
||||
default=1,
|
||||
min=1,
|
||||
max=6,
|
||||
tooltip="The number of images to generate.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Custom(RecraftIO.CONTROLS).Input(
|
||||
"recraft_controls",
|
||||
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.SVG.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
model: dict,
|
||||
n: int,
|
||||
seed: int,
|
||||
recraft_controls: RecraftControls | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
|
||||
response_model=RecraftImageGenerationResponse,
|
||||
data=RecraftImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
model=model["model"],
|
||||
size=model["size"],
|
||||
n=n,
|
||||
style="vector_illustration",
|
||||
substyle=None,
|
||||
controls=recraft_controls.create_api_model() if recraft_controls else None,
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
svg_data = []
|
||||
for data in response.data:
|
||||
svg_data.append(await download_url_as_bytesio(data.url, timeout=1024))
|
||||
return IO.NodeOutput(SVG(svg_data))
|
||||
|
||||
|
||||
class RecraftExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1098,6 +1345,8 @@ class RecraftExtension(ComfyExtension):
|
||||
RecraftCreateStyleNode,
|
||||
RecraftColorRGBNode,
|
||||
RecraftControlsNode,
|
||||
RecraftV4TextToImageNode,
|
||||
RecraftV4TextToVectorNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -505,6 +505,9 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.4}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -631,6 +631,7 @@ class StabilityTextToAudio(IO.ComfyNode):
|
||||
node_id="StabilityTextToAudio",
|
||||
display_name="Stability AI Text To Audio",
|
||||
category="api node/audio/Stability AI",
|
||||
essentials_category="Audio",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
|
||||
@ -54,6 +54,7 @@ async def execute_task(
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.state,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None,
|
||||
max_poll_attempts=max_poll_attempts,
|
||||
)
|
||||
if not response.creations:
|
||||
@ -1320,6 +1321,36 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-turbo",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16", "3:4", "4:3", "1:1"],
|
||||
tooltip="The aspect ratio of the output video.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
@ -1348,13 +1379,20 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$base := $lookup({"720p": 0.075, "1080p": 0.1}, $res);
|
||||
$perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res);
|
||||
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
|
||||
$d := $lookup(widgets, "model.duration");
|
||||
$contains(widgets.model, "turbo")
|
||||
? (
|
||||
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
: (
|
||||
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -1423,6 +1461,31 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-turbo",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
@ -1456,13 +1519,20 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res);
|
||||
$perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res);
|
||||
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
|
||||
$d := $lookup(widgets, "model.duration");
|
||||
$contains(widgets.model, "turbo")
|
||||
? (
|
||||
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
: (
|
||||
$rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -1495,6 +1565,145 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class Vidu3StartEndToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3StartEndToVideoNode",
|
||||
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-pro",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"viduq3-turbo",
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720p", "1080p"],
|
||||
tooltip="Resolution of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=1,
|
||||
max=16,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"audio",
|
||||
default=False,
|
||||
tooltip="When enabled, outputs video with sound "
|
||||
"(including dialogue and sound effects).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model to use for video generation.",
|
||||
),
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input("end_frame"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Prompt description (max 2000 characters).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=1,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$d := $lookup(widgets, "model.duration");
|
||||
$contains(widgets.model, "turbo")
|
||||
? (
|
||||
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
: (
|
||||
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
|
||||
{"type":"usd","usd": $rate * $d}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: dict,
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=2000)
|
||||
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model=model["model"],
|
||||
prompt=prompt,
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
resolution=model["resolution"],
|
||||
audio=model["audio"],
|
||||
images=[
|
||||
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
],
|
||||
)
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
|
||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||
|
||||
|
||||
class ViduExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1511,6 +1720,7 @@ class ViduExtension(ComfyExtension):
|
||||
ViduMultiFrameVideoNode,
|
||||
Vidu3TextToVideoNode,
|
||||
Vidu3ImageToVideoNode,
|
||||
Vidu3StartEndToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from .download_helpers import (
|
||||
download_url_to_video_output,
|
||||
)
|
||||
from .upload_helpers import (
|
||||
upload_3d_model_to_comfyapi,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_file_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
@ -62,6 +63,7 @@ __all__ = [
|
||||
"sync_op",
|
||||
"sync_op_raw",
|
||||
# Upload helpers
|
||||
"upload_3d_model_to_comfyapi",
|
||||
"upload_audio_to_comfyapi",
|
||||
"upload_file_to_comfyapi",
|
||||
"upload_image_to_comfyapi",
|
||||
|
||||
@ -57,6 +57,7 @@ class _RequestConfig:
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||
multipart_parser: Callable | None
|
||||
max_retries: int
|
||||
max_retries_on_rate_limit: int
|
||||
retry_delay: float
|
||||
retry_backoff: float
|
||||
wait_label: str = "Waiting"
|
||||
@ -65,6 +66,7 @@ class _RequestConfig:
|
||||
final_label_on_success: str | None = "Completed"
|
||||
progress_origin_ts: float | None = None
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -78,7 +80,7 @@ class _PollUIState:
|
||||
active_since: float | None = None # start time of current active interval (None if queued)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
||||
@ -103,6 +105,8 @@ async def sync_op(
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
@ -122,6 +126,8 @@ async def sync_op(
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
monitor_progress=monitor_progress,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
@ -143,9 +149,9 @@ async def poll_op(
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 160,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
max_retries_per_poll: int = 10,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
retry_backoff_per_poll: float = 1.4,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
@ -194,6 +200,8 @@ async def sync_op_raw(
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
) -> dict[str, Any] | bytes:
|
||||
"""
|
||||
Make a single network request.
|
||||
@ -222,6 +230,8 @@ async def sync_op_raw(
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
price_extractor=price_extractor,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
@ -240,9 +250,9 @@ async def poll_op_raw(
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 160,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
max_retries_per_poll: int = 10,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
retry_backoff_per_poll: float = 1.4,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
|
||||
if status == 409:
|
||||
return "There is a problem with your account. Please contact support@comfy.org."
|
||||
if status == 429:
|
||||
return "Rate Limit Exceeded: Please try again later."
|
||||
return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
|
||||
try:
|
||||
if isinstance(body, dict):
|
||||
err = body.get("error")
|
||||
@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
rate_limit_attempts = 0
|
||||
rate_limit_delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: int | None = None
|
||||
extracted_price: float | None = None
|
||||
@ -653,17 +665,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
payload_headers["Content-Type"] = "application/json"
|
||||
payload_kw["json"] = cfg.data or {}
|
||||
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
|
||||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||||
req_task = asyncio.create_task(req_coro)
|
||||
@ -688,41 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
body = await resp.text()
|
||||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||||
should_retry = False
|
||||
wait_time = 0.0
|
||||
retry_label = ""
|
||||
is_rl = resp.status == 429 or (
|
||||
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
||||
)
|
||||
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||
rate_limit_attempts += 1
|
||||
wait_time = min(rate_limit_delay, 30.0)
|
||||
rate_limit_delay *= cfg.retry_backoff
|
||||
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||
should_retry = True
|
||||
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
wait_time = delay
|
||||
delay *= cfg.retry_backoff
|
||||
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
|
||||
should_retry = True
|
||||
|
||||
if should_retry:
|
||||
logging.warning(
|
||||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||||
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
|
||||
method,
|
||||
url,
|
||||
resp.status,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
wait_time,
|
||||
retry_label,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=_friendly_http_message(resp.status, body),
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
@ -730,10 +731,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
await sleep_with_interrupt(
|
||||
wait_time,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
raise Exception(msg)
|
||||
|
||||
if expect_binary:
|
||||
@ -753,17 +771,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
bytes_payload = bytes(buff)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
return bytes_payload
|
||||
else:
|
||||
try:
|
||||
@ -780,45 +795,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
return payload
|
||||
|
||||
except ProcessingInterrupted:
|
||||
logging.debug("Polling was interrupted by user")
|
||||
raise
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= cfg.max_retries:
|
||||
if (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||||
method,
|
||||
url,
|
||||
delay,
|
||||
attempt,
|
||||
attempt - rate_limit_attempts,
|
||||
cfg.max_retries,
|
||||
str(e),
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
@ -831,23 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
continue
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
@ -855,10 +847,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
raise ApiServerError(
|
||||
f"The API server at {default_base_url()} is currently unreachable. "
|
||||
f"The service may be experiencing issues."
|
||||
|
||||
@ -57,7 +57,7 @@ def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
total_pixels: int | None = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
mime_type: str | None = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
|
||||
@ -167,27 +167,25 @@ async def download_url_to_bytesio(
|
||||
with contextlib.suppress(Exception):
|
||||
dest.seek(0)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
|
||||
@ -8,7 +8,6 @@ from typing import Any
|
||||
|
||||
import folder_paths
|
||||
|
||||
# Get the logger instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -91,38 +90,41 @@ def log_request_response(
|
||||
Filenames are sanitized and length-limited for cross-platform safety.
|
||||
If we still fail to write, we fall back to appending into api.log.
|
||||
"""
|
||||
log_dir = get_log_directory()
|
||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||
|
||||
log_content: list[str] = []
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data is not None:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content is not None:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
log_dir = get_log_directory()
|
||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||
|
||||
log_content: list[str] = []
|
||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||
log_content.append(f"Operation ID: {operation_id}")
|
||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||
log_content.append(f"Method: {request_method}")
|
||||
log_content.append(f"URL: {request_url}")
|
||||
if request_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||
if request_params:
|
||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||
if request_data is not None:
|
||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||
|
||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||
if response_status_code is not None:
|
||||
log_content.append(f"Status Code: {response_status_code}")
|
||||
if response_headers:
|
||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||
if response_content is not None:
|
||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||
if error_message:
|
||||
log_content.append(f"Error:\n{error_message}")
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(log_content))
|
||||
logger.debug("API log saved to: %s", filepath)
|
||||
except Exception as e:
|
||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -164,6 +164,27 @@ async def upload_video_to_comfyapi(
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
|
||||
|
||||
|
||||
_3D_MIME_TYPES = {
|
||||
"glb": "model/gltf-binary",
|
||||
"obj": "model/obj",
|
||||
"fbx": "application/octet-stream",
|
||||
}
|
||||
|
||||
|
||||
async def upload_3d_model_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
model_3d: Types.File3D,
|
||||
file_format: str,
|
||||
) -> str:
|
||||
"""Uploads a 3D model file to ComfyUI API and returns its download URL."""
|
||||
return await upload_file_to_comfyapi(
|
||||
cls,
|
||||
model_3d.get_data(),
|
||||
f"{uuid.uuid4()}.{file_format}",
|
||||
_3D_MIME_TYPES.get(file_format, "application/octet-stream"),
|
||||
)
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
file_bytes_io: BytesIO,
|
||||
@ -255,17 +276,14 @@ async def upload_file(
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
sess: aiohttp.ClientSession | None = None
|
||||
try:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||
@ -311,31 +329,27 @@ async def upload_file(
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
|
||||
@ -20,10 +20,60 @@ class JobStatus:
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
|
||||
|
||||
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
|
||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
|
||||
|
||||
|
||||
def has_3d_extension(filename: str) -> bool:
|
||||
lower = filename.lower()
|
||||
return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
|
||||
|
||||
|
||||
def normalize_output_item(item):
|
||||
"""Normalize a single output list item for the jobs API.
|
||||
|
||||
Returns the normalized item, or None to exclude it.
|
||||
String items with 3D extensions become {filename, type, subfolder} dicts.
|
||||
"""
|
||||
if item is None:
|
||||
return None
|
||||
if isinstance(item, str):
|
||||
if has_3d_extension(item):
|
||||
return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
|
||||
return None
|
||||
if isinstance(item, dict):
|
||||
return item
|
||||
return None
|
||||
|
||||
|
||||
def normalize_outputs(outputs: dict) -> dict:
|
||||
"""Normalize raw node outputs for the jobs API.
|
||||
|
||||
Transforms string 3D filenames into file output dicts and removes
|
||||
None items. All other items (non-3D strings, dicts, etc.) are
|
||||
preserved as-is.
|
||||
"""
|
||||
normalized = {}
|
||||
for node_id, node_outputs in outputs.items():
|
||||
if not isinstance(node_outputs, dict):
|
||||
normalized[node_id] = node_outputs
|
||||
continue
|
||||
normalized_node = {}
|
||||
for media_type, items in node_outputs.items():
|
||||
if media_type == 'animated' or not isinstance(items, list):
|
||||
normalized_node[media_type] = items
|
||||
continue
|
||||
normalized_items = []
|
||||
for item in items:
|
||||
if item is None:
|
||||
continue
|
||||
norm = normalize_output_item(item)
|
||||
normalized_items.append(norm if norm is not None else item)
|
||||
normalized_node[media_type] = normalized_items
|
||||
normalized[node_id] = normalized_node
|
||||
return normalized
|
||||
|
||||
|
||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||
@ -45,9 +95,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
|
||||
Maintains backwards compatibility with existing logic.
|
||||
|
||||
Priority:
|
||||
1. media_type is 'images', 'video', or 'audio'
|
||||
1. media_type is 'images', 'video', 'audio', or '3d'
|
||||
2. format field starts with 'video/' or 'audio/'
|
||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
|
||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
|
||||
"""
|
||||
if media_type in PREVIEWABLE_MEDIA_TYPES:
|
||||
return True
|
||||
@ -139,7 +189,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
||||
})
|
||||
|
||||
if include_outputs:
|
||||
job['outputs'] = outputs
|
||||
job['outputs'] = normalize_outputs(outputs)
|
||||
job['execution_status'] = status_info
|
||||
job['workflow'] = {
|
||||
'prompt': prompt,
|
||||
@ -171,18 +221,23 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
||||
continue
|
||||
|
||||
for item in items:
|
||||
count += 1
|
||||
|
||||
if not isinstance(item, dict):
|
||||
normalized = normalize_output_item(item)
|
||||
if normalized is None:
|
||||
continue
|
||||
|
||||
if preview_output is None and is_previewable(media_type, item):
|
||||
count += 1
|
||||
|
||||
if preview_output is not None:
|
||||
continue
|
||||
|
||||
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
|
||||
enriched = {
|
||||
**item,
|
||||
**normalized,
|
||||
'nodeId': node_id,
|
||||
'mediaType': media_type
|
||||
}
|
||||
if item.get('type') == 'output':
|
||||
if 'mediaType' not in normalized:
|
||||
enriched['mediaType'] = media_type
|
||||
if normalized.get('type') == 'output':
|
||||
preview_output = enriched
|
||||
elif fallback_preview is None:
|
||||
fallback_preview = enriched
|
||||
|
||||
@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
|
||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||
],
|
||||
outputs=[io.Conditioning.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
@ -159,6 +159,7 @@ class SaveAudio(IO.ComfyNode):
|
||||
search_aliases=["export flac"],
|
||||
display_name="Save Audio (FLAC)",
|
||||
category="audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
@ -300,6 +301,7 @@ class LoadAudio(IO.ComfyNode):
|
||||
search_aliases=["import audio", "open audio", "audio file"],
|
||||
display_name="Load Audio",
|
||||
category="audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
|
||||
],
|
||||
@ -700,6 +702,67 @@ class EmptyAudio(IO.ComfyNode):
|
||||
create_empty_audio = execute # TODO: remove
|
||||
|
||||
|
||||
class AudioEqualizer3Band(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioEqualizer3Band",
|
||||
search_aliases=["eq", "bass boost", "treble boost", "equalizer"],
|
||||
display_name="Audio Equalizer (3-Band)",
|
||||
category="audio",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"),
|
||||
IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"),
|
||||
IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"),
|
||||
IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"),
|
||||
IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"),
|
||||
IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"),
|
||||
IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
eq_waveform = waveform.clone()
|
||||
|
||||
# 1. Apply Low Shelf (Bass)
|
||||
if low_gain_dB != 0:
|
||||
eq_waveform = torchaudio.functional.bass_biquad(
|
||||
eq_waveform,
|
||||
sample_rate,
|
||||
gain=low_gain_dB,
|
||||
central_freq=float(low_freq),
|
||||
Q=0.707
|
||||
)
|
||||
|
||||
# 2. Apply Peaking EQ (Mids)
|
||||
if mid_gain_dB != 0:
|
||||
eq_waveform = torchaudio.functional.equalizer_biquad(
|
||||
eq_waveform,
|
||||
sample_rate,
|
||||
center_freq=float(mid_freq),
|
||||
gain=mid_gain_dB,
|
||||
Q=mid_q
|
||||
)
|
||||
|
||||
# 3. Apply High Shelf (Treble)
|
||||
if high_gain_dB != 0:
|
||||
eq_waveform = torchaudio.functional.treble_biquad(
|
||||
eq_waveform,
|
||||
sample_rate,
|
||||
gain=high_gain_dB,
|
||||
central_freq=float(high_freq),
|
||||
Q=0.707
|
||||
)
|
||||
|
||||
return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate})
|
||||
|
||||
|
||||
class AudioExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -722,6 +785,7 @@ class AudioExtension(ComfyExtension):
|
||||
AudioMerge,
|
||||
AudioAdjustVolume,
|
||||
EmptyAudio,
|
||||
AudioEqualizer3Band,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> AudioExtension:
|
||||
|
||||
@ -12,6 +12,7 @@ class Canny(io.ComfyNode):
|
||||
node_id="Canny",
|
||||
search_aliases=["edge detection", "outline", "contour detection", "line art"],
|
||||
category="image/preprocessors",
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01),
|
||||
|
||||
@ -622,6 +622,7 @@ class SamplerSASolver(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSASolver",
|
||||
search_aliases=["sde"],
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@ -666,6 +667,7 @@ class SamplerSEEDS2(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSEEDS2",
|
||||
search_aliases=["sde", "exp heun"],
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||
|
||||
@ -108,7 +108,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||
if easycache.is_past_end_timestep(timestep):
|
||||
return executor(*args, **kwargs)
|
||||
x: torch.Tensor = _extract_tensor(args[0], easycache.output_channels)
|
||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||
# prepare next x_prev
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
|
||||
@ -621,6 +621,7 @@ class SaveGLB(IO.ComfyNode):
|
||||
display_name="Save 3D Model",
|
||||
search_aliases=["export 3d model", "save mesh"],
|
||||
category="3d",
|
||||
essentials_category="Basics",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
|
||||
@ -23,8 +23,10 @@ class ImageCrop(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ImageCrop",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
display_name="Image Crop (Deprecated)",
|
||||
category="image/transform",
|
||||
is_deprecated=True,
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
@ -47,6 +49,57 @@ class ImageCrop(IO.ComfyNode):
|
||||
crop = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageCropV2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCropV2",
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, crop_region) -> IO.NodeOutput:
|
||||
x = crop_region.get("x", 0)
|
||||
y = crop_region.get("y", 0)
|
||||
width = crop_region.get("width", 512)
|
||||
height = crop_region.get("height", 512)
|
||||
|
||||
x = min(x, image.shape[2] - 1)
|
||||
y = min(y, image.shape[1] - 1)
|
||||
to_x = width + x
|
||||
to_y = height + y
|
||||
img = image[:,y:to_y, x:to_x, :]
|
||||
return IO.NodeOutput(img, ui=UI.PreviewImage(img))
|
||||
|
||||
|
||||
class BoundingBox(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PrimitiveBoundingBox",
|
||||
display_name="Bounding Box",
|
||||
category="utils/primitive",
|
||||
inputs=[
|
||||
IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
|
||||
],
|
||||
outputs=[IO.BoundingBox.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, x, y, width, height) -> IO.NodeOutput:
|
||||
return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
|
||||
|
||||
|
||||
class RepeatImageBatch(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -537,6 +590,7 @@ class ImageRotate(IO.ComfyNode):
|
||||
node_id="ImageRotate",
|
||||
search_aliases=["turn", "flip orientation"],
|
||||
category="image/transform",
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]),
|
||||
@ -632,6 +686,8 @@ class ImagesExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ImageCrop,
|
||||
ImageCropV2,
|
||||
BoundingBox,
|
||||
RepeatImageBatch,
|
||||
ImageFromBatch,
|
||||
ImageAddNoise,
|
||||
|
||||
@ -391,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||
normalized_latent = latent / latent_vector_magnitude
|
||||
|
||||
mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True)
|
||||
dims = list(range(1, latent_vector_magnitude.ndim))
|
||||
mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||
std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True)
|
||||
|
||||
top = (std * 5 + mean) * multiplier
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@ class Load3D(IO.ComfyNode):
|
||||
node_id="Load3D",
|
||||
display_name="Load 3D & Animation",
|
||||
category="3d",
|
||||
essentials_category="Basics",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
|
||||
|
||||
@ -7,6 +7,7 @@ import logging
|
||||
from enum import Enum
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from tqdm.auto import trange
|
||||
|
||||
CLAMP_QUANTILE = 0.99
|
||||
|
||||
@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
|
||||
"full_diff": LORAType.FULL_DIFF}
|
||||
|
||||
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
||||
comfy.model_management.load_models_gpu([model_diff])
|
||||
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||
|
||||
for k in sd:
|
||||
if k.endswith(".weight"):
|
||||
sd_keys = list(sd.keys())
|
||||
for index in trange(len(sd_keys), unit="weight"):
|
||||
k = sd_keys[index]
|
||||
op_keys = sd_keys[index].rsplit('.', 1)
|
||||
if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
|
||||
continue
|
||||
op = comfy.utils.get_attr(model_diff.model, op_keys[0])
|
||||
if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
|
||||
weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
|
||||
else:
|
||||
weight_diff = sd[k]
|
||||
|
||||
if op_keys[1] == "weight":
|
||||
if lora_type == LORAType.STANDARD:
|
||||
if weight_diff.ndim < 2:
|
||||
if bias_diff:
|
||||
@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
|
||||
elif lora_type == LORAType.FULL_DIFF:
|
||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||
|
||||
elif bias_diff and k.endswith(".bias"):
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||
elif bias_diff and op_keys[1] == "bias":
|
||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
|
||||
return output_sd
|
||||
|
||||
class LoraSave(io.ComfyNode):
|
||||
|
||||
99
comfy_extras/nodes_nag.py
Normal file
99
comfy_extras/nodes_nag.py
Normal file
@ -0,0 +1,99 @@
|
||||
import torch
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class NAGuidance(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="NAGuidance",
|
||||
display_name="Normalized Attention Guidance",
|
||||
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply NAG to."),
|
||||
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
|
||||
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
|
||||
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
|
||||
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
|
||||
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with NAG enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
|
||||
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
|
||||
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
|
||||
|
||||
def nag_attention_output_patch(out, extra_options):
|
||||
cond_or_uncond = extra_options.get("cond_or_uncond", None)
|
||||
if cond_or_uncond is None:
|
||||
return out
|
||||
|
||||
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
|
||||
return out
|
||||
|
||||
# sigma = extra_options.get("sigmas", None)
|
||||
# if sigma is not None and len(sigma) > 0:
|
||||
# sigma = sigma[0].item()
|
||||
# if sigma > sigma_start or sigma < sigma_end:
|
||||
# return out
|
||||
|
||||
img_slice = extra_options.get("img_slice", None)
|
||||
|
||||
if img_slice is not None:
|
||||
orig_out = out
|
||||
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
|
||||
|
||||
batch_size = out.shape[0]
|
||||
half_size = batch_size // len(cond_or_uncond)
|
||||
|
||||
ind_neg = cond_or_uncond.index(1)
|
||||
ind_pos = cond_or_uncond.index(0)
|
||||
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
|
||||
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
|
||||
|
||||
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
|
||||
|
||||
eps = 1e-6
|
||||
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||
|
||||
ratio = norm_guided / norm_pos
|
||||
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
|
||||
|
||||
guided_normalized = guided * scale_factor
|
||||
|
||||
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
|
||||
|
||||
if img_slice is not None:
|
||||
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
|
||||
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
|
||||
return orig_out
|
||||
else:
|
||||
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
|
||||
return out
|
||||
|
||||
m.set_model_attn1_output_patch(nag_attention_output_patch)
|
||||
m.disable_model_cfg1_optimization()
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class NagExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
NAGuidance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> NagExtension:
|
||||
return NagExtension()
|
||||
@ -77,6 +77,7 @@ class Blur(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ImageBlur",
|
||||
category="image/postprocessing",
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("blur_radius", default=1, min=1, max=31, step=1),
|
||||
@ -655,6 +656,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
||||
batched = batch_masks(values)
|
||||
return io.NodeOutput(batched)
|
||||
|
||||
|
||||
class PostProcessingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
|
||||
103
comfy_extras/nodes_replacements.py
Normal file
103
comfy_extras/nodes_replacements.py
Normal file
@ -0,0 +1,103 @@
|
||||
from comfy_api.latest import ComfyExtension, io, ComfyAPI
|
||||
|
||||
api = ComfyAPI()
|
||||
|
||||
|
||||
async def register_replacements():
|
||||
"""Register all built-in node replacements."""
|
||||
await register_replacements_longeredge()
|
||||
await register_replacements_batchimages()
|
||||
await register_replacements_upscaleimage()
|
||||
await register_replacements_controlnet()
|
||||
await register_replacements_load3d()
|
||||
await register_replacements_preview3d()
|
||||
await register_replacements_svdimg2vid()
|
||||
await register_replacements_conditioningavg()
|
||||
|
||||
async def register_replacements_longeredge():
|
||||
# No dynamic inputs here
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ImageScaleToMaxDimension",
|
||||
old_node_id="ResizeImagesByLongerEdge",
|
||||
old_widget_ids=["longer_edge"],
|
||||
input_mapping=[
|
||||
{"new_id": "image", "old_id": "images"},
|
||||
{"new_id": "largest_size", "old_id": "longer_edge"},
|
||||
{"new_id": "upscale_method", "set_value": "lanczos"},
|
||||
],
|
||||
# just to test the frontend output_mapping code, does nothing really here
|
||||
output_mapping=[{"new_idx": 0, "old_idx": 0}],
|
||||
))
|
||||
|
||||
async def register_replacements_batchimages():
|
||||
# BatchImages node uses Autogrow
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="BatchImagesNode",
|
||||
old_node_id="ImageBatch",
|
||||
input_mapping=[
|
||||
{"new_id": "images.image0", "old_id": "image1"},
|
||||
{"new_id": "images.image1", "old_id": "image2"},
|
||||
],
|
||||
))
|
||||
|
||||
async def register_replacements_upscaleimage():
|
||||
# ResizeImageMaskNode uses DynamicCombo
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ResizeImageMaskNode",
|
||||
old_node_id="ImageScaleBy",
|
||||
old_widget_ids=["upscale_method", "scale_by"],
|
||||
input_mapping=[
|
||||
{"new_id": "input", "old_id": "image"},
|
||||
{"new_id": "resize_type", "set_value": "scale by multiplier"},
|
||||
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
|
||||
{"new_id": "scale_method", "old_id": "upscale_method"},
|
||||
],
|
||||
))
|
||||
|
||||
async def register_replacements_controlnet():
|
||||
# T2IAdapterLoader → ControlNetLoader
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ControlNetLoader",
|
||||
old_node_id="T2IAdapterLoader",
|
||||
input_mapping=[
|
||||
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
|
||||
],
|
||||
))
|
||||
|
||||
async def register_replacements_load3d():
|
||||
# Load3DAnimation merged into Load3D
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="Load3D",
|
||||
old_node_id="Load3DAnimation",
|
||||
))
|
||||
|
||||
async def register_replacements_preview3d():
|
||||
# Preview3DAnimation merged into Preview3D
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="Preview3D",
|
||||
old_node_id="Preview3DAnimation",
|
||||
))
|
||||
|
||||
async def register_replacements_svdimg2vid():
|
||||
# Typo fix: SDV → SVD
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="SVD_img2vid_Conditioning",
|
||||
old_node_id="SDV_img2vid_Conditioning",
|
||||
))
|
||||
|
||||
async def register_replacements_conditioningavg():
|
||||
# Typo fix: trailing space in node name
|
||||
await api.node_replacement.register(io.NodeReplace(
|
||||
new_node_id="ConditioningAverage",
|
||||
old_node_id="ConditioningAverage ",
|
||||
))
|
||||
|
||||
class NodeReplacementsExtension(ComfyExtension):
|
||||
async def on_load(self) -> None:
|
||||
await register_replacements()
|
||||
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return []
|
||||
|
||||
async def comfy_entrypoint() -> NodeReplacementsExtension:
|
||||
return NodeReplacementsExtension()
|
||||
176
comfy_extras/nodes_textgen.py
Normal file
176
comfy_extras/nodes_textgen.py
Normal file
@ -0,0 +1,176 @@
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
|
||||
class TextGenerate(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
# Define dynamic combo options for sampling mode
|
||||
sampling_options = [
|
||||
io.DynamicCombo.Option(
|
||||
key="on",
|
||||
inputs=[
|
||||
io.Float.Input("temperature", default=0.7, min=0.01, max=2.0, step=0.000001),
|
||||
io.Int.Input("top_k", default=64, min=0, max=1000),
|
||||
io.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
|
||||
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
|
||||
]
|
||||
),
|
||||
io.DynamicCombo.Option(
|
||||
key="off",
|
||||
inputs=[]
|
||||
),
|
||||
]
|
||||
|
||||
return io.Schema(
|
||||
node_id="TextGenerate",
|
||||
category="textgen/",
|
||||
search_aliases=["LLM", "gemma"],
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
|
||||
io.Image.Input("image", optional=True),
|
||||
io.Int.Input("max_length", default=256, min=1, max=2048),
|
||||
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
|
||||
],
|
||||
outputs=[
|
||||
io.String.Output(display_name="generated_text"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
||||
|
||||
tokens = clip.tokenize(prompt, image=image, skip_template=False)
|
||||
|
||||
# Get sampling parameters from dynamic combo
|
||||
do_sample = sampling_mode.get("sampling_mode") == "on"
|
||||
temperature = sampling_mode.get("temperature", 1.0)
|
||||
top_k = sampling_mode.get("top_k", 50)
|
||||
top_p = sampling_mode.get("top_p", 1.0)
|
||||
min_p = sampling_mode.get("min_p", 0.0)
|
||||
seed = sampling_mode.get("seed", None)
|
||||
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
|
||||
|
||||
generated_ids = clip.generate(
|
||||
tokens,
|
||||
do_sample=do_sample,
|
||||
max_length=max_length,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
|
||||
return io.NodeOutput(generated_text)
|
||||
|
||||
|
||||
LTX2_T2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
|
||||
#### Guidelines
|
||||
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
|
||||
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
|
||||
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
|
||||
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
|
||||
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
|
||||
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
|
||||
- Speech (only when requested):
|
||||
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
|
||||
- Specify language if not English and accent if relevant.
|
||||
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if unspecified. Omit if unclear.
|
||||
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
|
||||
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
|
||||
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
|
||||
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
|
||||
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
|
||||
|
||||
#### Important notes:
|
||||
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is requested.
|
||||
- Camera motion: DO NOT invent camera motion unless requested by the user.
|
||||
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
|
||||
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological scene description.
|
||||
- Format: DO NOT start your response with special characters.
|
||||
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits or introduce new elements. Add/enhance audio descriptions if missing.
|
||||
|
||||
#### Output Format (Strict):
|
||||
- Single continuous paragraph in natural language (English).
|
||||
- NO titles, headings, prefaces, code fences, or Markdown.
|
||||
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||
|
||||
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video generation.
|
||||
|
||||
#### Example
|
||||
Input: "A woman at a coffee shop talking on the phone"
|
||||
Output:
|
||||
Style: realistic with cinematic lighting. In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone to her ear with the other. Ambient cafe sounds fill the space—espresso machine hiss, quiet conversations, gentle clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully, lowering the phone.
|
||||
"""
|
||||
|
||||
LTX2_I2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
|
||||
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image.
|
||||
|
||||
#### Guidelines:
|
||||
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
|
||||
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene).
|
||||
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts.
|
||||
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
|
||||
- Chronological flow: Use temporal connectors ("as," "then," "while").
|
||||
- Audio layer: Describe complete soundscape throughout the prompt alongside actions—NOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
|
||||
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.")
|
||||
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
|
||||
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
|
||||
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
|
||||
|
||||
#### Important notes:
|
||||
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion only if specified in the input.
|
||||
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
|
||||
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
|
||||
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
|
||||
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style (optional) and chronological scene description.
|
||||
- Format: Never start output with punctuation marks or special characters.
|
||||
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
|
||||
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
|
||||
|
||||
#### Output Format (Strict):
|
||||
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
|
||||
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
|
||||
|
||||
#### Example output:
|
||||
Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
|
||||
"""
|
||||
|
||||
class TextGenerateLTX2Prompt(TextGenerate):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
parent_schema = super().define_schema()
|
||||
return io.Schema(
|
||||
node_id="TextGenerateLTX2Prompt",
|
||||
category=parent_schema.category,
|
||||
inputs=parent_schema.inputs,
|
||||
outputs=parent_schema.outputs,
|
||||
search_aliases=["prompt enhance", "LLM", "gemma"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
|
||||
if image is None:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
else:
|
||||
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
|
||||
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
|
||||
|
||||
|
||||
class TextgenExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextGenerate,
|
||||
TextGenerateLTX2Prompt,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> TextgenExtension:
|
||||
return TextgenExtension()
|
||||
@ -4,6 +4,7 @@ import os
|
||||
import numpy as np
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from tqdm.auto import trange
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||
"""
|
||||
CFGGuider with modifications for training specific logic
|
||||
"""
|
||||
|
||||
def __init__(self, *args, offloading=False, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.offloading = offloading
|
||||
|
||||
def outer_sample(
|
||||
self,
|
||||
noise,
|
||||
@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
||||
noise.shape,
|
||||
self.conds,
|
||||
self.model_options,
|
||||
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
|
||||
force_full_load=not self.offloading,
|
||||
force_offload=self.offloading,
|
||||
)
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
if denoise_mask is not None:
|
||||
@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward(
|
||||
return result
|
||||
|
||||
|
||||
def patch(m):
|
||||
def find_modules_at_depth(
|
||||
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
|
||||
) -> list[nn.Module]:
|
||||
"""
|
||||
Find modules at a specific depth level for gradient checkpointing.
|
||||
|
||||
Args:
|
||||
model: The model to search
|
||||
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
|
||||
result: Accumulator for results
|
||||
current_depth: Current recursion depth
|
||||
name: Current module name for logging
|
||||
|
||||
Returns:
|
||||
List of modules at the target depth
|
||||
"""
|
||||
if result is None:
|
||||
result = []
|
||||
name = name or "root"
|
||||
|
||||
# Skip container modules (they don't have meaningful forward)
|
||||
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
|
||||
has_forward = hasattr(model, "forward") and not is_container
|
||||
|
||||
if has_forward:
|
||||
current_depth += 1
|
||||
if current_depth == depth:
|
||||
result.append(model)
|
||||
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
|
||||
return result
|
||||
|
||||
# Recurse into children
|
||||
for next_name, child in model.named_children():
|
||||
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class OffloadCheckpointFunction(torch.autograd.Function):
|
||||
"""
|
||||
Gradient checkpointing that works with weight offloading.
|
||||
|
||||
Forward: no_grad -> compute -> weights can be freed
|
||||
Backward: enable_grad -> recompute -> backward -> weights can be freed
|
||||
|
||||
For single input, single output modules (Linear, Conv*).
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x: torch.Tensor, forward_fn):
|
||||
ctx.save_for_backward(x)
|
||||
ctx.forward_fn = forward_fn
|
||||
with torch.no_grad():
|
||||
return forward_fn(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out: torch.Tensor):
|
||||
x, = ctx.saved_tensors
|
||||
forward_fn = ctx.forward_fn
|
||||
|
||||
# Clear context early
|
||||
ctx.forward_fn = None
|
||||
|
||||
with torch.enable_grad():
|
||||
x_detached = x.detach().requires_grad_(True)
|
||||
y = forward_fn(x_detached)
|
||||
y.backward(grad_out)
|
||||
grad_x = x_detached.grad
|
||||
|
||||
# Explicit cleanup
|
||||
del y, x_detached, forward_fn
|
||||
|
||||
return grad_x, None
|
||||
|
||||
|
||||
def patch(m, offloading=False):
|
||||
if not hasattr(m, "forward"):
|
||||
return
|
||||
org_forward = m.forward
|
||||
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
|
||||
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||
def checkpointing_fwd(x):
|
||||
return OffloadCheckpointFunction.apply(x, org_forward)
|
||||
# Branch 2: Others -> standard checkpoint
|
||||
else:
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||
|
||||
m.org_forward = org_forward
|
||||
m.forward = checkpointing_fwd
|
||||
@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode):
|
||||
default=True,
|
||||
tooltip="Use gradient checkpointing for training.",
|
||||
),
|
||||
io.Int.Input(
|
||||
"checkpoint_depth",
|
||||
default=1,
|
||||
min=1,
|
||||
max=5,
|
||||
tooltip="Depth level for gradient checkpointing.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"offloading",
|
||||
default=False,
|
||||
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"existing_lora",
|
||||
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
lora_dtype,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
checkpoint_depth,
|
||||
offloading,
|
||||
existing_lora,
|
||||
bucket_mode,
|
||||
bypass_mode,
|
||||
@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
lora_dtype = lora_dtype[0]
|
||||
algorithm = algorithm[0]
|
||||
gradient_checkpointing = gradient_checkpointing[0]
|
||||
offloading = offloading[0]
|
||||
checkpoint_depth = checkpoint_depth[0]
|
||||
existing_lora = existing_lora[0]
|
||||
bucket_mode = bucket_mode[0]
|
||||
bypass_mode = bypass_mode[0]
|
||||
@ -1019,6 +1124,15 @@ class TrainLoraNode(io.ComfyNode):
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
if mp.is_dynamic():
|
||||
if not bypass_mode:
|
||||
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
||||
bypass_mode = True
|
||||
offloading = True
|
||||
elif offloading:
|
||||
if not bypass_mode:
|
||||
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, dtype, bucket_mode
|
||||
@ -1054,16 +1168,18 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Setup gradient checkpointing
|
||||
if gradient_checkpointing:
|
||||
for m in find_all_highest_child_module_with_forward(
|
||||
mp.model.diffusion_model
|
||||
):
|
||||
patch(m)
|
||||
modules_to_patch = find_modules_at_depth(
|
||||
mp.model.diffusion_model, depth=checkpoint_depth
|
||||
)
|
||||
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
|
||||
for m in modules_to_patch:
|
||||
patch(m, offloading=offloading)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
# With force_full_load=False we should be able to have offloading
|
||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||
comfy.model_management.load_models_gpu(
|
||||
[mp], memory_required=1e20, force_full_load=True
|
||||
[mp], memory_required=1e20, force_full_load=not offloading
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -1100,7 +1216,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
# Setup guider
|
||||
guider = TrainGuider(mp)
|
||||
guider = TrainGuider(mp, offloading=offloading)
|
||||
guider.set_conds(positive)
|
||||
|
||||
# Inject bypass hooks if bypass mode is enabled
|
||||
@ -1113,6 +1229,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Run training loop
|
||||
try:
|
||||
comfy.model_management.in_training = True
|
||||
_run_training_loop(
|
||||
guider,
|
||||
train_sampler,
|
||||
@ -1123,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
multi_res,
|
||||
)
|
||||
finally:
|
||||
comfy.model_management.in_training = False
|
||||
# Eject bypass hooks if they were injected
|
||||
if bypass_injections is not None:
|
||||
for injection in bypass_injections:
|
||||
@ -1132,19 +1250,20 @@ class TrainLoraNode(io.ComfyNode):
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
|
||||
# Finalize adapters
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
|
||||
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
del adapter
|
||||
del all_weight_adapters
|
||||
|
||||
# mp in train node is highly specialized for training
|
||||
# use it in inference will result in bad behavior so we don't return it
|
||||
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
class LoraModelLoader(io.ComfyNode):#
|
||||
class LoraModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
@ -1166,6 +1285,11 @@ class LoraModelLoader(io.ComfyNode):#
|
||||
max=100.0,
|
||||
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"bypass",
|
||||
default=False,
|
||||
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(
|
||||
@ -1175,13 +1299,18 @@ class LoraModelLoader(io.ComfyNode):#
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, lora, strength_model):
|
||||
def execute(cls, model, lora, strength_model, bypass=False):
|
||||
if strength_model == 0:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(
|
||||
model, None, lora, strength_model, 0
|
||||
)
|
||||
if bypass:
|
||||
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
|
||||
model, None, lora, strength_model, 0
|
||||
)
|
||||
else:
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(
|
||||
model, None, lora, strength_model, 0
|
||||
)
|
||||
return io.NodeOutput(model_lora)
|
||||
|
||||
|
||||
|
||||
@ -73,6 +73,7 @@ class SaveVideo(io.ComfyNode):
|
||||
search_aliases=["export video"],
|
||||
display_name="Save Video",
|
||||
category="image/video",
|
||||
essentials_category="Basics",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to save."),
|
||||
@ -146,6 +147,7 @@ class GetVideoComponents(io.ComfyNode):
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
category="image/video",
|
||||
essentials_category="Video Tools",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||
@ -174,6 +176,7 @@ class LoadVideo(io.ComfyNode):
|
||||
search_aliases=["import video", "open video", "video file"],
|
||||
display_name="Load Video",
|
||||
category="image/video",
|
||||
essentials_category="Basics",
|
||||
inputs=[
|
||||
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
|
||||
],
|
||||
@ -202,6 +205,56 @@ class LoadVideo(io.ComfyNode):
|
||||
|
||||
return True
|
||||
|
||||
class VideoSlice(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Video Slice",
|
||||
display_name="Video Slice",
|
||||
search_aliases=[
|
||||
"trim video duration",
|
||||
"skip first frames",
|
||||
"frame load cap",
|
||||
"start time",
|
||||
],
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Video.Input("video"),
|
||||
io.Float.Input(
|
||||
"start_time",
|
||||
default=0.0,
|
||||
max=1e5,
|
||||
min=-1e5,
|
||||
step=0.001,
|
||||
tooltip="Start time in seconds",
|
||||
),
|
||||
io.Float.Input(
|
||||
"duration",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
step=0.001,
|
||||
tooltip="Duration in seconds, or 0 for unlimited duration",
|
||||
),
|
||||
io.Boolean.Input(
|
||||
"strict_duration",
|
||||
default=False,
|
||||
tooltip="If True, when the specified duration is not possible, an error will be raised.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
|
||||
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
|
||||
if trimmed is not None:
|
||||
return io.NodeOutput(trimmed)
|
||||
raise ValueError(
|
||||
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
|
||||
)
|
||||
|
||||
|
||||
class VideoExtension(ComfyExtension):
|
||||
@override
|
||||
@ -212,6 +265,7 @@ class VideoExtension(ComfyExtension):
|
||||
CreateVideo,
|
||||
GetVideoComponents,
|
||||
LoadVideo,
|
||||
VideoSlice,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> VideoExtension:
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.12.3"
|
||||
__version__ = "0.14.1"
|
||||
|
||||
@ -13,8 +13,11 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
from latent_preview import set_preview_method
|
||||
import nodes
|
||||
from comfy_execution.caching import (
|
||||
@ -527,8 +530,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||
finally:
|
||||
if allocator is not None:
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.model_vbar.vbars_analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
torch.cuda.synchronize()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
|
||||
if has_pending_tasks:
|
||||
pending_async_nodes[unique_id] = output_data
|
||||
@ -618,6 +623,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:
|
||||
tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected."
|
||||
|
||||
error_details = {
|
||||
"node_id": real_node_id,
|
||||
|
||||
13
nodes.py
13
nodes.py
@ -8,6 +8,7 @@ import json
|
||||
import glob
|
||||
import hashlib
|
||||
import inspect
|
||||
|
||||
import traceback
|
||||
import math
|
||||
import time
|
||||
@ -69,6 +70,7 @@ class CLIPTextEncode(ComfyNodeABC):
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
ESSENTIALS_CATEGORY = "Basics"
|
||||
DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images."
|
||||
SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"]
|
||||
|
||||
@ -667,6 +669,8 @@ class CLIPSetLastLayer:
|
||||
return (clip,)
|
||||
|
||||
class LoraLoader:
|
||||
ESSENTIALS_CATEGORY = "Image Generation"
|
||||
|
||||
def __init__(self):
|
||||
self.loaded_lora = None
|
||||
|
||||
@ -1648,6 +1652,7 @@ class SaveImage:
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "image"
|
||||
ESSENTIALS_CATEGORY = "Basics"
|
||||
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
|
||||
SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"]
|
||||
|
||||
@ -1706,6 +1711,7 @@ class LoadImage:
|
||||
}
|
||||
|
||||
CATEGORY = "image"
|
||||
ESSENTIALS_CATEGORY = "Basics"
|
||||
SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"]
|
||||
|
||||
RETURN_TYPES = ("IMAGE", "MASK")
|
||||
@ -1863,6 +1869,7 @@ class ImageScale:
|
||||
FUNCTION = "upscale"
|
||||
|
||||
CATEGORY = "image/upscaling"
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
SEARCH_ALIASES = ["resize", "resize image", "scale image", "image resize", "zoom", "zoom in", "change size"]
|
||||
|
||||
def upscale(self, image, upscale_method, width, height, crop):
|
||||
@ -1902,6 +1909,7 @@ class ImageScaleBy:
|
||||
|
||||
class ImageInvert:
|
||||
SEARCH_ALIASES = ["reverse colors"]
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -1918,6 +1926,7 @@ class ImageInvert:
|
||||
|
||||
class ImageBatch:
|
||||
SEARCH_ALIASES = ["combine images", "merge images", "stack images"]
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -2264,6 +2273,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
||||
if not isinstance(extension, ComfyExtension):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
|
||||
return False
|
||||
await extension.on_load()
|
||||
node_list = await extension.get_node_list()
|
||||
if not isinstance(node_list, list):
|
||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
|
||||
@ -2433,8 +2443,11 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
"nodes_lora_debug.py",
|
||||
"nodes_textgen.py",
|
||||
"nodes_color.py",
|
||||
"nodes_toolkit.py",
|
||||
"nodes_replacements.py",
|
||||
"nodes_nag.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.12.3"
|
||||
version = "0.14.1"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.38.13
|
||||
comfyui-workflow-templates==0.8.31
|
||||
comfyui-embedded-docs==0.4.0
|
||||
comfyui-frontend-package==1.39.14
|
||||
comfyui-workflow-templates==0.8.43
|
||||
comfyui-embedded-docs==0.4.1
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@ -22,7 +22,7 @@ alembic
|
||||
SQLAlchemy
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.1.7
|
||||
comfy-aimdo>=0.1.8
|
||||
requests
|
||||
|
||||
#non essential dependencies:
|
||||
|
||||
@ -40,6 +40,7 @@ from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from app.node_replace_manager import NodeReplaceManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@ -204,6 +205,7 @@ class PromptServer():
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.node_replace_manager = NodeReplaceManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
@ -687,6 +689,10 @@ class PromptServer():
|
||||
info['api_node'] = obj_class.API_NODE
|
||||
|
||||
info['search_aliases'] = getattr(obj_class, 'SEARCH_ALIASES', [])
|
||||
|
||||
if hasattr(obj_class, 'ESSENTIALS_CATEGORY'):
|
||||
info['essentials_category'] = obj_class.ESSENTIALS_CATEGORY
|
||||
|
||||
return info
|
||||
|
||||
@routes.get("/object_info")
|
||||
@ -887,6 +893,8 @@ class PromptServer():
|
||||
if "partial_execution_targets" in json_data:
|
||||
partial_execution_targets = json_data["partial_execution_targets"]
|
||||
|
||||
self.node_replace_manager.apply_replacements(prompt)
|
||||
|
||||
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
|
||||
extra_data = {}
|
||||
if "extra_data" in json_data:
|
||||
@ -995,6 +1003,7 @@ class PromptServer():
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.node_replace_manager.add_routes(self.routes)
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
|
||||
@ -5,8 +5,11 @@ from comfy_execution.jobs import (
|
||||
is_previewable,
|
||||
normalize_queue_item,
|
||||
normalize_history_item,
|
||||
normalize_output_item,
|
||||
normalize_outputs,
|
||||
get_outputs_summary,
|
||||
apply_sorting,
|
||||
has_3d_extension,
|
||||
)
|
||||
|
||||
|
||||
@ -35,8 +38,8 @@ class TestIsPreviewable:
|
||||
"""Unit tests for is_previewable()"""
|
||||
|
||||
def test_previewable_media_types(self):
|
||||
"""Images, video, audio media types should be previewable."""
|
||||
for media_type in ['images', 'video', 'audio']:
|
||||
"""Images, video, audio, 3d media types should be previewable."""
|
||||
for media_type in ['images', 'video', 'audio', '3d']:
|
||||
assert is_previewable(media_type, {}) is True
|
||||
|
||||
def test_non_previewable_media_types(self):
|
||||
@ -46,7 +49,7 @@ class TestIsPreviewable:
|
||||
|
||||
def test_3d_extensions_previewable(self):
|
||||
"""3D file extensions should be previewable regardless of media_type."""
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||
item = {'filename': f'model{ext}'}
|
||||
assert is_previewable('files', item) is True
|
||||
|
||||
@ -160,7 +163,7 @@ class TestGetOutputsSummary:
|
||||
|
||||
def test_3d_files_previewable(self):
|
||||
"""3D file extensions should be previewable."""
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||
outputs = {
|
||||
'node1': {
|
||||
'files': [{'filename': f'model{ext}', 'type': 'output'}]
|
||||
@ -192,6 +195,64 @@ class TestGetOutputsSummary:
|
||||
assert preview['mediaType'] == 'images'
|
||||
assert preview['subfolder'] == 'outputs'
|
||||
|
||||
def test_string_3d_filename_creates_preview(self):
|
||||
"""String items with 3D extensions should synthesize a preview (Preview3D node output).
|
||||
Only the .glb counts — nulls and non-file strings are excluded."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['preview3d_abc123.glb', None, None]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert count == 1
|
||||
assert preview is not None
|
||||
assert preview['filename'] == 'preview3d_abc123.glb'
|
||||
assert preview['mediaType'] == '3d'
|
||||
assert preview['nodeId'] == 'node1'
|
||||
assert preview['type'] == 'output'
|
||||
|
||||
def test_string_non_3d_filename_no_preview(self):
|
||||
"""String items without 3D extensions should not create a preview."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['data.json', None]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert count == 0
|
||||
assert preview is None
|
||||
|
||||
def test_string_3d_filename_used_as_fallback(self):
|
||||
"""String 3D preview should be used when no dict items are previewable."""
|
||||
outputs = {
|
||||
'node1': {
|
||||
'latents': [{'filename': 'latent.safetensors'}],
|
||||
},
|
||||
'node2': {
|
||||
'result': ['model.glb', None]
|
||||
}
|
||||
}
|
||||
count, preview = get_outputs_summary(outputs)
|
||||
assert preview is not None
|
||||
assert preview['filename'] == 'model.glb'
|
||||
assert preview['mediaType'] == '3d'
|
||||
|
||||
|
||||
class TestHas3DExtension:
|
||||
"""Unit tests for has_3d_extension()"""
|
||||
|
||||
def test_recognized_extensions(self):
|
||||
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||
assert has_3d_extension(f'model{ext}') is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert has_3d_extension('MODEL.GLB') is True
|
||||
assert has_3d_extension('Scene.GLTF') is True
|
||||
|
||||
def test_non_3d_extensions(self):
|
||||
for name in ['photo.png', 'video.mp4', 'data.json', 'model']:
|
||||
assert has_3d_extension(name) is False
|
||||
|
||||
|
||||
class TestApplySorting:
|
||||
"""Unit tests for apply_sorting()"""
|
||||
@ -395,3 +456,142 @@ class TestNormalizeHistoryItem:
|
||||
'prompt': {'nodes': {'1': {}}},
|
||||
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
|
||||
}
|
||||
|
||||
def test_include_outputs_normalizes_3d_strings(self):
|
||||
"""Detail view should transform string 3D filenames into file output dicts."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5,
|
||||
'prompt-3d',
|
||||
{'nodes': {}},
|
||||
{'create_time': 1234567890},
|
||||
['node1'],
|
||||
),
|
||||
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
||||
'outputs': {
|
||||
'node1': {
|
||||
'result': ['preview3d_abc123.glb', None, None]
|
||||
}
|
||||
},
|
||||
}
|
||||
job = normalize_history_item('prompt-3d', history_item, include_outputs=True)
|
||||
|
||||
assert job['outputs_count'] == 1
|
||||
result_items = job['outputs']['node1']['result']
|
||||
assert len(result_items) == 1
|
||||
assert result_items[0] == {
|
||||
'filename': 'preview3d_abc123.glb',
|
||||
'type': 'output',
|
||||
'subfolder': '',
|
||||
'mediaType': '3d',
|
||||
}
|
||||
|
||||
def test_include_outputs_preserves_dict_items(self):
|
||||
"""Detail view normalization should pass dict items through unchanged."""
|
||||
history_item = {
|
||||
'prompt': (
|
||||
5,
|
||||
'prompt-img',
|
||||
{'nodes': {}},
|
||||
{'create_time': 1234567890},
|
||||
['node1'],
|
||||
),
|
||||
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
||||
'outputs': {
|
||||
'node1': {
|
||||
'images': [
|
||||
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
job = normalize_history_item('prompt-img', history_item, include_outputs=True)
|
||||
|
||||
assert job['outputs_count'] == 1
|
||||
assert job['outputs']['node1']['images'] == [
|
||||
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
|
||||
]
|
||||
|
||||
|
||||
class TestNormalizeOutputItem:
|
||||
"""Unit tests for normalize_output_item()"""
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert normalize_output_item(None) is None
|
||||
|
||||
def test_string_3d_extension_synthesizes_dict(self):
|
||||
result = normalize_output_item('model.glb')
|
||||
assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
|
||||
|
||||
def test_string_non_3d_extension_returns_none(self):
|
||||
assert normalize_output_item('data.json') is None
|
||||
|
||||
def test_string_no_extension_returns_none(self):
|
||||
assert normalize_output_item('camera_info_string') is None
|
||||
|
||||
def test_dict_passes_through(self):
|
||||
item = {'filename': 'test.png', 'type': 'output'}
|
||||
assert normalize_output_item(item) is item
|
||||
|
||||
def test_other_types_return_none(self):
|
||||
assert normalize_output_item(42) is None
|
||||
assert normalize_output_item(True) is None
|
||||
|
||||
|
||||
class TestNormalizeOutputs:
|
||||
"""Unit tests for normalize_outputs()"""
|
||||
|
||||
def test_empty_outputs(self):
|
||||
assert normalize_outputs({}) == {}
|
||||
|
||||
def test_dict_items_pass_through(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'images': [{'filename': 'a.png', 'type': 'output'}],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == outputs
|
||||
|
||||
def test_3d_string_synthesized(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['model.glb', None, None],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == {
|
||||
'node1': {
|
||||
'result': [
|
||||
{'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'},
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
def test_animated_key_preserved(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'images': [{'filename': 'a.png', 'type': 'output'}],
|
||||
'animated': [True],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result['node1']['animated'] == [True]
|
||||
|
||||
def test_non_dict_node_outputs_preserved(self):
|
||||
outputs = {'node1': 'unexpected_value'}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == {'node1': 'unexpected_value'}
|
||||
|
||||
def test_none_items_filtered_but_other_types_preserved(self):
|
||||
outputs = {
|
||||
'node1': {
|
||||
'result': ['data.json', None, [1, 2, 3]],
|
||||
}
|
||||
}
|
||||
result = normalize_outputs(outputs)
|
||||
assert result == {
|
||||
'node1': {
|
||||
'result': ['data.json', [1, 2, 3]],
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user