mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 22:30:50 +08:00
Merge branch 'master' into feature/unified-jobs-api
This commit is contained in:
commit
7f4fb73f95
@ -53,6 +53,16 @@ try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash") # noqa: T201
|
||||
except:
|
||||
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
||||
repo.state_cleanup()
|
||||
repo.index.read_tree(repo.head.peel().tree)
|
||||
repo.index.write()
|
||||
try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash.") # noqa: T201
|
||||
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||
try:
|
||||
|
||||
27
README.md
27
README.md
@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
@ -319,6 +320,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
|
||||
|
||||
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
|
||||
|
||||
### Setup
|
||||
|
||||
1. Install the manager dependencies:
|
||||
```bash
|
||||
pip install -r manager_requirements.txt
|
||||
```
|
||||
|
||||
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
|
||||
```bash
|
||||
python main.py --enable-manager
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
|
||||
@ -58,8 +58,13 @@ class InternalRoutes:
|
||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||
|
||||
directory = get_directory_by_type(directory_type)
|
||||
|
||||
def is_visible_file(entry: os.DirEntry) -> bool:
|
||||
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
||||
return entry.is_file() and not entry.name.startswith('.')
|
||||
|
||||
sorted_files = sorted(
|
||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||
key=lambda entry: -entry.stat().st_mtime
|
||||
)
|
||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||
|
||||
@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
|
||||
Latent2RGB = "latent2rgb"
|
||||
TAESD = "taesd"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str):
|
||||
for member in cls:
|
||||
if member.value == value:
|
||||
return member
|
||||
return None
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
|
||||
@ -51,32 +51,43 @@ class ContextHandlerABC(ABC):
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
return full[idx].to(device)
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
window = full[idx]
|
||||
if retain_index_list:
|
||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||
window[idx] = full[idx]
|
||||
return window.to(device)
|
||||
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
full[idx] += to_add
|
||||
return full
|
||||
|
||||
def get_region_index(self, num_regions: int) -> int:
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
|
||||
EXECUTE_START = "execute_start"
|
||||
EXECUTE_CLEANUP = "execute_cleanup"
|
||||
RESIZE_COND_ITEM = "resize_cond_item"
|
||||
|
||||
def init_callbacks(self):
|
||||
return {}
|
||||
@ -94,7 +105,8 @@ class ContextFuseMethod:
|
||||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -103,13 +115,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.closed_loop = closed_loop
|
||||
self.dim = dim
|
||||
self._step = 0
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -123,6 +140,11 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return None
|
||||
# reuse or resize cond items to match context requirements
|
||||
resized_cond = []
|
||||
# if multiple conds, split based on primary region
|
||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||
region = window.get_region_index(len(cond_in))
|
||||
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
||||
cond_in = [cond_in[region]]
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
resized_actual_cond = actual_cond.copy()
|
||||
@ -145,13 +167,32 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
new_cond_item = cond_item.copy()
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
# Allow callbacks to handle custom conditioning items
|
||||
handled = False
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(
|
||||
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
|
||||
):
|
||||
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
|
||||
if result is not None:
|
||||
new_cond_item[cond_key] = result
|
||||
handled = True
|
||||
break
|
||||
if handled:
|
||||
continue
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
audio_cond = cond_value.cond
|
||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
@ -164,7 +205,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
@ -173,7 +214,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
@ -250,8 +291,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
prev_weight = (bias_total / (bias_total + bias))
|
||||
new_weight = (bias / (bias_total + bias))
|
||||
# account for dims of tensors
|
||||
idx_window = [slice(None)] * self.dim + [idx]
|
||||
pos_window = [slice(None)] * self.dim + [pos]
|
||||
idx_window = tuple([slice(None)] * self.dim + [idx])
|
||||
pos_window = tuple([slice(None)] * self.dim + [pos])
|
||||
# apply new values
|
||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
@ -287,6 +328,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
)
|
||||
|
||||
|
||||
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
|
||||
model_options = extra_args.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is None:
|
||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
"ContextWindows_sampler_sample",
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
@ -538,3 +601,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
|
||||
for i in range(len(window)):
|
||||
# 2) add end_delta to each val to slide windows to end
|
||||
window[i] = window[i] + end_delta
|
||||
|
||||
|
||||
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
||||
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
||||
logging.info("Context windows: Applying FreeNoise")
|
||||
generator = torch.Generator(device='cpu').manual_seed(seed)
|
||||
latent_video_length = noise.shape[dim]
|
||||
delta = context_length - context_overlap
|
||||
|
||||
for start_idx in range(0, latent_video_length - context_length, delta):
|
||||
place_idx = start_idx + context_length
|
||||
|
||||
actual_delta = min(delta, latent_video_length - place_idx)
|
||||
if actual_delta <= 0:
|
||||
break
|
||||
|
||||
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
||||
|
||||
source_slice = [slice(None)] * noise.ndim
|
||||
source_slice[dim] = list_idx
|
||||
target_slice = [slice(None)] * noise.ndim
|
||||
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
||||
|
||||
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
||||
|
||||
return noise
|
||||
|
||||
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
|
||||
@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
|
||||
nerf_final_head_type: str
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
|
||||
use_x0: bool
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
|
||||
params_dict |= overrides
|
||||
return params.__class__(**params_dict)
|
||||
|
||||
def _apply_x0_residual(self, predicted, noisy, timesteps):
|
||||
|
||||
# non zero during training to prevent 0 div
|
||||
eps = 0.0
|
||||
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
|
||||
transformer_options,
|
||||
attn_mask=kwargs.get("attention_mask", None),
|
||||
)
|
||||
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
# If x0 variant → v-pred, just return this instead
|
||||
if hasattr(self, "__x0__"):
|
||||
out = self._apply_x0_residual(out, img, timestep)
|
||||
return out
|
||||
|
||||
|
||||
@ -43,6 +43,7 @@ class HunyuanVideoParams:
|
||||
meanflow: bool
|
||||
use_cond_type_embedding: bool
|
||||
vision_in_dim: int
|
||||
meanflow_sum: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@ -317,7 +318,7 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
||||
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
||||
vec = (vec + vec_r) / 2
|
||||
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
|
||||
413
comfy/ldm/kandinsky5/model.py
Normal file
413
comfy/ldm/kandinsky5/model.py
Normal file
@ -0,0 +1,413 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
def attention(q, k, v, heads, transformer_options={}):
|
||||
return optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
heads=heads,
|
||||
skip_reshape=True,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return torch.addcmul(x, gate, out)
|
||||
|
||||
def get_shift_scale_gate(params):
|
||||
shift, scale, gate = torch.chunk(params, 3, dim=-1)
|
||||
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||
|
||||
def get_freqs(dim, max_period=10000.0):
|
||||
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||
super().__init__()
|
||||
assert model_dim % 2 == 0
|
||||
self.model_dim = model_dim
|
||||
self.max_period = max_period
|
||||
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
|
||||
class TextEmbeddings(nn.Module):
|
||||
def __init__(self, text_dim, model_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, text_embed):
|
||||
text_embed = self.in_layer(text_embed)
|
||||
return self.norm(text_embed).type_as(text_embed)
|
||||
|
||||
|
||||
class VisualEmbeddings(nn.Module):
|
||||
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
||||
B, T, H, W, dim = x.shape
|
||||
pt, ph, pw = self.patch_size
|
||||
|
||||
x = x.view(
|
||||
B,
|
||||
T // pt, pt,
|
||||
H // ph, ph,
|
||||
W // pw, pw,
|
||||
dim,
|
||||
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
||||
|
||||
return self.in_layer(x)
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
self.head_dim = head_dim
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 2
|
||||
|
||||
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
||||
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(norm_fn(result), freqs)
|
||||
|
||||
def _forward(self, x, freqs, transformer_options={}):
|
||||
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||
def process_chunks(proj_fn, norm_fn):
|
||||
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||
chunks = []
|
||||
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
||||
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
q = process_chunks(self.to_query, self.query_norm)
|
||||
k = process_chunks(self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||
else:
|
||||
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class CrossAttention(SelfAttention):
|
||||
def get_qkv(self, x, context):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, context, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x, context)
|
||||
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, ff_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.GELU()
|
||||
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 4
|
||||
|
||||
def _forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
def _forward_chunked(self, x):
|
||||
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
output_chunks = []
|
||||
for chunk in chunks:
|
||||
output_chunks.append(self._forward(chunk))
|
||||
return torch.cat(output_chunks, dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, visual_embed, time_embed):
|
||||
B, T, H, W, _ = visual_embed.shape
|
||||
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||
scale = scale[:, None, None, None, :]
|
||||
shift = shift[:, None, None, None, :]
|
||||
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
|
||||
x = x.view(
|
||||
B, T, H, W,
|
||||
out_dim,
|
||||
self.patch_size[0], self.patch_size[1], self.patch_size[2]
|
||||
)
|
||||
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||
out = self.feed_forward(out)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||
# self attention
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# cross attention
|
||||
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# feed forward
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
|
||||
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
|
||||
dtype=None, device=None, operations=None, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
head_dim = sum(axes_dims)
|
||||
self.rope_scale_factor = rope_scale_factor
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_embed_dim = visual_embed_dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
|
||||
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
|
||||
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
|
||||
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.text_transformer_blocks = nn.ModuleList(
|
||||
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
|
||||
)
|
||||
|
||||
self.visual_transformer_blocks = nn.ModuleList(
|
||||
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
|
||||
)
|
||||
|
||||
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
|
||||
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
|
||||
|
||||
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
|
||||
steps = seq_len if steps is None else steps
|
||||
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
|
||||
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
|
||||
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
else:
|
||||
rope_scale_factor = self.rope_scale_factor
|
||||
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
||||
if h * w >= 14080:
|
||||
rope_scale_factor = (1.0, 3.16, 3.16)
|
||||
|
||||
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
||||
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
||||
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
context = self.text_embeddings(context)
|
||||
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
||||
|
||||
for block in self.text_transformer_blocks:
|
||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_embed = visual_embed.flatten(1, -2)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.visual_transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||
else:
|
||||
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
original_dims = x.ndim
|
||||
if original_dims == 4:
|
||||
x = x.unsqueeze(2)
|
||||
bs, c, t_len, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if time_dim_replace is not None:
|
||||
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||
|
||||
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||
if original_dims == 4:
|
||||
out = out.squeeze(2)
|
||||
return out
|
||||
|
||||
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||
@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
|
||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
n_control_layers=6,
|
||||
control_in_dim=16,
|
||||
additional_in_dim=0,
|
||||
broken=False,
|
||||
refiner_control=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
|
||||
super().__init__()
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.additional_in_dim = 0
|
||||
self.control_in_dim = 16
|
||||
self.broken = broken
|
||||
self.additional_in_dim = additional_in_dim
|
||||
self.control_in_dim = control_in_dim
|
||||
n_refiner_layers = 2
|
||||
self.n_control_layers = 6
|
||||
self.n_control_layers = n_control_layers
|
||||
self.control_layers = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module):
|
||||
all_x_embedder = {}
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
|
||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
self.refiner_control = refiner_control
|
||||
|
||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
if self.refiner_control:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=layer_id,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||
patch_size = 2
|
||||
@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
x_attn_mask = None
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||
if not self.refiner_control:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||
|
||||
return control_context
|
||||
|
||||
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
if self.refiner_control:
|
||||
if self.broken:
|
||||
if layer_id == 0:
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if layer_id > 0:
|
||||
out = None
|
||||
for i in range(1, len(self.control_layers)):
|
||||
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if out is None:
|
||||
out = o
|
||||
|
||||
return (out, control_context)
|
||||
else:
|
||||
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
else:
|
||||
return (None, control_context)
|
||||
|
||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
|
||||
@ -377,6 +377,7 @@ class NextDiT(nn.Module):
|
||||
z_image_modulation=False,
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -447,6 +448,31 @@ class NextDiT(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
self.clip_text_pooled_proj = None
|
||||
|
||||
if clip_text_dim is not None:
|
||||
self.clip_text_dim = clip_text_dim
|
||||
self.clip_text_pooled_proj = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
clip_text_dim,
|
||||
clip_text_dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.time_text_embed = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024) + clip_text_dim,
|
||||
min(dim, 1024),
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
@ -510,6 +536,7 @@ class NextDiT(nn.Module):
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
orig_x = x
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
@ -546,13 +573,21 @@ class NextDiT(nn.Module):
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
|
||||
padded_img_mask = None
|
||||
for layer in self.noise_refiner:
|
||||
x_input = x
|
||||
for i, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
if "noise_refiner" in patches:
|
||||
for p in patches["noise_refiner"]:
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
if "img" in out:
|
||||
x = out["img"]
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
mask = None
|
||||
@ -585,17 +620,26 @@ class NextDiT(nn.Module):
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
|
||||
@ -320,8 +320,16 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = k[:-len(".weight")]
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -134,7 +135,7 @@ class BaseModel(torch.nn.Module):
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
@ -329,18 +330,6 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
# Save mixed precision metadata
|
||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||
metadata = {
|
||||
"format_version": "1.0",
|
||||
"layers": self.model_config.layer_quant_config
|
||||
}
|
||||
unet_state_dict["_quantization_metadata"] = metadata
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
@ -1121,6 +1110,10 @@ class Lumina2(BaseModel):
|
||||
if 'num_tokens' not in out:
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
@ -1642,3 +1635,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
||||
return out
|
||||
|
||||
class Kandinsky5(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
else:
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
time_dim_replace = kwargs.get("time_dim_replace", None)
|
||||
if time_dim_replace is not None:
|
||||
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
|
||||
|
||||
return out
|
||||
|
||||
class Kandinsky5Image(Kandinsky5):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
return None
|
||||
|
||||
@ -6,20 +6,6 @@ import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
|
||||
def detect_layer_quantization(metadata):
|
||||
quant_key = "_quantization_metadata"
|
||||
if metadata is not None and quant_key in metadata:
|
||||
quant_metadata = metadata.pop(quant_key)
|
||||
quant_metadata = json.loads(quant_metadata)
|
||||
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||
return quant_metadata["layers"]
|
||||
else:
|
||||
raise ValueError("Invalid quantization metadata format")
|
||||
return None
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
@ -194,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["use_cond_type_embedding"] = False
|
||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["meanflow_sum"] = True
|
||||
else:
|
||||
dit_config["vision_in_dim"] = None
|
||||
dit_config["meanflow_sum"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
@ -271,6 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "__x0__" in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["use_x0"] = False
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
@ -437,6 +429,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None:
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
dit_config["n_kv_heads"] = 30
|
||||
@ -625,6 +620,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
dit_config = {}
|
||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["model_dim"] = model_dim
|
||||
if model_dim in [4096, 2560]: # pro video and lite image
|
||||
dit_config["axes_dims"] = (32, 48, 48)
|
||||
if model_dim == 2560: # lite image
|
||||
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
|
||||
elif model_dim == 1792: # lite video
|
||||
dit_config["axes_dims"] = (16, 24, 24)
|
||||
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["image_model"] = "kandinsky5"
|
||||
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@ -767,22 +780,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
if model_config is None and use_base_if_no_match:
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
model_config.optimizations["fp8"] = False
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
layer_quant_config = detect_layer_quantization(metadata)
|
||||
if layer_quant_config:
|
||||
model_config.layer_quant_config = layer_quant_config
|
||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
|
||||
if quant_config:
|
||||
model_config.quant_config = quant_config
|
||||
logging.info("Detected mixed precision quantization")
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
@ -1492,6 +1492,20 @@ def extended_fp16_support():
|
||||
|
||||
return True
|
||||
|
||||
LORA_COMPUTE_DTYPES = {}
|
||||
def lora_compute_dtype(device):
|
||||
dtype = LORA_COMPUTE_DTYPES.get(device, None)
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
|
||||
if should_use_fp16(device):
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
LORA_COMPUTE_DTYPES[device] = dtype
|
||||
return dtype
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
||||
@ -35,6 +35,7 @@ import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
|
||||
@ -126,36 +127,23 @@ class LowVramPatch:
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
self.patches = patches
|
||||
self.convert_func = convert_func
|
||||
self.convert_func = convert_func # TODO: remove
|
||||
self.set_func = set_func
|
||||
|
||||
def __call__(self, weight):
|
||||
intermediate_dtype = weight.dtype
|
||||
if self.convert_func is not None:
|
||||
weight = self.convert_func(weight, inplace=False)
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
|
||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||
intermediate_dtype = torch.float32
|
||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
||||
if self.set_func is None:
|
||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||
else:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||
|
||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||
if self.set_func is not None:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
||||
else:
|
||||
return out
|
||||
|
||||
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||
|
||||
def low_vram_patch_estimate_vram(model, key):
|
||||
weight, set_func, convert_func = get_key_weight(model, key)
|
||||
if weight is None:
|
||||
return 0
|
||||
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
|
||||
if model_dtype is None:
|
||||
model_dtype = weight.dtype
|
||||
|
||||
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||
|
||||
def get_key_weight(model, key):
|
||||
set_func = None
|
||||
@ -466,6 +454,9 @@ class ModelPatcher:
|
||||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_noise_refiner_patch(self, patch):
|
||||
self.set_model_patch(patch, "noise_refiner")
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
@ -630,10 +621,11 @@ class ModelPatcher:
|
||||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||
if device_to is not None:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
temp_weight = weight.to(temp_dtype, copy=True)
|
||||
if convert_func is not None:
|
||||
temp_weight = convert_func(temp_weight, inplace=True)
|
||||
|
||||
@ -677,12 +669,18 @@ class ModelPatcher:
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
if weight_key in self.patches:
|
||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
||||
if bias_key in self.patches:
|
||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
||||
def check_module_offload_mem(key):
|
||||
if key in self.patches:
|
||||
return low_vram_patch_estimate_vram(self.model, key)
|
||||
model_dtype = getattr(self.model, "manual_cast_dtype", None)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if model_dtype is None or weight is None:
|
||||
return 0
|
||||
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
|
||||
return weight.numel() * model_dtype.itemsize
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
@ -777,6 +775,8 @@ class ModelPatcher:
|
||||
key = "{}.{}".format(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
if comfy.model_management.is_device_cuda(device_to):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
@ -933,7 +933,7 @@ class ModelPatcher:
|
||||
patch_counter += 1
|
||||
cast_weight = True
|
||||
|
||||
if cast_weight:
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
|
||||
180
comfy/ops.py
180
comfy/ops.py
@ -22,7 +22,7 @@ import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
import json
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
@ -93,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
@ -422,22 +415,12 @@ def fp8_linear(self, input):
|
||||
|
||||
if input.ndim == 3 or input.ndim == 2:
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
else:
|
||||
scale_weight = scale_weight.to(input.device)
|
||||
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
@ -458,7 +441,7 @@ class fp8_ops(manual_cast):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if not self.training:
|
||||
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
|
||||
try:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
@ -471,59 +454,6 @@ class fp8_ops(manual_cast):
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if override_dtype is not None:
|
||||
kwargs['dtype'] = override_dtype
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset_parameters(self):
|
||||
if not hasattr(self, 'scale_weight'):
|
||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
|
||||
if not scale_input:
|
||||
self.scale_input = None
|
||||
|
||||
if not hasattr(self, 'scale_input'):
|
||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if fp8_matrix_mult:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
|
||||
if weight.numel() < input.numel(): #TODO: optimize
|
||||
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
else:
|
||||
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
return weight
|
||||
else:
|
||||
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||
if return_weight:
|
||||
return weight
|
||||
if inplace_update:
|
||||
self.weight.data.copy_(weight)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
return scaled_fp8_op
|
||||
|
||||
CUBLAS_IS_AVAILABLE = False
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
@ -550,9 +480,9 @@ if CUBLAS_IS_AVAILABLE:
|
||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||
|
||||
|
||||
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_layer_quant_config = layer_quant_config
|
||||
_quant_config = quant_config
|
||||
_compute_dtype = compute_dtype
|
||||
_full_precision_mm = full_precision_mm
|
||||
|
||||
@ -567,15 +497,14 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
if dtype is None:
|
||||
dtype = MixedPrecisionOps._compute_dtype
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
self._has_bias = bias
|
||||
|
||||
self.tensor_class = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
@ -595,36 +524,59 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
if layer_conf is None:
|
||||
dtype = self.factory_kwargs["dtype"]
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
|
||||
if dtype != MixedPrecisionOps._compute_dtype:
|
||||
self.comfy_cast_weights = True
|
||||
if self._has_bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
else:
|
||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||
if quant_format is None:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
||||
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(weight_scale_key, None)
|
||||
if scale is not None:
|
||||
scale = scale.to(device)
|
||||
layout_params = {
|
||||
'scale': state_dict.pop(weight_scale_key, None),
|
||||
'scale': scale,
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
if layout_params['scale'] is not None:
|
||||
|
||||
if scale is not None:
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
if self._has_bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
@ -633,6 +585,16 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
return sd
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
@ -648,9 +610,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
getattr(self, 'input_scale', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
@ -661,7 +622,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
@ -670,17 +631,28 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful
|
||||
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||
if recurse:
|
||||
for module in self.children():
|
||||
module._apply(fn)
|
||||
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
return MixedPrecisionOps
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||
|
||||
if (
|
||||
fp8_compute and
|
||||
|
||||
@ -238,6 +238,9 @@ class QuantizedTensor(torch.Tensor):
|
||||
def is_contiguous(self, *arg, **kwargs):
|
||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
||||
|
||||
def storage(self):
|
||||
return self._qdata.storage()
|
||||
|
||||
# ==============================================================================
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
# ==============================================================================
|
||||
@ -249,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn):
|
||||
|
||||
|
||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||
if target_dtype is not None and target_dtype != qt.dtype:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||
f"but not supported for quantized tensors. Ignoring dtype."
|
||||
)
|
||||
|
||||
if target_layout is not None and target_layout != torch.strided:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||
@ -274,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||
new_q_data = qt._qdata.to(device=target_device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
if target_dtype is not None:
|
||||
new_params["orig_dtype"] = target_dtype
|
||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||
return new_qt
|
||||
@ -339,7 +338,9 @@ def generic_copy_(func, args, kwargs):
|
||||
# Copy from another quantized tensor
|
||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||
qt_dest._layout_type = src._layout_type
|
||||
orig_dtype = qt_dest._layout_params["orig_dtype"]
|
||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||
qt_dest._layout_params["orig_dtype"] = orig_dtype
|
||||
else:
|
||||
# Copy from regular tensor - just copy raw data
|
||||
qt_dest._qdata.copy_(src)
|
||||
@ -397,17 +398,23 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||
orig_dtype = tensor.dtype
|
||||
|
||||
if scale is None:
|
||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||
if isinstance(scale, str) and scale == "recalculate":
|
||||
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
|
||||
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
||||
tensor_info = torch.finfo(tensor.dtype)
|
||||
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
if scale is not None:
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
if inplace_ops:
|
||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||
if inplace_ops:
|
||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||
|
||||
131
comfy/sd.py
131
comfy/sd.py
@ -54,6 +54,7 @@ import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -98,7 +99,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
@ -126,9 +127,32 @@ class CLIP:
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
self.patcher.is_clip = True
|
||||
self.apply_hooks_to_conds = None
|
||||
if len(state_dict) > 0:
|
||||
if isinstance(state_dict, list):
|
||||
for c in state_dict:
|
||||
m, u = self.load_sd(c)
|
||||
if len(m) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected: {}".format(u))
|
||||
else:
|
||||
m, u = self.load_sd(state_dict, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
@ -525,8 +549,10 @@ class VAE:
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
|
||||
|
||||
|
||||
# Hunyuan 3d v2 2.0 & 2.1
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
|
||||
@ -745,6 +771,8 @@ class VAE:
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
do_tile = False
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@ -962,16 +990,16 @@ class CLIPType(Enum):
|
||||
HUNYUAN_IMAGE = 19
|
||||
HUNYUAN_VIDEO_15 = 20
|
||||
OVIS = 21
|
||||
KANDINSKY5 = 22
|
||||
KANDINSKY5_IMAGE = 23
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||
if metadata is not None:
|
||||
quant_metadata = metadata.get("_quantization_metadata", None)
|
||||
if quant_metadata is not None:
|
||||
sd["_quantization_metadata"] = quant_metadata
|
||||
if model_options.get("custom_operations", None) is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||
clip_data.append(sd)
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
|
||||
@ -1088,7 +1116,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
@ -1112,7 +1140,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
@ -1141,7 +1169,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.LLAMA3_8:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif te_model == TEModel.QWEN25_3B:
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
@ -1169,7 +1197,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
@ -1212,6 +1240,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
||||
elif clip_type == CLIPType.KANDINSKY5:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
|
||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
@ -1224,19 +1258,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
|
||||
parameters = 0
|
||||
for c in clip_data:
|
||||
if "_quantization_metadata" in c:
|
||||
c.pop("_quantization_metadata")
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected: {}".format(u))
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
||||
return clip
|
||||
|
||||
def load_gligen(ckpt_path):
|
||||
@ -1295,6 +1320,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
@ -1303,18 +1332,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return None
|
||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
if model_config.quant_config is not None:
|
||||
weight_dtype = None
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
if model_config.quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
@ -1332,22 +1365,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
|
||||
if output_clip:
|
||||
if te_model_options.get("custom_operations", None) is None:
|
||||
scaled_fp8_list = []
|
||||
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
|
||||
if k.endswith(".scaled_fp8"):
|
||||
scaled_fp8_list.append(k[:-len("scaled_fp8")])
|
||||
|
||||
if len(scaled_fp8_list) > 0:
|
||||
out_sd = {}
|
||||
for k in sd:
|
||||
skip = False
|
||||
for pref in scaled_fp8_list:
|
||||
skip = skip or k.startswith(pref)
|
||||
if not skip:
|
||||
out_sd[k] = sd[k]
|
||||
|
||||
for pref in scaled_fp8_list:
|
||||
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||
for k in quant_sd:
|
||||
out_sd[k] = quant_sd[k]
|
||||
sd = out_sd
|
||||
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
||||
else:
|
||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||
|
||||
@ -1394,6 +1438,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
@ -1424,7 +1471,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
if model_config.quant_config is not None:
|
||||
weight_dtype = None
|
||||
|
||||
if dtype is None:
|
||||
@ -1432,12 +1479,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
if model_config.layer_quant_config is not None:
|
||||
if model_config.quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
@ -1476,6 +1526,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
if vae is not None:
|
||||
vae_sd = vae.get_sd()
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
|
||||
@ -107,29 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
config[k] = v
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
scaled_fp8 = None
|
||||
quantization_metadata = model_options.get("quantization_metadata", None)
|
||||
quant_config = model_options.get("quantization_metadata", None)
|
||||
|
||||
if operations is None:
|
||||
layer_quant_config = None
|
||||
if quantization_metadata is not None:
|
||||
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
|
||||
|
||||
if layer_quant_config is not None:
|
||||
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
|
||||
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
|
||||
if quant_config is not None:
|
||||
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||
logging.info("Using MixedPrecisionOps for text encoder")
|
||||
else:
|
||||
# Fallback to scaled_fp8_ops for backward compatibility
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
operations = comfy.ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
if scaled_fp8 is not None:
|
||||
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||||
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
|
||||
from . import supported_models_base
|
||||
@ -540,7 +541,7 @@ class SD3(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD3
|
||||
|
||||
memory_usage_factor = 1.2
|
||||
memory_usage_factor = 1.6
|
||||
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@ -964,7 +965,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosPredict2(self, device=device)
|
||||
@ -1025,7 +1026,7 @@ class ZImage(Lumina2):
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.7
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
@ -1288,7 +1289,7 @@ class ChromaRadiance(Chroma):
|
||||
latent_format = comfy.latent_formats.ChromaRadiance
|
||||
|
||||
# Pixel-space model, no spatial compression for model input.
|
||||
memory_usage_factor = 0.038
|
||||
memory_usage_factor = 0.044
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ChromaRadiance(self, device=device)
|
||||
@ -1331,7 +1332,7 @@ class Omnigen2(supported_models_base.BASE):
|
||||
"shift": 2.6,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.65 #TODO
|
||||
memory_usage_factor = 1.95 #TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
@ -1396,7 +1397,7 @@ class HunyuanImage21(HunyuanVideo):
|
||||
|
||||
latent_format = latent_formats.HunyuanImage21
|
||||
|
||||
memory_usage_factor = 7.7
|
||||
memory_usage_factor = 8.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
@ -1474,7 +1475,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
|
||||
|
||||
class Kandinsky5(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "kandinsky5",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 10.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.HunyuanVideo
|
||||
|
||||
memory_usage_factor = 1.25 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Kandinsky5(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class Kandinsky5Image(Kandinsky5):
|
||||
unet_config = {
|
||||
"image_model": "kandinsky5",
|
||||
"model_dim": 2560,
|
||||
"visual_embed_dim": 64,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Flux
|
||||
memory_usage_factor = 1.25 #TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Kandinsky5Image(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
from . import model_base
|
||||
from . import utils
|
||||
from . import latent_formats
|
||||
@ -49,8 +50,7 @@ class BASE:
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||
quant_config = None # quantization configuration for mixed precision
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
@ -118,3 +118,7 @@ class BASE:
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
def __getattr__(self, name):
|
||||
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
|
||||
return None
|
||||
|
||||
@ -7,10 +7,10 @@ from transformers import T5TokenizerFast
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||
if t5xxl_scaled_fp8 is not None:
|
||||
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||
if t5xxl_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
||||
|
||||
@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class CosmosTEModel_(CosmosT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class FluxClipModel_(FluxClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return FluxClipModel_
|
||||
|
||||
@ -159,15 +159,13 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
|
||||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||
return out, pooled, extra
|
||||
|
||||
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
|
||||
def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||
class Flux2TEModel_(Flux2TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if pruned:
|
||||
model_options = model_options.copy()
|
||||
|
||||
@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class MochiTEModel_(MochiT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module):
|
||||
return self.llama.load_sd(sd)
|
||||
|
||||
|
||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
|
||||
class HiDreamTEModel_(HiDreamTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return HiDreamTEModel_
|
||||
|
||||
@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel):
|
||||
else:
|
||||
return super().load_sd(sd)
|
||||
|
||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
||||
def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
|
||||
class QwenImageTEModel_(HunyuanImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast
|
||||
import torch
|
||||
import os
|
||||
import numbers
|
||||
|
||||
import comfy.utils
|
||||
|
||||
def llama_detect(state_dict, prefix=""):
|
||||
out = {}
|
||||
@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""):
|
||||
if t5_key in state_dict:
|
||||
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
|
||||
if "_quantization_metadata" in state_dict:
|
||||
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
|
||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||
if quant is not None:
|
||||
out["llama_quantization_metadata"] = quant
|
||||
|
||||
return out
|
||||
|
||||
@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
class LLAMAModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
textmodel_json_config = {}
|
||||
vocab_size = model_options.get("vocab_size", None)
|
||||
@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
return self.llama.load_sd(sd)
|
||||
|
||||
|
||||
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
|
||||
def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return HunyuanVideoClipModel_
|
||||
|
||||
68
comfy/text_encoders/kandinsky5.py
Normal file
68
comfy/text_encoders/kandinsky5.py
Normal file
@ -0,0 +1,68 @@
|
||||
from comfy import sd1_clip
|
||||
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||
from .llama import Qwen25_7BVLI
|
||||
|
||||
|
||||
class Kandinsky5Tokenizer(QwenImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Kandinsky5TEModel(QwenImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
|
||||
|
||||
return cond, l_pooled, extra
|
||||
|
||||
def set_clip_options(self, options):
|
||||
super().set_clip_options(options)
|
||||
self.clip_l.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
super().reset_clip_options()
|
||||
self.clip_l.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
return self.clip_l.load_sd(sd)
|
||||
else:
|
||||
return super().load_sd(sd)
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Kandinsky5TEModel_(Kandinsky5TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Kandinsky5TEModel_
|
||||
@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
|
||||
if model_type == "gemma2_2b":
|
||||
model = Gemma2_2BModel
|
||||
elif model_type == "gemma3_4b":
|
||||
@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
||||
|
||||
class LuminaTEModel_(LuminaModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
||||
|
||||
@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Omnigen2TEModel_(Omnigen2Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -55,12 +55,9 @@ class OvisTEModel(sd1_clip.SD1ClipModel):
|
||||
return out, pooled, {}
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class OvisTEModel_(OvisTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
|
||||
@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class PixArtTEModel_(PixArtT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class QwenImageTEModel_(QwenImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -6,14 +6,15 @@ import torch
|
||||
import os
|
||||
import comfy.model_management
|
||||
import logging
|
||||
import comfy.utils
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||
if t5xxl_scaled_fp8 is not None:
|
||||
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||
if t5xxl_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||
|
||||
model_options = {**model_options, "model_name": "t5xxl"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""):
|
||||
if t5_key in state_dict:
|
||||
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||
if quant is not None:
|
||||
out["t5_quantization_metadata"] = quant
|
||||
|
||||
return out
|
||||
|
||||
@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
|
||||
class SD3ClipModel_(SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||
return SD3ClipModel_
|
||||
|
||||
@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
||||
|
||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class WanTEModel(WanT5Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class ZImageTEModel_(ZImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
|
||||
@ -29,6 +29,7 @@ import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@ -52,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
else:
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
@ -802,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
return None
|
||||
return f.read(length_of_header)
|
||||
|
||||
ATTR_UNSET={}
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], value)
|
||||
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
||||
if value is ATTR_UNSET:
|
||||
delattr(obj, attrs[-1])
|
||||
else:
|
||||
setattr(obj, attrs[-1], value)
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
@ -1194,3 +1200,68 @@ def unpack_latents(combined_latent, latent_shapes):
|
||||
else:
|
||||
output_tensors = combined_latent
|
||||
return output_tensors
|
||||
|
||||
def detect_layer_quantization(state_dict, prefix):
|
||||
for k in state_dict:
|
||||
if k.startswith(prefix) and k.endswith(".comfy_quant"):
|
||||
logging.info("Found quantization metadata version 1")
|
||||
return {"mixed_ops": True}
|
||||
return None
|
||||
|
||||
def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
quant_metadata = None
|
||||
if "_quantization_metadata" not in metadata:
|
||||
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
|
||||
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
||||
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
||||
if scaled_fp8_dtype == torch.float32:
|
||||
scaled_fp8_dtype = torch.float8_e4m3fn
|
||||
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
full_precision_matrix_mult = True
|
||||
else:
|
||||
full_precision_matrix_mult = False
|
||||
|
||||
out_sd = {}
|
||||
layers = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if not k.startswith(model_prefix):
|
||||
out_sd[k] = state_dict[k]
|
||||
continue
|
||||
k_out = k
|
||||
w = state_dict.pop(k)
|
||||
layer = None
|
||||
if k_out.endswith(".scale_weight"):
|
||||
layer = k_out[:-len(".scale_weight")]
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
if k_out.endswith(".scale_input"):
|
||||
layer = k_out[:-len(".scale_input")]
|
||||
k_out = "{}.input_scale".format(layer)
|
||||
if w.item() == 1.0:
|
||||
continue
|
||||
|
||||
out_sd[k_out] = w
|
||||
|
||||
state_dict = out_sd
|
||||
quant_metadata = {"layers": layers}
|
||||
else:
|
||||
quant_metadata = json.loads(metadata["_quantization_metadata"])
|
||||
|
||||
if quant_metadata is not None:
|
||||
layers = quant_metadata["layers"]
|
||||
for k, v in layers.items():
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
return state_dict, metadata
|
||||
|
||||
@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend,
|
||||
allowing graceful protocol evolution while maintaining backward compatibility.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
# Default server capabilities
|
||||
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||
SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||
|
||||
|
||||
def get_connection_feature(
|
||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||
sockets_metadata: dict[str, dict[str, Any]],
|
||||
sid: str,
|
||||
feature_name: str,
|
||||
default: Any = False
|
||||
@ -42,7 +42,7 @@ def get_connection_feature(
|
||||
|
||||
|
||||
def supports_feature(
|
||||
sockets_metadata: Dict[str, Dict[str, Any]],
|
||||
sockets_metadata: dict[str, dict[str, Any]],
|
||||
sid: str,
|
||||
feature_name: str
|
||||
) -> bool:
|
||||
@ -60,7 +60,7 @@ def supports_feature(
|
||||
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
|
||||
|
||||
|
||||
def get_server_features() -> Dict[str, Any]:
|
||||
def get_server_features() -> dict[str, Any]:
|
||||
"""
|
||||
Get the server's feature flags.
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Type, List, NamedTuple
|
||||
from typing import NamedTuple
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from packaging import version as packaging_version
|
||||
|
||||
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
|
||||
|
||||
class ComfyAPIWithVersion(NamedTuple):
|
||||
version: str
|
||||
api_class: Type[ComfyAPIBase]
|
||||
api_class: type[ComfyAPIBase]
|
||||
|
||||
|
||||
def parse_version(version_str: str) -> packaging_version.Version:
|
||||
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
|
||||
return packaging_version.parse(version_str)
|
||||
|
||||
|
||||
registered_versions: List[ComfyAPIWithVersion] = []
|
||||
registered_versions: list[ComfyAPIWithVersion] = []
|
||||
|
||||
|
||||
def register_versions(versions: List[ComfyAPIWithVersion]):
|
||||
def register_versions(versions: list[ComfyAPIWithVersion]):
|
||||
versions.sort(key=lambda x: parse_version(x.version))
|
||||
global registered_versions
|
||||
registered_versions = versions
|
||||
|
||||
|
||||
def get_all_versions() -> List[ComfyAPIWithVersion]:
|
||||
def get_all_versions() -> list[ComfyAPIWithVersion]:
|
||||
"""
|
||||
Returns a list of all registered ComfyAPI versions.
|
||||
"""
|
||||
|
||||
@ -8,7 +8,7 @@ import os
|
||||
import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, get_origin, get_args, get_type_hints
|
||||
from typing import Optional, get_origin, get_args, get_type_hints
|
||||
|
||||
|
||||
class TypeTracker:
|
||||
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
|
||||
return result_container["result"]
|
||||
|
||||
@classmethod
|
||||
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
|
||||
def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
|
||||
"""
|
||||
Creates a new class with synchronous versions of all async methods.
|
||||
|
||||
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
|
||||
|
||||
@classmethod
|
||||
def _generate_imports(
|
||||
cls, async_class: Type, type_tracker: TypeTracker
|
||||
cls, async_class: type, type_tracker: TypeTracker
|
||||
) -> list[str]:
|
||||
"""Generate import statements for the stub file."""
|
||||
imports = []
|
||||
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
|
||||
return imports
|
||||
|
||||
@classmethod
|
||||
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
|
||||
def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
|
||||
"""Extract class attributes that are classes themselves."""
|
||||
class_attributes = []
|
||||
|
||||
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
|
||||
def _generate_inner_class_stub(
|
||||
cls,
|
||||
name: str,
|
||||
attr: Type,
|
||||
attr: type,
|
||||
indent: str = " ",
|
||||
type_tracker: Optional[TypeTracker] = None,
|
||||
) -> list[str]:
|
||||
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
|
||||
return processed
|
||||
|
||||
@classmethod
|
||||
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
|
||||
def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
|
||||
"""
|
||||
Generate a .pyi stub file for the sync class to help IDEs with type checking.
|
||||
"""
|
||||
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
|
||||
logging.error(traceback.format_exc())
|
||||
|
||||
|
||||
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
|
||||
def create_sync_class(async_class: type, thread_pool_size=10) -> type:
|
||||
"""
|
||||
Creates a sync version of an async class
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Type, TypeVar
|
||||
from typing import TypeVar
|
||||
|
||||
class SingletonMetaclass(type):
|
||||
T = TypeVar("T", bound="SingletonMetaclass")
|
||||
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
|
||||
)
|
||||
return cls._instances[cls]
|
||||
|
||||
def inject_instance(cls: Type[T], instance: T) -> None:
|
||||
def inject_instance(cls: type[T], instance: T) -> None:
|
||||
assert cls not in SingletonMetaclass._instances, (
|
||||
"Cannot inject instance after first instantiation"
|
||||
)
|
||||
SingletonMetaclass._instances[cls] = instance
|
||||
|
||||
def get_instance(cls: Type[T], *args, **kwargs) -> T:
|
||||
def get_instance(cls: type[T], *args, **kwargs) -> T:
|
||||
"""
|
||||
Gets the singleton instance of the class, creating it if it doesn't exist.
|
||||
"""
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
@ -80,7 +80,7 @@ class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
Called when an extension is loaded.
|
||||
This should be used to initialize any global resources neeeded by the extension.
|
||||
This should be used to initialize any global resources needed by the extension.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -113,7 +113,7 @@ ComfyAPI = ComfyAPI_latest
|
||||
if TYPE_CHECKING:
|
||||
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
|
||||
|
||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||
|
||||
# create new aliases for io and ui
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import torch
|
||||
from typing import TypedDict, List, Optional
|
||||
from typing import TypedDict, Optional
|
||||
|
||||
ImageInput = torch.Tensor
|
||||
"""
|
||||
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
|
||||
Optional noise mask tensor in the same format as samples.
|
||||
"""
|
||||
|
||||
batch_index: Optional[List[int]]
|
||||
batch_index: Optional[list[int]]
|
||||
|
||||
@ -4,7 +4,7 @@ from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import io
|
||||
import av
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
|
||||
@ -3,14 +3,14 @@ from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import AudioInput, VideoInput
|
||||
from .._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
|
||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_api.latest._resources import Resources, ResourcesLocal
|
||||
from ._resources import Resources, ResourcesLocal
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL
|
||||
|
||||
@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO):
|
||||
'''Used by WAN Camera.'''
|
||||
time_dim_concat: NotRequired[torch.Tensor]
|
||||
'''Used by WAN Phantom Subject.'''
|
||||
time_dim_replace: NotRequired[torch.Tensor]
|
||||
'''Used by Kandinsky5 I2V.'''
|
||||
|
||||
CondList = list[tuple[torch.Tensor, PooledDict]]
|
||||
Type = CondList
|
||||
@ -772,6 +774,13 @@ class AudioEncoder(ComfyTypeIO):
|
||||
class AudioEncoderOutput(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="TRACKS")
|
||||
class Tracks(ComfyTypeIO):
|
||||
class TrackDict(TypedDict):
|
||||
track_path: torch.Tensor
|
||||
track_visibility: torch.Tensor
|
||||
Type = TrackDict
|
||||
|
||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||
class MultiType:
|
||||
Type = Any
|
||||
@ -1813,7 +1822,7 @@ class NodeOutput(_NodeOutputInternal):
|
||||
ui = data["ui"]
|
||||
if "expand" in data:
|
||||
expand = data["expand"]
|
||||
return cls(args=args, ui=ui, expand=expand)
|
||||
return cls(*args, ui=ui, expand=expand)
|
||||
|
||||
def __getitem__(self, index) -> Any:
|
||||
return self.args[index]
|
||||
@ -1892,6 +1901,7 @@ __all__ = [
|
||||
"SEGS",
|
||||
"AnyType",
|
||||
"MultiType",
|
||||
"Tracks",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
# "DynamicCombo",
|
||||
|
||||
@ -5,7 +5,6 @@ import os
|
||||
import random
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Type
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
@ -22,7 +21,7 @@ import folder_paths
|
||||
|
||||
# used for image preview
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
from ._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
|
||||
|
||||
class SavedResult(dict):
|
||||
@ -83,7 +82,7 @@ class ImageSaveHelper:
|
||||
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
|
||||
|
||||
@staticmethod
|
||||
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
@ -96,7 +95,7 @@ class ImageSaveHelper:
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
|
||||
def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
|
||||
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
|
||||
if args.disable_metadata or cls is None or not cls.hidden:
|
||||
return None
|
||||
@ -121,7 +120,7 @@ class ImageSaveHelper:
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
|
||||
def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
|
||||
"""Creates EXIF metadata bytes for WebP images."""
|
||||
exif_data = pil_image.getexif()
|
||||
if args.disable_metadata or cls is None or cls.hidden is None:
|
||||
@ -137,7 +136,7 @@ class ImageSaveHelper:
|
||||
|
||||
@staticmethod
|
||||
def save_images(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
|
||||
) -> list[SavedResult]:
|
||||
"""Saves a batch of images as individual PNG files."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
@ -155,7 +154,7 @@ class ImageSaveHelper:
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||
def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
|
||||
"""Saves a batch of images and returns a UI object for the node output."""
|
||||
return SavedImages(
|
||||
ImageSaveHelper.save_images(
|
||||
@ -169,7 +168,7 @@ class ImageSaveHelper:
|
||||
|
||||
@staticmethod
|
||||
def save_animated_png(
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedResult:
|
||||
"""Saves a batch of images as a single animated PNG."""
|
||||
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
|
||||
@ -191,7 +190,7 @@ class ImageSaveHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_save_animated_png_ui(
|
||||
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
|
||||
images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
|
||||
) -> SavedImages:
|
||||
"""Saves an animated PNG and returns a UI object for the node output."""
|
||||
result = ImageSaveHelper.save_animated_png(
|
||||
@ -209,7 +208,7 @@ class ImageSaveHelper:
|
||||
images,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
cls: type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
@ -238,7 +237,7 @@ class ImageSaveHelper:
|
||||
def get_save_animated_webp_ui(
|
||||
images,
|
||||
filename_prefix: str,
|
||||
cls: Type[ComfyNode] | None,
|
||||
cls: type[ComfyNode] | None,
|
||||
fps: float,
|
||||
lossless: bool,
|
||||
quality: int,
|
||||
@ -267,7 +266,7 @@ class AudioSaveHelper:
|
||||
audio: dict,
|
||||
filename_prefix: str,
|
||||
folder_type: FolderType,
|
||||
cls: Type[ComfyNode] | None,
|
||||
cls: type[ComfyNode] | None,
|
||||
format: str = "flac",
|
||||
quality: str = "128k",
|
||||
) -> list[SavedResult]:
|
||||
@ -372,7 +371,7 @@ class AudioSaveHelper:
|
||||
|
||||
@staticmethod
|
||||
def get_save_audio_ui(
|
||||
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||
audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
|
||||
) -> SavedAudios:
|
||||
"""Save and instantly wrap for UI."""
|
||||
return SavedAudios(
|
||||
@ -388,7 +387,7 @@ class AudioSaveHelper:
|
||||
|
||||
|
||||
class PreviewImage(_UIOutput):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
|
||||
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
|
||||
self.values = ImageSaveHelper.save_images(
|
||||
image,
|
||||
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
|
||||
@ -412,7 +411,7 @@ class PreviewMask(PreviewImage):
|
||||
|
||||
|
||||
class PreviewAudio(_UIOutput):
|
||||
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
|
||||
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
|
||||
self.values = AudioSaveHelper.save_audio(
|
||||
audio,
|
||||
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
|
||||
|
||||
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import ImageInput, AudioInput
|
||||
from .._input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
|
||||
@ -2,9 +2,8 @@ from comfy_api.latest import ComfyAPI_latest
|
||||
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
|
||||
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from typing import List, Type
|
||||
|
||||
supported_versions: List[Type[ComfyAPIBase]] = [
|
||||
supported_versions: list[type[ComfyAPIBase]] = [
|
||||
ComfyAPI_latest,
|
||||
ComfyAPIAdapter_v0_0_2,
|
||||
ComfyAPIAdapter_v0_0_1,
|
||||
|
||||
144
comfy_api_nodes/apis/bytedance_api.py
Normal file
144
comfy_api_nodes/apis/bytedance_api.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
size: str | None = Field(None)
|
||||
seed: int | None = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(True)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: str | None = Field("adaptive")
|
||||
seed: int | None = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(True)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
class Seedream4TaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str = Field("url")
|
||||
image: list[str] | None = Field(None, description="Image URLs")
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
model: str = Field(...)
|
||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||
data: list = Field([], description="Contains information about the generated image(s).")
|
||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||
|
||||
|
||||
class TaskTextContent(BaseModel):
|
||||
type: str = Field("text")
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContent(BaseModel):
|
||||
type: str = Field("image_url")
|
||||
image_url: TaskImageContentUrl = Field(...)
|
||||
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: TaskStatusError | None = Field(None)
|
||||
content: TaskStatusResult | None = Field(None)
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
("1152x864 (4:3)", 1152, 864),
|
||||
("1280x720 (16:9)", 1280, 720),
|
||||
("720x1280 (9:16)", 720, 1280),
|
||||
("832x1248 (2:3)", 832, 1248),
|
||||
("1248x832 (3:2)", 1248, 832),
|
||||
("1512x648 (21:9)", 1512, 648),
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("2304x1728 (4:3)", 2304, 1728),
|
||||
("1728x2304 (3:4)", 1728, 2304),
|
||||
("2560x1440 (16:9)", 2560, 1440),
|
||||
("1440x2560 (9:16)", 1440, 2560),
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-lite-i2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-pro-250528": {
|
||||
"480p": 70,
|
||||
"720p": 85,
|
||||
"1080p": 115,
|
||||
},
|
||||
"seedance-1-0-pro-fast-251015": {
|
||||
"480p": 50,
|
||||
"720p": 65,
|
||||
"1080p": 100,
|
||||
},
|
||||
}
|
||||
@ -84,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
|
||||
description="A list of ordered parts that make up a single message. "
|
||||
"Different parts may have different IANA MIME types.",
|
||||
)
|
||||
role: GeminiRole = Field(
|
||||
...,
|
||||
description="The identity of the entity that creates the message. "
|
||||
"The following values are supported: "
|
||||
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
||||
"model: This indicates that the message is generated by the model. "
|
||||
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
||||
"For non-multi-turn conversations, this field can be left blank or unset.",
|
||||
)
|
||||
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
|
||||
|
||||
|
||||
class GeminiFunctionDeclaration(BaseModel):
|
||||
|
||||
@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel):
|
||||
url: str = Field(..., description="URL for generated image")
|
||||
|
||||
|
||||
class OmniTaskStatusResults(BaseModel):
|
||||
class TaskStatusResults(BaseModel):
|
||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||
images: list[TaskStatusImageResult] | None = Field(None)
|
||||
|
||||
|
||||
class OmniTaskStatusResponseData(BaseModel):
|
||||
class TaskStatusResponseData(BaseModel):
|
||||
created_at: int | None = Field(None, description="Task creation time")
|
||||
updated_at: int | None = Field(None, description="Task update time")
|
||||
task_status: str | None = None
|
||||
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
||||
task_id: str | None = Field(None, description="Task ID")
|
||||
task_result: OmniTaskStatusResults | None = Field(None)
|
||||
task_result: TaskStatusResults | None = Field(None)
|
||||
|
||||
|
||||
class OmniTaskStatusResponse(BaseModel):
|
||||
class TaskStatusResponse(BaseModel):
|
||||
code: int | None = Field(None, description="Error code")
|
||||
message: str | None = Field(None, description="Error message")
|
||||
request_id: str | None = Field(None, description="Request ID")
|
||||
data: OmniTaskStatusResponseData | None = Field(None)
|
||||
data: TaskStatusResponseData | None = Field(None)
|
||||
|
||||
|
||||
class OmniImageParamImage(BaseModel):
|
||||
@ -84,3 +84,21 @@ class OmniProImageRequest(BaseModel):
|
||||
mode: str = Field("pro")
|
||||
n: int | None = Field(1, le=9)
|
||||
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||
|
||||
|
||||
class TextToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
|
||||
|
||||
class ImageToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
image: str = Field(...)
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
|
||||
@ -1,100 +0,0 @@
|
||||
from typing import Optional
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Pikaffect(str, Enum):
|
||||
Cake_ify = "Cake-ify"
|
||||
Crumble = "Crumble"
|
||||
Crush = "Crush"
|
||||
Decapitate = "Decapitate"
|
||||
Deflate = "Deflate"
|
||||
Dissolve = "Dissolve"
|
||||
Explode = "Explode"
|
||||
Eye_pop = "Eye-pop"
|
||||
Inflate = "Inflate"
|
||||
Levitate = "Levitate"
|
||||
Melt = "Melt"
|
||||
Peel = "Peel"
|
||||
Poke = "Poke"
|
||||
Squish = "Squish"
|
||||
Ta_da = "Ta-da"
|
||||
Tear = "Tear"
|
||||
|
||||
|
||||
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
|
||||
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
|
||||
duration: Optional[int] = Field(5)
|
||||
ingredientsMode: str = Field(...)
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
resolution: Optional[str] = Field('1080p')
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaGenerateResponse(BaseModel):
|
||||
video_id: str = Field(...)
|
||||
|
||||
|
||||
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
|
||||
duration: Optional[int] = 5
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
|
||||
duration: Optional[int] = Field(None, ge=5, le=10)
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: str = Field(...)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
|
||||
aspectRatio: Optional[float] = Field(
|
||||
1.7777777777777777,
|
||||
description='Aspect ratio (width / height)',
|
||||
ge=0.4,
|
||||
le=2.5,
|
||||
)
|
||||
duration: Optional[int] = 5
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: str = Field(...)
|
||||
resolution: Optional[str] = '1080p'
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
pikaffect: Optional[str] = None
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
|
||||
|
||||
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
|
||||
negativePrompt: Optional[str] = Field(None)
|
||||
promptText: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
modifyRegionRoi: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class PikaStatusEnum(str, Enum):
|
||||
queued = "queued"
|
||||
started = "started"
|
||||
finished = "finished"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class PikaVideoResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
progress: Optional[int] = Field(None)
|
||||
status: PikaStatusEnum
|
||||
url: Optional[str] = Field(None)
|
||||
@ -85,7 +85,7 @@ class Response1(BaseModel):
|
||||
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||
None, description='Reasons why media was filtered by responsible AI policies'
|
||||
)
|
||||
videos: Optional[list[Video]] = None
|
||||
videos: Optional[list[Video]] = Field(None)
|
||||
|
||||
|
||||
class VeoGenVidPollResponse(BaseModel):
|
||||
|
||||
@ -1,13 +1,27 @@
|
||||
import logging
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance_api import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
Image2ImageTaskCreationRequest,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
Seedream4Options,
|
||||
Seedream4TaskCreationRequest,
|
||||
TaskCreationResponse,
|
||||
TaskImageContent,
|
||||
TaskImageContentUrl,
|
||||
TaskStatusResponse,
|
||||
TaskTextContent,
|
||||
Text2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
|
||||
class Text2ImageModelName(str, Enum):
|
||||
seedream_3 = "seedream-3-0-t2i-250415"
|
||||
|
||||
|
||||
class Image2ImageModelName(str, Enum):
|
||||
seededit_3 = "seededit-3-0-i2i-250628"
|
||||
|
||||
|
||||
class Text2VideoModelName(str, Enum):
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
|
||||
|
||||
|
||||
class Image2VideoModelName(str, Enum):
|
||||
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
|
||||
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: Text2ImageModelName = Text2ImageModelName.seedream_3
|
||||
prompt: str = Field(...)
|
||||
response_format: Optional[str] = Field("url")
|
||||
size: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
||||
watermark: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: Image2ImageModelName = Image2ImageModelName.seededit_3
|
||||
prompt: str = Field(...)
|
||||
response_format: Optional[str] = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: Optional[str] = Field("adaptive")
|
||||
seed: Optional[int] = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
||||
watermark: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
class Seedream4TaskCreationRequest(BaseModel):
|
||||
model: str = Field("seedream-4-0-250828")
|
||||
prompt: str = Field(...)
|
||||
response_format: str = Field("url")
|
||||
image: Optional[list[str]] = Field(None, description="Image URLs")
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
model: str = Field(...)
|
||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||
data: list = Field([], description="Contains information about the generated image(s).")
|
||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||
|
||||
|
||||
class TaskTextContent(BaseModel):
|
||||
type: str = Field("text")
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContent(BaseModel):
|
||||
type: str = Field("image_url")
|
||||
image_url: TaskImageContentUrl = Field(...)
|
||||
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
|
||||
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: Optional[TaskStatusError] = Field(None)
|
||||
content: Optional[TaskStatusResult] = Field(None)
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
("1152x864 (4:3)", 1152, 864),
|
||||
("1280x720 (16:9)", 1280, 720),
|
||||
("720x1280 (9:16)", 720, 1280),
|
||||
("832x1248 (2:3)", 832, 1248),
|
||||
("1248x832 (3:2)", 1248, 832),
|
||||
("1512x648 (21:9)", 1512, 648),
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("2304x1728 (4:3)", 2304, 1728),
|
||||
("1728x2304 (3:4)", 1728, 2304),
|
||||
("2560x1440 (16:9)", 2560, 1440),
|
||||
("1440x2560 (9:16)", 1440, 2560),
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-lite-i2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-pro-250528": {
|
||||
"480p": 70,
|
||||
"720p": 85,
|
||||
"1080p": 115,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
||||
@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
return response.data[0]["url"]
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if hasattr(response, "content") and response.content:
|
||||
return response.content.video_url
|
||||
return None
|
||||
|
||||
|
||||
class ByteDanceImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
category="api node/image/ByteDance",
|
||||
description="Generate images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Text2ImageModelName,
|
||||
default=Text2ImageModelName.seedream_3,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
category="api node/image/ByteDance",
|
||||
description="Edit images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Image2ImageModelName,
|
||||
default=Image2ImageModelName.seededit_3,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The base image to edit",
|
||||
@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
guidance_scale: float,
|
||||
@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedream-4-0-250828"],
|
||||
options=["seedream-4-5-251128", "seedream-4-0-250828"],
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.String.Input(
|
||||
@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=64,
|
||||
step=8,
|
||||
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=64,
|
||||
step=8,
|
||||
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: torch.Tensor = None,
|
||||
image: Input.Image | None = None,
|
||||
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
|
||||
width: int = 2048,
|
||||
height: int = 2048,
|
||||
@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
raise ValueError(
|
||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
||||
)
|
||||
out_num_pixels = w * h
|
||||
mp_provided = out_num_pixels / 1_000_000.0
|
||||
if "seedream-4-5" in model and out_num_pixels < 3686400:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
n_input_images = get_number_of_images(image) if image is not None else 0
|
||||
if n_input_images > 10:
|
||||
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
|
||||
@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Text2VideoModelName,
|
||||
default=Text2VideoModelName.seedance_1_pro,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Image2VideoModelName,
|
||||
default=Image2VideoModelName.seedance_1_pro,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in Image2VideoModelName],
|
||||
default=Image2VideoModelName.seedance_1_lite.value,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
first_frame: Input.Image,
|
||||
last_frame: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[Image2VideoModelName.seedance_1_lite.value],
|
||||
default=Image2VideoModelName.seedance_1_lite.value,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
images: torch.Tensor,
|
||||
images: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
|
||||
async def process_video_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
||||
estimated_duration: Optional[int],
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
@ -1085,7 +934,7 @@ async def process_video_task(
|
||||
estimated_duration=estimated_duration,
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||
|
||||
@ -13,8 +13,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
@ -27,6 +26,8 @@ from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiMimeType,
|
||||
GeminiPart,
|
||||
GeminiRole,
|
||||
GeminiSystemInstructionContent,
|
||||
GeminiTextPart,
|
||||
Modality,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
@ -43,6 +44,14 @@ from comfy_api_nodes.util import (
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
GEMINI_IMAGE_SYS_PROMPT = (
|
||||
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
||||
"Interpret all user input—regardless of "
|
||||
"format, intent, or abstraction—as literal visual directives for image composition.\n"
|
||||
"If a prompt is conversational or lacks specific visual details, "
|
||||
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
|
||||
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
|
||||
)
|
||||
|
||||
|
||||
class GeminiModel(str, Enum):
|
||||
@ -68,7 +77,7 @@ class GeminiImageModel(str, Enum):
|
||||
|
||||
async def create_image_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: torch.Tensor,
|
||||
images: Input.Image,
|
||||
image_limit: int = 0,
|
||||
) -> list[GeminiPart]:
|
||||
image_parts: list[GeminiPart] = []
|
||||
@ -154,8 +163,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
for part in parts:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
@ -277,6 +286,13 @@ class GeminiNode(IO.ComfyNode):
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(),
|
||||
@ -293,7 +309,9 @@ class GeminiNode(IO.ComfyNode):
|
||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert video input to Gemini API compatible parts."""
|
||||
|
||||
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
@ -343,10 +361,11 @@ class GeminiNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
audio: Input.Audio | None = None,
|
||||
video: Input.Video | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
@ -363,7 +382,10 @@ class GeminiNode(IO.ComfyNode):
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -373,7 +395,8 @@ class GeminiNode(IO.ComfyNode):
|
||||
role=GeminiRole.user,
|
||||
parts=parts,
|
||||
)
|
||||
]
|
||||
],
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
@ -523,6 +546,13 @@ class GeminiImage(IO.ComfyNode):
|
||||
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -542,10 +572,11 @@ class GeminiImage(IO.ComfyNode):
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
aspect_ratio: str = "auto",
|
||||
response_modalities: str = "IMAGE+TEXT",
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
@ -559,6 +590,10 @@ class GeminiImage(IO.ComfyNode):
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -570,6 +605,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
@ -640,6 +676,13 @@ class GeminiImage2(IO.ComfyNode):
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -662,8 +705,9 @@ class GeminiImage2(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
response_modalities: str,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
|
||||
@ -679,6 +723,10 @@ class GeminiImage2(IO.ComfyNode):
|
||||
if aspect_ratio != "auto":
|
||||
image_config.aspectRatio = aspect_ratio
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -690,6 +738,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
|
||||
@ -50,6 +50,7 @@ from comfy_api_nodes.apis import (
|
||||
KlingSingleImageEffectModelName,
|
||||
)
|
||||
from comfy_api_nodes.apis.kling_api import (
|
||||
ImageToVideoWithAudioRequest,
|
||||
OmniImageParamImage,
|
||||
OmniParamImage,
|
||||
OmniParamVideo,
|
||||
@ -57,7 +58,8 @@ from comfy_api_nodes.apis.kling_api import (
|
||||
OmniProImageRequest,
|
||||
OmniProReferences2VideoRequest,
|
||||
OmniProText2VideoRequest,
|
||||
OmniTaskStatusResponse,
|
||||
TaskStatusResponse,
|
||||
TextToVideoWithAudioRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
@ -103,10 +105,6 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
|
||||
|
||||
|
||||
MODE_TEXT2VIDEO = {
|
||||
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
||||
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
|
||||
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
|
||||
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
|
||||
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
|
||||
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
|
||||
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
|
||||
@ -127,8 +125,6 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
|
||||
|
||||
|
||||
MODE_START_END_FRAME = {
|
||||
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
||||
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
|
||||
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
|
||||
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
|
||||
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
|
||||
@ -242,7 +238,7 @@ def normalize_omni_prompt_references(prompt: str) -> str:
|
||||
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
||||
|
||||
|
||||
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
|
||||
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
@ -250,7 +246,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStat
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||
response_model=OmniTaskStatusResponse,
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
max_poll_attempts=160,
|
||||
)
|
||||
@ -483,12 +479,12 @@ async def execute_image2video(
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||
response_model=KlingImage2VideoResponse,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||
response_model=KlingImage2VideoResponse,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@ -752,7 +748,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"mode",
|
||||
options=modes,
|
||||
default=modes[4],
|
||||
default=modes[8],
|
||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||
),
|
||||
],
|
||||
@ -834,7 +830,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=OmniTaskStatusResponse,
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProText2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
@ -929,7 +925,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=OmniTaskStatusResponse,
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProFirstLastFrameRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
@ -997,7 +993,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=OmniTaskStatusResponse,
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
@ -1081,7 +1077,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=OmniTaskStatusResponse,
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
@ -1162,7 +1158,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=OmniTaskStatusResponse,
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
@ -1237,7 +1233,7 @@ class OmniProImageNode(IO.ComfyNode):
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
||||
response_model=OmniTaskStatusResponse,
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProImageRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
@ -1253,7 +1249,7 @@ class OmniProImageNode(IO.ComfyNode):
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
||||
response_model=OmniTaskStatusResponse,
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
||||
@ -1328,9 +1324,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingImage2VideoNode",
|
||||
display_name="Kling Image to Video",
|
||||
display_name="Kling Image(First Frame) to Video",
|
||||
category="api node/video/Kling",
|
||||
description="Kling Image to Video Node",
|
||||
inputs=[
|
||||
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -1488,7 +1483,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"mode",
|
||||
options=modes,
|
||||
default=modes[8],
|
||||
default=modes[6],
|
||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||
),
|
||||
],
|
||||
@ -1951,7 +1946,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"model_name",
|
||||
options=[i.value for i in KlingImageGenModelName],
|
||||
default="kling-v1",
|
||||
default="kling-v2",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
@ -2034,6 +2029,136 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await image_result_to_node_output(images))
|
||||
|
||||
|
||||
class TextToVideoWithAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingTextToVideoWithAudio",
|
||||
display_name="Kling Text to Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||
IO.Combo.Input("mode", options=["pro"]),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
IO.Boolean.Input("generate_audio", default=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
mode: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
generate_audio: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=TextToVideoWithAudioRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
mode=mode,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
sound="on" if generate_audio else "off",
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class ImageToVideoWithAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingImageToVideoWithAudio",
|
||||
display_name="Kling Image(First Frame) to Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.Image.Input("start_frame"),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||
IO.Combo.Input("mode", options=["pro"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
IO.Boolean.Input("generate_audio", default=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
start_frame: Input.Image,
|
||||
prompt: str,
|
||||
mode: str,
|
||||
duration: int,
|
||||
generate_audio: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
validate_image_dimensions(start_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=ImageToVideoWithAudioRequest(
|
||||
model_name=model_name,
|
||||
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
|
||||
prompt=prompt,
|
||||
mode=mode,
|
||||
duration=str(duration),
|
||||
sound="on" if generate_audio else "off",
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class KlingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -2056,7 +2181,9 @@ class KlingExtension(ComfyExtension):
|
||||
OmniProImageToVideoNode,
|
||||
OmniProVideoToVideoNode,
|
||||
OmniProEditVideoNode,
|
||||
# OmniProImageNode, # need support from backend
|
||||
OmniProImageNode,
|
||||
TextToVideoWithAudio,
|
||||
ImageToVideoWithAudio,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
duration: int = Field(...)
|
||||
resolution: str = Field(...)
|
||||
fps: Optional[int] = Field(25)
|
||||
generate_audio: Optional[bool] = Field(True)
|
||||
image_uri: Optional[str] = Field(None)
|
||||
fps: int | None = Field(25)
|
||||
generate_audio: bool | None = Field(True)
|
||||
image_uri: str | None = Field(None)
|
||||
|
||||
|
||||
class TextToVideoNode(IO.ComfyNode):
|
||||
@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class ImageToVideoNode(IO.ComfyNode):
|
||||
@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class LtxvApiExtension(ComfyExtension):
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
|
||||
"""
|
||||
Validates and processes video input for Moonvalley Video-to-Video generation.
|
||||
|
||||
@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
return _validate_and_trim_duration(video)
|
||||
|
||||
|
||||
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
||||
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
|
||||
"""Extracts video dimensions with error handling."""
|
||||
try:
|
||||
return video.get_dimensions()
|
||||
@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
|
||||
"""Validates video duration and trims to 5 seconds if needed."""
|
||||
duration = video.get_duration()
|
||||
_validate_minimum_duration(duration)
|
||||
@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
|
||||
raise ValueError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
|
||||
"""Trims video to 5 seconds if longer."""
|
||||
if duration > 5:
|
||||
return trim_video(video, 5)
|
||||
@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
resolution: str,
|
||||
@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
video: Optional[VideoInput] = None,
|
||||
video: Input.Video | None = None,
|
||||
control_type: str = "Motion Transfer",
|
||||
motion_intensity: Optional[int] = 100,
|
||||
motion_intensity: int | None = 100,
|
||||
steps=33,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
|
||||
@ -1,568 +0,0 @@
|
||||
"""
|
||||
Pika x ComfyUI API Nodes
|
||||
|
||||
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||
from comfy_api_nodes.util import (
|
||||
validate_string,
|
||||
download_url_to_video_output,
|
||||
tensor_to_bytesio,
|
||||
ApiEndpoint,
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
|
||||
|
||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
|
||||
|
||||
PIKA_API_VERSION = "2.2"
|
||||
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
|
||||
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
|
||||
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
|
||||
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
||||
|
||||
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||
|
||||
|
||||
async def execute_task(
|
||||
task_id: str,
|
||||
cls: type[IO.ComfyNode],
|
||||
) -> IO.NodeOutput:
|
||||
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||
response_model=pika_defs.PikaVideoResponse,
|
||||
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||
estimated_duration=60,
|
||||
max_poll_attempts=240,
|
||||
)
|
||||
if not final_response.url:
|
||||
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||
logging.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
video_url = final_response.url
|
||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||
return IO.NodeOutput(await download_url_to_video_output(video_url))
|
||||
|
||||
|
||||
def get_base_inputs_types() -> list[IO.Input]:
|
||||
"""Get the base required inputs types common to all Pika nodes."""
|
||||
return [
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
||||
IO.Combo.Input("duration", options=[5, 10], default=5),
|
||||
]
|
||||
|
||||
|
||||
class PikaImageToVideo(IO.ComfyNode):
|
||||
"""Pika 2.2 Image to Video Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaImageToVideoNode2_2",
|
||||
display_name="Pika Image to Video",
|
||||
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to convert to video"),
|
||||
*get_base_inputs_types(),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
) -> IO.NodeOutput:
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaTextToVideoNode(IO.ComfyNode):
|
||||
"""Pika Text2Video v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaTextToVideoNode2_2",
|
||||
display_name="Pika Text to Video",
|
||||
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
*get_base_inputs_types(),
|
||||
IO.Float.Input(
|
||||
"aspect_ratio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
max=2.5,
|
||||
default=1.7777777777777777,
|
||||
tooltip="Aspect ratio (width / height)",
|
||||
)
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
aspect_ratio: float,
|
||||
) -> IO.NodeOutput:
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
),
|
||||
content_type="application/x-www-form-urlencoded",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaScenes(IO.ComfyNode):
|
||||
"""PikaScenes v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaScenesV2_2",
|
||||
display_name="Pika Scenes (Video Image Composition)",
|
||||
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
*get_base_inputs_types(),
|
||||
IO.Combo.Input(
|
||||
"ingredients_mode",
|
||||
options=["creative", "precise"],
|
||||
default="creative",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"aspect_ratio",
|
||||
step=0.001,
|
||||
min=0.4,
|
||||
max=2.5,
|
||||
default=1.7777777777777777,
|
||||
tooltip="Aspect ratio (width / height)",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_1",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_2",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_3",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_4",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_ingredient_5",
|
||||
optional=True,
|
||||
tooltip="Image that will be used as ingredient to create a video.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
ingredients_mode: str,
|
||||
aspect_ratio: float,
|
||||
image_ingredient_1: Optional[torch.Tensor] = None,
|
||||
image_ingredient_2: Optional[torch.Tensor] = None,
|
||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||
) -> IO.NodeOutput:
|
||||
all_image_bytes_io = []
|
||||
for image in [
|
||||
image_ingredient_1,
|
||||
image_ingredient_2,
|
||||
image_ingredient_3,
|
||||
image_ingredient_4,
|
||||
image_ingredient_5,
|
||||
]:
|
||||
if image is not None:
|
||||
all_image_bytes_io.append(tensor_to_bytesio(image))
|
||||
|
||||
pika_files = [
|
||||
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||
]
|
||||
|
||||
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||
ingredientsMode=ingredients_mode,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
aspectRatio=aspect_ratio,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikAdditionsNode(IO.ComfyNode):
|
||||
"""Pika Pikadditions Node. Add an image into a video."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikadditions",
|
||||
display_name="Pikadditions (Video Object Insertion)",
|
||||
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="The video to add an image to."),
|
||||
IO.Image.Input("image", tooltip="The image to add to the video."),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
image: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
image_bytes_io = tensor_to_bytesio(image)
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
"image": ("image.png", image_bytes_io, "image/png"),
|
||||
}
|
||||
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaSwapsNode(IO.ComfyNode):
|
||||
"""Pika Pikaswaps Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikaswaps",
|
||||
display_name="Pika Swaps (Video Object Replacement)",
|
||||
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="The video to swap an object in."),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The image used to replace the masked object in the video.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Mask.Input(
|
||||
"mask",
|
||||
tooltip="Use the mask to define areas in the video to replace.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input("prompt_text", multiline=True, optional=True),
|
||||
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
||||
IO.String.Input(
|
||||
"region_to_modify",
|
||||
multiline=True,
|
||||
optional=True,
|
||||
tooltip="Plaintext description of the object / region to modify.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
prompt_text: str = "",
|
||||
negative_prompt: str = "",
|
||||
seed: int = 0,
|
||||
region_to_modify: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
video_bytes_io = BytesIO()
|
||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
video_bytes_io.seek(0)
|
||||
pika_files = {
|
||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||
}
|
||||
if mask is not None:
|
||||
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
|
||||
if image is not None:
|
||||
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
|
||||
|
||||
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||
)
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_request_data,
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaffectsNode(IO.ComfyNode):
|
||||
"""Pika Pikaffects Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Pikaffects",
|
||||
display_name="Pikaffects (Video Effects)",
|
||||
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||
IO.Combo.Input(
|
||||
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
|
||||
),
|
||||
IO.String.Input("prompt_text", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True),
|
||||
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
pikaffect: str,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||
pikaffect=pikaffect,
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
),
|
||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
"""PikaFrames v2.2 Node."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="PikaStartEndFrameNode2_2",
|
||||
display_name="Pika Start and End Frame to Video",
|
||||
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
|
||||
category="api node/video/Pika",
|
||||
inputs=[
|
||||
IO.Image.Input("image_start", tooltip="The first image to combine."),
|
||||
IO.Image.Input("image_end", tooltip="The last image to combine."),
|
||||
*get_base_inputs_types(),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image_start: torch.Tensor,
|
||||
image_end: torch.Tensor,
|
||||
prompt_text: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
resolution: str,
|
||||
duration: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt_text, field_name="prompt_text", min_length=1)
|
||||
pika_files = [
|
||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||
]
|
||||
initial_operation = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||
response_model=pika_defs.PikaGenerateResponse,
|
||||
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||
promptText=prompt_text,
|
||||
negativePrompt=negative_prompt,
|
||||
seed=seed,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
),
|
||||
files=pika_files,
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
return await execute_task(initial_operation.video_id, cls)
|
||||
|
||||
|
||||
class PikaApiNodesExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
PikaImageToVideo,
|
||||
PikaTextToVideoNode,
|
||||
PikaScenes,
|
||||
PikAdditionsNode,
|
||||
PikaSwapsNode,
|
||||
PikaffectsNode,
|
||||
PikaStartEndFrameNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> PikaApiNodesExtension:
|
||||
return PikaApiNodesExtension()
|
||||
@ -11,12 +11,11 @@ User Guides:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union, Optional
|
||||
from typing_extensions import override
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis import (
|
||||
RunwayImageToVideoRequest,
|
||||
RunwayImageToVideoResponse,
|
||||
@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
|
||||
field_1280_768 = "1280:768"
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> Union[float, None]:
|
||||
) -> float | None:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
|
||||
async def get_response(
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_op(
|
||||
@ -119,8 +116,8 @@ async def get_response(
|
||||
async def generate_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: RunwayImageToVideoRequest,
|
||||
estimated_duration: Optional[int] = None,
|
||||
) -> VideoFromFile:
|
||||
estimated_duration: int | None = None,
|
||||
) -> InputImpl.VideoFromFile:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
cls,
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
reference_image: Optional[torch.Tensor] = None,
|
||||
reference_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis.veo_api import (
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse,
|
||||
@ -232,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
|
||||
# Check if video is provided as base64 or URL
|
||||
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
@ -431,8 +429,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
first_frame: Input.Image,
|
||||
last_frame: Input.Image,
|
||||
model: str,
|
||||
generate_audio: bool,
|
||||
):
|
||||
@ -493,7 +491,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
if response.videos:
|
||||
video = response.videos[0]
|
||||
if video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
if video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
|
||||
@ -47,6 +47,7 @@ from .validation_utils import (
|
||||
validate_string,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
validate_video_frame_count,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -94,6 +95,7 @@ __all__ = [
|
||||
"validate_string",
|
||||
"validate_video_dimensions",
|
||||
"validate_video_duration",
|
||||
"validate_video_frame_count",
|
||||
# Misc functions
|
||||
"get_fs_object_size",
|
||||
]
|
||||
|
||||
@ -2,8 +2,8 @@ import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.model_management import processing_interrupted
|
||||
@ -35,12 +35,12 @@ def default_base_url() -> str:
|
||||
|
||||
async def sleep_with_interrupt(
|
||||
seconds: float,
|
||||
node_cls: Optional[type[IO.ComfyNode]],
|
||||
label: Optional[str] = None,
|
||||
start_ts: Optional[float] = None,
|
||||
estimated_total: Optional[int] = None,
|
||||
node_cls: type[IO.ComfyNode] | None,
|
||||
label: str | None = None,
|
||||
start_ts: float | None = None,
|
||||
estimated_total: int | None = None,
|
||||
*,
|
||||
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
|
||||
display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Sleep in 1s slices while:
|
||||
@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
|
||||
def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
||||
if isinstance(path_or_object, str):
|
||||
return os.path.getsize(path_or_object)
|
||||
return len(path_or_object.getvalue())
|
||||
|
||||
@ -4,10 +4,11 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Literal, TypeVar
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
@ -37,8 +38,8 @@ class ApiEndpoint:
|
||||
path: str,
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||||
*,
|
||||
query_params: Optional[dict[str, Any]] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
query_params: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
):
|
||||
self.path = path
|
||||
self.method = method
|
||||
@ -52,18 +53,18 @@ class _RequestConfig:
|
||||
endpoint: ApiEndpoint
|
||||
timeout: float
|
||||
content_type: str
|
||||
data: Optional[dict[str, Any]]
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
||||
multipart_parser: Optional[Callable]
|
||||
data: dict[str, Any] | None
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||
multipart_parser: Callable | None
|
||||
max_retries: int
|
||||
retry_delay: float
|
||||
retry_backoff: float
|
||||
wait_label: str = "Waiting"
|
||||
monitor_progress: bool = True
|
||||
estimated_total: Optional[int] = None
|
||||
final_label_on_success: Optional[str] = "Completed"
|
||||
progress_origin_ts: Optional[float] = None
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
|
||||
estimated_total: int | None = None
|
||||
final_label_on_success: str | None = "Completed"
|
||||
progress_origin_ts: float | None = None
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -71,10 +72,10 @@ class _PollUIState:
|
||||
started: float
|
||||
status_label: str = "Queued"
|
||||
is_queued: bool = True
|
||||
price: Optional[float] = None
|
||||
estimated_duration: Optional[int] = None
|
||||
price: float | None = None
|
||||
estimated_duration: int | None = None
|
||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
||||
active_since: float | None = None # start time of current active interval (None if queued)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
@ -87,20 +88,20 @@ async def sync_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||||
data: Optional[BaseModel] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
response_model: type[M],
|
||||
price_extractor: Callable[[M | Any], float | None] | None = None,
|
||||
data: BaseModel | None = None,
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Optional[Callable] = None,
|
||||
multipart_parser: Callable | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: Optional[int] = None,
|
||||
final_label_on_success: Optional[str] = "Completed",
|
||||
progress_origin_ts: Optional[float] = None,
|
||||
estimated_duration: int | None = None,
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
@ -131,22 +132,22 @@ async def poll_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
||||
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||
data: Optional[BaseModel] = None,
|
||||
response_model: type[M],
|
||||
status_extractor: Callable[[M | Any], str | int | None],
|
||||
progress_extractor: Callable[[M | Any], int | None] | None = None,
|
||||
price_extractor: Callable[[M | Any], float | None] | None = None,
|
||||
completed_statuses: list[str | int] | None = None,
|
||||
failed_statuses: list[str | int] | None = None,
|
||||
queued_statuses: list[str | int] | None = None,
|
||||
data: BaseModel | None = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: Optional[int] = None,
|
||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
) -> M:
|
||||
raw = await poll_op_raw(
|
||||
@ -178,22 +179,22 @@ async def sync_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||
data: dict[str, Any] | BaseModel | None = None,
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Optional[Callable] = None,
|
||||
multipart_parser: Callable | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: Optional[int] = None,
|
||||
estimated_duration: int | None = None,
|
||||
as_binary: bool = False,
|
||||
final_label_on_success: Optional[str] = "Completed",
|
||||
progress_origin_ts: Optional[float] = None,
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
) -> Union[dict[str, Any], bytes]:
|
||||
) -> dict[str, Any] | bytes:
|
||||
"""
|
||||
Make a single network request.
|
||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||
@ -229,21 +230,21 @@ async def poll_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
||||
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
status_extractor: Callable[[dict[str, Any]], str | int | None],
|
||||
progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||
completed_statuses: list[str | int] | None = None,
|
||||
failed_statuses: list[str | int] | None = None,
|
||||
queued_statuses: list[str | int] | None = None,
|
||||
data: dict[str, Any] | BaseModel | None = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: Optional[int] = None,
|
||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@ -261,7 +262,7 @@ async def poll_op_raw(
|
||||
consumed_attempts = 0 # counts only non-queued polls
|
||||
|
||||
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||||
last_progress: Optional[int] = None
|
||||
last_progress: int | None = None
|
||||
|
||||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||
stop_ticker = asyncio.Event()
|
||||
@ -420,10 +421,10 @@ async def poll_op_raw(
|
||||
|
||||
def _display_text(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
text: Optional[str],
|
||||
text: str | None,
|
||||
*,
|
||||
status: Optional[Union[str, int]] = None,
|
||||
price: Optional[float] = None,
|
||||
status: str | int | None = None,
|
||||
price: float | None = None,
|
||||
) -> None:
|
||||
display_lines: list[str] = []
|
||||
if status:
|
||||
@ -440,13 +441,13 @@ def _display_text(
|
||||
|
||||
def _display_time_progress(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
status: Optional[Union[str, int]],
|
||||
status: str | int | None,
|
||||
elapsed_seconds: int,
|
||||
estimated_total: Optional[int] = None,
|
||||
estimated_total: int | None = None,
|
||||
*,
|
||||
price: Optional[float] = None,
|
||||
is_queued: Optional[bool] = None,
|
||||
processing_elapsed_seconds: Optional[int] = None,
|
||||
price: float | None = None,
|
||||
is_queued: bool | None = None,
|
||||
processing_elapsed_seconds: int | None = None,
|
||||
) -> None:
|
||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||
@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||
|
||||
|
||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
|
||||
params = dict(endpoint_params or {})
|
||||
if method.upper() == "GET" and data:
|
||||
for k, v in data.items():
|
||||
@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
||||
def _snapshot_request_body_for_logging(
|
||||
content_type: str,
|
||||
method: str,
|
||||
data: Optional[dict[str, Any]],
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
||||
) -> Optional[Union[dict[str, Any], str]]:
|
||||
data: dict[str, Any] | None,
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None,
|
||||
) -> dict[str, Any] | str | None:
|
||||
if method.upper() == "GET":
|
||||
return None
|
||||
if content_type == "multipart/form-data":
|
||||
@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: Optional[int] = None
|
||||
extracted_price: Optional[float] = None
|
||||
final_elapsed_seconds: int | None = None
|
||||
extracted_price: float | None = None
|
||||
while True:
|
||||
attempt += 1
|
||||
stop_event = asyncio.Event()
|
||||
monitor_task: Optional[asyncio.Task] = None
|
||||
sess: Optional[aiohttp.ClientSession] = None
|
||||
monitor_task: asyncio.Task | None = None
|
||||
sess: aiohttp.ClientSession | None = None
|
||||
|
||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||
@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
)
|
||||
|
||||
|
||||
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
||||
def _validate_or_raise(response_model: type[M], payload: Any) -> M:
|
||||
try:
|
||||
return response_model.model_validate(payload)
|
||||
except Exception as e:
|
||||
@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
||||
|
||||
|
||||
def _wrap_model_extractor(
|
||||
response_model: Type[M],
|
||||
extractor: Optional[Callable[[M], Any]],
|
||||
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
||||
response_model: type[M],
|
||||
extractor: Callable[[M], Any] | None,
|
||||
) -> Callable[[dict[str, Any]], Any] | None:
|
||||
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||||
Validates the dict into `response_model` before invoking `extractor`.
|
||||
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||||
@ -929,10 +930,10 @@ def _wrap_model_extractor(
|
||||
return _wrapped
|
||||
|
||||
|
||||
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
||||
def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
|
||||
if not values:
|
||||
return set()
|
||||
out: set[Union[str, int]] = set()
|
||||
out: set[str | int] = set()
|
||||
for v in values:
|
||||
nv = _normalize_status_value(v)
|
||||
if nv is not None:
|
||||
@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
||||
def _normalize_status_value(val: str | int | None) -> str | int | None:
|
||||
if isinstance(val, str):
|
||||
return val.strip().lower()
|
||||
return val
|
||||
|
||||
@ -4,7 +4,6 @@ import math
|
||||
import mimetypes
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
@ -12,8 +11,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import Input, InputImpl
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api.latest import Input, InputImpl, Types
|
||||
|
||||
from ._helpers import mimetype_to_extension
|
||||
|
||||
@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
|
||||
|
||||
def video_to_base64_string(
|
||||
video: Input.Video,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
container_format: Types.VideoContainer | None = None,
|
||||
codec: Types.VideoCodec | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
@ -189,12 +187,11 @@ def video_to_base64_string(
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video.save_to(
|
||||
video_bytes_io,
|
||||
format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
|
||||
codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
|
||||
)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
@ -3,15 +3,15 @@ import contextlib
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import IO, Optional, Union
|
||||
from typing import IO
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO as COMFY_IO
|
||||
from comfy_api.latest import InputImpl
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
|
||||
async def download_url_to_bytesio(
|
||||
url: str,
|
||||
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
|
||||
dest: BytesIO | IO[bytes] | str | Path | None,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 5,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
@ -71,10 +71,10 @@ async def download_url_to_bytesio(
|
||||
|
||||
is_path_sink = isinstance(dest, (str, Path))
|
||||
fhandle = None
|
||||
session: Optional[aiohttp.ClientSession] = None
|
||||
stop_evt: Optional[asyncio.Event] = None
|
||||
monitor_task: Optional[asyncio.Task] = None
|
||||
req_task: Optional[asyncio.Task] = None
|
||||
session: aiohttp.ClientSession | None = None
|
||||
stop_evt: asyncio.Event | None = None
|
||||
monitor_task: asyncio.Task | None = None
|
||||
req_task: asyncio.Task | None = None
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
@ -234,11 +234,11 @@ async def download_url_to_video_output(
|
||||
timeout: float = None,
|
||||
max_retries: int = 5,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> VideoFromFile:
|
||||
) -> InputImpl.VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
||||
return VideoFromFile(result)
|
||||
return InputImpl.VideoFromFile(result)
|
||||
|
||||
|
||||
async def download_url_as_bytesio(
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
@ -4,15 +4,13 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api.latest import IO, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api.latest import IO, Input, Types
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||
@ -32,7 +30,7 @@ from .conversions import (
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: Optional[str] = Field(
|
||||
content_type: str | None = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
@ -56,7 +54,7 @@ async def upload_images_to_comfyapi(
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
# if batched, try to upload each file if max_images is greater than 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
@ -100,9 +98,9 @@ async def upload_video_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
video: Input.Video,
|
||||
*,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
container: Types.VideoContainer = Types.VideoContainer.MP4,
|
||||
codec: Types.VideoCodec = Types.VideoCodec.H264,
|
||||
max_duration: int | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
) -> str:
|
||||
"""
|
||||
@ -220,7 +218,7 @@ async def upload_file(
|
||||
return
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
sess: Optional[aiohttp.ClientSession] = None
|
||||
sess: aiohttp.ClientSession | None = None
|
||||
try:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import Input
|
||||
|
||||
|
||||
@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
||||
|
||||
def validate_image_dimensions(
|
||||
image: torch.Tensor,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
max_height: Optional[int] = None,
|
||||
min_width: int | None = None,
|
||||
max_width: int | None = None,
|
||||
min_height: int | None = None,
|
||||
max_height: int | None = None,
|
||||
):
|
||||
height, width = get_image_dimensions(image)
|
||||
|
||||
@ -37,8 +35,8 @@ def validate_image_dimensions(
|
||||
|
||||
def validate_image_aspect_ratio(
|
||||
image: torch.Tensor,
|
||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
|
||||
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
@ -54,8 +52,8 @@ def validate_image_aspect_ratio(
|
||||
def validate_images_aspect_ratio_closeness(
|
||||
first_image: torch.Tensor,
|
||||
second_image: torch.Tensor,
|
||||
min_rel: float, # e.g. 0.8
|
||||
max_rel: float, # e.g. 1.25
|
||||
min_rel: float, # e.g. 0.8
|
||||
max_rel: float, # e.g. 1.25
|
||||
*,
|
||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
|
||||
|
||||
def validate_aspect_ratio_string(
|
||||
aspect_ratio: str,
|
||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
|
||||
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
|
||||
|
||||
def validate_video_dimensions(
|
||||
video: Input.Video,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
max_height: Optional[int] = None,
|
||||
min_width: int | None = None,
|
||||
max_width: int | None = None,
|
||||
min_height: int | None = None,
|
||||
max_height: int | None = None,
|
||||
):
|
||||
try:
|
||||
width, height = video.get_dimensions()
|
||||
@ -120,8 +118,8 @@ def validate_video_dimensions(
|
||||
|
||||
def validate_video_duration(
|
||||
video: Input.Video,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
min_duration: float | None = None,
|
||||
max_duration: float | None = None,
|
||||
):
|
||||
try:
|
||||
duration = video.get_duration()
|
||||
@ -136,6 +134,23 @@ def validate_video_duration(
|
||||
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
||||
|
||||
|
||||
def validate_video_frame_count(
|
||||
video: Input.Video,
|
||||
min_frame_count: int | None = None,
|
||||
max_frame_count: int | None = None,
|
||||
):
|
||||
try:
|
||||
frame_count = video.get_frame_count()
|
||||
except Exception as e:
|
||||
logging.error("Error getting frame count of video: %s", e)
|
||||
return
|
||||
|
||||
if min_frame_count is not None and min_frame_count > frame_count:
|
||||
raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
|
||||
if max_frame_count is not None and frame_count > max_frame_count:
|
||||
raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
|
||||
|
||||
|
||||
def get_number_of_images(images):
|
||||
if isinstance(images, torch.Tensor):
|
||||
return images.shape[0] if images.ndim >= 4 else 1
|
||||
@ -144,8 +159,8 @@ def get_number_of_images(images):
|
||||
|
||||
def validate_audio_duration(
|
||||
audio: Input.Audio,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
min_duration: float | None = None,
|
||||
max_duration: float | None = None,
|
||||
) -> None:
|
||||
sr = int(audio["sample_rate"])
|
||||
dur = int(audio["waveform"].shape[-1]) / sr
|
||||
@ -177,7 +192,7 @@ def validate_string(
|
||||
)
|
||||
|
||||
|
||||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
||||
def validate_container_format_is_mp4(video: Input.Video) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
||||
def _assert_ratio_bounds(
|
||||
ar: float,
|
||||
*,
|
||||
min_ratio: Optional[tuple[float, float]] = None,
|
||||
max_ratio: Optional[tuple[float, float]] = None,
|
||||
min_ratio: tuple[float, float] | None = None,
|
||||
max_ratio: tuple[float, float] | None = None,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||||
|
||||
@ -573,12 +573,14 @@ class EmptyAudio(IO.ComfyNode):
|
||||
step=0.01,
|
||||
tooltip="Duration of the empty audio clip in seconds",
|
||||
),
|
||||
IO.Float.Input(
|
||||
IO.Int.Input(
|
||||
"sample_rate",
|
||||
default=44100,
|
||||
tooltip="Sample rate of the empty audio clip.",
|
||||
min=1,
|
||||
max=192000,
|
||||
),
|
||||
IO.Float.Input(
|
||||
IO.Int.Input(
|
||||
"channels",
|
||||
default=2,
|
||||
min=1,
|
||||
|
||||
@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
model = model.clone()
|
||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||
@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
context_overlap=context_overlap,
|
||||
context_stride=context_stride,
|
||||
closed_loop=closed_loop,
|
||||
dim=dim)
|
||||
dim=dim,
|
||||
freenoise=freenoise,
|
||||
cond_retain_index_list=cond_retain_index_list,
|
||||
split_conds_to_windows=split_conds_to_windows
|
||||
)
|
||||
# make memory usage calculation only take into account the context window latents
|
||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||
if freenoise: # no other use for this wrapper at this time
|
||||
comfy.context_windows.create_sampler_sample_wrapper(model)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||
@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
]
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||
|
||||
|
||||
class ContextWindowsExtension(ComfyExtension):
|
||||
|
||||
@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
|
||||
get_sampler = execute
|
||||
|
||||
|
||||
class SamplerSEEDS2(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerSEEDS2",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
|
||||
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
|
||||
],
|
||||
outputs=[io.Sampler.Output()]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
|
||||
sampler_name = "seeds_2"
|
||||
sampler = comfy.samplers.ksampler(
|
||||
sampler_name,
|
||||
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
|
||||
)
|
||||
return io.NodeOutput(sampler)
|
||||
|
||||
|
||||
class Noise_EmptyNoise:
|
||||
def __init__(self):
|
||||
self.seed = 0
|
||||
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
|
||||
SamplerDPMAdaptative,
|
||||
SamplerER_SDE,
|
||||
SamplerSASolver,
|
||||
SamplerSEEDS2,
|
||||
SplitSigmas,
|
||||
SplitSigmasDenoise,
|
||||
FlipSigmas,
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
import torch
|
||||
import logging
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
def Fourier_filter(x, threshold, scale):
|
||||
# FFT
|
||||
@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale):
|
||||
return x_filtered.to(x.dtype)
|
||||
|
||||
|
||||
class FreeU:
|
||||
class FreeU(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FreeU",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
@classmethod
|
||||
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||
on_cpu_devices = {}
|
||||
@ -59,23 +66,31 @@ class FreeU:
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return (m, )
|
||||
return IO.NodeOutput(m)
|
||||
|
||||
class FreeU_V2:
|
||||
patch = execute # TODO: remove
|
||||
|
||||
|
||||
class FreeU_V2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FreeU_V2",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
|
||||
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
CATEGORY = "model_patches/unet"
|
||||
|
||||
def patch(self, model, b1, b2, s1, s2):
|
||||
@classmethod
|
||||
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
|
||||
model_channels = model.model.model_config.unet_config["model_channels"]
|
||||
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
|
||||
on_cpu_devices = {}
|
||||
@ -105,9 +120,19 @@ class FreeU_V2:
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_output_block_patch(output_block_patch)
|
||||
return (m, )
|
||||
return IO.NodeOutput(m)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"FreeU": FreeU,
|
||||
"FreeU_V2": FreeU_V2,
|
||||
}
|
||||
patch = execute # TODO: remove
|
||||
|
||||
|
||||
class FreelunchExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
FreeU,
|
||||
FreeU_V2,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> FreelunchExtension:
|
||||
return FreelunchExtension()
|
||||
|
||||
136
comfy_extras/nodes_kandinsky5.py
Normal file
136
comfy_extras/nodes_kandinsky5.py
Normal file
@ -0,0 +1,136 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class Kandinsky5ImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Kandinsky5ImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
|
||||
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
cond_latent_out = {}
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
encoded = vae.encode(start_image[:, :, :, :3])
|
||||
cond_latent_out["samples"] = encoded
|
||||
|
||||
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
|
||||
|
||||
|
||||
def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
|
||||
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
|
||||
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
|
||||
|
||||
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
|
||||
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
|
||||
|
||||
# normalization
|
||||
normalized = (source - source_mean) / (source_std + 1e-8)
|
||||
normalized = normalized * reference_std + reference_mean
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
class NormalizeVideoLatentStart(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="NormalizeVideoLatentStart",
|
||||
category="conditioning/video_models",
|
||||
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
|
||||
inputs=[
|
||||
io.Latent.Input("latent"),
|
||||
io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
|
||||
io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
|
||||
if latent["samples"].shape[2] <= 1:
|
||||
return io.NodeOutput(latent)
|
||||
s = latent.copy()
|
||||
samples = latent["samples"].clone()
|
||||
|
||||
first_frames = samples[:, :, :start_frame_count]
|
||||
reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
|
||||
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
|
||||
|
||||
samples[:, :, :start_frame_count] = normalized_first_frames
|
||||
s["samples"] = samples
|
||||
return io.NodeOutput(s)
|
||||
|
||||
|
||||
class CLIPTextEncodeKandinsky5(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="CLIPTextEncodeKandinsky5",
|
||||
category="advanced/conditioning/kandinsky5",
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||
io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
|
||||
tokens = clip.tokenize(clip_l)
|
||||
tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
|
||||
|
||||
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||
|
||||
|
||||
class Kandinsky5Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
Kandinsky5ImageToVideo,
|
||||
NormalizeVideoLatentStart,
|
||||
CLIPTextEncodeKandinsky5,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> Kandinsky5Extension:
|
||||
return Kandinsky5Extension()
|
||||
@ -4,7 +4,7 @@ import torch
|
||||
import nodes
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
import logging
|
||||
|
||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||
if latent.shape[1:] != target_shape[1:]:
|
||||
@ -388,6 +388,42 @@ class LatentOperationSharpen(io.ComfyNode):
|
||||
return luminance * sharpened
|
||||
return io.NodeOutput(sharpen)
|
||||
|
||||
class ReplaceVideoLatentFrames(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReplaceVideoLatentFrames",
|
||||
category="latent/batch",
|
||||
inputs=[
|
||||
io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
|
||||
io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
|
||||
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, destination, index, source=None) -> io.NodeOutput:
|
||||
if source is None:
|
||||
return io.NodeOutput(destination)
|
||||
dest_frames = destination["samples"].shape[2]
|
||||
source_frames = source["samples"].shape[2]
|
||||
if index < 0:
|
||||
index = dest_frames + index
|
||||
if index > dest_frames:
|
||||
logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.")
|
||||
return io.NodeOutput(destination)
|
||||
if index + source_frames > dest_frames:
|
||||
logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.")
|
||||
return io.NodeOutput(destination)
|
||||
s = source.copy()
|
||||
s_source = source["samples"]
|
||||
s_destination = destination["samples"].clone()
|
||||
s_destination[:, :, index:index + s_source.shape[2]] = s_source
|
||||
s["samples"] = s_destination
|
||||
return io.NodeOutput(s)
|
||||
|
||||
class LatentExtension(ComfyExtension):
|
||||
@override
|
||||
@ -405,6 +441,7 @@ class LatentExtension(ComfyExtension):
|
||||
LatentApplyOperationCFG,
|
||||
LatentOperationTonemapReinhard,
|
||||
LatentOperationSharpen,
|
||||
ReplaceVideoLatentFrames
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -3,11 +3,10 @@ import scipy.ndimage
|
||||
import torch
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
import folder_paths
|
||||
import random
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, UI
|
||||
|
||||
import nodes
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False):
|
||||
source = source.to(destination.device)
|
||||
@ -46,202 +45,213 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||
return destination
|
||||
|
||||
class LatentCompositeMasked:
|
||||
class LatentCompositeMasked(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"destination": ("LATENT",),
|
||||
"source": ("LATENT",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}),
|
||||
"resize_source": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "composite"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LatentCompositeMasked",
|
||||
category="latent",
|
||||
inputs=[
|
||||
IO.Latent.Input("destination"),
|
||||
IO.Latent.Input("source"),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8),
|
||||
IO.Boolean.Input("resize_source", default=False),
|
||||
IO.Mask.Input("mask", optional=True),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
||||
output = destination.copy()
|
||||
destination = destination["samples"].clone()
|
||||
source = source["samples"]
|
||||
output["samples"] = composite(destination, source, x, y, mask, 8, resize_source)
|
||||
return (output,)
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class ImageCompositeMasked:
|
||||
composite = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageCompositeMasked(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"destination": ("IMAGE",),
|
||||
"source": ("IMAGE",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"resize_source": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "composite"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageCompositeMasked",
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("destination"),
|
||||
IO.Image.Input("source"),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Boolean.Input("resize_source", default=False),
|
||||
IO.Mask.Input("mask", optional=True),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def composite(self, destination, source, x, y, resize_source, mask = None):
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||
destination = destination.clone().movedim(-1, 1)
|
||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||
return (output,)
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class MaskToImage:
|
||||
composite = execute # TODO: remove
|
||||
|
||||
|
||||
class MaskToImage(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskToImage",
|
||||
display_name="Convert Mask to Image",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "mask_to_image"
|
||||
|
||||
def mask_to_image(self, mask):
|
||||
@classmethod
|
||||
def execute(cls, mask) -> IO.NodeOutput:
|
||||
result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
return (result,)
|
||||
return IO.NodeOutput(result)
|
||||
|
||||
class ImageToMask:
|
||||
mask_to_image = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageToMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"channel": (["red", "green", "blue", "alpha"],),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageToMask",
|
||||
display_name="Convert Image to Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "image_to_mask"
|
||||
|
||||
def image_to_mask(self, image, channel):
|
||||
@classmethod
|
||||
def execute(cls, image, channel) -> IO.NodeOutput:
|
||||
channels = ["red", "green", "blue", "alpha"]
|
||||
mask = image[:, :, :, channels.index(channel)]
|
||||
return (mask,)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
class ImageColorToMask:
|
||||
image_to_mask = execute # TODO: remove
|
||||
|
||||
|
||||
class ImageColorToMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ImageColorToMask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "image_to_mask"
|
||||
|
||||
def image_to_mask(self, image, color):
|
||||
@classmethod
|
||||
def execute(cls, image, color) -> IO.NodeOutput:
|
||||
temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int)
|
||||
temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2]
|
||||
mask = torch.where(temp == color, 1.0, 0).float()
|
||||
return (mask,)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
class SolidMask:
|
||||
image_to_mask = execute # TODO: remove
|
||||
|
||||
|
||||
class SolidMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SolidMask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "solid"
|
||||
|
||||
def solid(self, value, width, height):
|
||||
@classmethod
|
||||
def execute(cls, value, width, height) -> IO.NodeOutput:
|
||||
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu")
|
||||
return (out,)
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
class InvertMask:
|
||||
solid = execute # TODO: remove
|
||||
|
||||
|
||||
class InvertMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="InvertMask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "invert"
|
||||
|
||||
def invert(self, mask):
|
||||
@classmethod
|
||||
def execute(cls, mask) -> IO.NodeOutput:
|
||||
out = 1.0 - mask
|
||||
return (out,)
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
class CropMask:
|
||||
invert = execute # TODO: remove
|
||||
|
||||
|
||||
class CropMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="CropMask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "crop"
|
||||
|
||||
def crop(self, mask, x, y, width, height):
|
||||
@classmethod
|
||||
def execute(cls, mask, x, y, width, height) -> IO.NodeOutput:
|
||||
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1]))
|
||||
out = mask[:, y:y + height, x:x + width]
|
||||
return (out,)
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
class MaskComposite:
|
||||
crop = execute # TODO: remove
|
||||
|
||||
|
||||
class MaskComposite(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"destination": ("MASK",),
|
||||
"source": ("MASK",),
|
||||
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"operation": (["multiply", "add", "subtract", "and", "or", "xor"],),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskComposite",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("destination"),
|
||||
IO.Mask.Input("source"),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "combine"
|
||||
|
||||
def combine(self, destination, source, x, y, operation):
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput:
|
||||
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
|
||||
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
|
||||
|
||||
@ -267,28 +277,29 @@ class MaskComposite:
|
||||
|
||||
output = torch.clamp(output, 0.0, 1.0)
|
||||
|
||||
return (output,)
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class FeatherMask:
|
||||
combine = execute # TODO: remove
|
||||
|
||||
|
||||
class FeatherMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FeatherMask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "feather"
|
||||
|
||||
def feather(self, mask, left, top, right, bottom):
|
||||
@classmethod
|
||||
def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput:
|
||||
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
|
||||
|
||||
left = min(left, output.shape[-1])
|
||||
@ -312,26 +323,28 @@ class FeatherMask:
|
||||
feather_rate = (y + 1) / bottom
|
||||
output[:, -y, :] *= feather_rate
|
||||
|
||||
return (output,)
|
||||
return IO.NodeOutput(output)
|
||||
|
||||
class GrowMask:
|
||||
feather = execute # TODO: remove
|
||||
|
||||
|
||||
class GrowMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}),
|
||||
"tapered_corners": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GrowMask",
|
||||
display_name="Grow Mask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Boolean.Input("tapered_corners", default=True),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
FUNCTION = "expand_mask"
|
||||
|
||||
def expand_mask(self, mask, expand, tapered_corners):
|
||||
@classmethod
|
||||
def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput:
|
||||
c = 0 if tapered_corners else 1
|
||||
kernel = np.array([[c, 1, c],
|
||||
[1, 1, 1],
|
||||
@ -347,69 +360,74 @@ class GrowMask:
|
||||
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
|
||||
output = torch.from_numpy(output)
|
||||
out.append(output)
|
||||
return (torch.stack(out, dim=0),)
|
||||
return IO.NodeOutput(torch.stack(out, dim=0))
|
||||
|
||||
class ThresholdMask:
|
||||
expand_mask = execute # TODO: remove
|
||||
|
||||
|
||||
class ThresholdMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"mask": ("MASK",),
|
||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ThresholdMask",
|
||||
category="mask",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[IO.Mask.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
FUNCTION = "image_to_mask"
|
||||
|
||||
def image_to_mask(self, mask, value):
|
||||
@classmethod
|
||||
def execute(cls, mask, value) -> IO.NodeOutput:
|
||||
mask = (mask > value).float()
|
||||
return (mask,)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
image_to_mask = execute # TODO: remove
|
||||
|
||||
|
||||
# Mask Preview - original implement from
|
||||
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
|
||||
# upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes
|
||||
class MaskPreview(nodes.SaveImage):
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
self.type = "temp"
|
||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||
self.compress_level = 4
|
||||
class MaskPreview(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="MaskPreview",
|
||||
display_name="Preview Mask",
|
||||
category="mask",
|
||||
description="Saves the input images to your ComfyUI output directory.",
|
||||
inputs=[
|
||||
IO.Mask.Input("mask"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {"mask": ("MASK",), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
|
||||
FUNCTION = "execute"
|
||||
CATEGORY = "mask"
|
||||
|
||||
def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||
preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3)
|
||||
return self.save_images(preview, filename_prefix, prompt, extra_pnginfo)
|
||||
def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput:
|
||||
return IO.NodeOutput(ui=UI.PreviewMask(mask))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LatentCompositeMasked": LatentCompositeMasked,
|
||||
"ImageCompositeMasked": ImageCompositeMasked,
|
||||
"MaskToImage": MaskToImage,
|
||||
"ImageToMask": ImageToMask,
|
||||
"ImageColorToMask": ImageColorToMask,
|
||||
"SolidMask": SolidMask,
|
||||
"InvertMask": InvertMask,
|
||||
"CropMask": CropMask,
|
||||
"MaskComposite": MaskComposite,
|
||||
"FeatherMask": FeatherMask,
|
||||
"GrowMask": GrowMask,
|
||||
"ThresholdMask": ThresholdMask,
|
||||
"MaskPreview": MaskPreview
|
||||
}
|
||||
class MaskExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
LatentCompositeMasked,
|
||||
ImageCompositeMasked,
|
||||
MaskToImage,
|
||||
ImageToMask,
|
||||
ImageColorToMask,
|
||||
SolidMask,
|
||||
InvertMask,
|
||||
CropMask,
|
||||
MaskComposite,
|
||||
FeatherMask,
|
||||
GrowMask,
|
||||
ThresholdMask,
|
||||
MaskPreview,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"ImageToMask": "Convert Image to Mask",
|
||||
"MaskToImage": "Convert Mask to Image",
|
||||
}
|
||||
|
||||
async def comfy_entrypoint() -> MaskExtension:
|
||||
return MaskExtension()
|
||||
|
||||
@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode):
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
# Sampling
|
||||
"PatchModelAddDownscale": "",
|
||||
}
|
||||
|
||||
class ModelDownscaleExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
|
||||
@ -243,7 +243,13 @@ class ModelPatchLoader:
|
||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
|
||||
sd = z_image_convert(sd)
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||
config = {}
|
||||
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
|
||||
config['n_control_layers'] = 15
|
||||
config['additional_in_dim'] = 17
|
||||
config['refiner_control'] = True
|
||||
config['broken'] = True
|
||||
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
|
||||
|
||||
model.load_state_dict(sd)
|
||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||
@ -297,56 +303,86 @@ class DiffSynthCnetPatch:
|
||||
return [self.model_patch]
|
||||
|
||||
class ZImageControlPatch:
|
||||
def __init__(self, model_patch, vae, image, strength):
|
||||
def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
|
||||
self.model_patch = model_patch
|
||||
self.vae = vae
|
||||
self.image = image
|
||||
self.inpaint_image = inpaint_image
|
||||
self.mask = mask
|
||||
self.strength = strength
|
||||
self.encoded_image = self.encode_latent_cond(image)
|
||||
self.encoded_image_size = (image.shape[1], image.shape[2])
|
||||
self.temp_data = None
|
||||
|
||||
def encode_latent_cond(self, image):
|
||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
|
||||
return latent_image
|
||||
def encode_latent_cond(self, control_image, inpaint_image=None):
|
||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
||||
if self.model_patch.model.additional_in_dim > 0:
|
||||
if self.mask is None:
|
||||
mask_ = torch.zeros_like(latent_image)[:, :1]
|
||||
else:
|
||||
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
||||
if inpaint_image is None:
|
||||
inpaint_image = torch.ones_like(control_image) * 0.5
|
||||
|
||||
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
||||
|
||||
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
||||
else:
|
||||
return latent_image
|
||||
|
||||
def __call__(self, kwargs):
|
||||
x = kwargs.get("x")
|
||||
img = kwargs.get("img")
|
||||
img_input = kwargs.get("img_input")
|
||||
txt = kwargs.get("txt")
|
||||
pe = kwargs.get("pe")
|
||||
vec = kwargs.get("vec")
|
||||
block_index = kwargs.get("block_index")
|
||||
block_type = kwargs.get("block_type", "")
|
||||
spacial_compression = self.vae.spacial_compression_encode()
|
||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
||||
inpaint_scaled = None
|
||||
if self.inpaint_image is not None:
|
||||
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
|
||||
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled)
|
||||
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
||||
comfy.model_management.load_models_gpu(loaded_models)
|
||||
|
||||
cnet_index = (block_index // 5)
|
||||
cnet_index_float = (block_index / 5)
|
||||
cnet_blocks = self.model_patch.model.n_control_layers
|
||||
div = round(30 / cnet_blocks)
|
||||
|
||||
cnet_index = (block_index // div)
|
||||
cnet_index_float = (block_index / div)
|
||||
|
||||
kwargs.pop("img") # we do ops in place
|
||||
kwargs.pop("txt")
|
||||
|
||||
cnet_blocks = self.model_patch.model.n_control_layers
|
||||
if cnet_index_float > (cnet_blocks - 1):
|
||||
self.temp_data = None
|
||||
return kwargs
|
||||
|
||||
if self.temp_data is None or self.temp_data[0] > cnet_index:
|
||||
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||
if block_type == "noise_refiner":
|
||||
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||
else:
|
||||
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
|
||||
|
||||
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||
if block_type == "noise_refiner":
|
||||
next_layer = self.temp_data[0] + 1
|
||||
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||
self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||
if self.temp_data[1][0] is not None:
|
||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||
else:
|
||||
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
|
||||
next_layer = self.temp_data[0] + 1
|
||||
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
|
||||
|
||||
if cnet_index_float == self.temp_data[0]:
|
||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||
if cnet_blocks == self.temp_data[0] + 1:
|
||||
self.temp_data = None
|
||||
if cnet_index_float == self.temp_data[0]:
|
||||
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
|
||||
if cnet_blocks == self.temp_data[0] + 1:
|
||||
self.temp_data = None
|
||||
|
||||
return kwargs
|
||||
|
||||
@ -386,7 +422,9 @@ class QwenImageDiffsynthControlnet:
|
||||
mask = 1.0 - mask
|
||||
|
||||
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
|
||||
patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask)
|
||||
model_patched.set_model_noise_refiner_patch(patch)
|
||||
model_patched.set_model_double_block_patch(patch)
|
||||
else:
|
||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||
return (model_patched,)
|
||||
|
||||
@ -2,6 +2,8 @@ from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||
|
||||
def skip_torch_compile_dict(guard_entries):
|
||||
return [("transformer_options" not in entry.name) for entry in guard_entries]
|
||||
|
||||
class TorchCompileModel(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -23,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, model, backend) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
set_torch_compile_wrapper(model=m, backend=backend)
|
||||
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
|
||||
@ -8,10 +8,7 @@ import json
|
||||
from typing import Optional
|
||||
from typing_extensions import override
|
||||
from fractions import Fraction
|
||||
from comfy_api.input import AudioInput, ImageInput, VideoInput
|
||||
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
|
||||
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
|
||||
from comfy_api.latest import ComfyExtension, io, ui
|
||||
from comfy_api.latest import ComfyExtension, io, ui, Input, InputImpl, Types
|
||||
from comfy.cli_args import args
|
||||
|
||||
class SaveWEBM(io.ComfyNode):
|
||||
@ -28,7 +25,6 @@ class SaveWEBM(io.ComfyNode):
|
||||
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
|
||||
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
@ -79,16 +75,15 @@ class SaveVideo(io.ComfyNode):
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to save."),
|
||||
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
|
||||
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
|
||||
io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
|
||||
def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput:
|
||||
width, height = video.get_dimensions()
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
filename_prefix,
|
||||
@ -105,10 +100,10 @@ class SaveVideo(io.ComfyNode):
|
||||
metadata["prompt"] = cls.hidden.prompt
|
||||
if len(metadata) > 0:
|
||||
saved_metadata = metadata
|
||||
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
|
||||
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
|
||||
video.save_to(
|
||||
os.path.join(full_output_folder, file),
|
||||
format=VideoContainer(format),
|
||||
format=Types.VideoContainer(format),
|
||||
codec=codec,
|
||||
metadata=saved_metadata
|
||||
)
|
||||
@ -135,9 +130,9 @@ class CreateVideo(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
|
||||
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
||||
return io.NodeOutput(
|
||||
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
)
|
||||
|
||||
class GetVideoComponents(io.ComfyNode):
|
||||
@ -159,11 +154,11 @@ class GetVideoComponents(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: VideoInput) -> io.NodeOutput:
|
||||
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
||||
components = video.get_components()
|
||||
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -185,7 +180,7 @@ class LoadVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, file) -> io.NodeOutput:
|
||||
video_path = folder_paths.get_annotated_filepath(file)
|
||||
return io.NodeOutput(VideoFromFile(video_path))
|
||||
return io.NodeOutput(InputImpl.VideoFromFile(video_path))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(s, file):
|
||||
|
||||
535
comfy_extras/nodes_wanmove.py
Normal file
535
comfy_extras/nodes_wanmove.py
Normal file
@ -0,0 +1,535 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import torchvision.transforms.functional as TF
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import numpy as np
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy_extras.nodes_wan import parse_json_tracks
|
||||
|
||||
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
SKIP_ZERO = False
|
||||
|
||||
def get_pos_emb(
|
||||
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
|
||||
pos_emb_dim: int,
|
||||
theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions.
|
||||
device: torch.device = torch.device("cpu"),
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim)
|
||||
|
||||
assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even."
|
||||
pos_k = pos_k.to(device, dtype)
|
||||
if SKIP_ZERO:
|
||||
pos_k = pos_k + 1
|
||||
batch_size = pos_k.size(0)
|
||||
|
||||
denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype)
|
||||
# Expand denominator to match the shape needed for broadcasting
|
||||
denominator_expanded = denominator.view(1, -1).expand(batch_size, -1)
|
||||
|
||||
thetas = theta_func(denominator_expanded, pos_emb_dim)
|
||||
|
||||
# Ensure pos_k is in the correct shape for broadcasting
|
||||
pos_k_expanded = pos_k.view(-1, 1).to(dtype)
|
||||
sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas))
|
||||
cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas))
|
||||
|
||||
# Concatenate sine and cosine embeddings along the last dimension
|
||||
pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1)
|
||||
|
||||
return pos_emb
|
||||
|
||||
def create_pos_embeddings(
|
||||
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
|
||||
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
|
||||
downsample_ratios: list[int], # the ratios for downsampling time, height, and width
|
||||
height: int, # the height of the feature map
|
||||
width: int, # the width of the feature map
|
||||
track_num: int = -1, # the number of tracks to use
|
||||
t_down_strategy: str = "sample", # the strategy for downsampling time dimension
|
||||
):
|
||||
assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension."
|
||||
|
||||
t, n, _ = pred_tracks.shape
|
||||
t_down, h_down, w_down = downsample_ratios
|
||||
track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long)
|
||||
|
||||
if track_num == -1:
|
||||
track_num = n
|
||||
|
||||
tracks_idx = torch.randperm(n)[:track_num]
|
||||
tracks = pred_tracks[:, tracks_idx]
|
||||
visibility = pred_visibility[:, tracks_idx]
|
||||
|
||||
for t_idx in range(0, t, t_down):
|
||||
if t_down_strategy == "sample" or t_idx == 0:
|
||||
cur_tracks = tracks[t_idx] # [N, 2]
|
||||
cur_visibility = visibility[t_idx] # [N]
|
||||
else:
|
||||
cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0)
|
||||
cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0)
|
||||
|
||||
for i in range(track_num):
|
||||
if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height:
|
||||
continue
|
||||
x, y = cur_tracks[i]
|
||||
x, y = int(x // w_down), int(y // h_down)
|
||||
track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x
|
||||
|
||||
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
|
||||
|
||||
def replace_feature(
|
||||
vae_feature: torch.Tensor, # [B, C', T', H', W']
|
||||
track_pos: torch.Tensor, # [B, N, T', 2]
|
||||
strength: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
b, _, t, h, w = vae_feature.shape
|
||||
assert b == track_pos.shape[0], "Batch size mismatch."
|
||||
n = track_pos.shape[1]
|
||||
|
||||
# Shuffle the trajectory order
|
||||
track_pos = track_pos[:, torch.randperm(n), :, :]
|
||||
|
||||
# Extract coordinates at time steps ≥ 1 and generate a valid mask
|
||||
current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2]
|
||||
mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1]
|
||||
|
||||
# Get all valid indices
|
||||
valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3]
|
||||
num_valid = valid_indices.shape[0]
|
||||
|
||||
if num_valid == 0:
|
||||
return vae_feature
|
||||
|
||||
# Decompose valid indices into each dimension
|
||||
batch_idx = valid_indices[:, 0]
|
||||
track_idx = valid_indices[:, 1]
|
||||
t_rel = valid_indices[:, 2]
|
||||
t_target = t_rel + 1 # Convert to original time step indices
|
||||
|
||||
# Extract target position coordinates
|
||||
h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices
|
||||
w_target = current_pos[batch_idx, track_idx, t_rel, 1].long()
|
||||
|
||||
# Extract source position coordinates (t=0)
|
||||
h_source = track_pos[batch_idx, track_idx, 0, 0].long()
|
||||
w_source = track_pos[batch_idx, track_idx, 0, 1].long()
|
||||
|
||||
# Get source features and assign to target positions
|
||||
src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
|
||||
dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
|
||||
|
||||
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
|
||||
|
||||
|
||||
return vae_feature
|
||||
|
||||
# Visualize functions
|
||||
|
||||
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
|
||||
draw = ImageDraw.Draw(overlay, 'RGBA')
|
||||
points = points[::-1]
|
||||
|
||||
# Compute total length
|
||||
total_length = 0
|
||||
segment_lengths = []
|
||||
for i in range(len(points) - 1):
|
||||
dx = points[i + 1][0] - points[i][0]
|
||||
dy = points[i + 1][1] - points[i][1]
|
||||
length = (dx * dx + dy * dy) ** 0.5
|
||||
segment_lengths.append(length)
|
||||
total_length += length
|
||||
|
||||
if total_length == 0:
|
||||
return
|
||||
|
||||
accumulated_length = 0
|
||||
|
||||
# Draw the gradient polyline
|
||||
for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])):
|
||||
segment_length = segment_lengths[idx]
|
||||
steps = max(int(segment_length), 1)
|
||||
|
||||
for i in range(steps):
|
||||
current_length = accumulated_length + (i / steps) * segment_length
|
||||
ratio = current_length / total_length
|
||||
|
||||
alpha = int(255 * (1 - ratio) * opacity)
|
||||
color = (*start_color, alpha)
|
||||
|
||||
x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps)
|
||||
y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps)
|
||||
|
||||
dynamic_line_width = max(int(line_width * (1 - ratio)), 1)
|
||||
draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width)
|
||||
|
||||
accumulated_length += segment_length
|
||||
|
||||
|
||||
def add_weighted(rgb, track):
|
||||
rgb = np.array(rgb) # [H, W, C] "RGB"
|
||||
track = np.array(track) # [H, W, C] "RGBA"
|
||||
|
||||
alpha = track[:, :, 3] / 255.0
|
||||
alpha = np.stack([alpha] * 3, axis=-1)
|
||||
blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha)
|
||||
|
||||
return Image.fromarray(blend_img.astype(np.uint8))
|
||||
|
||||
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
|
||||
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
|
||||
|
||||
video = video.byte().cpu().numpy() # (81, 480, 832, 3)
|
||||
tracks = tracks[0].long().detach().cpu().numpy()
|
||||
if visibility is not None:
|
||||
visibility = visibility[0].detach().cpu().numpy()
|
||||
|
||||
num_frames, height, width = video.shape[:3]
|
||||
num_tracks = tracks.shape[1]
|
||||
alpha_opacity = int(255 * opacity)
|
||||
|
||||
output_frames = []
|
||||
for t in range(num_frames):
|
||||
frame_rgb = video[t].astype(np.float32)
|
||||
|
||||
# Create a single RGBA overlay for all tracks in this frame
|
||||
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||
draw_overlay = ImageDraw.Draw(overlay)
|
||||
|
||||
polyline_data = []
|
||||
|
||||
# Draw all circles on a single overlay
|
||||
for n in range(num_tracks):
|
||||
if visibility is not None and visibility[t, n] == 0:
|
||||
continue
|
||||
|
||||
track_coord = tracks[t, n]
|
||||
color = color_map[n % len(color_map)]
|
||||
circle_color = color + (alpha_opacity,)
|
||||
|
||||
draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size),
|
||||
fill=circle_color
|
||||
)
|
||||
|
||||
# Store polyline data for batch processing
|
||||
tracks_coord = tracks[max(t - track_frame, 0):t + 1, n]
|
||||
if len(tracks_coord) > 1:
|
||||
polyline_data.append((tracks_coord, color))
|
||||
|
||||
# Blend circles overlay once
|
||||
overlay_np = np.array(overlay)
|
||||
alpha = overlay_np[:, :, 3:4] / 255.0
|
||||
frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
|
||||
|
||||
# Draw all polylines on a single overlay
|
||||
if polyline_data:
|
||||
polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||
for tracks_coord, color in polyline_data:
|
||||
_draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity)
|
||||
|
||||
# Blend polylines overlay once
|
||||
polyline_np = np.array(polyline_overlay)
|
||||
alpha = polyline_np[:, :, 3:4] / 255.0
|
||||
frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
|
||||
|
||||
output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8)))
|
||||
|
||||
return output_frames
|
||||
|
||||
|
||||
class WanMoveVisualizeTracks(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanMoveVisualizeTracks",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Image.Input("images"),
|
||||
io.Tracks.Input("tracks", optional=True),
|
||||
io.Int.Input("line_resolution", default=24, min=1, max=1024),
|
||||
io.Int.Input("circle_size", default=12, min=1, max=128),
|
||||
io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
|
||||
io.Int.Input("line_width", default=16, min=1, max=128),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput:
|
||||
if tracks is None:
|
||||
return io.NodeOutput(images)
|
||||
|
||||
track_path = tracks["track_path"].unsqueeze(0)
|
||||
track_visibility = tracks["track_visibility"].unsqueeze(0)
|
||||
images_in = images * 255.0
|
||||
if images_in.shape[0] != track_path.shape[1]:
|
||||
repeat_count = track_path.shape[1] // images.shape[0]
|
||||
images_in = images_in.repeat(repeat_count, 1, 1, 1)
|
||||
track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
|
||||
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float()
|
||||
|
||||
return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device()))
|
||||
|
||||
|
||||
class WanMoveTracksFromCoords(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanMoveTracksFromCoords",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.String.Input("track_coords", force_input=True, default="[]", optional=True),
|
||||
io.Mask.Input("track_mask", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Tracks.Output(),
|
||||
io.Int.Output(display_name="track_length"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, track_coords, track_mask=None) -> io.NodeOutput:
|
||||
device=comfy.model_management.intermediate_device()
|
||||
|
||||
tracks_data = parse_json_tracks(track_coords)
|
||||
track_length = len(tracks_data[0])
|
||||
|
||||
track_list = [
|
||||
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
|
||||
for frame in range(len(tracks_data[0]))
|
||||
]
|
||||
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
|
||||
|
||||
num_tracks = tracks.shape[-2]
|
||||
if track_mask is None:
|
||||
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
|
||||
else:
|
||||
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
||||
|
||||
out_track_info = {}
|
||||
out_track_info["track_path"] = tracks
|
||||
out_track_info["track_visibility"] = track_visibility
|
||||
return io.NodeOutput(out_track_info, track_length)
|
||||
|
||||
|
||||
class GenerateTracks(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GenerateTracks",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=4096, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=4096, step=16),
|
||||
io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
|
||||
io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
|
||||
io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
|
||||
io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
|
||||
io.Int.Input("num_frames", default=81, min=1, max=1024),
|
||||
io.Int.Input("num_tracks", default=5, min=1, max=100),
|
||||
io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
|
||||
io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
|
||||
io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
|
||||
io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
|
||||
io.Combo.Input(
|
||||
"interpolation",
|
||||
options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
|
||||
tooltip="Controls the timing/speed of movement along the path.",
|
||||
),
|
||||
io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
|
||||
],
|
||||
outputs=[
|
||||
io.Tracks.Output(),
|
||||
io.Int.Output(display_name="track_length"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
|
||||
track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
|
||||
device = comfy.model_management.intermediate_device()
|
||||
track_length = num_frames
|
||||
|
||||
# normalized coordinates to pixel coordinates
|
||||
start_x_px = start_x * width
|
||||
start_y_px = start_y * height
|
||||
mid_x_px = mid_x * width
|
||||
mid_y_px = mid_y * height
|
||||
end_x_px = end_x * width
|
||||
end_y_px = end_y * height
|
||||
|
||||
track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
|
||||
|
||||
t = torch.linspace(0, 1, num_frames, device=device)
|
||||
if interpolation == "constant": # All points stay at start position
|
||||
interp_values = torch.zeros_like(t)
|
||||
elif interpolation == "linear":
|
||||
interp_values = t
|
||||
elif interpolation == "ease_in":
|
||||
interp_values = t ** 2
|
||||
elif interpolation == "ease_out":
|
||||
interp_values = 1 - (1 - t) ** 2
|
||||
elif interpolation == "ease_in_out":
|
||||
interp_values = t * t * (3 - 2 * t)
|
||||
|
||||
if bezier: # apply interpolation to t for timing control along the bezier path
|
||||
t_interp = interp_values
|
||||
one_minus_t = 1 - t_interp
|
||||
x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
|
||||
y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
|
||||
tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
|
||||
tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
|
||||
else: # calculate base x and y positions for each frame (center track)
|
||||
x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
|
||||
y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
|
||||
# For non-bezier, tangent is constant (direction from start to end)
|
||||
tangent_x = torch.full_like(t, end_x_px - start_x_px)
|
||||
tangent_y = torch.full_like(t, end_y_px - start_y_px)
|
||||
|
||||
track_list = []
|
||||
for frame_idx in range(num_frames):
|
||||
# Calculate perpendicular direction at this frame
|
||||
tx = tangent_x[frame_idx].item()
|
||||
ty = tangent_y[frame_idx].item()
|
||||
length = (tx ** 2 + ty ** 2) ** 0.5
|
||||
|
||||
if length > 0: # Perpendicular unit vector (rotate 90 degrees)
|
||||
perp_x = -ty / length
|
||||
perp_y = tx / length
|
||||
else: # If tangent is zero, spread horizontally
|
||||
perp_x = 1.0
|
||||
perp_y = 0.0
|
||||
|
||||
frame_tracks = []
|
||||
for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
|
||||
offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
|
||||
track_x = x_positions[frame_idx].item() + perp_x * offset
|
||||
track_y = y_positions[frame_idx].item() + perp_y * offset
|
||||
frame_tracks.append([track_x, track_y])
|
||||
track_list.append(frame_tracks)
|
||||
|
||||
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
|
||||
|
||||
if track_mask is None:
|
||||
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
|
||||
else:
|
||||
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
|
||||
|
||||
out_track_info = {}
|
||||
out_track_info["track_path"] = tracks
|
||||
out_track_info["track_visibility"] = track_visibility
|
||||
return io.NodeOutput(out_track_info, track_length)
|
||||
|
||||
|
||||
class WanMoveConcatTrack(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanMoveConcatTrack",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Tracks.Input("tracks_1"),
|
||||
io.Tracks.Input("tracks_2", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Tracks.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput:
|
||||
if tracks_2 is None:
|
||||
return io.NodeOutput(tracks_1)
|
||||
|
||||
tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension
|
||||
mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1)
|
||||
|
||||
out_track_info = {}
|
||||
out_track_info["track_path"] = tracks_out
|
||||
out_track_info["track_visibility"] = mask_out
|
||||
return io.NodeOutput(out_track_info)
|
||||
|
||||
|
||||
class WanMoveTrackToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanMoveTrackToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Tracks.Input("tracks", optional=True),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image"),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
device=comfy.model_management.intermediate_device()
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||
image[:start_image.shape[0]] = start_image
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
if tracks is not None and strength > 0.0:
|
||||
tracks_path = tracks["track_path"][:length] # [T, N, 2]
|
||||
num_tracks = tracks_path.shape[-2]
|
||||
|
||||
track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device))
|
||||
|
||||
track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
|
||||
track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
|
||||
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
|
||||
else:
|
||||
concat_latent_image_pos = concat_latent_image
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanMoveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
WanMoveTrackToVideo,
|
||||
WanMoveTracksFromCoords,
|
||||
WanMoveConcatTrack,
|
||||
WanMoveVisualizeTracks,
|
||||
GenerateTracks,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanMoveExtension:
|
||||
return WanMoveExtension()
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.76"
|
||||
__version__ = "0.4.0"
|
||||
|
||||
@ -63,18 +63,22 @@ def cuda_malloc_supported():
|
||||
return True
|
||||
|
||||
|
||||
version = ""
|
||||
|
||||
try:
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
ver_file = os.path.join(folder, "version.py")
|
||||
if os.path.isfile(ver_file):
|
||||
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
version = module.__version__
|
||||
except:
|
||||
pass
|
||||
|
||||
if not args.cuda_malloc:
|
||||
try:
|
||||
version = ""
|
||||
torch_spec = importlib.util.find_spec("torch")
|
||||
for folder in torch_spec.submodule_search_locations:
|
||||
ver_file = os.path.join(folder, "version.py")
|
||||
if os.path.isfile(ver_file):
|
||||
spec = importlib.util.spec_from_file_location("torch_version_import", ver_file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
version = module.__version__
|
||||
|
||||
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
|
||||
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
|
||||
args.cuda_malloc = cuda_malloc_supported()
|
||||
@ -90,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc:
|
||||
env_var += ",backend:cudaMallocAsync"
|
||||
|
||||
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var
|
||||
|
||||
def get_torch_version_noimport():
|
||||
return str(version)
|
||||
|
||||
@ -13,6 +13,7 @@ import asyncio
|
||||
import torch
|
||||
|
||||
import comfy.model_management
|
||||
from latent_preview import set_preview_method
|
||||
import nodes
|
||||
from comfy_execution.caching import (
|
||||
BasicCache,
|
||||
@ -669,6 +670,8 @@ class PromptExecutor:
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
|
||||
nodes.interrupt_processing(False)
|
||||
|
||||
if "client_id" in extra_data:
|
||||
|
||||
@ -8,6 +8,8 @@ import folder_paths
|
||||
import comfy.utils
|
||||
import logging
|
||||
|
||||
default_preview_method = args.preview_method
|
||||
|
||||
MAX_PREVIEW_RESOLUTION = args.preview_size
|
||||
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
|
||||
|
||||
@ -125,3 +127,11 @@ def prepare_callback(model, steps, x0_output_dict=None):
|
||||
pbar.update_absolute(step + 1, total_steps, preview_bytes)
|
||||
return callback
|
||||
|
||||
def set_preview_method(override: str = None):
|
||||
if override and override != "default":
|
||||
method = LatentPreviewMethod.from_string(override)
|
||||
if method is not None:
|
||||
args.preview_method = method
|
||||
return
|
||||
args.preview_method = default_preview_method
|
||||
|
||||
|
||||
3
main.py
3
main.py
@ -167,6 +167,9 @@ if __name__ == "__main__":
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||
|
||||
import cuda_malloc
|
||||
if "rocm" in cuda_malloc.get_torch_version_noimport():
|
||||
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
|
||||
|
||||
|
||||
if 'torch' in sys.modules:
|
||||
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
|
||||
|
||||
@ -1 +1 @@
|
||||
comfyui_manager==4.0.3b3
|
||||
comfyui_manager==4.0.3b5
|
||||
|
||||
5
nodes.py
5
nodes.py
@ -970,7 +970,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -2357,6 +2357,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_rope.py",
|
||||
"nodes_logic.py",
|
||||
"nodes_nop.py",
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@ -2382,7 +2384,6 @@ async def init_builtin_api_nodes():
|
||||
"nodes_recraft.py",
|
||||
"nodes_pixverse.py",
|
||||
"nodes_stability.py",
|
||||
"nodes_pika.py",
|
||||
"nodes_runway.py",
|
||||
"nodes_sora.py",
|
||||
"nodes_topaz.py",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.76"
|
||||
version = "0.4.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.33.10
|
||||
comfyui-workflow-templates==0.7.25
|
||||
comfyui-frontend-package==1.34.8
|
||||
comfyui-workflow-templates==0.7.59
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@ -2,6 +2,7 @@ import unittest
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
|
||||
# Add comfy to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
@ -15,6 +16,7 @@ if not has_gpu():
|
||||
|
||||
from comfy import ops
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class SimpleModel(torch.nn.Module):
|
||||
@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32),
|
||||
}
|
||||
|
||||
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||
# Create model and load state dict (strict=False because custom loading pops keys)
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Verify weights are wrapped in QuantizedTensor
|
||||
@ -115,7 +118,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
|
||||
# Forward pass
|
||||
input_tensor = torch.randn(5, 10, dtype=torch.bfloat16)
|
||||
output = model(input_tensor)
|
||||
with torch.inference_mode():
|
||||
output = model(input_tensor)
|
||||
|
||||
self.assertEqual(output.shape, (5, 40))
|
||||
|
||||
@ -141,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||
model.load_state_dict(state_dict1, strict=False)
|
||||
|
||||
# Save state dict
|
||||
@ -178,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Add a weight function (simulating LoRA)
|
||||
@ -215,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase):
|
||||
"layer3.bias": torch.randn(40, dtype=torch.bfloat16),
|
||||
}
|
||||
|
||||
state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})})
|
||||
|
||||
# Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config))
|
||||
model = SimpleModel(operations=ops.mixed_precision_ops({}))
|
||||
with self.assertRaises(KeyError):
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
352
tests-unit/execution_test/preview_method_override_test.py
Normal file
352
tests-unit/execution_test/preview_method_override_test.py
Normal file
@ -0,0 +1,352 @@
|
||||
"""
|
||||
Unit tests for Queue-specific Preview Method Override feature.
|
||||
|
||||
Tests the preview method override functionality:
|
||||
- LatentPreviewMethod.from_string() method
|
||||
- set_preview_method() function in latent_preview.py
|
||||
- default_preview_method variable
|
||||
- Integration with args.preview_method
|
||||
"""
|
||||
import pytest
|
||||
from comfy.cli_args import args, LatentPreviewMethod
|
||||
from latent_preview import set_preview_method, default_preview_method
|
||||
|
||||
|
||||
class TestLatentPreviewMethodFromString:
|
||||
"""Test LatentPreviewMethod.from_string() classmethod."""
|
||||
|
||||
@pytest.mark.parametrize("value,expected", [
|
||||
("auto", LatentPreviewMethod.Auto),
|
||||
("latent2rgb", LatentPreviewMethod.Latent2RGB),
|
||||
("taesd", LatentPreviewMethod.TAESD),
|
||||
("none", LatentPreviewMethod.NoPreviews),
|
||||
])
|
||||
def test_valid_values_return_enum(self, value, expected):
|
||||
"""Valid string values should return corresponding enum."""
|
||||
assert LatentPreviewMethod.from_string(value) == expected
|
||||
|
||||
@pytest.mark.parametrize("invalid", [
|
||||
"invalid",
|
||||
"TAESD", # Case sensitive
|
||||
"AUTO", # Case sensitive
|
||||
"Latent2RGB", # Case sensitive
|
||||
"latent",
|
||||
"",
|
||||
"default", # default is special, not a method
|
||||
])
|
||||
def test_invalid_values_return_none(self, invalid):
|
||||
"""Invalid string values should return None."""
|
||||
assert LatentPreviewMethod.from_string(invalid) is None
|
||||
|
||||
|
||||
class TestLatentPreviewMethodEnumValues:
|
||||
"""Test LatentPreviewMethod enum has expected values."""
|
||||
|
||||
def test_enum_values(self):
|
||||
"""Verify enum values match expected strings."""
|
||||
assert LatentPreviewMethod.NoPreviews.value == "none"
|
||||
assert LatentPreviewMethod.Auto.value == "auto"
|
||||
assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb"
|
||||
assert LatentPreviewMethod.TAESD.value == "taesd"
|
||||
|
||||
def test_enum_count(self):
|
||||
"""Verify exactly 4 preview methods exist."""
|
||||
assert len(LatentPreviewMethod) == 4
|
||||
|
||||
|
||||
class TestSetPreviewMethod:
|
||||
"""Test set_preview_method() function from latent_preview.py."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Store original value before each test."""
|
||||
self.original = args.preview_method
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore original value after each test."""
|
||||
args.preview_method = self.original
|
||||
|
||||
def test_override_with_taesd(self):
|
||||
"""'taesd' should set args.preview_method to TAESD."""
|
||||
set_preview_method("taesd")
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
def test_override_with_latent2rgb(self):
|
||||
"""'latent2rgb' should set args.preview_method to Latent2RGB."""
|
||||
set_preview_method("latent2rgb")
|
||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||
|
||||
def test_override_with_auto(self):
|
||||
"""'auto' should set args.preview_method to Auto."""
|
||||
set_preview_method("auto")
|
||||
assert args.preview_method == LatentPreviewMethod.Auto
|
||||
|
||||
def test_override_with_none_value(self):
|
||||
"""'none' should set args.preview_method to NoPreviews."""
|
||||
set_preview_method("none")
|
||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||
|
||||
def test_default_restores_original(self):
|
||||
"""'default' should restore to default_preview_method."""
|
||||
# First override to something else
|
||||
set_preview_method("taesd")
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Then use 'default' to restore
|
||||
set_preview_method("default")
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_none_param_restores_original(self):
|
||||
"""None parameter should restore to default_preview_method."""
|
||||
# First override to something else
|
||||
set_preview_method("taesd")
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Then use None to restore
|
||||
set_preview_method(None)
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_empty_string_restores_original(self):
|
||||
"""Empty string should restore to default_preview_method."""
|
||||
set_preview_method("taesd")
|
||||
set_preview_method("")
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_invalid_value_restores_original(self):
|
||||
"""Invalid value should restore to default_preview_method."""
|
||||
set_preview_method("taesd")
|
||||
set_preview_method("invalid_method")
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_case_sensitive_invalid_restores(self):
|
||||
"""Case-mismatched values should restore to default."""
|
||||
set_preview_method("taesd")
|
||||
set_preview_method("TAESD") # Wrong case
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
|
||||
class TestDefaultPreviewMethod:
|
||||
"""Test default_preview_method module variable."""
|
||||
|
||||
def test_default_is_not_none(self):
|
||||
"""default_preview_method should not be None."""
|
||||
assert default_preview_method is not None
|
||||
|
||||
def test_default_is_enum_member(self):
|
||||
"""default_preview_method should be a LatentPreviewMethod enum."""
|
||||
assert isinstance(default_preview_method, LatentPreviewMethod)
|
||||
|
||||
def test_default_matches_args_initial(self):
|
||||
"""default_preview_method should match CLI default or user setting."""
|
||||
# This tests that default_preview_method was captured at module load
|
||||
# After set_preview_method(None), args should equal default
|
||||
original = args.preview_method
|
||||
set_preview_method("taesd")
|
||||
set_preview_method(None)
|
||||
assert args.preview_method == default_preview_method
|
||||
args.preview_method = original
|
||||
|
||||
|
||||
class TestArgsPreviewMethodModification:
|
||||
"""Test args.preview_method can be modified correctly."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Store original value before each test."""
|
||||
self.original = args.preview_method
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore original value after each test."""
|
||||
args.preview_method = self.original
|
||||
|
||||
def test_args_accepts_all_enum_values(self):
|
||||
"""args.preview_method should accept all LatentPreviewMethod values."""
|
||||
for method in LatentPreviewMethod:
|
||||
args.preview_method = method
|
||||
assert args.preview_method == method
|
||||
|
||||
def test_args_modification_and_restoration(self):
|
||||
"""args.preview_method should be modifiable and restorable."""
|
||||
original = args.preview_method
|
||||
|
||||
args.preview_method = LatentPreviewMethod.TAESD
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
args.preview_method = original
|
||||
assert args.preview_method == original
|
||||
|
||||
|
||||
class TestExecutionFlow:
|
||||
"""Test the execution flow pattern used in execution.py."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Store original value before each test."""
|
||||
self.original = args.preview_method
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore original value after each test."""
|
||||
args.preview_method = self.original
|
||||
|
||||
def test_sequential_executions_with_different_methods(self):
|
||||
"""Simulate multiple queue executions with different preview methods."""
|
||||
# Execution 1: taesd
|
||||
set_preview_method("taesd")
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Execution 2: none
|
||||
set_preview_method("none")
|
||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||
|
||||
# Execution 3: default (restore)
|
||||
set_preview_method("default")
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
# Execution 4: auto
|
||||
set_preview_method("auto")
|
||||
assert args.preview_method == LatentPreviewMethod.Auto
|
||||
|
||||
# Execution 5: no override (None)
|
||||
set_preview_method(None)
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_override_then_default_pattern(self):
|
||||
"""Test the pattern: override -> execute -> next call restores."""
|
||||
# First execution with override
|
||||
set_preview_method("latent2rgb")
|
||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||
|
||||
# Second execution without override restores default
|
||||
set_preview_method(None)
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_extra_data_simulation(self):
|
||||
"""Simulate extra_data.get('preview_method') patterns."""
|
||||
# Simulate: extra_data = {"preview_method": "taesd"}
|
||||
extra_data = {"preview_method": "taesd"}
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Simulate: extra_data = {}
|
||||
extra_data = {}
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
# Simulate: extra_data = {"preview_method": "default"}
|
||||
extra_data = {"preview_method": "default"}
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
|
||||
class TestRealWorldScenarios:
|
||||
"""Tests using real-world prompt data patterns."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Store original value before each test."""
|
||||
self.original = args.preview_method
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore original value after each test."""
|
||||
args.preview_method = self.original
|
||||
|
||||
def test_captured_prompt_without_preview_method(self):
|
||||
"""
|
||||
Test with captured prompt that has no preview_method.
|
||||
Based on: tests-unit/execution_test/fixtures/default_prompt.json
|
||||
"""
|
||||
# Real captured extra_data structure (preview_method absent)
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
|
||||
"create_time": 1765416558179
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_captured_prompt_with_preview_method_taesd(self):
|
||||
"""Test captured prompt with preview_method: taesd."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
|
||||
"preview_method": "taesd"
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
def test_captured_prompt_with_preview_method_none(self):
|
||||
"""Test captured prompt with preview_method: none (disable preview)."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "test-client",
|
||||
"preview_method": "none"
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||
|
||||
def test_captured_prompt_with_preview_method_latent2rgb(self):
|
||||
"""Test captured prompt with preview_method: latent2rgb."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "test-client",
|
||||
"preview_method": "latent2rgb"
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||
|
||||
def test_captured_prompt_with_preview_method_auto(self):
|
||||
"""Test captured prompt with preview_method: auto."""
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "test-client",
|
||||
"preview_method": "auto"
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.Auto
|
||||
|
||||
def test_captured_prompt_with_preview_method_default(self):
|
||||
"""Test captured prompt with preview_method: default (use CLI setting)."""
|
||||
# First set to something else
|
||||
set_preview_method("taesd")
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Then simulate a prompt with "default"
|
||||
extra_data = {
|
||||
"extra_pnginfo": {"workflow": {}},
|
||||
"client_id": "test-client",
|
||||
"preview_method": "default"
|
||||
}
|
||||
|
||||
set_preview_method(extra_data.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
def test_sequential_queue_with_different_preview_methods(self):
|
||||
"""
|
||||
Simulate real queue scenario: multiple prompts with different settings.
|
||||
This tests the actual usage pattern in ComfyUI.
|
||||
"""
|
||||
# Queue 1: User wants TAESD preview
|
||||
extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"}
|
||||
set_preview_method(extra_data_1.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.TAESD
|
||||
|
||||
# Queue 2: User wants no preview (faster execution)
|
||||
extra_data_2 = {"client_id": "client-2", "preview_method": "none"}
|
||||
set_preview_method(extra_data_2.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.NoPreviews
|
||||
|
||||
# Queue 3: User doesn't specify (use server default)
|
||||
extra_data_3 = {"client_id": "client-3"}
|
||||
set_preview_method(extra_data_3.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
# Queue 4: User explicitly wants default
|
||||
extra_data_4 = {"client_id": "client-4", "preview_method": "default"}
|
||||
set_preview_method(extra_data_4.get("preview_method"))
|
||||
assert args.preview_method == default_preview_method
|
||||
|
||||
# Queue 5: User wants latent2rgb
|
||||
extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"}
|
||||
set_preview_method(extra_data_5.get("preview_method"))
|
||||
assert args.preview_method == LatentPreviewMethod.Latent2RGB
|
||||
358
tests/execution/test_preview_method.py
Normal file
358
tests/execution/test_preview_method.py
Normal file
@ -0,0 +1,358 @@
|
||||
"""
|
||||
E2E tests for Queue-specific Preview Method Override feature.
|
||||
|
||||
Tests actual execution with different preview_method values.
|
||||
Requires a running ComfyUI server with models.
|
||||
|
||||
Usage:
|
||||
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
|
||||
|
||||
Note:
|
||||
These tests execute actual image generation and wait for completion.
|
||||
Tests verify preview image transmission based on preview_method setting.
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import pytest
|
||||
import uuid
|
||||
import time
|
||||
import random
|
||||
import websocket
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Server configuration
|
||||
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
|
||||
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
|
||||
|
||||
# Use existing inference graph fixture
|
||||
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
|
||||
|
||||
|
||||
def is_server_running() -> bool:
|
||||
"""Check if ComfyUI server is running."""
|
||||
try:
|
||||
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
|
||||
with urllib.request.urlopen(request, timeout=2.0):
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
|
||||
"""Prepare graph for testing: randomize seeds and reduce steps."""
|
||||
adapted = json.loads(json.dumps(graph)) # Deep copy
|
||||
for node_id, node in adapted.items():
|
||||
inputs = node.get("inputs", {})
|
||||
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
|
||||
if "seed" in inputs:
|
||||
inputs["seed"] = random.randint(0, 2**32 - 1)
|
||||
if "noise_seed" in inputs:
|
||||
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
|
||||
# Reduce steps for faster testing (default 20 -> 5)
|
||||
if "steps" in inputs:
|
||||
inputs["steps"] = steps
|
||||
return adapted
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
randomize_seed = prepare_graph_for_test
|
||||
|
||||
|
||||
class PreviewMethodClient:
|
||||
"""Client for testing preview_method with WebSocket execution tracking."""
|
||||
|
||||
def __init__(self, server_address: str):
|
||||
self.server_address = server_address
|
||||
self.client_id = str(uuid.uuid4())
|
||||
self.ws = None
|
||||
|
||||
def connect(self):
|
||||
"""Connect to WebSocket."""
|
||||
self.ws = websocket.WebSocket()
|
||||
self.ws.settimeout(120) # 2 minute timeout for sampling
|
||||
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
|
||||
|
||||
def close(self):
|
||||
"""Close WebSocket connection."""
|
||||
if self.ws:
|
||||
self.ws.close()
|
||||
|
||||
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
|
||||
"""Queue a prompt and return response with prompt_id."""
|
||||
data = {
|
||||
"prompt": prompt,
|
||||
"client_id": self.client_id,
|
||||
"extra_data": extra_data or {}
|
||||
}
|
||||
req = urllib.request.Request(
|
||||
f"http://{self.server_address}/prompt",
|
||||
data=json.dumps(data).encode("utf-8"),
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
return json.loads(urllib.request.urlopen(req).read())
|
||||
|
||||
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
|
||||
"""
|
||||
Wait for execution to complete via WebSocket.
|
||||
|
||||
Returns:
|
||||
dict with keys: completed, error, preview_count, execution_time
|
||||
"""
|
||||
result = {
|
||||
"completed": False,
|
||||
"error": None,
|
||||
"preview_count": 0,
|
||||
"execution_time": 0.0
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
self.ws.settimeout(timeout)
|
||||
|
||||
try:
|
||||
while True:
|
||||
out = self.ws.recv()
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
msg_type = message.get("type")
|
||||
data = message.get("data", {})
|
||||
|
||||
if data.get("prompt_id") != prompt_id:
|
||||
continue
|
||||
|
||||
if msg_type == "executing":
|
||||
if data.get("node") is None:
|
||||
# Execution complete
|
||||
result["completed"] = True
|
||||
result["execution_time"] = elapsed
|
||||
break
|
||||
|
||||
elif msg_type == "execution_error":
|
||||
result["error"] = data
|
||||
result["execution_time"] = elapsed
|
||||
break
|
||||
|
||||
elif msg_type == "progress":
|
||||
# Progress update during sampling
|
||||
pass
|
||||
|
||||
elif isinstance(out, bytes):
|
||||
# Binary data = preview image
|
||||
result["preview_count"] += 1
|
||||
|
||||
except websocket.WebSocketTimeoutException:
|
||||
result["error"] = "Timeout waiting for execution"
|
||||
result["execution_time"] = time.time() - start_time
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def load_graph() -> dict:
|
||||
"""Load the SDXL graph fixture with randomized seed."""
|
||||
with open(GRAPH_FILE) as f:
|
||||
graph = json.load(f)
|
||||
return randomize_seed(graph) # Avoid caching
|
||||
|
||||
|
||||
# Skip all tests if server is not running
|
||||
pytestmark = [
|
||||
pytest.mark.skipif(
|
||||
not is_server_running(),
|
||||
reason=f"ComfyUI server not running at {SERVER_URL}"
|
||||
),
|
||||
pytest.mark.preview_method,
|
||||
pytest.mark.execution,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create and connect a test client."""
|
||||
c = PreviewMethodClient(SERVER_HOST)
|
||||
c.connect()
|
||||
yield c
|
||||
c.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graph():
|
||||
"""Load the test graph."""
|
||||
return load_graph()
|
||||
|
||||
|
||||
class TestPreviewMethodExecution:
|
||||
"""Test actual execution with different preview methods."""
|
||||
|
||||
def test_execution_with_latent2rgb(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=latent2rgb.
|
||||
Should complete and potentially receive preview images.
|
||||
"""
|
||||
extra_data = {"preview_method": "latent2rgb"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
# Should complete (may error if model missing, but that's separate)
|
||||
assert result["completed"] or result["error"] is not None
|
||||
# Execution should take some time (sampling)
|
||||
if result["completed"]:
|
||||
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
|
||||
# latent2rgb should produce previews
|
||||
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_with_taesd(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=taesd.
|
||||
TAESD provides higher quality previews.
|
||||
"""
|
||||
extra_data = {"preview_method": "taesd"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
assert result["execution_time"] > 0.5
|
||||
# taesd should also produce previews
|
||||
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_with_none_preview(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=none.
|
||||
No preview images should be generated.
|
||||
"""
|
||||
extra_data = {"preview_method": "none"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
# With "none", should receive no preview images
|
||||
assert result["preview_count"] == 0, \
|
||||
f"Expected no previews with 'none', got {result['preview_count']}"
|
||||
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_with_default(self, client, graph):
|
||||
"""
|
||||
Execute with preview_method=default.
|
||||
Should use server's CLI default setting.
|
||||
"""
|
||||
extra_data = {"preview_method": "default"}
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
def test_execution_without_preview_method(self, client, graph):
|
||||
"""
|
||||
Execute without preview_method in extra_data.
|
||||
Should use server's default preview method.
|
||||
"""
|
||||
extra_data = {} # No preview_method
|
||||
|
||||
response = client.queue_prompt(graph, extra_data)
|
||||
assert "prompt_id" in response
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
assert result["completed"] or result["error"] is not None
|
||||
if result["completed"]:
|
||||
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
|
||||
|
||||
|
||||
class TestPreviewMethodComparison:
|
||||
"""Compare preview behavior between different methods."""
|
||||
|
||||
def test_none_vs_latent2rgb_preview_count(self, client, graph):
|
||||
"""
|
||||
Compare preview counts: 'none' should have 0, others should have >0.
|
||||
This is the key verification that preview_method actually works.
|
||||
"""
|
||||
results = {}
|
||||
|
||||
# Run with none (randomize seed to avoid caching)
|
||||
graph_none = randomize_seed(graph)
|
||||
extra_data_none = {"preview_method": "none"}
|
||||
response = client.queue_prompt(graph_none, extra_data_none)
|
||||
results["none"] = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
# Run with latent2rgb (randomize seed again)
|
||||
graph_rgb = randomize_seed(graph)
|
||||
extra_data_rgb = {"preview_method": "latent2rgb"}
|
||||
response = client.queue_prompt(graph_rgb, extra_data_rgb)
|
||||
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
|
||||
|
||||
# Verify both completed
|
||||
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
|
||||
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
|
||||
|
||||
# Key assertion: 'none' should have 0 previews
|
||||
assert results["none"]["preview_count"] == 0, \
|
||||
f"'none' should have 0 previews, got {results['none']['preview_count']}"
|
||||
|
||||
# 'latent2rgb' should have at least 1 preview (depends on steps)
|
||||
assert results["latent2rgb"]["preview_count"] > 0, \
|
||||
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
|
||||
|
||||
print("\nPreview count comparison:") # noqa: T201
|
||||
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
|
||||
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
|
||||
|
||||
|
||||
class TestPreviewMethodSequential:
|
||||
"""Test sequential execution with different preview methods."""
|
||||
|
||||
def test_sequential_different_methods(self, client, graph):
|
||||
"""
|
||||
Execute multiple prompts sequentially with different preview methods.
|
||||
Each should complete independently with correct preview behavior.
|
||||
"""
|
||||
methods = ["latent2rgb", "none", "default"]
|
||||
results = []
|
||||
|
||||
for method in methods:
|
||||
# Randomize seed for each execution to avoid caching
|
||||
graph_run = randomize_seed(graph)
|
||||
extra_data = {"preview_method": method}
|
||||
response = client.queue_prompt(graph_run, extra_data)
|
||||
|
||||
result = client.wait_for_execution(response["prompt_id"])
|
||||
results.append({
|
||||
"method": method,
|
||||
"completed": result["completed"],
|
||||
"preview_count": result["preview_count"],
|
||||
"execution_time": result["execution_time"],
|
||||
"error": result["error"]
|
||||
})
|
||||
|
||||
# All should complete or have clear errors
|
||||
for r in results:
|
||||
assert r["completed"] or r["error"] is not None, \
|
||||
f"Method {r['method']} neither completed nor errored"
|
||||
|
||||
# "none" should have zero previews if completed
|
||||
none_result = next(r for r in results if r["method"] == "none")
|
||||
if none_result["completed"]:
|
||||
assert none_result["preview_count"] == 0, \
|
||||
f"'none' should have 0 previews, got {none_result['preview_count']}"
|
||||
|
||||
print("\nSequential execution results:") # noqa: T201
|
||||
for r in results:
|
||||
status = "✓" if r["completed"] else f"✗ ({r['error']})"
|
||||
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201
|
||||
Loading…
Reference in New Issue
Block a user