mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-22 20:40:49 +08:00
Fix EliGen RoPE shape to match apply_rope1() requirements
This commit is contained in:
commit
242037fa32
@ -0,0 +1,3 @@
|
||||
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
@ -1,2 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
|
||||
@ -1,2 +1,3 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||
pause
|
||||
|
||||
4
.github/workflows/release-stable-all.yml
vendored
4
.github/workflows/release-stable-all.yml
vendored
@ -18,9 +18,9 @@ jobs:
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu129"
|
||||
cache_tag: "cu130"
|
||||
python_minor: "13"
|
||||
python_patch: "6"
|
||||
python_patch: "9"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: ""
|
||||
test_release: true
|
||||
|
||||
@ -17,7 +17,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "129"
|
||||
default: "130"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@ -29,7 +29,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "6"
|
||||
default: "9"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
|
||||
@ -176,6 +176,8 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
@ -201,6 +203,8 @@ Python 3.14 will work if you comment out the `kornia` dependency in the requirem
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
### Instructions:
|
||||
|
||||
Git clone this repo.
|
||||
|
||||
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
|
||||
|
||||
112
app/subgraph_manager.py
Normal file
112
app/subgraph_manager.py
Normal file
@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypedDict
|
||||
import os
|
||||
import folder_paths
|
||||
import glob
|
||||
from aiohttp import web
|
||||
import hashlib
|
||||
|
||||
|
||||
class Source:
|
||||
custom_node = "custom_node"
|
||||
|
||||
class SubgraphEntry(TypedDict):
|
||||
source: str
|
||||
"""
|
||||
Source of subgraph - custom_nodes vs templates.
|
||||
"""
|
||||
path: str
|
||||
"""
|
||||
Relative path of the subgraph file.
|
||||
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
|
||||
"""
|
||||
name: str
|
||||
"""
|
||||
Name of subgraph file.
|
||||
"""
|
||||
info: CustomNodeSubgraphEntryInfo
|
||||
"""
|
||||
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
||||
"""
|
||||
data: str
|
||||
|
||||
class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||
node_pack: str
|
||||
"""Node pack name."""
|
||||
|
||||
class SubgraphManager:
|
||||
def __init__(self):
|
||||
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||
|
||||
async def load_entry_data(self, entry: SubgraphEntry):
|
||||
with open(entry['path'], 'r') as f:
|
||||
entry['data'] = f.read()
|
||||
return entry
|
||||
|
||||
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
|
||||
if entry is None:
|
||||
return None
|
||||
entry = entry.copy()
|
||||
entry.pop('path', None)
|
||||
if remove_data:
|
||||
entry.pop('data', None)
|
||||
return entry
|
||||
|
||||
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
|
||||
entries = entries.copy()
|
||||
for key in list(entries.keys()):
|
||||
entries[key] = await self.sanitize_entry(entries[key], remove_data)
|
||||
return entries
|
||||
|
||||
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||
# if not forced to reload and cached, return cache
|
||||
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||
return self.cached_custom_node_subgraphs
|
||||
# Load subgraphs from custom nodes
|
||||
subfolder = "subgraphs"
|
||||
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||
|
||||
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||
matched_files = glob.glob(pattern)
|
||||
for file in matched_files:
|
||||
# replace backslashes with forward slashes
|
||||
file = file.replace('\\', '/')
|
||||
info: CustomNodeSubgraphEntryInfo = {
|
||||
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||
}
|
||||
source = Source.custom_node
|
||||
# hash source + path to make sure id will be as unique as possible, but
|
||||
# reproducible across backend reloads
|
||||
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||
entry: SubgraphEntry = {
|
||||
"source": Source.custom_node,
|
||||
"name": os.path.splitext(os.path.basename(file))[0],
|
||||
"path": file,
|
||||
"info": info,
|
||||
}
|
||||
subgraphs_dict[id] = entry
|
||||
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||
return subgraphs_dict
|
||||
|
||||
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||
if entry is not None and entry.get('data', None) is None:
|
||||
await self.load_entry_data(entry)
|
||||
return entry
|
||||
|
||||
def add_routes(self, routes, loadedModules):
|
||||
@routes.get("/global_subgraphs")
|
||||
async def get_global_subgraphs(request):
|
||||
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||
# that's the reasoning for the current implementation
|
||||
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||
|
||||
@routes.get("/global_subgraphs/{id}")
|
||||
async def get_global_subgraph(request):
|
||||
id = request.match_info.get("id", None)
|
||||
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||
return web.json_response(await self.sanitize_entry(subgraph))
|
||||
@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -144,8 +145,9 @@ class PerformanceFeature(enum.Enum):
|
||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||
CublasOps = "cublas_ops"
|
||||
AutoTune = "autotune"
|
||||
PinnedMem = "pinned_memory"
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||
|
||||
@ -310,11 +310,13 @@ class ControlLoraOps:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||
def __init__(
|
||||
@ -350,12 +352,13 @@ class ControlLoraOps:
|
||||
|
||||
|
||||
def forward(self, input):
|
||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||
if self.up is not None:
|
||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
else:
|
||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
class ControlLora(ControlNet):
|
||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||
|
||||
@ -195,8 +195,8 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
|
||||
@ -7,15 +7,7 @@ import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
q_shape = q.shape
|
||||
k_shape = k.shape
|
||||
|
||||
if pe is not None:
|
||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
||||
|
||||
q, k = apply_rope(q, k, pe)
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
@ -3,12 +3,11 @@ from torch import nn
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.modules.attention
|
||||
import comfy.ldm.common_dit
|
||||
from einops import rearrange
|
||||
import math
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
@ -238,20 +237,6 @@ class FeedForward(nn.Module):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
||||
cos_freqs = freqs_cis[0]
|
||||
sin_freqs = freqs_cis[1]
|
||||
|
||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
||||
t1, t2 = t_dup.unbind(dim=-1)
|
||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
||||
|
||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@ -281,8 +266,8 @@ class CrossAttention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
|
||||
if pe is not None:
|
||||
q = apply_rotary_emb(q, pe)
|
||||
k = apply_rotary_emb(k, pe)
|
||||
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
||||
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
||||
|
||||
if mask is None:
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module):
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
||||
x += self.ff(y) * gate_mlp
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
x.addcmul_(self.ff(y), gate_mlp)
|
||||
|
||||
return x
|
||||
|
||||
@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
|
||||
|
||||
|
||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||
dtype = torch.float32 #self.dtype
|
||||
dtype = torch.float32
|
||||
device = indices_grid.device
|
||||
|
||||
# Get fractional positions and compute frequency indices
|
||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
||||
|
||||
start = 1
|
||||
end = theta
|
||||
device = fractional_positions.device
|
||||
# Compute frequencies and apply cos/sin
|
||||
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
||||
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
indices = theta ** (
|
||||
torch.linspace(
|
||||
math.log(start, theta),
|
||||
math.log(end, theta),
|
||||
dim // 6,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
)
|
||||
indices = indices.to(dtype=dtype)
|
||||
|
||||
indices = indices * math.pi / 2
|
||||
|
||||
freqs = (
|
||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
||||
.transpose(-1, -2)
|
||||
.flatten(2)
|
||||
)
|
||||
|
||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
# Pad if dim is not divisible by 6
|
||||
if dim % 6 != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
||||
padding_size = dim % 6
|
||||
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
||||
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||
|
||||
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||
|
||||
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||
freqs_cis = torch.stack([
|
||||
torch.stack([cos_vals, -sin_vals], dim=-1),
|
||||
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class LTXVModel(torch.nn.Module):
|
||||
@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module):
|
||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||
x = self.norm_out(x)
|
||||
# Modulation
|
||||
x = x * (1 + scale) + shift
|
||||
x = torch.addcmul(x, x, scale).add_(shift)
|
||||
x = self.proj_out(x)
|
||||
|
||||
x = self.patchifier.unpatchify(
|
||||
|
||||
@ -522,7 +522,7 @@ class NextDiT(nn.Module):
|
||||
max_cap_len = max(l_effective_cap_len)
|
||||
max_img_len = max(l_effective_img_len)
|
||||
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
@ -531,10 +531,22 @@ class NextDiT(nn.Module):
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
assert H_tokens * W_tokens == img_len
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
|
||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
class GELU(nn.Module):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||
@ -135,45 +135,34 @@ class Attention(nn.Module):
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
transformer_options={},
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = hidden_states.shape[0]
|
||||
seq_img = hidden_states.shape[1]
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
|
||||
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
||||
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
|
||||
|
||||
img_query = self.norm_q(img_query)
|
||||
img_key = self.norm_k(img_key)
|
||||
txt_query = self.norm_added_q(txt_query)
|
||||
txt_key = self.norm_added_k(txt_key)
|
||||
|
||||
# Concatenate text and image streams
|
||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
||||
joint_query = torch.cat([txt_query, img_query], dim=2)
|
||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||
|
||||
# Apply RoPE to concatenated queries and keys
|
||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
# Validate attention mask shape if provided
|
||||
if attention_mask is not None:
|
||||
expected_seq = joint_query.shape[1]
|
||||
if attention_mask.shape[-1] != expected_seq:
|
||||
raise ValueError(
|
||||
f"Attention mask shape mismatch: {attention_mask.shape} "
|
||||
f"doesn't match sequence length {expected_seq}"
|
||||
)
|
||||
|
||||
# Use ComfyUI's optimized attention
|
||||
joint_query = joint_query.flatten(start_dim=2)
|
||||
joint_key = joint_key.flatten(start_dim=2)
|
||||
joint_value = joint_value.flatten(start_dim=2)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attention_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
|
||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||
@ -427,7 +416,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
device=latents.device
|
||||
).reshape(1, -1, 1).repeat(1, 1, 3)
|
||||
|
||||
entity_rope = self.pe_embedder(entity_ids).squeeze(1).squeeze(0)
|
||||
entity_rope = self.pe_embedder(entity_ids) # Keep shape [1, 1, seq, dim, 2, 2]
|
||||
entity_txt_embs.append(entity_rope)
|
||||
|
||||
# Generate global text RoPE
|
||||
@ -436,9 +425,9 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
max_vid_index + global_seq_len,
|
||||
device=latents.device
|
||||
).reshape(1, -1, 1).repeat(1, 1, 3)
|
||||
global_rope = self.pe_embedder(global_ids).squeeze(1).squeeze(0)
|
||||
global_rope = self.pe_embedder(global_ids) # Keep shape [1, 1, seq, dim, 2, 2]
|
||||
|
||||
txt_rotary_emb = torch.cat(entity_txt_embs + [global_rope], dim=0)
|
||||
txt_rotary_emb = torch.cat(entity_txt_embs + [global_rope], dim=2) # Concatenate on sequence dimension
|
||||
|
||||
h_coords = torch.arange(-(patch_h - patch_h // 2), patch_h // 2, device=latents.device)
|
||||
w_coords = torch.arange(-(patch_w - patch_w // 2), patch_w // 2, device=latents.device)
|
||||
@ -449,13 +438,13 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
img_ids[:, :, 2] = w_coords.unsqueeze(0)
|
||||
img_ids = img_ids.reshape(1, -1, 3)
|
||||
|
||||
img_rope = self.pe_embedder(img_ids).squeeze(1).squeeze(0)
|
||||
img_rope = self.pe_embedder(img_ids) # Keep shape [1, 1, seq, dim, 2, 2]
|
||||
|
||||
logging.debug(f"[EliGen Model] RoPE shapes - img: {img_rope.shape}, txt: {txt_rotary_emb.shape}")
|
||||
|
||||
# Concatenate text and image RoPE embeddings
|
||||
# Convert to latent dtype to match queries/keys
|
||||
image_rotary_emb = torch.cat([txt_rotary_emb, img_rope], dim=0).unsqueeze(1).to(dtype=latents.dtype)
|
||||
# Concatenate text and image RoPE embeddings on sequence dimension
|
||||
# Shape will be [1, 1, total_seq, dim, 2, 2] where total_seq = txt_seq + img_seq
|
||||
image_rotary_emb = torch.cat([txt_rotary_emb, img_rope], dim=2).to(dtype=latents.dtype)
|
||||
|
||||
# Prepare spatial masks
|
||||
repeat_dim = latents.shape[1]
|
||||
@ -684,7 +673,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
|
||||
@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module):
|
||||
# assert e[0].dtype == torch.float32
|
||||
|
||||
# self-attention
|
||||
x = x.contiguous() # otherwise implicit in LayerNorm
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, transformer_options=transformer_options)
|
||||
@ -588,7 +589,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
@ -601,10 +602,22 @@ class WanModel(torch.nn.Module):
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
@ -630,7 +643,7 @@ class WanModel(torch.nn.Module):
|
||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||
t_len += 1
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
|
||||
@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module):
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
@ -197,8 +197,14 @@ class BaseModel(torch.nn.Module):
|
||||
extra_conds[o] = extra
|
||||
|
||||
t = self.process_timestep(t, x=x, **extra_conds)
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
|
||||
|
||||
def process_timestep(self, timestep, **kwargs):
|
||||
return timestep
|
||||
@ -327,6 +333,14 @@ class BaseModel(torch.nn.Module):
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
# Save mixed precision metadata
|
||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||
metadata = {
|
||||
"format_version": "1.0",
|
||||
"layers": self.model_config.layer_quant_config
|
||||
}
|
||||
unet_state_dict["_quantization_metadata"] = metadata
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
|
||||
@ -6,6 +6,20 @@ import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
|
||||
def detect_layer_quantization(metadata):
|
||||
quant_key = "_quantization_metadata"
|
||||
if metadata is not None and quant_key in metadata:
|
||||
quant_metadata = metadata.pop(quant_key)
|
||||
quant_metadata = json.loads(quant_metadata)
|
||||
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||
return quant_metadata["layers"]
|
||||
else:
|
||||
raise ValueError("Invalid quantization metadata format")
|
||||
return None
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
layer_quant_config = detect_layer_quantization(metadata)
|
||||
if layer_quant_config:
|
||||
model_config.layer_quant_config = layer_quant_config
|
||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||
|
||||
return model_config
|
||||
|
||||
def unet_prefix_from_state_dict(state_dict):
|
||||
|
||||
@ -89,6 +89,7 @@ if args.deterministic:
|
||||
|
||||
directml_enabled = False
|
||||
if args.directml is not None:
|
||||
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
|
||||
import torch_directml
|
||||
directml_enabled = True
|
||||
device_index = args.directml
|
||||
@ -330,15 +331,21 @@ except:
|
||||
|
||||
|
||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||
|
||||
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||
|
||||
try:
|
||||
if is_amd():
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||
|
||||
try:
|
||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
@ -992,12 +999,6 @@ def device_supports_non_blocking(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
def device_should_use_non_blocking(device):
|
||||
if not device_supports_non_blocking(device):
|
||||
return False
|
||||
return False
|
||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
||||
|
||||
def force_channels_last():
|
||||
if args.force_channels_last:
|
||||
return True
|
||||
@ -1012,6 +1013,16 @@ if args.async_offload:
|
||||
NUM_STREAMS = 2
|
||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||
|
||||
def current_stream(device):
|
||||
if device is None:
|
||||
return None
|
||||
if is_device_cuda(device):
|
||||
return torch.cuda.current_stream()
|
||||
elif is_device_xpu(device):
|
||||
return torch.xpu.current_stream()
|
||||
else:
|
||||
return None
|
||||
|
||||
stream_counters = {}
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
@ -1020,21 +1031,17 @@ def get_offload_stream(device):
|
||||
|
||||
if device in STREAMS:
|
||||
ss = STREAMS[device]
|
||||
s = ss[stream_counter]
|
||||
#Sync the oldest stream in the queue with the current
|
||||
ss[stream_counter].wait_stream(current_stream(device))
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
if is_device_cuda(device):
|
||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
||||
elif is_device_xpu(device):
|
||||
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return ss[stream_counter]
|
||||
elif is_device_cuda(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
elif is_device_xpu(device):
|
||||
@ -1043,18 +1050,14 @@ def get_offload_stream(device):
|
||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counter = (stream_counter + 1) % len(ss)
|
||||
stream_counters[device] = stream_counter
|
||||
return s
|
||||
return None
|
||||
|
||||
def sync_stream(device, stream):
|
||||
if stream is None:
|
||||
if stream is None or current_stream(device) is None:
|
||||
return
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
elif is_device_xpu(device):
|
||||
torch.xpu.current_stream().wait_stream(stream)
|
||||
current_stream(device).wait_stream(stream)
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||
if device is None or weight.device == device:
|
||||
@ -1079,6 +1082,60 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
non_blocking = device_supports_non_blocking(device)
|
||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
PINNED_MEMORY = {}
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
if PerformanceFeature.PinnedMem in args.fast:
|
||||
if WINDOWS:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
||||
else:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
else:
|
||||
MAX_PINNED_MEMORY = -1
|
||||
|
||||
def pin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
||||
PINNED_MEMORY[ptr] = size
|
||||
TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def unpin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||
if len(PINNED_MEMORY) == 0:
|
||||
TOTAL_PINNED_MEMORY = 0
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
@ -1331,7 +1388,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
|
||||
if is_amd():
|
||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
||||
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
|
||||
if manual_cast:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -238,6 +238,7 @@ class ModelPatcher:
|
||||
self.force_cast_weights = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self.parent = None
|
||||
self.pinned = set()
|
||||
|
||||
self.attachments: dict[str] = {}
|
||||
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
||||
@ -275,6 +276,9 @@ class ModelPatcher:
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def loaded_size(self):
|
||||
return self.model.model_loaded_weight_memory
|
||||
|
||||
@ -294,6 +298,7 @@ class ModelPatcher:
|
||||
n.backup = self.backup
|
||||
n.object_patches_backup = self.object_patches_backup
|
||||
n.parent = self
|
||||
n.pinned = self.pinned
|
||||
|
||||
n.force_cast_weights = self.force_cast_weights
|
||||
|
||||
@ -450,6 +455,19 @@ class ModelPatcher:
|
||||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
rope_options["scale_y"] = scale_y
|
||||
rope_options["scale_t"] = scale_t
|
||||
|
||||
rope_options["shift_x"] = shift_x
|
||||
rope_options["shift_y"] = shift_y
|
||||
rope_options["shift_t"] = shift_t
|
||||
|
||||
self.model_options["transformer_options"]["rope_options"] = rope_options
|
||||
|
||||
|
||||
def add_object_patch(self, name, obj):
|
||||
self.object_patches[name] = obj
|
||||
|
||||
@ -618,6 +636,21 @@ class ModelPatcher:
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def pin_weight_to_device(self, key):
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
if comfy.model_management.pin_memory(weight):
|
||||
self.pinned.add(key)
|
||||
|
||||
def unpin_weight(self, key):
|
||||
if key in self.pinned:
|
||||
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||
comfy.model_management.unpin_memory(weight)
|
||||
self.pinned.remove(key)
|
||||
|
||||
def unpin_all_weights(self):
|
||||
for key in list(self.pinned):
|
||||
self.unpin_weight(key)
|
||||
|
||||
def _load_list(self):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
@ -639,9 +672,11 @@ class ModelPatcher:
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
lowvram_mem_counter = 0
|
||||
loading = self._load_list()
|
||||
|
||||
load_completely = []
|
||||
offloaded = []
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
@ -658,6 +693,7 @@ class ModelPatcher:
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
lowvram_mem_counter += module_mem
|
||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||
continue
|
||||
|
||||
@ -683,6 +719,7 @@ class ModelPatcher:
|
||||
patch_counter += 1
|
||||
|
||||
cast_weight = True
|
||||
offloaded.append((module_mem, n, m, params))
|
||||
else:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
wipe_lowvram_weight(m)
|
||||
@ -713,7 +750,9 @@ class ModelPatcher:
|
||||
continue
|
||||
|
||||
for param in params:
|
||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
||||
key = "{}.{}".format(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
@ -721,11 +760,17 @@ class ModelPatcher:
|
||||
for x in load_completely:
|
||||
x[2].to(device_to)
|
||||
|
||||
for x in offloaded:
|
||||
n = x[1]
|
||||
params = x[3]
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
@ -762,6 +807,7 @@ class ModelPatcher:
|
||||
self.eject_model()
|
||||
if unpatch_weights:
|
||||
self.unpatch_hooks()
|
||||
self.unpin_all_weights()
|
||||
if self.model.model_lowvram:
|
||||
for m in self.model.modules():
|
||||
move_weight_functions(m, device_to)
|
||||
@ -857,6 +903,9 @@ class ModelPatcher:
|
||||
memory_freed += module_mem
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
@ -1259,5 +1308,6 @@ class ModelPatcher:
|
||||
self.clear_cached_hook_weights()
|
||||
|
||||
def __del__(self):
|
||||
self.unpin_all_weights()
|
||||
self.detach(unpatch_all=False)
|
||||
|
||||
|
||||
91
comfy/nested_tensor.py
Normal file
91
comfy/nested_tensor.py
Normal file
@ -0,0 +1,91 @@
|
||||
import torch
|
||||
|
||||
class NestedTensor:
|
||||
def __init__(self, tensors):
|
||||
self.tensors = list(tensors)
|
||||
self.is_nested = True
|
||||
|
||||
def _copy(self):
|
||||
return NestedTensor(self.tensors)
|
||||
|
||||
def apply_operation(self, other, operation):
|
||||
o = self._copy()
|
||||
if isinstance(other, NestedTensor):
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = operation(t, other.tensors[i])
|
||||
else:
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = operation(t, other)
|
||||
return o
|
||||
|
||||
def __add__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x + y)
|
||||
|
||||
def __sub__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x - y)
|
||||
|
||||
def __mul__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x * y)
|
||||
|
||||
# def __itruediv__(self, b):
|
||||
# return self.apply_operation(b, lambda x, y: x / y)
|
||||
|
||||
def __truediv__(self, b):
|
||||
return self.apply_operation(b, lambda x, y: x / y)
|
||||
|
||||
def __getitem__(self, *args, **kwargs):
|
||||
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
|
||||
|
||||
def unbind(self):
|
||||
return self.tensors
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
o = self._copy()
|
||||
for i, t in enumerate(o.tensors):
|
||||
o.tensors[i] = t.to(*args, **kwargs)
|
||||
return o
|
||||
|
||||
def new_ones(self, *args, **kwargs):
|
||||
return self.tensors[0].new_ones(*args, **kwargs)
|
||||
|
||||
def float(self):
|
||||
return self.to(dtype=torch.float)
|
||||
|
||||
def chunk(self, *args, **kwargs):
|
||||
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
|
||||
|
||||
def size(self):
|
||||
return self.tensors[0].size()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.tensors[0].shape
|
||||
|
||||
@property
|
||||
def ndim(self):
|
||||
dims = 0
|
||||
for t in self.tensors:
|
||||
dims = max(t.ndim, dims)
|
||||
return dims
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.tensors[0].device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensors[0].dtype
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
return self.tensors[0].layout
|
||||
|
||||
|
||||
def cat_nested(tensors, *args, **kwargs):
|
||||
cated_tensors = []
|
||||
for i in range(len(tensors[0].tensors)):
|
||||
tens = []
|
||||
for j in range(len(tensors)):
|
||||
tens.append(tensors[j].tensors[i])
|
||||
cated_tensors.append(torch.cat(tens, *args, **kwargs))
|
||||
return NestedTensor(cated_tensors)
|
||||
299
comfy/ops.py
299
comfy/ops.py
@ -35,7 +35,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||
|
||||
|
||||
try:
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
import inspect
|
||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||
@ -70,8 +70,11 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
@ -80,32 +83,58 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
bias = None
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
has_function = len(s.bias_function) > 0
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
|
||||
if has_function:
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
bias_has_function = len(s.bias_function) > 0
|
||||
|
||||
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
|
||||
if bias_has_function:
|
||||
with wf_context:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
has_function = len(s.weight_function) > 0
|
||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
||||
if has_function:
|
||||
weight = weight.to(dtype=dtype)
|
||||
if weight_has_function:
|
||||
with wf_context:
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
return weight, bias
|
||||
if offloadable:
|
||||
return weight, bias, offload_stream
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
|
||||
|
||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
if offload_stream is None:
|
||||
return
|
||||
if weight is not None:
|
||||
device = weight.device
|
||||
else:
|
||||
if bias is None:
|
||||
return
|
||||
device = bias.device
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
|
||||
|
||||
class CastWeightBiasOp:
|
||||
comfy_cast_weights = False
|
||||
@ -118,8 +147,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
@ -133,8 +164,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
@ -148,8 +181,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
@ -172,8 +207,10 @@ class disable_weight_init:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return self._conv_forward(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._conv_forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
@ -187,8 +224,10 @@ class disable_weight_init:
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
@ -203,11 +242,14 @@ class disable_weight_init:
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
else:
|
||||
weight = None
|
||||
bias = None
|
||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
offload_stream = None
|
||||
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
@ -223,11 +265,15 @@ class disable_weight_init:
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if self.weight is not None:
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
else:
|
||||
weight = None
|
||||
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
bias = None
|
||||
offload_stream = None
|
||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
@ -246,10 +292,12 @@ class disable_weight_init:
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose2d(
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.conv_transpose2d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
@ -268,10 +316,12 @@ class disable_weight_init:
|
||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||
num_spatial_dims, self.dilation)
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.conv_transpose1d(
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.conv_transpose1d(
|
||||
input, weight, bias, self.stride, self.padding,
|
||||
output_padding, self.groups, self.dilation)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
@ -289,8 +339,11 @@ class disable_weight_init:
|
||||
output_dtype = out_dtype
|
||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||
out_dtype = None
|
||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
|
||||
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
@ -344,20 +397,18 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
"""
|
||||
Legacy FP8 linear function for backward compatibility.
|
||||
Uses QuantizedTensor subclass for dispatch.
|
||||
"""
|
||||
dtype = self.weight.dtype
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
tensor_2d = False
|
||||
if len(input.shape) == 2:
|
||||
tensor_2d = True
|
||||
input = input.unsqueeze(1)
|
||||
|
||||
input_shape = input.shape
|
||||
input_dtype = input.dtype
|
||||
if len(input.shape) == 3:
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||
w = w.t()
|
||||
|
||||
if input.ndim == 3 or input.ndim == 2:
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
@ -369,23 +420,20 @@ def fp8_linear(self, input):
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
||||
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
if tensor_2d:
|
||||
return o.reshape(input_shape[0], -1)
|
||||
|
||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||
uncast_bias_weight(self, w, bias, offload_stream)
|
||||
return o
|
||||
|
||||
return None
|
||||
|
||||
@ -405,8 +453,10 @@ class fp8_ops(manual_cast):
|
||||
except Exception as e:
|
||||
logging.info("Exception during fp8 op: {}".format(e))
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = torch.nn.functional.linear(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||
@ -434,12 +484,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
|
||||
if weight.numel() < input.numel(): #TODO: optimize
|
||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
else:
|
||||
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
@ -478,7 +530,130 @@ if CUBLAS_IS_AVAILABLE:
|
||||
def forward(self, *args, **kwargs):
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
||||
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
# ==============================================================================
|
||||
from .quant_ops import QuantizedTensor
|
||||
|
||||
QUANT_FORMAT_MIXINS = {
|
||||
"float8_e4m3fn": {
|
||||
"dtype": torch.float8_e4m3fn,
|
||||
"layout_type": "TensorCoreFP8Layout",
|
||||
"parameters": {
|
||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {}
|
||||
_compute_dtype = torch.bfloat16
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.tensor_class = None
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||
if quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
||||
self.layout_type = mixin["layout_type"]
|
||||
|
||||
scale_key = f"{prefix}weight_scale"
|
||||
layout_params = {
|
||||
'scale': state_dict.pop(scale_key, None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype
|
||||
}
|
||||
if layout_params['scale'] is not None:
|
||||
manually_loaded_keys.append(scale_key)
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name, param_value in mixin["parameters"].items():
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
getattr(self, 'input_scale', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||
MixedPrecisionOps._compute_dtype = compute_dtype
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return MixedPrecisionOps
|
||||
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
|
||||
512
comfy/quant_ops.py
Normal file
512
comfy/quant_ops.py
Normal file
@ -0,0 +1,512 @@
|
||||
import torch
|
||||
import logging
|
||||
from typing import Tuple, Dict
|
||||
|
||||
_LAYOUT_REGISTRY = {}
|
||||
_GENERIC_UTILS = {}
|
||||
|
||||
|
||||
def register_layout_op(torch_op, layout_type):
|
||||
"""
|
||||
Decorator to register a layout-specific operation handler.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
||||
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
Example:
|
||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||
def fp8_linear(func, args, kwargs):
|
||||
# FP8-specific linear implementation
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
if torch_op not in _LAYOUT_REGISTRY:
|
||||
_LAYOUT_REGISTRY[torch_op] = {}
|
||||
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def register_generic_util(torch_op):
|
||||
"""
|
||||
Decorator to register a generic utility that works for all layouts.
|
||||
Args:
|
||||
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||
|
||||
Example:
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
# Works for any layout
|
||||
...
|
||||
"""
|
||||
def decorator(handler_func):
|
||||
_GENERIC_UTILS[torch_op] = handler_func
|
||||
return handler_func
|
||||
return decorator
|
||||
|
||||
|
||||
def _get_layout_from_args(args):
|
||||
for arg in args:
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg._layout_type
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
for item in arg:
|
||||
if isinstance(item, QuantizedTensor):
|
||||
return item._layout_type
|
||||
return None
|
||||
|
||||
|
||||
def _move_layout_params_to_device(params, device):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.to(device=device)
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
def _copy_layout_params(params):
|
||||
new_params = {}
|
||||
for k, v in params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_params[k] = v.clone()
|
||||
else:
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
|
||||
class QuantizedLayout:
|
||||
"""
|
||||
Base class for quantization layouts.
|
||||
|
||||
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||
|
||||
New quantization formats should subclass this and implement the required methods.
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||
|
||||
|
||||
class QuantizedTensor(torch.Tensor):
|
||||
"""
|
||||
Universal quantized tensor that works with any layout.
|
||||
|
||||
This tensor subclass uses a pluggable layout system to support multiple
|
||||
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||
|
||||
The layout_type determines format-specific behavior, while common operations
|
||||
(detach, clone, to) are handled generically.
|
||||
|
||||
Attributes:
|
||||
_qdata: The quantized tensor data
|
||||
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, qdata, layout_type, layout_params):
|
||||
"""
|
||||
Create a quantized tensor.
|
||||
|
||||
Args:
|
||||
qdata: The quantized data tensor
|
||||
layout_type: Layout class (subclass of QuantizedLayout)
|
||||
layout_params: Dict with layout-specific parameters
|
||||
"""
|
||||
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||
|
||||
def __init__(self, qdata, layout_type, layout_params):
|
||||
self._qdata = qdata
|
||||
self._layout_type = layout_type
|
||||
self._layout_params = layout_params
|
||||
|
||||
def __repr__(self):
|
||||
layout_name = self._layout_type
|
||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||
|
||||
@property
|
||||
def layout_type(self):
|
||||
return self._layout_type
|
||||
|
||||
def __tensor_flatten__(self):
|
||||
"""
|
||||
Tensor flattening protocol for proper device movement.
|
||||
"""
|
||||
inner_tensors = ["_qdata"]
|
||||
ctx = {
|
||||
"layout_type": self._layout_type,
|
||||
}
|
||||
|
||||
tensor_params = {}
|
||||
non_tensor_params = {}
|
||||
for k, v in self._layout_params.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
tensor_params[k] = v
|
||||
else:
|
||||
non_tensor_params[k] = v
|
||||
|
||||
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||
ctx["non_tensor_params"] = non_tensor_params
|
||||
|
||||
for k, v in tensor_params.items():
|
||||
attr_name = f"_layout_param_{k}"
|
||||
object.__setattr__(self, attr_name, v)
|
||||
inner_tensors.append(attr_name)
|
||||
|
||||
return inner_tensors, ctx
|
||||
|
||||
@staticmethod
|
||||
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||
"""
|
||||
Tensor unflattening protocol for proper device movement.
|
||||
Reconstructs the QuantizedTensor after device movement.
|
||||
"""
|
||||
layout_type = ctx["layout_type"]
|
||||
layout_params = dict(ctx["non_tensor_params"])
|
||||
|
||||
for key in ctx["tensor_param_keys"]:
|
||||
attr_name = f"_layout_param_{key}"
|
||||
layout_params[key] = inner_tensors[attr_name]
|
||||
|
||||
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
||||
return cls(qdata, layout_type, layout_params)
|
||||
|
||||
def dequantize(self) -> torch.Tensor:
|
||||
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||
if func in _GENERIC_UTILS:
|
||||
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||
|
||||
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||
layout_type = _get_layout_from_args(args)
|
||||
if layout_type and func in _LAYOUT_REGISTRY:
|
||||
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||
if handler:
|
||||
return handler(func, args, kwargs)
|
||||
|
||||
# Step 3: Fallback to dequantization
|
||||
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||
return cls._dequant_and_fallback(func, args, kwargs)
|
||||
|
||||
@classmethod
|
||||
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||
def dequant_arg(arg):
|
||||
if isinstance(arg, QuantizedTensor):
|
||||
return arg.dequantize()
|
||||
elif isinstance(arg, (list, tuple)):
|
||||
return type(arg)(dequant_arg(a) for a in arg)
|
||||
return arg
|
||||
|
||||
new_args = dequant_arg(args)
|
||||
new_kwargs = dequant_arg(kwargs)
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
# ==============================================================================
|
||||
|
||||
def _create_transformed_qtensor(qt, transform_fn):
|
||||
new_data = transform_fn(qt._qdata)
|
||||
new_params = _copy_layout_params(qt._layout_params)
|
||||
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
||||
|
||||
|
||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||
if target_dtype is not None and target_dtype != qt.dtype:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||
f"but not supported for quantized tensors. Ignoring dtype."
|
||||
)
|
||||
|
||||
if target_layout is not None and target_layout != torch.strided:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||
f"but not supported. Ignoring layout."
|
||||
)
|
||||
|
||||
# Handle device transfer
|
||||
current_device = qt._qdata.device
|
||||
if target_device is not None:
|
||||
# Normalize device for comparison
|
||||
if isinstance(target_device, str):
|
||||
target_device = torch.device(target_device)
|
||||
if isinstance(current_device, str):
|
||||
current_device = torch.device(current_device)
|
||||
|
||||
if target_device != current_device:
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||
new_q_data = qt._qdata.to(device=target_device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||
return new_qt
|
||||
|
||||
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||
return qt
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.detach.default)
|
||||
def generic_detach(func, args, kwargs):
|
||||
"""Detach operation - creates a detached copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.clone.default)
|
||||
def generic_clone(func, args, kwargs):
|
||||
"""Clone operation - creates a deep copy of the quantized tensor."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._to_copy.default)
|
||||
def generic_to_copy(func, args, kwargs):
|
||||
"""Device/dtype transfer operation - handles .to(device) calls."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
op_name="_to_copy"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
||||
def generic_to_dtype_layout(func, args, kwargs):
|
||||
"""Handle .to(device) calls using the dtype_layout variant."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
return _handle_device_transfer(
|
||||
qt,
|
||||
target_device=kwargs.get('device', None),
|
||||
target_dtype=kwargs.get('dtype', None),
|
||||
target_layout=kwargs.get('layout', None),
|
||||
op_name="to"
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.copy_.default)
|
||||
def generic_copy_(func, args, kwargs):
|
||||
qt_dest = args[0]
|
||||
src = args[1]
|
||||
|
||||
if isinstance(qt_dest, QuantizedTensor):
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# Copy from another quantized tensor
|
||||
qt_dest._qdata.copy_(src._qdata)
|
||||
qt_dest._layout_type = src._layout_type
|
||||
qt_dest._layout_params = _copy_layout_params(src._layout_params)
|
||||
else:
|
||||
# Copy from regular tensor - just copy raw data
|
||||
qt_dest._qdata.copy_(src)
|
||||
return qt_dest
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||
return True
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layout + Operation Handlers
|
||||
# ==============================================================================
|
||||
class TensorCoreFP8Layout(QuantizedLayout):
|
||||
"""
|
||||
Storage format:
|
||||
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||
- scale: Scalar tensor (float32) for dequantization
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
|
||||
orig_dtype = tensor.dtype
|
||||
|
||||
if scale is None:
|
||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
||||
# lp_amax = torch.finfo(dtype).max
|
||||
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': orig_dtype
|
||||
}
|
||||
return qdata, layout_params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
return plain_tensor * scale
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor):
|
||||
return qtensor._qdata, qtensor._layout_params['scale']
|
||||
|
||||
|
||||
LAYOUTS = {
|
||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||
}
|
||||
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
||||
def fp8_linear(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
bias = args[2] if len(args) > 2 else None
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
out_dtype = kwargs.get("out_dtype")
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
weight_t = plain_weight.t()
|
||||
|
||||
tensor_2d = False
|
||||
if len(plain_input.shape) == 2:
|
||||
tensor_2d = True
|
||||
plain_input = plain_input.unsqueeze(1)
|
||||
|
||||
input_shape = plain_input.shape
|
||||
if len(input_shape) != 3:
|
||||
return None
|
||||
|
||||
try:
|
||||
output = torch._scaled_mm(
|
||||
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
||||
weight_t,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
|
||||
if not tensor_2d:
|
||||
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||
|
||||
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||
output_scale = scale_a * scale_b
|
||||
output_params = {
|
||||
'scale': output_scale,
|
||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||
}
|
||||
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
||||
else:
|
||||
return output
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||
|
||||
# Case 2: DQ Fallback
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
input_tensor = input_tensor.dequantize()
|
||||
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
|
||||
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
||||
if out_dtype is None:
|
||||
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||
|
||||
output = torch._scaled_mm(
|
||||
plain_input.contiguous(),
|
||||
plain_weight,
|
||||
bias=bias,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
||||
def fp8_addmm(func, args, kwargs):
|
||||
input_tensor = args[1]
|
||||
weight = args[2]
|
||||
bias = args[0]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
if isinstance(args[2], QuantizedTensor):
|
||||
a[2] = args[2].dequantize()
|
||||
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
||||
def fp8_mm(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
weight = args[1]
|
||||
|
||||
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
||||
|
||||
a = list(args)
|
||||
if isinstance(args[0], QuantizedTensor):
|
||||
a[0] = args[0].dequantize()
|
||||
if isinstance(args[1], QuantizedTensor):
|
||||
a[1] = args[1].dequantize()
|
||||
return func(*a, **kwargs)
|
||||
|
||||
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||
def fp8_func(func, args, kwargs):
|
||||
input_tensor = args[0]
|
||||
if isinstance(input_tensor, QuantizedTensor):
|
||||
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||
ar = list(args)
|
||||
ar[0] = plain_input
|
||||
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||
return func(*args, **kwargs)
|
||||
@ -4,13 +4,9 @@ import comfy.samplers
|
||||
import comfy.utils
|
||||
import numpy as np
|
||||
import logging
|
||||
import comfy.nested_tensor
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
@ -21,10 +17,29 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return torch.cat(noises, axis=0)
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
if latent_image.is_nested:
|
||||
tensors = latent_image.unbind()
|
||||
noises = []
|
||||
for t in tensors:
|
||||
noises.append(prepare_noise_inner(t, generator, noise_inds))
|
||||
noises = comfy.nested_tensor.NestedTensor(noises)
|
||||
else:
|
||||
noises = prepare_noise_inner(latent_image, generator, noise_inds)
|
||||
|
||||
return noises
|
||||
|
||||
def fix_empty_latent_channels(model, latent_image):
|
||||
if latent_image.is_nested:
|
||||
return latent_image
|
||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
|
||||
@ -782,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||
|
||||
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
|
||||
for k in conds:
|
||||
conds[k] = conds[k][:]
|
||||
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
||||
@ -792,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
||||
|
||||
if hasattr(model, 'extra_conds'):
|
||||
for k in conds:
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
|
||||
|
||||
#make sure each cond area has an opposite one with the same area
|
||||
for k in conds:
|
||||
@ -962,11 +962,11 @@ class CFGGuider:
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
|
||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||
latent_image = self.inner_model.process_latent_in(latent_image)
|
||||
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
|
||||
|
||||
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
||||
@ -980,7 +980,7 @@ class CFGGuider:
|
||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
|
||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||
device = self.model_patcher.load_device
|
||||
|
||||
@ -994,7 +994,7 @@ class CFGGuider:
|
||||
|
||||
try:
|
||||
self.model_patcher.pre_run()
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
self.model_patcher.cleanup()
|
||||
|
||||
@ -1007,6 +1007,12 @@ class CFGGuider:
|
||||
if sigmas.shape[-1] == 0:
|
||||
return latent_image
|
||||
|
||||
if latent_image.is_nested:
|
||||
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
|
||||
noise, _ = comfy.utils.pack_latents(noise.unbind())
|
||||
else:
|
||||
latent_shapes = [latent_image.shape]
|
||||
|
||||
self.conds = {}
|
||||
for k in self.original_conds:
|
||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||
@ -1026,7 +1032,7 @@ class CFGGuider:
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||
)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||
finally:
|
||||
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||
self.model_options = orig_model_options
|
||||
@ -1034,6 +1040,9 @@ class CFGGuider:
|
||||
self.model_patcher.restore_hook_patches()
|
||||
|
||||
del self.conds
|
||||
|
||||
if len(latent_shapes) > 1:
|
||||
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
|
||||
return output
|
||||
|
||||
|
||||
|
||||
27
comfy/sd.py
27
comfy/sd.py
@ -143,6 +143,9 @@ class CLIP:
|
||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||
return n
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.patcher.get_ram_usage()
|
||||
|
||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||
|
||||
@ -293,6 +296,7 @@ class VAE:
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.disable_offload = False
|
||||
self.not_video = False
|
||||
self.size = None
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
@ -595,6 +599,16 @@ class VAE:
|
||||
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
self.model_size()
|
||||
|
||||
def model_size(self):
|
||||
if self.size is not None:
|
||||
return self.size
|
||||
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||
return self.size
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
def throw_exception_if_invalid(self):
|
||||
if self.first_stage_model is None:
|
||||
@ -1262,7 +1276,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
"""
|
||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||
|
||||
@ -1296,7 +1310,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
load_device = model_management.get_torch_device()
|
||||
model_config = model_detection.model_config_from_unet(sd, "")
|
||||
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||
|
||||
if model_config is not None:
|
||||
new_sd = sd
|
||||
@ -1330,7 +1344,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
if model_config.layer_quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
@ -1346,8 +1363,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
||||
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||
|
||||
@ -50,6 +50,7 @@ class BASE:
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -1106,3 +1106,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
||||
dim=1
|
||||
)
|
||||
return out
|
||||
|
||||
def pack_latents(latents):
|
||||
latent_shapes = []
|
||||
tensors = []
|
||||
for tensor in latents:
|
||||
latent_shapes.append(tensor.shape)
|
||||
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
|
||||
|
||||
latent = torch.cat(tensors, dim=-1)
|
||||
return latent, latent_shapes
|
||||
|
||||
def unpack_latents(combined_latent, latent_shapes):
|
||||
if len(latent_shapes) > 1:
|
||||
output_tensors = []
|
||||
for shape in latent_shapes:
|
||||
cut = math.prod(shape[1:])
|
||||
tens = combined_latent[:, :, :cut]
|
||||
combined_latent = combined_latent[:, :, cut:]
|
||||
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
||||
else:
|
||||
output_tensors = combined_latent
|
||||
return output_tensors
|
||||
|
||||
@ -1,718 +0,0 @@
|
||||
from __future__ import annotations
|
||||
import aiohttp
|
||||
import io
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiClient,
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
UploadRequest,
|
||||
UploadResponse,
|
||||
)
|
||||
from server import PromptServer
|
||||
from comfy.cli_args import args
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import math
|
||||
import base64
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
import av
|
||||
|
||||
|
||||
async def download_url_to_video_output(
|
||||
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
||||
) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
||||
|
||||
Args:
|
||||
video_url: The URL of the video to download.
|
||||
|
||||
Returns:
|
||||
A Comfy node `VIDEO` output.
|
||||
"""
|
||||
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {video_url}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
return VideoFromFile(video_io)
|
||||
|
||||
|
||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
async def validate_and_cast_response(
|
||||
response, timeout: int = None, node_id: Union[str, None] = None
|
||||
) -> torch.Tensor:
|
||||
"""Validates and casts a response to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
response: The response to validate and cast.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
ValueError: If the response is not valid.
|
||||
"""
|
||||
# validate raw JSON response
|
||||
data = response.data
|
||||
if not data or len(data) == 0:
|
||||
raise ValueError("No images returned from API endpoint")
|
||||
|
||||
# Initialize list to store image tensors
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
|
||||
# Process each image in the data array
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||||
for img_data in data:
|
||||
img_bytes: bytes
|
||||
if img_data.b64_json:
|
||||
img_bytes = base64.b64decode(img_data.b64_json)
|
||||
elif img_data.url:
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
||||
async with session.get(img_data.url) as resp:
|
||||
if resp.status != 200:
|
||||
raise ValueError("Failed to download generated image")
|
||||
img_bytes = await resp.read()
|
||||
else:
|
||||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
||||
|
||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||
image_tensors.append(torch.from_numpy(arr))
|
||||
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
def validate_aspect_ratio(
|
||||
aspect_ratio: str,
|
||||
minimum_ratio: float,
|
||||
maximum_ratio: float,
|
||||
minimum_ratio_str: str,
|
||||
maximum_ratio_str: str,
|
||||
) -> float:
|
||||
"""Validates and casts an aspect ratio string to a float.
|
||||
|
||||
Args:
|
||||
aspect_ratio: The aspect ratio string to validate.
|
||||
minimum_ratio: The minimum aspect ratio.
|
||||
maximum_ratio: The maximum aspect ratio.
|
||||
minimum_ratio_str: The minimum aspect ratio string.
|
||||
maximum_ratio_str: The maximum aspect ratio string.
|
||||
|
||||
Returns:
|
||||
The validated and cast aspect ratio.
|
||||
|
||||
Raises:
|
||||
Exception: If the aspect ratio is not valid.
|
||||
"""
|
||||
# get ratio values
|
||||
numbers = aspect_ratio.split(":")
|
||||
if len(numbers) != 2:
|
||||
raise TypeError(
|
||||
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
|
||||
)
|
||||
try:
|
||||
numerator = int(numbers[0])
|
||||
denominator = int(numbers[1])
|
||||
except ValueError as exc:
|
||||
raise TypeError(
|
||||
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
|
||||
) from exc
|
||||
calculated_ratio = numerator / denominator
|
||||
# if not close to minimum and maximum, check bounds
|
||||
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
|
||||
calculated_ratio, maximum_ratio
|
||||
):
|
||||
if calculated_ratio < minimum_ratio:
|
||||
raise TypeError(
|
||||
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||
)
|
||||
if calculated_ratio > maximum_ratio:
|
||||
raise TypeError(
|
||||
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
||||
)
|
||||
return aspect_ratio
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
async def download_url_to_bytesio(
|
||||
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
||||
) -> BytesIO:
|
||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
||||
|
||||
Args:
|
||||
url: The URL to download.
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
BytesIO object containing the downloaded content.
|
||||
"""
|
||||
headers = {}
|
||||
if url.startswith("/proxy/"):
|
||||
url = str(args.comfy_api_base).rstrip("/") + url
|
||||
auth_token = auth_kwargs.get("auth_token")
|
||||
comfy_api_key = auth_kwargs.get("comfy_api_key")
|
||||
if auth_token:
|
||||
headers["Authorization"] = f"Bearer {auth_token}"
|
||||
elif comfy_api_key:
|
||||
headers["X-API-KEY"] = comfy_api_key
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
|
||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
||||
async with session.get(url, headers=headers) as resp:
|
||||
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
||||
return BytesIO(await resp.read())
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
"""Converts image data from BytesIO to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
image_bytesio: BytesIO object containing the image data.
|
||||
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||
ValueError: If the specified mode is invalid.
|
||||
"""
|
||||
image = Image.open(image_bytesio)
|
||||
image = image.convert(mode)
|
||||
image_array = np.array(image).astype(np.float32) / 255.0
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
image_bytesio = await download_url_to_bytesio(url, timeout)
|
||||
return bytesio_to_image_tensor(image_bytesio)
|
||||
|
||||
|
||||
def process_image_response(response_content: bytes | str) -> torch.Tensor:
|
||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
||||
return bytesio_to_image_tensor(BytesIO(response_content))
|
||||
|
||||
|
||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||
if len(image.shape) > 3:
|
||||
image = image[0]
|
||||
# TODO: remove alpha if not allowed and present
|
||||
input_tensor = image.cpu()
|
||||
input_tensor = downscale_image_tensor(
|
||||
input_tensor.unsqueeze(0), total_pixels=total_pixels
|
||||
).squeeze()
|
||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
return img
|
||||
|
||||
|
||||
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
"""Converts a PIL Image to a BytesIO object."""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
img_byte_arr = io.BytesIO()
|
||||
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||
pil_format = mime_type.split("/")[-1].upper()
|
||||
if pil_format == "JPG":
|
||||
pil_format = "JPEG"
|
||||
img.save(img_byte_arr, format=pil_format)
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
name: Optional filename for the BytesIO object.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Named BytesIO object containing the image data, with pointer set to the start of buffer.
|
||||
"""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = (
|
||||
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
)
|
||||
return img_binary
|
||||
|
||||
|
||||
def tensor_to_base64_string(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image.
|
||||
"""
|
||||
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
# Encode bytes to base64 string
|
||||
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return base64_encoded_string
|
||||
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Converts a tensor image to a Data URI string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||
|
||||
Returns:
|
||||
Data URI string (e.g., 'data:image/png;base64,...').
|
||||
"""
|
||||
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
file_content = f.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
|
||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
"""Converts a text file to a data URI."""
|
||||
base64_string = text_filepath_to_base64_string(filepath)
|
||||
mime_type, _ = mimetypes.guess_type(filepath)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: Optional[str],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single file to ComfyUI API and returns its download URL.
|
||||
|
||||
Args:
|
||||
file_bytes_io: BytesIO object containing the file data.
|
||||
filename: The filename of the file.
|
||||
upload_mime_type: MIME type of the file.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded file.
|
||||
"""
|
||||
if upload_mime_type is None:
|
||||
request_object = UploadRequest(file_name=filename)
|
||||
else:
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/customers/storage",
|
||||
method=HttpMethod.POST,
|
||||
request_model=UploadRequest,
|
||||
response_model=UploadResponse,
|
||||
),
|
||||
request=request_object,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
response: UploadResponse = await operation.execute()
|
||||
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
|
||||
return response.download_url
|
||||
|
||||
|
||||
def video_to_base64_string(
|
||||
video: VideoInput,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
|
||||
Args:
|
||||
video: The video input to convert
|
||||
container_format: Optional container format to use (defaults to video.container if available)
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = io.BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
video: VideoInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
Uses the specified container and codec for saving the video before upload.
|
||||
|
||||
Args:
|
||||
video: VideoInput object (Comfy VIDEO type).
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
container: The video container format to use (default: MP4).
|
||||
codec: The video codec to use (default: H264).
|
||||
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded video file.
|
||||
"""
|
||||
if max_duration is not None:
|
||||
try:
|
||||
actual_duration = video.duration_seconds
|
||||
if actual_duration is not None and actual_duration > max_duration:
|
||||
raise ValueError(
|
||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error getting video duration: %s", str(e))
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = io.BytesIO()
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||
|
||||
Args:
|
||||
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||
|
||||
Returns:
|
||||
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||
the first item is taken.
|
||||
"""
|
||||
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||
|
||||
# If batch is > 1, take first item
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0]
|
||||
|
||||
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||
if audio_data_np.dtype != np.float32:
|
||||
audio_data_np = audio_data_np.astype(np.float32)
|
||||
|
||||
return audio_data_np
|
||||
|
||||
|
||||
def audio_ndarray_to_bytesio(
|
||||
audio_data_np: np.ndarray,
|
||||
sample_rate: int,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
) -> BytesIO:
|
||||
"""
|
||||
Encodes a numpy array of audio data into a BytesIO object.
|
||||
"""
|
||||
audio_bytes_io = io.BytesIO()
|
||||
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
audio_data_np,
|
||||
format="fltp",
|
||||
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
|
||||
# Flush stream
|
||||
for packet in audio_stream.encode(None):
|
||||
output_container.mux(packet)
|
||||
|
||||
audio_bytes_io.seek(0)
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
async def upload_audio_to_comfyapi(
|
||||
audio: AudioInput,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
Encodes the raw waveform into the specified format before uploading.
|
||||
|
||||
Args:
|
||||
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
|
||||
Returns:
|
||||
The download URL for the uploaded audio file.
|
||||
"""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
|
||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2 ** 15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2 ** 31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
|
||||
"""
|
||||
Decode any common audio container from bytes using PyAV and return
|
||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||
"""
|
||||
with av.open(io.BytesIO(audio_bytes)) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in response.")
|
||||
stream = af.streams.audio[0]
|
||||
|
||||
in_sr = int(stream.codec_context.sample_rate)
|
||||
out_sr = in_sr
|
||||
|
||||
frames: list[torch.Tensor] = []
|
||||
n_channels = stream.channels or 1
|
||||
|
||||
for frame in af.decode(streams=stream.index):
|
||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||
buf = torch.from_numpy(arr)
|
||||
if buf.ndim == 1:
|
||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||
elif buf.shape[0] != n_channels:
|
||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||
frames.append(buf)
|
||||
|
||||
if not frames:
|
||||
raise ValueError("Decoded zero audio frames.")
|
||||
|
||||
wav = torch.cat(frames, dim=1) # [C, T]
|
||||
wav = f32_pcm(wav)
|
||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||
|
||||
|
||||
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
|
||||
waveform = audio["waveform"].cpu()
|
||||
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format="mp3")
|
||||
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
||||
frame.sample_rate = audio["sample_rate"]
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
output_container.mux(out_stream.encode(None))
|
||||
output_container.close()
|
||||
output_buffer.seek(0)
|
||||
return output_buffer
|
||||
|
||||
|
||||
def audio_to_base64_string(
|
||||
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
||||
) -> str:
|
||||
"""Converts an audio input to a base64 string."""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
||||
audio_data_np, sample_rate, container_format, codec_name
|
||||
)
|
||||
audio_bytes = audio_bytes_io.getvalue()
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
async def upload_images_to_comfyapi(
|
||||
image: torch.Tensor,
|
||||
max_images=8,
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
mime_type: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
max_images: Maximum number of images to upload.
|
||||
auth_kwargs: Optional authentication token(s).
|
||||
mime_type: Optional MIME type for the image.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
|
||||
def resize_mask_to_image(
|
||||
mask: torch.Tensor,
|
||||
image: torch.Tensor,
|
||||
upscale_method="nearest-exact",
|
||||
crop="disabled",
|
||||
allow_gradient=True,
|
||||
add_channel_dim=False,
|
||||
):
|
||||
"""
|
||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
||||
"""
|
||||
_, H, W, _ = image.shape
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = mask.movedim(-1, 1)
|
||||
mask = common_upscale(
|
||||
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
|
||||
)
|
||||
mask = mask.movedim(1, -1)
|
||||
if not add_channel_dim:
|
||||
mask = mask.squeeze(-1)
|
||||
if not allow_gradient:
|
||||
mask = (mask > 0.5).float()
|
||||
return mask
|
||||
|
||||
|
||||
def validate_string(
|
||||
string: str,
|
||||
strip_whitespace=True,
|
||||
field_name="prompt",
|
||||
min_length=None,
|
||||
max_length=None,
|
||||
):
|
||||
if string is None:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(
|
||||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(
|
||||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
|
||||
|
||||
def image_tensor_pair_to_batch(
|
||||
image1: torch.Tensor, image2: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts a pair of image tensors to a batch tensor.
|
||||
If the images are not the same size, the smaller image is resized to
|
||||
match the larger image.
|
||||
"""
|
||||
if image1.shape[1:] != image2.shape[1:]:
|
||||
image2 = common_upscale(
|
||||
image2.movedim(-1, 1),
|
||||
image1.shape[2],
|
||||
image1.shape[1],
|
||||
"bilinear",
|
||||
"center",
|
||||
).movedim(1, -1)
|
||||
return torch.cat((image1, image2), dim=0)
|
||||
|
||||
|
||||
def get_size(path_or_object: Union[str, io.BytesIO]) -> int:
|
||||
if isinstance(path_or_object, str):
|
||||
return os.path.getsize(path_or_object)
|
||||
return len(path_or_object.getvalue())
|
||||
|
||||
|
||||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||
@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel):
|
||||
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||
|
||||
|
||||
class BFLFluxCannyImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
|
||||
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxDepthImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='Text prompt for image generation')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
||||
|
||||
|
||||
class BFLFluxProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
@ -160,15 +122,8 @@ class BFLStatus(str, Enum):
|
||||
error = "Error"
|
||||
|
||||
|
||||
class BFLFluxProStatusResponse(BaseModel):
|
||||
class BFLFluxStatusResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
status: BFLStatus = Field(..., description="The status of the task.")
|
||||
result: Optional[Dict[str, Any]] = Field(
|
||||
None, description="The result of the task (null if not completed)."
|
||||
)
|
||||
progress: confloat(ge=0.0, le=1.0) = Field(
|
||||
..., description="The progress of the task (0.0 to 1.0)."
|
||||
)
|
||||
details: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Additional details about the task (null if not available)."
|
||||
)
|
||||
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
|
||||
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
|
||||
|
||||
120
comfy_api_nodes/apis/minimax_api.py
Normal file
120
comfy_api_nodes/apis/minimax_api.py
Normal file
@ -0,0 +1,120 @@
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MinimaxBaseResponse(BaseModel):
|
||||
status_code: int = Field(
|
||||
...,
|
||||
description='Status code. 0 indicates success, other values indicate errors.',
|
||||
)
|
||||
status_msg: str = Field(
|
||||
..., description='Specific error details or success message.'
|
||||
)
|
||||
|
||||
|
||||
class File(BaseModel):
|
||||
bytes: Optional[int] = Field(None, description='File size in bytes')
|
||||
created_at: Optional[int] = Field(
|
||||
None, description='Unix timestamp when the file was created, in seconds'
|
||||
)
|
||||
download_url: Optional[str] = Field(
|
||||
None, description='The URL to download the video'
|
||||
)
|
||||
backup_download_url: Optional[str] = Field(
|
||||
None, description='The backup URL to download the video'
|
||||
)
|
||||
|
||||
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
|
||||
filename: Optional[str] = Field(None, description='The name of the file')
|
||||
purpose: Optional[str] = Field(None, description='The purpose of using the file')
|
||||
|
||||
|
||||
class MinimaxFileRetrieveResponse(BaseModel):
|
||||
base_resp: MinimaxBaseResponse
|
||||
file: File
|
||||
|
||||
|
||||
class MiniMaxModel(str, Enum):
|
||||
T2V_01_Director = 'T2V-01-Director'
|
||||
I2V_01_Director = 'I2V-01-Director'
|
||||
S2V_01 = 'S2V-01'
|
||||
I2V_01 = 'I2V-01'
|
||||
I2V_01_live = 'I2V-01-live'
|
||||
T2V_01 = 'T2V-01'
|
||||
Hailuo_02 = 'MiniMax-Hailuo-02'
|
||||
|
||||
|
||||
class Status6(str, Enum):
|
||||
Queueing = 'Queueing'
|
||||
Preparing = 'Preparing'
|
||||
Processing = 'Processing'
|
||||
Success = 'Success'
|
||||
Fail = 'Fail'
|
||||
|
||||
|
||||
class MinimaxTaskResultResponse(BaseModel):
|
||||
base_resp: MinimaxBaseResponse
|
||||
file_id: Optional[str] = Field(
|
||||
None,
|
||||
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
|
||||
)
|
||||
status: Status6 = Field(
|
||||
...,
|
||||
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
|
||||
)
|
||||
task_id: str = Field(..., description='The task ID being queried.')
|
||||
|
||||
|
||||
class SubjectReferenceItem(BaseModel):
|
||||
image: Optional[str] = Field(
|
||||
None, description='URL or base64 encoding of the subject reference image.'
|
||||
)
|
||||
mask: Optional[str] = Field(
|
||||
None,
|
||||
description='URL or base64 encoding of the mask for the subject reference image.',
|
||||
)
|
||||
|
||||
|
||||
class MinimaxVideoGenerationRequest(BaseModel):
|
||||
callback_url: Optional[str] = Field(
|
||||
None,
|
||||
description='Optional. URL to receive real-time status updates about the video generation task.',
|
||||
)
|
||||
first_frame_image: Optional[str] = Field(
|
||||
None,
|
||||
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
|
||||
)
|
||||
model: MiniMaxModel = Field(
|
||||
...,
|
||||
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
|
||||
)
|
||||
prompt: Optional[str] = Field(
|
||||
None,
|
||||
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
|
||||
max_length=2000,
|
||||
)
|
||||
prompt_optimizer: Optional[bool] = Field(
|
||||
True,
|
||||
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
|
||||
)
|
||||
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
|
||||
None,
|
||||
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
|
||||
)
|
||||
duration: Optional[int] = Field(
|
||||
None,
|
||||
description="The length of the output video in seconds."
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
|
||||
)
|
||||
|
||||
|
||||
class MinimaxVideoGenerationResponse(BaseModel):
|
||||
base_resp: MinimaxBaseResponse
|
||||
task_id: str = Field(
|
||||
..., description='The task ID for the asynchronous video generation task.'
|
||||
)
|
||||
@ -1,13 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api_nodes.apis import (
|
||||
TripoModelVersion,
|
||||
TripoTextureQuality,
|
||||
)
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
class TripoModelVersion(str, Enum):
|
||||
v2_5_20250123 = 'v2.5-20250123'
|
||||
v2_0_20240919 = 'v2.0-20240919'
|
||||
v1_4_20240625 = 'v1.4-20240625'
|
||||
|
||||
|
||||
class TripoTextureQuality(str, Enum):
|
||||
standard = 'standard'
|
||||
detailed = 'detailed'
|
||||
|
||||
|
||||
class TripoStyle(str, Enum):
|
||||
PERSON_TO_CARTOON = "person:person2cartoon"
|
||||
ANIMAL_VENOM = "animal:venom"
|
||||
|
||||
111
comfy_api_nodes/apis/veo_api.py
Normal file
111
comfy_api_nodes/apis/veo_api.py
Normal file
@ -0,0 +1,111 @@
|
||||
from typing import Optional, Union
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Image2(BaseModel):
|
||||
bytesBase64Encoded: str
|
||||
gcsUri: Optional[str] = None
|
||||
mimeType: Optional[str] = None
|
||||
|
||||
|
||||
class Image3(BaseModel):
|
||||
bytesBase64Encoded: Optional[str] = None
|
||||
gcsUri: str
|
||||
mimeType: Optional[str] = None
|
||||
|
||||
|
||||
class Instance1(BaseModel):
|
||||
image: Optional[Union[Image2, Image3]] = Field(
|
||||
None, description='Optional image to guide video generation'
|
||||
)
|
||||
prompt: str = Field(..., description='Text description of the video')
|
||||
|
||||
|
||||
class PersonGeneration1(str, Enum):
|
||||
ALLOW = 'ALLOW'
|
||||
BLOCK = 'BLOCK'
|
||||
|
||||
|
||||
class Parameters1(BaseModel):
|
||||
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
|
||||
durationSeconds: Optional[int] = None
|
||||
enhancePrompt: Optional[bool] = None
|
||||
generateAudio: Optional[bool] = Field(
|
||||
None,
|
||||
description='Generate audio for the video. Only supported by veo 3 models.',
|
||||
)
|
||||
negativePrompt: Optional[str] = None
|
||||
personGeneration: Optional[PersonGeneration1] = None
|
||||
sampleCount: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
storageUri: Optional[str] = Field(
|
||||
None, description='Optional Cloud Storage URI to upload the video'
|
||||
)
|
||||
|
||||
|
||||
class VeoGenVidRequest(BaseModel):
|
||||
instances: Optional[list[Instance1]] = None
|
||||
parameters: Optional[Parameters1] = None
|
||||
|
||||
|
||||
class VeoGenVidResponse(BaseModel):
|
||||
name: str = Field(
|
||||
...,
|
||||
description='Operation resource name',
|
||||
examples=[
|
||||
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class VeoGenVidPollRequest(BaseModel):
|
||||
operationName: str = Field(
|
||||
...,
|
||||
description='Full operation name (from predict response)',
|
||||
examples=[
|
||||
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class Video(BaseModel):
|
||||
bytesBase64Encoded: Optional[str] = Field(
|
||||
None, description='Base64-encoded video content'
|
||||
)
|
||||
gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
|
||||
mimeType: Optional[str] = Field(None, description='Video MIME type')
|
||||
|
||||
|
||||
class Error1(BaseModel):
|
||||
code: Optional[int] = Field(None, description='Error code')
|
||||
message: Optional[str] = Field(None, description='Error message')
|
||||
|
||||
|
||||
class Response1(BaseModel):
|
||||
field_type: Optional[str] = Field(
|
||||
None,
|
||||
alias='@type',
|
||||
examples=[
|
||||
'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
|
||||
],
|
||||
)
|
||||
raiMediaFilteredCount: Optional[int] = Field(
|
||||
None, description='Count of media filtered by responsible AI policies'
|
||||
)
|
||||
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||
None, description='Reasons why media was filtered by responsible AI policies'
|
||||
)
|
||||
videos: Optional[list[Video]] = None
|
||||
|
||||
|
||||
class VeoGenVidPollResponse(BaseModel):
|
||||
done: Optional[bool] = None
|
||||
error: Optional[Error1] = Field(
|
||||
None, description='Error details if operation failed'
|
||||
)
|
||||
name: Optional[str] = None
|
||||
response: Optional[Response1] = Field(
|
||||
None, description='The actual prediction response if done is true'
|
||||
)
|
||||
@ -1,146 +1,46 @@
|
||||
import asyncio
|
||||
import io
|
||||
from inspect import cleandoc
|
||||
from typing import Union, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLStatus,
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
BFLFluxCannyImageRequest,
|
||||
BFLFluxDepthImageRequest,
|
||||
BFLFluxProGenerateRequest,
|
||||
BFLFluxKontextProGenerateRequest,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxProGenerateRequest,
|
||||
BFLFluxProGenerateResponse,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxStatusResponse,
|
||||
BFLStatus,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
validate_aspect_ratio,
|
||||
process_image_response,
|
||||
download_url_to_image_tensor,
|
||||
poll_op,
|
||||
resize_mask_to_image,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
validate_aspect_ratio_string,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import aiohttp
|
||||
import torch
|
||||
import base64
|
||||
import time
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: torch.Tensor):
|
||||
"""
|
||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||
"""
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = torch.cat([mask]*3, dim=-1)
|
||||
mask = torch.cat([mask] * 3, dim=-1)
|
||||
return mask
|
||||
|
||||
|
||||
async def handle_bfl_synchronous_operation(
|
||||
operation: SynchronousOperation,
|
||||
timeout_bfl_calls=360,
|
||||
node_id: Union[str, None] = None,
|
||||
):
|
||||
response_api: BFLFluxProGenerateResponse = await operation.execute()
|
||||
return await _poll_until_generated(
|
||||
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
|
||||
)
|
||||
|
||||
|
||||
async def _poll_until_generated(
|
||||
polling_url: str, timeout=360, node_id: Union[str, None] = None
|
||||
):
|
||||
# used bfl-comfy-nodes to verify code implementation:
|
||||
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
|
||||
start_time = time.time()
|
||||
retries_404 = 0
|
||||
max_retries_404 = 5
|
||||
retry_404_seconds = 2
|
||||
retry_202_seconds = 2
|
||||
retry_pending_seconds = 1
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
|
||||
while True:
|
||||
if node_id:
|
||||
time_elapsed = time.time() - start_time
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generating ({time_elapsed:.0f}s)", node_id
|
||||
)
|
||||
|
||||
async with session.get(polling_url) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if result["status"] == BFLStatus.ready:
|
||||
img_url = result["result"]["sample"]
|
||||
if node_id:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Result URL: {img_url}", node_id
|
||||
)
|
||||
async with session.get(img_url) as img_resp:
|
||||
return process_image_response(await img_resp.content.read())
|
||||
elif result["status"] in [
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
]:
|
||||
status = result["status"]
|
||||
raise Exception(
|
||||
f"BFL API did not return an image due to: {status}."
|
||||
)
|
||||
elif result["status"] == BFLStatus.error:
|
||||
raise Exception(f"BFL API encountered an error: {result}.")
|
||||
elif result["status"] == BFLStatus.pending:
|
||||
await asyncio.sleep(retry_pending_seconds)
|
||||
continue
|
||||
elif response.status == 404:
|
||||
if retries_404 < max_retries_404:
|
||||
retries_404 += 1
|
||||
await asyncio.sleep(retry_404_seconds)
|
||||
continue
|
||||
raise Exception(
|
||||
f"BFL API could not find task after {max_retries_404} tries."
|
||||
)
|
||||
elif response.status == 202:
|
||||
await asyncio.sleep(retry_202_seconds)
|
||||
elif time.time() - start_time > timeout:
|
||||
raise Exception(
|
||||
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
|
||||
)
|
||||
else:
|
||||
raise Exception(f"BFL API encountered an error: {response.json()}")
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
|
||||
# remove batch dimension if present
|
||||
if len(scaled_image.shape) > 3:
|
||||
scaled_image = scaled_image[0]
|
||||
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
img.save(img_byte_arr, format="PNG")
|
||||
return base64.b64encode(img_byte_arr.getvalue()).decode()
|
||||
|
||||
|
||||
class FluxProUltraImageNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
|
||||
"""
|
||||
|
||||
MINIMUM_RATIO = 1 / 4
|
||||
MAXIMUM_RATIO = 4 / 1
|
||||
MINIMUM_RATIO_STR = "1:4"
|
||||
MAXIMUM_RATIO_STR = "4:1"
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
@ -158,7 +58,9 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -203,16 +105,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, aspect_ratio: str):
|
||||
try:
|
||||
validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@ -220,49 +113,44 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
cls,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
prompt_upsampling=False,
|
||||
raw=False,
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
image_prompt_strength=0.1,
|
||||
prompt_upsampling: bool = False,
|
||||
raw: bool = False,
|
||||
seed: int = 0,
|
||||
image_prompt: Optional[torch.Tensor] = None,
|
||||
image_prompt_strength: float = 0.1,
|
||||
) -> IO.NodeOutput:
|
||||
if image_prompt is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.1-ultra/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxProUltraGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxProUltraGenerateRequest(
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/flux-pro-1.1-ultra/generate", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxProUltraGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
seed=seed,
|
||||
aspect_ratio=validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
),
|
||||
aspect_ratio=aspect_ratio,
|
||||
raw=raw,
|
||||
image_prompt=(
|
||||
image_prompt
|
||||
if image_prompt is None
|
||||
else convert_image_to_base64(image_prompt)
|
||||
),
|
||||
image_prompt_strength=(
|
||||
None if image_prompt is None else round(image_prompt_strength, 2)
|
||||
),
|
||||
image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
|
||||
image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxKontextProImageNode(IO.ComfyNode):
|
||||
@ -270,11 +158,6 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
MINIMUM_RATIO = 1 / 4
|
||||
MAXIMUM_RATIO = 4 / 1
|
||||
MINIMUM_RATIO_STR = "1:4"
|
||||
MAXIMUM_RATIO_STR = "4:1"
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
@ -347,46 +230,43 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
input_image: Optional[torch.Tensor]=None,
|
||||
input_image: Optional[torch.Tensor] = None,
|
||||
seed=0,
|
||||
prompt_upsampling=False,
|
||||
) -> IO.NodeOutput:
|
||||
aspect_ratio = validate_aspect_ratio(
|
||||
aspect_ratio,
|
||||
minimum_ratio=cls.MINIMUM_RATIO,
|
||||
maximum_ratio=cls.MAXIMUM_RATIO,
|
||||
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
|
||||
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
|
||||
)
|
||||
validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
|
||||
if input_image is None:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=cls.BFL_PATH,
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxKontextProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxKontextProGenerateRequest(
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=cls.BFL_PATH, method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxKontextProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
guidance=round(guidance, 1),
|
||||
steps=steps,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio,
|
||||
input_image=(
|
||||
input_image
|
||||
if input_image is None
|
||||
else convert_image_to_base64(input_image)
|
||||
)
|
||||
input_image=(input_image if input_image is None else tensor_to_base64_string(input_image)),
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||
@ -422,7 +302,9 @@ class FluxProImageNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"width",
|
||||
@ -481,20 +363,15 @@ class FluxProImageNode(IO.ComfyNode):
|
||||
image_prompt=None,
|
||||
# image_prompt_strength=0.1,
|
||||
) -> IO.NodeOutput:
|
||||
image_prompt = (
|
||||
image_prompt
|
||||
if image_prompt is None
|
||||
else convert_image_to_base64(image_prompt)
|
||||
)
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.1/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxProGenerateRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
method="POST",
|
||||
),
|
||||
request=BFLFluxProGenerateRequest(
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
width=width,
|
||||
@ -502,13 +379,23 @@ class FluxProImageNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
image_prompt=image_prompt,
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxProExpandNode(IO.ComfyNode):
|
||||
@ -534,7 +421,9 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"top",
|
||||
@ -610,16 +499,11 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
guidance: float,
|
||||
seed=0,
|
||||
) -> IO.NodeOutput:
|
||||
image = convert_image_to_base64(image)
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-expand/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxExpandImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxExpandImageRequest(
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-expand/generate", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxExpandImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
top=top,
|
||||
@ -629,16 +513,25 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
image=image,
|
||||
image=tensor_to_base64_string(image),
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxProFillNode(IO.ComfyNode):
|
||||
@ -665,7 +558,9 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"guidance",
|
||||
@ -712,272 +607,37 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
# prepare mask
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
mask = convert_image_to_base64(convert_mask_to_image(mask))
|
||||
# make sure image will have alpha channel removed
|
||||
image = convert_image_to_base64(image[:, :, :, :3])
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-fill/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxFillImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxFillImageRequest(
|
||||
mask = tensor_to_base64_string(convert_mask_to_image(mask))
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-fill/generate", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxFillImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
image=image,
|
||||
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
|
||||
mask=mask,
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
|
||||
|
||||
class FluxProCannyNode(IO.ComfyNode):
|
||||
"""
|
||||
Generate image using a control image (canny).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxProCannyNode",
|
||||
display_name="Flux.1 Canny Control Image",
|
||||
category="api node/image/BFL",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("control_image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"canny_low_threshold",
|
||||
default=0.1,
|
||||
min=0.01,
|
||||
max=0.99,
|
||||
step=0.01,
|
||||
tooltip="Low threshold for Canny edge detection; ignored if skip_processing is True",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"canny_high_threshold",
|
||||
default=0.4,
|
||||
min=0.01,
|
||||
max=0.99,
|
||||
step=0.01,
|
||||
tooltip="High threshold for Canny edge detection; ignored if skip_processing is True",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"skip_preprocessing",
|
||||
default=False,
|
||||
tooltip="Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"guidance",
|
||||
default=30,
|
||||
min=1,
|
||||
max=100,
|
||||
tooltip="Guidance strength for the image generation process",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=50,
|
||||
min=15,
|
||||
max=50,
|
||||
tooltip="Number of steps for the image generation process",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
queued_statuses=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
canny_low_threshold: float,
|
||||
canny_high_threshold: float,
|
||||
skip_preprocessing: bool,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
) -> IO.NodeOutput:
|
||||
control_image = convert_image_to_base64(control_image[:, :, :, :3])
|
||||
preprocessed_image = None
|
||||
|
||||
# scale canny threshold between 0-500, to match BFL's API
|
||||
def scale_value(value: float, min_val=0, max_val=500):
|
||||
return min_val + value * (max_val - min_val)
|
||||
canny_low_threshold = int(round(scale_value(canny_low_threshold)))
|
||||
canny_high_threshold = int(round(scale_value(canny_high_threshold)))
|
||||
|
||||
|
||||
if skip_preprocessing:
|
||||
preprocessed_image = control_image
|
||||
control_image = None
|
||||
canny_low_threshold = None
|
||||
canny_high_threshold = None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-canny/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxCannyImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxCannyImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
control_image=control_image,
|
||||
canny_low_threshold=canny_low_threshold,
|
||||
canny_high_threshold=canny_high_threshold,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
|
||||
|
||||
class FluxProDepthNode(IO.ComfyNode):
|
||||
"""
|
||||
Generate image using a control image (depth).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxProDepthNode",
|
||||
display_name="Flux.1 Depth Control Image",
|
||||
category="api node/image/BFL",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("control_image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"skip_preprocessing",
|
||||
default=False,
|
||||
tooltip="Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"guidance",
|
||||
default=15,
|
||||
min=1,
|
||||
max=100,
|
||||
tooltip="Guidance strength for the image generation process",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"steps",
|
||||
default=50,
|
||||
min=15,
|
||||
max=50,
|
||||
tooltip="Number of steps for the image generation process",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
control_image: torch.Tensor,
|
||||
prompt: str,
|
||||
prompt_upsampling: bool,
|
||||
skip_preprocessing: bool,
|
||||
steps: int,
|
||||
guidance: float,
|
||||
seed=0,
|
||||
) -> IO.NodeOutput:
|
||||
control_image = convert_image_to_base64(control_image[:,:,:,:3])
|
||||
preprocessed_image = None
|
||||
|
||||
if skip_preprocessing:
|
||||
preprocessed_image = control_image
|
||||
control_image = None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.0-depth/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=BFLFluxDepthImageRequest,
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
),
|
||||
request=BFLFluxDepthImageRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
steps=steps,
|
||||
guidance=guidance,
|
||||
seed=seed,
|
||||
control_image=control_image,
|
||||
preprocessed_image=preprocessed_image,
|
||||
),
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
|
||||
return IO.NodeOutput(output_image)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class BFLExtension(ComfyExtension):
|
||||
@ -990,8 +650,6 @@ class BFLExtension(ComfyExtension):
|
||||
FluxKontextMaxImageNode,
|
||||
FluxProExpandNode,
|
||||
FluxProFillNode,
|
||||
FluxProCannyNode,
|
||||
FluxProDepthNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,35 +1,27 @@
|
||||
import logging
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
from typing_extensions import override
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
validate_image_aspect_ratio_range,
|
||||
get_number_of_images,
|
||||
validate_image_dimensions,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
EmptyRequest,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
T,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
get_number_of_images,
|
||||
image_tensor_pair_to_batch,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
|
||||
|
||||
# Long-running tasks endpoints(e.g., video)
|
||||
@ -46,13 +38,14 @@ class Image2ImageModelName(str, Enum):
|
||||
|
||||
|
||||
class Text2VideoModelName(str, Enum):
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
|
||||
|
||||
|
||||
class Image2VideoModelName(str, Enum):
|
||||
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
|
||||
|
||||
|
||||
@ -208,35 +201,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
return None
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
task_id: str,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> TaskStatusResponse:
|
||||
"""Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
completed_statuses=[
|
||||
"succeeded",
|
||||
],
|
||||
failed_statuses=[
|
||||
"cancelled",
|
||||
"failed",
|
||||
],
|
||||
status_extractor=lambda response: response.status,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=get_video_url_from_task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
).execute()
|
||||
|
||||
|
||||
class ByteDanceImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -303,7 +267,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the image",
|
||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -341,8 +305,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
w, h = width, height
|
||||
if not (512 <= w <= 2048) or not (512 <= h <= 2048):
|
||||
raise ValueError(
|
||||
f"Custom size out of range: {w}x{h}. "
|
||||
"Both width and height must be between 512 and 2048 pixels."
|
||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 512 and 2048 pixels."
|
||||
)
|
||||
|
||||
payload = Text2ImageTaskCreationRequest(
|
||||
@ -353,20 +316,12 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
guidance_scale=guidance_scale,
|
||||
watermark=watermark,
|
||||
)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=BYTEPLUS_IMAGE_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Text2ImageTaskCreationRequest,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
|
||||
|
||||
@ -420,7 +375,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the image",
|
||||
tooltip='Whether to add an "AI generated" watermark to the image',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -448,17 +403,8 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
validate_image_aspect_ratio_range(image, (1, 3), (3, 1))
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
source_url = (await upload_images_to_comfyapi(
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
))[0]
|
||||
validate_image_aspect_ratio(image, (1, 3), (3, 1))
|
||||
source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
|
||||
payload = Image2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
@ -467,16 +413,12 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
guidance_scale=guidance_scale,
|
||||
watermark=watermark,
|
||||
)
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=BYTEPLUS_IMAGE_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Image2ImageTaskCreationRequest,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
|
||||
|
||||
@ -504,7 +446,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Input image(s) for image-to-image generation. "
|
||||
"List of 1-10 images for single or multi-reference generation.",
|
||||
"List of 1-10 images for single or multi-reference generation.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
@ -534,9 +476,9 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
"sequential_image_generation",
|
||||
options=["disabled", "auto"],
|
||||
tooltip="Group image generation mode. "
|
||||
"'disabled' generates a single image. "
|
||||
"'auto' lets the model decide whether to generate multiple related images "
|
||||
"(e.g., story scenes, character variations).",
|
||||
"'disabled' generates a single image. "
|
||||
"'auto' lets the model decide whether to generate multiple related images "
|
||||
"(e.g., story scenes, character variations).",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -547,7 +489,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Maximum number of images to generate when sequential_image_generation='auto'. "
|
||||
"Total images (input + generated) cannot exceed 15.",
|
||||
"Total images (input + generated) cannot exceed 15.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -564,7 +506,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the image.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the image.',
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
@ -611,8 +553,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
w, h = width, height
|
||||
if not (1024 <= w <= 4096) or not (1024 <= h <= 4096):
|
||||
raise ValueError(
|
||||
f"Custom size out of range: {w}x{h}. "
|
||||
"Both width and height must be between 1024 and 4096 pixels."
|
||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
||||
)
|
||||
n_input_images = get_number_of_images(image) if image is not None else 0
|
||||
if n_input_images > 10:
|
||||
@ -621,41 +562,31 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
raise ValueError(
|
||||
"The maximum number of generated images plus the number of reference images cannot exceed 15."
|
||||
)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
reference_images_urls = []
|
||||
if n_input_images:
|
||||
for i in image:
|
||||
validate_image_aspect_ratio_range(i, (1, 3), (3, 1))
|
||||
reference_images_urls = (await upload_images_to_comfyapi(
|
||||
validate_image_aspect_ratio(i, (1, 3), (3, 1))
|
||||
reference_images_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=n_input_images,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
))
|
||||
payload = Seedream4TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
image=reference_images_urls,
|
||||
size=f"{w}x{h}",
|
||||
seed=seed,
|
||||
sequential_image_generation=sequential_image_generation,
|
||||
sequential_image_generation_options=Seedream4Options(max_images=max_images),
|
||||
watermark=watermark,
|
||||
)
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=BYTEPLUS_IMAGE_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=Seedream4TaskCreationRequest,
|
||||
response_model=ImageTaskCreationResponse,
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
|
||||
response_model=ImageTaskCreationResponse,
|
||||
data=Seedream4TaskCreationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
image=reference_images_urls,
|
||||
size=f"{w}x{h}",
|
||||
seed=seed,
|
||||
sequential_image_generation=sequential_image_generation,
|
||||
sequential_image_generation_options=Seedream4Options(max_images=max_images),
|
||||
watermark=watermark,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
)
|
||||
if len(response.data) == 1:
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
|
||||
urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d]
|
||||
@ -719,13 +650,13 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
"camera_fixed",
|
||||
default=False,
|
||||
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the video.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -764,19 +695,9 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
f"--camerafixed {str(camera_fixed).lower()} "
|
||||
f"--watermark {str(watermark).lower()}"
|
||||
)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
return await process_video_task(
|
||||
request_model=Text2VideoTaskCreationRequest,
|
||||
payload=Text2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
content=[TaskTextContent(text=prompt)],
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]),
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@ -840,13 +761,13 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
"camera_fixed",
|
||||
default=False,
|
||||
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the video.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -877,15 +798,9 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0]
|
||||
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
||||
prompt = (
|
||||
f"{prompt} "
|
||||
f"--resolution {resolution} "
|
||||
@ -897,13 +812,11 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
return await process_video_task(
|
||||
request_model=Image2VideoTaskCreationRequest,
|
||||
cls,
|
||||
payload=Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))],
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@ -971,13 +884,13 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
"camera_fixed",
|
||||
default=False,
|
||||
tooltip="Specifies whether to fix the camera. The platform appends an instruction "
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
"to fix the camera to your prompt, but does not guarantee the actual effect.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the video.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -1010,18 +923,13 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
|
||||
for i in (first_frame, last_frame):
|
||||
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image_tensor_pair_to_batch(first_frame, last_frame),
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
prompt = (
|
||||
@ -1035,7 +943,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
return await process_video_task(
|
||||
request_model=Image2VideoTaskCreationRequest,
|
||||
cls,
|
||||
payload=Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
content=[
|
||||
@ -1044,8 +952,6 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"),
|
||||
],
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
@ -1108,7 +1014,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the video.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the video.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -1139,17 +1045,9 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
|
||||
for image in images:
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
image_urls = await upload_images_to_comfyapi(
|
||||
images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs
|
||||
)
|
||||
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
|
||||
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
|
||||
prompt = (
|
||||
f"{prompt} "
|
||||
f"--resolution {resolution} "
|
||||
@ -1160,42 +1058,32 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
)
|
||||
x = [
|
||||
TaskTextContent(text=prompt),
|
||||
*[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls]
|
||||
*[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls],
|
||||
]
|
||||
return await process_video_task(
|
||||
request_model=Image2VideoTaskCreationRequest,
|
||||
payload=Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
content=x,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
payload=Image2VideoTaskCreationRequest(model=model, content=x),
|
||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||
)
|
||||
|
||||
|
||||
async def process_video_task(
|
||||
request_model: Type[T],
|
||||
cls: type[IO.ComfyNode],
|
||||
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
||||
auth_kwargs: dict,
|
||||
node_id: str,
|
||||
estimated_duration: Optional[int],
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=BYTEPLUS_TASK_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_model,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
response = await poll_until_finished(
|
||||
auth_kwargs,
|
||||
initial_response.id,
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||
data=payload,
|
||||
response_model=TaskCreationResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
|
||||
|
||||
@ -1221,5 +1109,6 @@ class ByteDanceExtension(ComfyExtension):
|
||||
ByteDanceImageReferenceNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ByteDanceExtension:
|
||||
return ByteDanceExtension()
|
||||
|
||||
@ -2,45 +2,47 @@
|
||||
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
import uuid
|
||||
import base64
|
||||
from io import BytesIO
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Optional, Literal
|
||||
from io import BytesIO
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
||||
from server import PromptServer
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api_nodes.apis import (
|
||||
GeminiContent,
|
||||
GeminiGenerateContentRequest,
|
||||
GeminiGenerateContentResponse,
|
||||
GeminiInlineData,
|
||||
GeminiPart,
|
||||
GeminiMimeType,
|
||||
GeminiPart,
|
||||
)
|
||||
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiImageConfig,
|
||||
GeminiImageGenerateContentRequest,
|
||||
GeminiImageGenerationConfig,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
validate_string,
|
||||
audio_to_base64_string,
|
||||
video_to_base64_string,
|
||||
tensor_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
validate_string,
|
||||
video_to_base64_string,
|
||||
)
|
||||
from comfy_api.util import VideoContainer, VideoCodec
|
||||
|
||||
from server import PromptServer
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
@ -66,50 +68,6 @@ class GeminiImageModel(str, Enum):
|
||||
gemini_2_5_flash_image = "gemini-2.5-flash-image"
|
||||
|
||||
|
||||
def get_gemini_endpoint(
|
||||
model: GeminiModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
"""
|
||||
Get the API endpoint for a given Gemini model.
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use, either as enum or string value.
|
||||
|
||||
Returns:
|
||||
ApiEndpoint configured for the specific Gemini model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = GeminiModel(model)
|
||||
return ApiEndpoint(
|
||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
||||
method=HttpMethod.POST,
|
||||
request_model=GeminiGenerateContentRequest,
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
|
||||
def get_gemini_image_endpoint(
|
||||
model: GeminiImageModel,
|
||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
||||
"""
|
||||
Get the API endpoint for a given Gemini model.
|
||||
|
||||
Args:
|
||||
model: The Gemini model to use, either as enum or string value.
|
||||
|
||||
Returns:
|
||||
ApiEndpoint configured for the specific Gemini model.
|
||||
"""
|
||||
if isinstance(model, str):
|
||||
model = GeminiImageModel(model)
|
||||
return ApiEndpoint(
|
||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
||||
method=HttpMethod.POST,
|
||||
request_model=GeminiImageGenerateContentRequest,
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
|
||||
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
@ -122,9 +80,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(
|
||||
image_input[image_index].unsqueeze(0)
|
||||
)
|
||||
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0))
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
@ -136,37 +92,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
return image_parts
|
||||
|
||||
|
||||
def create_text_part(text: str) -> GeminiPart:
|
||||
"""
|
||||
Create a text part for the Gemini API request.
|
||||
|
||||
Args:
|
||||
text: The text content to include in the request.
|
||||
|
||||
Returns:
|
||||
A GeminiPart object with the text content.
|
||||
"""
|
||||
return GeminiPart(text=text)
|
||||
|
||||
|
||||
def get_parts_from_response(
|
||||
response: GeminiGenerateContentResponse
|
||||
) -> list[GeminiPart]:
|
||||
"""
|
||||
Extract all parts from the Gemini API response.
|
||||
|
||||
Args:
|
||||
response: The API response from Gemini.
|
||||
|
||||
Returns:
|
||||
List of response parts from the first candidate.
|
||||
"""
|
||||
return response.candidates[0].content.parts
|
||||
|
||||
|
||||
def get_parts_by_type(
|
||||
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
||||
) -> list[GeminiPart]:
|
||||
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
|
||||
"""
|
||||
Filter response parts by their type.
|
||||
|
||||
@ -178,14 +104,10 @@ def get_parts_by_type(
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
parts = []
|
||||
for part in get_parts_from_response(response):
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
parts.append(part)
|
||||
elif (
|
||||
hasattr(part, "inlineData")
|
||||
and part.inlineData
|
||||
and part.inlineData.mimeType == part_type
|
||||
):
|
||||
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
|
||||
parts.append(part)
|
||||
# Skip parts that don't match the requested type
|
||||
return parts
|
||||
@ -213,11 +135,11 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te
|
||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||
image_tensors.append(returned_image)
|
||||
if len(image_tensors) == 0:
|
||||
return torch.zeros((1,1024,1024,4))
|
||||
return torch.zeros((1, 1024, 1024, 4))
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
class GeminiNode(ComfyNodeABC):
|
||||
class GeminiNode(IO.ComfyNode):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
|
||||
@ -228,96 +150,79 @@ class GeminiNode(ComfyNodeABC):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
|
||||
},
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNode",
|
||||
display_name="Google Gemini",
|
||||
category="api node/text/Gemini",
|
||||
description="Generate text responses with Google's Gemini AI model. "
|
||||
"You can provide multiple types of inputs (text, images, audio, video) "
|
||||
"as context for generating more relevant and meaningful responses.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text inputs to the model, used to generate a response. "
|
||||
"You can include detailed instructions, questions, or context for the model.",
|
||||
),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiModel],
|
||||
"default": GeminiModel.gemini_2_5_pro.value,
|
||||
},
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=GeminiModel,
|
||||
default=GeminiModel.gemini_2_5_pro,
|
||||
tooltip="The Gemini model to use for generating responses.",
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 42,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
||||
},
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="When seed is fixed to a specific value, the model makes a best effort to provide "
|
||||
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||
"Also, changing the model or parameter settings, such as the temperature, "
|
||||
"can cause variations in the response even when you use the same seed value. "
|
||||
"By default, a random seed value is used.",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
IO.Image.Input(
|
||||
"images",
|
||||
optional=True,
|
||||
tooltip="Optional image(s) to use as context for the model. "
|
||||
"To include multiple images, you can use the Batch Images node.",
|
||||
),
|
||||
"audio": (
|
||||
IO.AUDIO,
|
||||
{
|
||||
"tooltip": "Optional audio to use as context for the model.",
|
||||
"default": None,
|
||||
},
|
||||
IO.Audio.Input(
|
||||
"audio",
|
||||
optional=True,
|
||||
tooltip="Optional audio to use as context for the model.",
|
||||
),
|
||||
"video": (
|
||||
IO.VIDEO,
|
||||
{
|
||||
"tooltip": "Optional video to use as context for the model.",
|
||||
"default": None,
|
||||
},
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
optional=True,
|
||||
tooltip="Optional video to use as context for the model.",
|
||||
),
|
||||
"files": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
},
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
|
||||
RETURN_TYPES = ("STRING",)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
API_NODE = True
|
||||
|
||||
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert video input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
video_input: Video tensor from ComfyUI.
|
||||
**kwargs: Additional arguments to pass to the conversion function.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded video.
|
||||
"""
|
||||
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input,
|
||||
container_format=VideoContainer.MP4,
|
||||
codec=VideoCodec.H264
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert video input to Gemini API compatible parts."""
|
||||
|
||||
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
@ -327,7 +232,8 @@ class GeminiNode(ComfyNodeABC):
|
||||
)
|
||||
]
|
||||
|
||||
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
|
||||
@classmethod
|
||||
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert audio input to Gemini API compatible parts.
|
||||
|
||||
@ -340,10 +246,10 @@ class GeminiNode(ComfyNodeABC):
|
||||
audio_parts: list[GeminiPart] = []
|
||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||
audio_at_index = {
|
||||
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
"sample_rate": audio_input["sample_rate"],
|
||||
}
|
||||
audio_at_index = Input.Audio(
|
||||
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
sample_rate=audio_input["sample_rate"],
|
||||
)
|
||||
# Convert to MP3 format for compatibility with Gemini API
|
||||
audio_bytes = audio_to_base64_string(
|
||||
audio_at_index,
|
||||
@ -360,38 +266,38 @@ class GeminiNode(ComfyNodeABC):
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: GeminiModel,
|
||||
images: Optional[IO.IMAGE] = None,
|
||||
audio: Optional[IO.AUDIO] = None,
|
||||
video: Optional[IO.VIDEO] = None,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
audio: Optional[Input.Audio] = None,
|
||||
video: Optional[Input.Video] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str]:
|
||||
# Validate inputs
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
# Create parts list with text prompt as the first part
|
||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
if audio is not None:
|
||||
parts.extend(self.create_audio_parts(audio))
|
||||
parts.extend(cls.create_audio_parts(audio))
|
||||
if video is not None:
|
||||
parts.extend(self.create_video_parts(video))
|
||||
parts.extend(cls.create_video_parts(video))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
response = await SynchronousOperation(
|
||||
endpoint=get_gemini_endpoint(model),
|
||||
request=GeminiGenerateContentRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
data=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
@ -399,15 +305,15 @@ class GeminiNode(ComfyNodeABC):
|
||||
)
|
||||
]
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
# Get result output
|
||||
output_text = get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
if output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": unique_id,
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
@ -427,10 +333,10 @@ class GeminiNode(ComfyNodeABC):
|
||||
render_spec,
|
||||
)
|
||||
|
||||
return (output_text or "Empty response from Gemini model...",)
|
||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||
|
||||
|
||||
class GeminiInputFiles(ComfyNodeABC):
|
||||
class GeminiInputFiles(IO.ComfyNode):
|
||||
"""
|
||||
Loads and formats input files for use with the Gemini API.
|
||||
|
||||
@ -441,7 +347,7 @@ class GeminiInputFiles(ComfyNodeABC):
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
def define_schema(cls):
|
||||
"""
|
||||
For details about the supported file input types, see:
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
@ -456,39 +362,37 @@ class GeminiInputFiles(ComfyNodeABC):
|
||||
]
|
||||
input_files = sorted(input_files, key=lambda x: x.name)
|
||||
input_files = [f.name for f in input_files]
|
||||
return {
|
||||
"required": {
|
||||
"file": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||
"options": input_files,
|
||||
"default": input_files[0] if input_files else None,
|
||||
},
|
||||
return IO.Schema(
|
||||
node_id="GeminiInputFiles",
|
||||
display_name="Gemini Input Files",
|
||||
category="api node/text/Gemini",
|
||||
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
|
||||
"The files will be read by the Gemini model when generating a response. "
|
||||
"The contents of the text file count toward the token limit. "
|
||||
"🛈 TIP: Can be chained together with other Gemini Input File nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"file",
|
||||
options=input_files,
|
||||
default=input_files[0] if input_files else None,
|
||||
tooltip="Input files to include as context for the model. "
|
||||
"Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"GEMINI_INPUT_FILES": (
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
|
||||
"default": None,
|
||||
},
|
||||
optional=True,
|
||||
tooltip="An optional additional file(s) to batch together with the file loaded from this node. "
|
||||
"Allows chaining of input files so that a single message can include multiple input files.",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
|
||||
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
|
||||
FUNCTION = "prepare_files"
|
||||
CATEGORY = "api node/text/Gemini"
|
||||
|
||||
def create_file_part(self, file_path: str) -> GeminiPart:
|
||||
mime_type = (
|
||||
GeminiMimeType.application_pdf
|
||||
if file_path.endswith(".pdf")
|
||||
else GeminiMimeType.text_plain
|
||||
],
|
||||
outputs=[
|
||||
IO.Custom("GEMINI_INPUT_FILES").Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_file_part(cls, file_path: str) -> GeminiPart:
|
||||
mime_type = GeminiMimeType.application_pdf if file_path.endswith(".pdf") else GeminiMimeType.text_plain
|
||||
# Use base64 string directly, not the data URI
|
||||
with open(file_path, "rb") as f:
|
||||
file_content = f.read()
|
||||
@ -501,120 +405,95 @@ class GeminiInputFiles(ComfyNodeABC):
|
||||
)
|
||||
)
|
||||
|
||||
def prepare_files(
|
||||
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
|
||||
) -> tuple[list[GeminiPart]]:
|
||||
"""
|
||||
Loads and formats input files for Gemini API.
|
||||
"""
|
||||
file_path = folder_paths.get_annotated_filepath(file)
|
||||
input_file_content = self.create_file_part(file_path)
|
||||
files = [input_file_content] + GEMINI_INPUT_FILES
|
||||
return (files,)
|
||||
|
||||
|
||||
class GeminiImage(ComfyNodeABC):
|
||||
"""
|
||||
Node to generate text and image responses from a Gemini model.
|
||||
|
||||
This node allows users to interact with Google's Gemini AI models, providing
|
||||
multimodal inputs (text, images, files) to generate coherent
|
||||
text and image responses. The node works with the latest Gemini models, handling the
|
||||
API communication and response parsing.
|
||||
"""
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
||||
return {
|
||||
"required": {
|
||||
"prompt": (
|
||||
IO.STRING,
|
||||
{
|
||||
"multiline": True,
|
||||
"default": "",
|
||||
"tooltip": "Text prompt for generation",
|
||||
},
|
||||
),
|
||||
"model": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "The Gemini model to use for generating responses.",
|
||||
"options": [model.value for model in GeminiImageModel],
|
||||
"default": GeminiImageModel.gemini_2_5_flash_image.value,
|
||||
},
|
||||
),
|
||||
"seed": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 42,
|
||||
"min": 0,
|
||||
"max": 0xFFFFFFFFFFFFFFFF,
|
||||
"control_after_generate": True,
|
||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"optional": {
|
||||
"images": (
|
||||
IO.IMAGE,
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
||||
},
|
||||
),
|
||||
"files": (
|
||||
"GEMINI_INPUT_FILES",
|
||||
{
|
||||
"default": None,
|
||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
},
|
||||
),
|
||||
# TODO: later we can add this parameter later
|
||||
# "n": (
|
||||
# IO.INT,
|
||||
# {
|
||||
# "default": 1,
|
||||
# "min": 1,
|
||||
# "max": 8,
|
||||
# "step": 1,
|
||||
# "display": "number",
|
||||
# "tooltip": "How many images to generate",
|
||||
# },
|
||||
# ),
|
||||
"aspect_ratio": (
|
||||
IO.COMBO,
|
||||
{
|
||||
"tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.",
|
||||
"options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
||||
"default": "auto",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput:
|
||||
"""Loads and formats input files for Gemini API."""
|
||||
if GEMINI_INPUT_FILES is None:
|
||||
GEMINI_INPUT_FILES = []
|
||||
file_path = folder_paths.get_annotated_filepath(file)
|
||||
input_file_content = cls.create_file_part(file_path)
|
||||
return IO.NodeOutput([input_file_content] + GEMINI_INPUT_FILES)
|
||||
|
||||
RETURN_TYPES = (IO.IMAGE, IO.STRING)
|
||||
FUNCTION = "api_call"
|
||||
CATEGORY = "api node/image/Gemini"
|
||||
DESCRIPTION = "Edit images synchronously via Google API."
|
||||
API_NODE = True
|
||||
|
||||
async def api_call(
|
||||
self,
|
||||
class GeminiImage(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImageNode",
|
||||
display_name="Google Gemini Image",
|
||||
category="api node/image/Gemini",
|
||||
description="Edit images synchronously via Google API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text prompt for generation",
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=GeminiImageModel,
|
||||
default=GeminiImageModel.gemini_2_5_flash_image,
|
||||
tooltip="The Gemini model to use for generating responses.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="When seed is fixed to a specific value, the model makes a best effort to provide "
|
||||
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||
"Also, changing the model or parameter settings, such as the temperature, "
|
||||
"can cause variations in the response even when you use the same seed value. "
|
||||
"By default, a random seed value is used.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"images",
|
||||
optional=True,
|
||||
tooltip="Optional image(s) to use as context for the model. "
|
||||
"To include multiple images, you can use the Batch Images node.",
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
||||
default="auto",
|
||||
tooltip="Defaults to matching the output image size to that of your input image, "
|
||||
"or otherwise generates 1:1 squares.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: GeminiImageModel,
|
||||
images: Optional[IO.IMAGE] = None,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
n=1,
|
||||
aspect_ratio: str = "auto",
|
||||
unique_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
|
||||
if not aspect_ratio:
|
||||
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
|
||||
@ -626,29 +505,27 @@ class GeminiImage(ComfyNodeABC):
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
response = await SynchronousOperation(
|
||||
endpoint=get_gemini_image_endpoint(model),
|
||||
request=GeminiImageGenerateContentRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
parts=parts,
|
||||
),
|
||||
GeminiContent(role="user", parts=parts),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=["TEXT","IMAGE"],
|
||||
responseModalities=["TEXT", "IMAGE"],
|
||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||
)
|
||||
),
|
||||
),
|
||||
auth_kwargs=kwargs,
|
||||
).execute()
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
)
|
||||
|
||||
output_image = get_image_from_response(response)
|
||||
output_text = get_text_from_response(response)
|
||||
if unique_id and output_text:
|
||||
if output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": unique_id,
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
@ -669,17 +546,18 @@ class GeminiImage(ComfyNodeABC):
|
||||
)
|
||||
|
||||
output_text = output_text or "Empty response from Gemini model..."
|
||||
return (output_image, output_text,)
|
||||
return IO.NodeOutput(output_image, output_text)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"GeminiNode": GeminiNode,
|
||||
"GeminiImageNode": GeminiImage,
|
||||
"GeminiInputFiles": GeminiInputFiles,
|
||||
}
|
||||
class GeminiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
GeminiNode,
|
||||
GeminiImage,
|
||||
GeminiInputFiles,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"GeminiNode": "Google Gemini",
|
||||
"GeminiImageNode": "Google Gemini Image",
|
||||
"GeminiInputFiles": "Gemini Input Files",
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> GeminiExtension:
|
||||
return GeminiExtension()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from io import BytesIO
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
|
||||
IdeogramV3Request,
|
||||
IdeogramV3EditRequest,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_as_bytesio,
|
||||
resize_mask_to_image,
|
||||
sync_op,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
V1_V1_RES_MAP = {
|
||||
"Auto":"AUTO",
|
||||
@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
|
||||
|
||||
for image_url in image_urls:
|
||||
# Using functions from apinode_utils.py to handle downloading and processing
|
||||
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
|
||||
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
|
||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||
image_tensors.append(img_tensor)
|
||||
|
||||
@ -233,19 +227,6 @@ async def download_and_process_images(image_urls):
|
||||
return stacked_tensors
|
||||
|
||||
|
||||
def display_image_urls_on_node(image_urls, node_id):
|
||||
if node_id and image_urls:
|
||||
if len(image_urls) == 1:
|
||||
PromptServer.instance.send_progress_text(
|
||||
f"Generated Image URL:\n{image_urls[0]}", node_id
|
||||
)
|
||||
else:
|
||||
urls_text = "Generated Image URLs:\n" + "\n".join(
|
||||
f"{i+1}. {url}" for i, url in enumerate(image_urls)
|
||||
)
|
||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
||||
|
||||
|
||||
class IdeogramV1(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode):
|
||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||
model = "V_1_TURBO" if turbo else "V_1"
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramGenerateRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=IdeogramGenerateRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
num_images=num_images,
|
||||
seed=seed,
|
||||
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||
magic_prompt_option=(
|
||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||
),
|
||||
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode):
|
||||
else:
|
||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/ideogram/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramGenerateRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=IdeogramGenerateRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramGenerateRequest(
|
||||
image_request=ImageRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode):
|
||||
seed=seed,
|
||||
aspect_ratio=final_aspect_ratio,
|
||||
resolution=final_resolution,
|
||||
magic_prompt_option=(
|
||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
||||
),
|
||||
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||
style_type=style_type if style_type != "NONE" else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
color_palette=color_palette if color_palette else None,
|
||||
)
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode):
|
||||
character_image=None,
|
||||
character_mask=None,
|
||||
):
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
if rendering_speed == "BALANCED": # for backward compatibility
|
||||
rendering_speed = "DEFAULT"
|
||||
|
||||
@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode):
|
||||
|
||||
# Check if both image and mask are provided for editing mode
|
||||
if image is not None and mask is not None:
|
||||
# Edit mode
|
||||
path = "/proxy/ideogram/ideogram-v3/edit"
|
||||
|
||||
# Process image and mask
|
||||
input_tensor = image.squeeze().cpu()
|
||||
# Resize mask to match image dimension
|
||||
@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode):
|
||||
if character_mask_binary:
|
||||
files["character_mask_binary"] = character_mask_binary
|
||||
|
||||
# Execute the operation for edit mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramV3EditRequest,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=edit_request,
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=edit_request,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
elif image is not None or mask is not None:
|
||||
# If only one of image or mask is provided, raise an error
|
||||
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
||||
else:
|
||||
# Generation mode
|
||||
path = "/proxy/ideogram/ideogram-v3/generate"
|
||||
|
||||
# Create generation request
|
||||
gen_request = IdeogramV3Request(
|
||||
prompt=prompt,
|
||||
@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode):
|
||||
if files:
|
||||
gen_request.style_type = "AUTO"
|
||||
|
||||
# Execute the operation for generation mode
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=path,
|
||||
method=HttpMethod.POST,
|
||||
request_model=IdeogramV3Request,
|
||||
response_model=IdeogramGenerateResponse,
|
||||
),
|
||||
request=gen_request,
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=gen_request,
|
||||
files=files if files else None,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
# Execute the operation and process response
|
||||
response = await operation.execute()
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
|
||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension):
|
||||
IdeogramV3,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> IdeogramExtension:
|
||||
return IdeogramExtension()
|
||||
|
||||
@ -5,8 +5,7 @@ For source of truth on the allowed permutations of request fields, please refere
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, TypeVar, Any
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, TypeVar
|
||||
import math
|
||||
import logging
|
||||
|
||||
@ -15,7 +14,6 @@ from typing_extensions import override
|
||||
import torch
|
||||
|
||||
from comfy_api_nodes.apis import (
|
||||
KlingTaskStatus,
|
||||
KlingCameraControl,
|
||||
KlingCameraConfig,
|
||||
KlingCameraControlType,
|
||||
@ -52,26 +50,20 @@ from comfy_api_nodes.apis import (
|
||||
KlingCharacterEffectModelName,
|
||||
KlingSingleImageEffectModelName,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
tensor_to_base64_string,
|
||||
download_url_to_video_output,
|
||||
upload_video_to_comfyapi,
|
||||
upload_audio_to_comfyapi,
|
||||
download_url_to_image_tensor,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
from comfy_api_nodes.util import (
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
tensor_to_base64_string,
|
||||
validate_string,
|
||||
upload_audio_to_comfyapi,
|
||||
download_url_to_image_tensor,
|
||||
upload_video_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
sync_op,
|
||||
ApiEndpoint,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
@ -214,34 +206,6 @@ VOICES_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
KlingTaskStatus.succeed.value,
|
||||
],
|
||||
failed_statuses=[KlingTaskStatus.failed.value],
|
||||
status_extractor=lambda response: (
|
||||
response.data.task_status.value
|
||||
if response.data and response.data.task_status
|
||||
else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=256,
|
||||
).execute()
|
||||
|
||||
|
||||
def is_valid_camera_control_configs(configs: list[float]) -> bool:
|
||||
"""Verifies that at least one camera control configuration is non-zero."""
|
||||
return any(not math.isclose(value, 0.0) for value in configs)
|
||||
@ -318,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None:
|
||||
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
|
||||
"""
|
||||
validate_image_dimensions(image, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
|
||||
validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
|
||||
|
||||
|
||||
def get_video_from_response(response) -> KlingVideoResult:
|
||||
@ -377,8 +341,7 @@ async def image_result_to_node_output(
|
||||
|
||||
|
||||
async def execute_text2video(
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
cfg_scale: float,
|
||||
@ -389,14 +352,11 @@ async def execute_text2video(
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingText2VideoRequest,
|
||||
response_model=KlingText2VideoResponse,
|
||||
),
|
||||
request=KlingText2VideoRequest(
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||
response_model=KlingText2VideoResponse,
|
||||
data=KlingText2VideoRequest(
|
||||
prompt=prompt if prompt else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
duration=KlingVideoGenDuration(duration),
|
||||
@ -406,24 +366,17 @@ async def execute_text2video(
|
||||
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
|
||||
camera_control=camera_control,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
|
||||
task_id = task_creation_response.data.task_id
|
||||
final_response = await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingText2VideoResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_TEXT_TO_VIDEO}/{task_id}"),
|
||||
response_model=KlingText2VideoResponse,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
node_id=node_id,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
@ -432,8 +385,7 @@ async def execute_text2video(
|
||||
|
||||
|
||||
async def execute_image2video(
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
start_frame: torch.Tensor,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
@ -455,14 +407,11 @@ async def execute_image2video(
|
||||
if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value:
|
||||
model_mode = "pro" # October 5: currently "std" mode is not supported for this model
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingImage2VideoRequest,
|
||||
response_model=KlingImage2VideoResponse,
|
||||
),
|
||||
request=KlingImage2VideoRequest(
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=KlingImage2VideoResponse,
|
||||
data=KlingImage2VideoRequest(
|
||||
model_name=KlingVideoGenModelName(model_name),
|
||||
image=tensor_to_base64_string(start_frame),
|
||||
image_tail=(
|
||||
@ -477,24 +426,17 @@ async def execute_image2video(
|
||||
duration=KlingVideoGenDuration(duration),
|
||||
camera_control=camera_control,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=KlingImage2VideoRequest,
|
||||
response_model=KlingImage2VideoResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||
response_model=KlingImage2VideoResponse,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
node_id=node_id,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
@ -503,8 +445,7 @@ async def execute_image2video(
|
||||
|
||||
|
||||
async def execute_video_effect(
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
dual_character: bool,
|
||||
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
|
||||
model_name: str,
|
||||
@ -530,35 +471,25 @@ async def execute_video_effect(
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_VIDEO_EFFECTS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingVideoEffectsRequest,
|
||||
response_model=KlingVideoEffectsResponse,
|
||||
),
|
||||
request=KlingVideoEffectsRequest(
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_VIDEO_EFFECTS, method="POST"),
|
||||
response_model=KlingVideoEffectsResponse,
|
||||
data=KlingVideoEffectsRequest(
|
||||
effect_scene=effect_scene,
|
||||
input=request_input_field,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVideoEffectsResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_EFFECTS}/{task_id}"),
|
||||
response_model=KlingVideoEffectsResponse,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS,
|
||||
node_id=node_id,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
@ -567,8 +498,7 @@ async def execute_video_effect(
|
||||
|
||||
|
||||
async def execute_lipsync(
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
video: VideoInput,
|
||||
audio: Optional[AudioInput] = None,
|
||||
voice_language: Optional[str] = None,
|
||||
@ -583,24 +513,21 @@ async def execute_lipsync(
|
||||
validate_video_duration(video, 2, 10)
|
||||
|
||||
# Upload video to Comfy API and get download URL
|
||||
video_url = await upload_video_to_comfyapi(video, auth_kwargs=auth_kwargs)
|
||||
video_url = await upload_video_to_comfyapi(cls, video)
|
||||
logging.info("Uploaded video to Comfy API. URL: %s", video_url)
|
||||
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=auth_kwargs)
|
||||
audio_url = await upload_audio_to_comfyapi(cls, audio)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
audio_url = None
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_LIP_SYNC,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingLipSyncRequest,
|
||||
response_model=KlingLipSyncResponse,
|
||||
),
|
||||
request=KlingLipSyncRequest(
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(PATH_LIP_SYNC, "POST"),
|
||||
response_model=KlingLipSyncResponse,
|
||||
data=KlingLipSyncRequest(
|
||||
input=KlingLipSyncInputObject(
|
||||
video_url=video_url,
|
||||
mode=model_mode,
|
||||
@ -612,24 +539,17 @@ async def execute_lipsync(
|
||||
voice_id=voice_id,
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_LIP_SYNC}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingLipSyncResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_LIP_SYNC}/{task_id}"),
|
||||
response_model=KlingLipSyncResponse,
|
||||
estimated_duration=AVERAGE_DURATION_LIP_SYNC,
|
||||
node_id=node_id,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
@ -807,11 +727,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
model_mode, duration, model_name = MODE_TEXT2VIDEO[mode]
|
||||
return await execute_text2video(
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
cfg_scale=cfg_scale,
|
||||
@ -872,11 +788,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_text2video(
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
model_name=KlingVideoGenModelName.kling_v1,
|
||||
cfg_scale=cfg_scale,
|
||||
model_mode=KlingVideoGenMode.std,
|
||||
@ -944,11 +856,7 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
||||
end_frame: Optional[torch.Tensor] = None,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_image2video(
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
start_frame=start_frame,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
@ -1017,11 +925,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
|
||||
camera_control: KlingCameraControl,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_image2video(
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
model_name=KlingVideoGenModelName.kling_v1_5,
|
||||
start_frame=start_frame,
|
||||
cfg_scale=cfg_scale,
|
||||
@ -1097,11 +1001,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
mode, duration, model_name = MODE_START_END_FRAME[mode]
|
||||
return await execute_image2video(
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
model_name=model_name,
|
||||
@ -1162,41 +1062,27 @@ class KlingVideoExtendNode(IO.ComfyNode):
|
||||
video_id: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_VIDEO_EXTEND,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingVideoExtendRequest,
|
||||
response_model=KlingVideoExtendResponse,
|
||||
),
|
||||
request=KlingVideoExtendRequest(
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_VIDEO_EXTEND, method="POST"),
|
||||
response_model=KlingVideoExtendResponse,
|
||||
data=KlingVideoExtendRequest(
|
||||
prompt=prompt if prompt else None,
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
cfg_scale=cfg_scale,
|
||||
video_id=video_id,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_until_finished(
|
||||
auth,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_EXTEND}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVideoExtendResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_EXTEND}/{task_id}"),
|
||||
response_model=KlingVideoExtendResponse,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
@ -1259,11 +1145,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
|
||||
duration: KlingVideoGenDuration,
|
||||
) -> IO.NodeOutput:
|
||||
video, _, duration = await execute_video_effect(
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
dual_character=True,
|
||||
effect_scene=effect_scene,
|
||||
model_name=model_name,
|
||||
@ -1324,11 +1206,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(
|
||||
*(
|
||||
await execute_video_effect(
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
dual_character=False,
|
||||
effect_scene=effect_scene,
|
||||
model_name=model_name,
|
||||
@ -1379,11 +1257,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
|
||||
voice_language: str,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_lipsync(
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
video=video,
|
||||
audio=audio,
|
||||
voice_language=voice_language,
|
||||
@ -1445,11 +1319,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
voice_id, voice_language = VOICES_CONFIG[voice]
|
||||
return await execute_lipsync(
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
video=video,
|
||||
text=text,
|
||||
voice_language=voice_language,
|
||||
@ -1496,40 +1366,26 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
|
||||
cloth_image: torch.Tensor,
|
||||
model_name: KlingVirtualTryOnModelName,
|
||||
) -> IO.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_VIRTUAL_TRY_ON,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingVirtualTryOnRequest,
|
||||
response_model=KlingVirtualTryOnResponse,
|
||||
),
|
||||
request=KlingVirtualTryOnRequest(
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_VIRTUAL_TRY_ON, method="POST"),
|
||||
response_model=KlingVirtualTryOnResponse,
|
||||
data=KlingVirtualTryOnRequest(
|
||||
human_image=tensor_to_base64_string(human_image),
|
||||
cloth_image=tensor_to_base64_string(cloth_image),
|
||||
model_name=model_name,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_until_finished(
|
||||
auth,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingVirtualTryOnResponse,
|
||||
),
|
||||
result_url_extractor=get_images_urls_from_response,
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}"),
|
||||
response_model=KlingVirtualTryOnResponse,
|
||||
estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
@ -1625,18 +1481,11 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
else:
|
||||
image = tensor_to_base64_string(image)
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_GENERATIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=KlingImageGenerationsRequest,
|
||||
response_model=KlingImageGenerationsResponse,
|
||||
),
|
||||
request=KlingImageGenerationsRequest(
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"),
|
||||
response_model=KlingImageGenerationsResponse,
|
||||
data=KlingImageGenerationsRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
@ -1647,24 +1496,17 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
n=n,
|
||||
aspect_ratio=aspect_ratio,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_until_finished(
|
||||
auth,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=KlingImageGenerationsResponse,
|
||||
),
|
||||
result_url_extractor=get_images_urls_from_response,
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_IMAGE_GENERATIONS}/{task_id}"),
|
||||
response_model=KlingImageGenerationsResponse,
|
||||
estimated_duration=AVERAGE_DURATION_IMAGE_GEN,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
validate_image_result_response(final_response)
|
||||
|
||||
|
||||
199
comfy_api_nodes/nodes_ltxv.py
Normal file
199
comfy_api_nodes/nodes_ltxv.py
Normal file
@ -0,0 +1,199 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
sync_op_raw,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
MODELS_MAP = {
|
||||
"LTX-2 (Pro)": "ltx-2-pro",
|
||||
"LTX-2 (Fast)": "ltx-2-fast",
|
||||
}
|
||||
|
||||
|
||||
class ExecuteTaskRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
model: str = Field(...)
|
||||
duration: int = Field(...)
|
||||
resolution: str = Field(...)
|
||||
fps: Optional[int] = Field(25)
|
||||
generate_audio: Optional[bool] = Field(True)
|
||||
image_uri: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class TextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiTextToVideo",
|
||||
display_name="LTXV Text To Video",
|
||||
category="api node/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"1920x1080",
|
||||
"2560x1440",
|
||||
"3840x2160",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
resolution: str,
|
||||
fps: int = 25,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=10000)
|
||||
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||
raise ValueError(
|
||||
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||
)
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
||||
data=ExecuteTaskRequest(
|
||||
prompt=prompt,
|
||||
model=MODELS_MAP[model],
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
fps=fps,
|
||||
generate_audio=generate_audio,
|
||||
),
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class ImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiImageToVideo",
|
||||
display_name="LTXV Image To Video",
|
||||
category="api node/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution based on start image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="First frame to be used for the video."),
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"1920x1080",
|
||||
"2560x1440",
|
||||
"3840x2160",
|
||||
],
|
||||
),
|
||||
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
resolution: str,
|
||||
fps: int = 25,
|
||||
generate_audio: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=10000)
|
||||
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||
raise ValueError(
|
||||
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||
)
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Currently only one input image is supported.")
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"),
|
||||
data=ExecuteTaskRequest(
|
||||
image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0],
|
||||
prompt=prompt,
|
||||
model=MODELS_MAP[model],
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
fps=fps,
|
||||
generate_audio=generate_audio,
|
||||
),
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class LtxvApiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TextToVideoNode,
|
||||
ImageToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LtxvApiExtension:
|
||||
return LtxvApiExtension()
|
||||
@ -1,69 +1,51 @@
|
||||
from __future__ import annotations
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.luma_api import (
|
||||
LumaImageModel,
|
||||
LumaVideoModel,
|
||||
LumaVideoOutputResolution,
|
||||
LumaVideoModelOutputDuration,
|
||||
LumaAspectRatio,
|
||||
LumaState,
|
||||
LumaImageGenerationRequest,
|
||||
LumaGenerationRequest,
|
||||
LumaGeneration,
|
||||
LumaCharacterRef,
|
||||
LumaModifyImageRef,
|
||||
LumaConceptChain,
|
||||
LumaGeneration,
|
||||
LumaGenerationRequest,
|
||||
LumaImageGenerationRequest,
|
||||
LumaImageIdentity,
|
||||
LumaImageModel,
|
||||
LumaImageReference,
|
||||
LumaIO,
|
||||
LumaKeyframes,
|
||||
LumaModifyImageRef,
|
||||
LumaReference,
|
||||
LumaReferenceChain,
|
||||
LumaImageReference,
|
||||
LumaKeyframes,
|
||||
LumaConceptChain,
|
||||
LumaIO,
|
||||
LumaVideoModel,
|
||||
LumaVideoModelOutputDuration,
|
||||
LumaVideoOutputResolution,
|
||||
get_luma_concepts,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
process_image_response,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
|
||||
LUMA_T2V_AVERAGE_DURATION = 105
|
||||
LUMA_I2V_AVERAGE_DURATION = 100
|
||||
|
||||
def image_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
|
||||
|
||||
def video_result_url_extractor(response: LumaGeneration):
|
||||
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
|
||||
|
||||
class LumaReferenceNode(IO.ComfyNode):
|
||||
"""
|
||||
Holds an image and weight for use with Luma Generate Image node.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaReferenceNode",
|
||||
display_name="Luma Reference",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Holds an image and weight for use with Luma Generate Image node.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
@ -83,17 +65,10 @@ class LumaReferenceNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
|
||||
) -> IO.NodeOutput:
|
||||
def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
|
||||
if luma_ref is not None:
|
||||
luma_ref = luma_ref.clone()
|
||||
else:
|
||||
@ -103,17 +78,13 @@ class LumaReferenceNode(IO.ComfyNode):
|
||||
|
||||
|
||||
class LumaConceptsNode(IO.ComfyNode):
|
||||
"""
|
||||
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaConceptsNode",
|
||||
display_name="Luma Concepts",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"concept1",
|
||||
@ -138,11 +109,6 @@ class LumaConceptsNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -161,17 +127,13 @@ class LumaConceptsNode(IO.ComfyNode):
|
||||
|
||||
|
||||
class LumaImageGenerationNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode",
|
||||
display_name="Luma Text to Image",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -237,45 +199,30 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
seed,
|
||||
style_image_weight: float,
|
||||
image_luma_ref: LumaReferenceChain = None,
|
||||
style_image: torch.Tensor = None,
|
||||
character_image: torch.Tensor = None,
|
||||
image_luma_ref: Optional[LumaReferenceChain] = None,
|
||||
style_image: Optional[torch.Tensor] = None,
|
||||
character_image: Optional[torch.Tensor] = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# handle image_luma_ref
|
||||
api_image_ref = None
|
||||
if image_luma_ref is not None:
|
||||
api_image_ref = await cls._convert_luma_refs(
|
||||
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
|
||||
# handle style_luma_ref
|
||||
api_style_ref = None
|
||||
if style_image is not None:
|
||||
api_style_ref = await cls._convert_style_image(
|
||||
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
|
||||
# handle character_ref images
|
||||
character_ref = None
|
||||
if character_image is not None:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
character_image, max_images=4, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
character_ref = LumaCharacterRef(
|
||||
identity0=LumaImageIdentity(images=download_urls)
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
|
||||
character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations/image",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaImageGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaImageGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
aspect_ratio=aspect_ratio,
|
||||
@ -283,41 +230,21 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
||||
style_ref=api_style_ref,
|
||||
character_ref=character_ref,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return IO.NodeOutput(img)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
|
||||
|
||||
@classmethod
|
||||
async def _convert_luma_refs(
|
||||
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
|
||||
luma_urls = []
|
||||
ref_count = 0
|
||||
for ref in luma_ref.refs:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
|
||||
luma_urls.append(download_urls[0])
|
||||
ref_count += 1
|
||||
if ref_count >= max_refs:
|
||||
@ -325,27 +252,19 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||
|
||||
@classmethod
|
||||
async def _convert_style_image(
|
||||
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
||||
):
|
||||
chain = LumaReferenceChain(
|
||||
first_ref=LumaReference(image=style_image, weight=weight)
|
||||
)
|
||||
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
||||
async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
|
||||
chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
|
||||
return await cls._convert_luma_refs(chain, max_refs=1)
|
||||
|
||||
|
||||
class LumaImageModifyNode(IO.ComfyNode):
|
||||
"""
|
||||
Modifies images synchronously based on prompt and aspect ratio.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageModifyNode",
|
||||
display_name="Luma Image to Image",
|
||||
category="api node/image/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Modifies images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
@ -395,68 +314,37 @@ class LumaImageModifyNode(IO.ComfyNode):
|
||||
image_weight: float,
|
||||
seed,
|
||||
) -> IO.NodeOutput:
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# first, upload image
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
|
||||
image_url = download_urls[0]
|
||||
# next, make Luma call with download url provided
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations/image",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaImageGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaImageGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaImageGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
modify_image_ref=LumaModifyImageRef(
|
||||
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
|
||||
url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=image_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.image) as img_response:
|
||||
img = process_image_response(await img_response.content.read())
|
||||
return IO.NodeOutput(img)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
|
||||
|
||||
|
||||
class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on prompt and output_size.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaVideoNode",
|
||||
display_name="Luma Text to Video",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -498,7 +386,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||
optional=True,
|
||||
)
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
duration: str,
|
||||
loop: bool,
|
||||
seed,
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
luma_concepts: Optional[LumaConceptChain] = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
resolution=resolution,
|
||||
@ -545,47 +426,25 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
loop=loop,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
|
||||
|
||||
|
||||
class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on prompt, input images, and output_size.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaImageToVideoNode",
|
||||
display_name="Luma Image to Video",
|
||||
category="api node/video/Luma",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on prompt, input images, and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -637,7 +496,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
"luma_concepts",
|
||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||
optional=True,
|
||||
)
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
@ -662,25 +521,15 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
luma_concepts: LumaConceptChain = None,
|
||||
) -> IO.NodeOutput:
|
||||
if first_image is None and last_image is None:
|
||||
raise Exception(
|
||||
"At least one of first_image and last_image requires an input."
|
||||
)
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
|
||||
raise Exception("At least one of first_image and last_image requires an input.")
|
||||
keyframes = await cls._convert_to_keyframes(first_image, last_image)
|
||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/luma/generations",
|
||||
method=HttpMethod.POST,
|
||||
request_model=LumaGenerationRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
request=LumaGenerationRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
|
||||
response_model=LumaGeneration,
|
||||
data=LumaGenerationRequest(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
|
||||
@ -690,54 +539,31 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
keyframes=keyframes,
|
||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_api: LumaGeneration = await operation.execute()
|
||||
|
||||
if cls.hidden.unique_id:
|
||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/luma/generations/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=LumaGeneration,
|
||||
),
|
||||
completed_statuses=[LumaState.completed],
|
||||
failed_statuses=[LumaState.failed],
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||
response_model=LumaGeneration,
|
||||
status_extractor=lambda x: x.state,
|
||||
result_url_extractor=video_result_url_extractor,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.assets.video) as vid_response:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
|
||||
|
||||
@classmethod
|
||||
async def _convert_to_keyframes(
|
||||
cls,
|
||||
first_image: torch.Tensor = None,
|
||||
last_image: torch.Tensor = None,
|
||||
auth_kwargs: Optional[dict[str,str]] = None,
|
||||
):
|
||||
if first_image is None and last_image is None:
|
||||
return None
|
||||
frame0 = None
|
||||
frame1 = None
|
||||
if first_image is not None:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
|
||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||
if last_image is not None:
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
||||
)
|
||||
download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
|
||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||
|
||||
|
||||
@ -1,71 +1,57 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
import logging
|
||||
import torch
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.minimax_api import (
|
||||
MinimaxFileRetrieveResponse,
|
||||
MiniMaxModel,
|
||||
MinimaxTaskResultResponse,
|
||||
MinimaxVideoGenerationRequest,
|
||||
MinimaxVideoGenerationResponse,
|
||||
MinimaxFileRetrieveResponse,
|
||||
MinimaxTaskResultResponse,
|
||||
SubjectReferenceItem,
|
||||
MiniMaxModel,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_bytesio,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
|
||||
I2V_AVERAGE_DURATION = 114
|
||||
T2V_AVERAGE_DURATION = 234
|
||||
|
||||
|
||||
async def _generate_mm_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
*,
|
||||
auth: dict[str, str],
|
||||
node_id: str,
|
||||
prompt_text: str,
|
||||
seed: int,
|
||||
model: str,
|
||||
image: Optional[torch.Tensor] = None, # used for ImageToVideo
|
||||
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
|
||||
image: Optional[torch.Tensor] = None, # used for ImageToVideo
|
||||
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
|
||||
average_duration: Optional[int] = None,
|
||||
) -> IO.NodeOutput:
|
||||
if image is None:
|
||||
validate_string(prompt_text, field_name="prompt_text")
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if image is not None:
|
||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
||||
|
||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||
subject_reference = None
|
||||
if subject is not None:
|
||||
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
|
||||
subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
|
||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||
|
||||
|
||||
video_generate_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/video_generation",
|
||||
method=HttpMethod.POST,
|
||||
request_model=MinimaxVideoGenerationRequest,
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
data=MinimaxVideoGenerationRequest(
|
||||
model=MiniMaxModel(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
@ -73,81 +59,50 @@ async def _generate_mm_video(
|
||||
subject_reference=subject_reference,
|
||||
prompt_optimizer=None,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||
|
||||
video_generate_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/query/video_generation",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
query_params={"task_id": task_id},
|
||||
),
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
task_result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
status_extractor=lambda x: x.status.value,
|
||||
estimated_duration=average_duration,
|
||||
node_id=node_id,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
raise Exception("Request was not successful. Missing file ID.")
|
||||
file_retrieve_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/files/retrieve",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=auth,
|
||||
file_result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
)
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info("Generated video URL: %s", file_url)
|
||||
if node_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, node_id)
|
||||
|
||||
# Download and return as VideoFromFile
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return IO.NodeOutput(VideoFromFile(video_io))
|
||||
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
|
||||
if file_result.file.backup_download_url:
|
||||
try:
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
|
||||
except Exception: # if we have a second URL to retrieve the result, try again using that one
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url))
|
||||
|
||||
|
||||
class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxTextToVideoNode",
|
||||
display_name="MiniMax Text to Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on a prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt_text",
|
||||
@ -189,11 +144,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
return await _generate_mm_video(
|
||||
auth={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt_text=prompt_text,
|
||||
seed=seed,
|
||||
model=model,
|
||||
@ -204,17 +155,13 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||
|
||||
|
||||
class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxImageToVideoNode",
|
||||
display_name="MiniMax Image to Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
@ -261,11 +208,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
return await _generate_mm_video(
|
||||
auth={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt_text=prompt_text,
|
||||
seed=seed,
|
||||
model=model,
|
||||
@ -276,17 +219,13 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||
|
||||
|
||||
class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxSubjectToVideoNode",
|
||||
display_name="MiniMax Subject to Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"subject",
|
||||
@ -333,11 +272,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
return await _generate_mm_video(
|
||||
auth={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
node_id=cls.hidden.unique_id,
|
||||
cls,
|
||||
prompt_text=prompt_text,
|
||||
seed=seed,
|
||||
model=model,
|
||||
@ -348,15 +283,13 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
||||
|
||||
|
||||
class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="MinimaxHailuoVideoNode",
|
||||
display_name="MiniMax Hailuo Video",
|
||||
category="api node/video/MiniMax",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt_text",
|
||||
@ -420,10 +353,6 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||
resolution: str = "768P",
|
||||
model: str = "MiniMax-Hailuo-02",
|
||||
) -> IO.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
if first_frame_image is None:
|
||||
validate_string(prompt_text, field_name="prompt_text")
|
||||
|
||||
@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||
# upload image, if passed in
|
||||
image_url = None
|
||||
if first_frame_image is not None:
|
||||
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
|
||||
image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
|
||||
|
||||
video_generate_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/video_generation",
|
||||
method=HttpMethod.POST,
|
||||
request_model=MinimaxVideoGenerationRequest,
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
),
|
||||
request=MinimaxVideoGenerationRequest(
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
|
||||
response_model=MinimaxVideoGenerationResponse,
|
||||
data=MinimaxVideoGenerationRequest(
|
||||
model=MiniMaxModel(model),
|
||||
prompt=prompt_text,
|
||||
callback_url=None,
|
||||
@ -453,67 +379,42 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||
duration=duration,
|
||||
resolution=resolution,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response = await video_generate_operation.execute()
|
||||
|
||||
task_id = response.task_id
|
||||
if not task_id:
|
||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||
|
||||
average_duration = 120 if resolution == "768P" else 240
|
||||
video_generate_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/query/video_generation",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
query_params={"task_id": task_id},
|
||||
),
|
||||
completed_statuses=["Success"],
|
||||
failed_statuses=["Fail"],
|
||||
task_result = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
|
||||
response_model=MinimaxTaskResultResponse,
|
||||
status_extractor=lambda x: x.status.value,
|
||||
estimated_duration=average_duration,
|
||||
node_id=cls.hidden.unique_id,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_result = await video_generate_operation.execute()
|
||||
|
||||
file_id = task_result.file_id
|
||||
if file_id is None:
|
||||
raise Exception("Request was not successful. Missing file ID.")
|
||||
file_retrieve_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/minimax/files/retrieve",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
query_params={"file_id": int(file_id)},
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
auth_kwargs=auth,
|
||||
file_result = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
|
||||
response_model=MinimaxFileRetrieveResponse,
|
||||
)
|
||||
file_result = await file_retrieve_operation.execute()
|
||||
|
||||
file_url = file_result.file.download_url
|
||||
if file_url is None:
|
||||
raise Exception(
|
||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
||||
)
|
||||
logging.info("Generated video URL: %s", file_url)
|
||||
if cls.hidden.unique_id:
|
||||
if hasattr(file_result.file, "backup_download_url"):
|
||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
||||
else:
|
||||
message = f"Result URL: {file_url}"
|
||||
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
|
||||
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
|
||||
|
||||
video_io = await download_url_to_bytesio(file_url)
|
||||
if video_io is None:
|
||||
error_msg = f"Failed to download video from {file_url}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
return IO.NodeOutput(VideoFromFile(video_io))
|
||||
if file_result.file.backup_download_url:
|
||||
try:
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
|
||||
except Exception: # if we have a second URL to retrieve the result, try again using that one
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(file_url))
|
||||
|
||||
|
||||
class MinimaxExtension(ComfyExtension):
|
||||
|
||||
@ -1,35 +1,31 @@
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
|
||||
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyTextToVideoRequest,
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
MoonvalleyTextToVideoRequest,
|
||||
MoonvalleyVideoToVideoInferenceParams,
|
||||
MoonvalleyVideoToVideoRequest,
|
||||
MoonvalleyPromptResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
trim_video,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_container_format_is_mp4,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import ComfyExtension, InputImpl, IO
|
||||
import av
|
||||
import io
|
||||
|
||||
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
|
||||
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
|
||||
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
|
||||
@ -51,13 +47,6 @@ MAX_VID_HEIGHT = 10000
|
||||
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
|
||||
|
||||
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class MoonvalleyApiError(Exception):
|
||||
"""Base exception for Moonvalley API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
|
||||
@ -69,64 +58,7 @@ def validate_task_creation_response(response) -> None:
|
||||
if not is_valid_task_creation_response(response):
|
||||
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
|
||||
logging.error(error_msg)
|
||||
raise MoonvalleyApiError(error_msg)
|
||||
|
||||
|
||||
def get_video_from_response(response):
|
||||
video = response.output_url
|
||||
logging.info(
|
||||
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
|
||||
)
|
||||
return video
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
"""Returns the first video url from the Moonvalley video generation task result.
|
||||
Will not raise an error if the response is not valid.
|
||||
"""
|
||||
if response:
|
||||
return str(get_video_from_response(response))
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
"completed",
|
||||
],
|
||||
max_poll_attempts=240, # 64 minutes with 16s interval
|
||||
poll_interval=16.0,
|
||||
failed_statuses=["error"],
|
||||
status_extractor=lambda response: (
|
||||
response.status if response and response.status else None
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
node_id=node_id,
|
||||
).execute()
|
||||
|
||||
|
||||
def validate_prompts(
|
||||
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
|
||||
):
|
||||
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
|
||||
if not prompt:
|
||||
raise ValueError("Positive prompt is empty")
|
||||
if len(prompt) > max_length:
|
||||
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
|
||||
if negative_prompt and len(negative_prompt) > max_length:
|
||||
raise ValueError(
|
||||
f"Negative prompt is too long: {len(negative_prompt)} characters"
|
||||
)
|
||||
return True
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
@ -170,12 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
}
|
||||
|
||||
if (width, height) not in supported_resolutions:
|
||||
supported_list = ", ".join(
|
||||
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
|
||||
)
|
||||
raise ValueError(
|
||||
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
|
||||
)
|
||||
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
@ -188,7 +116,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
def _validate_minimum_duration(duration: float) -> None:
|
||||
"""Ensures video is at least 5 seconds long."""
|
||||
if duration < 5:
|
||||
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
|
||||
raise ValueError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
@ -198,123 +126,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
return video
|
||||
|
||||
|
||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
using av to avoid loading entire video into memory.
|
||||
|
||||
Args:
|
||||
video: Input video to trim
|
||||
duration_sec: Duration in seconds to keep from the beginning
|
||||
|
||||
Returns:
|
||||
VideoFromFile object that owns the output buffer
|
||||
"""
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
# Get the stream source - this avoids loading entire video into memory
|
||||
# when the source is already a file path
|
||||
input_source = video.get_stream_source()
|
||||
|
||||
# Open containers
|
||||
input_container = av.open(input_source, mode="r")
|
||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||
|
||||
# Set up output streams for re-encoding
|
||||
video_stream = None
|
||||
audio_stream = None
|
||||
|
||||
for stream in input_container.streams:
|
||||
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
||||
if isinstance(stream, av.VideoStream):
|
||||
# Create output video stream with same parameters
|
||||
video_stream = output_container.add_stream(
|
||||
"h264", rate=stream.average_rate
|
||||
)
|
||||
video_stream.width = stream.width
|
||||
video_stream.height = stream.height
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
logging.info(
|
||||
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
|
||||
)
|
||||
elif isinstance(stream, av.AudioStream):
|
||||
# Create output audio stream with same parameters
|
||||
audio_stream = output_container.add_stream(
|
||||
"aac", rate=stream.sample_rate
|
||||
)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
||||
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (
|
||||
estimated_frames // 16
|
||||
) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
|
||||
frame_count = 0
|
||||
audio_frame_count = 0
|
||||
|
||||
# Decode and re-encode video frames
|
||||
if video_stream:
|
||||
for frame in input_container.decode(video=0):
|
||||
if frame_count >= target_frames:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
||||
|
||||
# Decode and re-encode audio frames
|
||||
if audio_stream:
|
||||
input_container.seek(0) # Reset to beginning for audio
|
||||
for frame in input_container.decode(audio=0):
|
||||
if frame.time >= duration_sec:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
audio_frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s audio frames", audio_frame_count)
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Return as VideoFromFile using the buffer
|
||||
output_buffer.seek(0)
|
||||
return InputImpl.VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def parse_width_height_from_res(resolution: str):
|
||||
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
||||
res_map = {
|
||||
@ -338,19 +149,14 @@ def parse_control_parameter(value):
|
||||
return control_map.get(value, control_map["Motion Transfer"])
|
||||
|
||||
|
||||
async def get_response(
|
||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
||||
) -> MoonvalleyPromptResponse:
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=node_id,
|
||||
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
status_extractor=lambda r: (r.status if r and r.status else None),
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=240,
|
||||
)
|
||||
|
||||
|
||||
@ -444,14 +250,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
steps: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = parse_width_height_from_res(resolution)
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
@ -464,33 +266,17 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
|
||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||
mime_type = "image/png"
|
||||
|
||||
image_url = (
|
||||
await upload_images_to_comfyapi(
|
||||
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
|
||||
)
|
||||
)[0]
|
||||
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_IMG2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyTextToVideoRequest(
|
||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return IO.NodeOutput(video)
|
||||
|
||||
@ -582,15 +368,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
steps=33,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
validated_video = validate_video_to_video_input(video)
|
||||
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
|
||||
|
||||
validate_prompts(prompt, negative_prompt)
|
||||
video_url = await upload_video_to_comfyapi(cls, validated_video)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
|
||||
# Only include motion_intensity for Motion Transfer
|
||||
control_params = {}
|
||||
@ -605,35 +386,20 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
guidance_scale=prompt_adherence,
|
||||
)
|
||||
|
||||
control = parse_control_parameter(control_type)
|
||||
|
||||
request = MoonvalleyVideoToVideoRequest(
|
||||
control_type=control,
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_VIDEO2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyVideoToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyVideoToVideoRequest(
|
||||
control_type=parse_control_parameter(control_type),
|
||||
video_url=video_url,
|
||||
prompt_text=prompt,
|
||||
inference_params=inference_params,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
task_creation_response = await initial_operation.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return IO.NodeOutput(video)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||
|
||||
|
||||
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
@ -720,14 +486,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
steps: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||
width_height = parse_width_height_from_res(resolution)
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||
negative_prompt=negative_prompt,
|
||||
steps=steps,
|
||||
@ -737,30 +499,16 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||
width=width_height["width"],
|
||||
height=width_height["height"],
|
||||
)
|
||||
request = MoonvalleyTextToVideoRequest(
|
||||
prompt_text=prompt, inference_params=inference_params
|
||||
)
|
||||
|
||||
init_op = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=API_TXT2VIDEO_ENDPOINT,
|
||||
method=HttpMethod.POST,
|
||||
request_model=MoonvalleyTextToVideoRequest,
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth,
|
||||
task_creation_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
|
||||
response_model=MoonvalleyPromptResponse,
|
||||
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
|
||||
)
|
||||
task_creation_response = await init_op.execute()
|
||||
validate_task_creation_response(task_creation_response)
|
||||
task_id = task_creation_response.id
|
||||
|
||||
final_response = await get_response(
|
||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
||||
)
|
||||
|
||||
video = await download_url_to_video_output(final_response.output_url)
|
||||
return IO.NodeOutput(video)
|
||||
final_response = await get_response(cls, task_creation_response.id)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||
|
||||
|
||||
class MoonvalleyExtension(ComfyExtension):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -7,28 +7,23 @@ from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from typing import Optional, TypeVar
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||
from comfy_api_nodes.util import (
|
||||
validate_string,
|
||||
download_url_to_video_output,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api_nodes.apis import pika_defs
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
EmptyRequest,
|
||||
HttpMethod,
|
||||
PollingOperation,
|
||||
SynchronousOperation,
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||
@ -44,28 +39,18 @@ PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||
|
||||
|
||||
async def execute_task(
|
||||
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
|
||||
auth_kwargs: Optional[dict[str, str]] = None,
|
||||
node_id: Optional[str] = None,
|
||||
task_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
) -> IO.NodeOutput:
|
||||
task_id = (await initial_operation.execute()).video_id
|
||||
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=pika_defs.PikaVideoResponse,
|
||||
),
|
||||
completed_statuses=["finished"],
|
||||
failed_statuses=["failed", "cancelled"],
|
||||
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||
response_model=pika_defs.PikaVideoResponse,
|
||||
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
|
||||
node_id=node_id,
|
||||
estimated_duration=60,
|
||||
max_poll_attempts=240,
|
||||
).execute()
|
||||
)
|
||||
if not final_response.url:
|
||||
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||
logging.error(error_msg)
|
||||
@ -128,23 +113,15 @@ class PikaImageToVideo(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaTextToVideoNode(IO.ComfyNode):
|
||||
@ -187,18 +164,11 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
) -> IO.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
@ -206,10 +176,9 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaScenes(IO.ComfyNode):
|
||||
@ -313,24 +282,16 @@ class PikaScenes(IO.ComfyNode):
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKASCENES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikAdditionsNode(IO.ComfyNode):
|
||||
@ -387,24 +348,16 @@ class PikAdditionsNode(IO.ComfyNode):
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKADDITIONS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaSwapsNode(IO.ComfyNode):
|
||||
@ -476,23 +429,15 @@ class PikaSwapsNode(IO.ComfyNode):
|
||||
seed=seed,
|
||||
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKASWAPS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_request_data,
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaffectsNode(IO.ComfyNode):
|
||||
@ -532,18 +477,11 @@ class PikaffectsNode(IO.ComfyNode):
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFFECTS,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
pikaffect=pikaffect,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
@ -551,9 +489,8 @@ class PikaffectsNode(IO.ComfyNode):
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
@ -596,18 +533,11 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_PIKAFRAMES,
|
||||
method=HttpMethod.POST,
|
||||
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
),
|
||||
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
@ -616,9 +546,8 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaApiNodesExtension(ComfyExtension):
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
from io import BytesIO
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.pixverse_api import (
|
||||
PixverseTextVideoRequest,
|
||||
PixverseImageVideoRequest,
|
||||
@ -17,59 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import (
|
||||
PixverseIO,
|
||||
pixverse_templates,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
import torch
|
||||
import aiohttp
|
||||
|
||||
|
||||
AVERAGE_DURATION_T2V = 32
|
||||
AVERAGE_DURATION_I2V = 30
|
||||
AVERAGE_DURATION_T2T = 52
|
||||
|
||||
|
||||
def get_video_url_from_response(
|
||||
response: PixverseGenerationStatusResponse,
|
||||
) -> Optional[str]:
|
||||
if response.Resp is None or response.Resp.url is None:
|
||||
return None
|
||||
return str(response.Resp.url)
|
||||
|
||||
|
||||
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
||||
files = {"image": tensor_to_bytesio(image)}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/image/upload",
|
||||
method=HttpMethod.POST,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseImageUploadResponse,
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
files=files,
|
||||
async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
|
||||
response_upload = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
|
||||
response_model=PixverseImageUploadResponse,
|
||||
files={"image": tensor_to_bytesio(image)},
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
response_upload: PixverseImageUploadResponse = await operation.execute()
|
||||
|
||||
if response_upload.Resp is None:
|
||||
raise Exception(
|
||||
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
|
||||
)
|
||||
|
||||
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
||||
return response_upload.Resp.img_id
|
||||
|
||||
|
||||
@ -95,22 +65,17 @@ class PixverseTemplateNode(IO.ComfyNode):
|
||||
template_id = pixverse_templates.get(template, None)
|
||||
if template_id is None:
|
||||
raise Exception(f"Template '{template}' is not recognized.")
|
||||
# just return the integer
|
||||
return IO.NodeOutput(template_id)
|
||||
|
||||
|
||||
class PixverseTextToVideoNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PixverseTextToVideoNode",
|
||||
display_name="PixVerse Text to Video",
|
||||
category="api node/video/PixVerse",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -177,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
||||
negative_prompt: str = None,
|
||||
pixverse_template: int = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
if quality == PixverseQuality.res_1080p:
|
||||
@ -186,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/text/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseTextVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseTextVideoRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
|
||||
response_model=PixverseVideoResponse,
|
||||
data=PixverseTextVideoRequest(
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
quality=quality,
|
||||
@ -207,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
@ -228,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=auth,
|
||||
node_id=cls.hidden.unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||
|
||||
|
||||
class PixverseImageToVideoNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PixverseImageToVideoNode",
|
||||
display_name="PixVerse Image to Video",
|
||||
category="api node/video/PixVerse",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
@ -316,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
||||
pixverse_template: int = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
|
||||
img_id = await upload_image_to_pixverse(cls, image)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@ -330,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/img/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseImageVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseImageVideoRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
|
||||
response_model=PixverseVideoResponse,
|
||||
data=PixverseImageVideoRequest(
|
||||
img_id=img_id,
|
||||
prompt=prompt,
|
||||
quality=quality,
|
||||
@ -347,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
||||
template_id=pixverse_template,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
@ -368,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=auth,
|
||||
node_id=cls.hidden.unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||
|
||||
|
||||
class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates videos based on prompt and output_size.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PixverseTransitionVideoNode",
|
||||
display_name="PixVerse Transition Video",
|
||||
category="api node/video/PixVerse",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input("last_frame"),
|
||||
@ -452,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
negative_prompt: str = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
|
||||
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
|
||||
first_frame_id = await upload_image_to_pixverse(cls, first_frame)
|
||||
last_frame_id = await upload_image_to_pixverse(cls, last_frame)
|
||||
|
||||
# 1080p is limited to 5 seconds duration
|
||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||
@ -467,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
elif duration_seconds != PixverseDuration.dur_5:
|
||||
motion_mode = PixverseMotionMode.normal
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/pixverse/video/transition/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=PixverseTransitionVideoRequest,
|
||||
response_model=PixverseVideoResponse,
|
||||
),
|
||||
request=PixverseTransitionVideoRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
|
||||
response_model=PixverseVideoResponse,
|
||||
data=PixverseTransitionVideoRequest(
|
||||
first_frame_img=first_frame_id,
|
||||
last_frame_img=last_frame_id,
|
||||
prompt=prompt,
|
||||
@ -484,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
seed=seed,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.Resp is None:
|
||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
),
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||
response_model=PixverseGenerationStatusResponse,
|
||||
completed_statuses=[PixverseStatus.successful],
|
||||
failed_statuses=[
|
||||
PixverseStatus.contents_moderation,
|
||||
@ -505,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
PixverseStatus.deleted,
|
||||
],
|
||||
status_extractor=lambda x: x.Resp.status,
|
||||
auth_kwargs=auth,
|
||||
node_id=cls.hidden.unique_id,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
estimated_duration=AVERAGE_DURATION_T2V,
|
||||
)
|
||||
response_poll = await operation.execute()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(response_poll.Resp.url) as vid_response:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||
|
||||
|
||||
class PixVerseExtension(ComfyExtension):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -225,21 +225,20 @@ async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] =
|
||||
|
||||
|
||||
async def download_files(url_list, task_uuid):
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
|
||||
result_folder_name = f"Rodin3D_{task_uuid}"
|
||||
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
model_file_path = None
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for i in url_list.list:
|
||||
url = i.url
|
||||
file_name = i.name
|
||||
file_path = os.path.join(save_path, file_name)
|
||||
file_path = os.path.join(save_path, i.name)
|
||||
if file_path.endswith(".glb"):
|
||||
model_file_path = file_path
|
||||
model_file_path = os.path.join(result_folder_name, i.name)
|
||||
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
|
||||
max_retries = 5
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
async with session.get(i.url) as resp:
|
||||
resp.raise_for_status()
|
||||
with open(file_path, "wb") as f:
|
||||
async for chunk in resp.content.iter_chunked(32 * 1024):
|
||||
|
||||
@ -11,7 +11,7 @@ User Guides:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union, Optional, Any
|
||||
from typing import Union, Optional
|
||||
from typing_extensions import override
|
||||
from enum import Enum
|
||||
|
||||
@ -21,7 +21,6 @@ from comfy_api_nodes.apis import (
|
||||
RunwayImageToVideoRequest,
|
||||
RunwayImageToVideoResponse,
|
||||
RunwayTaskStatusResponse as TaskStatusResponse,
|
||||
RunwayTaskStatusEnum as TaskStatus,
|
||||
RunwayModelEnum as Model,
|
||||
RunwayDurationEnum as Duration,
|
||||
RunwayAspectRatioEnum as AspectRatio,
|
||||
@ -33,23 +32,20 @@ from comfy_api_nodes.apis import (
|
||||
ReferenceImage,
|
||||
RunwayTextToImageAspectRatioEnum,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
upload_images_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
upload_images_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
download_url_to_image_tensor,
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
@ -91,31 +87,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
return None
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> TaskStatusResponse:
|
||||
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[
|
||||
TaskStatus.SUCCEEDED.value,
|
||||
],
|
||||
failed_statuses=[
|
||||
TaskStatus.FAILED.value,
|
||||
TaskStatus.CANCELLED.value,
|
||||
],
|
||||
status_extractor=lambda response: response.status.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=get_video_url_from_task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
).execute()
|
||||
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> Union[float, None]:
|
||||
@ -132,42 +103,32 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
|
||||
async def get_response(
|
||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status.value,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
)
|
||||
|
||||
|
||||
async def generate_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: RunwayImageToVideoRequest,
|
||||
auth_kwargs: dict[str, str],
|
||||
node_id: Optional[str] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
) -> VideoFromFile:
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_IMAGE_TO_VIDEO,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayImageToVideoRequest,
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth_kwargs,
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=RunwayImageToVideoResponse,
|
||||
data=request,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
|
||||
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
|
||||
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
@ -184,9 +145,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
display_name="Runway Image to Video (Gen3a Turbo)",
|
||||
category="api node/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen3a Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -239,22 +200,18 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
cls,
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@ -262,15 +219,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
)
|
||||
|
||||
@ -284,9 +235,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
display_name="Runway Image to Video (Gen4 Turbo)",
|
||||
category="api node/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen4 Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -339,22 +290,18 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
start_frame,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
cls,
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@ -362,15 +309,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
duration=Duration(duration),
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
)
|
||||
]
|
||||
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
)
|
||||
)
|
||||
@ -385,12 +326,12 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
display_name="Runway First-Last-Frame to Video",
|
||||
category="api node/video/Runway",
|
||||
description="Upload first and last keyframes, draft a prompt, and generate a video. "
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
"This would give the generation more time to smoothly transition between the two inputs. "
|
||||
"Before diving in, review these best practices to ensure that your input selections "
|
||||
"will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
"This would give the generation more time to smoothly transition between the two inputs. "
|
||||
"Before diving in, review these best practices to ensure that your input selections "
|
||||
"will set your generation up for success: "
|
||||
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -449,26 +390,22 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
validate_string(prompt, min_length=1)
|
||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||
validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
|
||||
|
||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
stacked_input_images,
|
||||
max_images=2,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
cls,
|
||||
RunwayImageToVideoRequest(
|
||||
promptText=prompt,
|
||||
seed=seed,
|
||||
@ -477,17 +414,11 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
ratio=AspectRatio(ratio),
|
||||
promptImage=RunwayPromptImageObject(
|
||||
root=[
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[0]), position="first"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(
|
||||
uri=str(download_urls[1]), position="last"
|
||||
),
|
||||
RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"),
|
||||
RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"),
|
||||
]
|
||||
),
|
||||
),
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||
)
|
||||
)
|
||||
@ -502,7 +433,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
display_name="Runway Text to Image",
|
||||
category="api node/image/Runway",
|
||||
description="Generate an image from a text prompt using Runway's Gen 4 model. "
|
||||
"You can also include reference image to guide the generation.",
|
||||
"You can also include reference image to guide the generation.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -540,49 +471,34 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
auth_kwargs = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
# Prepare reference images if provided
|
||||
reference_images = None
|
||||
if reference_image is not None:
|
||||
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
||||
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
||||
validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
|
||||
download_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
reference_image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
|
||||
|
||||
request = RunwayTextToImageRequest(
|
||||
promptText=prompt,
|
||||
model=Model4.gen4_image,
|
||||
ratio=ratio,
|
||||
referenceImages=reference_images,
|
||||
)
|
||||
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=PATH_TEXT_TO_IMAGE,
|
||||
method=HttpMethod.POST,
|
||||
request_model=RunwayTextToImageRequest,
|
||||
response_model=RunwayTextToImageResponse,
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"),
|
||||
response_model=RunwayTextToImageResponse,
|
||||
data=RunwayTextToImageRequest(
|
||||
promptText=prompt,
|
||||
model=Model4.gen4_image,
|
||||
ratio=ratio,
|
||||
referenceImages=reference_images,
|
||||
),
|
||||
request=request,
|
||||
auth_kwargs=auth_kwargs,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
|
||||
# Poll for completion
|
||||
final_response = await get_response(
|
||||
cls,
|
||||
initial_response.id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
)
|
||||
if not final_response.output:
|
||||
@ -601,5 +517,6 @@ class RunwayExtension(ComfyExtension):
|
||||
RunwayTextToImageNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> RunwayExtension:
|
||||
return RunwayExtension()
|
||||
|
||||
@ -1,23 +1,20 @@
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import get_number_of_images
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_bytesio,
|
||||
)
|
||||
|
||||
|
||||
class Sora2GenerationRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
model: str = Field(...)
|
||||
@ -80,7 +77,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
||||
control_after_generate=True,
|
||||
optional=True,
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
@ -111,55 +108,34 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Currently only one input image is supported.")
|
||||
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload = Sora2GenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
seconds=str(duration),
|
||||
size=size,
|
||||
)
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/openai/v1/videos",
|
||||
method=HttpMethod.POST,
|
||||
request_model=Sora2GenerationRequest,
|
||||
response_model=Sora2GenerationResponse
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"),
|
||||
data=Sora2GenerationRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
seconds=str(duration),
|
||||
size=size,
|
||||
),
|
||||
request=payload,
|
||||
files=files_input,
|
||||
auth_kwargs=auth,
|
||||
response_model=Sora2GenerationResponse,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
initial_response = await initial_operation.execute()
|
||||
if initial_response.error:
|
||||
raise Exception(initial_response.error.message)
|
||||
raise Exception(initial_response.error["message"])
|
||||
|
||||
model_time_multiplier = 1 if model == "sora-2" else 2
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/openai/v1/videos/{initial_response.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=Sora2GenerationResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=["failed"],
|
||||
await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"),
|
||||
response_model=Sora2GenerationResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
auth_kwargs=auth,
|
||||
poll_interval=8.0,
|
||||
max_poll_attempts=160,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=45 * (duration / 4) * model_time_multiplier,
|
||||
estimated_duration=int(45 * (duration / 4) * model_time_multiplier),
|
||||
)
|
||||
await poll_operation.execute()
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_video_output(
|
||||
f"/proxy/openai/v1/videos/{initial_response.id}/content",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -20,21 +20,17 @@ from comfy_api_nodes.apis.stability_api import (
|
||||
StabilityAudioInpaintRequest,
|
||||
StabilityAudioResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
from comfy_api_nodes.util import (
|
||||
validate_audio_duration,
|
||||
validate_string,
|
||||
audio_input_to_mp3,
|
||||
bytesio_to_image_tensor,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
audio_bytes_to_audio_input,
|
||||
audio_input_to_mp3,
|
||||
sync_op,
|
||||
poll_op,
|
||||
ApiEndpoint,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import validate_audio_duration
|
||||
|
||||
import torch
|
||||
import base64
|
||||
@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/generate/ultra",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityStableUltraRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityStableUltraRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityStableUltraRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||
@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/generate/sd3",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityStable3_5Request,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityStable3_5Request(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityStable3_5Request(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||
@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityUpscaleConservativeRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=StabilityUpscaleConservativeRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
data=StabilityUpscaleConservativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||
@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/creative",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityUpscaleCreativeRequest,
|
||||
response_model=StabilityAsyncResponse,
|
||||
),
|
||||
request=StabilityUpscaleCreativeRequest(
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
|
||||
response_model=StabilityAsyncResponse,
|
||||
data=StabilityUpscaleCreativeRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
creativity=round(creativity,2),
|
||||
@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||
),
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/stability/v2beta/results/{response_api.id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=StabilityResultsGetResponse,
|
||||
),
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
|
||||
response_model=StabilityResultsGetResponse,
|
||||
poll_interval=3,
|
||||
completed_statuses=[StabilityPollStatus.finished],
|
||||
failed_statuses=[StabilityPollStatus.failed],
|
||||
status_extractor=lambda x: get_async_dummy_status(x),
|
||||
auth_kwargs=auth,
|
||||
node_id=cls.hidden.unique_id,
|
||||
)
|
||||
response_poll: StabilityResultsGetResponse = await operation.execute()
|
||||
|
||||
if response_poll.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||
@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
|
||||
"image": image_binary
|
||||
}
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/stable-image/upscale/fast",
|
||||
method=HttpMethod.POST,
|
||||
request_model=EmptyRequest,
|
||||
response_model=StabilityStableUltraResponse,
|
||||
),
|
||||
request=EmptyRequest(),
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
|
||||
response_model=StabilityStableUltraResponse,
|
||||
files=files,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
|
||||
if response_api.finish_reason != "SUCCESS":
|
||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||
@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode):
|
||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
||||
validate_string(prompt, max_length=10000)
|
||||
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityTextToAudioRequest,
|
||||
response_model=StabilityAudioResponse,
|
||||
),
|
||||
request=payload,
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
auth_kwargs= {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode):
|
||||
payload = StabilityAudioToAudioRequest(
|
||||
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityAudioToAudioRequest,
|
||||
response_model=StabilityAudioResponse,
|
||||
),
|
||||
request=payload,
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
auth_kwargs= {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode):
|
||||
mask_start=mask_start,
|
||||
mask_end=mask_end,
|
||||
)
|
||||
operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
|
||||
method=HttpMethod.POST,
|
||||
request_model=StabilityAudioInpaintRequest,
|
||||
response_model=StabilityAudioResponse,
|
||||
),
|
||||
request=payload,
|
||||
response_api = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
|
||||
response_model=StabilityAudioResponse,
|
||||
data=payload,
|
||||
content_type="multipart/form-data",
|
||||
files={"audio": audio_input_to_mp3(audio)},
|
||||
auth_kwargs={
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
)
|
||||
response_api = await operation.execute()
|
||||
if not response_api.audio:
|
||||
raise ValueError("No audio file was received in response.")
|
||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,28 +1,21 @@
|
||||
import logging
|
||||
import base64
|
||||
import aiohttp
|
||||
import torch
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api_nodes.apis import (
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis.veo_api import (
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse,
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
)
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
downscale_image_tensor,
|
||||
download_url_to_video_output,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
)
|
||||
|
||||
@ -35,28 +28,6 @@ MODELS_MAP = {
|
||||
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
|
||||
}
|
||||
|
||||
def convert_image_to_base64(image: torch.Tensor):
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
|
||||
return tensor_to_base64_string(scaled_image)
|
||||
|
||||
|
||||
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
and poll_response.response.videos
|
||||
and len(poll_response.response.videos) > 0
|
||||
):
|
||||
video = poll_response.response.videos[0]
|
||||
else:
|
||||
return None
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return str(video.gcsUri)
|
||||
return None
|
||||
|
||||
|
||||
class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
"""
|
||||
@ -169,18 +140,13 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
# Prepare the instances for the request
|
||||
instances = []
|
||||
|
||||
instance = {
|
||||
"prompt": prompt
|
||||
}
|
||||
instance = {"prompt": prompt}
|
||||
|
||||
# Add image if provided
|
||||
if image is not None:
|
||||
image_base64 = convert_image_to_base64(image)
|
||||
image_base64 = tensor_to_base64_string(image)
|
||||
if image_base64:
|
||||
instance["image"] = {
|
||||
"bytesBase64Encoded": image_base64,
|
||||
"mimeType": "image/png"
|
||||
}
|
||||
instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
|
||||
|
||||
instances.append(instance)
|
||||
|
||||
@ -198,119 +164,77 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
if seed > 0:
|
||||
parameters["seed"] = seed
|
||||
# Only add generateAudio for Veo 3 models
|
||||
if "veo-3.0" in model:
|
||||
if model.find("veo-2.0") == -1:
|
||||
parameters["generateAudio"] = generate_audio
|
||||
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
# Initial request to start video generation
|
||||
initial_operation = SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=f"/proxy/veo/{model}/generate",
|
||||
method=HttpMethod.POST,
|
||||
request_model=VeoGenVidRequest,
|
||||
response_model=VeoGenVidResponse
|
||||
),
|
||||
request=VeoGenVidRequest(
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
|
||||
response_model=VeoGenVidResponse,
|
||||
data=VeoGenVidRequest(
|
||||
instances=instances,
|
||||
parameters=parameters
|
||||
parameters=parameters,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
|
||||
initial_response = await initial_operation.execute()
|
||||
operation_name = initial_response.name
|
||||
|
||||
logging.info("Veo generation started with operation name: %s", operation_name)
|
||||
|
||||
# Define status extractor function
|
||||
def status_extractor(response):
|
||||
# Only return "completed" if the operation is done, regardless of success or failure
|
||||
# We'll check for errors after polling completes
|
||||
return "completed" if response.done else "pending"
|
||||
|
||||
# Define progress extractor function
|
||||
def progress_extractor(response):
|
||||
# Could be enhanced if the API provides progress information
|
||||
return None
|
||||
|
||||
# Define the polling operation
|
||||
poll_operation = PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/veo/{model}/poll",
|
||||
method=HttpMethod.POST,
|
||||
request_model=VeoGenVidPollRequest,
|
||||
response_model=VeoGenVidPollResponse
|
||||
),
|
||||
completed_statuses=["completed"],
|
||||
failed_statuses=[], # No failed statuses, we'll handle errors after polling
|
||||
poll_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
|
||||
response_model=VeoGenVidPollResponse,
|
||||
status_extractor=status_extractor,
|
||||
progress_extractor=progress_extractor,
|
||||
request=VeoGenVidPollRequest(
|
||||
operationName=operation_name
|
||||
data=VeoGenVidPollRequest(
|
||||
operationName=initial_response.name,
|
||||
),
|
||||
auth_kwargs=auth,
|
||||
poll_interval=5.0,
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
node_id=cls.hidden.unique_id,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
# Execute the polling operation
|
||||
poll_response = await poll_operation.execute()
|
||||
|
||||
# Now check for errors in the final response
|
||||
# Check for error in poll response
|
||||
if hasattr(poll_response, 'error') and poll_response.error:
|
||||
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
if poll_response.error:
|
||||
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
|
||||
|
||||
# Check for RAI filtered content
|
||||
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
|
||||
poll_response.response.raiMediaFilteredCount > 0):
|
||||
if (
|
||||
hasattr(poll_response.response, "raiMediaFilteredCount")
|
||||
and poll_response.response.raiMediaFilteredCount > 0
|
||||
):
|
||||
|
||||
# Extract reason message if available
|
||||
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
|
||||
poll_response.response.raiMediaFilteredReasons):
|
||||
if (
|
||||
hasattr(poll_response.response, "raiMediaFilteredReasons")
|
||||
and poll_response.response.raiMediaFilteredReasons
|
||||
):
|
||||
reason = poll_response.response.raiMediaFilteredReasons[0]
|
||||
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||
else:
|
||||
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||
|
||||
logging.error(error_message)
|
||||
raise Exception(error_message)
|
||||
|
||||
# Extract video data
|
||||
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
||||
if (
|
||||
poll_response.response
|
||||
and hasattr(poll_response.response, "videos")
|
||||
and poll_response.response.videos
|
||||
and len(poll_response.response.videos) > 0
|
||||
):
|
||||
video = poll_response.response.videos[0]
|
||||
|
||||
# Check if video is provided as base64 or URL
|
||||
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
|
||||
# Decode base64 string to bytes
|
||||
video_data = base64.b64decode(video.bytesBase64Encoded)
|
||||
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
||||
# Download from URL
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(video.gcsUri) as video_response:
|
||||
video_data = await video_response.content.read()
|
||||
else:
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
else:
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
|
||||
if not video_data:
|
||||
raise Exception("No video data was returned")
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
|
||||
logging.info("Video generation completed successfully")
|
||||
|
||||
# Convert video data to BytesIO object
|
||||
video_io = BytesIO(video_data)
|
||||
|
||||
# Return VideoFromFile object
|
||||
return IO.NodeOutput(VideoFromFile(video_io))
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
|
||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
@ -394,7 +318,10 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[
|
||||
"veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.0-generate-001", "veo-3.0-fast-generate-001"
|
||||
"veo-3.1-generate",
|
||||
"veo-3.1-fast-generate",
|
||||
"veo-3.0-generate-001",
|
||||
"veo-3.0-fast-generate-001",
|
||||
],
|
||||
default="veo-3.0-generate-001",
|
||||
tooltip="Veo 3 model to use for video generation",
|
||||
@ -427,5 +354,6 @@ class VeoExtension(ComfyExtension):
|
||||
Veo3VideoGenerationNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> VeoExtension:
|
||||
return VeoExtension()
|
||||
|
||||
@ -1,27 +1,23 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Literal, TypeVar
|
||||
from typing_extensions import override
|
||||
from typing import Literal, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api_nodes.util.validation_utils import (
|
||||
validate_aspect_ratio_closeness,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio_range,
|
||||
get_number_of_images,
|
||||
)
|
||||
from comfy_api_nodes.apis.client import (
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_images_aspect_ratio_closeness,
|
||||
)
|
||||
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
|
||||
|
||||
|
||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
|
||||
@ -31,8 +27,9 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class VideoModelName(str, Enum):
|
||||
vidu_q1 = 'viduq1'
|
||||
vidu_q1 = "viduq1"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
@ -63,17 +60,9 @@ class TaskCreationRequest(BaseModel):
|
||||
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
created = "created"
|
||||
queueing = "queueing"
|
||||
processing = "processing"
|
||||
success = "success"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
task_id: str = Field(...)
|
||||
state: TaskStatus = Field(...)
|
||||
state: str = Field(...)
|
||||
created_at: str = Field(...)
|
||||
code: Optional[int] = Field(None, description="Error code")
|
||||
|
||||
@ -85,32 +74,11 @@ class TaskResult(BaseModel):
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
state: TaskStatus = Field(...)
|
||||
state: str = Field(...)
|
||||
err_code: Optional[str] = Field(None)
|
||||
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||
|
||||
|
||||
async def poll_until_finished(
|
||||
auth_kwargs: dict[str, str],
|
||||
api_endpoint: ApiEndpoint[Any, R],
|
||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
||||
estimated_duration: Optional[int] = None,
|
||||
node_id: Optional[str] = None,
|
||||
) -> R:
|
||||
return await PollingOperation(
|
||||
poll_endpoint=api_endpoint,
|
||||
completed_statuses=[TaskStatus.success.value],
|
||||
failed_statuses=[TaskStatus.failed.value],
|
||||
status_extractor=lambda response: response.state.value,
|
||||
auth_kwargs=auth_kwargs,
|
||||
result_url_extractor=result_url_extractor,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
poll_interval=16.0,
|
||||
max_poll_attempts=256,
|
||||
).execute()
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
if response.creations:
|
||||
return response.creations[0].url
|
||||
@ -127,37 +95,27 @@ def get_video_from_response(response) -> TaskResult:
|
||||
|
||||
|
||||
async def execute_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
vidu_endpoint: str,
|
||||
auth_kwargs: Optional[dict[str, str]],
|
||||
payload: TaskCreationRequest,
|
||||
estimated_duration: int,
|
||||
node_id: str,
|
||||
) -> R:
|
||||
response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=vidu_endpoint,
|
||||
method=HttpMethod.POST,
|
||||
request_model=TaskCreationRequest,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
if response.state == TaskStatus.failed:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=payload,
|
||||
)
|
||||
if response.state == "failed":
|
||||
error_msg = f"Vidu request failed. Code: {response.code}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
return await poll_until_finished(
|
||||
auth_kwargs,
|
||||
ApiEndpoint(
|
||||
path=VIDU_GET_GENERATION_STATUS % response.task_id,
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=TaskStatusResponse,
|
||||
),
|
||||
result_url_extractor=get_video_url_from_response,
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.state,
|
||||
estimated_duration=estimated_duration,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
|
||||
@ -258,11 +216,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
|
||||
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
@ -353,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) > 1:
|
||||
raise ValueError("Only one input image is allowed.")
|
||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
@ -362,17 +316,13 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
max_images=1,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
@ -473,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
if a > 7:
|
||||
raise ValueError("Too many images, maximum allowed is 7.")
|
||||
for image in images:
|
||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
||||
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
@ -484,17 +434,13 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
images,
|
||||
max_images=7,
|
||||
mime_type="image/png",
|
||||
auth_kwargs=auth,
|
||||
)
|
||||
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
||||
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
@ -587,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
resolution: str,
|
||||
movement_amplitude: str,
|
||||
) -> IO.NodeOutput:
|
||||
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||
payload = TaskCreationRequest(
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
@ -596,15 +542,11 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
resolution=resolution,
|
||||
movement_amplitude=movement_amplitude,
|
||||
)
|
||||
auth = {
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
}
|
||||
payload.images = [
|
||||
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
|
||||
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||
for frame in (first_frame, end_frame)
|
||||
]
|
||||
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
|
||||
results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||
|
||||
|
||||
@ -618,5 +560,6 @@ class ViduExtension(ComfyExtension):
|
||||
ViduStartEndToVideoNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ViduExtension:
|
||||
return ViduExtension()
|
||||
|
||||
@ -1,28 +1,24 @@
|
||||
import re
|
||||
from typing import Optional, Type, Union
|
||||
from typing_extensions import override
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from comfy_api.latest import ComfyExtension, Input, IO
|
||||
from comfy_api_nodes.apis.client import (
|
||||
ApiEndpoint,
|
||||
HttpMethod,
|
||||
SynchronousOperation,
|
||||
PollingOperation,
|
||||
EmptyRequest,
|
||||
R,
|
||||
T,
|
||||
)
|
||||
from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api_nodes.apinode_utils import (
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
audio_to_base64_string,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
audio_to_base64_string,
|
||||
validate_audio_duration,
|
||||
)
|
||||
|
||||
|
||||
class Text2ImageInputField(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
negative_prompt: Optional[str] = Field(None)
|
||||
@ -146,53 +142,7 @@ class VideoTaskStatusResponse(BaseModel):
|
||||
request_id: str = Field(...)
|
||||
|
||||
|
||||
RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)')
|
||||
|
||||
|
||||
async def process_task(
|
||||
auth_kwargs: dict[str, str],
|
||||
url: str,
|
||||
request_model: Type[T],
|
||||
response_model: Type[R],
|
||||
payload: Union[
|
||||
Text2ImageTaskCreationRequest,
|
||||
Image2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
Image2VideoTaskCreationRequest,
|
||||
],
|
||||
node_id: str,
|
||||
estimated_duration: int,
|
||||
poll_interval: int,
|
||||
) -> Type[R]:
|
||||
initial_response = await SynchronousOperation(
|
||||
endpoint=ApiEndpoint(
|
||||
path=url,
|
||||
method=HttpMethod.POST,
|
||||
request_model=request_model,
|
||||
response_model=TaskCreationResponse,
|
||||
),
|
||||
request=payload,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
|
||||
return await PollingOperation(
|
||||
poll_endpoint=ApiEndpoint(
|
||||
path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}",
|
||||
method=HttpMethod.GET,
|
||||
request_model=EmptyRequest,
|
||||
response_model=response_model,
|
||||
),
|
||||
completed_statuses=["SUCCEEDED"],
|
||||
failed_statuses=["FAILED", "CANCELED", "UNKNOWN"],
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=estimated_duration,
|
||||
poll_interval=poll_interval,
|
||||
node_id=node_id,
|
||||
auth_kwargs=auth_kwargs,
|
||||
).execute()
|
||||
RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)")
|
||||
|
||||
|
||||
class WanTextToImageApi(IO.ComfyNode):
|
||||
@ -259,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -286,26 +236,28 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
prompt_extend: bool = True,
|
||||
watermark: bool = True,
|
||||
):
|
||||
payload = Text2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
|
||||
parameters=Txt2ImageParametersField(
|
||||
size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Text2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
|
||||
parameters=Txt2ImageParametersField(
|
||||
size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/text2image/image-synthesis",
|
||||
request_model=Text2ImageTaskCreationRequest,
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response_model=ImageTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=9,
|
||||
poll_interval=3,
|
||||
)
|
||||
@ -320,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
display_name="Wan Image to Image",
|
||||
category="api node/image/Wan",
|
||||
description="Generates an image from one or two input images and a text prompt. "
|
||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
@ -376,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -408,28 +360,30 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
||||
images = []
|
||||
for i in image:
|
||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
|
||||
payload = Image2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
|
||||
parameters=Image2ImageParametersField(
|
||||
# size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Image2ImageTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
|
||||
parameters=Image2ImageParametersField(
|
||||
# size=f"{width}*{height}",
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
),
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
|
||||
request_model=Image2ImageTaskCreationRequest,
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response_model=ImageTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=42,
|
||||
poll_interval=3,
|
||||
poll_interval=4,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||
|
||||
@ -523,7 +477,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -557,28 +511,31 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
if audio is not None:
|
||||
validate_audio_duration(audio, 3.0, 29.0)
|
||||
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
||||
payload = Text2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
|
||||
parameters=Text2VideoParametersField(
|
||||
size=f"{width}*{height}",
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Text2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
|
||||
parameters=Text2VideoParametersField(
|
||||
size=f"{width}*{height}",
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
|
||||
request_model=Text2VideoTaskCreationRequest,
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response_model=VideoTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=120 * int(duration / 5),
|
||||
poll_interval=6,
|
||||
)
|
||||
@ -667,7 +624,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
IO.Boolean.Input(
|
||||
"watermark",
|
||||
default=True,
|
||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
||||
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
@ -699,35 +656,37 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
):
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Exactly one input image is required.")
|
||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000)
|
||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
||||
audio_url = None
|
||||
if audio is not None:
|
||||
validate_audio_duration(audio, 3.0, 29.0)
|
||||
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
||||
payload = Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2VideoInputField(
|
||||
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
|
||||
),
|
||||
parameters=Image2VideoParametersField(
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||
response_model=TaskCreationResponse,
|
||||
data=Image2VideoTaskCreationRequest(
|
||||
model=model,
|
||||
input=Image2VideoInputField(
|
||||
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
|
||||
),
|
||||
parameters=Image2VideoParametersField(
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
seed=seed,
|
||||
audio=generate_audio,
|
||||
prompt_extend=prompt_extend,
|
||||
watermark=watermark,
|
||||
),
|
||||
),
|
||||
)
|
||||
response = await process_task(
|
||||
{
|
||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
||||
},
|
||||
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
|
||||
request_model=Image2VideoTaskCreationRequest,
|
||||
if not initial_response.output:
|
||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||
response_model=VideoTaskStatusResponse,
|
||||
payload=payload,
|
||||
node_id=cls.hidden.unique_id,
|
||||
status_extractor=lambda x: x.output.task_status,
|
||||
estimated_duration=120 * int(duration / 5),
|
||||
poll_interval=6,
|
||||
)
|
||||
|
||||
@ -0,0 +1,97 @@
|
||||
from ._helpers import get_fs_object_size
|
||||
from .client import (
|
||||
ApiEndpoint,
|
||||
poll_op,
|
||||
poll_op_raw,
|
||||
sync_op,
|
||||
sync_op_raw,
|
||||
)
|
||||
from .conversions import (
|
||||
audio_bytes_to_audio_input,
|
||||
audio_input_to_mp3,
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
downscale_image_tensor,
|
||||
image_tensor_pair_to_batch,
|
||||
pil_to_bytesio,
|
||||
resize_mask_to_image,
|
||||
tensor_to_base64_string,
|
||||
tensor_to_bytesio,
|
||||
tensor_to_pil,
|
||||
text_filepath_to_base64_string,
|
||||
text_filepath_to_data_uri,
|
||||
trim_video,
|
||||
video_to_base64_string,
|
||||
)
|
||||
from .download_helpers import (
|
||||
download_url_as_bytesio,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
)
|
||||
from .upload_helpers import (
|
||||
upload_audio_to_comfyapi,
|
||||
upload_file_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
)
|
||||
from .validation_utils import (
|
||||
get_number_of_images,
|
||||
validate_aspect_ratio_string,
|
||||
validate_audio_duration,
|
||||
validate_container_format_is_mp4,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_images_aspect_ratio_closeness,
|
||||
validate_string,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# API client
|
||||
"ApiEndpoint",
|
||||
"poll_op",
|
||||
"poll_op_raw",
|
||||
"sync_op",
|
||||
"sync_op_raw",
|
||||
# Upload helpers
|
||||
"upload_audio_to_comfyapi",
|
||||
"upload_file_to_comfyapi",
|
||||
"upload_images_to_comfyapi",
|
||||
"upload_video_to_comfyapi",
|
||||
# Download helpers
|
||||
"download_url_as_bytesio",
|
||||
"download_url_to_bytesio",
|
||||
"download_url_to_image_tensor",
|
||||
"download_url_to_video_output",
|
||||
# Conversions
|
||||
"audio_bytes_to_audio_input",
|
||||
"audio_input_to_mp3",
|
||||
"audio_to_base64_string",
|
||||
"bytesio_to_image_tensor",
|
||||
"downscale_image_tensor",
|
||||
"image_tensor_pair_to_batch",
|
||||
"pil_to_bytesio",
|
||||
"resize_mask_to_image",
|
||||
"tensor_to_base64_string",
|
||||
"tensor_to_bytesio",
|
||||
"tensor_to_pil",
|
||||
"text_filepath_to_base64_string",
|
||||
"text_filepath_to_data_uri",
|
||||
"trim_video",
|
||||
"video_to_base64_string",
|
||||
# Validation utilities
|
||||
"get_number_of_images",
|
||||
"validate_aspect_ratio_string",
|
||||
"validate_audio_duration",
|
||||
"validate_container_format_is_mp4",
|
||||
"validate_image_aspect_ratio",
|
||||
"validate_image_dimensions",
|
||||
"validate_images_aspect_ratio_closeness",
|
||||
"validate_string",
|
||||
"validate_video_dimensions",
|
||||
"validate_video_duration",
|
||||
# Misc functions
|
||||
"get_fs_object_size",
|
||||
]
|
||||
71
comfy_api_nodes/util/_helpers.py
Normal file
71
comfy_api_nodes/util/_helpers.py
Normal file
@ -0,0 +1,71 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.model_management import processing_interrupted
|
||||
from comfy_api.latest import IO
|
||||
|
||||
from .common_exceptions import ProcessingInterrupted
|
||||
|
||||
|
||||
def is_processing_interrupted() -> bool:
|
||||
"""Return True if user/runtime requested interruption."""
|
||||
return processing_interrupted()
|
||||
|
||||
|
||||
def get_node_id(node_cls: type[IO.ComfyNode]) -> str:
|
||||
return node_cls.hidden.unique_id
|
||||
|
||||
|
||||
def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
if node_cls.hidden.auth_token_comfy_org:
|
||||
return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"}
|
||||
if node_cls.hidden.api_key_comfy_org:
|
||||
return {"X-API-KEY": node_cls.hidden.api_key_comfy_org}
|
||||
return {}
|
||||
|
||||
|
||||
def default_base_url() -> str:
|
||||
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
||||
|
||||
|
||||
async def sleep_with_interrupt(
|
||||
seconds: float,
|
||||
node_cls: Optional[type[IO.ComfyNode]],
|
||||
label: Optional[str] = None,
|
||||
start_ts: Optional[float] = None,
|
||||
estimated_total: Optional[int] = None,
|
||||
*,
|
||||
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
|
||||
):
|
||||
"""
|
||||
Sleep in 1s slices while:
|
||||
- Checking for interruption (raises ProcessingInterrupted).
|
||||
- Optionally emitting time progress via display_callback (if provided).
|
||||
"""
|
||||
end = time.monotonic() + seconds
|
||||
while True:
|
||||
if is_processing_interrupted():
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
now = time.monotonic()
|
||||
if start_ts is not None and label and display_callback:
|
||||
with contextlib.suppress(Exception):
|
||||
display_callback(node_cls, label, int(now - start_ts), estimated_total)
|
||||
if now >= end:
|
||||
break
|
||||
await asyncio.sleep(min(1.0, end - now))
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
|
||||
if isinstance(path_or_object, str):
|
||||
return os.path.getsize(path_or_object)
|
||||
return len(path_or_object.getvalue())
|
||||
936
comfy_api_nodes/util/client.py
Normal file
936
comfy_api_nodes/util/client.py
Normal file
@ -0,0 +1,936 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
from pydantic import BaseModel
|
||||
|
||||
from comfy import utils
|
||||
from comfy_api.latest import IO
|
||||
from comfy_api_nodes.apis import request_logger
|
||||
from server import PromptServer
|
||||
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
)
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
|
||||
|
||||
class ApiEndpoint:
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||||
*,
|
||||
query_params: Optional[dict[str, Any]] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
):
|
||||
self.path = path
|
||||
self.method = method
|
||||
self.query_params = query_params or {}
|
||||
self.headers = headers or {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RequestConfig:
|
||||
node_cls: type[IO.ComfyNode]
|
||||
endpoint: ApiEndpoint
|
||||
timeout: float
|
||||
content_type: str
|
||||
data: Optional[dict[str, Any]]
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
||||
multipart_parser: Optional[Callable]
|
||||
max_retries: int
|
||||
retry_delay: float
|
||||
retry_backoff: float
|
||||
wait_label: str = "Waiting"
|
||||
monitor_progress: bool = True
|
||||
estimated_total: Optional[int] = None
|
||||
final_label_on_success: Optional[str] = "Completed"
|
||||
progress_origin_ts: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PollUIState:
|
||||
started: float
|
||||
status_label: str = "Queued"
|
||||
is_queued: bool = True
|
||||
price: Optional[float] = None
|
||||
estimated_duration: Optional[int] = None
|
||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||||
|
||||
|
||||
async def sync_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
data: Optional[BaseModel] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Optional[Callable] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: Optional[int] = None,
|
||||
final_label_on_success: Optional[str] = "Completed",
|
||||
progress_origin_ts: Optional[float] = None,
|
||||
monitor_progress: bool = True,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
endpoint,
|
||||
data=data,
|
||||
files=files,
|
||||
content_type=content_type,
|
||||
timeout=timeout,
|
||||
multipart_parser=multipart_parser,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
retry_backoff=retry_backoff,
|
||||
wait_label=wait_label,
|
||||
estimated_duration=estimated_duration,
|
||||
as_binary=False,
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
monitor_progress=monitor_progress,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
return _validate_or_raise(response_model, raw)
|
||||
|
||||
|
||||
async def poll_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
||||
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||
data: Optional[BaseModel] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: Optional[int] = None,
|
||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
) -> M:
|
||||
raw = await poll_op_raw(
|
||||
cls,
|
||||
poll_endpoint=poll_endpoint,
|
||||
status_extractor=_wrap_model_extractor(response_model, status_extractor),
|
||||
progress_extractor=_wrap_model_extractor(response_model, progress_extractor),
|
||||
price_extractor=_wrap_model_extractor(response_model, price_extractor),
|
||||
completed_statuses=completed_statuses,
|
||||
failed_statuses=failed_statuses,
|
||||
queued_statuses=queued_statuses,
|
||||
data=data,
|
||||
poll_interval=poll_interval,
|
||||
max_poll_attempts=max_poll_attempts,
|
||||
timeout_per_poll=timeout_per_poll,
|
||||
max_retries_per_poll=max_retries_per_poll,
|
||||
retry_delay_per_poll=retry_delay_per_poll,
|
||||
retry_backoff_per_poll=retry_backoff_per_poll,
|
||||
estimated_duration=estimated_duration,
|
||||
cancel_endpoint=cancel_endpoint,
|
||||
cancel_timeout=cancel_timeout,
|
||||
)
|
||||
if not isinstance(raw, dict):
|
||||
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||
return _validate_or_raise(response_model, raw)
|
||||
|
||||
|
||||
async def sync_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Optional[Callable] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: Optional[int] = None,
|
||||
as_binary: bool = False,
|
||||
final_label_on_success: Optional[str] = "Completed",
|
||||
progress_origin_ts: Optional[float] = None,
|
||||
monitor_progress: bool = True,
|
||||
) -> Union[dict[str, Any], bytes]:
|
||||
"""
|
||||
Make a single network request.
|
||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||
- If as_binary=True: returns bytes.
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump(exclude_none=True)
|
||||
for k, v in list(data.items()):
|
||||
if isinstance(v, Enum):
|
||||
data[k] = v.value
|
||||
cfg = _RequestConfig(
|
||||
node_cls=cls,
|
||||
endpoint=endpoint,
|
||||
timeout=timeout,
|
||||
content_type=content_type,
|
||||
data=data,
|
||||
files=files,
|
||||
multipart_parser=multipart_parser,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
retry_backoff=retry_backoff,
|
||||
wait_label=wait_label,
|
||||
monitor_progress=monitor_progress,
|
||||
estimated_total=estimated_duration,
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
|
||||
async def poll_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
||||
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: Optional[int] = None,
|
||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||||
checks interruption every second, and calls Cancel endpoint (if provided) on interruption.
|
||||
|
||||
Uses default complete, failed and queued states assumption.
|
||||
|
||||
Returns the final JSON response from the poll endpoint.
|
||||
"""
|
||||
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
||||
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
|
||||
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
|
||||
started = time.monotonic()
|
||||
consumed_attempts = 0 # counts only non-queued polls
|
||||
|
||||
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||||
last_progress: Optional[int] = None
|
||||
|
||||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||
stop_ticker = asyncio.Event()
|
||||
|
||||
async def _ticker():
|
||||
"""Emit a UI update every second while polling is in progress."""
|
||||
try:
|
||||
while not stop_ticker.is_set():
|
||||
if is_processing_interrupted():
|
||||
break
|
||||
now = time.monotonic()
|
||||
proc_elapsed = state.base_processing_elapsed + (
|
||||
(now - state.active_since) if state.active_since is not None else 0.0
|
||||
)
|
||||
_display_time_progress(
|
||||
cls,
|
||||
status=state.status_label,
|
||||
elapsed_seconds=int(now - state.started),
|
||||
estimated_total=state.estimated_duration,
|
||||
price=state.price,
|
||||
is_queued=state.is_queued,
|
||||
processing_elapsed_seconds=int(proc_elapsed),
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except Exception as exc:
|
||||
logging.debug("Polling ticker exited: %s", exc)
|
||||
|
||||
ticker_task = asyncio.create_task(_ticker())
|
||||
try:
|
||||
while consumed_attempts < max_poll_attempts:
|
||||
try:
|
||||
resp_json = await sync_op_raw(
|
||||
cls,
|
||||
poll_endpoint,
|
||||
data=data,
|
||||
timeout=timeout_per_poll,
|
||||
max_retries=max_retries_per_poll,
|
||||
retry_delay=retry_delay_per_poll,
|
||||
retry_backoff=retry_backoff_per_poll,
|
||||
wait_label="Checking",
|
||||
estimated_duration=None,
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
if not isinstance(resp_json, dict):
|
||||
raise Exception("Polling endpoint returned non-JSON response.")
|
||||
except ProcessingInterrupted:
|
||||
if cancel_endpoint:
|
||||
with contextlib.suppress(Exception):
|
||||
await sync_op_raw(
|
||||
cls,
|
||||
cancel_endpoint,
|
||||
timeout=cancel_timeout,
|
||||
max_retries=0,
|
||||
wait_label="Cancelling task",
|
||||
estimated_duration=None,
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
raise
|
||||
|
||||
try:
|
||||
status = _normalize_status_value(status_extractor(resp_json))
|
||||
except Exception as e:
|
||||
logging.error("Status extraction failed: %s", e)
|
||||
status = None
|
||||
|
||||
if price_extractor:
|
||||
new_price = price_extractor(resp_json)
|
||||
if new_price is not None:
|
||||
state.price = new_price
|
||||
|
||||
if progress_extractor:
|
||||
new_progress = progress_extractor(resp_json)
|
||||
if new_progress is not None and last_progress != new_progress:
|
||||
progress_bar.update_absolute(new_progress, total=100)
|
||||
last_progress = new_progress
|
||||
|
||||
now_ts = time.monotonic()
|
||||
is_queued = status in queued_states
|
||||
|
||||
if is_queued:
|
||||
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
|
||||
state.base_processing_elapsed += now_ts - state.active_since
|
||||
state.active_since = None
|
||||
else:
|
||||
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
|
||||
state.active_since = now_ts
|
||||
|
||||
state.is_queued = is_queued
|
||||
state.status_label = status or ("Queued" if is_queued else "Processing")
|
||||
if status in completed_states:
|
||||
if state.active_since is not None:
|
||||
state.base_processing_elapsed += now_ts - state.active_since
|
||||
state.active_since = None
|
||||
stop_ticker.set()
|
||||
with contextlib.suppress(Exception):
|
||||
await ticker_task
|
||||
|
||||
if progress_bar and last_progress != 100:
|
||||
progress_bar.update_absolute(100, total=100)
|
||||
|
||||
_display_time_progress(
|
||||
cls,
|
||||
status=status if status else "Completed",
|
||||
elapsed_seconds=int(now_ts - started),
|
||||
estimated_total=estimated_duration,
|
||||
price=state.price,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=int(state.base_processing_elapsed),
|
||||
)
|
||||
return resp_json
|
||||
|
||||
if status in failed_states:
|
||||
msg = f"Task failed: {json.dumps(resp_json)}"
|
||||
logging.error(msg)
|
||||
raise Exception(msg)
|
||||
|
||||
try:
|
||||
await sleep_with_interrupt(poll_interval, cls, None, None, None)
|
||||
except ProcessingInterrupted:
|
||||
if cancel_endpoint:
|
||||
with contextlib.suppress(Exception):
|
||||
await sync_op_raw(
|
||||
cls,
|
||||
cancel_endpoint,
|
||||
timeout=cancel_timeout,
|
||||
max_retries=0,
|
||||
wait_label="Cancelling task",
|
||||
estimated_duration=None,
|
||||
as_binary=False,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
raise
|
||||
if not is_queued:
|
||||
consumed_attempts += 1
|
||||
|
||||
raise Exception(
|
||||
f"Polling timed out after {max_poll_attempts} non-queued attempts "
|
||||
f"(~{int(max_poll_attempts * poll_interval)}s of active polling)."
|
||||
)
|
||||
except ProcessingInterrupted:
|
||||
raise
|
||||
except (LocalNetworkError, ApiServerError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise Exception(f"Polling aborted due to error: {e}") from e
|
||||
finally:
|
||||
stop_ticker.set()
|
||||
with contextlib.suppress(Exception):
|
||||
await ticker_task
|
||||
|
||||
|
||||
def _display_text(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
text: Optional[str],
|
||||
*,
|
||||
status: Optional[Union[str, int]] = None,
|
||||
price: Optional[float] = None,
|
||||
) -> None:
|
||||
display_lines: list[str] = []
|
||||
if status:
|
||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||
if price is not None:
|
||||
display_lines.append(f"Price: ${float(price):,.4f}")
|
||||
if text is not None:
|
||||
display_lines.append(text)
|
||||
if display_lines:
|
||||
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
|
||||
|
||||
|
||||
def _display_time_progress(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
status: Optional[Union[str, int]],
|
||||
elapsed_seconds: int,
|
||||
estimated_total: Optional[int] = None,
|
||||
*,
|
||||
price: Optional[float] = None,
|
||||
is_queued: Optional[bool] = None,
|
||||
processing_elapsed_seconds: Optional[int] = None,
|
||||
) -> None:
|
||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||
remaining = max(0, int(estimated_total) - int(pe))
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||||
else:
|
||||
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
|
||||
_display_text(node_cls, time_line, status=status, price=price)
|
||||
|
||||
|
||||
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
||||
results = {
|
||||
"internet_accessible": False,
|
||||
"api_accessible": False,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get("https://www.google.com") as resp:
|
||||
results["internet_accessible"] = resp.status < 500
|
||||
if not results["internet_accessible"]:
|
||||
return results
|
||||
|
||||
parsed = urlparse(default_base_url())
|
||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get(health_url) as resp:
|
||||
results["api_accessible"] = resp.status < 500
|
||||
return results
|
||||
|
||||
|
||||
def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
||||
"""Normalize (filename, value, content_type)."""
|
||||
if len(t) == 2:
|
||||
return t[0], t[1], "application/octet-stream"
|
||||
if len(t) == 3:
|
||||
return t[0], t[1], t[2]
|
||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||
|
||||
|
||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
||||
params = dict(endpoint_params or {})
|
||||
if method.upper() == "GET" and data:
|
||||
for k, v in data.items():
|
||||
if v is not None:
|
||||
params[k] = v
|
||||
return params
|
||||
|
||||
|
||||
def _friendly_http_message(status: int, body: Any) -> str:
|
||||
if status == 401:
|
||||
return "Unauthorized: Please login first to use this node."
|
||||
if status == 402:
|
||||
return "Payment Required: Please add credits to your account to use this node."
|
||||
if status == 409:
|
||||
return "There is a problem with your account. Please contact support@comfy.org."
|
||||
if status == 429:
|
||||
return "Rate Limit Exceeded: Please try again later."
|
||||
try:
|
||||
if isinstance(body, dict):
|
||||
err = body.get("error")
|
||||
if isinstance(err, dict):
|
||||
msg = err.get("message")
|
||||
typ = err.get("type")
|
||||
if msg and typ:
|
||||
return f"API Error: {msg} (Type: {typ})"
|
||||
if msg:
|
||||
return f"API Error: {msg}"
|
||||
return f"API Error: {json.dumps(body)}"
|
||||
else:
|
||||
txt = str(body)
|
||||
if len(txt) <= 200:
|
||||
return f"API Error (raw): {txt}"
|
||||
return f"API Error (status {status})"
|
||||
except Exception:
|
||||
return f"HTTP {status}: Unknown error"
|
||||
|
||||
|
||||
def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
||||
slug = path.strip("/").replace("/", "_") or "op"
|
||||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def _snapshot_request_body_for_logging(
|
||||
content_type: str,
|
||||
method: str,
|
||||
data: Optional[dict[str, Any]],
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
||||
) -> Optional[Union[dict[str, Any], str]]:
|
||||
if method.upper() == "GET":
|
||||
return None
|
||||
if content_type == "multipart/form-data":
|
||||
form_fields = sorted([k for k, v in (data or {}).items() if v is not None])
|
||||
file_fields: list[dict[str, str]] = []
|
||||
if files:
|
||||
file_iter = files if isinstance(files, list) else list(files.items())
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue
|
||||
if isinstance(file_obj, tuple):
|
||||
filename = file_obj[0]
|
||||
else:
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
file_fields.append({"field": field_name, "filename": str(filename or "")})
|
||||
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
|
||||
if content_type == "application/x-www-form-urlencoded":
|
||||
return data or {}
|
||||
return data or {}
|
||||
|
||||
|
||||
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
|
||||
url = cfg.endpoint.path
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
|
||||
method = cfg.endpoint.method
|
||||
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
||||
|
||||
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
|
||||
"""Every second: update elapsed time and signal interruption."""
|
||||
try:
|
||||
while not stop_evt.is_set():
|
||||
if is_processing_interrupted():
|
||||
return
|
||||
if cfg.monitor_progress:
|
||||
_display_time_progress(
|
||||
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
|
||||
)
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return # normal shutdown
|
||||
|
||||
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: Optional[int] = None
|
||||
while True:
|
||||
attempt += 1
|
||||
stop_event = asyncio.Event()
|
||||
monitor_task: Optional[asyncio.Task] = None
|
||||
sess: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||
|
||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
if cfg.endpoint.headers:
|
||||
payload_headers.update(cfg.endpoint.headers)
|
||||
|
||||
payload_kw: dict[str, Any] = {"headers": payload_headers}
|
||||
if method == "GET":
|
||||
payload_headers.pop("Content-Type", None)
|
||||
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
|
||||
try:
|
||||
if cfg.monitor_progress:
|
||||
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=cfg.timeout)
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
|
||||
if cfg.content_type == "multipart/form-data" and method != "GET":
|
||||
# aiohttp will set Content-Type boundary; remove any fixed Content-Type
|
||||
payload_headers.pop("Content-Type", None)
|
||||
if cfg.multipart_parser and cfg.data:
|
||||
form = cfg.multipart_parser(cfg.data)
|
||||
if not isinstance(form, aiohttp.FormData):
|
||||
raise ValueError("multipart_parser must return aiohttp.FormData")
|
||||
else:
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
if cfg.data:
|
||||
for k, v in cfg.data.items():
|
||||
if v is None:
|
||||
continue
|
||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||||
if cfg.files:
|
||||
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||||
for field_name, file_obj in file_iter:
|
||||
if file_obj is None:
|
||||
continue
|
||||
if isinstance(file_obj, tuple):
|
||||
filename, file_value, content_type = _unpack_tuple(file_obj)
|
||||
else:
|
||||
filename = getattr(file_obj, "name", field_name)
|
||||
file_value = file_obj
|
||||
content_type = "application/octet-stream"
|
||||
# Attempt to rewind BytesIO for retries
|
||||
if isinstance(file_value, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
file_value.seek(0)
|
||||
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
|
||||
payload_kw["data"] = form
|
||||
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
|
||||
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
payload_kw["data"] = cfg.data or {}
|
||||
elif method != "GET":
|
||||
payload_headers["Content-Type"] = "application/json"
|
||||
payload_kw["json"] = cfg.data or {}
|
||||
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||||
|
||||
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||||
req_task = asyncio.create_task(req_coro)
|
||||
|
||||
# Race: request vs. monitor (interruption)
|
||||
tasks = {req_task}
|
||||
if monitor_task:
|
||||
tasks.add(monitor_task)
|
||||
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task and monitor_task in done:
|
||||
# Interrupted – cancel the request and abort
|
||||
if req_task in pending:
|
||||
req_task.cancel()
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
|
||||
# Otherwise, request finished
|
||||
resp = await req_task
|
||||
async with resp:
|
||||
if resp.status >= 400:
|
||||
try:
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
body = await resp.text()
|
||||
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||||
method,
|
||||
url,
|
||||
resp.status,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=_friendly_http_message(resp.status, body),
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
msg = _friendly_http_message(resp.status, body)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
raise Exception(msg)
|
||||
|
||||
if expect_binary:
|
||||
buff = bytearray()
|
||||
last_tick = time.monotonic()
|
||||
async for chunk in resp.content.iter_chunked(64 * 1024):
|
||||
buff.extend(chunk)
|
||||
now = time.monotonic()
|
||||
if now - last_tick >= 1.0:
|
||||
last_tick = now
|
||||
if is_processing_interrupted():
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
if cfg.monitor_progress:
|
||||
_display_time_progress(
|
||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||
)
|
||||
bytes_payload = bytes(buff)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
return bytes_payload
|
||||
else:
|
||||
try:
|
||||
payload = await resp.json()
|
||||
response_content_to_log: Any = payload
|
||||
except (ContentTypeError, json.JSONDecodeError):
|
||||
text = await resp.text()
|
||||
try:
|
||||
payload = json.loads(text) if text else {}
|
||||
except json.JSONDecodeError:
|
||||
payload = {"_raw": text}
|
||||
response_content_to_log = payload if isinstance(payload, dict) else text
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=response_content_to_log,
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||
return payload
|
||||
|
||||
except ProcessingInterrupted:
|
||||
logging.debug("Polling was interrupted by user")
|
||||
raise
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= cfg.max_retries:
|
||||
logging.warning(
|
||||
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||||
method,
|
||||
url,
|
||||
delay,
|
||||
attempt,
|
||||
cfg.max_retries,
|
||||
str(e),
|
||||
)
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cfg.node_cls,
|
||||
cfg.wait_label if cfg.monitor_progress else None,
|
||||
start_time if cfg.monitor_progress else None,
|
||||
cfg.estimated_total,
|
||||
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||
)
|
||||
delay *= cfg.retry_backoff
|
||||
continue
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"LocalNetworkError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the API server due to local network issues. "
|
||||
"Please check your internet connection and try again."
|
||||
) from e
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
request_headers=dict(payload_headers) if payload_headers else None,
|
||||
request_params=dict(params) if params else None,
|
||||
request_data=request_body_log,
|
||||
error_message=f"ApiServerError: {str(e)}",
|
||||
)
|
||||
except Exception as _log_e:
|
||||
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||
raise ApiServerError(
|
||||
f"The API server at {default_base_url()} is currently unreachable. "
|
||||
f"The service may be experiencing issues."
|
||||
) from e
|
||||
finally:
|
||||
stop_event.set()
|
||||
if monitor_task:
|
||||
monitor_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await monitor_task
|
||||
if sess:
|
||||
with contextlib.suppress(Exception):
|
||||
await sess.close()
|
||||
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
|
||||
_display_time_progress(
|
||||
cfg.node_cls,
|
||||
status=cfg.final_label_on_success,
|
||||
elapsed_seconds=(
|
||||
final_elapsed_seconds
|
||||
if final_elapsed_seconds is not None
|
||||
else int(time.monotonic() - start_time)
|
||||
),
|
||||
estimated_total=cfg.estimated_total,
|
||||
price=None,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=final_elapsed_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
||||
try:
|
||||
return response_model.model_validate(payload)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
"Response validation failed for %s: %s",
|
||||
getattr(response_model, "__name__", response_model),
|
||||
e,
|
||||
)
|
||||
raise Exception(
|
||||
f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def _wrap_model_extractor(
|
||||
response_model: Type[M],
|
||||
extractor: Optional[Callable[[M], Any]],
|
||||
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
||||
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||||
Validates the dict into `response_model` before invoking `extractor`.
|
||||
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||||
the same response for multiple extractors in a single poll attempt.
|
||||
"""
|
||||
if extractor is None:
|
||||
return None
|
||||
_cache: dict[int, M] = {}
|
||||
|
||||
def _wrapped(d: dict[str, Any]) -> Any:
|
||||
try:
|
||||
key = id(d)
|
||||
model = _cache.get(key)
|
||||
if model is None:
|
||||
model = response_model.model_validate(d)
|
||||
_cache[key] = model
|
||||
return extractor(model)
|
||||
except Exception as e:
|
||||
logging.error("Extractor failed (typed -> dict wrapper): %s", e)
|
||||
raise
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
||||
if not values:
|
||||
return set()
|
||||
out: set[Union[str, int]] = set()
|
||||
for v in values:
|
||||
nv = _normalize_status_value(v)
|
||||
if nv is not None:
|
||||
out.add(nv)
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
||||
if isinstance(val, str):
|
||||
return val.strip().lower()
|
||||
return val
|
||||
14
comfy_api_nodes/util/common_exceptions.py
Normal file
14
comfy_api_nodes/util/common_exceptions.py
Normal file
@ -0,0 +1,14 @@
|
||||
class NetworkError(Exception):
|
||||
"""Base exception for network-related errors with diagnostic information."""
|
||||
|
||||
|
||||
class LocalNetworkError(NetworkError):
|
||||
"""Exception raised when local network connectivity issues are detected."""
|
||||
|
||||
|
||||
class ApiServerError(NetworkError):
|
||||
"""Exception raised when the API server is unreachable but internet is working."""
|
||||
|
||||
|
||||
class ProcessingInterrupted(Exception):
|
||||
"""Operation was interrupted by user/runtime via processing_interrupted()."""
|
||||
470
comfy_api_nodes/util/conversions.py
Normal file
470
comfy_api_nodes/util/conversions.py
Normal file
@ -0,0 +1,470 @@
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import mimetypes
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import Input, InputImpl
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
|
||||
from ._helpers import mimetype_to_extension
|
||||
|
||||
|
||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||
"""Converts image data from BytesIO to a torch.Tensor.
|
||||
|
||||
Args:
|
||||
image_bytesio: BytesIO object containing the image data.
|
||||
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
|
||||
Raises:
|
||||
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||
ValueError: If the specified mode is invalid.
|
||||
"""
|
||||
image = Image.open(image_bytesio)
|
||||
image = image.convert(mode)
|
||||
image_array = np.array(image).astype(np.float32) / 255.0
|
||||
return torch.from_numpy(image_array).unsqueeze(0)
|
||||
|
||||
|
||||
def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts a pair of image tensors to a batch tensor.
|
||||
If the images are not the same size, the smaller image is resized to
|
||||
match the larger image.
|
||||
"""
|
||||
if image1.shape[1:] != image2.shape[1:]:
|
||||
image2 = common_upscale(
|
||||
image2.movedim(-1, 1),
|
||||
image1.shape[2],
|
||||
image1.shape[1],
|
||||
"bilinear",
|
||||
"center",
|
||||
).movedim(1, -1)
|
||||
return torch.cat((image1, image2), dim=0)
|
||||
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||
|
||||
Args:
|
||||
image: Input torch.Tensor image.
|
||||
name: Optional filename for the BytesIO object.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Named BytesIO object containing the image data, with pointer set to the start of buffer.
|
||||
"""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
|
||||
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||
return img_binary
|
||||
|
||||
|
||||
def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||
if len(image.shape) > 3:
|
||||
image = image[0]
|
||||
# TODO: remove alpha if not allowed and present
|
||||
input_tensor = image.cpu()
|
||||
input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze()
|
||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||
img = Image.fromarray(image_np)
|
||||
return img
|
||||
|
||||
|
||||
def tensor_to_base64_string(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||
|
||||
Returns:
|
||||
Base64 encoded string of the image.
|
||||
"""
|
||||
pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||
img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
# Encode bytes to base64 string
|
||||
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||
return base64_encoded_string
|
||||
|
||||
|
||||
def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||
"""Converts a PIL Image to a BytesIO object."""
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
|
||||
img_byte_arr = BytesIO()
|
||||
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||
pil_format = mime_type.split("/")[-1].upper()
|
||||
if pil_format == "JPG":
|
||||
pil_format = "JPEG"
|
||||
img.save(img_byte_arr, format=pil_format)
|
||||
img_byte_arr.seek(0)
|
||||
return img_byte_arr
|
||||
|
||||
|
||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(total_pixels)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
if scale_by >= 1:
|
||||
return image
|
||||
width = round(samples.shape[3] * scale_by)
|
||||
height = round(samples.shape[2] * scale_by)
|
||||
|
||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||
s = s.movedim(1, -1)
|
||||
return s
|
||||
|
||||
|
||||
def tensor_to_data_uri(
|
||||
image_tensor: torch.Tensor,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> str:
|
||||
"""Converts a tensor image to a Data URI string.
|
||||
|
||||
Args:
|
||||
image_tensor: Input torch.Tensor image.
|
||||
total_pixels: Maximum total pixels for potential downscaling.
|
||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||
|
||||
Returns:
|
||||
Data URI string (e.g., 'data:image/png;base64,...').
|
||||
"""
|
||||
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
|
||||
|
||||
def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str:
|
||||
"""Converts an audio input to a base64 string."""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||
audio_bytes = audio_bytes_io.getvalue()
|
||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
|
||||
def video_to_base64_string(
|
||||
video: Input.Video,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
|
||||
Args:
|
||||
video: The video input to convert
|
||||
container_format: Optional container format to use (defaults to video.container if available)
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def audio_ndarray_to_bytesio(
|
||||
audio_data_np: np.ndarray,
|
||||
sample_rate: int,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
) -> BytesIO:
|
||||
"""
|
||||
Encodes a numpy array of audio data into a BytesIO object.
|
||||
"""
|
||||
audio_bytes_io = BytesIO()
|
||||
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
audio_data_np,
|
||||
format="fltp",
|
||||
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
|
||||
# Flush stream
|
||||
for packet in audio_stream.encode(None):
|
||||
output_container.mux(packet)
|
||||
|
||||
audio_bytes_io.seek(0)
|
||||
return audio_bytes_io
|
||||
|
||||
|
||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||
|
||||
Args:
|
||||
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||
|
||||
Returns:
|
||||
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||
the first item is taken.
|
||||
"""
|
||||
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||
|
||||
# If batch is > 1, take first item
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform[0]
|
||||
|
||||
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||
if audio_data_np.dtype != np.float32:
|
||||
audio_data_np = audio_data_np.astype(np.float32)
|
||||
|
||||
return audio_data_np
|
||||
|
||||
|
||||
def audio_input_to_mp3(audio: Input.Audio) -> BytesIO:
|
||||
waveform = audio["waveform"].cpu()
|
||||
|
||||
output_buffer = BytesIO()
|
||||
output_container = av.open(output_buffer, mode="w", format="mp3")
|
||||
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||
out_stream.bit_rate = 320000
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||
format="flt",
|
||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||
)
|
||||
frame.sample_rate = audio["sample_rate"]
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
output_container.mux(out_stream.encode(None))
|
||||
output_container.close()
|
||||
output_buffer.seek(0)
|
||||
return output_buffer
|
||||
|
||||
|
||||
def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
|
||||
"""
|
||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||
using av to avoid loading entire video into memory.
|
||||
|
||||
Args:
|
||||
video: Input video to trim
|
||||
duration_sec: Duration in seconds to keep from the beginning
|
||||
|
||||
Returns:
|
||||
VideoFromFile object that owns the output buffer
|
||||
"""
|
||||
output_buffer = BytesIO()
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
try:
|
||||
# Get the stream source - this avoids loading entire video into memory
|
||||
# when the source is already a file path
|
||||
input_source = video.get_stream_source()
|
||||
|
||||
# Open containers
|
||||
input_container = av.open(input_source, mode="r")
|
||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||
|
||||
# Set up output streams for re-encoding
|
||||
video_stream = None
|
||||
audio_stream = None
|
||||
|
||||
for stream in input_container.streams:
|
||||
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
||||
if isinstance(stream, av.VideoStream):
|
||||
# Create output video stream with same parameters
|
||||
video_stream = output_container.add_stream("h264", rate=stream.average_rate)
|
||||
video_stream.width = stream.width
|
||||
video_stream.height = stream.height
|
||||
video_stream.pix_fmt = "yuv420p"
|
||||
logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate)
|
||||
elif isinstance(stream, av.AudioStream):
|
||||
# Create output audio stream with same parameters
|
||||
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
|
||||
audio_stream.sample_rate = stream.sample_rate
|
||||
audio_stream.layout = stream.layout
|
||||
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
||||
|
||||
# Calculate target frame count that's divisible by 16
|
||||
fps = input_container.streams.video[0].average_rate
|
||||
estimated_frames = int(duration_sec * fps)
|
||||
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
||||
|
||||
if target_frames == 0:
|
||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||
|
||||
frame_count = 0
|
||||
audio_frame_count = 0
|
||||
|
||||
# Decode and re-encode video frames
|
||||
if video_stream:
|
||||
for frame in input_container.decode(video=0):
|
||||
if frame_count >= target_frames:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
||||
|
||||
# Decode and re-encode audio frames
|
||||
if audio_stream:
|
||||
input_container.seek(0) # Reset to beginning for audio
|
||||
for frame in input_container.decode(audio=0):
|
||||
if frame.time >= duration_sec:
|
||||
break
|
||||
|
||||
# Re-encode frame
|
||||
for packet in audio_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
audio_frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
for packet in audio_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
logging.info("Encoded %s audio frames", audio_frame_count)
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
input_container.close()
|
||||
|
||||
# Return as VideoFromFile using the buffer
|
||||
output_buffer.seek(0)
|
||||
return InputImpl.VideoFromFile(output_buffer)
|
||||
|
||||
except Exception as e:
|
||||
# Clean up on error
|
||||
if input_container is not None:
|
||||
input_container.close()
|
||||
if output_container is not None:
|
||||
output_container.close()
|
||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||
|
||||
|
||||
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
elif wav.dtype == torch.int16:
|
||||
return wav.float() / (2**15)
|
||||
elif wav.dtype == torch.int32:
|
||||
return wav.float() / (2**31)
|
||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||
|
||||
|
||||
def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
|
||||
"""
|
||||
Decode any common audio container from bytes using PyAV and return
|
||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||
"""
|
||||
with av.open(BytesIO(audio_bytes)) as af:
|
||||
if not af.streams.audio:
|
||||
raise ValueError("No audio stream found in response.")
|
||||
stream = af.streams.audio[0]
|
||||
|
||||
in_sr = int(stream.codec_context.sample_rate)
|
||||
out_sr = in_sr
|
||||
|
||||
frames: list[torch.Tensor] = []
|
||||
n_channels = stream.channels or 1
|
||||
|
||||
for frame in af.decode(streams=stream.index):
|
||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||
buf = torch.from_numpy(arr)
|
||||
if buf.ndim == 1:
|
||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||
elif buf.shape[0] != n_channels:
|
||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||
frames.append(buf)
|
||||
|
||||
if not frames:
|
||||
raise ValueError("Decoded zero audio frames.")
|
||||
|
||||
wav = torch.cat(frames, dim=1) # [C, T]
|
||||
wav = _f32_pcm(wav)
|
||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||
|
||||
|
||||
def resize_mask_to_image(
|
||||
mask: torch.Tensor,
|
||||
image: torch.Tensor,
|
||||
upscale_method="nearest-exact",
|
||||
crop="disabled",
|
||||
allow_gradient=True,
|
||||
add_channel_dim=False,
|
||||
):
|
||||
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
|
||||
_, height, width, _ = image.shape
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = mask.movedim(-1, 1)
|
||||
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
|
||||
mask = mask.movedim(1, -1)
|
||||
if not add_channel_dim:
|
||||
mask = mask.squeeze(-1)
|
||||
if not allow_gradient:
|
||||
mask = (mask > 0.5).float()
|
||||
return mask
|
||||
|
||||
|
||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||
"""Converts a text file to a base64 string."""
|
||||
with open(filepath, "rb") as f:
|
||||
file_content = f.read()
|
||||
return base64.b64encode(file_content).decode("utf-8")
|
||||
|
||||
|
||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||
"""Converts a text file to a data URI."""
|
||||
base64_string = text_filepath_to_base64_string(filepath)
|
||||
mime_type, _ = mimetypes.guess_type(filepath)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
return f"data:{mime_type};base64,{base64_string}"
|
||||
262
comfy_api_nodes/util/download_helpers.py
Normal file
262
comfy_api_nodes/util/download_helpers.py
Normal file
@ -0,0 +1,262 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import IO, Optional, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO as COMFY_IO
|
||||
from comfy_api_nodes.apis import request_logger
|
||||
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
)
|
||||
from .client import _diagnose_connectivity
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
from .conversions import bytesio_to_image_tensor
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
|
||||
|
||||
async def download_url_to_bytesio(
|
||||
url: str,
|
||||
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
max_retries: int = 5,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> None:
|
||||
"""Stream-download a URL to `dest`.
|
||||
|
||||
`dest` must be one of:
|
||||
- a BytesIO (rewound to 0 after write),
|
||||
- a file-like object opened in binary write mode (must implement .write()),
|
||||
- a filesystem path (str | pathlib.Path), which will be opened with 'wb'.
|
||||
|
||||
If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded
|
||||
to an absolute URL and authentication headers can be applied.
|
||||
|
||||
Raises:
|
||||
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
|
||||
"""
|
||||
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
|
||||
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
|
||||
|
||||
attempt = 0
|
||||
delay = retry_delay
|
||||
headers: dict[str, str] = {}
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
if cls is None:
|
||||
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
headers = get_auth_header(cls)
|
||||
|
||||
while True:
|
||||
attempt += 1
|
||||
op_id = _generate_operation_id("GET", url, attempt)
|
||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout)
|
||||
|
||||
is_path_sink = isinstance(dest, (str, Path))
|
||||
fhandle = None
|
||||
session: Optional[aiohttp.ClientSession] = None
|
||||
stop_evt: Optional[asyncio.Event] = None
|
||||
monitor_task: Optional[asyncio.Task] = None
|
||||
req_task: Optional[asyncio.Task] = None
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
|
||||
|
||||
session = aiohttp.ClientSession(timeout=timeout_cfg)
|
||||
stop_evt = asyncio.Event()
|
||||
|
||||
async def _monitor():
|
||||
try:
|
||||
while not stop_evt.is_set():
|
||||
if is_processing_interrupted():
|
||||
return
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
|
||||
req_task = asyncio.create_task(session.get(url, headers=headers))
|
||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task in done and req_task in pending:
|
||||
req_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await req_task
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
|
||||
try:
|
||||
resp = await req_task
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
|
||||
async with resp:
|
||||
if resp.status >= 400:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
body = await resp.json()
|
||||
except (ContentTypeError, ValueError):
|
||||
text = await resp.text()
|
||||
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=f"HTTP {resp.status}",
|
||||
)
|
||||
|
||||
if resp.status in _RETRY_STATUS and attempt <= max_retries:
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to download (HTTP {resp.status}).")
|
||||
|
||||
if is_path_sink:
|
||||
p = Path(str(dest))
|
||||
with contextlib.suppress(Exception):
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
fhandle = open(p, "wb")
|
||||
sink = fhandle
|
||||
else:
|
||||
sink = dest # BytesIO or file-like
|
||||
|
||||
written = 0
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
chunk = b""
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
|
||||
if is_processing_interrupted():
|
||||
raise ProcessingInterrupted("Task cancelled")
|
||||
|
||||
if not chunk:
|
||||
if resp.content.at_eof():
|
||||
break
|
||||
continue
|
||||
|
||||
sink.write(chunk)
|
||||
written += len(chunk)
|
||||
|
||||
if isinstance(dest, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
dest.seek(0)
|
||||
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=f"[streamed {written} bytes to dest]",
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=op_id,
|
||||
request_method="GET",
|
||||
request_url=url,
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the network. Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError("The remote service appears unreachable at this time.") from e
|
||||
finally:
|
||||
if stop_evt is not None:
|
||||
stop_evt.set()
|
||||
if monitor_task:
|
||||
monitor_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await monitor_task
|
||||
if req_task and not req_task.done():
|
||||
req_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await req_task
|
||||
if session:
|
||||
with contextlib.suppress(Exception):
|
||||
await session.close()
|
||||
if fhandle:
|
||||
with contextlib.suppress(Exception):
|
||||
fhandle.flush()
|
||||
fhandle.close()
|
||||
|
||||
|
||||
async def download_url_to_image_tensor(
|
||||
url: str,
|
||||
*,
|
||||
timeout: float = None,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
|
||||
return bytesio_to_image_tensor(result)
|
||||
|
||||
|
||||
async def download_url_to_video_output(
|
||||
video_url: str,
|
||||
*,
|
||||
timeout: float = None,
|
||||
max_retries: int = 5,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
||||
return VideoFromFile(result)
|
||||
|
||||
|
||||
async def download_url_as_bytesio(
|
||||
url: str,
|
||||
*,
|
||||
timeout: float = None,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> BytesIO:
|
||||
"""Downloads content from a URL and returns a new BytesIO (rewound to 0)."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
|
||||
return result
|
||||
|
||||
|
||||
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
|
||||
except Exception:
|
||||
slug = "download"
|
||||
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||
338
comfy_api_nodes/util/upload_helpers.py
Normal file
338
comfy_api_nodes/util/upload_helpers.py
Normal file
@ -0,0 +1,338 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api.latest import IO, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api_nodes.apis import request_logger
|
||||
|
||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||
from .client import (
|
||||
ApiEndpoint,
|
||||
_diagnose_connectivity,
|
||||
_display_time_progress,
|
||||
sync_op,
|
||||
)
|
||||
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||
from .conversions import (
|
||||
audio_ndarray_to_bytesio,
|
||||
audio_tensor_to_contiguous_ndarray,
|
||||
tensor_to_bytesio,
|
||||
)
|
||||
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
download_url: str = Field(..., description="URL to GET uploaded file")
|
||||
upload_url: str = Field(..., description="URL to PUT file to upload")
|
||||
|
||||
|
||||
async def upload_images_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
max_images: int = 8,
|
||||
mime_type: Optional[str] = None,
|
||||
wait_label: Optional[str] = "Uploading",
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
|
||||
async def upload_audio_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
audio: Input.Audio,
|
||||
*,
|
||||
container_format: str = "mp4",
|
||||
codec_name: str = "aac",
|
||||
mime_type: str = "audio/mp4",
|
||||
filename: str = "uploaded_audio.mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||
Encodes the raw waveform into the specified format before uploading.
|
||||
"""
|
||||
sample_rate: int = audio["sample_rate"]
|
||||
waveform: torch.Tensor = audio["waveform"]
|
||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
|
||||
|
||||
|
||||
async def upload_video_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
video: Input.Video,
|
||||
*,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
Uses the specified container and codec for saving the video before upload.
|
||||
"""
|
||||
if max_duration is not None:
|
||||
try:
|
||||
actual_duration = video.get_duration()
|
||||
if actual_duration > max_duration:
|
||||
raise ValueError(
|
||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error("Error getting video duration: %s", str(e))
|
||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||
|
||||
upload_mime_type = f"video/{container.value.lower()}"
|
||||
filename = f"uploaded_video.{container.value.lower()}"
|
||||
|
||||
# Convert VideoInput to BytesIO using specified container/codec
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: Optional[str],
|
||||
wait_label: Optional[str] = "Uploading",
|
||||
) -> str:
|
||||
"""Uploads a single file to ComfyUI API and returns its download URL."""
|
||||
if upload_mime_type is None:
|
||||
request_object = UploadRequest(file_name=filename)
|
||||
else:
|
||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||
create_resp = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
|
||||
data=request_object,
|
||||
response_model=UploadResponse,
|
||||
final_label_on_success=None,
|
||||
monitor_progress=False,
|
||||
)
|
||||
await upload_file(
|
||||
cls,
|
||||
create_resp.upload_url,
|
||||
file_bytes_io,
|
||||
content_type=upload_mime_type,
|
||||
wait_label=wait_label,
|
||||
)
|
||||
return create_resp.download_url
|
||||
|
||||
|
||||
async def upload_file(
|
||||
cls: type[IO.ComfyNode],
|
||||
upload_url: str,
|
||||
file: Union[BytesIO, str],
|
||||
*,
|
||||
content_type: Optional[str] = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
|
||||
|
||||
Args:
|
||||
cls: Node class (provides auth context + UI progress hooks).
|
||||
upload_url: Pre-signed PUT URL.
|
||||
file: BytesIO or path string.
|
||||
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
|
||||
max_retries: Maximum retry attempts.
|
||||
retry_delay: Initial delay in seconds.
|
||||
retry_backoff: Exponential backoff factor.
|
||||
wait_label: Progress label shown in Comfy UI.
|
||||
|
||||
Raises:
|
||||
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
|
||||
"""
|
||||
if isinstance(file, BytesIO):
|
||||
with contextlib.suppress(Exception):
|
||||
file.seek(0)
|
||||
data = file.read()
|
||||
elif isinstance(file, str):
|
||||
with open(file, "rb") as f:
|
||||
data = f.read()
|
||||
else:
|
||||
raise ValueError("file must be a BytesIO or a filesystem path string")
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
skip_auto_headers: set[str] = set()
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
else:
|
||||
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
|
||||
|
||||
attempt = 0
|
||||
delay = retry_delay
|
||||
start_ts = time.monotonic()
|
||||
op_uuid = uuid.uuid4().hex[:8]
|
||||
while True:
|
||||
attempt += 1
|
||||
operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid)
|
||||
timeout = aiohttp.ClientTimeout(total=None)
|
||||
stop_evt = asyncio.Event()
|
||||
|
||||
async def _monitor():
|
||||
try:
|
||||
while not stop_evt.is_set():
|
||||
if is_processing_interrupted():
|
||||
return
|
||||
if wait_label:
|
||||
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
|
||||
await asyncio.sleep(1.0)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
sess: Optional[aiohttp.ClientSession] = None
|
||||
try:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_params=None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
||||
|
||||
sess = aiohttp.ClientSession(timeout=timeout)
|
||||
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||
req_task = asyncio.create_task(req)
|
||||
|
||||
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if monitor_task in done and req_task in pending:
|
||||
req_task.cancel()
|
||||
raise ProcessingInterrupted("Upload cancelled")
|
||||
|
||||
try:
|
||||
resp = await req_task
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Upload cancelled") from None
|
||||
|
||||
async with resp:
|
||||
if resp.status >= 400:
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
body = await resp.json()
|
||||
except Exception:
|
||||
body = await resp.text()
|
||||
msg = f"Upload failed with status {resp.status}"
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content=body,
|
||||
error_message=msg,
|
||||
)
|
||||
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
wait_label,
|
||||
start_ts,
|
||||
None,
|
||||
display_callback=_display_time_progress if wait_label else None,
|
||||
)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_content="File uploaded successfully.",
|
||||
)
|
||||
except Exception as e:
|
||||
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise ProcessingInterrupted("Task cancelled") from None
|
||||
except (aiohttp.ClientError, OSError) as e:
|
||||
if attempt <= max_retries:
|
||||
with contextlib.suppress(Exception):
|
||||
request_logger.log_request_response(
|
||||
operation_id=operation_id,
|
||||
request_method="PUT",
|
||||
request_url=upload_url,
|
||||
request_headers=headers or None,
|
||||
request_data=f"[File data {len(data)} bytes]",
|
||||
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||
)
|
||||
await sleep_with_interrupt(
|
||||
delay,
|
||||
cls,
|
||||
wait_label,
|
||||
start_ts,
|
||||
None,
|
||||
display_callback=_display_time_progress if wait_label else None,
|
||||
)
|
||||
delay *= retry_backoff
|
||||
continue
|
||||
|
||||
diag = await _diagnose_connectivity()
|
||||
if not diag["internet_accessible"]:
|
||||
raise LocalNetworkError(
|
||||
"Unable to connect to the network. Please check your internet connection and try again."
|
||||
) from e
|
||||
raise ApiServerError("The API service appears unreachable at this time.") from e
|
||||
finally:
|
||||
stop_evt.set()
|
||||
if monitor_task:
|
||||
monitor_task.cancel()
|
||||
with contextlib.suppress(Exception):
|
||||
await monitor_task
|
||||
if sess:
|
||||
with contextlib.suppress(Exception):
|
||||
await sess.close()
|
||||
|
||||
|
||||
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
|
||||
except Exception:
|
||||
slug = "upload"
|
||||
return f"{method}_{slug}_{op_uuid}_try{attempt}"
|
||||
@ -2,6 +2,8 @@ import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import Input
|
||||
|
||||
|
||||
@ -28,76 +30,69 @@ def validate_image_dimensions(
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(
|
||||
f"Image height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
raise ValueError(f"Image height must be at least {min_height}px, got {height}px")
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
|
||||
|
||||
|
||||
def validate_image_aspect_ratio(
|
||||
image: torch.Tensor,
|
||||
min_aspect_ratio: Optional[float] = None,
|
||||
max_aspect_ratio: Optional[float] = None,
|
||||
):
|
||||
width, height = get_image_dimensions(image)
|
||||
aspect_ratio = width / height
|
||||
|
||||
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
|
||||
raise ValueError(
|
||||
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
|
||||
)
|
||||
|
||||
|
||||
def validate_image_aspect_ratio_range(
|
||||
image: torch.Tensor,
|
||||
min_ratio: tuple[float, float], # e.g. (1, 4)
|
||||
max_ratio: tuple[float, float], # e.g. (4, 1)
|
||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
a1, b1 = min_ratio
|
||||
a2, b2 = max_ratio
|
||||
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
|
||||
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
|
||||
lo, hi = (a1 / b1), (a2 / b2)
|
||||
if lo > hi:
|
||||
lo, hi = hi, lo
|
||||
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
|
||||
"""Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
|
||||
w, h = get_image_dimensions(image)
|
||||
if w <= 0 or h <= 0:
|
||||
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
||||
ar = w / h
|
||||
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
|
||||
if not ok:
|
||||
op = "<" if strict else "≤"
|
||||
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
|
||||
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||||
return ar
|
||||
|
||||
|
||||
def validate_aspect_ratio_closeness(
|
||||
start_img,
|
||||
end_img,
|
||||
min_rel: float,
|
||||
max_rel: float,
|
||||
def validate_images_aspect_ratio_closeness(
|
||||
first_image: torch.Tensor,
|
||||
second_image: torch.Tensor,
|
||||
min_rel: float, # e.g. 0.8
|
||||
max_rel: float, # e.g. 1.25
|
||||
*,
|
||||
strict: bool = False, # True => exclusive, False => inclusive
|
||||
) -> None:
|
||||
w1, h1 = get_image_dimensions(start_img)
|
||||
w2, h2 = get_image_dimensions(end_img)
|
||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
"""
|
||||
Validates that the two images' aspect ratios are 'close'.
|
||||
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
|
||||
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
|
||||
|
||||
Returns the computed closeness factor C.
|
||||
"""
|
||||
w1, h1 = get_image_dimensions(first_image)
|
||||
w2, h2 = get_image_dimensions(second_image)
|
||||
if min(w1, h1, w2, h2) <= 0:
|
||||
raise ValueError("Invalid image dimensions")
|
||||
ar1 = w1 / h1
|
||||
ar2 = w2 / h2
|
||||
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
|
||||
closeness = max(ar1, ar2) / min(ar1, ar2)
|
||||
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
|
||||
limit = max(max_rel, 1.0 / min_rel)
|
||||
if (closeness >= limit) if strict else (closeness > limit):
|
||||
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.")
|
||||
raise ValueError(
|
||||
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
|
||||
f"allowed range {min_rel}–{max_rel} (limit {limit:.2g})."
|
||||
)
|
||||
return closeness
|
||||
|
||||
|
||||
def validate_aspect_ratio_string(
|
||||
aspect_ratio: str,
|
||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
|
||||
ar = _parse_aspect_ratio_string(aspect_ratio)
|
||||
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||||
return ar
|
||||
|
||||
|
||||
def validate_video_dimensions(
|
||||
@ -118,9 +113,7 @@ def validate_video_dimensions(
|
||||
if max_width is not None and width > max_width:
|
||||
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
|
||||
if min_height is not None and height < min_height:
|
||||
raise ValueError(
|
||||
f"Video height must be at least {min_height}px, got {height}px"
|
||||
)
|
||||
raise ValueError(f"Video height must be at least {min_height}px, got {height}px")
|
||||
if max_height is not None and height > max_height:
|
||||
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
|
||||
|
||||
@ -138,13 +131,9 @@ def validate_video_duration(
|
||||
|
||||
epsilon = 0.0001
|
||||
if min_duration is not None and min_duration - epsilon > duration:
|
||||
raise ValueError(
|
||||
f"Video duration must be at least {min_duration}s, got {duration}s"
|
||||
)
|
||||
raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s")
|
||||
if max_duration is not None and duration > max_duration + epsilon:
|
||||
raise ValueError(
|
||||
f"Video duration must be at most {max_duration}s, got {duration}s"
|
||||
)
|
||||
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
||||
|
||||
|
||||
def get_number_of_images(images):
|
||||
@ -165,3 +154,77 @@ def validate_audio_duration(
|
||||
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
|
||||
if max_duration is not None and dur - eps > max_duration:
|
||||
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
|
||||
|
||||
|
||||
def validate_string(
|
||||
string: str,
|
||||
strip_whitespace=True,
|
||||
field_name="prompt",
|
||||
min_length=None,
|
||||
max_length=None,
|
||||
):
|
||||
if string is None:
|
||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||
if strip_whitespace:
|
||||
string = string.strip()
|
||||
if min_length and len(string) < min_length:
|
||||
raise Exception(
|
||||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
if max_length and len(string) > max_length:
|
||||
raise Exception(
|
||||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||
)
|
||||
|
||||
|
||||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||
|
||||
|
||||
def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
||||
a, b = r
|
||||
if a <= 0 or b <= 0:
|
||||
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
|
||||
return a / b
|
||||
|
||||
|
||||
def _assert_ratio_bounds(
|
||||
ar: float,
|
||||
*,
|
||||
min_ratio: Optional[tuple[float, float]] = None,
|
||||
max_ratio: Optional[tuple[float, float]] = None,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||||
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
|
||||
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
|
||||
|
||||
if lo is not None and hi is not None and lo > hi:
|
||||
lo, hi = hi, lo # normalize order if caller swapped them
|
||||
|
||||
if lo is not None:
|
||||
if (ar <= lo) if strict else (ar < lo):
|
||||
op = "<" if strict else "≤"
|
||||
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
|
||||
if hi is not None:
|
||||
if (ar >= hi) if strict else (ar > hi):
|
||||
op = "<" if strict else "≤"
|
||||
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
|
||||
|
||||
|
||||
def _parse_aspect_ratio_string(ar_str: str) -> float:
|
||||
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
|
||||
parts = ar_str.split(":")
|
||||
if len(parts) != 2:
|
||||
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
|
||||
try:
|
||||
a = int(parts[0].strip())
|
||||
b = int(parts[1].strip())
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
|
||||
if a <= 0 or b <= 0:
|
||||
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
|
||||
return a / b
|
||||
|
||||
@ -1,4 +1,9 @@
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
import psutil
|
||||
import time
|
||||
import torch
|
||||
from typing import Sequence, Mapping, Dict
|
||||
from comfy_execution.graph import DynamicPrompt
|
||||
from abc import ABC, abstractmethod
|
||||
@ -48,7 +53,7 @@ class Unhashable:
|
||||
def to_hashable(obj):
|
||||
# So that we don't infinitely recurse since frozenset and tuples
|
||||
# are Sequences.
|
||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||
@ -188,6 +193,9 @@ class BasicCache:
|
||||
self._clean_cache()
|
||||
self._clean_subcaches()
|
||||
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
@ -265,6 +273,29 @@ class HierarchicalCache(BasicCache):
|
||||
assert cache is not None
|
||||
return await cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
class NullCache:
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
pass
|
||||
|
||||
def all_node_ids(self):
|
||||
return []
|
||||
|
||||
def clean_unused(self):
|
||||
pass
|
||||
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def get(self, node_id):
|
||||
return None
|
||||
|
||||
def set(self, node_id, value):
|
||||
pass
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
return self
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
@ -318,155 +349,75 @@ class LRUCache(BasicCache):
|
||||
return self
|
||||
|
||||
|
||||
class DependencyAwareCache(BasicCache):
|
||||
"""
|
||||
A cache implementation that tracks dependencies between nodes and manages
|
||||
their execution and caching accordingly. It extends the BasicCache class.
|
||||
Nodes are removed from this cache once all of their descendants have been
|
||||
executed.
|
||||
"""
|
||||
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
||||
|
||||
RAM_CACHE_HYSTERESIS = 1.1
|
||||
|
||||
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
||||
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
||||
|
||||
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
||||
|
||||
#Exponential bias towards evicting older workflows so garbage will be taken out
|
||||
#in constantly changing setups.
|
||||
|
||||
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||
|
||||
class RAMPressureCache(LRUCache):
|
||||
|
||||
def __init__(self, key_class):
|
||||
"""
|
||||
Initialize the DependencyAwareCache.
|
||||
|
||||
Args:
|
||||
key_class: The class used for generating cache keys.
|
||||
"""
|
||||
super().__init__(key_class)
|
||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
"""
|
||||
Clear the entire cache and rebuild the dependency graph.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to initialize the cache for.
|
||||
is_changed_cache: Flag indicating if the cache has changed.
|
||||
"""
|
||||
# Clear all existing cache data
|
||||
self.cache.clear()
|
||||
self.subcaches.clear()
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
self.executed_nodes.clear()
|
||||
|
||||
# Call the parent method to initialize the cache with the new prompt
|
||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
# Rebuild the dependency graph
|
||||
self._build_dependency_graph(dynprompt, node_ids)
|
||||
|
||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
||||
"""
|
||||
Build the dependency graph for all nodes.
|
||||
|
||||
Args:
|
||||
dynprompt: The dynamic prompt object containing node information.
|
||||
node_ids: List of node IDs to build the graph for.
|
||||
"""
|
||||
self.descendants.clear()
|
||||
self.ancestors.clear()
|
||||
for node_id in node_ids:
|
||||
self.descendants[node_id] = set()
|
||||
self.ancestors[node_id] = set()
|
||||
|
||||
for node_id in node_ids:
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
for input_data in inputs.values():
|
||||
if is_link(input_data): # Check if the input is a link to another node
|
||||
ancestor_id = input_data[0]
|
||||
self.descendants[ancestor_id].add(node_id)
|
||||
self.ancestors[node_id].add(ancestor_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
"""
|
||||
Mark a node as executed and store its value in the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to store.
|
||||
value: The value to store for the node.
|
||||
"""
|
||||
self._set_immediate(node_id, value)
|
||||
self.executed_nodes.add(node_id)
|
||||
self._cleanup_ancestors(node_id)
|
||||
|
||||
def get(self, node_id):
|
||||
"""
|
||||
Retrieve the cached value for a node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to retrieve.
|
||||
|
||||
Returns:
|
||||
The cached value for the node.
|
||||
"""
|
||||
return self._get_immediate(node_id)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
"""
|
||||
Ensure a subcache exists for a node and update dependencies.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the parent node.
|
||||
children_ids: List of child node IDs to associate with the parent node.
|
||||
|
||||
Returns:
|
||||
The subcache object for the node.
|
||||
"""
|
||||
subcache = await super()._ensure_subcache(node_id, children_ids)
|
||||
for child_id in children_ids:
|
||||
self.descendants[node_id].add(child_id)
|
||||
self.ancestors[child_id].add(node_id)
|
||||
return subcache
|
||||
|
||||
def _cleanup_ancestors(self, node_id):
|
||||
"""
|
||||
Check if ancestors of a node can be removed from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node whose ancestors are to be checked.
|
||||
"""
|
||||
for ancestor_id in self.ancestors.get(node_id, []):
|
||||
if ancestor_id in self.executed_nodes:
|
||||
# Remove ancestor if all its descendants have been executed
|
||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
||||
self._remove_node(ancestor_id)
|
||||
|
||||
def _remove_node(self, node_id):
|
||||
"""
|
||||
Remove a node from the cache.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to remove.
|
||||
"""
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
del self.cache[cache_key]
|
||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||
if subcache_key in self.subcaches:
|
||||
del self.subcaches[subcache_key]
|
||||
super().__init__(key_class, 0)
|
||||
self.timestamps = {}
|
||||
|
||||
def clean_unused(self):
|
||||
"""
|
||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
||||
"""
|
||||
pass
|
||||
self._clean_subcaches()
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
"""
|
||||
Dump the cache and dependency graph for debugging.
|
||||
def set(self, node_id, value):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
super().set(node_id, value)
|
||||
|
||||
Returns:
|
||||
A list containing the cache state and dependency graph.
|
||||
"""
|
||||
result = super().recursive_debug_dump()
|
||||
result.append({
|
||||
"descendants": self.descendants,
|
||||
"ancestors": self.ancestors,
|
||||
"executed_nodes": list(self.executed_nodes),
|
||||
})
|
||||
return result
|
||||
def get(self, node_id):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return super().get(node_id)
|
||||
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
|
||||
clean_list = []
|
||||
|
||||
for key, (outputs, _), in self.cache.items():
|
||||
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||
|
||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||
def scan_list_for_ram_usage(outputs):
|
||||
nonlocal ram_usage
|
||||
if outputs is None:
|
||||
return
|
||||
for output in outputs:
|
||||
if isinstance(output, list):
|
||||
scan_list_for_ram_usage(output)
|
||||
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||
#score Tensors at a 50% discount for RAM usage as they are likely to
|
||||
#be high value intermediates
|
||||
ram_usage += (output.numel() * output.element_size()) * 0.5
|
||||
elif hasattr(output, "get_ram_usage"):
|
||||
ram_usage += output.get_ram_usage()
|
||||
scan_list_for_ram_usage(outputs)
|
||||
|
||||
oom_score *= ram_usage
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
|
||||
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
gc.collect()
|
||||
|
||||
@ -153,8 +153,9 @@ class TopologicalSort:
|
||||
continue
|
||||
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
if (include_lazy or not is_lazy):
|
||||
if not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
links.append((from_node_id, from_socket, unique_id))
|
||||
|
||||
for link in links:
|
||||
@ -194,10 +195,40 @@ class ExecutionList(TopologicalSort):
|
||||
super().__init__(dynprompt)
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
self.execution_cache = {}
|
||||
self.execution_cache_listeners = {}
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def cache_link(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
self.execution_cache[to_node_id] = {}
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||
if not from_node_id in self.execution_cache_listeners:
|
||||
self.execution_cache_listeners[from_node_id] = set()
|
||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||
|
||||
def get_cache(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
return None
|
||||
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||
if value is None:
|
||||
return None
|
||||
#Write back to the main cache on touch.
|
||||
self.output_cache.set(from_node_id, value)
|
||||
return value
|
||||
|
||||
def cache_update(self, node_id, value):
|
||||
if node_id in self.execution_cache_listeners:
|
||||
for to_node_id in self.execution_cache_listeners[node_id]:
|
||||
if to_node_id in self.execution_cache:
|
||||
self.execution_cache[to_node_id][node_id] = value
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
self.cache_link(from_node_id, to_node_id)
|
||||
|
||||
async def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
@ -277,6 +308,8 @@ class ExecutionList(TopologicalSort):
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
self.pop_node(node_id)
|
||||
self.execution_cache.pop(node_id, None)
|
||||
self.execution_cache_listeners.pop(node_id, None)
|
||||
self.staged_node_id = None
|
||||
|
||||
def get_nodes_in_cycle(self):
|
||||
|
||||
@ -2,6 +2,9 @@ import comfy.utils
|
||||
import folder_paths
|
||||
import torch
|
||||
import logging
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def load_hypernetwork_patch(path, strength):
|
||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||
@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
|
||||
|
||||
return hypernetwork_patch(out, strength)
|
||||
|
||||
class HypernetworkLoader:
|
||||
class HypernetworkLoader(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_hypernetwork"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="HypernetworkLoader",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
||||
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
||||
@classmethod
|
||||
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
|
||||
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||
model_hypernetwork = model.clone()
|
||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||
if patch is not None:
|
||||
model_hypernetwork.set_model_attn1_patch(patch)
|
||||
model_hypernetwork.set_model_attn2_patch(patch)
|
||||
return (model_hypernetwork,)
|
||||
return IO.NodeOutput(model_hypernetwork)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"HypernetworkLoader": HypernetworkLoader
|
||||
}
|
||||
load_hypernetwork = execute # TODO: remove
|
||||
|
||||
|
||||
class HyperNetworkExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
HypernetworkLoader,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> HyperNetworkExtension:
|
||||
return HyperNetworkExtension()
|
||||
|
||||
47
comfy_extras/nodes_rope.py
Normal file
47
comfy_extras/nodes_rope.py
Normal file
@ -0,0 +1,47 @@
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class ScaleROPE(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ScaleROPE",
|
||||
category="advanced/model_patches",
|
||||
description="Scale and shift the ROPE of the model.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||
|
||||
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||
|
||||
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||
|
||||
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class RopeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
ScaleROPE
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> RopeExtension:
|
||||
return RopeExtension()
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.66"
|
||||
__version__ = "0.3.68"
|
||||
|
||||
114
execution.py
114
execution.py
@ -18,9 +18,10 @@ from comfy_execution.caching import (
|
||||
BasicCache,
|
||||
CacheKeySetID,
|
||||
CacheKeySetInputSignature,
|
||||
DependencyAwareCache,
|
||||
NullCache,
|
||||
HierarchicalCache,
|
||||
LRUCache,
|
||||
RAMPressureCache,
|
||||
)
|
||||
from comfy_execution.graph import (
|
||||
DynamicPrompt,
|
||||
@ -88,54 +89,62 @@ class IsChangedCache:
|
||||
return self.is_changed[node_id]
|
||||
|
||||
|
||||
class CacheEntry(NamedTuple):
|
||||
ui: dict
|
||||
outputs: list
|
||||
|
||||
|
||||
class CacheType(Enum):
|
||||
CLASSIC = 0
|
||||
LRU = 1
|
||||
DEPENDENCY_AWARE = 2
|
||||
NONE = 2
|
||||
RAM_PRESSURE = 3
|
||||
|
||||
|
||||
class CacheSet:
|
||||
def __init__(self, cache_type=None, cache_size=None):
|
||||
if cache_type == CacheType.DEPENDENCY_AWARE:
|
||||
self.init_dependency_aware_cache()
|
||||
def __init__(self, cache_type=None, cache_args={}):
|
||||
if cache_type == CacheType.NONE:
|
||||
self.init_null_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.RAM_PRESSURE:
|
||||
cache_ram = cache_args.get("ram", 16.0)
|
||||
self.init_ram_cache(cache_ram)
|
||||
logging.info("Using RAM pressure cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
if cache_size is None:
|
||||
cache_size = 0
|
||||
cache_size = cache_args.get("lru", 0)
|
||||
self.init_lru_cache(cache_size)
|
||||
logging.info("Using LRU cache")
|
||||
else:
|
||||
self.init_classic_cache()
|
||||
|
||||
self.all = [self.outputs, self.ui, self.objects]
|
||||
self.all = [self.outputs, self.objects]
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# only hold cached items while the decendents have not executed
|
||||
def init_dependency_aware_cache(self):
|
||||
self.outputs = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.ui = DependencyAwareCache(CacheKeySetInputSignature)
|
||||
self.objects = DependencyAwareCache(CacheKeySetID)
|
||||
def init_ram_cache(self, min_headroom):
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_null_cache(self):
|
||||
self.outputs = NullCache()
|
||||
self.objects = NullCache()
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
result = {
|
||||
"outputs": self.outputs.recursive_debug_dump(),
|
||||
"ui": self.ui.recursive_debug_dump(),
|
||||
}
|
||||
return result
|
||||
|
||||
SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
|
||||
|
||||
def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data={}):
|
||||
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
|
||||
is_v3 = issubclass(class_def, _ComfyNodeInternal)
|
||||
if is_v3:
|
||||
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
|
||||
@ -153,17 +162,17 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
||||
if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)):
|
||||
input_unique_id = input_data[0]
|
||||
output_index = input_data[1]
|
||||
if outputs is None:
|
||||
if execution_list is None:
|
||||
mark_missing()
|
||||
continue # This might be a lazily-evaluated input
|
||||
cached_output = outputs.get(input_unique_id)
|
||||
if cached_output is None:
|
||||
cached = execution_list.get_cache(input_unique_id, unique_id)
|
||||
if cached is None or cached.outputs is None:
|
||||
mark_missing()
|
||||
continue
|
||||
if output_index >= len(cached_output):
|
||||
if output_index >= len(cached.outputs):
|
||||
mark_missing()
|
||||
continue
|
||||
obj = cached_output[output_index]
|
||||
obj = cached.outputs[output_index]
|
||||
input_data_all[x] = obj
|
||||
elif input_category is not None:
|
||||
input_data_all[x] = [input_data]
|
||||
@ -392,7 +401,7 @@ def format_value(x):
|
||||
else:
|
||||
return str(x)
|
||||
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
|
||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||
unique_id = current_item
|
||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||
@ -400,11 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if caches.outputs.get(unique_id) is not None:
|
||||
cached = caches.outputs.get(unique_id)
|
||||
if cached is not None:
|
||||
if server.client_id is not None:
|
||||
cached_output = caches.ui.get(unique_id) or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
cached_ui = cached.ui or {}
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||
if cached.ui is not None:
|
||||
ui_outputs[unique_id] = cached.ui
|
||||
get_progress_state().finish_progress(unique_id)
|
||||
execution_list.cache_update(unique_id, cached)
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
input_data_all = None
|
||||
@ -434,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
for r in result:
|
||||
if is_link(r):
|
||||
source_node, source_output = r[0], r[1]
|
||||
node_output = caches.outputs.get(source_node)[source_output]
|
||||
for o in node_output:
|
||||
node_cached = execution_list.get_cache(source_node, unique_id)
|
||||
for o in node_cached.outputs[source_output]:
|
||||
resolved_output.append(o)
|
||||
|
||||
else:
|
||||
@ -443,10 +456,11 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
resolved_outputs.append(tuple(resolved_output))
|
||||
output_data = merge_result_data(resolved_outputs, class_def)
|
||||
output_ui = []
|
||||
del pending_subgraph_results[unique_id]
|
||||
has_subgraph = False
|
||||
else:
|
||||
get_progress_state().start_progress(unique_id)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data)
|
||||
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
|
||||
if server.client_id is not None:
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
@ -504,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
caches.ui.set(unique_id, {
|
||||
ui_outputs[unique_id] = {
|
||||
"meta": {
|
||||
"node_id": unique_id,
|
||||
"display_node": display_node_id,
|
||||
@ -512,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
"real_node_id": real_node_id,
|
||||
},
|
||||
"output": output_ui
|
||||
})
|
||||
}
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||
if has_subgraph:
|
||||
@ -525,10 +539,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if new_graph is None:
|
||||
cached_outputs.append((False, node_outputs))
|
||||
else:
|
||||
# Check for conflicts
|
||||
for node_id in new_graph.keys():
|
||||
if dynprompt.has_node(node_id):
|
||||
raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.")
|
||||
for node_id, node_info in new_graph.items():
|
||||
new_node_ids.append(node_id)
|
||||
display_id = node_info.get("override_display_id", unique_id)
|
||||
@ -549,11 +559,16 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
subcache.clean_unused()
|
||||
for node_id in new_output_ids:
|
||||
execution_list.add_node(node_id)
|
||||
execution_list.cache_link(node_id, unique_id)
|
||||
for link in new_output_links:
|
||||
execution_list.add_strong_link(link[0], link[1], unique_id)
|
||||
pending_subgraph_results[unique_id] = cached_outputs
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
caches.outputs.set(unique_id, output_data)
|
||||
|
||||
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
||||
execution_list.cache_update(unique_id, cache_entry)
|
||||
caches.outputs.set(unique_id, cache_entry)
|
||||
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
|
||||
@ -597,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
return (ExecutionResult.SUCCESS, None, None)
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server, cache_type=False, cache_size=None):
|
||||
self.cache_size = cache_size
|
||||
def __init__(self, server, cache_type=False, cache_args=None):
|
||||
self.cache_args = cache_args
|
||||
self.cache_type = cache_type
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
|
||||
@ -679,6 +694,7 @@ class PromptExecutor:
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
@ -692,7 +708,7 @@ class PromptExecutor:
|
||||
break
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
@ -701,18 +717,16 @@ class PromptExecutor:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
all_node_ids = self.caches.ui.all_node_ids()
|
||||
for node_id in all_node_ids:
|
||||
ui_info = self.caches.ui.get(node_id)
|
||||
if ui_info is not None:
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
@ -1110,7 +1124,7 @@ class PromptQueue:
|
||||
messages: List[str]
|
||||
|
||||
def task_done(self, item_id, history_result,
|
||||
status: Optional['PromptQueue.ExecutionStatus']):
|
||||
status: Optional['PromptQueue.ExecutionStatus'], process_item=None):
|
||||
with self.mutex:
|
||||
prompt = self.currently_running.pop(item_id)
|
||||
if len(self.history) > MAXIMUM_HISTORY_SIZE:
|
||||
@ -1120,10 +1134,8 @@ class PromptQueue:
|
||||
if status is not None:
|
||||
status_dict = copy.deepcopy(status._asdict())
|
||||
|
||||
# Remove sensitive data from extra_data before storing in history
|
||||
for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS:
|
||||
if sensitive_val in prompt[3]:
|
||||
prompt[3].pop(sensitive_val)
|
||||
if process_item is not None:
|
||||
prompt = process_item(prompt)
|
||||
|
||||
self.history[prompt[1]] = {
|
||||
"prompt": prompt,
|
||||
|
||||
17
main.py
17
main.py
@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif args.cache_ram > 0:
|
||||
cache_type = execution.CacheType.RAM_PRESSURE
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.DEPENDENCY_AWARE
|
||||
cache_type = execution.CacheType.NONE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
@ -192,14 +194,21 @@ def prompt_worker(q, server_instance):
|
||||
prompt_id = item[1]
|
||||
server_instance.last_prompt_id = prompt_id
|
||||
|
||||
e.execute(item[2], prompt_id, item[3], item[4])
|
||||
sensitive = item[5]
|
||||
extra_data = item[3].copy()
|
||||
for k in sensitive:
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages))
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -2329,6 +2329,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_model_patch.py",
|
||||
"nodes_easycache.py",
|
||||
"nodes_audio_encoder.py",
|
||||
"nodes_rope.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@ -2349,6 +2350,7 @@ async def init_builtin_api_nodes():
|
||||
"nodes_kling.py",
|
||||
"nodes_bfl.py",
|
||||
"nodes_bytedance.py",
|
||||
"nodes_ltxv.py",
|
||||
"nodes_luma.py",
|
||||
"nodes_recraft.py",
|
||||
"nodes_pixverse.py",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.66"
|
||||
version = "0.3.68"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
@ -50,6 +50,8 @@ messages_control.disable = [
|
||||
"too-many-branches",
|
||||
"too-many-locals",
|
||||
"too-many-arguments",
|
||||
"too-many-return-statements",
|
||||
"too-many-nested-blocks",
|
||||
"duplicate-code",
|
||||
"abstract-method",
|
||||
"superfluous-parens",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.28.7
|
||||
comfyui-workflow-templates==0.2.1
|
||||
comfyui-embedded-docs==0.3.0
|
||||
comfyui-frontend-package==1.28.8
|
||||
comfyui-workflow-templates==0.2.11
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
14
server.py
14
server.py
@ -35,6 +35,7 @@ from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from app.subgraph_manager import SubgraphManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from protocol import BinaryEventTypes
|
||||
@ -173,6 +174,7 @@ class PromptServer():
|
||||
self.user_manager = UserManager()
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
self.subgraph_manager = SubgraphManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
self.supports = ["custom_nodes_from_web"]
|
||||
self.prompt_queue = execution.PromptQueue(self)
|
||||
@ -689,8 +691,9 @@ class PromptServer():
|
||||
async def get_queue(request):
|
||||
queue_info = {}
|
||||
current_queue = self.prompt_queue.get_current_queue_volatile()
|
||||
queue_info['queue_running'] = current_queue[0]
|
||||
queue_info['queue_pending'] = current_queue[1]
|
||||
remove_sensitive = lambda queue: [x[:5] for x in queue]
|
||||
queue_info['queue_running'] = remove_sensitive(current_queue[0])
|
||||
queue_info['queue_pending'] = remove_sensitive(current_queue[1])
|
||||
return web.json_response(queue_info)
|
||||
|
||||
@routes.post("/prompt")
|
||||
@ -726,7 +729,11 @@ class PromptServer():
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
if valid[0]:
|
||||
outputs_to_execute = valid[2]
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute))
|
||||
sensitive = {}
|
||||
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
|
||||
if sensitive_val in extra_data:
|
||||
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
return web.json_response(response)
|
||||
else:
|
||||
@ -819,6 +826,7 @@ class PromptServer():
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
|
||||
232
tests-unit/comfy_quant/test_mixed_precision.py
Normal file
232
tests-unit/comfy_quant/test_mixed_precision.py
Normal file
@ -0,0 +1,232 @@
|
||||
import unittest
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add comfy to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
def has_gpu():
|
||||
return torch.cuda.is_available()
|
||||
|
||||
from comfy.cli_args import args
|
||||
if not has_gpu():
|
||||
args.cpu = True
|
||||
|
||||
from comfy import ops
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
def __init__(self, operations=ops.disable_weight_init):
|
||||
super().__init__()
|
||||
self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16)
|
||||
self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16)
|
||||
self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.layer1(x)
|
||||
x = torch.nn.functional.relu(x)
|
||||
x = self.layer2(x)
|
||||
x = torch.nn.functional.relu(x)
|
||||
x = self.layer3(x)
|
||||
return x
|
||||
|
||||
|
||||
class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
def test_all_layers_standard(self):
|
||||
"""Test that model with no quantization works normally"""
|
||||
# Configure no quantization
|
||||
ops.MixedPrecisionOps._layer_quant_config = {}
|
||||
|
||||
# Create model
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
|
||||
# Initialize weights manually
|
||||
model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16))
|
||||
model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16))
|
||||
model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16))
|
||||
model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16))
|
||||
model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16))
|
||||
model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16))
|
||||
|
||||
# Initialize weight_function and bias_function
|
||||
for layer in [model.layer1, model.layer2, model.layer3]:
|
||||
layer.weight_function = []
|
||||
layer.bias_function = []
|
||||
|
||||
# Forward pass
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output = model(input_tensor)
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
self.assertEqual(output.dtype, torch.bfloat16)
|
||||
|
||||
def test_mixed_precision_load(self):
|
||||
"""Test loading a mixed precision model from state dict"""
|
||||
# Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard
|
||||
layer_quant_config = {
|
||||
"layer1": {
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
},
|
||||
"layer3": {
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
# Create state dict with mixed precision
|
||||
fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
|
||||
state_dict = {
|
||||
# Layer 1: FP8 E4M3FN
|
||||
"layer1.weight": fp8_weight1,
|
||||
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
||||
|
||||
# Layer 2: Standard BF16
|
||||
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||
|
||||
# Layer 3: FP8 E4M3FN
|
||||
"layer3.weight": fp8_weight3,
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||
}
|
||||
|
||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Verify weights are wrapped in QuantizedTensor
|
||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
# Layer 2 should NOT be quantized
|
||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||
|
||||
# Layer 3 should be quantized
|
||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
# Verify scales were loaded
|
||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||
self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5)
|
||||
|
||||
# Forward pass
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output = model(input_tensor)
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
def test_state_dict_quantized_preserved(self):
|
||||
"""Test that quantized weights are preserved in state_dict()"""
|
||||
# Configure mixed precision
|
||||
layer_quant_config = {
|
||||
"layer1": {
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
# Create and load model
|
||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
state_dict1 = {
|
||||
"layer1.weight": fp8_weight,
|
||||
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||
"layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32),
|
||||
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model.load_state_dict(state_dict1, strict=False)
|
||||
|
||||
# Save state dict
|
||||
state_dict2 = model.state_dict()
|
||||
|
||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
# Verify non-quantized layers are standard tensors
|
||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||
self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor)
|
||||
|
||||
def test_weight_function_compatibility(self):
|
||||
"""Test that weight_function (LoRA) works with quantized layers"""
|
||||
# Configure FP8 quantization
|
||||
layer_quant_config = {
|
||||
"layer1": {
|
||||
"format": "float8_e4m3fn",
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
# Create and load model
|
||||
fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
state_dict = {
|
||||
"layer1.weight": fp8_weight,
|
||||
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||
"layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32),
|
||||
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Add a weight function (simulating LoRA)
|
||||
# This should trigger dequantization during forward pass
|
||||
def apply_lora(weight):
|
||||
lora_delta = torch.randn_like(weight) * 0.01
|
||||
return weight + lora_delta
|
||||
|
||||
model.layer1.weight_function.append(apply_lora)
|
||||
|
||||
# Forward pass should work with LoRA (triggers weight_function path)
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output = model(input_tensor)
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
def test_error_handling_unknown_format(self):
|
||||
"""Test that unknown formats raise error"""
|
||||
# Configure with unknown format
|
||||
layer_quant_config = {
|
||||
"layer1": {
|
||||
"format": "unknown_format_xyz",
|
||||
"params": {}
|
||||
}
|
||||
}
|
||||
ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
|
||||
|
||||
# Create state dict
|
||||
state_dict = {
|
||||
"layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16),
|
||||
"layer1.bias": torch.randn(20, dtype=torch.bfloat16),
|
||||
"layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16),
|
||||
"layer2.bias": torch.randn(30, dtype=torch.bfloat16),
|
||||
"layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16),
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||
model = SimpleModel(operations=ops.MixedPrecisionOps)
|
||||
with self.assertRaises(KeyError):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
190
tests-unit/comfy_quant/test_quant_registry.py
Normal file
190
tests-unit/comfy_quant/test_quant_registry.py
Normal file
@ -0,0 +1,190 @@
|
||||
import unittest
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add comfy to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
def has_gpu():
|
||||
return torch.cuda.is_available()
|
||||
|
||||
from comfy.cli_args import args
|
||||
if not has_gpu():
|
||||
args.cpu = True
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
||||
|
||||
|
||||
class TestQuantizedTensor(unittest.TestCase):
|
||||
"""Test the QuantizedTensor subclass with FP8 layout"""
|
||||
|
||||
def test_creation(self):
|
||||
"""Test creating a QuantizedTensor with TensorCoreFP8Layout"""
|
||||
fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(2.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.shape, (256, 128))
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt._layout_params['scale'], scale)
|
||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
def test_dequantize(self):
|
||||
"""Test explicit dequantization"""
|
||||
|
||||
fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(3.0)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
dequantized = qt.dequantize()
|
||||
|
||||
self.assertEqual(dequantized.dtype, torch.float32)
|
||||
self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1))
|
||||
|
||||
def test_from_float(self):
|
||||
"""Test creating QuantizedTensor from float tensor"""
|
||||
float_tensor = torch.randn(64, 32, dtype=torch.float32)
|
||||
scale = torch.tensor(1.5)
|
||||
|
||||
qt = QuantizedTensor.from_float(
|
||||
float_tensor,
|
||||
"TensorCoreFP8Layout",
|
||||
scale=scale,
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
self.assertIsInstance(qt, QuantizedTensor)
|
||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qt.shape, (64, 32))
|
||||
|
||||
# Verify dequantization gives approximately original values
|
||||
dequantized = qt.dequantize()
|
||||
mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean()
|
||||
self.assertLess(mean_rel_error, 0.1)
|
||||
|
||||
|
||||
class TestGenericUtilities(unittest.TestCase):
|
||||
"""Test generic utility operations"""
|
||||
|
||||
def test_detach(self):
|
||||
"""Test detach operation on quantized tensor"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
# Detach should return a new QuantizedTensor
|
||||
qt_detached = qt.detach()
|
||||
|
||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||
self.assertEqual(qt_detached.shape, qt.shape)
|
||||
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
def test_clone(self):
|
||||
"""Test clone operation on quantized tensor"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
# Clone should return a new QuantizedTensor
|
||||
qt_cloned = qt.clone()
|
||||
|
||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
||||
|
||||
# Verify it's a deep copy
|
||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||
|
||||
@unittest.skipUnless(has_gpu(), "GPU not available")
|
||||
def test_to_device(self):
|
||||
"""Test device transfer"""
|
||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||
scale = torch.tensor(1.5)
|
||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||
|
||||
# Moving to same device should work (CPU to CPU)
|
||||
qt_cpu = qt.to('cpu')
|
||||
|
||||
self.assertIsInstance(qt_cpu, QuantizedTensor)
|
||||
self.assertEqual(qt_cpu.device.type, 'cpu')
|
||||
self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu')
|
||||
|
||||
|
||||
class TestTensorCoreFP8Layout(unittest.TestCase):
|
||||
"""Test the TensorCoreFP8Layout implementation"""
|
||||
|
||||
def test_quantize(self):
|
||||
"""Test quantization method"""
|
||||
float_tensor = torch.randn(32, 64, dtype=torch.float32)
|
||||
scale = torch.tensor(1.5)
|
||||
|
||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||
float_tensor,
|
||||
scale=scale,
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
self.assertEqual(qdata.dtype, torch.float8_e4m3fn)
|
||||
self.assertEqual(qdata.shape, float_tensor.shape)
|
||||
self.assertIn('scale', layout_params)
|
||||
self.assertIn('orig_dtype', layout_params)
|
||||
self.assertEqual(layout_params['orig_dtype'], torch.float32)
|
||||
|
||||
def test_dequantize(self):
|
||||
"""Test dequantization method"""
|
||||
float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0
|
||||
scale = torch.tensor(1.0)
|
||||
|
||||
qdata, layout_params = TensorCoreFP8Layout.quantize(
|
||||
float_tensor,
|
||||
scale=scale,
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params)
|
||||
|
||||
# Should approximately match original
|
||||
self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1))
|
||||
|
||||
|
||||
class TestFallbackMechanism(unittest.TestCase):
|
||||
"""Test fallback for unsupported operations"""
|
||||
|
||||
def test_unsupported_op_dequantizes(self):
|
||||
"""Test that unsupported operations fall back to dequantization"""
|
||||
# Set seed for reproducibility
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Create quantized tensor
|
||||
a_fp32 = torch.randn(10, 20, dtype=torch.float32)
|
||||
scale = torch.tensor(1.0)
|
||||
a_q = QuantizedTensor.from_float(
|
||||
a_fp32,
|
||||
"TensorCoreFP8Layout",
|
||||
scale=scale,
|
||||
dtype=torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
# Call an operation that doesn't have a registered handler
|
||||
# For example, torch.abs
|
||||
result = torch.abs(a_q)
|
||||
|
||||
# Should work via fallback (dequantize → abs → return)
|
||||
self.assertNotIsInstance(result, QuantizedTensor)
|
||||
expected = torch.abs(a_fp32)
|
||||
# FP8 introduces quantization error, so use loose tolerance
|
||||
mean_error = (result - expected).abs().mean()
|
||||
self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -152,12 +152,12 @@ class TestExecution:
|
||||
# Initialize server and client
|
||||
#
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
# (use_lru, lru_size)
|
||||
(False, 0),
|
||||
(True, 0),
|
||||
(True, 100),
|
||||
{ "extra_args" : [], "should_cache_results" : True },
|
||||
{ "extra_args" : ["--cache-lru", 0], "should_cache_results" : True },
|
||||
{ "extra_args" : ["--cache-lru", 100], "should_cache_results" : True },
|
||||
{ "extra_args" : ["--cache-none"], "should_cache_results" : False },
|
||||
])
|
||||
def _server(self, args_pytest, request):
|
||||
def server(self, args_pytest, request):
|
||||
# Start server
|
||||
pargs = [
|
||||
'python','main.py',
|
||||
@ -167,12 +167,10 @@ class TestExecution:
|
||||
'--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
|
||||
'--cpu',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
pargs += ['--cache-lru', str(lru_size)]
|
||||
pargs += [ str(param) for param in request.param["extra_args"] ]
|
||||
print("Running server with args:", pargs) # noqa: T201
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
yield request.param
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -193,7 +191,7 @@ class TestExecution:
|
||||
return comfy_client
|
||||
|
||||
@fixture(scope="class", autouse=True)
|
||||
def shared_client(self, args_pytest, _server):
|
||||
def shared_client(self, args_pytest, server):
|
||||
client = self.start_client(args_pytest["listen"], args_pytest["port"])
|
||||
yield client
|
||||
del client
|
||||
@ -225,7 +223,7 @@ class TestExecution:
|
||||
assert result.did_run(mask)
|
||||
assert result.did_run(lazy_mix)
|
||||
|
||||
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
@ -237,9 +235,12 @@ class TestExecution:
|
||||
client.run(g)
|
||||
result2 = client.run(g)
|
||||
for node_id, node in g.nodes.items():
|
||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||
else:
|
||||
assert result2.did_run(node), f"Node {node_id} was cached, but should have been run"
|
||||
|
||||
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
@ -251,8 +252,12 @@ class TestExecution:
|
||||
client.run(g)
|
||||
mask.inputs['value'] = 0.4
|
||||
result2 = client.run(g)
|
||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||
else:
|
||||
assert result2.did_run(input1), "Input1 should have been rerun"
|
||||
assert result2.did_run(input2), "Input2 should have been rerun"
|
||||
|
||||
def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
@ -411,7 +416,7 @@ class TestExecution:
|
||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||
client.run(g)
|
||||
|
||||
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
||||
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
g = builder
|
||||
# Creating the nodes in this specific order previously caused a bug
|
||||
save = g.node("SaveImage")
|
||||
@ -427,7 +432,10 @@ class TestExecution:
|
||||
result3 = client.run(g)
|
||||
result4 = client.run(g)
|
||||
assert result1.did_run(is_changed), "is_changed should have been run"
|
||||
assert not result2.did_run(is_changed), "is_changed should have been cached"
|
||||
if server["should_cache_results"]:
|
||||
assert not result2.did_run(is_changed), "is_changed should have been cached"
|
||||
else:
|
||||
assert result2.did_run(is_changed), "is_changed should have been re-run"
|
||||
assert result3.did_run(is_changed), "is_changed should have been re-run"
|
||||
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
||||
|
||||
@ -514,7 +522,7 @@ class TestExecution:
|
||||
assert len(images2) == 1, "Should have 1 image"
|
||||
|
||||
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||
g = builder
|
||||
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||
@ -530,7 +538,11 @@ class TestExecution:
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
if server["should_cache_results"]:
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
else:
|
||||
assert result.did_run(test_node), "The execution should have been re-run"
|
||||
|
||||
|
||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||
# Warmup execution to ensure server is fully initialized
|
||||
|
||||
Loading…
Reference in New Issue
Block a user