mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-06 23:32:30 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
39a5c5621e
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -8,13 +8,15 @@ body:
|
|||||||
Before submitting a **Bug Report**, please ensure the following:
|
Before submitting a **Bug Report**, please ensure the following:
|
||||||
|
|
||||||
- **1:** You are running the latest version of ComfyUI.
|
- **1:** You are running the latest version of ComfyUI.
|
||||||
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
|
||||||
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
||||||
`--disable-all-custom-nodes` command line argument.
|
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
|
||||||
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
||||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||||
|
|
||||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
## Very Important
|
||||||
|
|
||||||
|
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
id: custom-nodes-test
|
id: custom-nodes-test
|
||||||
attributes:
|
attributes:
|
||||||
|
|||||||
20
.github/workflows/test-ci.yml
vendored
20
.github/workflows/test-ci.yml
vendored
@ -21,14 +21,15 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
# os: [macos, linux, windows]
|
# os: [macos, linux, windows]
|
||||||
os: [macos, linux]
|
# os: [macos, linux]
|
||||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
os: [linux]
|
||||||
|
python_version: ["3.10", "3.11", "3.12"]
|
||||||
cuda_version: ["12.1"]
|
cuda_version: ["12.1"]
|
||||||
torch_version: ["stable"]
|
torch_version: ["stable"]
|
||||||
include:
|
include:
|
||||||
- os: macos
|
# - os: macos
|
||||||
runner_label: [self-hosted, macOS]
|
# runner_label: [self-hosted, macOS]
|
||||||
flags: "--use-pytorch-cross-attention"
|
# flags: "--use-pytorch-cross-attention"
|
||||||
- os: linux
|
- os: linux
|
||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
@ -73,14 +74,15 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos, linux]
|
# os: [macos, linux]
|
||||||
|
os: [linux]
|
||||||
python_version: ["3.11"]
|
python_version: ["3.11"]
|
||||||
cuda_version: ["12.1"]
|
cuda_version: ["12.1"]
|
||||||
torch_version: ["nightly"]
|
torch_version: ["nightly"]
|
||||||
include:
|
include:
|
||||||
- os: macos
|
# - os: macos
|
||||||
runner_label: [self-hosted, macOS]
|
# runner_label: [self-hosted, macOS]
|
||||||
flags: "--use-pytorch-cross-attention"
|
# flags: "--use-pytorch-cross-attention"
|
||||||
- os: linux
|
- os: linux
|
||||||
runner_label: [self-hosted, Linux]
|
runner_label: [self-hosted, Linux]
|
||||||
flags: ""
|
flags: ""
|
||||||
|
|||||||
13
README.md
13
README.md
@ -112,10 +112,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
|
|
||||||
## Release Process
|
## Release Process
|
||||||
|
|
||||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||||
|
|
||||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||||
- Releases a new stable version (e.g., v0.7.0)
|
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||||
|
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||||
- Serves as the foundation for the desktop release
|
- Serves as the foundation for the desktop release
|
||||||
|
|
||||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||||
@ -199,7 +200,7 @@ comfy install
|
|||||||
|
|
||||||
## Manual Install (Windows, Linux)
|
## Manual Install (Windows, Linux)
|
||||||
|
|
||||||
Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended.
|
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||||
|
|
||||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||||
|
|
||||||
@ -241,7 +242,7 @@ RDNA 4 (RX 9000 series):
|
|||||||
|
|
||||||
### Intel GPUs (Windows and Linux)
|
### Intel GPUs (Windows and Linux)
|
||||||
|
|
||||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||||
|
|
||||||
1. To install PyTorch xpu, use the following command:
|
1. To install PyTorch xpu, use the following command:
|
||||||
|
|
||||||
@ -251,10 +252,6 @@ This is the command to install the Pytorch xpu nightly which might have some per
|
|||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||||
|
|
||||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
|
||||||
|
|
||||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|||||||
@ -145,9 +145,10 @@ class PerformanceFeature(enum.Enum):
|
|||||||
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
Fp8MatrixMultiplication = "fp8_matrix_mult"
|
||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
PinnedMem = "pinned_memory"
|
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
|
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||||
|
|
||||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||||
|
|||||||
@ -195,8 +195,8 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||||
|
|||||||
@ -7,15 +7,7 @@ import comfy.model_management
|
|||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||||
q_shape = q.shape
|
q, k = apply_rope(q, k, pe)
|
||||||
k_shape = k.shape
|
|
||||||
|
|
||||||
if pe is not None:
|
|
||||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
|
||||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
|
||||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
|
||||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -210,7 +210,7 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
@ -222,10 +222,22 @@ class Flux(nn.Module):
|
|||||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
steps_h = h_len
|
||||||
|
steps_w = w_len
|
||||||
|
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
if rope_options is not None:
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
|
||||||
|
index += rope_options.get("shift_t", 0.0)
|
||||||
|
h_offset += rope_options.get("shift_y", 0.0)
|
||||||
|
w_offset += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
|
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||||
@ -241,7 +253,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||||
img, img_ids = self.process_img(x)
|
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||||
img_tokens = img.shape[1]
|
img_tokens = img.shape[1]
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
h = 0
|
h = 0
|
||||||
|
|||||||
@ -3,12 +3,11 @@ from torch import nn
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.modules.attention
|
import comfy.ldm.modules.attention
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
from einops import rearrange
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
@ -238,20 +237,6 @@ class FeedForward(nn.Module):
|
|||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
|
||||||
cos_freqs = freqs_cis[0]
|
|
||||||
sin_freqs = freqs_cis[1]
|
|
||||||
|
|
||||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
|
||||||
t1, t2 = t_dup.unbind(dim=-1)
|
|
||||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
|
||||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
|
||||||
|
|
||||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -281,8 +266,8 @@ class CrossAttention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
if pe is not None:
|
if pe is not None:
|
||||||
q = apply_rotary_emb(q, pe)
|
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
||||||
k = apply_rotary_emb(k, pe)
|
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||||
|
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||||
|
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||||
|
x.addcmul_(attn1_input, gate_msa)
|
||||||
|
del attn1_input
|
||||||
|
|
||||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
y = comfy.ldm.common_dit.rms_norm(x)
|
||||||
x += self.ff(y) * gate_mlp
|
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||||
|
x.addcmul_(self.ff(y), gate_mlp)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
|
|||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||||
dtype = torch.float32 #self.dtype
|
dtype = torch.float32
|
||||||
|
device = indices_grid.device
|
||||||
|
|
||||||
|
# Get fractional positions and compute frequency indices
|
||||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||||
|
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
||||||
|
|
||||||
start = 1
|
# Compute frequencies and apply cos/sin
|
||||||
end = theta
|
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
||||||
device = fractional_positions.device
|
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
||||||
|
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
||||||
|
|
||||||
indices = theta ** (
|
# Pad if dim is not divisible by 6
|
||||||
torch.linspace(
|
|
||||||
math.log(start, theta),
|
|
||||||
math.log(end, theta),
|
|
||||||
dim // 6,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
indices = indices.to(dtype=dtype)
|
|
||||||
|
|
||||||
indices = indices * math.pi / 2
|
|
||||||
|
|
||||||
freqs = (
|
|
||||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
|
||||||
.transpose(-1, -2)
|
|
||||||
.flatten(2)
|
|
||||||
)
|
|
||||||
|
|
||||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
|
||||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
|
||||||
if dim % 6 != 0:
|
if dim % 6 != 0:
|
||||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
padding_size = dim % 6
|
||||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
||||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
|
||||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||||
|
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||||
|
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||||
|
|
||||||
|
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||||
|
freqs_cis = torch.stack([
|
||||||
|
torch.stack([cos_vals, -sin_vals], dim=-1),
|
||||||
|
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||||
|
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||||
|
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
class LTXVModel(torch.nn.Module):
|
class LTXVModel(torch.nn.Module):
|
||||||
@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module):
|
|||||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
# Modulation
|
# Modulation
|
||||||
x = x * (1 + scale) + shift
|
x = torch.addcmul(x, x, scale).add_(shift)
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
|
|
||||||
x = self.patchifier.unpatchify(
|
x = self.patchifier.unpatchify(
|
||||||
|
|||||||
@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
|||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
del ids, txt_ids, img_ids
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
|||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
|
||||||
class GELU(nn.Module):
|
class GELU(nn.Module):
|
||||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||||
@ -134,33 +135,34 @@ class Attention(nn.Module):
|
|||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
seq_img = hidden_states.shape[1]
|
||||||
seq_txt = encoder_hidden_states.shape[1]
|
seq_txt = encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
|
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
img_query = self.norm_q(img_query)
|
img_query = self.norm_q(img_query)
|
||||||
img_key = self.norm_k(img_key)
|
img_key = self.norm_k(img_key)
|
||||||
txt_query = self.norm_added_q(txt_query)
|
txt_query = self.norm_added_q(txt_query)
|
||||||
txt_key = self.norm_added_k(txt_key)
|
txt_key = self.norm_added_k(txt_key)
|
||||||
|
|
||||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
joint_query = torch.cat([txt_query, img_query], dim=2)
|
||||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||||
|
|
||||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
joint_query = joint_query.flatten(start_dim=2)
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||||
joint_key = joint_key.flatten(start_dim=2)
|
attention_mask, transformer_options=transformer_options,
|
||||||
joint_value = joint_value.flatten(start_dim=2)
|
skip_reshape=True)
|
||||||
|
|
||||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||||
@ -234,10 +236,10 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||||
|
|
||||||
img_normed = self.img_norm1(hidden_states)
|
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
del img_mod1
|
||||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||||
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
del txt_mod1
|
||||||
|
|
||||||
img_attn_output, txt_attn_output = self.attn(
|
img_attn_output, txt_attn_output = self.attn(
|
||||||
hidden_states=img_modulated,
|
hidden_states=img_modulated,
|
||||||
@ -246,16 +248,20 @@ class QwenImageTransformerBlock(nn.Module):
|
|||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
del img_modulated
|
||||||
|
del txt_modulated
|
||||||
|
|
||||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||||
|
del img_attn_output
|
||||||
|
del txt_attn_output
|
||||||
|
del img_gate1
|
||||||
|
del txt_gate1
|
||||||
|
|
||||||
img_normed2 = self.img_norm2(hidden_states)
|
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
|
||||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||||
|
|
||||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
|
||||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||||
|
|
||||||
return encoder_hidden_states, hidden_states
|
return encoder_hidden_states, hidden_states
|
||||||
@ -413,7 +419,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
del ids, txt_ids, img_ids
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
|
|||||||
@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
# assert e[0].dtype == torch.float32
|
# assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
|
x = x.contiguous() # otherwise implicit in LayerNorm
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
freqs, transformer_options=transformer_options)
|
freqs, transformer_options=transformer_options)
|
||||||
|
|||||||
@ -504,6 +504,7 @@ class LoadedModel:
|
|||||||
if use_more_vram == 0:
|
if use_more_vram == 0:
|
||||||
use_more_vram = 1e32
|
use_more_vram = 1e32
|
||||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||||
|
|
||||||
real_model = self.model.model
|
real_model = self.model.model
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||||
@ -689,7 +690,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||||
|
|
||||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
||||||
|
|
||||||
|
if lowvram_model_memory == 0:
|
||||||
|
lowvram_model_memory = 0.1
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 0.1
|
lowvram_model_memory = 0.1
|
||||||
@ -1082,32 +1086,75 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
non_blocking = device_supports_non_blocking(device)
|
non_blocking = device_supports_non_blocking(device)
|
||||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
|
PINNED_MEMORY = {}
|
||||||
|
TOTAL_PINNED_MEMORY = 0
|
||||||
|
MAX_PINNED_MEMORY = -1
|
||||||
|
if not args.disable_pinned_memory:
|
||||||
|
if is_nvidia() or is_amd():
|
||||||
|
if WINDOWS:
|
||||||
|
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
||||||
|
else:
|
||||||
|
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||||
|
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||||
|
|
||||||
|
|
||||||
def pin_memory(tensor):
|
def pin_memory(tensor):
|
||||||
if PerformanceFeature.PinnedMem not in args.fast:
|
global TOTAL_PINNED_MEMORY
|
||||||
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not is_nvidia():
|
if type(tensor) is not torch.nn.parameter.Parameter:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not is_device_cpu(tensor.device):
|
if not is_device_cpu(tensor.device):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if torch.cuda.cudart().cudaHostRegister(tensor.data_ptr(), tensor.numel() * tensor.element_size(), 1) == 0:
|
if tensor.is_pinned():
|
||||||
|
#NOTE: Cuda does detect when a tensor is already pinned and would
|
||||||
|
#error below, but there are proven cases where this also queues an error
|
||||||
|
#on the GPU async. So dont trust the CUDA API and guard here
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not tensor.is_contiguous():
|
||||||
|
return False
|
||||||
|
|
||||||
|
size = tensor.numel() * tensor.element_size()
|
||||||
|
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||||
|
return False
|
||||||
|
|
||||||
|
ptr = tensor.data_ptr()
|
||||||
|
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
||||||
|
PINNED_MEMORY[ptr] = size
|
||||||
|
TOTAL_PINNED_MEMORY += size
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def unpin_memory(tensor):
|
def unpin_memory(tensor):
|
||||||
if PerformanceFeature.PinnedMem not in args.fast:
|
global TOTAL_PINNED_MEMORY
|
||||||
return False
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
|
|
||||||
if not is_nvidia():
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not is_device_cpu(tensor.device):
|
if not is_device_cpu(tensor.device):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if torch.cuda.cudart().cudaHostUnregister(tensor.data_ptr()) == 0:
|
ptr = tensor.data_ptr()
|
||||||
|
size = tensor.numel() * tensor.element_size()
|
||||||
|
|
||||||
|
size_stored = PINNED_MEMORY.get(ptr, None)
|
||||||
|
if size_stored is None:
|
||||||
|
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if size != size_stored:
|
||||||
|
logging.warning("Size of pinned tensor changed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||||
|
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||||
|
if len(PINNED_MEMORY) == 0:
|
||||||
|
TOTAL_PINNED_MEMORY = 0
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -298,6 +298,7 @@ class ModelPatcher:
|
|||||||
n.backup = self.backup
|
n.backup = self.backup
|
||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
n.parent = self
|
n.parent = self
|
||||||
|
n.pinned = self.pinned
|
||||||
|
|
||||||
n.force_cast_weights = self.force_cast_weights
|
n.force_cast_weights = self.force_cast_weights
|
||||||
|
|
||||||
@ -842,7 +843,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.object_patches_backup.clear()
|
self.object_patches_backup.clear()
|
||||||
|
|
||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||||
with self.use_ejected():
|
with self.use_ejected():
|
||||||
hooks_unpatched = False
|
hooks_unpatched = False
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
@ -886,13 +887,19 @@ class ModelPatcher:
|
|||||||
module_mem += move_weight_functions(m, device_to)
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
if force_patch_weights:
|
||||||
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
self.patch_weight_to_device(weight_key)
|
||||||
patch_counter += 1
|
else:
|
||||||
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||||
|
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||||
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
if force_patch_weights:
|
||||||
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
self.patch_weight_to_device(bias_key)
|
||||||
patch_counter += 1
|
else:
|
||||||
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||||
|
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||||
|
patch_counter += 1
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
|
|
||||||
if cast_weight:
|
if cast_weight:
|
||||||
@ -908,6 +915,7 @@ class ModelPatcher:
|
|||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
|
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||||
@ -920,6 +928,9 @@ class ModelPatcher:
|
|||||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||||
|
|
||||||
self.patch_model(load_weights=False)
|
self.patch_model(load_weights=False)
|
||||||
|
if extra_memory < 0 and not unpatch_weights:
|
||||||
|
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||||
|
return 0
|
||||||
full_load = False
|
full_load = False
|
||||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||||
|
|||||||
50
comfy/ops.py
50
comfy/ops.py
@ -35,7 +35,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
|
||||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
import inspect
|
import inspect
|
||||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||||
@ -71,7 +71,6 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
@torch.compiler.disable()
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||||
@ -84,7 +83,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
if offloadable:
|
if offloadable and (device != s.weight.device or
|
||||||
|
(s.bias is not None and device != s.bias.device)):
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
else:
|
else:
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
@ -94,21 +94,25 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
else:
|
else:
|
||||||
wf_context = contextlib.nullcontext()
|
wf_context = contextlib.nullcontext()
|
||||||
|
|
||||||
bias = None
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
|
||||||
has_function = len(s.bias_function) > 0
|
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
|
||||||
|
|
||||||
if has_function:
|
weight_has_function = len(s.weight_function) > 0
|
||||||
|
bias_has_function = len(s.bias_function) > 0
|
||||||
|
|
||||||
|
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
bias = None
|
||||||
|
if s.bias is not None:
|
||||||
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
if bias_has_function:
|
||||||
with wf_context:
|
with wf_context:
|
||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
has_function = len(s.weight_function) > 0
|
if weight_has_function or weight.dtype != dtype:
|
||||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
|
||||||
if has_function:
|
|
||||||
with wf_context:
|
with wf_context:
|
||||||
|
weight = weight.to(dtype=dtype)
|
||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
@ -401,15 +405,9 @@ def fp8_linear(self, input):
|
|||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
tensor_2d = False
|
|
||||||
if len(input.shape) == 2:
|
|
||||||
tensor_2d = True
|
|
||||||
input = input.unsqueeze(1)
|
|
||||||
|
|
||||||
input_shape = input.shape
|
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
|
|
||||||
if len(input.shape) == 3:
|
if input.ndim == 3 or input.ndim == 2:
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
@ -422,24 +420,20 @@ def fp8_linear(self, input):
|
|||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
|
||||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||||
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight)
|
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype)
|
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||||
|
|
||||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight)
|
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||||
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
uncast_bias_weight(self, w, bias, offload_stream)
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
|
return o
|
||||||
if tensor_2d:
|
|
||||||
return o.reshape(input_shape[0], -1)
|
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -540,12 +534,12 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from .quant_ops import QuantizedTensor
|
||||||
|
|
||||||
QUANT_FORMAT_MIXINS = {
|
QUANT_FORMAT_MIXINS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
"dtype": torch.float8_e4m3fn,
|
"dtype": torch.float8_e4m3fn,
|
||||||
"layout_type": TensorCoreFP8Layout,
|
"layout_type": "TensorCoreFP8Layout",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
|||||||
@ -123,15 +123,15 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
layout_type: Layout class (subclass of QuantizedLayout)
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
layout_params: Dict with layout-specific parameters
|
layout_params: Dict with layout-specific parameters
|
||||||
"""
|
"""
|
||||||
return torch.Tensor._make_subclass(cls, qdata, require_grad=False)
|
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||||
|
|
||||||
def __init__(self, qdata, layout_type, layout_params):
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
self._qdata = qdata.contiguous()
|
self._qdata = qdata
|
||||||
self._layout_type = layout_type
|
self._layout_type = layout_type
|
||||||
self._layout_params = layout_params
|
self._layout_params = layout_params
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
layout_name = self._layout_type.__name__
|
layout_name = self._layout_type
|
||||||
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||||
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||||
|
|
||||||
@ -179,15 +179,15 @@ class QuantizedTensor(torch.Tensor):
|
|||||||
attr_name = f"_layout_param_{key}"
|
attr_name = f"_layout_param_{key}"
|
||||||
layout_params[key] = inner_tensors[attr_name]
|
layout_params[key] = inner_tensors[attr_name]
|
||||||
|
|
||||||
return QuantizedTensor(inner_tensors["_q_data"], layout_type, layout_params)
|
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||||
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs)
|
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
||||||
return cls(qdata, layout_type, layout_params)
|
return cls(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
def dequantize(self) -> torch.Tensor:
|
def dequantize(self) -> torch.Tensor:
|
||||||
return self._layout_type.dequantize(self._qdata, **self._layout_params)
|
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
|||||||
return qtensor._qdata, qtensor._layout_params['scale']
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
|
||||||
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
LAYOUTS = {
|
||||||
|
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
||||||
def fp8_linear(func, args, kwargs):
|
def fp8_linear(func, args, kwargs):
|
||||||
input_tensor = args[0]
|
input_tensor = args[0]
|
||||||
weight = args[1]
|
weight = args[1]
|
||||||
@ -406,13 +411,17 @@ def fp8_linear(func, args, kwargs):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
output = torch._scaled_mm(
|
output = torch._scaled_mm(
|
||||||
plain_input.reshape(-1, input_shape[2]),
|
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
||||||
weight_t,
|
weight_t,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
scale_a=scale_a,
|
scale_a=scale_a,
|
||||||
scale_b=scale_b,
|
scale_b=scale_b,
|
||||||
out_dtype=out_dtype,
|
out_dtype=out_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
if not tensor_2d:
|
if not tensor_2d:
|
||||||
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||||
|
|
||||||
@ -422,7 +431,7 @@ def fp8_linear(func, args, kwargs):
|
|||||||
'scale': output_scale,
|
'scale': output_scale,
|
||||||
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||||
}
|
}
|
||||||
return QuantizedTensor(output, TensorCoreFP8Layout, output_params)
|
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
||||||
else:
|
else:
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -436,3 +445,68 @@ def fp8_linear(func, args, kwargs):
|
|||||||
input_tensor = input_tensor.dequantize()
|
input_tensor = input_tensor.dequantize()
|
||||||
|
|
||||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
||||||
|
if out_dtype is None:
|
||||||
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||||
|
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||||
|
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
plain_input.contiguous(),
|
||||||
|
plain_weight,
|
||||||
|
bias=bias,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||||
|
output = output[0]
|
||||||
|
return output
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_addmm(func, args, kwargs):
|
||||||
|
input_tensor = args[1]
|
||||||
|
weight = args[2]
|
||||||
|
bias = args[0]
|
||||||
|
|
||||||
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
|
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
||||||
|
|
||||||
|
a = list(args)
|
||||||
|
if isinstance(args[0], QuantizedTensor):
|
||||||
|
a[0] = args[0].dequantize()
|
||||||
|
if isinstance(args[1], QuantizedTensor):
|
||||||
|
a[1] = args[1].dequantize()
|
||||||
|
if isinstance(args[2], QuantizedTensor):
|
||||||
|
a[2] = args[2].dequantize()
|
||||||
|
|
||||||
|
return func(*a, **kwargs)
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_mm(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
|
||||||
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
|
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
||||||
|
|
||||||
|
a = list(args)
|
||||||
|
if isinstance(args[0], QuantizedTensor):
|
||||||
|
a[0] = args[0].dequantize()
|
||||||
|
if isinstance(args[1], QuantizedTensor):
|
||||||
|
a[1] = args[1].dequantize()
|
||||||
|
return func(*a, **kwargs)
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||||
|
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_func(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
ar = list(args)
|
||||||
|
ar[0] = plain_input
|
||||||
|
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|||||||
@ -1,73 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import aiohttp
|
|
||||||
import mimetypes
|
|
||||||
from typing import Union
|
|
||||||
from server import PromptServer
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
import base64
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_and_cast_response(
|
|
||||||
response, timeout: int = None, node_id: Union[str, None] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Validates and casts a response to a torch.Tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: The response to validate and cast.
|
|
||||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A torch.Tensor representing the image (1, H, W, C).
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the response is not valid.
|
|
||||||
"""
|
|
||||||
# validate raw JSON response
|
|
||||||
data = response.data
|
|
||||||
if not data or len(data) == 0:
|
|
||||||
raise ValueError("No images returned from API endpoint")
|
|
||||||
|
|
||||||
# Initialize list to store image tensors
|
|
||||||
image_tensors: list[torch.Tensor] = []
|
|
||||||
|
|
||||||
# Process each image in the data array
|
|
||||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
|
||||||
for img_data in data:
|
|
||||||
img_bytes: bytes
|
|
||||||
if img_data.b64_json:
|
|
||||||
img_bytes = base64.b64decode(img_data.b64_json)
|
|
||||||
elif img_data.url:
|
|
||||||
if node_id:
|
|
||||||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
|
||||||
async with session.get(img_data.url) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
raise ValueError("Failed to download generated image")
|
|
||||||
img_bytes = await resp.read()
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
|
||||||
|
|
||||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
|
||||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
|
||||||
image_tensors.append(torch.from_numpy(arr))
|
|
||||||
|
|
||||||
return torch.stack(image_tensors, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
|
||||||
"""Converts a text file to a base64 string."""
|
|
||||||
with open(filepath, "rb") as f:
|
|
||||||
file_content = f.read()
|
|
||||||
return base64.b64encode(file_content).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
|
||||||
"""Converts a text file to a data URI."""
|
|
||||||
base64_string = text_filepath_to_base64_string(filepath)
|
|
||||||
mime_type, _ = mimetypes.guess_type(filepath)
|
|
||||||
if mime_type is None:
|
|
||||||
mime_type = "application/octet-stream"
|
|
||||||
return f"data:{mime_type};base64,{base64_string}"
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: filtered-openapi.yaml
|
|
||||||
# timestamp: 2025-04-29T23:44:54+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from . import PixverseDto
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseData(BaseModel):
|
|
||||||
ErrCode: Optional[int] = None
|
|
||||||
ErrMsg: Optional[str] = None
|
|
||||||
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
|
|
||||||
@ -1,57 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: filtered-openapi.yaml
|
|
||||||
# timestamp: 2025-04-29T23:44:54+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class V2OpenAPII2VResp(BaseModel):
|
|
||||||
video_id: Optional[int] = Field(None, description='Video_id')
|
|
||||||
|
|
||||||
|
|
||||||
class V2OpenAPIT2VReq(BaseModel):
|
|
||||||
aspect_ratio: str = Field(
|
|
||||||
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
|
|
||||||
)
|
|
||||||
duration: int = Field(
|
|
||||||
...,
|
|
||||||
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
|
|
||||||
examples=[5],
|
|
||||||
)
|
|
||||||
model: str = Field(
|
|
||||||
..., description='Model version (only supports v3.5)', examples=['v3.5']
|
|
||||||
)
|
|
||||||
motion_mode: Optional[str] = Field(
|
|
||||||
'normal',
|
|
||||||
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
|
||||||
examples=['normal'],
|
|
||||||
)
|
|
||||||
negative_prompt: Optional[str] = Field(
|
|
||||||
None, description='Negative prompt\n', max_length=2048
|
|
||||||
)
|
|
||||||
prompt: str = Field(..., description='Prompt', max_length=2048)
|
|
||||||
quality: str = Field(
|
|
||||||
...,
|
|
||||||
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
|
||||||
examples=['540p'],
|
|
||||||
)
|
|
||||||
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
|
|
||||||
style: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
|
|
||||||
examples=['anime'],
|
|
||||||
)
|
|
||||||
template_id: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='Template ID (template_id must be activated before use)',
|
|
||||||
examples=[302325299692608],
|
|
||||||
)
|
|
||||||
water_mark: Optional[bool] = Field(
|
|
||||||
False,
|
|
||||||
description='Watermark (true: add watermark, false: no watermark)',
|
|
||||||
examples=[False],
|
|
||||||
)
|
|
||||||
@ -1,981 +0,0 @@
|
|||||||
"""
|
|
||||||
API Client Framework for api.comfy.org.
|
|
||||||
|
|
||||||
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
|
||||||
It supports both synchronous and asynchronous API operations with proper type validation.
|
|
||||||
|
|
||||||
Key Components:
|
|
||||||
--------------
|
|
||||||
1. ApiClient - Handles HTTP requests with authentication and error handling
|
|
||||||
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
|
||||||
3. ApiOperation - Executes a single synchronous API operation
|
|
||||||
|
|
||||||
Usage Examples:
|
|
||||||
--------------
|
|
||||||
|
|
||||||
# Example 1: Synchronous API Operation
|
|
||||||
# ------------------------------------
|
|
||||||
# For a simple API call that returns the result immediately:
|
|
||||||
|
|
||||||
# 1. Create the API client
|
|
||||||
api_client = ApiClient(
|
|
||||||
base_url="https://api.example.com",
|
|
||||||
auth_token="your_auth_token_here",
|
|
||||||
comfy_api_key="your_comfy_api_key_here",
|
|
||||||
timeout=30.0,
|
|
||||||
verify_ssl=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Define the endpoint
|
|
||||||
user_info_endpoint = ApiEndpoint(
|
|
||||||
path="/v1/users/me",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest, # No request body needed
|
|
||||||
response_model=UserProfile, # Pydantic model for the response
|
|
||||||
query_params=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Create the request object
|
|
||||||
request = EmptyRequest()
|
|
||||||
|
|
||||||
# 4. Create and execute the operation
|
|
||||||
operation = ApiOperation(
|
|
||||||
endpoint=user_info_endpoint,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
|
|
||||||
|
|
||||||
|
|
||||||
# Example 2: Asynchronous API Operation with Polling
|
|
||||||
# -------------------------------------------------
|
|
||||||
# For an API that starts a task and requires polling for completion:
|
|
||||||
|
|
||||||
# 1. Define the endpoints (initial request and polling)
|
|
||||||
generate_image_endpoint = ApiEndpoint(
|
|
||||||
path="/v1/images/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=ImageGenerationRequest,
|
|
||||||
response_model=TaskCreatedResponse,
|
|
||||||
query_params=None
|
|
||||||
)
|
|
||||||
|
|
||||||
check_task_endpoint = ApiEndpoint(
|
|
||||||
path="/v1/tasks/{task_id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=ImageGenerationResult,
|
|
||||||
query_params=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Create the request object
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt="a beautiful sunset over mountains",
|
|
||||||
width=1024,
|
|
||||||
height=1024,
|
|
||||||
num_images=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Create and execute the polling operation
|
|
||||||
operation = PollingOperation(
|
|
||||||
initial_endpoint=generate_image_endpoint,
|
|
||||||
initial_request=request,
|
|
||||||
poll_endpoint=check_task_endpoint,
|
|
||||||
task_id_field="task_id",
|
|
||||||
status_field="status",
|
|
||||||
completed_statuses=["completed"],
|
|
||||||
failed_statuses=["failed", "error"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# This will make the initial request and then poll until completion
|
|
||||||
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
from aiohttp.client_exceptions import ClientError, ClientResponseError
|
|
||||||
from typing import Type, Optional, Any, TypeVar, Generic, Callable
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
|
||||||
from urllib.parse import urljoin, urlparse
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import uuid # For generating unique operation IDs
|
|
||||||
|
|
||||||
from server import PromptServer
|
|
||||||
from comfy.cli_args import args
|
|
||||||
from comfy import utils
|
|
||||||
from . import request_logger
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
|
||||||
R = TypeVar("R", bound=BaseModel)
|
|
||||||
P = TypeVar("P", bound=BaseModel) # For poll response
|
|
||||||
|
|
||||||
PROGRESS_BAR_MAX = 100
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkError(Exception):
|
|
||||||
"""Base exception for network-related errors with diagnostic information."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LocalNetworkError(NetworkError):
|
|
||||||
"""Exception raised when local network connectivity issues are detected."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ApiServerError(NetworkError):
|
|
||||||
"""Exception raised when the API server is unreachable but internet is working."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyRequest(BaseModel):
|
|
||||||
"""Base class for empty request bodies.
|
|
||||||
For GET requests, fields will be sent as query parameters."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UploadRequest(BaseModel):
|
|
||||||
file_name: str = Field(..., description="Filename to upload")
|
|
||||||
content_type: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UploadResponse(BaseModel):
|
|
||||||
download_url: str = Field(..., description="URL to GET uploaded file")
|
|
||||||
upload_url: str = Field(..., description="URL to PUT file to upload")
|
|
||||||
|
|
||||||
|
|
||||||
class HttpMethod(str, Enum):
|
|
||||||
GET = "GET"
|
|
||||||
POST = "POST"
|
|
||||||
PUT = "PUT"
|
|
||||||
DELETE = "DELETE"
|
|
||||||
PATCH = "PATCH"
|
|
||||||
|
|
||||||
|
|
||||||
class ApiClient:
|
|
||||||
"""
|
|
||||||
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
base_url: str,
|
|
||||||
auth_token: Optional[str] = None,
|
|
||||||
comfy_api_key: Optional[str] = None,
|
|
||||||
timeout: float = 3600.0,
|
|
||||||
verify_ssl: bool = True,
|
|
||||||
max_retries: int = 3,
|
|
||||||
retry_delay: float = 1.0,
|
|
||||||
retry_backoff_factor: float = 2.0,
|
|
||||||
retry_status_codes: Optional[tuple[int, ...]] = None,
|
|
||||||
session: Optional[aiohttp.ClientSession] = None,
|
|
||||||
):
|
|
||||||
self.base_url = base_url
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.comfy_api_key = comfy_api_key
|
|
||||||
self.timeout = timeout
|
|
||||||
self.verify_ssl = verify_ssl
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.retry_delay = retry_delay
|
|
||||||
self.retry_backoff_factor = retry_backoff_factor
|
|
||||||
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
|
|
||||||
# 500, 502, 503, 504 (Server Errors)
|
|
||||||
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
|
|
||||||
self._session: Optional[aiohttp.ClientSession] = session
|
|
||||||
self._owns_session = session is None # Track if we have to close it
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_operation_id(path: str) -> str:
|
|
||||||
"""Generates a unique operation ID for logging."""
|
|
||||||
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_json_payload_args(
|
|
||||||
data: Optional[dict[str, Any]] = None,
|
|
||||||
headers: Optional[dict[str, str]] = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"json": data,
|
|
||||||
"headers": headers,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _create_form_data_args(
|
|
||||||
self,
|
|
||||||
data: dict[str, Any] | None,
|
|
||||||
files: dict[str, Any] | None,
|
|
||||||
headers: Optional[dict[str, str]] = None,
|
|
||||||
multipart_parser: Callable | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
if headers and "Content-Type" in headers:
|
|
||||||
del headers["Content-Type"]
|
|
||||||
|
|
||||||
if multipart_parser and data:
|
|
||||||
data = multipart_parser(data)
|
|
||||||
|
|
||||||
if isinstance(data, aiohttp.FormData):
|
|
||||||
form = data # If the parser already returned a FormData, pass it through
|
|
||||||
else:
|
|
||||||
form = aiohttp.FormData(default_to_multipart=True)
|
|
||||||
if data: # regular text fields
|
|
||||||
for k, v in data.items():
|
|
||||||
if v is None:
|
|
||||||
continue # aiohttp fails to serialize "None" values
|
|
||||||
# aiohttp expects strings or bytes; convert enums etc.
|
|
||||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
|
||||||
|
|
||||||
if files:
|
|
||||||
file_iter = files if isinstance(files, list) else files.items()
|
|
||||||
for field_name, file_obj in file_iter:
|
|
||||||
if file_obj is None:
|
|
||||||
continue # aiohttp fails to serialize "None" values
|
|
||||||
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
|
|
||||||
if isinstance(file_obj, tuple):
|
|
||||||
filename, file_value, content_type = self._unpack_tuple(file_obj)
|
|
||||||
else:
|
|
||||||
file_value = file_obj
|
|
||||||
filename = getattr(file_obj, "name", field_name)
|
|
||||||
content_type = "application/octet-stream"
|
|
||||||
|
|
||||||
form.add_field(
|
|
||||||
name=field_name,
|
|
||||||
value=file_value,
|
|
||||||
filename=filename,
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
|
||||||
return {"data": form, "headers": headers or {}}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_urlencoded_form_data_args(
|
|
||||||
data: dict[str, Any],
|
|
||||||
headers: Optional[dict[str, str]] = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
headers = headers or {}
|
|
||||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
|
||||||
return {
|
|
||||||
"data": data,
|
|
||||||
"headers": headers,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_headers(self) -> dict[str, str]:
|
|
||||||
"""Get headers for API requests, including authentication if available"""
|
|
||||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
|
||||||
|
|
||||||
if self.auth_token:
|
|
||||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
|
||||||
elif self.comfy_api_key:
|
|
||||||
headers["X-API-KEY"] = self.comfy_api_key
|
|
||||||
|
|
||||||
return headers
|
|
||||||
|
|
||||||
async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
|
|
||||||
"""
|
|
||||||
Check connectivity to determine if network issues are local or server-related.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_url: URL to check connectivity to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with connectivity status details
|
|
||||||
"""
|
|
||||||
results = {
|
|
||||||
"internet_accessible": False,
|
|
||||||
"api_accessible": False,
|
|
||||||
"is_local_issue": False,
|
|
||||||
"is_api_issue": False,
|
|
||||||
}
|
|
||||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
try:
|
|
||||||
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
|
|
||||||
results["internet_accessible"] = resp.status < 500
|
|
||||||
except (ClientError, asyncio.TimeoutError, socket.gaierror):
|
|
||||||
results["is_local_issue"] = True
|
|
||||||
return results # cannot reach the internet – early exit
|
|
||||||
|
|
||||||
# Now check API health endpoint
|
|
||||||
parsed = urlparse(target_url)
|
|
||||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
|
||||||
try:
|
|
||||||
async with session.get(health_url, ssl=self.verify_ssl) as resp:
|
|
||||||
results["api_accessible"] = resp.status < 500
|
|
||||||
except ClientError:
|
|
||||||
pass # leave as False
|
|
||||||
|
|
||||||
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def request(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
path: str,
|
|
||||||
params: Optional[dict[str, Any]] = None,
|
|
||||||
data: Optional[dict[str, Any]] = None,
|
|
||||||
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
|
|
||||||
headers: Optional[dict[str, str]] = None,
|
|
||||||
content_type: str = "application/json",
|
|
||||||
multipart_parser: Callable | None = None,
|
|
||||||
retry_count: int = 0, # Used internally for tracking retries
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Make an HTTP request to the API with automatic retries for transient errors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
method: HTTP method (GET, POST, etc.)
|
|
||||||
path: API endpoint path (will be joined with base_url)
|
|
||||||
params: Query parameters
|
|
||||||
data: body data
|
|
||||||
files: Files to upload
|
|
||||||
headers: Additional headers
|
|
||||||
content_type: Content type of the request. Defaults to application/json.
|
|
||||||
retry_count: Internal parameter for tracking retries, do not set manually
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parsed JSON response
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
LocalNetworkError: If local network connectivity issues are detected
|
|
||||||
ApiServerError: If the API server is unreachable but internet is working
|
|
||||||
Exception: For other request failures
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Build full URL and merge headers
|
|
||||||
relative_path = path.lstrip("/")
|
|
||||||
url = urljoin(self.base_url, relative_path)
|
|
||||||
self._check_auth(self.auth_token, self.comfy_api_key)
|
|
||||||
|
|
||||||
request_headers = self.get_headers()
|
|
||||||
if headers:
|
|
||||||
request_headers.update(headers)
|
|
||||||
if files:
|
|
||||||
request_headers.pop("Content-Type", None)
|
|
||||||
if params:
|
|
||||||
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
|
||||||
|
|
||||||
logging.debug("[DEBUG] Request Headers: %s", request_headers)
|
|
||||||
logging.debug("[DEBUG] Files: %s", files)
|
|
||||||
logging.debug("[DEBUG] Params: %s", params)
|
|
||||||
logging.debug("[DEBUG] Data: %s", data)
|
|
||||||
|
|
||||||
if content_type == "application/x-www-form-urlencoded":
|
|
||||||
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
|
||||||
elif content_type == "multipart/form-data":
|
|
||||||
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
|
|
||||||
else:
|
|
||||||
payload_args = self._create_json_payload_args(data, request_headers)
|
|
||||||
|
|
||||||
operation_id = self._generate_operation_id(path)
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method=method,
|
|
||||||
request_url=url,
|
|
||||||
request_headers=request_headers,
|
|
||||||
request_params=params,
|
|
||||||
request_data=data if content_type == "application/json" else "[form-data or other]",
|
|
||||||
)
|
|
||||||
|
|
||||||
session = await self._get_session()
|
|
||||||
try:
|
|
||||||
async with session.request(
|
|
||||||
method,
|
|
||||||
url,
|
|
||||||
params=params,
|
|
||||||
ssl=self.verify_ssl,
|
|
||||||
**payload_args,
|
|
||||||
) as resp:
|
|
||||||
if resp.status >= 400:
|
|
||||||
try:
|
|
||||||
error_data = await resp.json()
|
|
||||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
|
||||||
error_data = await resp.text()
|
|
||||||
|
|
||||||
return await self._handle_http_error(
|
|
||||||
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
|
|
||||||
operation_id,
|
|
||||||
method,
|
|
||||||
url,
|
|
||||||
params,
|
|
||||||
data,
|
|
||||||
files,
|
|
||||||
headers,
|
|
||||||
content_type,
|
|
||||||
multipart_parser,
|
|
||||||
retry_count=retry_count,
|
|
||||||
response_content=error_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Success – parse JSON (safely) and log
|
|
||||||
try:
|
|
||||||
payload = await resp.json()
|
|
||||||
response_content_to_log = payload
|
|
||||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
|
||||||
payload = {}
|
|
||||||
response_content_to_log = await resp.text()
|
|
||||||
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method=method,
|
|
||||||
request_url=url,
|
|
||||||
response_status_code=resp.status,
|
|
||||||
response_headers=dict(resp.headers),
|
|
||||||
response_content=response_content_to_log,
|
|
||||||
)
|
|
||||||
return payload
|
|
||||||
|
|
||||||
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
|
|
||||||
# Treat as *connection* problem – optionally retry, else escalate
|
|
||||||
if retry_count < self.max_retries:
|
|
||||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
|
||||||
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
|
|
||||||
self.max_retries, str(e))
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
return await self.request(
|
|
||||||
method,
|
|
||||||
path,
|
|
||||||
params=params,
|
|
||||||
data=data,
|
|
||||||
files=files,
|
|
||||||
headers=headers,
|
|
||||||
content_type=content_type,
|
|
||||||
multipart_parser=multipart_parser,
|
|
||||||
retry_count=retry_count + 1,
|
|
||||||
)
|
|
||||||
# One final connectivity check for diagnostics
|
|
||||||
connectivity = await self._check_connectivity(self.base_url)
|
|
||||||
if connectivity["is_local_issue"]:
|
|
||||||
raise LocalNetworkError(
|
|
||||||
"Unable to connect to the API server due to local network issues. "
|
|
||||||
"Please check your internet connection and try again."
|
|
||||||
) from e
|
|
||||||
raise ApiServerError(
|
|
||||||
f"The API server at {self.base_url} is currently unreachable. "
|
|
||||||
f"The service may be experiencing issues. Please try again later."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _check_auth(auth_token, comfy_api_key):
|
|
||||||
"""Verify that an auth token is present or comfy_api_key is present"""
|
|
||||||
if auth_token is None and comfy_api_key is None:
|
|
||||||
raise Exception("Unauthorized: Please login first to use this node.")
|
|
||||||
return auth_token or comfy_api_key
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def upload_file(
|
|
||||||
upload_url: str,
|
|
||||||
file: io.BytesIO | str,
|
|
||||||
content_type: str | None = None,
|
|
||||||
max_retries: int = 3,
|
|
||||||
retry_delay: float = 1.0,
|
|
||||||
retry_backoff_factor: float = 2.0,
|
|
||||||
) -> aiohttp.ClientResponse:
|
|
||||||
"""Upload a file to the API with retry logic.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
upload_url: The URL to upload to
|
|
||||||
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
|
|
||||||
content_type: Optional mime type to set for the upload
|
|
||||||
max_retries: Maximum number of retry attempts
|
|
||||||
retry_delay: Initial delay between retries in seconds
|
|
||||||
retry_backoff_factor: Multiplier for the delay after each retry
|
|
||||||
"""
|
|
||||||
headers: dict[str, str] = {}
|
|
||||||
skip_auto_headers: set[str] = set()
|
|
||||||
if content_type:
|
|
||||||
headers["Content-Type"] = content_type
|
|
||||||
else:
|
|
||||||
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
|
|
||||||
skip_auto_headers.add("Content-Type")
|
|
||||||
|
|
||||||
# Extract file bytes
|
|
||||||
if isinstance(file, io.BytesIO):
|
|
||||||
file.seek(0)
|
|
||||||
data = file.read()
|
|
||||||
elif isinstance(file, str):
|
|
||||||
with open(file, "rb") as f:
|
|
||||||
data = f.read()
|
|
||||||
else:
|
|
||||||
raise ValueError("File must be BytesIO or str path")
|
|
||||||
|
|
||||||
parsed = urlparse(upload_url)
|
|
||||||
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
|
|
||||||
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method="PUT",
|
|
||||||
request_url=upload_url,
|
|
||||||
request_headers=headers,
|
|
||||||
request_data=f"[File data {len(data)} bytes]",
|
|
||||||
)
|
|
||||||
|
|
||||||
delay = retry_delay
|
|
||||||
for attempt in range(max_retries + 1):
|
|
||||||
try:
|
|
||||||
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
async with session.put(
|
|
||||||
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
|
|
||||||
) as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method="PUT",
|
|
||||||
request_url=upload_url,
|
|
||||||
response_status_code=resp.status,
|
|
||||||
response_headers=dict(resp.headers),
|
|
||||||
response_content="File uploaded successfully.",
|
|
||||||
)
|
|
||||||
return resp
|
|
||||||
except (ClientError, asyncio.TimeoutError) as e:
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method="PUT",
|
|
||||||
request_url=upload_url,
|
|
||||||
response_status_code=e.status if hasattr(e, "status") else None,
|
|
||||||
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
|
|
||||||
response_content=None,
|
|
||||||
error_message=f"{type(e).__name__}: {str(e)}",
|
|
||||||
)
|
|
||||||
if attempt < max_retries:
|
|
||||||
logging.warning(
|
|
||||||
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
delay *= retry_backoff_factor
|
|
||||||
else:
|
|
||||||
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
|
|
||||||
|
|
||||||
async def _handle_http_error(
|
|
||||||
self,
|
|
||||||
exc: ClientResponseError,
|
|
||||||
operation_id: str,
|
|
||||||
*req_meta,
|
|
||||||
retry_count: int,
|
|
||||||
response_content: dict | str = "",
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
status_code = exc.status
|
|
||||||
if status_code == 401:
|
|
||||||
user_friendly = "Unauthorized: Please login first to use this node."
|
|
||||||
elif status_code == 402:
|
|
||||||
user_friendly = "Payment Required: Please add credits to your account to use this node."
|
|
||||||
elif status_code == 409:
|
|
||||||
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
|
|
||||||
elif status_code == 429:
|
|
||||||
user_friendly = "Rate Limit Exceeded: Please try again later."
|
|
||||||
else:
|
|
||||||
if isinstance(response_content, dict):
|
|
||||||
if "error" in response_content and "message" in response_content["error"]:
|
|
||||||
user_friendly = f"API Error: {response_content['error']['message']}"
|
|
||||||
if "type" in response_content["error"]:
|
|
||||||
user_friendly += f" (Type: {response_content['error']['type']})"
|
|
||||||
else: # Handle cases where error is just a JSON dict with unknown format
|
|
||||||
user_friendly = f"API Error: {json.dumps(response_content)}"
|
|
||||||
else:
|
|
||||||
if len(response_content) < 200: # Arbitrary limit for display
|
|
||||||
user_friendly = f"API Error (raw): {response_content}"
|
|
||||||
else:
|
|
||||||
user_friendly = f"API Error (raw, status {response_content})"
|
|
||||||
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method=req_meta[0],
|
|
||||||
request_url=req_meta[1],
|
|
||||||
response_status_code=exc.status,
|
|
||||||
response_headers=dict(req_meta[5]) if req_meta[5] else None,
|
|
||||||
response_content=response_content,
|
|
||||||
error_message=f"HTTP Error {exc.status}",
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
|
|
||||||
if response_content:
|
|
||||||
logging.debug("[DEBUG] Response content: %s", response_content)
|
|
||||||
|
|
||||||
# Retry if eligible
|
|
||||||
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
|
||||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
|
||||||
logging.warning(
|
|
||||||
"HTTP error %s. Retrying in %.2fs (%s/%s)",
|
|
||||||
status_code,
|
|
||||||
delay,
|
|
||||||
retry_count + 1,
|
|
||||||
self.max_retries,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
return await self.request(
|
|
||||||
req_meta[0], # method
|
|
||||||
req_meta[1].replace(self.base_url, ""), # path
|
|
||||||
params=req_meta[2],
|
|
||||||
data=req_meta[3],
|
|
||||||
files=req_meta[4],
|
|
||||||
headers=req_meta[5],
|
|
||||||
content_type=req_meta[6],
|
|
||||||
multipart_parser=req_meta[7],
|
|
||||||
retry_count=retry_count + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise Exception(user_friendly) from exc
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unpack_tuple(t):
|
|
||||||
"""Helper to normalise (filename, file, content_type) tuples."""
|
|
||||||
if len(t) == 3:
|
|
||||||
return t
|
|
||||||
elif len(t) == 2:
|
|
||||||
return t[0], t[1], "application/octet-stream"
|
|
||||||
else:
|
|
||||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
|
||||||
|
|
||||||
async def _get_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._session is None or self._session.closed:
|
|
||||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
|
||||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
|
||||||
self._owns_session = True
|
|
||||||
return self._session
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
if self._owns_session and self._session and not self._session.closed:
|
|
||||||
await self._session.close()
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "ApiClient":
|
|
||||||
"""Allow usage as async‑context‑manager – ensures clean teardown"""
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
|
|
||||||
class ApiEndpoint(Generic[T, R]):
|
|
||||||
"""Defines an API endpoint with its request and response types"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
method: HttpMethod,
|
|
||||||
request_model: Type[T],
|
|
||||||
response_model: Type[R],
|
|
||||||
query_params: Optional[dict[str, Any]] = None,
|
|
||||||
):
|
|
||||||
"""Initialize an API endpoint definition.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: The URL path for this endpoint, can include placeholders like {id}
|
|
||||||
method: The HTTP method to use (GET, POST, etc.)
|
|
||||||
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
|
||||||
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
|
||||||
query_params: Optional dictionary of query parameters to include in the request
|
|
||||||
"""
|
|
||||||
self.path = path
|
|
||||||
self.method = method
|
|
||||||
self.request_model = request_model
|
|
||||||
self.response_model = response_model
|
|
||||||
self.query_params = query_params or {}
|
|
||||||
|
|
||||||
|
|
||||||
class SynchronousOperation(Generic[T, R]):
|
|
||||||
"""Represents a single synchronous API operation."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
endpoint: ApiEndpoint[T, R],
|
|
||||||
request: T,
|
|
||||||
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
|
|
||||||
api_base: str | None = None,
|
|
||||||
auth_token: Optional[str] = None,
|
|
||||||
comfy_api_key: Optional[str] = None,
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
timeout: float = 7200.0,
|
|
||||||
verify_ssl: bool = True,
|
|
||||||
content_type: str = "application/json",
|
|
||||||
multipart_parser: Callable | None = None,
|
|
||||||
max_retries: int = 3,
|
|
||||||
retry_delay: float = 1.0,
|
|
||||||
retry_backoff_factor: float = 2.0,
|
|
||||||
) -> None:
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.request = request
|
|
||||||
self.files = files
|
|
||||||
self.api_base: str = api_base or args.comfy_api_base
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.comfy_api_key = comfy_api_key
|
|
||||||
if auth_kwargs is not None:
|
|
||||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
|
||||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
|
||||||
self.timeout = timeout
|
|
||||||
self.verify_ssl = verify_ssl
|
|
||||||
self.content_type = content_type
|
|
||||||
self.multipart_parser = multipart_parser
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.retry_delay = retry_delay
|
|
||||||
self.retry_backoff_factor = retry_backoff_factor
|
|
||||||
|
|
||||||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
|
||||||
owns_client = client is None
|
|
||||||
if owns_client:
|
|
||||||
client = ApiClient(
|
|
||||||
base_url=self.api_base,
|
|
||||||
auth_token=self.auth_token,
|
|
||||||
comfy_api_key=self.comfy_api_key,
|
|
||||||
timeout=self.timeout,
|
|
||||||
verify_ssl=self.verify_ssl,
|
|
||||||
max_retries=self.max_retries,
|
|
||||||
retry_delay=self.retry_delay,
|
|
||||||
retry_backoff_factor=self.retry_backoff_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
request_dict: Optional[dict[str, Any]]
|
|
||||||
if isinstance(self.request, EmptyRequest):
|
|
||||||
request_dict = None
|
|
||||||
else:
|
|
||||||
request_dict = self.request.model_dump(exclude_none=True)
|
|
||||||
for k, v in list(request_dict.items()):
|
|
||||||
if isinstance(v, Enum):
|
|
||||||
request_dict[k] = v.value
|
|
||||||
|
|
||||||
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
|
|
||||||
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
|
|
||||||
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
|
|
||||||
|
|
||||||
response_json = await client.request(
|
|
||||||
self.endpoint.method.value,
|
|
||||||
self.endpoint.path,
|
|
||||||
params=self.endpoint.query_params,
|
|
||||||
data=request_dict,
|
|
||||||
files=self.files,
|
|
||||||
content_type=self.content_type,
|
|
||||||
multipart_parser=self.multipart_parser,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.debug("=" * 50)
|
|
||||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
|
||||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
|
||||||
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
|
|
||||||
logging.debug("=" * 50)
|
|
||||||
|
|
||||||
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
|
||||||
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
|
|
||||||
return parsed_response
|
|
||||||
finally:
|
|
||||||
if owns_client:
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(str, Enum):
|
|
||||||
"""Enum for task status values"""
|
|
||||||
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
PENDING = "pending"
|
|
||||||
|
|
||||||
|
|
||||||
class PollingOperation(Generic[T, R]):
|
|
||||||
"""Represents an asynchronous API operation that requires polling for completion."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
|
||||||
completed_statuses: list[str],
|
|
||||||
failed_statuses: list[str],
|
|
||||||
*,
|
|
||||||
status_extractor: Callable[[R], Optional[str]],
|
|
||||||
progress_extractor: Callable[[R], Optional[float]] | None = None,
|
|
||||||
result_url_extractor: Callable[[R], Optional[str]] | None = None,
|
|
||||||
price_extractor: Callable[[R], Optional[float]] | None = None,
|
|
||||||
request: Optional[T] = None,
|
|
||||||
api_base: str | None = None,
|
|
||||||
auth_token: Optional[str] = None,
|
|
||||||
comfy_api_key: Optional[str] = None,
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
poll_interval: float = 5.0,
|
|
||||||
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
|
||||||
max_retries: int = 3, # Max retries per individual API call
|
|
||||||
retry_delay: float = 1.0,
|
|
||||||
retry_backoff_factor: float = 2.0,
|
|
||||||
estimated_duration: Optional[float] = None,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
self.poll_endpoint = poll_endpoint
|
|
||||||
self.request = request
|
|
||||||
self.api_base: str = api_base or args.comfy_api_base
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.comfy_api_key = comfy_api_key
|
|
||||||
if auth_kwargs is not None:
|
|
||||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
|
||||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
|
||||||
self.poll_interval = poll_interval
|
|
||||||
self.max_poll_attempts = max_poll_attempts
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.retry_delay = retry_delay
|
|
||||||
self.retry_backoff_factor = retry_backoff_factor
|
|
||||||
self.estimated_duration = estimated_duration
|
|
||||||
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
|
|
||||||
self.progress_extractor = progress_extractor
|
|
||||||
self.result_url_extractor = result_url_extractor
|
|
||||||
self.price_extractor = price_extractor
|
|
||||||
self.node_id = node_id
|
|
||||||
self.completed_statuses = completed_statuses
|
|
||||||
self.failed_statuses = failed_statuses
|
|
||||||
self.final_response: Optional[R] = None
|
|
||||||
self.extracted_price: Optional[float] = None
|
|
||||||
|
|
||||||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
|
||||||
owns_client = client is None
|
|
||||||
if owns_client:
|
|
||||||
client = ApiClient(
|
|
||||||
base_url=self.api_base,
|
|
||||||
auth_token=self.auth_token,
|
|
||||||
comfy_api_key=self.comfy_api_key,
|
|
||||||
max_retries=self.max_retries,
|
|
||||||
retry_delay=self.retry_delay,
|
|
||||||
retry_backoff_factor=self.retry_backoff_factor,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
return await self._poll_until_complete(client)
|
|
||||||
finally:
|
|
||||||
if owns_client:
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
def _display_text_on_node(self, text: str):
|
|
||||||
if not self.node_id:
|
|
||||||
return
|
|
||||||
if self.extracted_price is not None:
|
|
||||||
text = f"Price: ${self.extracted_price}\n{text}"
|
|
||||||
PromptServer.instance.send_progress_text(text, self.node_id)
|
|
||||||
|
|
||||||
def _display_time_progress_on_node(self, time_completed: int | float):
|
|
||||||
if not self.node_id:
|
|
||||||
return
|
|
||||||
if self.estimated_duration is not None:
|
|
||||||
remaining = max(0, int(self.estimated_duration) - time_completed)
|
|
||||||
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
|
|
||||||
else:
|
|
||||||
message = f"Task in progress: {time_completed}s"
|
|
||||||
self._display_text_on_node(message)
|
|
||||||
|
|
||||||
def _check_task_status(self, response: R) -> TaskStatus:
|
|
||||||
try:
|
|
||||||
status = self.status_extractor(response)
|
|
||||||
if status in self.completed_statuses:
|
|
||||||
return TaskStatus.COMPLETED
|
|
||||||
if status in self.failed_statuses:
|
|
||||||
return TaskStatus.FAILED
|
|
||||||
return TaskStatus.PENDING
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("Error extracting status: %s", e)
|
|
||||||
return TaskStatus.PENDING
|
|
||||||
|
|
||||||
async def _poll_until_complete(self, client: ApiClient) -> R:
|
|
||||||
"""Poll until the task is complete"""
|
|
||||||
consecutive_errors = 0
|
|
||||||
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
|
|
||||||
|
|
||||||
if self.progress_extractor:
|
|
||||||
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
|
|
||||||
|
|
||||||
status = TaskStatus.PENDING
|
|
||||||
for poll_count in range(1, self.max_poll_attempts + 1):
|
|
||||||
try:
|
|
||||||
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
|
|
||||||
|
|
||||||
request_dict = None if self.request is None else self.request.model_dump(exclude_none=True)
|
|
||||||
|
|
||||||
if poll_count == 1:
|
|
||||||
logging.debug(
|
|
||||||
"[DEBUG] Poll Request: %s %s",
|
|
||||||
self.poll_endpoint.method.value,
|
|
||||||
self.poll_endpoint.path,
|
|
||||||
)
|
|
||||||
logging.debug(
|
|
||||||
"[DEBUG] Poll Request Data: %s",
|
|
||||||
json.dumps(request_dict, indent=2) if request_dict else "None",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query task status
|
|
||||||
resp = await client.request(
|
|
||||||
self.poll_endpoint.method.value,
|
|
||||||
self.poll_endpoint.path,
|
|
||||||
params=self.poll_endpoint.query_params,
|
|
||||||
data=request_dict,
|
|
||||||
)
|
|
||||||
consecutive_errors = 0 # reset on success
|
|
||||||
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
|
|
||||||
|
|
||||||
# Check if task is complete
|
|
||||||
status = self._check_task_status(response_obj)
|
|
||||||
logging.debug("[DEBUG] Task Status: %s", status)
|
|
||||||
|
|
||||||
# If progress extractor is provided, extract progress
|
|
||||||
if self.progress_extractor:
|
|
||||||
new_progress = self.progress_extractor(response_obj)
|
|
||||||
if new_progress is not None:
|
|
||||||
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
|
||||||
|
|
||||||
if self.price_extractor:
|
|
||||||
price = self.price_extractor(response_obj)
|
|
||||||
if price is not None:
|
|
||||||
self.extracted_price = price
|
|
||||||
|
|
||||||
if status == TaskStatus.COMPLETED:
|
|
||||||
message = "Task completed successfully"
|
|
||||||
if self.result_url_extractor:
|
|
||||||
result_url = self.result_url_extractor(response_obj)
|
|
||||||
if result_url:
|
|
||||||
message = f"Result URL: {result_url}"
|
|
||||||
logging.debug("[DEBUG] %s", message)
|
|
||||||
self._display_text_on_node(message)
|
|
||||||
self.final_response = response_obj
|
|
||||||
if self.progress_extractor:
|
|
||||||
progress.update(100)
|
|
||||||
return self.final_response
|
|
||||||
if status == TaskStatus.FAILED:
|
|
||||||
message = f"Task failed: {json.dumps(resp)}"
|
|
||||||
logging.error("[DEBUG] %s", message)
|
|
||||||
raise Exception(message)
|
|
||||||
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
|
||||||
# Task pending – wait
|
|
||||||
for i in range(int(self.poll_interval)):
|
|
||||||
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
except (LocalNetworkError, ApiServerError, NetworkError) as e:
|
|
||||||
consecutive_errors += 1
|
|
||||||
if consecutive_errors >= max_consecutive_errors:
|
|
||||||
raise Exception(
|
|
||||||
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
|
||||||
) from e
|
|
||||||
logging.warning(
|
|
||||||
"Network error (%s/%s): %s",
|
|
||||||
consecutive_errors,
|
|
||||||
max_consecutive_errors,
|
|
||||||
str(e),
|
|
||||||
)
|
|
||||||
await asyncio.sleep(self.poll_interval)
|
|
||||||
except Exception as e:
|
|
||||||
# For other errors, increment count and potentially abort
|
|
||||||
consecutive_errors += 1
|
|
||||||
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
|
|
||||||
raise Exception(
|
|
||||||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
logging.error("[DEBUG] Polling error: %s", str(e))
|
|
||||||
logging.warning(
|
|
||||||
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
|
|
||||||
poll_count,
|
|
||||||
self.max_poll_attempts,
|
|
||||||
str(e),
|
|
||||||
self.poll_interval,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(self.poll_interval)
|
|
||||||
|
|
||||||
# If we've exhausted all polling attempts
|
|
||||||
raise Exception(
|
|
||||||
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
|
|
||||||
"The operation may still be running on the server but is taking longer than expected."
|
|
||||||
)
|
|
||||||
@ -46,7 +46,7 @@ class TextToVideoNode(IO.ComfyNode):
|
|||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
),
|
),
|
||||||
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
|
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[
|
options=[
|
||||||
@ -85,6 +85,10 @@ class TextToVideoNode(IO.ComfyNode):
|
|||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1, max_length=10000)
|
validate_string(prompt, min_length=1, max_length=10000)
|
||||||
|
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||||
|
raise ValueError(
|
||||||
|
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||||
|
)
|
||||||
response = await sync_op_raw(
|
response = await sync_op_raw(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
||||||
@ -118,7 +122,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
|||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
),
|
),
|
||||||
IO.Combo.Input("duration", options=[6, 8, 10], default=8),
|
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[
|
options=[
|
||||||
@ -158,6 +162,10 @@ class ImageToVideoNode(IO.ComfyNode):
|
|||||||
generate_audio: bool = False,
|
generate_audio: bool = False,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1, max_length=10000)
|
validate_string(prompt, min_length=1, max_length=10000)
|
||||||
|
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||||
|
raise ValueError(
|
||||||
|
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||||
|
)
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Currently only one input image is supported.")
|
raise ValueError("Currently only one input image is supported.")
|
||||||
response = await sync_op_raw(
|
response = await sync_op_raw(
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -7,24 +7,23 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, TypeVar
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||||
from comfy_api_nodes.apis import pika_defs
|
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
|
validate_string,
|
||||||
|
download_url_to_video_output,
|
||||||
|
tensor_to_bytesio,
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
EmptyRequest,
|
sync_op,
|
||||||
HttpMethod,
|
poll_op,
|
||||||
PollingOperation,
|
|
||||||
SynchronousOperation,
|
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio
|
|
||||||
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||||
@ -40,28 +39,18 @@ PATH_VIDEO_GET = "/proxy/pika/videos"
|
|||||||
|
|
||||||
|
|
||||||
async def execute_task(
|
async def execute_task(
|
||||||
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse],
|
task_id: str,
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
cls: type[IO.ComfyNode],
|
||||||
node_id: Optional[str] = None,
|
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
task_id = (await initial_operation.execute()).video_id
|
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||||
final_response: pika_defs.PikaVideoResponse = await PollingOperation(
|
cls,
|
||||||
poll_endpoint=ApiEndpoint(
|
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
response_model=pika_defs.PikaVideoResponse,
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=pika_defs.PikaVideoResponse,
|
|
||||||
),
|
|
||||||
completed_statuses=["finished"],
|
|
||||||
failed_statuses=["failed", "cancelled"],
|
|
||||||
status_extractor=lambda response: (response.status.value if response.status else None),
|
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||||
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
|
|
||||||
node_id=node_id,
|
|
||||||
estimated_duration=60,
|
estimated_duration=60,
|
||||||
max_poll_attempts=240,
|
max_poll_attempts=240,
|
||||||
).execute()
|
)
|
||||||
if not final_response.url:
|
if not final_response.url:
|
||||||
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
@ -124,23 +113,15 @@ class PikaImageToVideo(IO.ComfyNode):
|
|||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_IMAGE_TO_VIDEO,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaTextToVideoNode(IO.ComfyNode):
|
class PikaTextToVideoNode(IO.ComfyNode):
|
||||||
@ -183,18 +164,11 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
|||||||
duration: int,
|
duration: int,
|
||||||
aspect_ratio: float,
|
aspect_ratio: float,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_TEXT_TO_VIDEO,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -202,10 +176,9 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
aspectRatio=aspect_ratio,
|
aspectRatio=aspect_ratio,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
content_type="application/x-www-form-urlencoded",
|
content_type="application/x-www-form-urlencoded",
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaScenes(IO.ComfyNode):
|
class PikaScenes(IO.ComfyNode):
|
||||||
@ -309,24 +282,16 @@ class PikaScenes(IO.ComfyNode):
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
aspectRatio=aspect_ratio,
|
aspectRatio=aspect_ratio,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKASCENES,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikAdditionsNode(IO.ComfyNode):
|
class PikAdditionsNode(IO.ComfyNode):
|
||||||
@ -383,24 +348,16 @@ class PikAdditionsNode(IO.ComfyNode):
|
|||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKADDITIONS,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaSwapsNode(IO.ComfyNode):
|
class PikaSwapsNode(IO.ComfyNode):
|
||||||
@ -472,23 +429,15 @@ class PikaSwapsNode(IO.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKASWAPS,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaffectsNode(IO.ComfyNode):
|
class PikaffectsNode(IO.ComfyNode):
|
||||||
@ -528,18 +477,11 @@ class PikaffectsNode(IO.ComfyNode):
|
|||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKAFFECTS,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
|
||||||
pikaffect=pikaffect,
|
pikaffect=pikaffect,
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
@ -547,9 +489,8 @@ class PikaffectsNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaStartEndFrameNode(IO.ComfyNode):
|
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||||
@ -592,18 +533,11 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
|||||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||||
]
|
]
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKAFRAMES,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
|
||||||
response_model=pika_defs.PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -612,9 +546,8 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaApiNodesExtension(ComfyExtension):
|
class PikaApiNodesExtension(ComfyExtension):
|
||||||
|
|||||||
@ -5,12 +5,9 @@ Rodin API docs: https://developer.hyper3d.ai/
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
import folder_paths as comfy_paths
|
import folder_paths as comfy_paths
|
||||||
import aiohttp
|
|
||||||
import os
|
import os
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -26,11 +23,11 @@ from comfy_api_nodes.apis.rodin_api import (
|
|||||||
Rodin3DDownloadResponse,
|
Rodin3DDownloadResponse,
|
||||||
JobStatus,
|
JobStatus,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
|
sync_op,
|
||||||
|
poll_op,
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_bytesio,
|
||||||
SynchronousOperation,
|
|
||||||
PollingOperation,
|
|
||||||
)
|
)
|
||||||
from comfy_api.latest import ComfyExtension, IO
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
|
|
||||||
@ -121,35 +118,31 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
|||||||
|
|
||||||
|
|
||||||
async def create_generate_task(
|
async def create_generate_task(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
images=None,
|
images=None,
|
||||||
seed=1,
|
seed=1,
|
||||||
material="PBR",
|
material="PBR",
|
||||||
quality_override=18000,
|
quality_override=18000,
|
||||||
tier="Regular",
|
tier="Regular",
|
||||||
mesh_mode="Quad",
|
mesh_mode="Quad",
|
||||||
TAPose = False,
|
ta_pose: bool = False,
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
):
|
):
|
||||||
if images is None:
|
if images is None:
|
||||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||||
if len(images) > 5:
|
if len(images) > 5:
|
||||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||||
|
|
||||||
path = "/proxy/rodin/api/v2/rodin"
|
response = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
|
||||||
path=path,
|
response_model=Rodin3DGenerateResponse,
|
||||||
method=HttpMethod.POST,
|
data=Rodin3DGenerateRequest(
|
||||||
request_model=Rodin3DGenerateRequest,
|
|
||||||
response_model=Rodin3DGenerateResponse,
|
|
||||||
),
|
|
||||||
request=Rodin3DGenerateRequest(
|
|
||||||
seed=seed,
|
seed=seed,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
material=material,
|
material=material,
|
||||||
quality_override=quality_override,
|
quality_override=quality_override,
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode=mesh_mode,
|
||||||
TAPose=TAPose,
|
TAPose=ta_pose,
|
||||||
),
|
),
|
||||||
files=[
|
files=[
|
||||||
(
|
(
|
||||||
@ -159,11 +152,8 @@ async def create_generate_task(
|
|||||||
for image in images if image is not None
|
for image in images if image is not None
|
||||||
],
|
],
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if hasattr(response, "error"):
|
if hasattr(response, "error"):
|
||||||
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
@ -187,75 +177,46 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
|||||||
return "DONE"
|
return "DONE"
|
||||||
return "Generating"
|
return "Generating"
|
||||||
|
|
||||||
|
def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
|
||||||
|
if not response.jobs:
|
||||||
|
return None
|
||||||
|
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
|
||||||
|
return int((completed_count / len(response.jobs)) * 100)
|
||||||
|
|
||||||
async def poll_for_task_status(
|
|
||||||
subscription_key, auth_kwargs: Optional[dict[str, str]] = None,
|
async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse:
|
||||||
) -> Rodin3DCheckStatusResponse:
|
|
||||||
poll_operation = PollingOperation(
|
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/rodin/api/v2/status",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=Rodin3DCheckStatusRequest,
|
|
||||||
response_model=Rodin3DCheckStatusResponse,
|
|
||||||
),
|
|
||||||
request=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
|
|
||||||
completed_statuses=["DONE"],
|
|
||||||
failed_statuses=["FAILED"],
|
|
||||||
status_extractor=check_rodin_status,
|
|
||||||
poll_interval=3.0,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||||
return await poll_operation.execute()
|
return await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"),
|
||||||
async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse:
|
response_model=Rodin3DCheckStatusResponse,
|
||||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
data=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
|
||||||
operation = SynchronousOperation(
|
status_extractor=check_rodin_status,
|
||||||
endpoint=ApiEndpoint(
|
progress_extractor=extract_progress,
|
||||||
path="/proxy/rodin/api/v2/download",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=Rodin3DDownloadRequest,
|
|
||||||
response_model=Rodin3DDownloadResponse,
|
|
||||||
),
|
|
||||||
request=Rodin3DDownloadRequest(task_uuid=uuid),
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
return await operation.execute()
|
|
||||||
|
|
||||||
|
|
||||||
async def download_files(url_list, task_uuid):
|
async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse:
|
||||||
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
|
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||||
|
return await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"),
|
||||||
|
response_model=Rodin3DDownloadResponse,
|
||||||
|
data=Rodin3DDownloadRequest(task_uuid=uuid),
|
||||||
|
monitor_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def download_files(url_list, task_uuid: str):
|
||||||
|
result_folder_name = f"Rodin3D_{task_uuid}"
|
||||||
|
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
|
||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
model_file_path = None
|
model_file_path = None
|
||||||
async with aiohttp.ClientSession() as session:
|
for i in url_list.list:
|
||||||
for i in url_list.list:
|
file_path = os.path.join(save_path, i.name)
|
||||||
url = i.url
|
if file_path.endswith(".glb"):
|
||||||
file_name = i.name
|
model_file_path = os.path.join(result_folder_name, i.name)
|
||||||
file_path = os.path.join(save_path, file_name)
|
await download_url_to_bytesio(i.url, file_path)
|
||||||
if file_path.endswith(".glb"):
|
|
||||||
model_file_path = file_path
|
|
||||||
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
|
|
||||||
max_retries = 5
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
async with session.get(url) as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
async for chunk in resp.content.iter_chunked(32 * 1024):
|
|
||||||
f.write(chunk)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logging.info("Retrying...")
|
|
||||||
await asyncio.sleep(2)
|
|
||||||
else:
|
|
||||||
logging.info(
|
|
||||||
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
|
|
||||||
file_path,
|
|
||||||
max_retries,
|
|
||||||
)
|
|
||||||
return model_file_path
|
return model_file_path
|
||||||
|
|
||||||
|
|
||||||
@ -277,6 +238,7 @@ class Rodin3D_Regular(IO.ComfyNode):
|
|||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
IO.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -295,21 +257,17 @@ class Rodin3D_Regular(IO.ComfyNode):
|
|||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
task_uuid, subscription_key = await create_generate_task(
|
task_uuid, subscription_key = await create_generate_task(
|
||||||
|
cls,
|
||||||
images=m_images,
|
images=m_images,
|
||||||
seed=Seed,
|
seed=Seed,
|
||||||
material=Material_Type,
|
material=Material_Type,
|
||||||
quality_override=quality_override,
|
quality_override=quality_override,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode=mesh_mode,
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
await poll_for_task_status(subscription_key, cls)
|
||||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||||
model = await download_files(download_list, task_uuid)
|
model = await download_files(download_list, task_uuid)
|
||||||
|
|
||||||
return IO.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
@ -333,6 +291,7 @@ class Rodin3D_Detail(IO.ComfyNode):
|
|||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
IO.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -351,21 +310,17 @@ class Rodin3D_Detail(IO.ComfyNode):
|
|||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
task_uuid, subscription_key = await create_generate_task(
|
task_uuid, subscription_key = await create_generate_task(
|
||||||
|
cls,
|
||||||
images=m_images,
|
images=m_images,
|
||||||
seed=Seed,
|
seed=Seed,
|
||||||
material=Material_Type,
|
material=Material_Type,
|
||||||
quality_override=quality_override,
|
quality_override=quality_override,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode=mesh_mode,
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
await poll_for_task_status(subscription_key, cls)
|
||||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||||
model = await download_files(download_list, task_uuid)
|
model = await download_files(download_list, task_uuid)
|
||||||
|
|
||||||
return IO.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
@ -389,6 +344,7 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
|||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
IO.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -401,27 +357,22 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
|||||||
Material_Type,
|
Material_Type,
|
||||||
Polygon_count,
|
Polygon_count,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
tier = "Smooth"
|
|
||||||
num_images = Images.shape[0]
|
num_images = Images.shape[0]
|
||||||
m_images = []
|
m_images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
task_uuid, subscription_key = await create_generate_task(
|
task_uuid, subscription_key = await create_generate_task(
|
||||||
|
cls,
|
||||||
images=m_images,
|
images=m_images,
|
||||||
seed=Seed,
|
seed=Seed,
|
||||||
material=Material_Type,
|
material=Material_Type,
|
||||||
quality_override=quality_override,
|
quality_override=quality_override,
|
||||||
tier=tier,
|
tier="Smooth",
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode=mesh_mode,
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
await poll_for_task_status(subscription_key, cls)
|
||||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||||
model = await download_files(download_list, task_uuid)
|
model = await download_files(download_list, task_uuid)
|
||||||
|
|
||||||
return IO.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
@ -452,6 +403,7 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
|||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
IO.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -462,29 +414,21 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
|||||||
Images,
|
Images,
|
||||||
Seed,
|
Seed,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
tier = "Sketch"
|
|
||||||
num_images = Images.shape[0]
|
num_images = Images.shape[0]
|
||||||
m_images = []
|
m_images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
material_type = "PBR"
|
|
||||||
quality_override = 18000
|
|
||||||
mesh_mode = "Quad"
|
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
task_uuid, subscription_key = await create_generate_task(
|
task_uuid, subscription_key = await create_generate_task(
|
||||||
|
cls,
|
||||||
images=m_images,
|
images=m_images,
|
||||||
seed=Seed,
|
seed=Seed,
|
||||||
material=material_type,
|
material="PBR",
|
||||||
quality_override=quality_override,
|
quality_override=18000,
|
||||||
tier=tier,
|
tier="Sketch",
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode="Quad",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
await poll_for_task_status(subscription_key, cls)
|
||||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||||
model = await download_files(download_list, task_uuid)
|
model = await download_files(download_list, task_uuid)
|
||||||
|
|
||||||
return IO.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
@ -523,6 +467,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
|||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
IO.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -542,22 +487,18 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
|||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
task_uuid, subscription_key = await create_generate_task(
|
task_uuid, subscription_key = await create_generate_task(
|
||||||
|
cls,
|
||||||
images=m_images,
|
images=m_images,
|
||||||
seed=Seed,
|
seed=Seed,
|
||||||
material=Material_Type,
|
material=Material_Type,
|
||||||
quality_override=quality_override,
|
quality_override=quality_override,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode=mesh_mode,
|
||||||
TAPose=TAPose,
|
ta_pose=TAPose,
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
await poll_for_task_status(subscription_key, cls)
|
||||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||||
model = await download_files(download_list, task_uuid)
|
model = await download_files(download_list, task_uuid)
|
||||||
|
|
||||||
return IO.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
|
|||||||
@ -20,13 +20,6 @@ from comfy_api_nodes.apis.stability_api import (
|
|||||||
StabilityAudioInpaintRequest,
|
StabilityAudioInpaintRequest,
|
||||||
StabilityAudioResponse,
|
StabilityAudioResponse,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
PollingOperation,
|
|
||||||
EmptyRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
validate_audio_duration,
|
validate_audio_duration,
|
||||||
validate_string,
|
validate_string,
|
||||||
@ -34,6 +27,9 @@ from comfy_api_nodes.util import (
|
|||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
audio_bytes_to_audio_input,
|
audio_bytes_to_audio_input,
|
||||||
|
sync_op,
|
||||||
|
poll_op,
|
||||||
|
ApiEndpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
data=StabilityStableUltraRequest(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/generate/ultra",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityStableUltraRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=StabilityStableUltraRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||||
@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
data=StabilityStable3_5Request(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/generate/sd3",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityStable3_5Request,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=StabilityStable3_5Request(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||||
@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
data=StabilityUpscaleConservativeRequest(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityUpscaleConservativeRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=StabilityUpscaleConservativeRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
creativity=round(creativity,2),
|
creativity=round(creativity,2),
|
||||||
@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||||
@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
|
||||||
}
|
response_model=StabilityAsyncResponse,
|
||||||
|
data=StabilityUpscaleCreativeRequest(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/creative",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityUpscaleCreativeRequest,
|
|
||||||
response_model=StabilityAsyncResponse,
|
|
||||||
),
|
|
||||||
request=StabilityUpscaleCreativeRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
creativity=round(creativity,2),
|
creativity=round(creativity,2),
|
||||||
@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/stability/v2beta/results/{response_api.id}",
|
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
|
||||||
method=HttpMethod.GET,
|
response_model=StabilityResultsGetResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=StabilityResultsGetResponse,
|
|
||||||
),
|
|
||||||
poll_interval=3,
|
poll_interval=3,
|
||||||
completed_statuses=[StabilityPollStatus.finished],
|
|
||||||
failed_statuses=[StabilityPollStatus.failed],
|
|
||||||
status_extractor=lambda x: get_async_dummy_status(x),
|
status_extractor=lambda x: get_async_dummy_status(x),
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
)
|
)
|
||||||
response_poll: StabilityResultsGetResponse = await operation.execute()
|
|
||||||
|
|
||||||
if response_poll.finish_reason != "SUCCESS":
|
if response_poll.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||||
@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/fast",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||||
@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode):
|
|||||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
||||||
validate_string(prompt, max_length=10000)
|
validate_string(prompt, max_length=10000)
|
||||||
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
|
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=StabilityAudioResponse,
|
||||||
request_model=StabilityTextToAudioRequest,
|
data=payload,
|
||||||
response_model=StabilityAudioResponse,
|
|
||||||
),
|
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs= {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode):
|
|||||||
payload = StabilityAudioToAudioRequest(
|
payload = StabilityAudioToAudioRequest(
|
||||||
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||||
)
|
)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
|
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=StabilityAudioResponse,
|
||||||
request_model=StabilityAudioToAudioRequest,
|
data=payload,
|
||||||
response_model=StabilityAudioResponse,
|
|
||||||
),
|
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
files={"audio": audio_input_to_mp3(audio)},
|
files={"audio": audio_input_to_mp3(audio)},
|
||||||
auth_kwargs= {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode):
|
|||||||
mask_start=mask_start,
|
mask_start=mask_start,
|
||||||
mask_end=mask_end,
|
mask_end=mask_end,
|
||||||
)
|
)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
|
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=StabilityAudioResponse,
|
||||||
request_model=StabilityAudioInpaintRequest,
|
data=payload,
|
||||||
response_model=StabilityAudioResponse,
|
|
||||||
),
|
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
files={"audio": audio_input_to_mp3(audio)},
|
files={"audio": audio_input_to_mp3(audio)},
|
||||||
auth_kwargs={
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
|
|||||||
@ -18,6 +18,8 @@ from .conversions import (
|
|||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
tensor_to_pil,
|
tensor_to_pil,
|
||||||
|
text_filepath_to_base64_string,
|
||||||
|
text_filepath_to_data_uri,
|
||||||
trim_video,
|
trim_video,
|
||||||
video_to_base64_string,
|
video_to_base64_string,
|
||||||
)
|
)
|
||||||
@ -75,6 +77,8 @@ __all__ = [
|
|||||||
"tensor_to_base64_string",
|
"tensor_to_base64_string",
|
||||||
"tensor_to_bytesio",
|
"tensor_to_bytesio",
|
||||||
"tensor_to_pil",
|
"tensor_to_pil",
|
||||||
|
"text_filepath_to_base64_string",
|
||||||
|
"text_filepath_to_data_uri",
|
||||||
"trim_video",
|
"trim_video",
|
||||||
"video_to_base64_string",
|
"video_to_base64_string",
|
||||||
# Validation utilities
|
# Validation utilities
|
||||||
|
|||||||
@ -16,9 +16,9 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from comfy import utils
|
from comfy import utils
|
||||||
from comfy_api.latest import IO
|
from comfy_api.latest import IO
|
||||||
from comfy_api_nodes.apis import request_logger
|
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
|
|
||||||
|
from . import request_logger
|
||||||
from ._helpers import (
|
from ._helpers import (
|
||||||
default_base_url,
|
default_base_url,
|
||||||
get_auth_header,
|
get_auth_header,
|
||||||
@ -77,7 +77,7 @@ class _PollUIState:
|
|||||||
|
|
||||||
|
|
||||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
|
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"]
|
||||||
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||||||
|
|
||||||
@ -589,7 +589,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||||
|
|
||||||
payload_headers = {"Accept": "*/*"}
|
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||||
if cfg.endpoint.headers:
|
if cfg.endpoint.headers:
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import mimetypes
|
||||||
import uuid
|
import uuid
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -12,7 +13,7 @@ from PIL import Image
|
|||||||
|
|
||||||
from comfy.utils import common_upscale
|
from comfy.utils import common_upscale
|
||||||
from comfy_api.latest import Input, InputImpl
|
from comfy_api.latest import Input, InputImpl
|
||||||
from comfy_api.util import VideoContainer, VideoCodec
|
from comfy_api.util import VideoCodec, VideoContainer
|
||||||
|
|
||||||
from ._helpers import mimetype_to_extension
|
from ._helpers import mimetype_to_extension
|
||||||
|
|
||||||
@ -451,3 +452,19 @@ def resize_mask_to_image(
|
|||||||
if not allow_gradient:
|
if not allow_gradient:
|
||||||
mask = (mask > 0.5).float()
|
mask = (mask > 0.5).float()
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||||
|
"""Converts a text file to a base64 string."""
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
file_content = f.read()
|
||||||
|
return base64.b64encode(file_content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||||
|
"""Converts a text file to a data URI."""
|
||||||
|
base64_string = text_filepath_to_base64_string(filepath)
|
||||||
|
mime_type, _ = mimetypes.guess_type(filepath)
|
||||||
|
if mime_type is None:
|
||||||
|
mime_type = "application/octet-stream"
|
||||||
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
|
|||||||
@ -12,8 +12,8 @@ from aiohttp.client_exceptions import ClientError, ContentTypeError
|
|||||||
|
|
||||||
from comfy_api.input_impl import VideoFromFile
|
from comfy_api.input_impl import VideoFromFile
|
||||||
from comfy_api.latest import IO as COMFY_IO
|
from comfy_api.latest import IO as COMFY_IO
|
||||||
from comfy_api_nodes.apis import request_logger
|
|
||||||
|
|
||||||
|
from . import request_logger
|
||||||
from ._helpers import (
|
from ._helpers import (
|
||||||
default_base_url,
|
default_base_url,
|
||||||
get_auth_header,
|
get_auth_header,
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import hashlib
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
@ -13,8 +13,8 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from comfy_api.latest import IO, Input
|
from comfy_api.latest import IO, Input
|
||||||
from comfy_api.util import VideoCodec, VideoContainer
|
from comfy_api.util import VideoCodec, VideoContainer
|
||||||
from comfy_api_nodes.apis import request_logger
|
|
||||||
|
|
||||||
|
from . import request_logger
|
||||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||||
from .client import (
|
from .client import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class Unhashable:
|
|||||||
def to_hashable(obj):
|
def to_hashable(obj):
|
||||||
# So that we don't infinitely recurse since frozenset and tuples
|
# So that we don't infinitely recurse since frozenset and tuples
|
||||||
# are Sequences.
|
# are Sequences.
|
||||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
||||||
return obj
|
return obj
|
||||||
elif isinstance(obj, Mapping):
|
elif isinstance(obj, Mapping):
|
||||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||||
@ -399,6 +399,8 @@ class RAMPressureCache(LRUCache):
|
|||||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||||
def scan_list_for_ram_usage(outputs):
|
def scan_list_for_ram_usage(outputs):
|
||||||
nonlocal ram_usage
|
nonlocal ram_usage
|
||||||
|
if outputs is None:
|
||||||
|
return
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
if isinstance(output, list):
|
if isinstance(output, list):
|
||||||
scan_list_for_ram_usage(output)
|
scan_list_for_ram_usage(output)
|
||||||
|
|||||||
@ -2,6 +2,9 @@ import comfy.utils
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
def load_hypernetwork_patch(path, strength):
|
def load_hypernetwork_patch(path, strength):
|
||||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||||
@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
|
|
||||||
return hypernetwork_patch(out, strength)
|
return hypernetwork_patch(out, strength)
|
||||||
|
|
||||||
class HypernetworkLoader:
|
class HypernetworkLoader(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return IO.Schema(
|
||||||
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
node_id="HypernetworkLoader",
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
category="loaders",
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("MODEL",)
|
IO.Model.Input("model"),
|
||||||
FUNCTION = "load_hypernetwork"
|
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
||||||
|
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
@classmethod
|
||||||
|
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
|
||||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
|
||||||
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||||
model_hypernetwork = model.clone()
|
model_hypernetwork = model.clone()
|
||||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
if patch is not None:
|
if patch is not None:
|
||||||
model_hypernetwork.set_model_attn1_patch(patch)
|
model_hypernetwork.set_model_attn1_patch(patch)
|
||||||
model_hypernetwork.set_model_attn2_patch(patch)
|
model_hypernetwork.set_model_attn2_patch(patch)
|
||||||
return (model_hypernetwork,)
|
return IO.NodeOutput(model_hypernetwork)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
load_hypernetwork = execute # TODO: remove
|
||||||
"HypernetworkLoader": HypernetworkLoader
|
|
||||||
}
|
|
||||||
|
class HyperNetworkExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
HypernetworkLoader,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> HyperNetworkExtension:
|
||||||
|
return HyperNetworkExtension()
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.67"
|
__version__ = "0.3.68"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.67"
|
version = "0.3.68"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
comfyui-frontend-package==1.28.8
|
comfyui-frontend-package==1.28.8
|
||||||
comfyui-workflow-templates==0.2.4
|
comfyui-workflow-templates==0.2.11
|
||||||
comfyui-embedded-docs==0.3.0
|
comfyui-embedded-docs==0.3.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
torchvision
|
torchvision
|
||||||
|
|||||||
@ -14,7 +14,7 @@ if not has_gpu():
|
|||||||
args.cpu = True
|
args.cpu = True
|
||||||
|
|
||||||
from comfy import ops
|
from comfy import ops
|
||||||
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
|
||||||
class SimpleModel(torch.nn.Module):
|
class SimpleModel(torch.nn.Module):
|
||||||
@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
|
|
||||||
# Verify weights are wrapped in QuantizedTensor
|
# Verify weights are wrapped in QuantizedTensor
|
||||||
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer1.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Layer 2 should NOT be quantized
|
# Layer 2 should NOT be quantized
|
||||||
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
|
||||||
|
|
||||||
# Layer 3 should be quantized
|
# Layer 3 should be quantized
|
||||||
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
self.assertIsInstance(model.layer3.weight, QuantizedTensor)
|
||||||
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify scales were loaded
|
# Verify scales were loaded
|
||||||
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
|
||||||
@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
|||||||
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
# Verify layer1.weight is a QuantizedTensor with scale preserved
|
||||||
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
|
||||||
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify non-quantized layers are standard tensors
|
# Verify non-quantized layers are standard tensors
|
||||||
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)
|
||||||
|
|||||||
@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
scale = torch.tensor(2.0)
|
scale = torch.tensor(2.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
self.assertIsInstance(qt, QuantizedTensor)
|
self.assertIsInstance(qt, QuantizedTensor)
|
||||||
self.assertEqual(qt.shape, (256, 128))
|
self.assertEqual(qt.shape, (256, 128))
|
||||||
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
self.assertEqual(qt.dtype, torch.float8_e4m3fn)
|
||||||
self.assertEqual(qt._layout_params['scale'], scale)
|
self.assertEqual(qt._layout_params['scale'], scale)
|
||||||
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
|
||||||
self.assertEqual(qt._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
def test_dequantize(self):
|
def test_dequantize(self):
|
||||||
"""Test explicit dequantization"""
|
"""Test explicit dequantization"""
|
||||||
@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
scale = torch.tensor(3.0)
|
scale = torch.tensor(3.0)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
|
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
dequantized = qt.dequantize()
|
dequantized = qt.dequantize()
|
||||||
|
|
||||||
self.assertEqual(dequantized.dtype, torch.float32)
|
self.assertEqual(dequantized.dtype, torch.float32)
|
||||||
@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase):
|
|||||||
|
|
||||||
qt = QuantizedTensor.from_float(
|
qt = QuantizedTensor.from_float(
|
||||||
float_tensor,
|
float_tensor,
|
||||||
TensorCoreFP8Layout,
|
"TensorCoreFP8Layout",
|
||||||
scale=scale,
|
scale=scale,
|
||||||
dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase):
|
|||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Detach should return a new QuantizedTensor
|
# Detach should return a new QuantizedTensor
|
||||||
qt_detached = qt.detach()
|
qt_detached = qt.detach()
|
||||||
|
|
||||||
self.assertIsInstance(qt_detached, QuantizedTensor)
|
self.assertIsInstance(qt_detached, QuantizedTensor)
|
||||||
self.assertEqual(qt_detached.shape, qt.shape)
|
self.assertEqual(qt_detached.shape, qt.shape)
|
||||||
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
def test_clone(self):
|
def test_clone(self):
|
||||||
"""Test clone operation on quantized tensor"""
|
"""Test clone operation on quantized tensor"""
|
||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Clone should return a new QuantizedTensor
|
# Clone should return a new QuantizedTensor
|
||||||
qt_cloned = qt.clone()
|
qt_cloned = qt.clone()
|
||||||
|
|
||||||
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
self.assertIsInstance(qt_cloned, QuantizedTensor)
|
||||||
self.assertEqual(qt_cloned.shape, qt.shape)
|
self.assertEqual(qt_cloned.shape, qt.shape)
|
||||||
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout)
|
self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
|
||||||
|
|
||||||
# Verify it's a deep copy
|
# Verify it's a deep copy
|
||||||
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
self.assertIsNot(qt_cloned._qdata, qt._qdata)
|
||||||
@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase):
|
|||||||
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
|
||||||
scale = torch.tensor(1.5)
|
scale = torch.tensor(1.5)
|
||||||
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
layout_params = {'scale': scale, 'orig_dtype': torch.float32}
|
||||||
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params)
|
qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
|
||||||
|
|
||||||
# Moving to same device should work (CPU to CPU)
|
# Moving to same device should work (CPU to CPU)
|
||||||
qt_cpu = qt.to('cpu')
|
qt_cpu = qt.to('cpu')
|
||||||
@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase):
|
|||||||
scale = torch.tensor(1.0)
|
scale = torch.tensor(1.0)
|
||||||
a_q = QuantizedTensor.from_float(
|
a_q = QuantizedTensor.from_float(
|
||||||
a_fp32,
|
a_fp32,
|
||||||
TensorCoreFP8Layout,
|
"TensorCoreFP8Layout",
|
||||||
scale=scale,
|
scale=scale,
|
||||||
dtype=torch.float8_e4m3fn
|
dtype=torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user