mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-03 13:52:31 +08:00
Merge branch 'master' into patch-1
This commit is contained in:
commit
7d3b7ddee7
36
.github/workflows/release-webhook.yml
vendored
36
.github/workflows/release-webhook.yml
vendored
@ -7,6 +7,8 @@ on:
|
|||||||
jobs:
|
jobs:
|
||||||
send-webhook:
|
send-webhook:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
env:
|
||||||
|
DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||||
steps:
|
steps:
|
||||||
- name: Send release webhook
|
- name: Send release webhook
|
||||||
env:
|
env:
|
||||||
@ -106,3 +108,37 @@ jobs:
|
|||||||
--fail --silent --show-error
|
--fail --silent --show-error
|
||||||
|
|
||||||
echo "✅ Release webhook sent successfully"
|
echo "✅ Release webhook sent successfully"
|
||||||
|
|
||||||
|
- name: Send repository dispatch to desktop
|
||||||
|
env:
|
||||||
|
DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
|
||||||
|
RELEASE_TAG: ${{ github.event.release.tag_name }}
|
||||||
|
RELEASE_URL: ${{ github.event.release.html_url }}
|
||||||
|
run: |
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [ -z "${DISPATCH_TOKEN:-}" ]; then
|
||||||
|
echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
PAYLOAD="$(jq -n \
|
||||||
|
--arg release_tag "$RELEASE_TAG" \
|
||||||
|
--arg release_url "$RELEASE_URL" \
|
||||||
|
'{
|
||||||
|
event_type: "comfyui_release_published",
|
||||||
|
client_payload: {
|
||||||
|
release_tag: $release_tag,
|
||||||
|
release_url: $release_url
|
||||||
|
}
|
||||||
|
}')"
|
||||||
|
|
||||||
|
curl -fsSL \
|
||||||
|
-X POST \
|
||||||
|
-H "Accept: application/vnd.github+json" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
|
||||||
|
https://api.github.com/repos/Comfy-Org/desktop/dispatches \
|
||||||
|
-d "$PAYLOAD"
|
||||||
|
|
||||||
|
echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
|
||||||
|
|||||||
@ -227,7 +227,7 @@ Put your VAE in: models/vae
|
|||||||
|
|
||||||
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
|
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
|
||||||
|
|
||||||
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
|
||||||
|
|
||||||
|
|||||||
105
app/node_replace_manager.py
Normal file
105
app/node_replace_manager.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from comfy_api.latest._io_public import NodeReplace
|
||||||
|
|
||||||
|
from comfy_execution.graph_utils import is_link
|
||||||
|
import nodes
|
||||||
|
|
||||||
|
class NodeStruct(TypedDict):
|
||||||
|
inputs: dict[str, str | int | float | bool | tuple[str, int]]
|
||||||
|
class_type: str
|
||||||
|
_meta: dict[str, str]
|
||||||
|
|
||||||
|
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
|
||||||
|
new_node_struct = node_struct.copy()
|
||||||
|
if empty_inputs:
|
||||||
|
new_node_struct["inputs"] = {}
|
||||||
|
else:
|
||||||
|
new_node_struct["inputs"] = node_struct["inputs"].copy()
|
||||||
|
new_node_struct["_meta"] = node_struct["_meta"].copy()
|
||||||
|
return new_node_struct
|
||||||
|
|
||||||
|
|
||||||
|
class NodeReplaceManager:
|
||||||
|
"""Manages node replacement registrations."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||||
|
|
||||||
|
def register(self, node_replace: NodeReplace):
|
||||||
|
"""Register a node replacement mapping."""
|
||||||
|
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
||||||
|
|
||||||
|
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||||
|
"""Get replacements for an old node ID."""
|
||||||
|
return self._replacements.get(old_node_id)
|
||||||
|
|
||||||
|
def has_replacement(self, old_node_id: str) -> bool:
|
||||||
|
"""Check if a replacement exists for an old node ID."""
|
||||||
|
return old_node_id in self._replacements
|
||||||
|
|
||||||
|
def apply_replacements(self, prompt: dict[str, NodeStruct]):
|
||||||
|
connections: dict[str, list[tuple[str, str, int]]] = {}
|
||||||
|
need_replacement: set[str] = set()
|
||||||
|
for node_number, node_struct in prompt.items():
|
||||||
|
class_type = node_struct["class_type"]
|
||||||
|
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
|
||||||
|
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
|
||||||
|
need_replacement.add(node_number)
|
||||||
|
# keep track of connections
|
||||||
|
for input_id, input_value in node_struct["inputs"].items():
|
||||||
|
if is_link(input_value):
|
||||||
|
conn_number = input_value[0]
|
||||||
|
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
|
||||||
|
for node_number in need_replacement:
|
||||||
|
node_struct = prompt[node_number]
|
||||||
|
class_type = node_struct["class_type"]
|
||||||
|
replacements = self.get_replacement(class_type)
|
||||||
|
if replacements is None:
|
||||||
|
continue
|
||||||
|
# just use the first replacement
|
||||||
|
replacement = replacements[0]
|
||||||
|
new_node_id = replacement.new_node_id
|
||||||
|
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
|
||||||
|
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
|
||||||
|
continue
|
||||||
|
# first, replace node id (class_type)
|
||||||
|
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
|
||||||
|
new_node_struct["class_type"] = new_node_id
|
||||||
|
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
|
||||||
|
# second, replace inputs
|
||||||
|
if replacement.input_mapping is not None:
|
||||||
|
for input_map in replacement.input_mapping:
|
||||||
|
if "set_value" in input_map:
|
||||||
|
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
|
||||||
|
elif "old_id" in input_map:
|
||||||
|
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
|
||||||
|
# finalize input replacement
|
||||||
|
prompt[node_number] = new_node_struct
|
||||||
|
# third, replace outputs
|
||||||
|
if replacement.output_mapping is not None:
|
||||||
|
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
|
||||||
|
if node_number in connections:
|
||||||
|
for conns in connections[node_number]:
|
||||||
|
conn_node_number, conn_input_id, old_output_idx = conns
|
||||||
|
for output_map in replacement.output_mapping:
|
||||||
|
if output_map["old_idx"] == old_output_idx:
|
||||||
|
new_output_idx = output_map["new_idx"]
|
||||||
|
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
|
||||||
|
previous_input[1] = new_output_idx
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
"""Serialize all replacements to dict."""
|
||||||
|
return {
|
||||||
|
k: [v.as_dict() for v in v_list]
|
||||||
|
for k, v_list in self._replacements.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_routes(self, routes):
|
||||||
|
@routes.get("/node_replacements")
|
||||||
|
async def get_node_replacements(request):
|
||||||
|
return web.json_response(self.as_dict())
|
||||||
@ -1,13 +0,0 @@
|
|||||||
import pickle
|
|
||||||
|
|
||||||
load = pickle.load
|
|
||||||
|
|
||||||
class Empty:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Unpickler(pickle.Unpickler):
|
|
||||||
def find_class(self, module, name):
|
|
||||||
#TODO: safe unpickle
|
|
||||||
if module.startswith("pytorch_lightning"):
|
|
||||||
return Empty
|
|
||||||
return super().find_class(module, name)
|
|
||||||
@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
|||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
super().cleanup()
|
super().cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class QwenFunControlNet(ControlNet):
|
||||||
|
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||||
|
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||||
|
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||||
|
# unchanged while >1 grows more gently.
|
||||||
|
original_strength = self.strength
|
||||||
|
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||||
|
try:
|
||||||
|
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||||
|
finally:
|
||||||
|
self.strength = original_strength
|
||||||
|
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
self.set_extra_arg("base_model", model.diffusion_model)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
|
c.control_model = self.control_model
|
||||||
|
c.control_model_wrapped = self.control_model_wrapped
|
||||||
|
self.copy_to(c)
|
||||||
|
return c
|
||||||
|
|
||||||
class ControlLoraOps:
|
class ControlLoraOps:
|
||||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
|||||||
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
sd = model_config.process_unet_state_dict(sd)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
|
||||||
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
|
in_features = sd["control_img_in.weight"].shape[1]
|
||||||
|
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||||
|
|
||||||
|
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||||
|
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||||
|
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||||
|
|
||||||
|
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||||
|
control_in_features=in_features,
|
||||||
|
inner_dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_control_blocks=5,
|
||||||
|
main_model_double=60,
|
||||||
|
injection_layers=(0, 12, 24, 36, 48),
|
||||||
|
operations=operations,
|
||||||
|
device=comfy.model_management.unet_offload_device(),
|
||||||
|
dtype=unet_dtype,
|
||||||
|
)
|
||||||
|
model = controlnet_load_state_dict(model, sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Wan21()
|
||||||
|
control = QwenFunControlNet(
|
||||||
|
model,
|
||||||
|
compression_ratio=1,
|
||||||
|
latent_format=latent_format,
|
||||||
|
# Fun checkpoints already expect their own 33-channel context handling.
|
||||||
|
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||||
|
# and breaks the intended fallback packing path.
|
||||||
|
concat_mask=False,
|
||||||
|
load_device=load_device,
|
||||||
|
manual_cast_dtype=manual_cast_dtype,
|
||||||
|
extra_conds=[],
|
||||||
|
)
|
||||||
|
return control
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
def convert_mistoline(sd):
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||||
|
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|||||||
@ -1,12 +1,11 @@
|
|||||||
import math
|
import math
|
||||||
import time
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from scipy import integrate
|
from scipy import integrate
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torchsde
|
import torchsde
|
||||||
from tqdm.auto import trange as trange_, tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from . import deis
|
from . import deis
|
||||||
@ -15,34 +14,7 @@ import comfy.model_patcher
|
|||||||
import comfy.model_sampling
|
import comfy.model_sampling
|
||||||
|
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
|
from comfy.utils import model_trange as trange
|
||||||
|
|
||||||
def trange(*args, **kwargs):
|
|
||||||
if comfy.memory_management.aimdo_allocator is None:
|
|
||||||
return trange_(*args, **kwargs)
|
|
||||||
|
|
||||||
pbar = trange_(*args, **kwargs, smoothing=1.0)
|
|
||||||
pbar._i = 0
|
|
||||||
pbar.set_postfix_str(" Model Initializing ... ")
|
|
||||||
|
|
||||||
_update = pbar.update
|
|
||||||
|
|
||||||
def warmup_update(n=1):
|
|
||||||
pbar._i += 1
|
|
||||||
if pbar._i == 1:
|
|
||||||
pbar.i1_time = time.time()
|
|
||||||
pbar.set_postfix_str(" Model Initialization complete! ")
|
|
||||||
elif pbar._i == 2:
|
|
||||||
#bring forward the effective start time based the the diff between first and second iteration
|
|
||||||
#to attempt to remove load overhead from the final step rate estimate.
|
|
||||||
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
|
||||||
pbar.set_postfix_str("")
|
|
||||||
|
|
||||||
_update(n)
|
|
||||||
|
|
||||||
pbar.update = warmup_update
|
|
||||||
return pbar
|
|
||||||
|
|
||||||
|
|
||||||
def append_zero(x):
|
def append_zero(x):
|
||||||
return torch.cat([x, x.new_zeros([1])])
|
return torch.cat([x, x.new_zeros([1])])
|
||||||
|
|||||||
@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
|
||||||
|
|
||||||
def preprocess_text_embeds(self, text_embeds, text_ids):
|
def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
|
||||||
if text_ids is not None:
|
if text_ids is not None:
|
||||||
return self.llm_adapter(text_embeds, text_ids)
|
out = self.llm_adapter(text_embeds, text_ids)
|
||||||
|
if t5xxl_weights is not None:
|
||||||
|
out = out * t5xxl_weights
|
||||||
|
|
||||||
|
if out.shape[1] < 512:
|
||||||
|
out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
|
||||||
|
return out
|
||||||
else:
|
else:
|
||||||
return text_embeds
|
return text_embeds
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, **kwargs):
|
||||||
|
t5xxl_ids = kwargs.pop("t5xxl_ids", None)
|
||||||
|
if t5xxl_ids is not None:
|
||||||
|
context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
|
||||||
|
return super().forward(x, timesteps, context, **kwargs)
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from comfy.ldm.flux.layers import (
|
from comfy.ldm.flux.layers import (
|
||||||
MLPEmbedder,
|
MLPEmbedder,
|
||||||
RMSNorm,
|
|
||||||
ModulationOut,
|
ModulationOut,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,7 +28,7 @@ class Approximator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
||||||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
|
self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
|
||||||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -4,8 +4,6 @@ from functools import lru_cache
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from comfy.ldm.flux.layers import RMSNorm
|
|
||||||
|
|
||||||
|
|
||||||
class NerfEmbedder(nn.Module):
|
class NerfEmbedder(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
# We now need to generate parameters for 3 matrices.
|
# We now need to generate parameters for 3 matrices.
|
||||||
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
total_params = 3 * hidden_size_x**2 * mlp_ratio
|
||||||
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
|
||||||
self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
|
||||||
self.mlp_ratio = mlp_ratio
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
|
||||||
@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
|
|||||||
class NerfFinalLayer(nn.Module):
|
class NerfFinalLayer(nn.Module):
|
||||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||||
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
|
|||||||
class NerfFinalLayerConv(nn.Module):
|
class NerfFinalLayerConv(nn.Module):
|
||||||
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
|
self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
|
||||||
self.conv = operations.Conv2d(
|
self.conv = operations.Conv2d(
|
||||||
in_channels=hidden_size,
|
in_channels=hidden_size,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
|
|||||||
@ -5,9 +5,9 @@ import torch
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
|
||||||
import comfy.ldm.common_dit
|
|
||||||
|
|
||||||
|
# Fix import for some custom nodes, TODO: delete eventually.
|
||||||
|
RMSNorm = None
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||||
@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
|
|||||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
|
|
||||||
|
|
||||||
def forward(self, x: Tensor):
|
|
||||||
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
|
|
||||||
|
|
||||||
|
|
||||||
class QKNorm(torch.nn.Module):
|
class QKNorm(torch.nn.Module):
|
||||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||||
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
|
self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
|
||||||
q = self.query_norm(q)
|
q = self.query_norm(q)
|
||||||
@ -169,7 +161,7 @@ class SiLUActivation(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DoubleStreamBlock(nn.Module):
|
class DoubleStreamBlock(nn.Module):
|
||||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
|
|
||||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
self.flipped_img_txt = flipped_img_txt
|
|
||||||
|
|
||||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||||
if self.modulation:
|
if self.modulation:
|
||||||
img_mod1, img_mod2 = self.img_mod(vec)
|
img_mod1, img_mod2 = self.img_mod(vec)
|
||||||
@ -224,32 +214,17 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
del txt_qkv
|
del txt_qkv
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
if self.flipped_img_txt:
|
q = torch.cat((txt_q, img_q), dim=2)
|
||||||
q = torch.cat((img_q, txt_q), dim=2)
|
del txt_q, img_q
|
||||||
del img_q, txt_q
|
k = torch.cat((txt_k, img_k), dim=2)
|
||||||
k = torch.cat((img_k, txt_k), dim=2)
|
del txt_k, img_k
|
||||||
del img_k, txt_k
|
v = torch.cat((txt_v, img_v), dim=2)
|
||||||
v = torch.cat((img_v, txt_v), dim=2)
|
del txt_v, img_v
|
||||||
del img_v, txt_v
|
# run actual attention
|
||||||
# run actual attention
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
attn = attention(q, k, v,
|
del q, k, v
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
else:
|
|
||||||
q = torch.cat((txt_q, img_q), dim=2)
|
|
||||||
del txt_q, img_q
|
|
||||||
k = torch.cat((txt_k, img_k), dim=2)
|
|
||||||
del txt_k, img_k
|
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
|
||||||
del txt_v, img_v
|
|
||||||
# run actual attention
|
|
||||||
attn = attention(q, k, v,
|
|
||||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
|
||||||
del q, k, v
|
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
|
|||||||
@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|||||||
return out.to(dtype=torch.float32, device=pos.device)
|
return out.to(dtype=torch.float32, device=pos.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||||
|
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||||
|
|
||||||
|
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||||
|
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||||
|
|
||||||
|
return x_out.reshape(*x.shape).type_as(x)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
apply_rope = comfy.quant_ops.ck.apply_rope
|
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||||
apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||||
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
|
if comfy.model_management.in_training:
|
||||||
|
return _apply_rope(xq, xk, freqs_cis)
|
||||||
|
else:
|
||||||
|
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||||
|
def apply_rope1(x, freqs_cis):
|
||||||
|
if comfy.model_management.in_training:
|
||||||
|
return _apply_rope1(x, freqs_cis)
|
||||||
|
else:
|
||||||
|
return q_apply_rope1(x, freqs_cis)
|
||||||
except:
|
except:
|
||||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||||
def apply_rope1(x: Tensor, freqs_cis: Tensor):
|
apply_rope = _apply_rope
|
||||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
apply_rope1 = _apply_rope1
|
||||||
|
|
||||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
|
||||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
|
||||||
|
|
||||||
return x_out.reshape(*x.shape).type_as(x)
|
|
||||||
|
|
||||||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
|
||||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
|
||||||
|
|||||||
@ -16,7 +16,6 @@ from .layers import (
|
|||||||
SingleStreamBlock,
|
SingleStreamBlock,
|
||||||
timestep_embedding,
|
timestep_embedding,
|
||||||
Modulation,
|
Modulation,
|
||||||
RMSNorm
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -81,7 +80,7 @@ class Flux(nn.Module):
|
|||||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||||
|
|
||||||
if params.txt_norm:
|
if params.txt_norm:
|
||||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
|
||||||
else:
|
else:
|
||||||
self.txt_norm = None
|
self.txt_norm = None
|
||||||
|
|
||||||
|
|||||||
@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
|
|||||||
self.num_heads,
|
self.num_heads,
|
||||||
mlp_ratio=params.mlp_ratio,
|
mlp_ratio=params.mlp_ratio,
|
||||||
qkv_bias=params.qkv_bias,
|
qkv_bias=params.qkv_bias,
|
||||||
flipped_img_txt=True,
|
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
for _ in range(params.depth)
|
for _ in range(params.depth)
|
||||||
@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module):
|
|||||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||||
|
|
||||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
img_len = img.shape[1]
|
img_len = img.shape[1]
|
||||||
if txt_mask is not None:
|
if txt_mask is not None:
|
||||||
attn_mask_len = img_len + txt.shape[1]
|
attn_mask_len = img_len + txt.shape[1]
|
||||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||||
attn_mask[:, 0, img_len:] = txt_mask
|
attn_mask[:, 0, :txt.shape[1]] = txt_mask
|
||||||
else:
|
else:
|
||||||
attn_mask = None
|
attn_mask = None
|
||||||
|
|
||||||
@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
if add is not None:
|
if add is not None:
|
||||||
img += add
|
img += add
|
||||||
|
|
||||||
img = torch.cat((img, txt), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module):
|
|||||||
if i < len(control_o):
|
if i < len(control_o):
|
||||||
add = control_o[i]
|
add = control_o[i]
|
||||||
if add is not None:
|
if add is not None:
|
||||||
img[:, : img_len] += add
|
img[:, txt.shape[1]: img_len + txt.shape[1]] += add
|
||||||
|
|
||||||
img = img[:, : img_len]
|
img = img[:, txt.shape[1]: img_len + txt.shape[1]]
|
||||||
if ref_latent is not None:
|
if ref_latent is not None:
|
||||||
img = img[:, ref_latent.shape[1]:]
|
img = img[:, ref_latent.shape[1]:]
|
||||||
|
|
||||||
|
|||||||
@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
|
|||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
def interpolate_up(x, scale_factor):
|
def interpolate_up(x, scale_factor):
|
||||||
try:
|
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
|
||||||
except: #operation not implemented for bf16
|
|
||||||
orig_shape = list(x.shape)
|
|
||||||
out_shape = orig_shape[:2]
|
|
||||||
for i in range(len(orig_shape) - 2):
|
|
||||||
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
|
||||||
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
|
||||||
split = 8
|
|
||||||
l = out.shape[1] // split
|
|
||||||
for i in range(0, out.shape[1], l):
|
|
||||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
|
||||||
return out
|
|
||||||
|
|
||||||
class Upsample(nn.Module):
|
class Upsample(nn.Module):
|
||||||
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||||
|
|||||||
@ -2,6 +2,196 @@ import torch
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from .model import QwenImageTransformer2DModel
|
from .model import QwenImageTransformer2DModel
|
||||||
|
from .model import QwenImageTransformerBlock
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||||
|
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__(
|
||||||
|
dim=dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.has_before_proj = has_before_proj
|
||||||
|
if has_before_proj:
|
||||||
|
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
control_in_features=132,
|
||||||
|
inner_dim=3072,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
num_control_blocks=5,
|
||||||
|
main_model_double=60,
|
||||||
|
injection_layers=(0, 12, 24, 36, 48),
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.main_model_double = main_model_double
|
||||||
|
self.injection_layers = tuple(injection_layers)
|
||||||
|
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||||
|
# to the reference Gen2/VideoX implementation around strength=1.
|
||||||
|
self.hint_scale = 1.0
|
||||||
|
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.control_blocks = torch.nn.ModuleList([])
|
||||||
|
for i in range(num_control_blocks):
|
||||||
|
self.control_blocks.append(
|
||||||
|
QwenImageFunControlBlock(
|
||||||
|
dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
has_before_proj=(i == 0),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_hint_tokens(self, hint):
|
||||||
|
if hint is None:
|
||||||
|
return None
|
||||||
|
if hint.ndim == 4:
|
||||||
|
hint = hint.unsqueeze(2)
|
||||||
|
|
||||||
|
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||||
|
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||||
|
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||||
|
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||||
|
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||||
|
if hint.shape[1] == 16 and expected_c == 33:
|
||||||
|
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||||
|
zeros_inpaint = torch.zeros_like(hint)
|
||||||
|
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||||
|
|
||||||
|
bs, c, t, h, w = hint.shape
|
||||||
|
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
orig_shape[0],
|
||||||
|
orig_shape[1],
|
||||||
|
orig_shape[-3],
|
||||||
|
orig_shape[-2] // 2,
|
||||||
|
2,
|
||||||
|
orig_shape[-1] // 2,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
bs,
|
||||||
|
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||||
|
c * 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_in = self.control_img_in.weight.shape[1]
|
||||||
|
cur_in = hidden_states.shape[-1]
|
||||||
|
if cur_in < expected_in:
|
||||||
|
pad = torch.zeros(
|
||||||
|
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||||
|
elif cur_in > expected_in:
|
||||||
|
hidden_states = hidden_states[:, :, :expected_in]
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
hint=None,
|
||||||
|
transformer_options={},
|
||||||
|
base_model=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if base_model is None:
|
||||||
|
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||||
|
|
||||||
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||||
|
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||||
|
encoder_hidden_states_mask = None
|
||||||
|
|
||||||
|
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||||
|
hint_tokens = self._process_hint_tokens(hint)
|
||||||
|
if hint_tokens is None:
|
||||||
|
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||||
|
|
||||||
|
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||||
|
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||||
|
hint_tokens = hint_tokens[:, :max_tokens]
|
||||||
|
hidden_states = hidden_states[:, :max_tokens]
|
||||||
|
img_ids = img_ids[:, :max_tokens]
|
||||||
|
|
||||||
|
txt_start = round(
|
||||||
|
max(
|
||||||
|
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||||
|
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
|
|
||||||
|
hidden_states = base_model.img_in(hidden_states)
|
||||||
|
encoder_hidden_states = base_model.txt_norm(context)
|
||||||
|
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
|
||||||
|
temb = (
|
||||||
|
base_model.time_text_embed(timesteps, hidden_states)
|
||||||
|
if guidance is None
|
||||||
|
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
c = self.control_img_in(hint_tokens)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.control_blocks):
|
||||||
|
if i == 0:
|
||||||
|
c_in = block.before_proj(c) + hidden_states
|
||||||
|
all_c = []
|
||||||
|
else:
|
||||||
|
all_c = list(torch.unbind(c, dim=0))
|
||||||
|
c_in = all_c.pop(-1)
|
||||||
|
|
||||||
|
encoder_hidden_states, c_out = block(
|
||||||
|
hidden_states=c_in,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||||
|
all_c += [c_skip, c_out]
|
||||||
|
c = torch.stack(all_c, dim=0)
|
||||||
|
|
||||||
|
hints = torch.unbind(c, dim=0)[:-1]
|
||||||
|
|
||||||
|
controlnet_block_samples = [None] * self.main_model_double
|
||||||
|
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||||
|
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||||
|
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||||
|
|
||||||
|
return {"input": controlnet_block_samples}
|
||||||
|
|
||||||
|
|
||||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||||
|
|||||||
@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
|
|||||||
|
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
|
def calculate_shape(patches, weight, key, original_weights=None):
|
||||||
|
current_shape = weight.shape
|
||||||
|
|
||||||
|
for p in patches:
|
||||||
|
v = p[1]
|
||||||
|
offset = p[3]
|
||||||
|
|
||||||
|
# Offsets restore the old shape; lists force a diff without metadata
|
||||||
|
if offset is not None or isinstance(v, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(v, weight_adapter.WeightAdapterBase):
|
||||||
|
adapter_shape = v.calculate_shape(key)
|
||||||
|
if adapter_shape is not None:
|
||||||
|
current_shape = adapter_shape
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Standard diff logic with padding
|
||||||
|
if len(v) == 2:
|
||||||
|
patch_type, patch_data = v[0], v[1]
|
||||||
|
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
|
||||||
|
current_shape = patch_data[0].shape
|
||||||
|
|
||||||
|
return current_shape
|
||||||
|
|
||||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
|
||||||
for p in patches:
|
for p in patches:
|
||||||
strength = p[0]
|
strength = p[0]
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import comfy.utils
|
|||||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||||
sd_out = {}
|
sd_out = {}
|
||||||
for k in sd:
|
for k in sd:
|
||||||
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
|
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
|
||||||
sd_out[k_to] = sd[k]
|
sd_out[k_to] = sd[k]
|
||||||
|
|
||||||
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
|
||||||
|
|||||||
@ -1160,12 +1160,16 @@ class Anima(BaseModel):
|
|||||||
device = kwargs["device"]
|
device = kwargs["device"]
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
if t5xxl_ids is not None:
|
if t5xxl_ids is not None:
|
||||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
|
|
||||||
if t5xxl_weights is not None:
|
if t5xxl_weights is not None:
|
||||||
cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
|
||||||
|
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
if torch.is_inference_mode_enabled(): # if not we are training
|
||||||
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
||||||
|
else:
|
||||||
|
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||||
|
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||||
|
|
||||||
if cross_attn.shape[1] < 512:
|
|
||||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
|
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
|
|||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
def any_suffix_in(keys, prefix, main, suffix_list=[]):
|
||||||
|
for x in suffix_list:
|
||||||
|
if "{}{}{}".format(prefix, main, x) in keys:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
|
||||||
context_dim = None
|
context_dim = None
|
||||||
use_linear_in_transformer = False
|
use_linear_in_transformer = False
|
||||||
@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["meanflow_sum"] = False
|
dit_config["meanflow_sum"] = False
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||||
dit_config["image_model"] = "flux2"
|
dit_config["image_model"] = "flux2"
|
||||||
@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
|
||||||
|
if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
|
||||||
dit_config["image_model"] = "chroma"
|
dit_config["image_model"] = "chroma"
|
||||||
dit_config["in_channels"] = 64
|
dit_config["in_channels"] = 64
|
||||||
dit_config["out_channels"] = 64
|
dit_config["out_channels"] = 64
|
||||||
@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["out_dim"] = 3072
|
dit_config["out_dim"] = 3072
|
||||||
dit_config["hidden_dim"] = 5120
|
dit_config["hidden_dim"] = 5120
|
||||||
dit_config["n_layers"] = 5
|
dit_config["n_layers"] = 5
|
||||||
if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
|
|
||||||
|
if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
|
||||||
dit_config["image_model"] = "chroma_radiance"
|
dit_config["image_model"] = "chroma_radiance"
|
||||||
dit_config["in_channels"] = 3
|
dit_config["in_channels"] = 3
|
||||||
dit_config["out_channels"] = 3
|
dit_config["out_channels"] = 3
|
||||||
@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_depth"] = 4
|
dit_config["nerf_depth"] = 4
|
||||||
dit_config["nerf_max_freqs"] = 8
|
dit_config["nerf_max_freqs"] = 8
|
||||||
dit_config["nerf_tile_size"] = 512
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
|
||||||
dit_config["use_x0"] = True
|
dit_config["use_x0"] = True
|
||||||
@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
else:
|
else:
|
||||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||||
dit_config["txt_ids_dims"] = [1, 2]
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@
|
|||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import threading
|
import threading
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
@ -55,6 +55,11 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Training Related State
|
||||||
|
in_training = False
|
||||||
|
|
||||||
|
|
||||||
def get_supported_float8_types():
|
def get_supported_float8_types():
|
||||||
float8_types = []
|
float8_types = []
|
||||||
try:
|
try:
|
||||||
@ -651,7 +656,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
return unloaded_models
|
return unloaded_models
|
||||||
|
|
||||||
def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
@ -747,26 +752,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
|
|||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
return
|
return
|
||||||
|
|
||||||
def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
|
|
||||||
with torch.inference_mode():
|
|
||||||
load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
|
||||||
soft_empty_cache()
|
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
|
||||||
#Deliberately load models outside of the Aimdo mempool so they can be retained accross
|
|
||||||
#nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
|
|
||||||
#thread local. So exploit that to escape context
|
|
||||||
if enables_dynamic_vram():
|
|
||||||
t = threading.Thread(
|
|
||||||
target=load_models_gpu_thread,
|
|
||||||
args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
|
||||||
)
|
|
||||||
t.start()
|
|
||||||
t.join()
|
|
||||||
else:
|
|
||||||
load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
|
|
||||||
minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
|
||||||
|
|
||||||
def load_model_gpu(model):
|
def load_model_gpu(model):
|
||||||
return load_models_gpu([model])
|
return load_models_gpu([model])
|
||||||
|
|
||||||
@ -1226,21 +1211,20 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
|||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = weight._model_dtype
|
dtype = weight._model_dtype
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||||
if signature is not None:
|
if signature is not None:
|
||||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
v_tensor = weight._v_tensor
|
||||||
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
else:
|
||||||
|
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||||
|
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||||
|
weight._v_tensor = v_tensor
|
||||||
weight._v_signature = signature
|
weight._v_signature = signature
|
||||||
#Send it over
|
#Send it over
|
||||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||||
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
|
return v_tensor.to(dtype=dtype)
|
||||||
#a non comfy weight
|
|
||||||
r.copy_(v_tensor)
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
|
|
||||||
return r
|
|
||||||
|
|
||||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||||
#Offloaded casting could skip this, however it would make the quantizations
|
#Offloaded casting could skip this, however it would make the quantizations
|
||||||
|
|||||||
@ -19,7 +19,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import copy
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
@ -317,7 +316,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
|
||||||
n.backup = self.backup
|
n.backup = self.backup
|
||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
n.parent = self
|
n.parent = self
|
||||||
@ -680,18 +679,19 @@ class ModelPatcher:
|
|||||||
for key in list(self.pinned):
|
for key in list(self.pinned):
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
|
|
||||||
def _load_list(self, prio_comfy_cast_weights=False):
|
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
params = []
|
default = False
|
||||||
skip = False
|
params = { name: param for name, param in m.named_parameters(recurse=False) }
|
||||||
for name, param in m.named_parameters(recurse=False):
|
|
||||||
params.append(name)
|
|
||||||
for name, param in m.named_parameters(recurse=True):
|
for name, param in m.named_parameters(recurse=True):
|
||||||
if name not in params:
|
if name not in params:
|
||||||
skip = True # skip random weights in non leaf modules
|
default = True # default random weights in non leaf modules
|
||||||
break
|
break
|
||||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
if default and default_device is not None:
|
||||||
|
for param in params.values():
|
||||||
|
param.data = param.data.to(device=default_device)
|
||||||
|
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
module_offload_mem = module_mem
|
module_offload_mem = module_mem
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
@ -1492,9 +1492,11 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if vbar is not None:
|
if vbar is not None:
|
||||||
vbar.prioritize()
|
vbar.prioritize()
|
||||||
|
|
||||||
#We have way more tools for acceleration on comfy weight offloading, so always
|
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
||||||
|
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||||
|
#with a high layer rate (e.g. autoregressive LLMs).
|
||||||
#prioritize the non-comfy weights (note the order reverse).
|
#prioritize the non-comfy weights (note the order reverse).
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
|
|
||||||
for x in loading:
|
for x in loading:
|
||||||
@ -1512,8 +1514,10 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
weight, _, _ = get_key_weight(self.model, key)
|
weight, _, _ = get_key_weight(self.model, key)
|
||||||
if weight is None:
|
if weight is None:
|
||||||
return 0
|
return (False, 0)
|
||||||
if key in self.patches:
|
if key in self.patches:
|
||||||
|
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
|
||||||
|
return (True, 0)
|
||||||
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
|
||||||
num_patches += 1
|
num_patches += 1
|
||||||
else:
|
else:
|
||||||
@ -1524,10 +1528,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
setattr(m, param_key + "_function", weight_function)
|
setattr(m, param_key + "_function", weight_function)
|
||||||
geometry = weight
|
geometry = weight
|
||||||
if not isinstance(weight, QuantizedTensor):
|
if not isinstance(weight, QuantizedTensor):
|
||||||
model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
|
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
|
||||||
weight._model_dtype = model_dtype
|
weight._model_dtype = model_dtype
|
||||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
return comfy.memory_management.vram_aligned_size(geometry)
|
return (False, comfy.memory_management.vram_aligned_size(geometry))
|
||||||
|
|
||||||
|
def force_load_param(self, param_key, device_to):
|
||||||
|
key = key_param_name_to_key(n, param_key)
|
||||||
|
if key in self.backup:
|
||||||
|
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
|
||||||
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
@ -1535,13 +1545,19 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
m.seed_key = n
|
m.seed_key = n
|
||||||
set_dirty(m, dirty)
|
set_dirty(m, dirty)
|
||||||
|
|
||||||
v_weight_size = 0
|
force_load, v_weight_size = setup_param(self, m, n, "weight")
|
||||||
v_weight_size += setup_param(self, m, n, "weight")
|
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
|
||||||
v_weight_size += setup_param(self, m, n, "bias")
|
force_load = force_load or force_load_bias
|
||||||
|
v_weight_size += v_weight_bias
|
||||||
|
|
||||||
if vbar is not None and not hasattr(m, "_v"):
|
if force_load:
|
||||||
m._v = vbar.alloc(v_weight_size)
|
logging.info(f"Module {n} has resizing Lora - force loading")
|
||||||
allocated_size += v_weight_size
|
force_load_param(self, "weight", device_to)
|
||||||
|
force_load_param(self, "bias", device_to)
|
||||||
|
else:
|
||||||
|
if vbar is not None and not hasattr(m, "_v"):
|
||||||
|
m._v = vbar.alloc(v_weight_size)
|
||||||
|
allocated_size += v_weight_size
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for param in params:
|
for param in params:
|
||||||
@ -1550,13 +1566,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
weight.seed_key = key
|
weight.seed_key = key
|
||||||
set_dirty(weight, dirty)
|
set_dirty(weight, dirty)
|
||||||
geometry = weight
|
geometry = weight
|
||||||
model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
|
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
|
||||||
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
|
||||||
weight_size = geometry.numel() * geometry.element_size()
|
weight_size = geometry.numel() * geometry.element_size()
|
||||||
if vbar is not None and not hasattr(weight, "_v"):
|
if vbar is not None and not hasattr(weight, "_v"):
|
||||||
weight._v = vbar.alloc(weight_size)
|
weight._v = vbar.alloc(weight_size)
|
||||||
weight._model_dtype = model_dtype
|
weight._model_dtype = model_dtype
|
||||||
allocated_size += weight_size
|
allocated_size += weight_size
|
||||||
|
vbar.set_watermark_limit(allocated_size)
|
||||||
|
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||||
|
|
||||||
@ -1577,7 +1596,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
_, _, _, _, m, _ = x
|
_, _, _, _, m, _ = x
|
||||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||||
@ -1598,6 +1617,13 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.partially_unload_ram(1e32)
|
self.partially_unload_ram(1e32)
|
||||||
self.partially_unload(None, 1e32)
|
self.partially_unload(None, 1e32)
|
||||||
|
for m in self.model.modules():
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
|
|
||||||
|
keys = list(self.backup.keys())
|
||||||
|
for k in keys:
|
||||||
|
bk = self.backup[k]
|
||||||
|
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
assert not force_patch_weights #See above
|
assert not force_patch_weights #See above
|
||||||
|
|||||||
25
comfy/ops.py
25
comfy/ops.py
@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||||
if signature is not None:
|
|
||||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
|
||||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||||
|
if signature is not None:
|
||||||
|
if resident:
|
||||||
|
weight = s._v_weight
|
||||||
|
bias = s._v_bias
|
||||||
|
else:
|
||||||
|
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||||
|
|
||||||
if not resident:
|
if not resident:
|
||||||
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||||
cast_dest = None
|
cast_dest = None
|
||||||
|
|
||||||
xfer_source = [ s.weight, s.bias ]
|
xfer_source = [ s.weight, s.bias ]
|
||||||
@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
post_cast.copy_(pre_cast)
|
post_cast.copy_(pre_cast)
|
||||||
xfer_dest = cast_dest
|
xfer_dest = cast_dest
|
||||||
|
|
||||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||||
weight = params[0]
|
weight = params[0]
|
||||||
bias = params[1]
|
bias = params[1]
|
||||||
|
if signature is not None:
|
||||||
|
s._v_weight = weight
|
||||||
|
s._v_bias = bias
|
||||||
|
s._v_signature=signature
|
||||||
|
|
||||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
@ -169,8 +177,8 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
if orig.dtype == dtype and len(fns) == 0:
|
if orig.dtype == dtype and len(fns) == 0:
|
||||||
#The layer actually wants our freshly saved QT
|
#The layer actually wants our freshly saved QT
|
||||||
x = y
|
x = y
|
||||||
else:
|
elif update_weight:
|
||||||
y = x
|
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
||||||
if update_weight:
|
if update_weight:
|
||||||
orig.copy_(y)
|
orig.copy_(y)
|
||||||
for f in fns:
|
for f in fns:
|
||||||
@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||||
s._v_signature=signature
|
|
||||||
|
|
||||||
#FIXME: weird offload return protocol
|
#FIXME: weird offload return protocol
|
||||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||||
|
|||||||
@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
|
|||||||
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
|
||||||
return memory_required, minimum_memory_required
|
return memory_required, minimum_memory_required
|
||||||
|
|
||||||
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||||
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
|
||||||
_prepare_sampling,
|
_prepare_sampling,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
|
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
|
||||||
|
|
||||||
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
|
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
|
||||||
real_model: BaseModel = None
|
real_model: BaseModel = None
|
||||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||||
models += get_additional_models_from_model_options(model_options)
|
models += get_additional_models_from_model_options(model_options)
|
||||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||||
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
|
||||||
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
|
memory_required = 1e20
|
||||||
|
minimum_memory_required = None
|
||||||
|
else:
|
||||||
|
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
|
||||||
|
memory_required += inference_memory
|
||||||
|
minimum_memory_required += inference_memory
|
||||||
|
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
return real_model, conds, models
|
return real_model, conds, models
|
||||||
|
|||||||
@ -793,8 +793,6 @@ class VAE:
|
|||||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||||
self.first_stage_model = self.first_stage_model.eval()
|
self.first_stage_model = self.first_stage_model.eval()
|
||||||
|
|
||||||
model_management.archive_model_dtypes(self.first_stage_model)
|
|
||||||
|
|
||||||
if device is None:
|
if device is None:
|
||||||
device = model_management.vae_device()
|
device = model_management.vae_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -803,6 +801,7 @@ class VAE:
|
|||||||
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
|
||||||
self.vae_dtype = dtype
|
self.vae_dtype = dtype
|
||||||
self.first_stage_model.to(self.vae_dtype)
|
self.first_stage_model.to(self.vae_dtype)
|
||||||
|
model_management.archive_model_dtypes(self.first_stage_model)
|
||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
mp = comfy.model_patcher.CoreModelPatcher
|
mp = comfy.model_patcher.CoreModelPatcher
|
||||||
|
|||||||
@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def process_tokens(self, tokens, device):
|
def process_tokens(self, tokens, device):
|
||||||
end_token = self.special_tokens.get("end", None)
|
end_token = self.special_tokens.get("end", None)
|
||||||
|
pad_token = self.special_tokens.get("pad", -1)
|
||||||
if end_token is None:
|
if end_token is None:
|
||||||
cmp_token = self.special_tokens.get("pad", -1)
|
cmp_token = pad_token
|
||||||
else:
|
else:
|
||||||
cmp_token = end_token
|
cmp_token = end_token
|
||||||
|
|
||||||
@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
other_embeds = []
|
other_embeds = []
|
||||||
eos = False
|
eos = False
|
||||||
index = 0
|
index = 0
|
||||||
|
left_pad = False
|
||||||
for y in x:
|
for y in x:
|
||||||
if isinstance(y, numbers.Integral):
|
if isinstance(y, numbers.Integral):
|
||||||
if eos:
|
token = int(y)
|
||||||
|
if index == 0 and token == pad_token:
|
||||||
|
left_pad = True
|
||||||
|
|
||||||
|
if eos or (left_pad and token == pad_token):
|
||||||
attention_mask.append(0)
|
attention_mask.append(0)
|
||||||
else:
|
else:
|
||||||
attention_mask.append(1)
|
attention_mask.append(1)
|
||||||
token = int(y)
|
left_pad = False
|
||||||
|
|
||||||
tokens_temp += [token]
|
tokens_temp += [token]
|
||||||
if not eos and token == cmp_token:
|
if not eos and token == cmp_token and not left_pad:
|
||||||
if end_token is None:
|
if end_token is None:
|
||||||
attention_mask[-1] = 0
|
attention_mask[-1] = 0
|
||||||
eos = True
|
eos = True
|
||||||
|
|||||||
@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
|
|||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith("_norm.scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
||||||
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
||||||
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
||||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
|
||||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
|
||||||
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
||||||
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
||||||
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
out_sd[key_out] = state_dict[k]
|
out_sd[key_out] = state_dict[k]
|
||||||
return out_sd
|
return out_sd
|
||||||
|
|
||||||
@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
|
|||||||
|
|
||||||
latent_format = latent_formats.Hunyuan3Dv2
|
latent_format = latent_formats.Hunyuan3Dv2
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
def process_unet_state_dict_for_saving(self, state_dict):
|
def process_unet_state_dict_for_saving(self, state_dict):
|
||||||
replace_prefix = {"": "model."}
|
replace_prefix = {"": "model."}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
|
|||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
def process_unet_state_dict(self, state_dict):
|
||||||
|
out_sd = {}
|
||||||
|
for k in list(state_dict.keys()):
|
||||||
|
key_out = k
|
||||||
|
if key_out.endswith(".scale"):
|
||||||
|
key_out = "{}.weight".format(key_out[:-len(".scale")])
|
||||||
|
out_sd[key_out] = state_dict[k]
|
||||||
|
return out_sd
|
||||||
|
|
||||||
def get_model(self, state_dict, prefix="", device=None):
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
out = model_base.Chroma(self, device=device)
|
out = model_base.Chroma(self, device=device)
|
||||||
|
|||||||
@ -10,12 +10,12 @@ import comfy.utils
|
|||||||
def sample_manual_loop_no_classes(
|
def sample_manual_loop_no_classes(
|
||||||
model,
|
model,
|
||||||
ids=None,
|
ids=None,
|
||||||
paddings=[],
|
|
||||||
execution_dtype=None,
|
execution_dtype=None,
|
||||||
cfg_scale: float = 2.0,
|
cfg_scale: float = 2.0,
|
||||||
temperature: float = 0.85,
|
temperature: float = 0.85,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
top_k: int = None,
|
top_k: int = None,
|
||||||
|
min_p: float = 0.000,
|
||||||
seed: int = 1,
|
seed: int = 1,
|
||||||
min_tokens: int = 1,
|
min_tokens: int = 1,
|
||||||
max_new_tokens: int = 2048,
|
max_new_tokens: int = 2048,
|
||||||
@ -23,6 +23,8 @@ def sample_manual_loop_no_classes(
|
|||||||
audio_end_id: int = 215669,
|
audio_end_id: int = 215669,
|
||||||
eos_token_id: int = 151645,
|
eos_token_id: int = 151645,
|
||||||
):
|
):
|
||||||
|
if ids is None:
|
||||||
|
return []
|
||||||
device = model.execution_device
|
device = model.execution_device
|
||||||
|
|
||||||
if execution_dtype is None:
|
if execution_dtype is None:
|
||||||
@ -32,31 +34,34 @@ def sample_manual_loop_no_classes(
|
|||||||
execution_dtype = torch.float32
|
execution_dtype = torch.float32
|
||||||
|
|
||||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||||
for i, t in enumerate(paddings):
|
embeds_batch = embeds.shape[0]
|
||||||
attention_mask[i, :t] = 0
|
|
||||||
attention_mask[i, t:] = 1
|
|
||||||
|
|
||||||
output_audio_codes = []
|
output_audio_codes = []
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
generator = torch.Generator(device=device)
|
generator = torch.Generator(device=device)
|
||||||
generator.manual_seed(seed)
|
generator.manual_seed(seed)
|
||||||
model_config = model.transformer.model.config
|
model_config = model.transformer.model.config
|
||||||
|
past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
|
||||||
|
|
||||||
for x in range(model_config.num_hidden_layers):
|
for x in range(model_config.num_hidden_layers):
|
||||||
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
|
||||||
|
|
||||||
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
|
||||||
|
|
||||||
for step in range(max_new_tokens):
|
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
|
||||||
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
|
||||||
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
|
||||||
past_key_values = outputs[2]
|
past_key_values = outputs[2]
|
||||||
|
|
||||||
cond_logits = next_token_logits[0:1]
|
if cfg_scale != 1.0:
|
||||||
uncond_logits = next_token_logits[1:2]
|
cond_logits = next_token_logits[0:1]
|
||||||
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
uncond_logits = next_token_logits[1:2]
|
||||||
|
cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
|
||||||
|
else:
|
||||||
|
cfg_logits = next_token_logits[0:1]
|
||||||
|
|
||||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
|
||||||
|
if use_eos_score:
|
||||||
eos_score = cfg_logits[:, eos_token_id].clone()
|
eos_score = cfg_logits[:, eos_token_id].clone()
|
||||||
|
|
||||||
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
remove_logit_value = torch.finfo(cfg_logits.dtype).min
|
||||||
@ -64,7 +69,7 @@ def sample_manual_loop_no_classes(
|
|||||||
cfg_logits[:, :audio_start_id] = remove_logit_value
|
cfg_logits[:, :audio_start_id] = remove_logit_value
|
||||||
cfg_logits[:, audio_end_id:] = remove_logit_value
|
cfg_logits[:, audio_end_id:] = remove_logit_value
|
||||||
|
|
||||||
if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
|
if use_eos_score:
|
||||||
cfg_logits[:, eos_token_id] = eos_score
|
cfg_logits[:, eos_token_id] = eos_score
|
||||||
|
|
||||||
if top_k is not None and top_k > 0:
|
if top_k is not None and top_k > 0:
|
||||||
@ -72,6 +77,12 @@ def sample_manual_loop_no_classes(
|
|||||||
min_val = top_k_vals[..., -1, None]
|
min_val = top_k_vals[..., -1, None]
|
||||||
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
cfg_logits[cfg_logits < min_val] = remove_logit_value
|
||||||
|
|
||||||
|
if min_p is not None and min_p > 0:
|
||||||
|
probs = torch.softmax(cfg_logits, dim=-1)
|
||||||
|
p_max = probs.max(dim=-1, keepdim=True).values
|
||||||
|
indices_to_remove = probs < (min_p * p_max)
|
||||||
|
cfg_logits[indices_to_remove] = remove_logit_value
|
||||||
|
|
||||||
if top_p is not None and top_p < 1.0:
|
if top_p is not None and top_p < 1.0:
|
||||||
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
|
||||||
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
@ -93,8 +104,8 @@ def sample_manual_loop_no_classes(
|
|||||||
break
|
break
|
||||||
|
|
||||||
embed, _, _, _ = model.process_tokens([[token]], device)
|
embed, _, _, _ = model.process_tokens([[token]], device)
|
||||||
embeds = embed.repeat(2, 1, 1)
|
embeds = embed.repeat(embeds_batch, 1, 1)
|
||||||
attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
|
||||||
|
|
||||||
output_audio_codes.append(token - audio_start_id)
|
output_audio_codes.append(token - audio_start_id)
|
||||||
progress_bar.update_absolute(step)
|
progress_bar.update_absolute(step)
|
||||||
@ -102,24 +113,29 @@ def sample_manual_loop_no_classes(
|
|||||||
return output_audio_codes
|
return output_audio_codes
|
||||||
|
|
||||||
|
|
||||||
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
|
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
|
||||||
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
positive = [[token for token, _ in inner_list] for inner_list in positive]
|
||||||
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
|
||||||
positive = positive[0]
|
positive = positive[0]
|
||||||
negative = negative[0]
|
|
||||||
|
|
||||||
neg_pad = 0
|
if cfg_scale != 1.0:
|
||||||
if len(negative) < len(positive):
|
negative = [[token for token, _ in inner_list] for inner_list in negative]
|
||||||
neg_pad = (len(positive) - len(negative))
|
negative = negative[0]
|
||||||
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
|
||||||
|
|
||||||
pos_pad = 0
|
neg_pad = 0
|
||||||
if len(negative) > len(positive):
|
if len(negative) < len(positive):
|
||||||
pos_pad = (len(negative) - len(positive))
|
neg_pad = (len(positive) - len(negative))
|
||||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
negative = [model.special_tokens["pad"]] * neg_pad + negative
|
||||||
|
|
||||||
paddings = [pos_pad, neg_pad]
|
pos_pad = 0
|
||||||
return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
if len(negative) > len(positive):
|
||||||
|
pos_pad = (len(negative) - len(positive))
|
||||||
|
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||||
|
|
||||||
|
ids = [positive, negative]
|
||||||
|
else:
|
||||||
|
ids = [positive]
|
||||||
|
|
||||||
|
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||||
|
|
||||||
|
|
||||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
@ -129,12 +145,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
|
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
|
||||||
user_metas = {
|
user_metas = {
|
||||||
k: kwargs.pop(k)
|
k: kwargs.pop(k)
|
||||||
for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
|
for k in ("bpm", "duration", "keyscale", "timesignature")
|
||||||
if k in kwargs
|
if k in kwargs
|
||||||
}
|
}
|
||||||
timesignature = user_metas.get("timesignature")
|
timesignature = user_metas.get("timesignature")
|
||||||
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||||
user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
|
user_metas["timesignature"] = timesignature[:-2]
|
||||||
user_metas = {
|
user_metas = {
|
||||||
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
|
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
|
||||||
for k, v in user_metas.items()
|
for k, v in user_metas.items()
|
||||||
@ -147,8 +163,11 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
|
return f"<think>\n{meta_yaml}\n</think>" if not return_yaml else meta_yaml
|
||||||
|
|
||||||
def _metas_to_cap(self, **kwargs) -> str:
|
def _metas_to_cap(self, **kwargs) -> str:
|
||||||
use_keys = ("bpm", "duration", "keyscale", "timesignature")
|
use_keys = ("bpm", "timesignature", "keyscale", "duration")
|
||||||
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
|
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
|
||||||
|
timesignature = user_metas.get("timesignature")
|
||||||
|
if isinstance(timesignature, str) and timesignature.endswith("/4"):
|
||||||
|
user_metas["timesignature"] = timesignature[:-2]
|
||||||
duration = user_metas["duration"]
|
duration = user_metas["duration"]
|
||||||
if duration == "N/A":
|
if duration == "N/A":
|
||||||
user_metas["duration"] = "30 seconds"
|
user_metas["duration"] = "30 seconds"
|
||||||
@ -159,9 +178,13 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
|
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
text = text.strip()
|
||||||
|
text_negative = kwargs.get("caption_negative", text).strip()
|
||||||
lyrics = kwargs.get("lyrics", "")
|
lyrics = kwargs.get("lyrics", "")
|
||||||
|
lyrics_negative = kwargs.get("lyrics_negative", lyrics)
|
||||||
duration = kwargs.get("duration", 120)
|
duration = kwargs.get("duration", 120)
|
||||||
|
if isinstance(duration, str):
|
||||||
|
duration = float(duration.split(None, 1)[0])
|
||||||
language = kwargs.get("language")
|
language = kwargs.get("language")
|
||||||
seed = kwargs.get("seed", 0)
|
seed = kwargs.get("seed", 0)
|
||||||
|
|
||||||
@ -170,28 +193,55 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
temperature = kwargs.get("temperature", 0.85)
|
temperature = kwargs.get("temperature", 0.85)
|
||||||
top_p = kwargs.get("top_p", 0.9)
|
top_p = kwargs.get("top_p", 0.9)
|
||||||
top_k = kwargs.get("top_k", 0.0)
|
top_k = kwargs.get("top_k", 0.0)
|
||||||
|
min_p = kwargs.get("min_p", 0.000)
|
||||||
|
|
||||||
duration = math.ceil(duration)
|
duration = math.ceil(duration)
|
||||||
kwargs["duration"] = duration
|
kwargs["duration"] = duration
|
||||||
|
tokens_duration = duration * 5
|
||||||
|
min_tokens = int(kwargs.get("min_tokens", tokens_duration))
|
||||||
|
max_tokens = int(kwargs.get("max_tokens", tokens_duration))
|
||||||
|
|
||||||
cot_text = self._metas_to_cot(caption = text, **kwargs)
|
metas_negative = {
|
||||||
|
k.rsplit("_", 1)[0]: kwargs.pop(k)
|
||||||
|
for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
|
||||||
|
if k in kwargs
|
||||||
|
}
|
||||||
|
if not kwargs.get("use_negative_caption"):
|
||||||
|
_ = metas_negative.pop("caption", None)
|
||||||
|
|
||||||
|
cot_text = self._metas_to_cot(caption=text, **kwargs)
|
||||||
|
cot_text_negative = "<think>\n\n</think>" if not metas_negative else self._metas_to_cot(**metas_negative)
|
||||||
meta_cap = self._metas_to_cap(**kwargs)
|
meta_cap = self._metas_to_cap(**kwargs)
|
||||||
|
|
||||||
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
|
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
|
||||||
|
lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
|
||||||
|
qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
|
||||||
|
|
||||||
out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
|
llm_prompts = {
|
||||||
out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "<think>\n</think>"), disable_weights=True)
|
"lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
|
||||||
|
"lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
|
||||||
|
"lyrics": lyrics_template.format(language if language is not None else "", lyrics),
|
||||||
|
"qwen3_06b": qwen3_06b_template.format(text, meta_cap),
|
||||||
|
}
|
||||||
|
|
||||||
out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
|
out = {
|
||||||
out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
|
prompt_key: self.qwen3_06b.tokenize_with_weights(
|
||||||
out["lm_metadata"] = {"min_tokens": duration * 5,
|
prompt,
|
||||||
|
prompt_key == "qwen3_06b" and return_word_ids,
|
||||||
|
disable_weights = True,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for prompt_key, prompt in llm_prompts.items()
|
||||||
|
}
|
||||||
|
out["lm_metadata"] = {"min_tokens": min_tokens,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
"seed": seed,
|
"seed": seed,
|
||||||
"generate_audio_codes": generate_audio_codes,
|
"generate_audio_codes": generate_audio_codes,
|
||||||
"cfg_scale": cfg_scale,
|
"cfg_scale": cfg_scale,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"top_p": top_p,
|
"top_p": top_p,
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
|
"min_p": min_p,
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -252,7 +302,7 @@ class ACE15TEModel(torch.nn.Module):
|
|||||||
|
|
||||||
lm_metadata = token_weight_pairs["lm_metadata"]
|
lm_metadata = token_weight_pairs["lm_metadata"]
|
||||||
if lm_metadata["generate_audio_codes"]:
|
if lm_metadata["generate_audio_codes"]:
|
||||||
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
|
audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
|
||||||
out["audio_codes"] = [audio_codes]
|
out["audio_codes"] = [audio_codes]
|
||||||
|
|
||||||
return base_out, None, out
|
return base_out, None, out
|
||||||
|
|||||||
@ -355,13 +355,6 @@ class RMSNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||||
if not isinstance(theta, list):
|
if not isinstance(theta, list):
|
||||||
theta = [theta]
|
theta = [theta]
|
||||||
@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
|||||||
else:
|
else:
|
||||||
cos = cos.unsqueeze(1)
|
cos = cos.unsqueeze(1)
|
||||||
sin = sin.unsqueeze(1)
|
sin = sin.unsqueeze(1)
|
||||||
out.append((cos, sin))
|
sin_split = sin.shape[-1] // 2
|
||||||
|
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
|
||||||
|
|
||||||
if len(out) == 1:
|
if len(out) == 1:
|
||||||
return out[0]
|
return out[0]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
org_dtype = xq.dtype
|
org_dtype = xq.dtype
|
||||||
cos = freqs_cis[0]
|
cos = freqs_cis[0]
|
||||||
sin = freqs_cis[1]
|
sin = freqs_cis[1]
|
||||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
nsin = freqs_cis[2]
|
||||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
|
||||||
|
q_embed = (xq * cos)
|
||||||
|
q_split = q_embed.shape[-1] // 2
|
||||||
|
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
|
||||||
|
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
|
||||||
|
|
||||||
|
k_embed = (xk * cos)
|
||||||
|
k_split = k_embed.shape[-1] // 2
|
||||||
|
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
|
||||||
|
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
|
||||||
|
|
||||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
|
|||||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
@ -97,6 +97,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||||
|
|
||||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||||
|
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
|
||||||
out_device = out.device
|
out_device = out.device
|
||||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||||
@ -138,6 +139,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
|
|
||||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||||
|
num_tokens = max(num_tokens, 64)
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
|||||||
@ -20,13 +20,14 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
import comfy.checkpoint_pickle
|
import comfy.memory_management
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import logging
|
import logging
|
||||||
import itertools
|
import itertools
|
||||||
from torch.nn.functional import interpolate
|
from torch.nn.functional import interpolate
|
||||||
|
from tqdm.auto import trange
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args, enables_dynamic_vram
|
||||||
import json
|
import json
|
||||||
@ -37,26 +38,26 @@ import warnings
|
|||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
DISABLE_MMAP = args.disable_mmap
|
DISABLE_MMAP = args.disable_mmap
|
||||||
|
|
||||||
ALWAYS_SAFE_LOAD = False
|
|
||||||
if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
|
if True: # ckpt/pt file whitelist for safe loading of old sd files
|
||||||
class ModelCheckpoint:
|
class ModelCheckpoint:
|
||||||
pass
|
pass
|
||||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||||
|
|
||||||
def scalar(*args, **kwargs):
|
def scalar(*args, **kwargs):
|
||||||
from numpy.core.multiarray import scalar as sc
|
return None
|
||||||
return sc(*args, **kwargs)
|
|
||||||
scalar.__module__ = "numpy.core.multiarray"
|
scalar.__module__ = "numpy.core.multiarray"
|
||||||
|
|
||||||
from numpy import dtype
|
from numpy import dtype
|
||||||
from numpy.dtypes import Float64DType
|
from numpy.dtypes import Float64DType
|
||||||
from _codecs import encode
|
|
||||||
|
def encode(*args, **kwargs): # no longer necessary on newer torch
|
||||||
|
return None
|
||||||
|
encode.__module__ = "_codecs"
|
||||||
|
|
||||||
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
|
||||||
ALWAYS_SAFE_LOAD = True
|
|
||||||
logging.info("Checkpoint files will always be loaded safely.")
|
logging.info("Checkpoint files will always be loaded safely.")
|
||||||
else:
|
|
||||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
|
||||||
|
|
||||||
# Current as of safetensors 0.7.0
|
# Current as of safetensors 0.7.0
|
||||||
_TYPES = {
|
_TYPES = {
|
||||||
@ -139,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
|||||||
if MMAP_TORCH_FILES:
|
if MMAP_TORCH_FILES:
|
||||||
torch_args["mmap"] = True
|
torch_args["mmap"] = True
|
||||||
|
|
||||||
if safe_load or ALWAYS_SAFE_LOAD:
|
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
||||||
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
|
|
||||||
else:
|
|
||||||
logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
|
|
||||||
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
|
|
||||||
if "state_dict" in pl_sd:
|
if "state_dict" in pl_sd:
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
else:
|
else:
|
||||||
@ -674,10 +672,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
"ff_context.linear_in.bias": "txt_mlp.0.bias",
|
||||||
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
"ff_context.linear_out.weight": "txt_mlp.2.weight",
|
||||||
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
"ff_context.linear_out.bias": "txt_mlp.2.bias",
|
||||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
"attn.norm_q.weight": "img_attn.norm.query_norm.weight",
|
||||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
"attn.norm_k.weight": "img_attn.norm.key_norm.weight",
|
||||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
|
||||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
@ -700,8 +698,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
"norm.linear.bias": "modulation.lin.bias",
|
"norm.linear.bias": "modulation.lin.bias",
|
||||||
"proj_out.weight": "linear2.weight",
|
"proj_out.weight": "linear2.weight",
|
||||||
"proj_out.bias": "linear2.bias",
|
"proj_out.bias": "linear2.bias",
|
||||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
"attn.norm_q.weight": "norm.query_norm.weight",
|
||||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
"attn.norm_k.weight": "norm.key_norm.weight",
|
||||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||||
}
|
}
|
||||||
@ -1155,6 +1153,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||||
|
|
||||||
|
def model_trange(*args, **kwargs):
|
||||||
|
if comfy.memory_management.aimdo_allocator is None:
|
||||||
|
return trange(*args, **kwargs)
|
||||||
|
|
||||||
|
pbar = trange(*args, **kwargs, smoothing=1.0)
|
||||||
|
pbar._i = 0
|
||||||
|
pbar.set_postfix_str(" Model Initializing ... ")
|
||||||
|
|
||||||
|
_update = pbar.update
|
||||||
|
|
||||||
|
def warmup_update(n=1):
|
||||||
|
pbar._i += 1
|
||||||
|
if pbar._i == 1:
|
||||||
|
pbar.i1_time = time.time()
|
||||||
|
pbar.set_postfix_str(" Model Initialization complete! ")
|
||||||
|
elif pbar._i == 2:
|
||||||
|
#bring forward the effective start time based the the diff between first and second iteration
|
||||||
|
#to attempt to remove load overhead from the final step rate estimate.
|
||||||
|
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
|
||||||
|
pbar.set_postfix_str("")
|
||||||
|
|
||||||
|
_update(n)
|
||||||
|
|
||||||
|
pbar.update = warmup_update
|
||||||
|
return pbar
|
||||||
|
|
||||||
PROGRESS_BAR_ENABLED = True
|
PROGRESS_BAR_ENABLED = True
|
||||||
def set_progress_bar_enabled(enabled):
|
def set_progress_bar_enabled(enabled):
|
||||||
global PROGRESS_BAR_ENABLED
|
global PROGRESS_BAR_ENABLED
|
||||||
@ -1376,3 +1400,21 @@ def string_to_seed(data):
|
|||||||
else:
|
else:
|
||||||
crc >>= 1
|
crc >>= 1
|
||||||
return crc ^ 0xFFFFFFFF
|
return crc ^ 0xFFFFFFFF
|
||||||
|
|
||||||
|
def deepcopy_list_dict(obj, memo=None):
|
||||||
|
if memo is None:
|
||||||
|
memo = {}
|
||||||
|
|
||||||
|
obj_id = id(obj)
|
||||||
|
if obj_id in memo:
|
||||||
|
return memo[obj_id]
|
||||||
|
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
res = [deepcopy_list_dict(i, memo) for i in obj]
|
||||||
|
else:
|
||||||
|
res = obj
|
||||||
|
|
||||||
|
memo[obj_id] = res
|
||||||
|
return res
|
||||||
|
|||||||
@ -49,6 +49,12 @@ class WeightAdapterBase:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def calculate_shape(
|
||||||
|
self,
|
||||||
|
key
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
def calculate_weight(
|
def calculate_weight(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from typing import Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
from .base import WeightAdapterBase, WeightAdapterTrainBase
|
||||||
from comfy.patcher_extension import PatcherInjection
|
from comfy.patcher_extension import PatcherInjection
|
||||||
|
|
||||||
@ -181,18 +182,21 @@ class BypassForwardHook:
|
|||||||
)
|
)
|
||||||
return # Already injected
|
return # Already injected
|
||||||
|
|
||||||
# Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
|
# Move adapter weights to compute device (GPU)
|
||||||
device = None
|
# Use get_torch_device() instead of module.weight.device because
|
||||||
|
# with offloading, module weights may be on CPU while compute happens on GPU
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
# Get dtype from module weight if available
|
||||||
dtype = None
|
dtype = None
|
||||||
if hasattr(self.module, "weight") and self.module.weight is not None:
|
if hasattr(self.module, "weight") and self.module.weight is not None:
|
||||||
device = self.module.weight.device
|
|
||||||
dtype = self.module.weight.dtype
|
dtype = self.module.weight.dtype
|
||||||
elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
|
|
||||||
device = self.module.W_q.device
|
|
||||||
dtype = self.module.W_q.dtype
|
|
||||||
|
|
||||||
if device is not None:
|
# Only use dtype if it's a standard float type, not quantized
|
||||||
self._move_adapter_weights_to_device(device, dtype)
|
if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
|
||||||
|
dtype = None
|
||||||
|
|
||||||
|
self._move_adapter_weights_to_device(device, dtype)
|
||||||
|
|
||||||
self.original_forward = self.module.forward
|
self.original_forward = self.module.forward
|
||||||
self.module.forward = self._bypass_forward
|
self.module.forward = self._bypass_forward
|
||||||
|
|||||||
@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def calculate_shape(
|
||||||
|
self,
|
||||||
|
key
|
||||||
|
):
|
||||||
|
reshape = self.weights[5]
|
||||||
|
return tuple(reshape) if reshape is not None else None
|
||||||
|
|
||||||
def calculate_weight(
|
def calculate_weight(
|
||||||
self,
|
self,
|
||||||
weight,
|
weight,
|
||||||
|
|||||||
@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
|||||||
"supports_preview_metadata": True,
|
"supports_preview_metadata": True,
|
||||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||||
"extension": {"manager": {"supports_v4": True}},
|
"extension": {"manager": {"supports_v4": True}},
|
||||||
|
"node_replacements": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
|
|||||||
VERSION = "latest"
|
VERSION = "latest"
|
||||||
STABLE = False
|
STABLE = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.node_replacement = self.NodeReplacement()
|
||||||
|
self.execution = self.Execution()
|
||||||
|
|
||||||
|
class NodeReplacement(ProxiedSingleton):
|
||||||
|
async def register(self, node_replace: io.NodeReplace) -> None:
|
||||||
|
"""Register a node replacement mapping."""
|
||||||
|
from server import PromptServer
|
||||||
|
PromptServer.instance.node_replace_manager.register(node_replace)
|
||||||
|
|
||||||
class Execution(ProxiedSingleton):
|
class Execution(ProxiedSingleton):
|
||||||
async def set_progress(
|
async def set_progress(
|
||||||
self,
|
self,
|
||||||
@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
|
|||||||
image=to_display,
|
image=to_display,
|
||||||
)
|
)
|
||||||
|
|
||||||
execution: Execution
|
|
||||||
|
|
||||||
class ComfyExtension(ABC):
|
class ComfyExtension(ABC):
|
||||||
async def on_load(self) -> None:
|
async def on_load(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -34,6 +34,21 @@ class VideoInput(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def as_trimmed(
|
||||||
|
self,
|
||||||
|
start_time: float | None = None,
|
||||||
|
duration: float | None = None,
|
||||||
|
strict_duration: bool = False,
|
||||||
|
) -> VideoInput | None:
|
||||||
|
"""
|
||||||
|
Create a new VideoInput which is trimmed to have the corresponding start_time and duration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new VideoInput, or None if the result would have negative duration
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
def get_stream_source(self) -> Union[str, io.BytesIO]:
|
||||||
"""
|
"""
|
||||||
Get a streamable source for the video. This allows processing without
|
Get a streamable source for the video. This allows processing without
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Optional
|
|||||||
from .._input import AudioInput, VideoInput
|
from .._input import AudioInput, VideoInput
|
||||||
import av
|
import av
|
||||||
import io
|
import io
|
||||||
|
import itertools
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
|
|||||||
formats = container_format.split(",")
|
formats = container_format.split(",")
|
||||||
return formats[0]
|
return formats[0]
|
||||||
|
|
||||||
|
|
||||||
def get_open_write_kwargs(
|
def get_open_write_kwargs(
|
||||||
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
dest: str | io.BytesIO, container_format: str, to_format: str | None
|
||||||
) -> dict:
|
) -> dict:
|
||||||
@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
|
|||||||
Class representing video input from a file.
|
Class representing video input from a file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, file: str | io.BytesIO):
|
def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
|
||||||
"""
|
"""
|
||||||
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
|
||||||
containing the file contents.
|
containing the file contents.
|
||||||
"""
|
"""
|
||||||
self.__file = file
|
self.__file = file
|
||||||
|
self.__start_time = start_time
|
||||||
|
self.__duration = duration
|
||||||
|
|
||||||
def get_stream_source(self) -> str | io.BytesIO:
|
def get_stream_source(self) -> str | io.BytesIO:
|
||||||
"""
|
"""
|
||||||
@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
|
|||||||
Returns:
|
Returns:
|
||||||
Duration in seconds
|
Duration in seconds
|
||||||
"""
|
"""
|
||||||
|
raw_duration = self._get_raw_duration()
|
||||||
|
if self.__start_time < 0:
|
||||||
|
duration_from_start = min(raw_duration, -self.__start_time)
|
||||||
|
else:
|
||||||
|
duration_from_start = raw_duration - self.__start_time
|
||||||
|
if self.__duration:
|
||||||
|
return min(self.__duration, duration_from_start)
|
||||||
|
return duration_from_start
|
||||||
|
|
||||||
|
def _get_raw_duration(self) -> float:
|
||||||
if isinstance(self.__file, io.BytesIO):
|
if isinstance(self.__file, io.BytesIO):
|
||||||
self.__file.seek(0)
|
self.__file.seek(0)
|
||||||
with av.open(self.__file, mode="r") as container:
|
with av.open(self.__file, mode="r") as container:
|
||||||
@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
|
|||||||
if video_stream and video_stream.average_rate:
|
if video_stream and video_stream.average_rate:
|
||||||
frame_count = 0
|
frame_count = 0
|
||||||
container.seek(0)
|
container.seek(0)
|
||||||
for packet in container.demux(video_stream):
|
frame_iterator = (
|
||||||
for _ in packet.decode():
|
container.decode(video_stream)
|
||||||
frame_count += 1
|
if video_stream.codec.capabilities & 0x100
|
||||||
|
else container.demux(video_stream)
|
||||||
|
)
|
||||||
|
for packet in frame_iterator:
|
||||||
|
frame_count += 1
|
||||||
if frame_count > 0:
|
if frame_count > 0:
|
||||||
return float(frame_count / video_stream.average_rate)
|
return float(frame_count / video_stream.average_rate)
|
||||||
|
|
||||||
@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
|
|||||||
|
|
||||||
with av.open(self.__file, mode="r") as container:
|
with av.open(self.__file, mode="r") as container:
|
||||||
video_stream = self._get_first_video_stream(container)
|
video_stream = self._get_first_video_stream(container)
|
||||||
# 1. Prefer the frames field if available
|
# 1. Prefer the frames field if available and usable
|
||||||
if video_stream.frames and video_stream.frames > 0:
|
if (
|
||||||
|
video_stream.frames
|
||||||
|
and video_stream.frames > 0
|
||||||
|
and not self.__start_time
|
||||||
|
and not self.__duration
|
||||||
|
):
|
||||||
return int(video_stream.frames)
|
return int(video_stream.frames)
|
||||||
|
|
||||||
# 2. Try to estimate from duration and average_rate using only metadata
|
# 2. Try to estimate from duration and average_rate using only metadata
|
||||||
if container.duration is not None and video_stream.average_rate:
|
|
||||||
duration_seconds = float(container.duration / av.time_base)
|
|
||||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
|
||||||
if estimated_frames > 0:
|
|
||||||
return estimated_frames
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
getattr(video_stream, "duration", None) is not None
|
getattr(video_stream, "duration", None) is not None
|
||||||
and getattr(video_stream, "time_base", None) is not None
|
and getattr(video_stream, "time_base", None) is not None
|
||||||
and video_stream.average_rate
|
and video_stream.average_rate
|
||||||
):
|
):
|
||||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
raw_duration = float(video_stream.duration * video_stream.time_base)
|
||||||
|
if self.__start_time < 0:
|
||||||
|
duration_from_start = min(raw_duration, -self.__start_time)
|
||||||
|
else:
|
||||||
|
duration_from_start = raw_duration - self.__start_time
|
||||||
|
duration_seconds = min(self.__duration, duration_from_start)
|
||||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||||
if estimated_frames > 0:
|
if estimated_frames > 0:
|
||||||
return estimated_frames
|
return estimated_frames
|
||||||
|
|
||||||
# 3. Last resort: decode frames and count them (streaming)
|
# 3. Last resort: decode frames and count them (streaming)
|
||||||
frame_count = 0
|
if self.__start_time < 0:
|
||||||
container.seek(0)
|
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||||
for packet in container.demux(video_stream):
|
else:
|
||||||
for _ in packet.decode():
|
start_time = self.__start_time
|
||||||
frame_count += 1
|
frame_count = 1
|
||||||
|
start_pts = int(start_time / video_stream.time_base)
|
||||||
if frame_count == 0:
|
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
container.seek(start_pts, stream=video_stream)
|
||||||
|
frame_iterator = (
|
||||||
|
container.decode(video_stream)
|
||||||
|
if video_stream.codec.capabilities & 0x100
|
||||||
|
else container.demux(video_stream)
|
||||||
|
)
|
||||||
|
for frame in frame_iterator:
|
||||||
|
if frame.pts >= start_pts:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
|
||||||
|
for frame in frame_iterator:
|
||||||
|
if frame.pts >= end_pts:
|
||||||
|
break
|
||||||
|
frame_count += 1
|
||||||
return frame_count
|
return frame_count
|
||||||
|
|
||||||
def get_frame_rate(self) -> Fraction:
|
def get_frame_rate(self) -> Fraction:
|
||||||
@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
|
|||||||
return container.format.name
|
return container.format.name
|
||||||
|
|
||||||
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
def get_components_internal(self, container: InputContainer) -> VideoComponents:
|
||||||
|
video_stream = self._get_first_video_stream(container)
|
||||||
|
if self.__start_time < 0:
|
||||||
|
start_time = max(self._get_raw_duration() + self.__start_time, 0)
|
||||||
|
else:
|
||||||
|
start_time = self.__start_time
|
||||||
# Get video frames
|
# Get video frames
|
||||||
frames = []
|
frames = []
|
||||||
for frame in container.decode(video=0):
|
start_pts = int(start_time / video_stream.time_base)
|
||||||
|
end_pts = int((start_time + self.__duration) / video_stream.time_base)
|
||||||
|
container.seek(start_pts, stream=video_stream)
|
||||||
|
for frame in container.decode(video_stream):
|
||||||
|
if frame.pts < start_pts:
|
||||||
|
continue
|
||||||
|
if self.__duration and frame.pts >= end_pts:
|
||||||
|
break
|
||||||
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
|
||||||
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
|
||||||
frames.append(img)
|
frames.append(img)
|
||||||
@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
|
|||||||
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
|
||||||
|
|
||||||
# Get frame rate
|
# Get frame rate
|
||||||
video_stream = next(s for s in container.streams if s.type == 'video')
|
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
|
||||||
frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
|
|
||||||
|
|
||||||
# Get audio if available
|
# Get audio if available
|
||||||
audio = None
|
audio = None
|
||||||
try:
|
container.seek(start_pts, stream=video_stream)
|
||||||
container.seek(0) # Reset the container to the beginning
|
# Use last stream for consistency
|
||||||
for stream in container.streams:
|
if len(container.streams.audio):
|
||||||
if stream.type != 'audio':
|
audio_stream = container.streams.audio[-1]
|
||||||
continue
|
audio_frames = []
|
||||||
assert isinstance(stream, av.AudioStream)
|
resample = av.audio.resampler.AudioResampler(format='fltp').resample
|
||||||
audio_frames = []
|
frames = itertools.chain.from_iterable(
|
||||||
for packet in container.demux(stream):
|
map(resample, container.decode(audio_stream))
|
||||||
for frame in packet.decode():
|
)
|
||||||
assert isinstance(frame, av.AudioFrame)
|
|
||||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
has_first_frame = False
|
||||||
if len(audio_frames) > 0:
|
for frame in frames:
|
||||||
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||||
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||||
audio = AudioInput({
|
if to_skip < frame.samples:
|
||||||
"waveform": audio_tensor,
|
has_first_frame = True
|
||||||
"sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
|
break
|
||||||
})
|
if has_first_frame:
|
||||||
except StopIteration:
|
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||||
pass # No audio stream
|
|
||||||
|
for frame in frames:
|
||||||
|
if frame.time > start_time + self.__duration:
|
||||||
|
break
|
||||||
|
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||||
|
if len(audio_frames) > 0:
|
||||||
|
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
|
||||||
|
if self.__duration:
|
||||||
|
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
|
||||||
|
|
||||||
|
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
|
||||||
|
audio = AudioInput({
|
||||||
|
"waveform": audio_tensor,
|
||||||
|
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
|
||||||
|
})
|
||||||
|
|
||||||
metadata = container.metadata
|
metadata = container.metadata
|
||||||
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
|
||||||
@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
|
|||||||
path: str | io.BytesIO,
|
path: str | io.BytesIO,
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if isinstance(self.__file, io.BytesIO):
|
if isinstance(self.__file, io.BytesIO):
|
||||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||||
@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
|
|||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||||
reuse_streams = False
|
reuse_streams = False
|
||||||
|
if self.__start_time or self.__duration:
|
||||||
|
reuse_streams = False
|
||||||
|
|
||||||
if not reuse_streams:
|
if not reuse_streams:
|
||||||
components = self.get_components_internal(container)
|
components = self.get_components_internal(container)
|
||||||
video = VideoFromComponents(components)
|
video = VideoFromComponents(components)
|
||||||
return video.save_to(
|
return video.save_to(
|
||||||
path,
|
path, format=format, codec=codec, metadata=metadata
|
||||||
format=format,
|
|
||||||
codec=codec,
|
|
||||||
metadata=metadata
|
|
||||||
)
|
)
|
||||||
|
|
||||||
streams = container.streams
|
streams = container.streams
|
||||||
@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
|
|||||||
output_container.mux(packet)
|
output_container.mux(packet)
|
||||||
|
|
||||||
def _get_first_video_stream(self, container: InputContainer):
|
def _get_first_video_stream(self, container: InputContainer):
|
||||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
if len(container.streams.video):
|
||||||
if video_stream is None:
|
return container.streams.video[0]
|
||||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||||
return video_stream
|
|
||||||
|
def as_trimmed(
|
||||||
|
self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
|
||||||
|
) -> VideoInput | None:
|
||||||
|
trimmed = VideoFromFile(
|
||||||
|
self.get_stream_source(),
|
||||||
|
start_time=start_time + self.__start_time,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
if trimmed.get_duration() < duration and strict_duration:
|
||||||
|
return None
|
||||||
|
return trimmed
|
||||||
|
|
||||||
|
|
||||||
class VideoFromComponents(VideoInput):
|
class VideoFromComponents(VideoInput):
|
||||||
@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
|
|||||||
return VideoComponents(
|
return VideoComponents(
|
||||||
images=self.__components.images,
|
images=self.__components.images,
|
||||||
audio=self.__components.audio,
|
audio=self.__components.audio,
|
||||||
frame_rate=self.__components.frame_rate
|
frame_rate=self.__components.frame_rate,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_to(
|
def save_to(
|
||||||
@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
|
|||||||
path: str,
|
path: str,
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||||
raise ValueError("Only MP4 format is supported for now")
|
raise ValueError("Only MP4 format is supported for now")
|
||||||
@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput):
|
|||||||
audio_stream: Optional[av.AudioStream] = None
|
audio_stream: Optional[av.AudioStream] = None
|
||||||
if self.__components.audio:
|
if self.__components.audio:
|
||||||
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
audio_sample_rate = int(self.__components.audio['sample_rate'])
|
||||||
audio_stream = output.add_stream('aac', rate=audio_sample_rate)
|
waveform = self.__components.audio['waveform']
|
||||||
|
waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
||||||
|
layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
|
||||||
|
audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
|
||||||
|
|
||||||
# Encode video
|
# Encode video
|
||||||
for i, frame in enumerate(self.__components.images):
|
for i, frame in enumerate(self.__components.images):
|
||||||
@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput):
|
|||||||
output.mux(packet)
|
output.mux(packet)
|
||||||
|
|
||||||
if audio_stream and self.__components.audio:
|
if audio_stream and self.__components.audio:
|
||||||
waveform = self.__components.audio['waveform']
|
frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
|
||||||
waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
|
|
||||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
|
|
||||||
frame.sample_rate = audio_sample_rate
|
frame.sample_rate = audio_sample_rate
|
||||||
frame.pts = 0
|
frame.pts = 0
|
||||||
output.mux(audio_stream.encode(frame))
|
output.mux(audio_stream.encode(frame))
|
||||||
|
|
||||||
# Flush encoder
|
# Flush encoder
|
||||||
output.mux(audio_stream.encode(None))
|
output.mux(audio_stream.encode(None))
|
||||||
|
|
||||||
|
def as_trimmed(
|
||||||
|
self,
|
||||||
|
start_time: float | None = None,
|
||||||
|
duration: float | None = None,
|
||||||
|
strict_duration: bool = True,
|
||||||
|
) -> VideoInput | None:
|
||||||
|
if self.get_duration() < start_time + duration:
|
||||||
|
return None
|
||||||
|
#TODO Consider tracking duration and trimming at time of save?
|
||||||
|
return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
|
||||||
|
|||||||
@ -2030,6 +2030,68 @@ class _UIOutput(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class InputMapOldId(TypedDict):
|
||||||
|
"""Map an old node input to a new node input by ID."""
|
||||||
|
new_id: str
|
||||||
|
old_id: str
|
||||||
|
|
||||||
|
class InputMapSetValue(TypedDict):
|
||||||
|
"""Set a specific value for a new node input."""
|
||||||
|
new_id: str
|
||||||
|
set_value: Any
|
||||||
|
|
||||||
|
InputMap = InputMapOldId | InputMapSetValue
|
||||||
|
"""
|
||||||
|
Input mapping for node replacement. Type is inferred by dictionary keys:
|
||||||
|
- {"new_id": str, "old_id": str} - maps old input to new input
|
||||||
|
- {"new_id": str, "set_value": Any} - sets a specific value for new input
|
||||||
|
"""
|
||||||
|
|
||||||
|
class OutputMap(TypedDict):
|
||||||
|
"""Map outputs of node replacement via indexes."""
|
||||||
|
new_idx: int
|
||||||
|
old_idx: int
|
||||||
|
|
||||||
|
class NodeReplace:
|
||||||
|
"""
|
||||||
|
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
|
||||||
|
|
||||||
|
Also supports assigning specific values to the input widgets of the new node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_node_id: The class name of the new replacement node.
|
||||||
|
old_node_id: The class name of the deprecated node.
|
||||||
|
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
|
||||||
|
connected. The workflow JSON stores widget values by their relative position index,
|
||||||
|
not by ID. This list maps those positional indexes to input IDs, enabling the
|
||||||
|
replacement system to correctly identify widget values during node migration.
|
||||||
|
input_mapping: List of input mappings from old node to new node.
|
||||||
|
output_mapping: List of output mappings from old node to new node.
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
new_node_id: str,
|
||||||
|
old_node_id: str,
|
||||||
|
old_widget_ids: list[str] | None=None,
|
||||||
|
input_mapping: list[InputMap] | None=None,
|
||||||
|
output_mapping: list[OutputMap] | None=None,
|
||||||
|
):
|
||||||
|
self.new_node_id = new_node_id
|
||||||
|
self.old_node_id = old_node_id
|
||||||
|
self.old_widget_ids = old_widget_ids
|
||||||
|
self.input_mapping = input_mapping
|
||||||
|
self.output_mapping = output_mapping
|
||||||
|
|
||||||
|
def as_dict(self):
|
||||||
|
"""Create serializable representation of the node replacement."""
|
||||||
|
return {
|
||||||
|
"new_node_id": self.new_node_id,
|
||||||
|
"old_node_id": self.old_node_id,
|
||||||
|
"old_widget_ids": self.old_widget_ids,
|
||||||
|
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
|
||||||
|
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"FolderType",
|
"FolderType",
|
||||||
"UploadType",
|
"UploadType",
|
||||||
@ -2121,4 +2183,5 @@ __all__ = [
|
|||||||
"ImageCompare",
|
"ImageCompare",
|
||||||
"PriceBadgeDepends",
|
"PriceBadgeDepends",
|
||||||
"PriceBadge",
|
"PriceBadge",
|
||||||
|
"NodeReplace",
|
||||||
]
|
]
|
||||||
|
|||||||
8
comfy_api_nodes/apis/__init__.py
generated
8
comfy_api_nodes/apis/__init__.py
generated
@ -1197,12 +1197,6 @@ class KlingImageGenImageReferenceType(str, Enum):
|
|||||||
face = 'face'
|
face = 'face'
|
||||||
|
|
||||||
|
|
||||||
class KlingImageGenModelName(str, Enum):
|
|
||||||
kling_v1 = 'kling-v1'
|
|
||||||
kling_v1_5 = 'kling-v1-5'
|
|
||||||
kling_v2 = 'kling-v2'
|
|
||||||
|
|
||||||
|
|
||||||
class KlingImageGenerationsRequest(BaseModel):
|
class KlingImageGenerationsRequest(BaseModel):
|
||||||
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
|
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
|
||||||
callback_url: Optional[AnyUrl] = Field(
|
callback_url: Optional[AnyUrl] = Field(
|
||||||
@ -1218,7 +1212,7 @@ class KlingImageGenerationsRequest(BaseModel):
|
|||||||
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
|
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
|
||||||
)
|
)
|
||||||
image_reference: Optional[KlingImageGenImageReferenceType] = None
|
image_reference: Optional[KlingImageGenImageReferenceType] = None
|
||||||
model_name: Optional[KlingImageGenModelName] = 'kling-v1'
|
model_name: str = Field(...)
|
||||||
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
|
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
|
||||||
negative_prompt: Optional[str] = Field(
|
negative_prompt: Optional[str] = Field(
|
||||||
None, description='Negative text prompt', max_length=200
|
None, description='Negative text prompt', max_length=200
|
||||||
|
|||||||
@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveBackgroundRequest(BaseModel):
|
||||||
|
image: str = Field(...)
|
||||||
|
sync: bool = Field(False)
|
||||||
|
visual_input_content_moderation: bool = Field(
|
||||||
|
False, description="If true, returns 422 on input image moderation failure."
|
||||||
|
)
|
||||||
|
visual_output_content_moderation: bool = Field(
|
||||||
|
False, description="If true, returns 422 on visual output moderation failure."
|
||||||
|
)
|
||||||
|
seed: int = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class BriaStatusResponse(BaseModel):
|
class BriaStatusResponse(BaseModel):
|
||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
status_url: str = Field(...)
|
status_url: str = Field(...)
|
||||||
warning: str | None = Field(None)
|
warning: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class BriaResult(BaseModel):
|
class BriaRemoveBackgroundResult(BaseModel):
|
||||||
|
image_url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveBackgroundResponse(BaseModel):
|
||||||
|
status: str = Field(...)
|
||||||
|
result: BriaRemoveBackgroundResult | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaImageEditResult(BaseModel):
|
||||||
structured_prompt: str = Field(...)
|
structured_prompt: str = Field(...)
|
||||||
image_url: str = Field(...)
|
image_url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class BriaResponse(BaseModel):
|
class BriaImageEditResponse(BaseModel):
|
||||||
status: str = Field(...)
|
status: str = Field(...)
|
||||||
result: BriaResult | None = Field(None)
|
result: BriaImageEditResult | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveVideoBackgroundRequest(BaseModel):
|
||||||
|
video: str = Field(...)
|
||||||
|
background_color: str = Field(default="transparent", description="Background color for the output video.")
|
||||||
|
output_container_and_codec: str = Field(...)
|
||||||
|
preserve_audio: bool = Field(True)
|
||||||
|
seed: int = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveVideoBackgroundResult(BaseModel):
|
||||||
|
video_url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveVideoBackgroundResponse(BaseModel):
|
||||||
|
status: str = Field(...)
|
||||||
|
result: BriaRemoveVideoBackgroundResult | None = Field(None)
|
||||||
|
|||||||
@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
|
|||||||
|
|
||||||
class To3DProTaskQueryRequest(BaseModel):
|
class To3DProTaskQueryRequest(BaseModel):
|
||||||
JobId: str = Field(...)
|
JobId: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class To3DUVFileInput(BaseModel):
|
||||||
|
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
|
||||||
|
Url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class To3DUVTaskRequest(BaseModel):
|
||||||
|
File: To3DUVFileInput = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TextureEditImageInfo(BaseModel):
|
||||||
|
Url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class TextureEditTaskRequest(BaseModel):
|
||||||
|
File3D: To3DUVFileInput = Field(...)
|
||||||
|
Image: TextureEditImageInfo | None = Field(None)
|
||||||
|
Prompt: str | None = Field(None)
|
||||||
|
EnablePBR: bool | None = Field(None)
|
||||||
|
|||||||
@ -1,12 +1,22 @@
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class MultiPromptEntry(BaseModel):
|
||||||
|
index: int = Field(...)
|
||||||
|
prompt: str = Field(...)
|
||||||
|
duration: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class OmniProText2VideoRequest(BaseModel):
|
class OmniProText2VideoRequest(BaseModel):
|
||||||
model_name: str = Field(..., description="kling-video-o1")
|
model_name: str = Field(..., description="kling-video-o1")
|
||||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||||
duration: str = Field(..., description="'5' or '10'")
|
duration: str = Field(..., description="'5' or '10'")
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
mode: str = Field("pro")
|
mode: str = Field("pro")
|
||||||
|
multi_shot: bool | None = Field(None)
|
||||||
|
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||||
|
shot_type: str | None = Field(None)
|
||||||
|
sound: str = Field(..., description="'on' or 'off'")
|
||||||
|
|
||||||
|
|
||||||
class OmniParamImage(BaseModel):
|
class OmniParamImage(BaseModel):
|
||||||
@ -26,6 +36,10 @@ class OmniProFirstLastFrameRequest(BaseModel):
|
|||||||
duration: str = Field(..., description="'5' or '10'")
|
duration: str = Field(..., description="'5' or '10'")
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
mode: str = Field("pro")
|
mode: str = Field("pro")
|
||||||
|
sound: str | None = Field(None, description="'on' or 'off'")
|
||||||
|
multi_shot: bool | None = Field(None)
|
||||||
|
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||||
|
shot_type: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class OmniProReferences2VideoRequest(BaseModel):
|
class OmniProReferences2VideoRequest(BaseModel):
|
||||||
@ -38,6 +52,10 @@ class OmniProReferences2VideoRequest(BaseModel):
|
|||||||
duration: str | None = Field(..., description="From 3 to 10.")
|
duration: str | None = Field(..., description="From 3 to 10.")
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
mode: str = Field("pro")
|
mode: str = Field("pro")
|
||||||
|
sound: str | None = Field(None, description="'on' or 'off'")
|
||||||
|
multi_shot: bool | None = Field(None)
|
||||||
|
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||||
|
shot_type: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusVideoResult(BaseModel):
|
class TaskStatusVideoResult(BaseModel):
|
||||||
@ -54,6 +72,7 @@ class TaskStatusImageResult(BaseModel):
|
|||||||
class TaskStatusResults(BaseModel):
|
class TaskStatusResults(BaseModel):
|
||||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||||
images: list[TaskStatusImageResult] | None = Field(None)
|
images: list[TaskStatusImageResult] | None = Field(None)
|
||||||
|
series_images: list[TaskStatusImageResult] | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class TaskStatusResponseData(BaseModel):
|
class TaskStatusResponseData(BaseModel):
|
||||||
@ -77,31 +96,42 @@ class OmniImageParamImage(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class OmniProImageRequest(BaseModel):
|
class OmniProImageRequest(BaseModel):
|
||||||
model_name: str = Field(..., description="kling-image-o1")
|
model_name: str = Field(...)
|
||||||
resolution: str = Field(..., description="'1k' or '2k'")
|
resolution: str = Field(...)
|
||||||
aspect_ratio: str | None = Field(...)
|
aspect_ratio: str | None = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
mode: str = Field("pro")
|
mode: str = Field("pro")
|
||||||
n: int | None = Field(1, le=9)
|
n: int | None = Field(1, le=9)
|
||||||
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||||
|
result_type: str | None = Field(None, description="Set to 'series' for series generation")
|
||||||
|
series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series")
|
||||||
|
|
||||||
|
|
||||||
class TextToVideoWithAudioRequest(BaseModel):
|
class TextToVideoWithAudioRequest(BaseModel):
|
||||||
model_name: str = Field(..., description="kling-v2-6")
|
model_name: str = Field(...)
|
||||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||||
duration: str = Field(..., description="'5' or '10'")
|
duration: str = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str | None = Field(...)
|
||||||
|
negative_prompt: str | None = Field(None)
|
||||||
mode: str = Field("pro")
|
mode: str = Field("pro")
|
||||||
sound: str = Field(..., description="'on' or 'off'")
|
sound: str = Field(..., description="'on' or 'off'")
|
||||||
|
multi_shot: bool | None = Field(None)
|
||||||
|
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||||
|
shot_type: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ImageToVideoWithAudioRequest(BaseModel):
|
class ImageToVideoWithAudioRequest(BaseModel):
|
||||||
model_name: str = Field(..., description="kling-v2-6")
|
model_name: str = Field(...)
|
||||||
image: str = Field(...)
|
image: str = Field(...)
|
||||||
duration: str = Field(..., description="'5' or '10'")
|
image_tail: str | None = Field(None)
|
||||||
prompt: str = Field(...)
|
duration: str = Field(...)
|
||||||
|
prompt: str | None = Field(...)
|
||||||
|
negative_prompt: str | None = Field(None)
|
||||||
mode: str = Field("pro")
|
mode: str = Field("pro")
|
||||||
sound: str = Field(..., description="'on' or 'off'")
|
sound: str = Field(..., description="'on' or 'off'")
|
||||||
|
multi_shot: bool | None = Field(None)
|
||||||
|
multi_prompt: list[MultiPromptEntry] | None = Field(None)
|
||||||
|
shot_type: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class MotionControlRequest(BaseModel):
|
class MotionControlRequest(BaseModel):
|
||||||
|
|||||||
@ -3,7 +3,11 @@ from typing_extensions import override
|
|||||||
from comfy_api.latest import IO, ComfyExtension, Input
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from comfy_api_nodes.apis.bria import (
|
from comfy_api_nodes.apis.bria import (
|
||||||
BriaEditImageRequest,
|
BriaEditImageRequest,
|
||||||
BriaResponse,
|
BriaRemoveBackgroundRequest,
|
||||||
|
BriaRemoveBackgroundResponse,
|
||||||
|
BriaRemoveVideoBackgroundRequest,
|
||||||
|
BriaRemoveVideoBackgroundResponse,
|
||||||
|
BriaImageEditResponse,
|
||||||
BriaStatusResponse,
|
BriaStatusResponse,
|
||||||
InputModerationSettings,
|
InputModerationSettings,
|
||||||
)
|
)
|
||||||
@ -11,10 +15,12 @@ from comfy_api_nodes.util import (
|
|||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
convert_mask_to_image,
|
convert_mask_to_image,
|
||||||
download_url_to_image_tensor,
|
download_url_to_image_tensor,
|
||||||
get_number_of_images,
|
download_url_to_video_output,
|
||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
upload_images_to_comfyapi,
|
upload_image_to_comfyapi,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
|
validate_video_duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -73,21 +79,15 @@ class BriaImageEditNode(IO.ComfyNode):
|
|||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"moderation",
|
"moderation",
|
||||||
options=[
|
options=[
|
||||||
|
IO.DynamicCombo.Option("false", []),
|
||||||
IO.DynamicCombo.Option(
|
IO.DynamicCombo.Option(
|
||||||
"true",
|
"true",
|
||||||
[
|
[
|
||||||
IO.Boolean.Input(
|
IO.Boolean.Input("prompt_content_moderation", default=False),
|
||||||
"prompt_content_moderation", default=False
|
IO.Boolean.Input("visual_input_moderation", default=False),
|
||||||
),
|
IO.Boolean.Input("visual_output_moderation", default=True),
|
||||||
IO.Boolean.Input(
|
|
||||||
"visual_input_moderation", default=False
|
|
||||||
),
|
|
||||||
IO.Boolean.Input(
|
|
||||||
"visual_output_moderation", default=True
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
IO.DynamicCombo.Option("false", []),
|
|
||||||
],
|
],
|
||||||
tooltip="Moderation settings",
|
tooltip="Moderation settings",
|
||||||
),
|
),
|
||||||
@ -127,50 +127,26 @@ class BriaImageEditNode(IO.ComfyNode):
|
|||||||
mask: Input.Image | None = None,
|
mask: Input.Image | None = None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if not prompt and not structured_prompt:
|
if not prompt and not structured_prompt:
|
||||||
raise ValueError(
|
raise ValueError("One of prompt or structured_prompt is required to be non-empty.")
|
||||||
"One of prompt or structured_prompt is required to be non-empty."
|
|
||||||
)
|
|
||||||
if get_number_of_images(image) != 1:
|
|
||||||
raise ValueError("Exactly one input image is required.")
|
|
||||||
mask_url = None
|
mask_url = None
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask_url = (
|
mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask")
|
||||||
await upload_images_to_comfyapi(
|
|
||||||
cls,
|
|
||||||
convert_mask_to_image(mask),
|
|
||||||
max_images=1,
|
|
||||||
mime_type="image/png",
|
|
||||||
wait_label="Uploading mask",
|
|
||||||
)
|
|
||||||
)[0]
|
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
|
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
|
||||||
data=BriaEditImageRequest(
|
data=BriaEditImageRequest(
|
||||||
instruction=prompt if prompt else None,
|
instruction=prompt if prompt else None,
|
||||||
structured_instruction=structured_prompt if structured_prompt else None,
|
structured_instruction=structured_prompt if structured_prompt else None,
|
||||||
images=await upload_images_to_comfyapi(
|
images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")],
|
||||||
cls,
|
|
||||||
image,
|
|
||||||
max_images=1,
|
|
||||||
mime_type="image/png",
|
|
||||||
wait_label="Uploading image",
|
|
||||||
),
|
|
||||||
mask=mask_url,
|
mask=mask_url,
|
||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
guidance_scale=guidance_scale,
|
guidance_scale=guidance_scale,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
model_version=model,
|
model_version=model,
|
||||||
steps_num=steps,
|
steps_num=steps,
|
||||||
prompt_content_moderation=moderation.get(
|
prompt_content_moderation=moderation.get("prompt_content_moderation", False),
|
||||||
"prompt_content_moderation", False
|
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
|
||||||
),
|
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
|
||||||
visual_input_content_moderation=moderation.get(
|
|
||||||
"visual_input_moderation", False
|
|
||||||
),
|
|
||||||
visual_output_content_moderation=moderation.get(
|
|
||||||
"visual_output_moderation", False
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
response_model=BriaStatusResponse,
|
response_model=BriaStatusResponse,
|
||||||
)
|
)
|
||||||
@ -178,7 +154,7 @@ class BriaImageEditNode(IO.ComfyNode):
|
|||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||||
status_extractor=lambda r: r.status,
|
status_extractor=lambda r: r.status,
|
||||||
response_model=BriaResponse,
|
response_model=BriaImageEditResponse,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
await download_url_to_image_tensor(response.result.image_url),
|
await download_url_to_image_tensor(response.result.image_url),
|
||||||
@ -186,11 +162,167 @@ class BriaImageEditNode(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveImageBackground(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="BriaRemoveImageBackground",
|
||||||
|
display_name="Bria Remove Image Background",
|
||||||
|
category="api node/image/Bria",
|
||||||
|
description="Remove the background from an image using Bria RMBG 2.0.",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input("image"),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"moderation",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option("false", []),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"true",
|
||||||
|
[
|
||||||
|
IO.Boolean.Input("visual_input_moderation", default=False),
|
||||||
|
IO.Boolean.Input("visual_output_moderation", default=True),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="Moderation settings",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Image.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr="""{"type":"usd","usd":0.018}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
image: Input.Image,
|
||||||
|
moderation: dict,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"),
|
||||||
|
data=BriaRemoveBackgroundRequest(
|
||||||
|
image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"),
|
||||||
|
sync=False,
|
||||||
|
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
|
||||||
|
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
response_model=BriaStatusResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
response_model=BriaRemoveBackgroundResponse,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url))
|
||||||
|
|
||||||
|
|
||||||
|
class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="BriaRemoveVideoBackground",
|
||||||
|
display_name="Bria Remove Video Background",
|
||||||
|
category="api node/video/Bria",
|
||||||
|
description="Remove the background from a video using Bria. ",
|
||||||
|
inputs=[
|
||||||
|
IO.Video.Input("video"),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"background_color",
|
||||||
|
options=[
|
||||||
|
"Black",
|
||||||
|
"White",
|
||||||
|
"Gray",
|
||||||
|
"Red",
|
||||||
|
"Green",
|
||||||
|
"Blue",
|
||||||
|
"Yellow",
|
||||||
|
"Cyan",
|
||||||
|
"Magenta",
|
||||||
|
"Orange",
|
||||||
|
],
|
||||||
|
tooltip="Background color for the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[IO.Video.Output()],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
video: Input.Video,
|
||||||
|
background_color: str,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_video_duration(video, max_duration=60.0)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
|
||||||
|
data=BriaRemoveVideoBackgroundRequest(
|
||||||
|
video=await upload_video_to_comfyapi(cls, video),
|
||||||
|
background_color=background_color,
|
||||||
|
output_container_and_codec="mp4_h264",
|
||||||
|
seed=seed,
|
||||||
|
),
|
||||||
|
response_model=BriaStatusResponse,
|
||||||
|
)
|
||||||
|
response = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||||
|
status_extractor=lambda r: r.status,
|
||||||
|
response_model=BriaRemoveVideoBackgroundResponse,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||||
|
|
||||||
|
|
||||||
class BriaExtension(ComfyExtension):
|
class BriaExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
BriaImageEditNode,
|
BriaImageEditNode,
|
||||||
|
BriaRemoveImageBackground,
|
||||||
|
BriaRemoveVideoBackground,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,31 +1,48 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input
|
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||||
from comfy_api_nodes.apis.hunyuan3d import (
|
from comfy_api_nodes.apis.hunyuan3d import (
|
||||||
Hunyuan3DViewImage,
|
Hunyuan3DViewImage,
|
||||||
InputGenerateType,
|
InputGenerateType,
|
||||||
ResultFile3D,
|
ResultFile3D,
|
||||||
|
TextureEditTaskRequest,
|
||||||
To3DProTaskCreateResponse,
|
To3DProTaskCreateResponse,
|
||||||
To3DProTaskQueryRequest,
|
To3DProTaskQueryRequest,
|
||||||
To3DProTaskRequest,
|
To3DProTaskRequest,
|
||||||
To3DProTaskResultResponse,
|
To3DProTaskResultResponse,
|
||||||
|
To3DUVFileInput,
|
||||||
|
To3DUVTaskRequest,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
download_url_to_file_3d,
|
download_url_to_file_3d,
|
||||||
|
download_url_to_image_tensor,
|
||||||
downscale_image_tensor_by_max_side,
|
downscale_image_tensor_by_max_side,
|
||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
|
upload_3d_model_to_comfyapi,
|
||||||
upload_image_to_comfyapi,
|
upload_image_to_comfyapi,
|
||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
|
def _is_tencent_rate_limited(status: int, body: object) -> bool:
|
||||||
|
return (
|
||||||
|
status == 400
|
||||||
|
and isinstance(body, dict)
|
||||||
|
and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_from_response(
|
||||||
|
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
|
||||||
|
) -> ResultFile3D | None:
|
||||||
for i in response_objs:
|
for i in response_objs:
|
||||||
if i.Type.lower() == file_type.lower():
|
if i.Type.lower() == file_type.lower():
|
||||||
return i
|
return i
|
||||||
|
if raise_if_not_found:
|
||||||
|
raise ValueError(f"'{file_type}' file type is not found in the response.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -35,7 +52,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="TencentTextToModelNode",
|
node_id="TencentTextToModelNode",
|
||||||
display_name="Hunyuan3D: Text to Model (Pro)",
|
display_name="Hunyuan3D: Text to Model",
|
||||||
category="api node/3d/Tencent",
|
category="api node/3d/Tencent",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -120,6 +137,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
|||||||
EnablePBR=generate_type.get("pbr", None),
|
EnablePBR=generate_type.get("pbr", None),
|
||||||
PolygonType=generate_type.get("polygon_type", None),
|
PolygonType=generate_type.get("polygon_type", None),
|
||||||
),
|
),
|
||||||
|
is_rate_limited=_is_tencent_rate_limited,
|
||||||
)
|
)
|
||||||
if response.Error:
|
if response.Error:
|
||||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||||
@ -131,11 +149,14 @@ class TencentTextToModelNode(IO.ComfyNode):
|
|||||||
response_model=To3DProTaskResultResponse,
|
response_model=To3DProTaskResultResponse,
|
||||||
status_extractor=lambda r: r.Status,
|
status_extractor=lambda r: r.Status,
|
||||||
)
|
)
|
||||||
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
|
|
||||||
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
|
|
||||||
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
|
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
|
f"{task_id}.glb",
|
||||||
|
await download_url_to_file_3d(
|
||||||
|
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||||
|
),
|
||||||
|
await download_url_to_file_3d(
|
||||||
|
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -145,7 +166,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="TencentImageToModelNode",
|
node_id="TencentImageToModelNode",
|
||||||
display_name="Hunyuan3D: Image(s) to Model (Pro)",
|
display_name="Hunyuan3D: Image(s) to Model",
|
||||||
category="api node/3d/Tencent",
|
category="api node/3d/Tencent",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
@ -268,6 +289,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
|||||||
EnablePBR=generate_type.get("pbr", None),
|
EnablePBR=generate_type.get("pbr", None),
|
||||||
PolygonType=generate_type.get("polygon_type", None),
|
PolygonType=generate_type.get("polygon_type", None),
|
||||||
),
|
),
|
||||||
|
is_rate_limited=_is_tencent_rate_limited,
|
||||||
)
|
)
|
||||||
if response.Error:
|
if response.Error:
|
||||||
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||||
@ -279,11 +301,257 @@ class TencentImageToModelNode(IO.ComfyNode):
|
|||||||
response_model=To3DProTaskResultResponse,
|
response_model=To3DProTaskResultResponse,
|
||||||
status_extractor=lambda r: r.Status,
|
status_extractor=lambda r: r.Status,
|
||||||
)
|
)
|
||||||
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
|
|
||||||
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
|
|
||||||
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
|
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
|
f"{task_id}.glb",
|
||||||
|
await download_url_to_file_3d(
|
||||||
|
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||||
|
),
|
||||||
|
await download_url_to_file_3d(
|
||||||
|
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="TencentModelTo3DUVNode",
|
||||||
|
display_name="Hunyuan3D: Model to UV",
|
||||||
|
category="api node/3d/Tencent",
|
||||||
|
description="Perform UV unfolding on a 3D model to generate UV texture. "
|
||||||
|
"Input model must have less than 30000 faces.",
|
||||||
|
inputs=[
|
||||||
|
IO.MultiType.Input(
|
||||||
|
"model_3d",
|
||||||
|
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny],
|
||||||
|
tooltip="Input 3D model (GLB, OBJ, or FBX)",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=1,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||||
|
IO.File3DFBX.Output(display_name="FBX"),
|
||||||
|
IO.Image.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'),
|
||||||
|
)
|
||||||
|
|
||||||
|
SUPPORTED_FORMATS = {"glb", "obj", "fbx"}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_3d: Types.File3D,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
_ = seed
|
||||||
|
file_format = model_3d.format.lower()
|
||||||
|
if file_format not in cls.SUPPORTED_FORMATS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported file format: '{file_format}'. "
|
||||||
|
f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
|
||||||
|
)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
|
||||||
|
response_model=To3DProTaskCreateResponse,
|
||||||
|
data=To3DUVTaskRequest(
|
||||||
|
File=To3DUVFileInput(
|
||||||
|
Type=file_format.upper(),
|
||||||
|
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
is_rate_limited=_is_tencent_rate_limited,
|
||||||
|
)
|
||||||
|
if response.Error:
|
||||||
|
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||||
|
result = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"),
|
||||||
|
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||||
|
response_model=To3DProTaskResultResponse,
|
||||||
|
status_extractor=lambda r: r.Status,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(
|
||||||
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||||
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||||
|
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="Tencent3DTextureEditNode",
|
||||||
|
display_name="Hunyuan3D: 3D Texture Edit",
|
||||||
|
category="api node/3d/Tencent",
|
||||||
|
description="After inputting the 3D model, perform 3D model texture redrawing.",
|
||||||
|
inputs=[
|
||||||
|
IO.MultiType.Input(
|
||||||
|
"model_3d",
|
||||||
|
types=[IO.File3DFBX, IO.File3DAny],
|
||||||
|
tooltip="3D model in FBX format. Model should have less than 100000 faces.",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.File3DGLB.Output(display_name="GLB"),
|
||||||
|
IO.File3DFBX.Output(display_name="FBX"),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr="""{"type":"usd","usd": 0.6}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_3d: Types.File3D,
|
||||||
|
prompt: str,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
_ = seed
|
||||||
|
file_format = model_3d.format.lower()
|
||||||
|
if file_format != "fbx":
|
||||||
|
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
|
||||||
|
validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
|
||||||
|
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
|
||||||
|
response_model=To3DProTaskCreateResponse,
|
||||||
|
data=TextureEditTaskRequest(
|
||||||
|
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
||||||
|
Prompt=prompt,
|
||||||
|
EnablePBR=True,
|
||||||
|
),
|
||||||
|
is_rate_limited=_is_tencent_rate_limited,
|
||||||
|
)
|
||||||
|
if response.Error:
|
||||||
|
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||||
|
|
||||||
|
result = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"),
|
||||||
|
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||||
|
response_model=To3DProTaskResultResponse,
|
||||||
|
status_extractor=lambda r: r.Status,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(
|
||||||
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
|
||||||
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Tencent3DPartNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="Tencent3DPartNode",
|
||||||
|
display_name="Hunyuan3D: 3D Part",
|
||||||
|
category="api node/3d/Tencent",
|
||||||
|
description="Automatically perform component identification and generation based on the model structure.",
|
||||||
|
inputs=[
|
||||||
|
IO.MultiType.Input(
|
||||||
|
"model_3d",
|
||||||
|
types=[IO.File3DFBX, IO.File3DAny],
|
||||||
|
tooltip="3D model in FBX format. Model should have less than 30000 faces.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed controls whether the node should re-run; "
|
||||||
|
"results are non-deterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.File3DFBX.Output(display_name="FBX"),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model_3d: Types.File3D,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
_ = seed
|
||||||
|
file_format = model_3d.format.lower()
|
||||||
|
if file_format != "fbx":
|
||||||
|
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
|
||||||
|
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
|
||||||
|
response_model=To3DProTaskCreateResponse,
|
||||||
|
data=To3DUVTaskRequest(
|
||||||
|
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
|
||||||
|
),
|
||||||
|
is_rate_limited=_is_tencent_rate_limited,
|
||||||
|
)
|
||||||
|
if response.Error:
|
||||||
|
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
|
||||||
|
result = await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"),
|
||||||
|
data=To3DProTaskQueryRequest(JobId=response.JobId),
|
||||||
|
response_model=To3DProTaskResultResponse,
|
||||||
|
status_extractor=lambda r: r.Status,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(
|
||||||
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -293,6 +561,9 @@ class TencentHunyuan3DExtension(ComfyExtension):
|
|||||||
return [
|
return [
|
||||||
TencentTextToModelNode,
|
TencentTextToModelNode,
|
||||||
TencentImageToModelNode,
|
TencentImageToModelNode,
|
||||||
|
# TencentModelTo3DUVNode,
|
||||||
|
# Tencent3DTextureEditNode,
|
||||||
|
Tencent3DPartNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
|
|||||||
validate_image_dimensions,
|
validate_image_dimensions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_EUR_TO_USD = 1.19
|
||||||
|
|
||||||
|
|
||||||
|
def _tier_price_eur(megapixels: float) -> float:
|
||||||
|
"""Price in EUR for a single Magnific upscaling step based on input megapixels."""
|
||||||
|
if megapixels <= 1.3:
|
||||||
|
return 0.143
|
||||||
|
if megapixels <= 3.0:
|
||||||
|
return 0.286
|
||||||
|
if megapixels <= 6.4:
|
||||||
|
return 0.429
|
||||||
|
return 1.716
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
|
||||||
|
"""Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
|
||||||
|
num_steps = int(math.log2(scale))
|
||||||
|
total_eur = 0.0
|
||||||
|
pixels = width * height
|
||||||
|
for _ in range(num_steps):
|
||||||
|
total_eur += _tier_price_eur(pixels / 1_000_000)
|
||||||
|
pixels *= 4
|
||||||
|
return round(total_eur * _EUR_TO_USD, 2)
|
||||||
|
|
||||||
|
|
||||||
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
$ad := widgets.auto_downscale;
|
||||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
$mins := $ad
|
||||||
|
? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
|
||||||
|
: {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||||
|
$maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
|
||||||
|
{
|
||||||
|
"type": "range_usd",
|
||||||
|
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||||
|
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||||
|
"format": { "approximate": true }
|
||||||
|
}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
|||||||
f"Use a smaller input image or lower scale factor."
|
f"Use a smaller input image or lower scale factor."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
final_height, final_width = get_image_dimensions(image)
|
||||||
|
actual_scale = int(scale_factor.rstrip("x"))
|
||||||
|
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
|
||||||
|
|
||||||
initial_res = await sync_op(
|
initial_res = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
|
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
|
||||||
@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
|||||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
|
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
|
||||||
response_model=TaskResponse,
|
response_model=TaskResponse,
|
||||||
status_extractor=lambda x: x.status,
|
status_extractor=lambda x: x.status,
|
||||||
|
price_extractor=lambda _: price_usd,
|
||||||
poll_interval=10.0,
|
poll_interval=10.0,
|
||||||
max_poll_attempts=480,
|
max_poll_attempts=480,
|
||||||
)
|
)
|
||||||
@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
|||||||
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
|
$mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
|
||||||
{"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
|
$maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
|
||||||
|
{
|
||||||
|
"type": "range_usd",
|
||||||
|
"min_usd": $lookup($mins, widgets.scale_factor),
|
||||||
|
"max_usd": $lookup($maxs, widgets.scale_factor),
|
||||||
|
"format": { "approximate": true }
|
||||||
|
}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
|||||||
f"Use a smaller input image or lower scale factor."
|
f"Use a smaller input image or lower scale factor."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
final_height, final_width = get_image_dimensions(image)
|
||||||
|
price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
|
||||||
|
|
||||||
initial_res = await sync_op(
|
initial_res = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
|
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
|
||||||
@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
|||||||
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
|
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
|
||||||
response_model=TaskResponse,
|
response_model=TaskResponse,
|
||||||
status_extractor=lambda x: x.status,
|
status_extractor=lambda x: x.status,
|
||||||
|
price_extractor=lambda _: price_usd,
|
||||||
poll_interval=10.0,
|
poll_interval=10.0,
|
||||||
max_poll_attempts=480,
|
max_poll_attempts=480,
|
||||||
)
|
)
|
||||||
@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension):
|
|||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
# MagnificImageUpscalerCreativeNode,
|
MagnificImageUpscalerCreativeNode,
|
||||||
# MagnificImageUpscalerPreciseV2Node,
|
MagnificImageUpscalerPreciseV2Node,
|
||||||
MagnificImageStyleTransferNode,
|
MagnificImageStyleTransferNode,
|
||||||
MagnificImageRelightNode,
|
MagnificImageRelightNode,
|
||||||
MagnificImageSkinEnhancerNode,
|
MagnificImageSkinEnhancerNode,
|
||||||
|
|||||||
@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
default=33,
|
default=80,
|
||||||
min=1,
|
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
|
||||||
max=100,
|
max=100,
|
||||||
step=1,
|
step=1,
|
||||||
tooltip="Number of denoising steps",
|
tooltip="Number of denoising steps",
|
||||||
@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
default=33,
|
default=60,
|
||||||
min=1,
|
min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
|
||||||
max=100,
|
max=100,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=IO.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
|||||||
video: Input.Video | None = None,
|
video: Input.Video | None = None,
|
||||||
control_type: str = "Motion Transfer",
|
control_type: str = "Motion Transfer",
|
||||||
motion_intensity: int | None = 100,
|
motion_intensity: int | None = 100,
|
||||||
steps=33,
|
steps=60,
|
||||||
prompt_adherence=4.5,
|
prompt_adherence=4.5,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validated_video = validate_video_to_video_input(video)
|
validated_video = validate_video_to_video_input(video)
|
||||||
@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
default=33,
|
default=80,
|
||||||
min=1,
|
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
|
||||||
max=100,
|
max=100,
|
||||||
step=1,
|
step=1,
|
||||||
tooltip="Inference steps",
|
tooltip="Inference steps",
|
||||||
|
|||||||
@ -43,7 +43,6 @@ class SupportedOpenAIModel(str, Enum):
|
|||||||
o1 = "o1"
|
o1 = "o1"
|
||||||
o3 = "o3"
|
o3 = "o3"
|
||||||
o1_pro = "o1-pro"
|
o1_pro = "o1-pro"
|
||||||
gpt_4o = "gpt-4o"
|
|
||||||
gpt_4_1 = "gpt-4.1"
|
gpt_4_1 = "gpt-4.1"
|
||||||
gpt_4_1_mini = "gpt-4.1-mini"
|
gpt_4_1_mini = "gpt-4.1-mini"
|
||||||
gpt_4_1_nano = "gpt-4.1-nano"
|
gpt_4_1_nano = "gpt-4.1-nano"
|
||||||
@ -649,11 +648,6 @@ class OpenAIChatNode(IO.ComfyNode):
|
|||||||
"usd": [0.01, 0.04],
|
"usd": [0.01, 0.04],
|
||||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||||
}
|
}
|
||||||
: $contains($m, "gpt-4o") ? {
|
|
||||||
"type": "list_usd",
|
|
||||||
"usd": [0.0025, 0.01],
|
|
||||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
|
||||||
}
|
|
||||||
: $contains($m, "gpt-4.1-nano") ? {
|
: $contains($m, "gpt-4.1-nano") ? {
|
||||||
"type": "list_usd",
|
"type": "list_usd",
|
||||||
"usd": [0.0001, 0.0004],
|
"usd": [0.0001, 0.0004],
|
||||||
|
|||||||
@ -33,6 +33,7 @@ from .download_helpers import (
|
|||||||
download_url_to_video_output,
|
download_url_to_video_output,
|
||||||
)
|
)
|
||||||
from .upload_helpers import (
|
from .upload_helpers import (
|
||||||
|
upload_3d_model_to_comfyapi,
|
||||||
upload_audio_to_comfyapi,
|
upload_audio_to_comfyapi,
|
||||||
upload_file_to_comfyapi,
|
upload_file_to_comfyapi,
|
||||||
upload_image_to_comfyapi,
|
upload_image_to_comfyapi,
|
||||||
@ -62,6 +63,7 @@ __all__ = [
|
|||||||
"sync_op",
|
"sync_op",
|
||||||
"sync_op_raw",
|
"sync_op_raw",
|
||||||
# Upload helpers
|
# Upload helpers
|
||||||
|
"upload_3d_model_to_comfyapi",
|
||||||
"upload_audio_to_comfyapi",
|
"upload_audio_to_comfyapi",
|
||||||
"upload_file_to_comfyapi",
|
"upload_file_to_comfyapi",
|
||||||
"upload_image_to_comfyapi",
|
"upload_image_to_comfyapi",
|
||||||
|
|||||||
@ -57,6 +57,7 @@ class _RequestConfig:
|
|||||||
files: dict[str, Any] | list[tuple[str, Any]] | None
|
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||||
multipart_parser: Callable | None
|
multipart_parser: Callable | None
|
||||||
max_retries: int
|
max_retries: int
|
||||||
|
max_retries_on_rate_limit: int
|
||||||
retry_delay: float
|
retry_delay: float
|
||||||
retry_backoff: float
|
retry_backoff: float
|
||||||
wait_label: str = "Waiting"
|
wait_label: str = "Waiting"
|
||||||
@ -65,6 +66,7 @@ class _RequestConfig:
|
|||||||
final_label_on_success: str | None = "Completed"
|
final_label_on_success: str | None = "Completed"
|
||||||
progress_origin_ts: float | None = None
|
progress_origin_ts: float | None = None
|
||||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||||
|
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -78,7 +80,7 @@ class _PollUIState:
|
|||||||
active_since: float | None = None # start time of current active interval (None if queued)
|
active_since: float | None = None # start time of current active interval (None if queued)
|
||||||
|
|
||||||
|
|
||||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
||||||
@ -103,6 +105,8 @@ async def sync_op(
|
|||||||
final_label_on_success: str | None = "Completed",
|
final_label_on_success: str | None = "Completed",
|
||||||
progress_origin_ts: float | None = None,
|
progress_origin_ts: float | None = None,
|
||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
|
max_retries_on_rate_limit: int = 16,
|
||||||
|
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||||
) -> M:
|
) -> M:
|
||||||
raw = await sync_op_raw(
|
raw = await sync_op_raw(
|
||||||
cls,
|
cls,
|
||||||
@ -122,6 +126,8 @@ async def sync_op(
|
|||||||
final_label_on_success=final_label_on_success,
|
final_label_on_success=final_label_on_success,
|
||||||
progress_origin_ts=progress_origin_ts,
|
progress_origin_ts=progress_origin_ts,
|
||||||
monitor_progress=monitor_progress,
|
monitor_progress=monitor_progress,
|
||||||
|
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||||
|
is_rate_limited=is_rate_limited,
|
||||||
)
|
)
|
||||||
if not isinstance(raw, dict):
|
if not isinstance(raw, dict):
|
||||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||||
@ -143,9 +149,9 @@ async def poll_op(
|
|||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 160,
|
max_poll_attempts: int = 160,
|
||||||
timeout_per_poll: float = 120.0,
|
timeout_per_poll: float = 120.0,
|
||||||
max_retries_per_poll: int = 3,
|
max_retries_per_poll: int = 10,
|
||||||
retry_delay_per_poll: float = 1.0,
|
retry_delay_per_poll: float = 1.0,
|
||||||
retry_backoff_per_poll: float = 2.0,
|
retry_backoff_per_poll: float = 1.4,
|
||||||
estimated_duration: int | None = None,
|
estimated_duration: int | None = None,
|
||||||
cancel_endpoint: ApiEndpoint | None = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
@ -194,6 +200,8 @@ async def sync_op_raw(
|
|||||||
final_label_on_success: str | None = "Completed",
|
final_label_on_success: str | None = "Completed",
|
||||||
progress_origin_ts: float | None = None,
|
progress_origin_ts: float | None = None,
|
||||||
monitor_progress: bool = True,
|
monitor_progress: bool = True,
|
||||||
|
max_retries_on_rate_limit: int = 16,
|
||||||
|
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||||
) -> dict[str, Any] | bytes:
|
) -> dict[str, Any] | bytes:
|
||||||
"""
|
"""
|
||||||
Make a single network request.
|
Make a single network request.
|
||||||
@ -222,6 +230,8 @@ async def sync_op_raw(
|
|||||||
final_label_on_success=final_label_on_success,
|
final_label_on_success=final_label_on_success,
|
||||||
progress_origin_ts=progress_origin_ts,
|
progress_origin_ts=progress_origin_ts,
|
||||||
price_extractor=price_extractor,
|
price_extractor=price_extractor,
|
||||||
|
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||||
|
is_rate_limited=is_rate_limited,
|
||||||
)
|
)
|
||||||
return await _request_base(cfg, expect_binary=as_binary)
|
return await _request_base(cfg, expect_binary=as_binary)
|
||||||
|
|
||||||
@ -240,9 +250,9 @@ async def poll_op_raw(
|
|||||||
poll_interval: float = 5.0,
|
poll_interval: float = 5.0,
|
||||||
max_poll_attempts: int = 160,
|
max_poll_attempts: int = 160,
|
||||||
timeout_per_poll: float = 120.0,
|
timeout_per_poll: float = 120.0,
|
||||||
max_retries_per_poll: int = 3,
|
max_retries_per_poll: int = 10,
|
||||||
retry_delay_per_poll: float = 1.0,
|
retry_delay_per_poll: float = 1.0,
|
||||||
retry_backoff_per_poll: float = 2.0,
|
retry_backoff_per_poll: float = 1.4,
|
||||||
estimated_duration: int | None = None,
|
estimated_duration: int | None = None,
|
||||||
cancel_endpoint: ApiEndpoint | None = None,
|
cancel_endpoint: ApiEndpoint | None = None,
|
||||||
cancel_timeout: float = 10.0,
|
cancel_timeout: float = 10.0,
|
||||||
@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
|
|||||||
if status == 409:
|
if status == 409:
|
||||||
return "There is a problem with your account. Please contact support@comfy.org."
|
return "There is a problem with your account. Please contact support@comfy.org."
|
||||||
if status == 429:
|
if status == 429:
|
||||||
return "Rate Limit Exceeded: Please try again later."
|
return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
|
||||||
try:
|
try:
|
||||||
if isinstance(body, dict):
|
if isinstance(body, dict):
|
||||||
err = body.get("error")
|
err = body.get("error")
|
||||||
@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||||
attempt = 0
|
attempt = 0
|
||||||
delay = cfg.retry_delay
|
delay = cfg.retry_delay
|
||||||
|
rate_limit_attempts = 0
|
||||||
|
rate_limit_delay = cfg.retry_delay
|
||||||
operation_succeeded: bool = False
|
operation_succeeded: bool = False
|
||||||
final_elapsed_seconds: int | None = None
|
final_elapsed_seconds: int | None = None
|
||||||
extracted_price: float | None = None
|
extracted_price: float | None = None
|
||||||
@ -653,17 +665,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
payload_headers["Content-Type"] = "application/json"
|
payload_headers["Content-Type"] = "application/json"
|
||||||
payload_kw["json"] = cfg.data or {}
|
payload_kw["json"] = cfg.data or {}
|
||||||
|
|
||||||
try:
|
request_logger.log_request_response(
|
||||||
request_logger.log_request_response(
|
operation_id=operation_id,
|
||||||
operation_id=operation_id,
|
request_method=method,
|
||||||
request_method=method,
|
request_url=url,
|
||||||
request_url=url,
|
request_headers=dict(payload_headers) if payload_headers else None,
|
||||||
request_headers=dict(payload_headers) if payload_headers else None,
|
request_params=dict(params) if params else None,
|
||||||
request_params=dict(params) if params else None,
|
request_data=request_body_log,
|
||||||
request_data=request_body_log,
|
)
|
||||||
)
|
|
||||||
except Exception as _log_e:
|
|
||||||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
|
||||||
|
|
||||||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||||||
req_task = asyncio.create_task(req_coro)
|
req_task = asyncio.create_task(req_coro)
|
||||||
@ -688,41 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
body = await resp.json()
|
body = await resp.json()
|
||||||
except (ContentTypeError, json.JSONDecodeError):
|
except (ContentTypeError, json.JSONDecodeError):
|
||||||
body = await resp.text()
|
body = await resp.text()
|
||||||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
should_retry = False
|
||||||
|
wait_time = 0.0
|
||||||
|
retry_label = ""
|
||||||
|
is_rl = resp.status == 429 or (
|
||||||
|
cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
|
||||||
|
)
|
||||||
|
if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
|
||||||
|
rate_limit_attempts += 1
|
||||||
|
wait_time = min(rate_limit_delay, 30.0)
|
||||||
|
rate_limit_delay *= cfg.retry_backoff
|
||||||
|
retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
|
||||||
|
should_retry = True
|
||||||
|
elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||||
|
wait_time = delay
|
||||||
|
delay *= cfg.retry_backoff
|
||||||
|
retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
|
||||||
|
should_retry = True
|
||||||
|
|
||||||
|
if should_retry:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
|
||||||
method,
|
method,
|
||||||
url,
|
url,
|
||||||
resp.status,
|
resp.status,
|
||||||
delay,
|
wait_time,
|
||||||
attempt,
|
retry_label,
|
||||||
cfg.max_retries,
|
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method=method,
|
|
||||||
request_url=url,
|
|
||||||
response_status_code=resp.status,
|
|
||||||
response_headers=dict(resp.headers),
|
|
||||||
response_content=body,
|
|
||||||
error_message=_friendly_http_message(resp.status, body),
|
|
||||||
)
|
|
||||||
except Exception as _log_e:
|
|
||||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
|
||||||
|
|
||||||
await sleep_with_interrupt(
|
|
||||||
delay,
|
|
||||||
cfg.node_cls,
|
|
||||||
cfg.wait_label if cfg.monitor_progress else None,
|
|
||||||
start_time if cfg.monitor_progress else None,
|
|
||||||
cfg.estimated_total,
|
|
||||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
|
||||||
)
|
|
||||||
delay *= cfg.retry_backoff
|
|
||||||
continue
|
|
||||||
msg = _friendly_http_message(resp.status, body)
|
|
||||||
try:
|
|
||||||
request_logger.log_request_response(
|
request_logger.log_request_response(
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
request_method=method,
|
request_method=method,
|
||||||
@ -730,10 +731,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
response_status_code=resp.status,
|
response_status_code=resp.status,
|
||||||
response_headers=dict(resp.headers),
|
response_headers=dict(resp.headers),
|
||||||
response_content=body,
|
response_content=body,
|
||||||
error_message=msg,
|
error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
|
||||||
)
|
)
|
||||||
except Exception as _log_e:
|
await sleep_with_interrupt(
|
||||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
wait_time,
|
||||||
|
cfg.node_cls,
|
||||||
|
cfg.wait_label if cfg.monitor_progress else None,
|
||||||
|
start_time if cfg.monitor_progress else None,
|
||||||
|
cfg.estimated_total,
|
||||||
|
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
msg = _friendly_http_message(resp.status, body)
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=dict(resp.headers),
|
||||||
|
response_content=body,
|
||||||
|
error_message=msg,
|
||||||
|
)
|
||||||
raise Exception(msg)
|
raise Exception(msg)
|
||||||
|
|
||||||
if expect_binary:
|
if expect_binary:
|
||||||
@ -753,17 +771,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
bytes_payload = bytes(buff)
|
bytes_payload = bytes(buff)
|
||||||
operation_succeeded = True
|
operation_succeeded = True
|
||||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
try:
|
request_logger.log_request_response(
|
||||||
request_logger.log_request_response(
|
operation_id=operation_id,
|
||||||
operation_id=operation_id,
|
request_method=method,
|
||||||
request_method=method,
|
request_url=url,
|
||||||
request_url=url,
|
response_status_code=resp.status,
|
||||||
response_status_code=resp.status,
|
response_headers=dict(resp.headers),
|
||||||
response_headers=dict(resp.headers),
|
response_content=bytes_payload,
|
||||||
response_content=bytes_payload,
|
)
|
||||||
)
|
|
||||||
except Exception as _log_e:
|
|
||||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
|
||||||
return bytes_payload
|
return bytes_payload
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
@ -780,45 +795,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||||
operation_succeeded = True
|
operation_succeeded = True
|
||||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
try:
|
request_logger.log_request_response(
|
||||||
request_logger.log_request_response(
|
operation_id=operation_id,
|
||||||
operation_id=operation_id,
|
request_method=method,
|
||||||
request_method=method,
|
request_url=url,
|
||||||
request_url=url,
|
response_status_code=resp.status,
|
||||||
response_status_code=resp.status,
|
response_headers=dict(resp.headers),
|
||||||
response_headers=dict(resp.headers),
|
response_content=response_content_to_log,
|
||||||
response_content=response_content_to_log,
|
)
|
||||||
)
|
|
||||||
except Exception as _log_e:
|
|
||||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
except ProcessingInterrupted:
|
except ProcessingInterrupted:
|
||||||
logging.debug("Polling was interrupted by user")
|
logging.debug("Polling was interrupted by user")
|
||||||
raise
|
raise
|
||||||
except (ClientError, OSError) as e:
|
except (ClientError, OSError) as e:
|
||||||
if attempt <= cfg.max_retries:
|
if (attempt - rate_limit_attempts) <= cfg.max_retries:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||||||
method,
|
method,
|
||||||
url,
|
url,
|
||||||
delay,
|
delay,
|
||||||
attempt,
|
attempt - rate_limit_attempts,
|
||||||
cfg.max_retries,
|
cfg.max_retries,
|
||||||
str(e),
|
str(e),
|
||||||
)
|
)
|
||||||
try:
|
request_logger.log_request_response(
|
||||||
request_logger.log_request_response(
|
operation_id=operation_id,
|
||||||
operation_id=operation_id,
|
request_method=method,
|
||||||
request_method=method,
|
request_url=url,
|
||||||
request_url=url,
|
request_headers=dict(payload_headers) if payload_headers else None,
|
||||||
request_headers=dict(payload_headers) if payload_headers else None,
|
request_params=dict(params) if params else None,
|
||||||
request_params=dict(params) if params else None,
|
request_data=request_body_log,
|
||||||
request_data=request_body_log,
|
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
)
|
||||||
)
|
|
||||||
except Exception as _log_e:
|
|
||||||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
|
||||||
await sleep_with_interrupt(
|
await sleep_with_interrupt(
|
||||||
delay,
|
delay,
|
||||||
cfg.node_cls,
|
cfg.node_cls,
|
||||||
@ -831,23 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
continue
|
continue
|
||||||
diag = await _diagnose_connectivity()
|
diag = await _diagnose_connectivity()
|
||||||
if not diag["internet_accessible"]:
|
if not diag["internet_accessible"]:
|
||||||
try:
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method=method,
|
|
||||||
request_url=url,
|
|
||||||
request_headers=dict(payload_headers) if payload_headers else None,
|
|
||||||
request_params=dict(params) if params else None,
|
|
||||||
request_data=request_body_log,
|
|
||||||
error_message=f"LocalNetworkError: {str(e)}",
|
|
||||||
)
|
|
||||||
except Exception as _log_e:
|
|
||||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
|
||||||
raise LocalNetworkError(
|
|
||||||
"Unable to connect to the API server due to local network issues. "
|
|
||||||
"Please check your internet connection and try again."
|
|
||||||
) from e
|
|
||||||
try:
|
|
||||||
request_logger.log_request_response(
|
request_logger.log_request_response(
|
||||||
operation_id=operation_id,
|
operation_id=operation_id,
|
||||||
request_method=method,
|
request_method=method,
|
||||||
@ -855,10 +847,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
request_headers=dict(payload_headers) if payload_headers else None,
|
request_headers=dict(payload_headers) if payload_headers else None,
|
||||||
request_params=dict(params) if params else None,
|
request_params=dict(params) if params else None,
|
||||||
request_data=request_body_log,
|
request_data=request_body_log,
|
||||||
error_message=f"ApiServerError: {str(e)}",
|
error_message=f"LocalNetworkError: {str(e)}",
|
||||||
)
|
)
|
||||||
except Exception as _log_e:
|
raise LocalNetworkError(
|
||||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
"Unable to connect to the API server due to local network issues. "
|
||||||
|
"Please check your internet connection and try again."
|
||||||
|
) from e
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
request_headers=dict(payload_headers) if payload_headers else None,
|
||||||
|
request_params=dict(params) if params else None,
|
||||||
|
request_data=request_body_log,
|
||||||
|
error_message=f"ApiServerError: {str(e)}",
|
||||||
|
)
|
||||||
raise ApiServerError(
|
raise ApiServerError(
|
||||||
f"The API server at {default_base_url()} is currently unreachable. "
|
f"The API server at {default_base_url()} is currently unreachable. "
|
||||||
f"The service may be experiencing issues."
|
f"The service may be experiencing issues."
|
||||||
|
|||||||
@ -57,7 +57,7 @@ def tensor_to_bytesio(
|
|||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
*,
|
*,
|
||||||
total_pixels: int | None = 2048 * 2048,
|
total_pixels: int | None = 2048 * 2048,
|
||||||
mime_type: str = "image/png",
|
mime_type: str | None = "image/png",
|
||||||
) -> BytesIO:
|
) -> BytesIO:
|
||||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||||
|
|
||||||
|
|||||||
@ -167,27 +167,25 @@ async def download_url_to_bytesio(
|
|||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
dest.seek(0)
|
dest.seek(0)
|
||||||
|
|
||||||
with contextlib.suppress(Exception):
|
request_logger.log_request_response(
|
||||||
request_logger.log_request_response(
|
operation_id=op_id,
|
||||||
operation_id=op_id,
|
request_method="GET",
|
||||||
request_method="GET",
|
request_url=url,
|
||||||
request_url=url,
|
response_status_code=resp.status,
|
||||||
response_status_code=resp.status,
|
response_headers=dict(resp.headers),
|
||||||
response_headers=dict(resp.headers),
|
response_content=f"[streamed {written} bytes to dest]",
|
||||||
response_content=f"[streamed {written} bytes to dest]",
|
)
|
||||||
)
|
|
||||||
return
|
return
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise ProcessingInterrupted("Task cancelled") from None
|
raise ProcessingInterrupted("Task cancelled") from None
|
||||||
except (ClientError, OSError) as e:
|
except (ClientError, OSError) as e:
|
||||||
if attempt <= max_retries:
|
if attempt <= max_retries:
|
||||||
with contextlib.suppress(Exception):
|
request_logger.log_request_response(
|
||||||
request_logger.log_request_response(
|
operation_id=op_id,
|
||||||
operation_id=op_id,
|
request_method="GET",
|
||||||
request_method="GET",
|
request_url=url,
|
||||||
request_url=url,
|
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
)
|
||||||
)
|
|
||||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||||
delay *= retry_backoff
|
delay *= retry_backoff
|
||||||
continue
|
continue
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from typing import Any
|
|||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
# Get the logger instance
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -91,38 +90,41 @@ def log_request_response(
|
|||||||
Filenames are sanitized and length-limited for cross-platform safety.
|
Filenames are sanitized and length-limited for cross-platform safety.
|
||||||
If we still fail to write, we fall back to appending into api.log.
|
If we still fail to write, we fall back to appending into api.log.
|
||||||
"""
|
"""
|
||||||
log_dir = get_log_directory()
|
|
||||||
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
|
||||||
|
|
||||||
log_content: list[str] = []
|
|
||||||
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
|
||||||
log_content.append(f"Operation ID: {operation_id}")
|
|
||||||
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
|
||||||
log_content.append(f"Method: {request_method}")
|
|
||||||
log_content.append(f"URL: {request_url}")
|
|
||||||
if request_headers:
|
|
||||||
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
|
||||||
if request_params:
|
|
||||||
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
|
||||||
if request_data is not None:
|
|
||||||
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
|
||||||
|
|
||||||
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
|
||||||
if response_status_code is not None:
|
|
||||||
log_content.append(f"Status Code: {response_status_code}")
|
|
||||||
if response_headers:
|
|
||||||
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
|
||||||
if response_content is not None:
|
|
||||||
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
|
||||||
if error_message:
|
|
||||||
log_content.append(f"Error:\n{error_message}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(filepath, "w", encoding="utf-8") as f:
|
log_dir = get_log_directory()
|
||||||
f.write("\n".join(log_content))
|
filepath = _build_log_filepath(log_dir, operation_id, request_url)
|
||||||
logger.debug("API log saved to: %s", filepath)
|
|
||||||
except Exception as e:
|
log_content: list[str] = []
|
||||||
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
|
||||||
|
log_content.append(f"Operation ID: {operation_id}")
|
||||||
|
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
|
||||||
|
log_content.append(f"Method: {request_method}")
|
||||||
|
log_content.append(f"URL: {request_url}")
|
||||||
|
if request_headers:
|
||||||
|
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
|
||||||
|
if request_params:
|
||||||
|
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
|
||||||
|
if request_data is not None:
|
||||||
|
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
|
||||||
|
|
||||||
|
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
|
||||||
|
if response_status_code is not None:
|
||||||
|
log_content.append(f"Status Code: {response_status_code}")
|
||||||
|
if response_headers:
|
||||||
|
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
|
||||||
|
if response_content is not None:
|
||||||
|
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
|
||||||
|
if error_message:
|
||||||
|
log_content.append(f"Error:\n{error_message}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
f.write("\n".join(log_content))
|
||||||
|
logger.debug("API log saved to: %s", filepath)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||||
|
except Exception as _log_e:
|
||||||
|
logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
@ -164,6 +164,27 @@ async def upload_video_to_comfyapi(
|
|||||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
|
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
|
||||||
|
|
||||||
|
|
||||||
|
_3D_MIME_TYPES = {
|
||||||
|
"glb": "model/gltf-binary",
|
||||||
|
"obj": "model/obj",
|
||||||
|
"fbx": "application/octet-stream",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_3d_model_to_comfyapi(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
model_3d: Types.File3D,
|
||||||
|
file_format: str,
|
||||||
|
) -> str:
|
||||||
|
"""Uploads a 3D model file to ComfyUI API and returns its download URL."""
|
||||||
|
return await upload_file_to_comfyapi(
|
||||||
|
cls,
|
||||||
|
model_3d.get_data(),
|
||||||
|
f"{uuid.uuid4()}.{file_format}",
|
||||||
|
_3D_MIME_TYPES.get(file_format, "application/octet-stream"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def upload_file_to_comfyapi(
|
async def upload_file_to_comfyapi(
|
||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
file_bytes_io: BytesIO,
|
file_bytes_io: BytesIO,
|
||||||
@ -255,17 +276,14 @@ async def upload_file(
|
|||||||
monitor_task = asyncio.create_task(_monitor())
|
monitor_task = asyncio.create_task(_monitor())
|
||||||
sess: aiohttp.ClientSession | None = None
|
sess: aiohttp.ClientSession | None = None
|
||||||
try:
|
try:
|
||||||
try:
|
request_logger.log_request_response(
|
||||||
request_logger.log_request_response(
|
operation_id=operation_id,
|
||||||
operation_id=operation_id,
|
request_method="PUT",
|
||||||
request_method="PUT",
|
request_url=upload_url,
|
||||||
request_url=upload_url,
|
request_headers=headers or None,
|
||||||
request_headers=headers or None,
|
request_params=None,
|
||||||
request_params=None,
|
request_data=f"[File data {len(data)} bytes]",
|
||||||
request_data=f"[File data {len(data)} bytes]",
|
)
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
|
||||||
|
|
||||||
sess = aiohttp.ClientSession(timeout=timeout)
|
sess = aiohttp.ClientSession(timeout=timeout)
|
||||||
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||||
@ -311,31 +329,27 @@ async def upload_file(
|
|||||||
delay *= retry_backoff
|
delay *= retry_backoff
|
||||||
continue
|
continue
|
||||||
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
||||||
try:
|
request_logger.log_request_response(
|
||||||
request_logger.log_request_response(
|
operation_id=operation_id,
|
||||||
operation_id=operation_id,
|
request_method="PUT",
|
||||||
request_method="PUT",
|
request_url=upload_url,
|
||||||
request_url=upload_url,
|
response_status_code=resp.status,
|
||||||
response_status_code=resp.status,
|
response_headers=dict(resp.headers),
|
||||||
response_headers=dict(resp.headers),
|
response_content="File uploaded successfully.",
|
||||||
response_content="File uploaded successfully.",
|
)
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
|
||||||
return
|
return
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise ProcessingInterrupted("Task cancelled") from None
|
raise ProcessingInterrupted("Task cancelled") from None
|
||||||
except (aiohttp.ClientError, OSError) as e:
|
except (aiohttp.ClientError, OSError) as e:
|
||||||
if attempt <= max_retries:
|
if attempt <= max_retries:
|
||||||
with contextlib.suppress(Exception):
|
request_logger.log_request_response(
|
||||||
request_logger.log_request_response(
|
operation_id=operation_id,
|
||||||
operation_id=operation_id,
|
request_method="PUT",
|
||||||
request_method="PUT",
|
request_url=upload_url,
|
||||||
request_url=upload_url,
|
request_headers=headers or None,
|
||||||
request_headers=headers or None,
|
request_data=f"[File data {len(data)} bytes]",
|
||||||
request_data=f"[File data {len(data)} bytes]",
|
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
)
|
||||||
)
|
|
||||||
await sleep_with_interrupt(
|
await sleep_with_interrupt(
|
||||||
delay,
|
delay,
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
@ -20,10 +20,60 @@ class JobStatus:
|
|||||||
|
|
||||||
|
|
||||||
# Media types that can be previewed in the frontend
|
# Media types that can be previewed in the frontend
|
||||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
|
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
|
||||||
|
|
||||||
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
# 3D file extensions for preview fallback (no dedicated media_type exists)
|
||||||
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
|
THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
|
||||||
|
|
||||||
|
|
||||||
|
def has_3d_extension(filename: str) -> bool:
|
||||||
|
lower = filename.lower()
|
||||||
|
return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_output_item(item):
|
||||||
|
"""Normalize a single output list item for the jobs API.
|
||||||
|
|
||||||
|
Returns the normalized item, or None to exclude it.
|
||||||
|
String items with 3D extensions become {filename, type, subfolder} dicts.
|
||||||
|
"""
|
||||||
|
if item is None:
|
||||||
|
return None
|
||||||
|
if isinstance(item, str):
|
||||||
|
if has_3d_extension(item):
|
||||||
|
return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
|
||||||
|
return None
|
||||||
|
if isinstance(item, dict):
|
||||||
|
return item
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_outputs(outputs: dict) -> dict:
|
||||||
|
"""Normalize raw node outputs for the jobs API.
|
||||||
|
|
||||||
|
Transforms string 3D filenames into file output dicts and removes
|
||||||
|
None items. All other items (non-3D strings, dicts, etc.) are
|
||||||
|
preserved as-is.
|
||||||
|
"""
|
||||||
|
normalized = {}
|
||||||
|
for node_id, node_outputs in outputs.items():
|
||||||
|
if not isinstance(node_outputs, dict):
|
||||||
|
normalized[node_id] = node_outputs
|
||||||
|
continue
|
||||||
|
normalized_node = {}
|
||||||
|
for media_type, items in node_outputs.items():
|
||||||
|
if media_type == 'animated' or not isinstance(items, list):
|
||||||
|
normalized_node[media_type] = items
|
||||||
|
continue
|
||||||
|
normalized_items = []
|
||||||
|
for item in items:
|
||||||
|
if item is None:
|
||||||
|
continue
|
||||||
|
norm = normalize_output_item(item)
|
||||||
|
normalized_items.append(norm if norm is not None else item)
|
||||||
|
normalized_node[media_type] = normalized_items
|
||||||
|
normalized[node_id] = normalized_node
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
|
||||||
@ -45,9 +95,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
|
|||||||
Maintains backwards compatibility with existing logic.
|
Maintains backwards compatibility with existing logic.
|
||||||
|
|
||||||
Priority:
|
Priority:
|
||||||
1. media_type is 'images', 'video', or 'audio'
|
1. media_type is 'images', 'video', 'audio', or '3d'
|
||||||
2. format field starts with 'video/' or 'audio/'
|
2. format field starts with 'video/' or 'audio/'
|
||||||
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
|
3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
|
||||||
"""
|
"""
|
||||||
if media_type in PREVIEWABLE_MEDIA_TYPES:
|
if media_type in PREVIEWABLE_MEDIA_TYPES:
|
||||||
return True
|
return True
|
||||||
@ -139,7 +189,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
|
|||||||
})
|
})
|
||||||
|
|
||||||
if include_outputs:
|
if include_outputs:
|
||||||
job['outputs'] = outputs
|
job['outputs'] = normalize_outputs(outputs)
|
||||||
job['execution_status'] = status_info
|
job['execution_status'] = status_info
|
||||||
job['workflow'] = {
|
job['workflow'] = {
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
@ -171,18 +221,23 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for item in items:
|
for item in items:
|
||||||
count += 1
|
normalized = normalize_output_item(item)
|
||||||
|
if normalized is None:
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if preview_output is None and is_previewable(media_type, item):
|
count += 1
|
||||||
|
|
||||||
|
if preview_output is not None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(normalized, dict) and is_previewable(media_type, normalized):
|
||||||
enriched = {
|
enriched = {
|
||||||
**item,
|
**normalized,
|
||||||
'nodeId': node_id,
|
'nodeId': node_id,
|
||||||
'mediaType': media_type
|
|
||||||
}
|
}
|
||||||
if item.get('type') == 'output':
|
if 'mediaType' not in normalized:
|
||||||
|
enriched['mediaType'] = media_type
|
||||||
|
if normalized.get('type') == 'output':
|
||||||
preview_output = enriched
|
preview_output = enriched
|
||||||
elif fallback_preview is None:
|
elif fallback_preview is None:
|
||||||
fallback_preview = enriched
|
fallback_preview = enriched
|
||||||
|
|||||||
@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
|
|||||||
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
|
||||||
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
|
||||||
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
|
||||||
|
io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
|
||||||
],
|
],
|
||||||
outputs=[io.Conditioning.Output()],
|
outputs=[io.Conditioning.Output()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
|
def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
|
||||||
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
|
tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
|
||||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||||
return io.NodeOutput(conditioning)
|
return io.NodeOutput(conditioning)
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import logging
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from tqdm.auto import trange
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
|
|||||||
"full_diff": LORAType.FULL_DIFF}
|
"full_diff": LORAType.FULL_DIFF}
|
||||||
|
|
||||||
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
|
||||||
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
|
comfy.model_management.load_models_gpu([model_diff])
|
||||||
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
|
||||||
|
|
||||||
for k in sd:
|
sd_keys = list(sd.keys())
|
||||||
if k.endswith(".weight"):
|
for index in trange(len(sd_keys), unit="weight"):
|
||||||
|
k = sd_keys[index]
|
||||||
|
op_keys = sd_keys[index].rsplit('.', 1)
|
||||||
|
if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
|
||||||
|
continue
|
||||||
|
op = comfy.utils.get_attr(model_diff.model, op_keys[0])
|
||||||
|
if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
|
||||||
|
weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
|
||||||
|
else:
|
||||||
weight_diff = sd[k]
|
weight_diff = sd[k]
|
||||||
|
|
||||||
|
if op_keys[1] == "weight":
|
||||||
if lora_type == LORAType.STANDARD:
|
if lora_type == LORAType.STANDARD:
|
||||||
if weight_diff.ndim < 2:
|
if weight_diff.ndim < 2:
|
||||||
if bias_diff:
|
if bias_diff:
|
||||||
@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
|
|||||||
elif lora_type == LORAType.FULL_DIFF:
|
elif lora_type == LORAType.FULL_DIFF:
|
||||||
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
|
||||||
|
|
||||||
elif bias_diff and k.endswith(".bias"):
|
elif bias_diff and op_keys[1] == "bias":
|
||||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
|
||||||
return output_sd
|
return output_sd
|
||||||
|
|
||||||
class LoraSave(io.ComfyNode):
|
class LoraSave(io.ComfyNode):
|
||||||
|
|||||||
@ -655,6 +655,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
|
|||||||
batched = batch_masks(values)
|
batched = batch_masks(values)
|
||||||
return io.NodeOutput(batched)
|
return io.NodeOutput(batched)
|
||||||
|
|
||||||
|
|
||||||
class PostProcessingExtension(ComfyExtension):
|
class PostProcessingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
|||||||
103
comfy_extras/nodes_replacements.py
Normal file
103
comfy_extras/nodes_replacements.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
from comfy_api.latest import ComfyExtension, io, ComfyAPI
|
||||||
|
|
||||||
|
api = ComfyAPI()
|
||||||
|
|
||||||
|
|
||||||
|
async def register_replacements():
|
||||||
|
"""Register all built-in node replacements."""
|
||||||
|
await register_replacements_longeredge()
|
||||||
|
await register_replacements_batchimages()
|
||||||
|
await register_replacements_upscaleimage()
|
||||||
|
await register_replacements_controlnet()
|
||||||
|
await register_replacements_load3d()
|
||||||
|
await register_replacements_preview3d()
|
||||||
|
await register_replacements_svdimg2vid()
|
||||||
|
await register_replacements_conditioningavg()
|
||||||
|
|
||||||
|
async def register_replacements_longeredge():
|
||||||
|
# No dynamic inputs here
|
||||||
|
await api.node_replacement.register(io.NodeReplace(
|
||||||
|
new_node_id="ImageScaleToMaxDimension",
|
||||||
|
old_node_id="ResizeImagesByLongerEdge",
|
||||||
|
old_widget_ids=["longer_edge"],
|
||||||
|
input_mapping=[
|
||||||
|
{"new_id": "image", "old_id": "images"},
|
||||||
|
{"new_id": "largest_size", "old_id": "longer_edge"},
|
||||||
|
{"new_id": "upscale_method", "set_value": "lanczos"},
|
||||||
|
],
|
||||||
|
# just to test the frontend output_mapping code, does nothing really here
|
||||||
|
output_mapping=[{"new_idx": 0, "old_idx": 0}],
|
||||||
|
))
|
||||||
|
|
||||||
|
async def register_replacements_batchimages():
|
||||||
|
# BatchImages node uses Autogrow
|
||||||
|
await api.node_replacement.register(io.NodeReplace(
|
||||||
|
new_node_id="BatchImagesNode",
|
||||||
|
old_node_id="ImageBatch",
|
||||||
|
input_mapping=[
|
||||||
|
{"new_id": "images.image0", "old_id": "image1"},
|
||||||
|
{"new_id": "images.image1", "old_id": "image2"},
|
||||||
|
],
|
||||||
|
))
|
||||||
|
|
||||||
|
async def register_replacements_upscaleimage():
|
||||||
|
# ResizeImageMaskNode uses DynamicCombo
|
||||||
|
await api.node_replacement.register(io.NodeReplace(
|
||||||
|
new_node_id="ResizeImageMaskNode",
|
||||||
|
old_node_id="ImageScaleBy",
|
||||||
|
old_widget_ids=["upscale_method", "scale_by"],
|
||||||
|
input_mapping=[
|
||||||
|
{"new_id": "input", "old_id": "image"},
|
||||||
|
{"new_id": "resize_type", "set_value": "scale by multiplier"},
|
||||||
|
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
|
||||||
|
{"new_id": "scale_method", "old_id": "upscale_method"},
|
||||||
|
],
|
||||||
|
))
|
||||||
|
|
||||||
|
async def register_replacements_controlnet():
|
||||||
|
# T2IAdapterLoader → ControlNetLoader
|
||||||
|
await api.node_replacement.register(io.NodeReplace(
|
||||||
|
new_node_id="ControlNetLoader",
|
||||||
|
old_node_id="T2IAdapterLoader",
|
||||||
|
input_mapping=[
|
||||||
|
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
|
||||||
|
],
|
||||||
|
))
|
||||||
|
|
||||||
|
async def register_replacements_load3d():
|
||||||
|
# Load3DAnimation merged into Load3D
|
||||||
|
await api.node_replacement.register(io.NodeReplace(
|
||||||
|
new_node_id="Load3D",
|
||||||
|
old_node_id="Load3DAnimation",
|
||||||
|
))
|
||||||
|
|
||||||
|
async def register_replacements_preview3d():
|
||||||
|
# Preview3DAnimation merged into Preview3D
|
||||||
|
await api.node_replacement.register(io.NodeReplace(
|
||||||
|
new_node_id="Preview3D",
|
||||||
|
old_node_id="Preview3DAnimation",
|
||||||
|
))
|
||||||
|
|
||||||
|
async def register_replacements_svdimg2vid():
|
||||||
|
# Typo fix: SDV → SVD
|
||||||
|
await api.node_replacement.register(io.NodeReplace(
|
||||||
|
new_node_id="SVD_img2vid_Conditioning",
|
||||||
|
old_node_id="SDV_img2vid_Conditioning",
|
||||||
|
))
|
||||||
|
|
||||||
|
async def register_replacements_conditioningavg():
|
||||||
|
# Typo fix: trailing space in node name
|
||||||
|
await api.node_replacement.register(io.NodeReplace(
|
||||||
|
new_node_id="ConditioningAverage",
|
||||||
|
old_node_id="ConditioningAverage ",
|
||||||
|
))
|
||||||
|
|
||||||
|
class NodeReplacementsExtension(ComfyExtension):
|
||||||
|
async def on_load(self) -> None:
|
||||||
|
await register_replacements()
|
||||||
|
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> NodeReplacementsExtension:
|
||||||
|
return NodeReplacementsExtension()
|
||||||
@ -4,6 +4,7 @@ import os
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|||||||
"""
|
"""
|
||||||
CFGGuider with modifications for training specific logic
|
CFGGuider with modifications for training specific logic
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, offloading=False, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.offloading = offloading
|
||||||
|
|
||||||
def outer_sample(
|
def outer_sample(
|
||||||
self,
|
self,
|
||||||
noise,
|
noise,
|
||||||
@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
|
|||||||
noise.shape,
|
noise.shape,
|
||||||
self.conds,
|
self.conds,
|
||||||
self.model_options,
|
self.model_options,
|
||||||
force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
|
force_full_load=not self.offloading,
|
||||||
|
force_offload=self.offloading,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
torch.cuda.empty_cache()
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
if denoise_mask is not None:
|
if denoise_mask is not None:
|
||||||
@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def patch(m):
|
def find_modules_at_depth(
|
||||||
|
model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
|
||||||
|
) -> list[nn.Module]:
|
||||||
|
"""
|
||||||
|
Find modules at a specific depth level for gradient checkpointing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model to search
|
||||||
|
depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
|
||||||
|
result: Accumulator for results
|
||||||
|
current_depth: Current recursion depth
|
||||||
|
name: Current module name for logging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of modules at the target depth
|
||||||
|
"""
|
||||||
|
if result is None:
|
||||||
|
result = []
|
||||||
|
name = name or "root"
|
||||||
|
|
||||||
|
# Skip container modules (they don't have meaningful forward)
|
||||||
|
is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
|
||||||
|
has_forward = hasattr(model, "forward") and not is_container
|
||||||
|
|
||||||
|
if has_forward:
|
||||||
|
current_depth += 1
|
||||||
|
if current_depth == depth:
|
||||||
|
result.append(model)
|
||||||
|
logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Recurse into children
|
||||||
|
for next_name, child in model.named_children():
|
||||||
|
find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadCheckpointFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Gradient checkpointing that works with weight offloading.
|
||||||
|
|
||||||
|
Forward: no_grad -> compute -> weights can be freed
|
||||||
|
Backward: enable_grad -> recompute -> backward -> weights can be freed
|
||||||
|
|
||||||
|
For single input, single output modules (Linear, Conv*).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, x: torch.Tensor, forward_fn):
|
||||||
|
ctx.save_for_backward(x)
|
||||||
|
ctx.forward_fn = forward_fn
|
||||||
|
with torch.no_grad():
|
||||||
|
return forward_fn(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out: torch.Tensor):
|
||||||
|
x, = ctx.saved_tensors
|
||||||
|
forward_fn = ctx.forward_fn
|
||||||
|
|
||||||
|
# Clear context early
|
||||||
|
ctx.forward_fn = None
|
||||||
|
|
||||||
|
with torch.enable_grad():
|
||||||
|
x_detached = x.detach().requires_grad_(True)
|
||||||
|
y = forward_fn(x_detached)
|
||||||
|
y.backward(grad_out)
|
||||||
|
grad_x = x_detached.grad
|
||||||
|
|
||||||
|
# Explicit cleanup
|
||||||
|
del y, x_detached, forward_fn
|
||||||
|
|
||||||
|
return grad_x, None
|
||||||
|
|
||||||
|
|
||||||
|
def patch(m, offloading=False):
|
||||||
if not hasattr(m, "forward"):
|
if not hasattr(m, "forward"):
|
||||||
return
|
return
|
||||||
org_forward = m.forward
|
org_forward = m.forward
|
||||||
|
|
||||||
def fwd(args, kwargs):
|
# Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
|
||||||
return org_forward(*args, **kwargs)
|
if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
||||||
|
def checkpointing_fwd(x):
|
||||||
|
return OffloadCheckpointFunction.apply(x, org_forward)
|
||||||
|
# Branch 2: Others -> standard checkpoint
|
||||||
|
else:
|
||||||
|
def fwd(args, kwargs):
|
||||||
|
return org_forward(*args, **kwargs)
|
||||||
|
|
||||||
def checkpointing_fwd(*args, **kwargs):
|
def checkpointing_fwd(*args, **kwargs):
|
||||||
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
|
||||||
|
|
||||||
m.org_forward = org_forward
|
m.org_forward = org_forward
|
||||||
m.forward = checkpointing_fwd
|
m.forward = checkpointing_fwd
|
||||||
@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
default=True,
|
default=True,
|
||||||
tooltip="Use gradient checkpointing for training.",
|
tooltip="Use gradient checkpointing for training.",
|
||||||
),
|
),
|
||||||
|
io.Int.Input(
|
||||||
|
"checkpoint_depth",
|
||||||
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=5,
|
||||||
|
tooltip="Depth level for gradient checkpointing.",
|
||||||
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"offloading",
|
||||||
|
default=False,
|
||||||
|
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
||||||
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"existing_lora",
|
"existing_lora",
|
||||||
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
options=folder_paths.get_filename_list("loras") + ["[None]"],
|
||||||
@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype,
|
lora_dtype,
|
||||||
algorithm,
|
algorithm,
|
||||||
gradient_checkpointing,
|
gradient_checkpointing,
|
||||||
|
checkpoint_depth,
|
||||||
|
offloading,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
bucket_mode,
|
bucket_mode,
|
||||||
bypass_mode,
|
bypass_mode,
|
||||||
@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype = lora_dtype[0]
|
lora_dtype = lora_dtype[0]
|
||||||
algorithm = algorithm[0]
|
algorithm = algorithm[0]
|
||||||
gradient_checkpointing = gradient_checkpointing[0]
|
gradient_checkpointing = gradient_checkpointing[0]
|
||||||
|
offloading = offloading[0]
|
||||||
|
checkpoint_depth = checkpoint_depth[0]
|
||||||
existing_lora = existing_lora[0]
|
existing_lora = existing_lora[0]
|
||||||
bucket_mode = bucket_mode[0]
|
bucket_mode = bucket_mode[0]
|
||||||
bypass_mode = bypass_mode[0]
|
bypass_mode = bypass_mode[0]
|
||||||
@ -1019,6 +1124,15 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
|
||||||
|
if mp.is_dynamic():
|
||||||
|
if not bypass_mode:
|
||||||
|
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
||||||
|
bypass_mode = True
|
||||||
|
offloading = True
|
||||||
|
elif offloading:
|
||||||
|
if not bypass_mode:
|
||||||
|
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# Prepare latents and compute counts
|
||||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||||
latents, dtype, bucket_mode
|
latents, dtype, bucket_mode
|
||||||
@ -1054,16 +1168,18 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Setup gradient checkpointing
|
# Setup gradient checkpointing
|
||||||
if gradient_checkpointing:
|
if gradient_checkpointing:
|
||||||
for m in find_all_highest_child_module_with_forward(
|
modules_to_patch = find_modules_at_depth(
|
||||||
mp.model.diffusion_model
|
mp.model.diffusion_model, depth=checkpoint_depth
|
||||||
):
|
)
|
||||||
patch(m)
|
logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
|
||||||
|
for m in modules_to_patch:
|
||||||
|
patch(m, offloading=offloading)
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
# With force_full_load=False we should be able to have offloading
|
# With force_full_load=False we should be able to have offloading
|
||||||
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
|
||||||
comfy.model_management.load_models_gpu(
|
comfy.model_management.load_models_gpu(
|
||||||
[mp], memory_required=1e20, force_full_load=True
|
[mp], memory_required=1e20, force_full_load=not offloading
|
||||||
)
|
)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -1100,7 +1216,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
guider = TrainGuider(mp)
|
guider = TrainGuider(mp, offloading=offloading)
|
||||||
guider.set_conds(positive)
|
guider.set_conds(positive)
|
||||||
|
|
||||||
# Inject bypass hooks if bypass mode is enabled
|
# Inject bypass hooks if bypass mode is enabled
|
||||||
@ -1113,6 +1229,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Run training loop
|
# Run training loop
|
||||||
try:
|
try:
|
||||||
|
comfy.model_management.in_training = True
|
||||||
_run_training_loop(
|
_run_training_loop(
|
||||||
guider,
|
guider,
|
||||||
train_sampler,
|
train_sampler,
|
||||||
@ -1123,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
multi_res,
|
multi_res,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
comfy.model_management.in_training = False
|
||||||
# Eject bypass hooks if they were injected
|
# Eject bypass hooks if they were injected
|
||||||
if bypass_injections is not None:
|
if bypass_injections is not None:
|
||||||
for injection in bypass_injections:
|
for injection in bypass_injections:
|
||||||
@ -1132,19 +1250,20 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
unpatch(m)
|
unpatch(m)
|
||||||
del train_sampler, optimizer
|
del train_sampler, optimizer
|
||||||
|
|
||||||
# Finalize adapters
|
for param in lora_sd:
|
||||||
|
lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
|
||||||
|
|
||||||
for adapter in all_weight_adapters:
|
for adapter in all_weight_adapters:
|
||||||
adapter.requires_grad_(False)
|
adapter.requires_grad_(False)
|
||||||
|
del adapter
|
||||||
for param in lora_sd:
|
del all_weight_adapters
|
||||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
|
||||||
|
|
||||||
# mp in train node is highly specialized for training
|
# mp in train node is highly specialized for training
|
||||||
# use it in inference will result in bad behavior so we don't return it
|
# use it in inference will result in bad behavior so we don't return it
|
||||||
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
|
||||||
|
|
||||||
|
|
||||||
class LoraModelLoader(io.ComfyNode):#
|
class LoraModelLoader(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
@ -1166,6 +1285,11 @@ class LoraModelLoader(io.ComfyNode):#
|
|||||||
max=100.0,
|
max=100.0,
|
||||||
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
tooltip="How strongly to modify the diffusion model. This value can be negative.",
|
||||||
),
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"bypass",
|
||||||
|
default=False,
|
||||||
|
tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(
|
io.Model.Output(
|
||||||
@ -1175,13 +1299,18 @@ class LoraModelLoader(io.ComfyNode):#
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, lora, strength_model):
|
def execute(cls, model, lora, strength_model, bypass=False):
|
||||||
if strength_model == 0:
|
if strength_model == 0:
|
||||||
return io.NodeOutput(model)
|
return io.NodeOutput(model)
|
||||||
|
|
||||||
model_lora, _ = comfy.sd.load_lora_for_models(
|
if bypass:
|
||||||
model, None, lora, strength_model, 0
|
model_lora, _ = comfy.sd.load_bypass_lora_for_models(
|
||||||
)
|
model, None, lora, strength_model, 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_lora, _ = comfy.sd.load_lora_for_models(
|
||||||
|
model, None, lora, strength_model, 0
|
||||||
|
)
|
||||||
return io.NodeOutput(model_lora)
|
return io.NodeOutput(model_lora)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -202,6 +202,56 @@ class LoadVideo(io.ComfyNode):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
class VideoSlice(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="Video Slice",
|
||||||
|
display_name="Video Slice",
|
||||||
|
search_aliases=[
|
||||||
|
"trim video duration",
|
||||||
|
"skip first frames",
|
||||||
|
"frame load cap",
|
||||||
|
"start time",
|
||||||
|
],
|
||||||
|
category="image/video",
|
||||||
|
inputs=[
|
||||||
|
io.Video.Input("video"),
|
||||||
|
io.Float.Input(
|
||||||
|
"start_time",
|
||||||
|
default=0.0,
|
||||||
|
max=1e5,
|
||||||
|
min=-1e5,
|
||||||
|
step=0.001,
|
||||||
|
tooltip="Start time in seconds",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"duration",
|
||||||
|
default=0.0,
|
||||||
|
min=0.0,
|
||||||
|
step=0.001,
|
||||||
|
tooltip="Duration in seconds, or 0 for unlimited duration",
|
||||||
|
),
|
||||||
|
io.Boolean.Input(
|
||||||
|
"strict_duration",
|
||||||
|
default=False,
|
||||||
|
tooltip="If True, when the specified duration is not possible, an error will be raised.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Video.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
|
||||||
|
trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
|
||||||
|
if trimmed is not None:
|
||||||
|
return io.NodeOutput(trimmed)
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VideoExtension(ComfyExtension):
|
class VideoExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
@ -212,6 +262,7 @@ class VideoExtension(ComfyExtension):
|
|||||||
CreateVideo,
|
CreateVideo,
|
||||||
GetVideoComponents,
|
GetVideoComponents,
|
||||||
LoadVideo,
|
LoadVideo,
|
||||||
|
VideoSlice,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> VideoExtension:
|
async def comfy_entrypoint() -> VideoExtension:
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.12.3"
|
__version__ = "0.13.0"
|
||||||
|
|||||||
@ -13,8 +13,11 @@ from contextlib import nullcontext
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy_aimdo.model_vbar
|
||||||
|
|
||||||
from latent_preview import set_preview_method
|
from latent_preview import set_preview_method
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_execution.caching import (
|
from comfy_execution.caching import (
|
||||||
@ -527,8 +530,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||||
finally:
|
finally:
|
||||||
if allocator is not None:
|
if allocator is not None:
|
||||||
|
if args.verbose == "DEBUG":
|
||||||
|
comfy_aimdo.model_vbar.vbars_analyze()
|
||||||
comfy.model_management.reset_cast_buffers()
|
comfy.model_management.reset_cast_buffers()
|
||||||
torch.cuda.synchronize()
|
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||||
|
|
||||||
if has_pending_tasks:
|
if has_pending_tasks:
|
||||||
pending_async_nodes[unique_id] = output_data
|
pending_async_nodes[unique_id] = output_data
|
||||||
@ -618,6 +623,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||||
logging.error("Got an OOM, unloading all loaded models.")
|
logging.error("Got an OOM, unloading all loaded models.")
|
||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
|
elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:
|
||||||
|
tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected."
|
||||||
|
|
||||||
error_details = {
|
error_details = {
|
||||||
"node_id": real_node_id,
|
"node_id": real_node_id,
|
||||||
|
|||||||
2
nodes.py
2
nodes.py
@ -2264,6 +2264,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
|
|||||||
if not isinstance(extension, ComfyExtension):
|
if not isinstance(extension, ComfyExtension):
|
||||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
|
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
|
||||||
return False
|
return False
|
||||||
|
await extension.on_load()
|
||||||
node_list = await extension.get_node_list()
|
node_list = await extension.get_node_list()
|
||||||
if not isinstance(node_list, list):
|
if not isinstance(node_list, list):
|
||||||
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
|
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
|
||||||
@ -2435,6 +2436,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_lora_debug.py",
|
"nodes_lora_debug.py",
|
||||||
"nodes_color.py",
|
"nodes_color.py",
|
||||||
"nodes_toolkit.py",
|
"nodes_toolkit.py",
|
||||||
|
"nodes_replacements.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.12.3"
|
version = "0.13.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.38.13
|
comfyui-frontend-package==1.38.14
|
||||||
comfyui-workflow-templates==0.8.31
|
comfyui-workflow-templates==0.8.38
|
||||||
comfyui-embedded-docs==0.4.1
|
comfyui-embedded-docs==0.4.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
@ -22,7 +22,7 @@ alembic
|
|||||||
SQLAlchemy
|
SQLAlchemy
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.7
|
comfy-kitchen>=0.2.7
|
||||||
comfy-aimdo>=0.1.7
|
comfy-aimdo>=0.1.8
|
||||||
requests
|
requests
|
||||||
|
|
||||||
#non essential dependencies:
|
#non essential dependencies:
|
||||||
|
|||||||
@ -40,6 +40,7 @@ from app.user_manager import UserManager
|
|||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
from app.subgraph_manager import SubgraphManager
|
from app.subgraph_manager import SubgraphManager
|
||||||
|
from app.node_replace_manager import NodeReplaceManager
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
from protocol import BinaryEventTypes
|
from protocol import BinaryEventTypes
|
||||||
@ -204,6 +205,7 @@ class PromptServer():
|
|||||||
self.model_file_manager = ModelFileManager()
|
self.model_file_manager = ModelFileManager()
|
||||||
self.custom_node_manager = CustomNodeManager()
|
self.custom_node_manager = CustomNodeManager()
|
||||||
self.subgraph_manager = SubgraphManager()
|
self.subgraph_manager = SubgraphManager()
|
||||||
|
self.node_replace_manager = NodeReplaceManager()
|
||||||
self.internal_routes = InternalRoutes(self)
|
self.internal_routes = InternalRoutes(self)
|
||||||
self.supports = ["custom_nodes_from_web"]
|
self.supports = ["custom_nodes_from_web"]
|
||||||
self.prompt_queue = execution.PromptQueue(self)
|
self.prompt_queue = execution.PromptQueue(self)
|
||||||
@ -887,6 +889,8 @@ class PromptServer():
|
|||||||
if "partial_execution_targets" in json_data:
|
if "partial_execution_targets" in json_data:
|
||||||
partial_execution_targets = json_data["partial_execution_targets"]
|
partial_execution_targets = json_data["partial_execution_targets"]
|
||||||
|
|
||||||
|
self.node_replace_manager.apply_replacements(prompt)
|
||||||
|
|
||||||
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
|
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
|
||||||
extra_data = {}
|
extra_data = {}
|
||||||
if "extra_data" in json_data:
|
if "extra_data" in json_data:
|
||||||
@ -995,6 +999,7 @@ class PromptServer():
|
|||||||
self.model_file_manager.add_routes(self.routes)
|
self.model_file_manager.add_routes(self.routes)
|
||||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||||
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||||
|
self.node_replace_manager.add_routes(self.routes)
|
||||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||||
|
|
||||||
# Prefix every route with /api for easier matching for delegation.
|
# Prefix every route with /api for easier matching for delegation.
|
||||||
|
|||||||
@ -5,8 +5,11 @@ from comfy_execution.jobs import (
|
|||||||
is_previewable,
|
is_previewable,
|
||||||
normalize_queue_item,
|
normalize_queue_item,
|
||||||
normalize_history_item,
|
normalize_history_item,
|
||||||
|
normalize_output_item,
|
||||||
|
normalize_outputs,
|
||||||
get_outputs_summary,
|
get_outputs_summary,
|
||||||
apply_sorting,
|
apply_sorting,
|
||||||
|
has_3d_extension,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -35,8 +38,8 @@ class TestIsPreviewable:
|
|||||||
"""Unit tests for is_previewable()"""
|
"""Unit tests for is_previewable()"""
|
||||||
|
|
||||||
def test_previewable_media_types(self):
|
def test_previewable_media_types(self):
|
||||||
"""Images, video, audio media types should be previewable."""
|
"""Images, video, audio, 3d media types should be previewable."""
|
||||||
for media_type in ['images', 'video', 'audio']:
|
for media_type in ['images', 'video', 'audio', '3d']:
|
||||||
assert is_previewable(media_type, {}) is True
|
assert is_previewable(media_type, {}) is True
|
||||||
|
|
||||||
def test_non_previewable_media_types(self):
|
def test_non_previewable_media_types(self):
|
||||||
@ -46,7 +49,7 @@ class TestIsPreviewable:
|
|||||||
|
|
||||||
def test_3d_extensions_previewable(self):
|
def test_3d_extensions_previewable(self):
|
||||||
"""3D file extensions should be previewable regardless of media_type."""
|
"""3D file extensions should be previewable regardless of media_type."""
|
||||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||||
item = {'filename': f'model{ext}'}
|
item = {'filename': f'model{ext}'}
|
||||||
assert is_previewable('files', item) is True
|
assert is_previewable('files', item) is True
|
||||||
|
|
||||||
@ -160,7 +163,7 @@ class TestGetOutputsSummary:
|
|||||||
|
|
||||||
def test_3d_files_previewable(self):
|
def test_3d_files_previewable(self):
|
||||||
"""3D file extensions should be previewable."""
|
"""3D file extensions should be previewable."""
|
||||||
for ext in ['.obj', '.fbx', '.gltf', '.glb']:
|
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||||
outputs = {
|
outputs = {
|
||||||
'node1': {
|
'node1': {
|
||||||
'files': [{'filename': f'model{ext}', 'type': 'output'}]
|
'files': [{'filename': f'model{ext}', 'type': 'output'}]
|
||||||
@ -192,6 +195,64 @@ class TestGetOutputsSummary:
|
|||||||
assert preview['mediaType'] == 'images'
|
assert preview['mediaType'] == 'images'
|
||||||
assert preview['subfolder'] == 'outputs'
|
assert preview['subfolder'] == 'outputs'
|
||||||
|
|
||||||
|
def test_string_3d_filename_creates_preview(self):
|
||||||
|
"""String items with 3D extensions should synthesize a preview (Preview3D node output).
|
||||||
|
Only the .glb counts — nulls and non-file strings are excluded."""
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'result': ['preview3d_abc123.glb', None, None]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert count == 1
|
||||||
|
assert preview is not None
|
||||||
|
assert preview['filename'] == 'preview3d_abc123.glb'
|
||||||
|
assert preview['mediaType'] == '3d'
|
||||||
|
assert preview['nodeId'] == 'node1'
|
||||||
|
assert preview['type'] == 'output'
|
||||||
|
|
||||||
|
def test_string_non_3d_filename_no_preview(self):
|
||||||
|
"""String items without 3D extensions should not create a preview."""
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'result': ['data.json', None]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert count == 0
|
||||||
|
assert preview is None
|
||||||
|
|
||||||
|
def test_string_3d_filename_used_as_fallback(self):
|
||||||
|
"""String 3D preview should be used when no dict items are previewable."""
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'latents': [{'filename': 'latent.safetensors'}],
|
||||||
|
},
|
||||||
|
'node2': {
|
||||||
|
'result': ['model.glb', None]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
count, preview = get_outputs_summary(outputs)
|
||||||
|
assert preview is not None
|
||||||
|
assert preview['filename'] == 'model.glb'
|
||||||
|
assert preview['mediaType'] == '3d'
|
||||||
|
|
||||||
|
|
||||||
|
class TestHas3DExtension:
|
||||||
|
"""Unit tests for has_3d_extension()"""
|
||||||
|
|
||||||
|
def test_recognized_extensions(self):
|
||||||
|
for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
|
||||||
|
assert has_3d_extension(f'model{ext}') is True
|
||||||
|
|
||||||
|
def test_case_insensitive(self):
|
||||||
|
assert has_3d_extension('MODEL.GLB') is True
|
||||||
|
assert has_3d_extension('Scene.GLTF') is True
|
||||||
|
|
||||||
|
def test_non_3d_extensions(self):
|
||||||
|
for name in ['photo.png', 'video.mp4', 'data.json', 'model']:
|
||||||
|
assert has_3d_extension(name) is False
|
||||||
|
|
||||||
|
|
||||||
class TestApplySorting:
|
class TestApplySorting:
|
||||||
"""Unit tests for apply_sorting()"""
|
"""Unit tests for apply_sorting()"""
|
||||||
@ -395,3 +456,142 @@ class TestNormalizeHistoryItem:
|
|||||||
'prompt': {'nodes': {'1': {}}},
|
'prompt': {'nodes': {'1': {}}},
|
||||||
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
|
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def test_include_outputs_normalizes_3d_strings(self):
|
||||||
|
"""Detail view should transform string 3D filenames into file output dicts."""
|
||||||
|
history_item = {
|
||||||
|
'prompt': (
|
||||||
|
5,
|
||||||
|
'prompt-3d',
|
||||||
|
{'nodes': {}},
|
||||||
|
{'create_time': 1234567890},
|
||||||
|
['node1'],
|
||||||
|
),
|
||||||
|
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
||||||
|
'outputs': {
|
||||||
|
'node1': {
|
||||||
|
'result': ['preview3d_abc123.glb', None, None]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
job = normalize_history_item('prompt-3d', history_item, include_outputs=True)
|
||||||
|
|
||||||
|
assert job['outputs_count'] == 1
|
||||||
|
result_items = job['outputs']['node1']['result']
|
||||||
|
assert len(result_items) == 1
|
||||||
|
assert result_items[0] == {
|
||||||
|
'filename': 'preview3d_abc123.glb',
|
||||||
|
'type': 'output',
|
||||||
|
'subfolder': '',
|
||||||
|
'mediaType': '3d',
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_include_outputs_preserves_dict_items(self):
|
||||||
|
"""Detail view normalization should pass dict items through unchanged."""
|
||||||
|
history_item = {
|
||||||
|
'prompt': (
|
||||||
|
5,
|
||||||
|
'prompt-img',
|
||||||
|
{'nodes': {}},
|
||||||
|
{'create_time': 1234567890},
|
||||||
|
['node1'],
|
||||||
|
),
|
||||||
|
'status': {'status_str': 'success', 'completed': True, 'messages': []},
|
||||||
|
'outputs': {
|
||||||
|
'node1': {
|
||||||
|
'images': [
|
||||||
|
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
job = normalize_history_item('prompt-img', history_item, include_outputs=True)
|
||||||
|
|
||||||
|
assert job['outputs_count'] == 1
|
||||||
|
assert job['outputs']['node1']['images'] == [
|
||||||
|
{'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeOutputItem:
|
||||||
|
"""Unit tests for normalize_output_item()"""
|
||||||
|
|
||||||
|
def test_none_returns_none(self):
|
||||||
|
assert normalize_output_item(None) is None
|
||||||
|
|
||||||
|
def test_string_3d_extension_synthesizes_dict(self):
|
||||||
|
result = normalize_output_item('model.glb')
|
||||||
|
assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
|
||||||
|
|
||||||
|
def test_string_non_3d_extension_returns_none(self):
|
||||||
|
assert normalize_output_item('data.json') is None
|
||||||
|
|
||||||
|
def test_string_no_extension_returns_none(self):
|
||||||
|
assert normalize_output_item('camera_info_string') is None
|
||||||
|
|
||||||
|
def test_dict_passes_through(self):
|
||||||
|
item = {'filename': 'test.png', 'type': 'output'}
|
||||||
|
assert normalize_output_item(item) is item
|
||||||
|
|
||||||
|
def test_other_types_return_none(self):
|
||||||
|
assert normalize_output_item(42) is None
|
||||||
|
assert normalize_output_item(True) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeOutputs:
|
||||||
|
"""Unit tests for normalize_outputs()"""
|
||||||
|
|
||||||
|
def test_empty_outputs(self):
|
||||||
|
assert normalize_outputs({}) == {}
|
||||||
|
|
||||||
|
def test_dict_items_pass_through(self):
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'images': [{'filename': 'a.png', 'type': 'output'}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = normalize_outputs(outputs)
|
||||||
|
assert result == outputs
|
||||||
|
|
||||||
|
def test_3d_string_synthesized(self):
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'result': ['model.glb', None, None],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = normalize_outputs(outputs)
|
||||||
|
assert result == {
|
||||||
|
'node1': {
|
||||||
|
'result': [
|
||||||
|
{'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_animated_key_preserved(self):
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'images': [{'filename': 'a.png', 'type': 'output'}],
|
||||||
|
'animated': [True],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = normalize_outputs(outputs)
|
||||||
|
assert result['node1']['animated'] == [True]
|
||||||
|
|
||||||
|
def test_non_dict_node_outputs_preserved(self):
|
||||||
|
outputs = {'node1': 'unexpected_value'}
|
||||||
|
result = normalize_outputs(outputs)
|
||||||
|
assert result == {'node1': 'unexpected_value'}
|
||||||
|
|
||||||
|
def test_none_items_filtered_but_other_types_preserved(self):
|
||||||
|
outputs = {
|
||||||
|
'node1': {
|
||||||
|
'result': ['data.json', None, [1, 2, 3]],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result = normalize_outputs(outputs)
|
||||||
|
assert result == {
|
||||||
|
'node1': {
|
||||||
|
'result': ['data.json', [1, 2, 3]],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user