Compare commits

..

No commits in common. "master" and "v0.3.76" have entirely different histories.

123 changed files with 2629 additions and 5951 deletions

View File

@ -53,16 +53,6 @@ try:
repo.stash(ident)
except KeyError:
print("nothing to stash") # noqa: T201
except:
print("Could not stash, cleaning index and trying again.") # noqa: T201
repo.state_cleanup()
repo.index.read_tree(repo.head.peel().tree)
repo.index.write()
try:
repo.stash(ident)
except KeyError:
print("nothing to stash.") # noqa: T201
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try:
@ -76,10 +66,8 @@ if branch is None:
try:
ref = repo.lookup_reference('refs/remotes/origin/master')
except:
print("fetching.") # noqa: T201
for remote in repo.remotes:
if remote.name == "origin":
remote.fetch()
print("pulling.") # noqa: T201
pull(repo)
ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref)
branch = repo.lookup_branch('master')
@ -161,4 +149,3 @@ try:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass

View File

@ -5,7 +5,6 @@ on:
push:
branches:
- master
- release/**
paths-ignore:
- 'app/**'
- 'input/**'

View File

@ -2,9 +2,9 @@ name: Execution Tests
on:
push:
branches: [ main, master, release/** ]
branches: [ main, master ]
pull_request:
branches: [ main, master, release/** ]
branches: [ main, master ]
jobs:
test:

View File

@ -2,9 +2,9 @@ name: Test server launches without errors
on:
push:
branches: [ main, master, release/** ]
branches: [ main, master ]
pull_request:
branches: [ main, master, release/** ]
branches: [ main, master ]
jobs:
test:

View File

@ -2,9 +2,9 @@ name: Unit Tests
on:
push:
branches: [ main, master, release/** ]
branches: [ main, master ]
pull_request:
branches: [ main, master, release/** ]
branches: [ main, master ]
jobs:
test:

View File

@ -6,7 +6,6 @@ on:
- "pyproject.toml"
branches:
- master
- release/**
jobs:
update-version:

View File

@ -1,2 +1,3 @@
# Admins
* @comfyanonymous @kosinkadink @guill
* @comfyanonymous
* @kosinkadink

View File

@ -81,7 +81,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
@ -320,32 +319,6 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
2. Launch ComfyUI by running `python main.py`
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
### Setup
1. Install the manager dependencies:
```bash
pip install -r manager_requirements.txt
```
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
```bash
python main.py --enable-manager
```
### Command Line Options
| Flag | Description |
|------|-------------|
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
# Running
```python main.py```

View File

@ -58,13 +58,8 @@ class InternalRoutes:
return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type)
def is_visible_file(entry: os.DirEntry) -> bool:
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
return entry.is_file() and not entry.name.startswith('.')
sorted_files = sorted(
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
(entry for entry in os.scandir(directory) if entry.is_file()),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)

View File

@ -97,13 +97,6 @@ class LatentPreviewMethod(enum.Enum):
Latent2RGB = "latent2rgb"
TAESD = "taesd"
@classmethod
def from_string(cls, value: str):
for member in cls:
if member.value == value:
return member
return None
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
@ -128,12 +121,6 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
manager_group = parser.add_mutually_exclusive_group()
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
@ -181,7 +168,6 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"

View File

@ -51,43 +51,32 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
def __init__(self, index_list: list[int], dim: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.dim = dim
self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
if dim == 0 and full.shape[dim] == 1:
return full
idx = tuple([slice(None)] * dim + [self.index_list])
window = full[idx]
if retain_index_list:
idx = tuple([slice(None)] * dim + [retain_index_list])
window[idx] = full[idx]
return window.to(device)
idx = [slice(None)] * dim + [self.index_list]
return full[idx].to(device)
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
if dim is None:
dim = self.dim
idx = tuple([slice(None)] * dim + [self.index_list])
idx = [slice(None)] * dim + [self.index_list]
full[idx] += to_add
return full
def get_region_index(self, num_regions: int) -> int:
region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1)
class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup"
RESIZE_COND_ITEM = "resize_cond_item"
def init_callbacks(self):
return {}
@ -105,8 +94,7 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@ -115,18 +103,13 @@ class IndexListContextHandler(ContextHandlerABC):
self.closed_loop = closed_loop
self.dim = dim
self._step = 0
self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows
self.callbacks = {}
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
return True
return False
@ -140,11 +123,6 @@ class IndexListContextHandler(ContextHandlerABC):
return None
# reuse or resize cond items to match context requirements
resized_cond = []
# if multiple conds, split based on primary region
if self.split_conds_to_windows and len(cond_in) > 1:
region = window.get_region_index(len(cond_in))
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
cond_in = [cond_in[region]]
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
for actual_cond in cond_in:
resized_actual_cond = actual_cond.copy()
@ -167,32 +145,13 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
# Allow callbacks to handle custom conditioning items
handled = False
for callback in comfy.patcher_extension.get_all_callbacks(
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
):
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
if result is not None:
new_cond_item[cond_key] = result
handled = True
break
if handled:
continue
if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1)
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
audio_cond = cond_value.cond
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
# if has cond that is a Tensor, check if needs to be subset
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
elif cond_key == "num_video_frames": # for SVD
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
new_cond_item[cond_key].cond = window.context_length
@ -205,7 +164,7 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
@ -214,7 +173,7 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
@ -291,8 +250,8 @@ class IndexListContextHandler(ContextHandlerABC):
prev_weight = (bias_total / (bias_total + bias))
new_weight = (bias / (bias_total + bias))
# account for dims of tensors
idx_window = tuple([slice(None)] * self.dim + [idx])
pos_window = tuple([slice(None)] * self.dim + [pos])
idx_window = [slice(None)] * self.dim + [idx]
pos_window = [slice(None)] * self.dim + [pos]
# apply new values
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
biases_final[i][idx] = bias_total + bias
@ -328,28 +287,6 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
)
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
model_options = extra_args.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is None:
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
"ContextWindows_sampler_sample",
_sampler_sample_wrapper
)
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
@ -601,29 +538,3 @@ def shift_window_to_end(window: list[int], num_frames: int):
for i in range(len(window)):
# 2) add end_delta to each val to slide windows to end
window[i] = window[i] + end_delta
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
logging.info("Context windows: Applying FreeNoise")
generator = torch.Generator(device='cpu').manual_seed(seed)
latent_video_length = noise.shape[dim]
delta = context_length - context_overlap
for start_idx in range(0, latent_video_length - context_length, delta):
place_idx = start_idx + context_length
actual_delta = min(delta, latent_video_length - place_idx)
if actual_delta <= 0:
break
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
source_slice = [slice(None)] * noise.ndim
source_slice[dim] = list_idx
target_slice = [slice(None)] * noise.ndim
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
noise[tuple(target_slice)] = noise[tuple(source_slice)]
return noise

View File

@ -1557,13 +1557,10 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1603,14 +1600,8 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
if solver_type == "phi_1":
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()

View File

@ -40,8 +40,7 @@ class ChromaParams:
out_dim: int
hidden_dim: int
n_layers: int
txt_ids_dims: list
vec_in_dim: int

View File

@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
nerf_final_head_type: str
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
class ChromaRadiance(Chroma):
"""
@ -159,9 +159,6 @@ class ChromaRadiance(Chroma):
self.skip_dit = []
self.lite = False
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
@ -279,12 +276,6 @@ class ChromaRadiance(Chroma):
params_dict |= overrides
return params.__class__(**params_dict)
def _apply_x0_residual(self, predicted, noisy, timesteps):
# non zero during training to prevent 0 div
eps = 0.0
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
def _forward(
self,
x: Tensor,
@ -325,11 +316,4 @@ class ChromaRadiance(Chroma):
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
# If x0 variant → v-pred, just return this instead
if hasattr(self, "__x0__"):
out = self._apply_x0_residual(out, img, timestep)
return out
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]

View File

@ -57,35 +57,6 @@ class MLPEmbedder(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class YakMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
self.act_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
if yak_mlp:
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
if mlp_silu_act:
return nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
return nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
@ -169,7 +140,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@ -185,7 +156,18 @@ class DoubleStreamBlock(nn.Module):
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
if mlp_silu_act:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
@ -195,7 +177,18 @@ class DoubleStreamBlock(nn.Module):
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
if mlp_silu_act:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
@ -282,7 +275,6 @@ class SingleStreamBlock(nn.Module):
modulation=True,
mlp_silu_act=False,
bias=True,
yak_mlp=False,
dtype=None,
device=None,
operations=None
@ -296,17 +288,12 @@ class SingleStreamBlock(nn.Module):
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim
self.yak_mlp = yak_mlp
if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation()
else:
self.mlp_act = nn.GELU(approximate="tanh")
if self.yak_mlp:
self.mlp_hidden_dim_first *= 2
self.mlp_act = nn.SiLU()
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out
@ -338,10 +325,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
else:
mlp = self.mlp_act(mlp)
mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:

View File

@ -15,8 +15,7 @@ from .layers import (
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
Modulation,
RMSNorm
Modulation
)
@dataclass
@ -35,14 +34,11 @@ class FluxParams:
patch_size: int
qkv_bias: bool
guidance_embed: bool
txt_ids_dims: list
global_modulation: bool = False
mlp_silu_act: bool = False
ops_bias: bool = True
default_ref_method: str = "offset"
ref_index_scale: float = 1.0
yak_mlp: bool = False
txt_norm: bool = False
class Flux(nn.Module):
@ -80,11 +76,6 @@ class Flux(nn.Module):
)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
else:
self.txt_norm = None
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@ -95,7 +86,6 @@ class Flux(nn.Module):
modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias,
yak_mlp=params.yak_mlp,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@ -104,7 +94,7 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
@ -160,8 +150,6 @@ class Flux(nn.Module):
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.txt_norm is not None:
txt = self.txt_norm(txt)
txt = self.txt_in(txt)
vec_orig = vec
@ -344,9 +332,8 @@ class Flux(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
if len(self.params.txt_ids_dims) > 0:
for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
if len(self.params.axes_dim) == 4: # Flux 2
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]

View File

@ -43,7 +43,6 @@ class HunyuanVideoParams:
meanflow: bool
use_cond_type_embedding: bool
vision_in_dim: int
meanflow_sum: bool
class SelfAttentionRef(nn.Module):
@ -318,7 +317,7 @@ class HunyuanVideo(nn.Module):
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
vec = (vec + vec_r) / 2
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)

View File

@ -1,8 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module):

View File

@ -1,12 +1,42 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops
import comfy.ldm.models.autoencoder
import comfy.model_management
ops = comfy.ops.disable_weight_init
class NoPadConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, NoPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
out = op(x)
return out
class RMS_norm(nn.Module):
def __init__(self, dim):
@ -19,7 +49,7 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds, refiner_vae, op):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
@ -79,7 +109,7 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus, refiner_vae, op):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
@ -133,6 +163,23 @@ class UpSmpl(nn.Module):
return h + x
class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
@ -144,7 +191,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = CarriedConv3d
conv_op = NoPadConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@ -159,10 +206,9 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
@ -172,9 +218,9 @@ class Encoder(nn.Module):
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@ -200,20 +246,22 @@ class Encoder(nn.Module):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down:
for blk in stage.block:
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
x1 = blk(x1, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
out = torch_cat_if_needed(out, dim=2)
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out
@ -240,7 +288,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
conv_op = CarriedConv3d
conv_op = NoPadConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@ -250,9 +298,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@ -260,10 +308,9 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op)
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
@ -293,7 +340,7 @@ class Decoder(nn.Module):
conv_carry_out = None
for stage in self.up:
for blk in stage.block:
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
x1 = blk(x1, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
@ -303,7 +350,10 @@ class Decoder(nn.Module):
conv_carry_in = conv_carry_out
del x
out = torch_cat_if_needed(out, dim=2)
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
if not self.refiner_vae:
if z.shape[-3] == 1:

View File

@ -1,413 +0,0 @@
import torch
from torch import nn
import math
import comfy.ldm.common_dit
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.layers import EmbedND
def attention(q, k, v, heads, transformer_options={}):
return optimized_attention(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
heads=heads,
skip_reshape=True,
transformer_options=transformer_options
)
def apply_scale_shift_norm(norm, x, scale, shift):
return torch.addcmul(shift, norm(x), scale + 1.0)
def apply_gate_sum(x, out, gate):
return torch.addcmul(x, gate, out)
def get_shift_scale_gate(params):
shift, scale, gate = torch.chunk(params, 3, dim=-1)
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
def get_freqs(dim, max_period=10000.0):
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
class TimeEmbeddings(nn.Module):
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
super().__init__()
assert model_dim % 2 == 0
self.model_dim = model_dim
self.max_period = max_period
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.SiLU()
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, timestep, dtype):
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
class TextEmbeddings(nn.Module):
def __init__(self, text_dim, model_dim, operation_settings=None):
super().__init__()
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, text_embed):
text_embed = self.in_layer(text_embed)
return self.norm(text_embed).type_as(text_embed)
class VisualEmbeddings(nn.Module):
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
super().__init__()
self.patch_size = patch_size
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
x = x.movedim(1, -1) # B C T H W -> B T H W C
B, T, H, W, dim = x.shape
pt, ph, pw = self.patch_size
x = x.view(
B,
T // pt, pt,
H // ph, ph,
W // pw, pw,
dim,
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
return self.in_layer(x)
class Modulation(nn.Module):
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
super().__init__()
self.activation = nn.SiLU()
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
return self.out_layer(self.activation(x))
class SelfAttention(nn.Module):
def __init__(self, num_channels, head_dim, operation_settings=None):
super().__init__()
assert num_channels % head_dim == 0
self.num_heads = num_channels // head_dim
self.head_dim = head_dim
operations = operation_settings.get("operations")
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 2
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(norm_fn(result), freqs)
def _forward(self, x, freqs, transformer_options={}):
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
def _forward_chunked(self, x, freqs, transformer_options={}):
def process_chunks(proj_fn, norm_fn):
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
chunks = []
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
return torch.cat(chunks, dim=1)
q = process_chunks(self.to_query, self.query_norm)
k = process_chunks(self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
def forward(self, x, freqs, transformer_options={}):
if x.shape[1] > 8192:
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
else:
return self._forward(x, freqs, transformer_options=transformer_options)
class CrossAttention(SelfAttention):
def get_qkv(self, x, context):
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
return q, k, v
def forward(self, x, context, transformer_options={}):
q, k, v = self.get_qkv(x, context)
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
return self.out_layer(out)
class FeedForward(nn.Module):
def __init__(self, dim, ff_dim, operation_settings=None):
super().__init__()
operations = operation_settings.get("operations")
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.activation = nn.GELU()
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.num_chunks = 4
def _forward(self, x):
return self.out_layer(self.activation(self.in_layer(x)))
def _forward_chunked(self, x):
chunks = torch.chunk(x, self.num_chunks, dim=1)
output_chunks = []
for chunk in chunks:
output_chunks.append(self._forward(chunk))
return torch.cat(output_chunks, dim=1)
def forward(self, x):
if x.shape[1] > 8192:
return self._forward_chunked(x)
else:
return self._forward(x)
class OutLayer(nn.Module):
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
super().__init__()
self.patch_size = patch_size
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, visual_embed, time_embed):
B, T, H, W, _ = visual_embed.shape
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
scale = scale[:, None, None, None, :]
shift = shift[:, None, None, None, :]
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
x = self.out_layer(visual_embed)
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
x = x.view(
B, T, H, W,
out_dim,
self.patch_size[0], self.patch_size[1], self.patch_size[2]
)
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
class TransformerEncoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
super().__init__()
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, x, time_embed, freqs, transformer_options={}):
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
shift, scale, gate = get_shift_scale_gate(self_attn_params)
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
out = self.self_attention(out, freqs, transformer_options=transformer_options)
x = apply_gate_sum(x, out, gate)
shift, scale, gate = get_shift_scale_gate(ff_params)
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
out = self.feed_forward(out)
x = apply_gate_sum(x, out, gate)
return x
class TransformerDecoderBlock(nn.Module):
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
super().__init__()
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
operations = operation_settings.get("operations")
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
# self attention
shift, scale, gate = get_shift_scale_gate(self_attn_params)
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# cross attention
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# feed forward
shift, scale, gate = get_shift_scale_gate(ff_params)
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
visual_out = self.feed_forward(visual_out)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
return visual_embed
class Kandinsky5(nn.Module):
def __init__(
self,
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
dtype=None, device=None, operations=None, **kwargs
):
super().__init__()
head_dim = sum(axes_dims)
self.rope_scale_factor = rope_scale_factor
self.in_visual_dim = in_visual_dim
self.model_dim = model_dim
self.patch_size = patch_size
self.visual_embed_dim = visual_embed_dim
self.dtype = dtype
self.device = device
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
self.text_transformer_blocks = nn.ModuleList(
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
)
self.visual_transformer_blocks = nn.ModuleList(
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
)
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
steps = seq_len if steps is None else steps
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
return freqs
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if steps_t is None:
steps_t = t_len
if steps_h is None:
steps_h = h_len
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
else:
rope_scale_factor = self.rope_scale_factor
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
if h * w >= 14080:
rope_scale_factor = (1.0, 3.16, 3.16)
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
return freqs
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
context = self.text_embeddings(context)
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
for block in self.text_transformer_blocks:
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
visual_embed = self.visual_embeddings(x)
visual_shape = visual_embed.shape[:-1]
visual_embed = visual_embed.flatten(1, -2)
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.visual_transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
else:
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
visual_embed = visual_embed.reshape(*visual_shape, -1)
return self.out_layer(visual_embed, time_embed)
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
original_dims = x.ndim
if original_dims == 4:
x = x.unsqueeze(2)
bs, c, t_len, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if time_dim_replace is not None:
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
if original_dims == 4:
out = out.squeeze(2)
return out
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)

View File

@ -1,160 +0,0 @@
import torch
from torch import nn
from .model import JointTransformerBlock
class ZImageControlTransformerBlock(JointTransformerBlock):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
modulation=True,
block_id=0,
operation_settings=None,
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
self.block_id = block_id
if block_id == 0:
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
return c_skip, c
class ZImage_Control(torch.nn.Module):
def __init__(
self,
dim: int = 3840,
n_heads: int = 30,
n_kv_heads: int = 30,
multiple_of: int = 256,
ffn_dim_multiplier: float = (8.0 / 3.0),
norm_eps: float = 1e-5,
qk_norm: bool = True,
n_control_layers=6,
control_in_dim=16,
additional_in_dim=0,
broken=False,
refiner_control=False,
dtype=None,
device=None,
operations=None,
**kwargs
):
super().__init__()
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.broken = broken
self.additional_in_dim = additional_in_dim
self.control_in_dim = control_in_dim
n_refiner_layers = 2
self.n_control_layers = n_control_layers
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(
i,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
block_id=i,
operation_settings=operation_settings,
)
for i in range(self.n_control_layers)
]
)
all_x_embedder = {}
patch_size = 2
f_patch_size = 1
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.refiner_control = refiner_control
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
if self.refiner_control:
self.control_noise_refiner = nn.ModuleList(
[
ZImageControlTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
block_id=layer_id,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
else:
self.control_noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
patch_size = 2
f_patch_size = 1
pH = pW = patch_size
B, C, H, W = control_context.shape
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_attn_mask = None
if not self.refiner_control:
for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
return control_context
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
if self.refiner_control:
if self.broken:
if layer_id == 0:
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if layer_id > 0:
out = None
for i in range(1, len(self.control_layers)):
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if out is None:
out = o
return (out, control_context)
else:
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
else:
return (None, control_context)
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

View File

@ -22,10 +22,6 @@ def modulate(x, scale):
# Core NextDiT Model #
#############################################################################
def clamp_fp16(x):
if x.dtype == torch.float16:
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class JointAttention(nn.Module):
"""Multi-head attention module."""
@ -173,7 +169,7 @@ class FeedForward(nn.Module):
# @torch.compile
def _forward_silu_gating(self, x1, x3):
return clamp_fp16(F.silu(x1) * x3)
return F.silu(x1) * x3
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@ -277,27 +273,27 @@ class JointTransformerBlock(nn.Module):
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
clamp_fp16(self.attention(
self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
))
)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
clamp_fp16(self.feed_forward(
self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
))
)
)
else:
assert adaln_input is None
x = x + self.attention_norm2(
clamp_fp16(self.attention(
self.attention(
self.attention_norm1(x),
x_mask,
freqs_cis,
transformer_options=transformer_options,
))
)
)
x = x + self.ffn_norm2(
self.feed_forward(
@ -377,7 +373,6 @@ class NextDiT(nn.Module):
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
clip_text_dim=None,
image_model=None,
device=None,
dtype=None,
@ -448,31 +443,6 @@ class NextDiT(nn.Module):
),
)
self.clip_text_pooled_proj = None
if clip_text_dim is not None:
self.clip_text_dim = clip_text_dim
self.clip_text_pooled_proj = nn.Sequential(
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear(
clip_text_dim,
clip_text_dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.time_text_embed = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024) + clip_text_dim,
min(dim, 1024),
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
self.layers = nn.ModuleList(
[
JointTransformerBlock(
@ -536,7 +506,6 @@ class NextDiT(nn.Module):
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
orig_x = x
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
@ -573,21 +542,13 @@ class NextDiT(nn.Module):
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
patches = transformer_options.get("patches", {})
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
padded_img_mask = None
x_input = x
for i, layer in enumerate(self.noise_refiner):
for layer in self.noise_refiner:
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if "noise_refiner" in patches:
for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out:
x = out["img"]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
@ -603,7 +564,7 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@ -620,36 +581,16 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
patches = transformer_options.get("patches", {})
transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
img = self.final_layer(img, adaln_input)
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
return -img
return -x

View File

@ -517,7 +517,6 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@ -542,8 +541,6 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
exception_fallback = True
if exception_fallback:
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),

View File

@ -13,12 +13,6 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
def torch_cat_if_needed(xl, dim):
if len(xl) > 1:
return torch.cat(xl, dim)
else:
return xl[0]
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
@ -49,37 +43,6 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class CarriedConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if isinstance(op, CarriedConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
out = op(x)
return out
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
@ -126,24 +89,29 @@ class Upsample(nn.Module):
stride=1,
padding=1)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
def forward(self, x):
scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
results = []
if conv_carry_in is None:
first = x[:, :, :1, :, :]
results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
x = x[:, :, 1:, :, :]
if x.shape[2] > 0:
results.append(interpolate_up(x, scale_factor))
x = torch_cat_if_needed(results, dim=2)
t = x.shape[2]
if t > 1:
a, b = x.split((1, t - 1), dim=2)
del x
b = interpolate_up(b, scale_factor)
else:
a = x
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
else:
x = interpolate_up(x, scale_factor)
if self.with_conv:
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
x = self.conv(x)
return x
@ -159,20 +127,17 @@ class Downsample(nn.Module):
stride=stride,
padding=0)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
def forward(self, x):
if self.with_conv:
if isinstance(self.conv, CarriedConv3d):
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
elif x.ndim == 4:
if x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
x = self.conv(x)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
x = self.conv(x)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
@ -218,23 +183,23 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
def forward(self, x, temb=None):
h = x
h = self.norm1(h)
h = [ self.swish(h) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = self.swish(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h)
h = self.swish(h)
h = [ self.dropout(h) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
@ -314,7 +279,6 @@ def pytorch_attention(q, k, v):
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
oom_fallback = False
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@ -325,8 +289,6 @@ def pytorch_attention(q, k, v):
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
@ -555,14 +517,9 @@ class Encoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.carried = False
if conv3d:
if not attn_resolutions:
conv_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@ -575,7 +532,6 @@ class Encoder(nn.Module):
stride=1,
padding=1)
self.time_compress = 1
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
@ -602,15 +558,10 @@ class Encoder(nn.Module):
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
else:
self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
if time_compress is not None:
self.time_compress = time_compress
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
@ -636,42 +587,15 @@ class Encoder(nn.Module):
def forward(self, x):
# timestep embedding
temb = None
if self.carried:
xl = [x[:, :, :1, :, :]]
if x.shape[2] > self.time_compress:
tc = self.time_compress
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
x = xl
else:
x = [x]
out = []
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
# downsampling
x1 = [ x1 ]
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
if len(self.down[i_level].attn) > 0:
assert i == 0 #carried should not happen if attn exists
h1 = self.down[i_level].attn[i_block](h1)
if i_level != self.num_resolutions-1:
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
out.append(h1)
conv_carry_in = conv_carry_out
h = torch_cat_if_needed(out, dim=2)
del out
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions-1:
h = self.down[i_level].downsample(h)
# middle
h = self.mid.block_1(h, temb)
@ -680,15 +604,15 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
h = [ nonlinearity(h) ]
h = conv_carry_causal_3d(h, self.conv_out)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, tanh_out=False, use_linear_attn=False,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
@ -702,18 +626,12 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
self.carried = False
if conv3d:
if not attn_resolutions and resnet_op == ResnetBlock:
conv_op = CarriedConv3d
conv_out_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d
conv_out_op = VideoConv3d
conv_op = VideoConv3d
conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@ -788,43 +706,29 @@ class Decoder(nn.Module):
temb = None
# z to block_in
h = conv_carry_causal_3d([z], self.conv_in)
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
if self.carried:
h = torch.split(h, 2, dim=2)
else:
h = [ h ]
out = []
conv_carry_in = None
# upsampling
for i, h1 in enumerate(h):
conv_carry_out = []
if i == len(h) - 1:
conv_carry_out = None
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
if len(self.up[i_level].attn) > 0:
assert i == 0 #carried should not happen if attn exists
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
if i_level != 0:
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb, **kwargs)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, **kwargs)
if i_level != 0:
h = self.up[i_level].upsample(h)
h1 = self.norm_out(h1)
h1 = [ nonlinearity(h1) ]
h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
if self.tanh_out:
h1 = torch.tanh(h1)
out.append(h1)
conv_carry_in = conv_carry_out
# end
if self.give_pre_end:
return h
out = torch_cat_if_needed(out, dim=2)
return out
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h, **kwargs)
if self.tanh_out:
h = torch.tanh(h)
return h

View File

@ -218,24 +218,9 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations,
)
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
if timestep_zero_index is not None:
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
else:
return torch.addcmul(y, gate, x)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
if timestep_zero_index is not None:
actual_batch = shift.size(0) // 2
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
else:
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward(
self,
@ -244,19 +229,14 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep_zero_index=None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
@ -271,15 +251,15 @@ class QwenImageTransformerBlock(nn.Module):
del img_modulated
del txt_modulated
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@ -322,7 +302,6 @@ class QwenImageTransformer2DModel(nn.Module):
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",
image_model=None,
final_layer=True,
dtype=None,
@ -335,7 +314,6 @@ class QwenImageTransformer2DModel(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.default_ref_method = default_ref_method
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
@ -413,14 +391,11 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
timestep_zero_index = None
if ref_latents is not None:
h = 0
w = 0
index = 0
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
timestep_zero = ref_method == "index_timestep_zero"
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
for ref in ref_latents:
if index_ref_method:
index += 1
@ -440,10 +415,6 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@ -475,7 +446,7 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
@ -487,7 +458,6 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
timestep_zero_index=timestep_zero_index,
transformer_options=transformer_options,
)
@ -504,9 +474,6 @@ class QwenImageTransformer2DModel(nn.Module):
if add is not None:
hidden_states[:, :add.shape[1]] += add
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@ -568,10 +568,7 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -766,10 +763,7 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -868,10 +862,7 @@ class CameraWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -1335,19 +1326,16 @@ class WanModel_S2V(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
x = block(x, e=e0, freqs=freqs, context=context)
if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head
@ -1586,10 +1574,7 @@ class HumoWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -523,10 +523,7 @@ class AnimateWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -320,16 +320,8 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
if isinstance(model, comfy.model_base.Kandinsky5):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
key_map["transformer.{}".format(key_lora)] = k
return key_map

View File

@ -47,7 +47,6 @@ import comfy.ldm.chroma_radiance.model
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.model_management
import comfy.patcher_extension
@ -135,7 +134,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@ -330,6 +329,18 @@ class BaseModel(torch.nn.Module):
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
unet_state_dict = self.diffusion_model.state_dict()
if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION:
@ -1110,10 +1121,6 @@ class Lumina2(BaseModel):
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
clip_text_pooled = kwargs["pooled_output"] # Newbie
if clip_text_pooled is not None:
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
return out
class WAN21(BaseModel):
@ -1635,49 +1642,3 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(False)
return out
class Kandinsky5(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
device = kwargs["device"]
image = torch.zeros_like(noise)
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :1]
else:
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
time_dim_replace = kwargs.get("time_dim_replace", None)
if time_dim_replace is not None:
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
return out
class Kandinsky5Image(Kandinsky5):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
return None

View File

@ -6,6 +6,20 @@ import math
import logging
import torch
def detect_layer_quantization(metadata):
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError("Invalid quantization metadata format")
return None
def count_blocks(state_dict_keys, prefix_string):
count = 0
while True:
@ -180,10 +194,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
dit_config["meanflow_sum"] = True
else:
dit_config["vision_in_dim"] = None
dit_config["meanflow_sum"] = False
return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
@ -196,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 2000
dit_config["out_channels"] = 128
dit_config["global_modulation"] = True
dit_config["vec_in_dim"] = None
dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0
dit_config["txt_ids_dims"] = [3]
patch_size = 1
else:
dit_config["image_model"] = "flux"
@ -211,7 +223,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 10000
dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True
dit_config["txt_ids_dims"] = []
patch_size = 2
dit_config["in_channels"] = 16
@ -234,8 +245,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
else:
dit_config["vec_in_dim"] = None
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
@ -259,17 +268,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
@ -429,9 +429,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
if ctd_weight is not None:
dit_config["clip_text_dim"] = ctd_weight.shape[0]
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
@ -620,24 +617,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
dit_config["model_dim"] = model_dim
if model_dim in [4096, 2560]: # pro video and lite image
dit_config["axes_dims"] = (32, 48, 48)
if model_dim == 2560: # lite image
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
elif model_dim == 1792: # lite video
dit_config["axes_dims"] = (16, 24, 24)
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
dit_config["image_model"] = "kandinsky5"
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@ -780,11 +759,22 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config)
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
model_config.optimizations["fp8"] = False
else:
model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
if quant_config:
model_config.quant_config = quant_config
logging.info("Detected mixed precision quantization")
layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
return model_config

View File

@ -1492,20 +1492,6 @@ def extended_fp16_support():
return True
LORA_COMPUTE_DTYPES = {}
def lora_compute_dtype(device):
dtype = LORA_COMPUTE_DTYPES.get(device, None)
if dtype is not None:
return dtype
if should_use_fp16(device):
dtype = torch.float16
else:
dtype = torch.float32
LORA_COMPUTE_DTYPES[device] = dtype
return dtype
def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:

View File

@ -35,7 +35,6 @@ import comfy.model_management
import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -127,23 +126,36 @@ class LowVramPatch:
def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key
self.patches = patches
self.convert_func = convert_func # TODO: remove
self.convert_func = convert_func
self.set_func = set_func
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
intermediate_dtype = weight.dtype
if self.convert_func is not None:
weight = self.convert_func(weight, inplace=False)
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is None:
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
else:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
if self.set_func is not None:
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
else:
return out
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
if model_dtype is None:
model_dtype = weight.dtype
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
@ -454,9 +466,6 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_noise_refiner_patch(self, patch):
self.set_model_patch(patch, "noise_refiner")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
@ -621,11 +630,10 @@ class ModelPatcher:
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(temp_dtype, copy=True)
temp_weight = weight.to(torch.float32, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)
@ -669,18 +677,12 @@ class ModelPatcher:
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
def check_module_offload_mem(key):
if key in self.patches:
return low_vram_patch_estimate_vram(self.model, key)
model_dtype = getattr(self.model, "manual_cast_dtype", None)
weight, _, _ = get_key_weight(self.model, key)
if model_dtype is None or weight is None:
return 0
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
return weight.numel() * model_dtype.itemsize
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
loading.append((module_offload_mem, module_mem, n, m, params))
return loading
@ -697,12 +699,12 @@ class ModelPatcher:
offloaded = []
offload_buffer = 0
loading.sort(reverse=True)
for i, x in enumerate(loading):
for x in loading:
module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n)
@ -775,8 +777,6 @@ class ModelPatcher:
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
if comfy.model_management.is_device_cuda(device_to):
torch.cuda.synchronize()
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@ -876,18 +876,14 @@ class ModelPatcher:
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory
if len(unload_list) > 0:
NS = comfy.model_management.NUM_STREAMS
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
module_offload_mem, module_mem, n, m, params = unload
potential_offload = module_offload_mem + sum(offload_weight_factor)
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@ -933,14 +929,12 @@ class ModelPatcher:
patch_counter += 1
cast_weight = True
if cast_weight and hasattr(m, "comfy_cast_weights"):
if cast_weight:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload)
offload_weight_factor.append(module_mem)
offload_weight_factor.pop(0)
logging.debug("freed {}".format(n))
for param in params:

View File

@ -22,7 +22,7 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import json
import contextlib
def run_every_op():
if torch.compiler.is_compiling():
@ -93,6 +93,13 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
else:
offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0
@ -104,24 +111,22 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
bias_a = bias
weight_a = weight
if s.bias is not None:
for f in s.bias_function:
bias = f(bias)
if bias_has_function:
with wf_context:
for f in s.bias_function:
bias = f(bias)
if weight_has_function or weight.dtype != dtype:
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function:
weight = f(weight)
with wf_context:
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function:
weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
if offloadable:
return weight, bias, (offload_stream, weight_a, bias_a)
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
@ -130,16 +135,13 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
os, weight_a, bias_a = offload_stream
if os is None:
return
if weight_a is not None:
device = weight_a.device
if weight is not None:
device = weight.device
else:
if bias_a is None:
if bias is None:
return
device = bias_a.device
os.wait_stream(comfy.model_management.current_stream(device))
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp:
@ -415,12 +417,22 @@ def fp8_linear(self, input):
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
else:
scale_weight = scale_weight.to(input.device)
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else:
scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
@ -441,7 +453,7 @@ class fp8_ops(manual_cast):
return None
def forward_comfy_cast_weights(self, input):
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
if not self.training:
try:
out = fp8_linear(self, input)
if out is not None:
@ -454,6 +466,59 @@ class fp8_ops(manual_cast):
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
if override_dtype is not None:
kwargs['dtype'] = override_dtype
super().__init__(*args, **kwargs)
def reset_parameters(self):
if not hasattr(self, 'scale_weight'):
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
if not scale_input:
self.scale_input = None
if not hasattr(self, 'scale_input'):
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
return None
def forward_comfy_cast_weights(self, input):
if fp8_matrix_mult:
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
if weight.numel() < input.numel(): #TODO: optimize
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else:
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def convert_weight(self, weight, inplace=False, **kwargs):
if inplace:
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight
else:
return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
if return_weight:
return weight
if inplace_update:
self.weight.data.copy_(weight)
else:
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return scaled_fp8_op
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
@ -480,9 +545,9 @@ if CUBLAS_IS_AVAILABLE:
from .quant_ops import QuantizedTensor, QUANT_ALGOS
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
_layer_quant_config = layer_quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
@ -497,14 +562,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
) -> None:
super().__init__()
if dtype is None:
dtype = MixedPrecisionOps._compute_dtype
self.factory_kwargs = {"device": device, "dtype": dtype}
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
self._has_bias = bias
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@ -524,59 +590,36 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
dtype = self.factory_kwargs["dtype"]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm:
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
if self.quant_format is None:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[self.quant_format]
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(weight_scale_key, None)
if scale is not None:
scale = scale.to(device)
layout_params = {
'scale': scale,
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if scale is not None:
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
requires_grad=False
)
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
else:
self.register_parameter("bias", None)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@ -585,16 +628,6 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if key in missing_keys:
missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
return sd
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
@ -610,8 +643,9 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
@ -622,7 +656,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
@ -631,28 +665,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
if recurse:
for module in self.children():
module._apply(fn)
for key, param in self._parameters.items():
if param is None:
continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
if (
fp8_compute and

View File

@ -238,9 +238,6 @@ class QuantizedTensor(torch.Tensor):
def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs)
def storage(self):
return self._qdata.storage()
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
@ -252,6 +249,12 @@ def _create_transformed_qtensor(qt, transform_fn):
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_dtype is not None and target_dtype != qt.dtype:
logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
@ -271,8 +274,6 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
if target_dtype is not None:
new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
@ -338,9 +339,7 @@ def generic_copy_(func, args, kwargs):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
orig_dtype = qt_dest._layout_params["orig_dtype"]
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
qt_dest._layout_params["orig_dtype"] = orig_dtype
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
@ -398,23 +397,17 @@ class TensorCoreFP8Layout(QuantizedLayout):
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype
if isinstance(scale, str) and scale == "recalculate":
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
tensor_info = torch.finfo(tensor.dtype)
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
if scale is not None:
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
tensor = tensor * (1.0 / scale).to(tensor.dtype)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)

View File

@ -53,8 +53,6 @@ import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
import comfy.model_patcher
import comfy.lora
@ -99,7 +97,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
if no_init:
return
params = target.params.copy()
@ -127,32 +125,9 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None
if len(state_dict) > 0:
if isinstance(state_dict, list):
for c in state_dict:
m, u = self.load_sd(c)
if len(m) > 0:
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected: {}".format(u))
else:
m, u = self.load_sd(state_dict, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
if params['device'] == load_device:
model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None
@ -217,7 +192,6 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
@ -265,7 +239,6 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
@ -495,7 +468,7 @@ class VAE:
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@ -507,10 +480,8 @@ class VAE:
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
#This is likely to significantly over-estimate with single image or low frame counts as the
#implementation is able to completely skip caching. Rework if used as an image only VAE
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
@ -549,10 +520,8 @@ class VAE:
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
# Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
@ -771,8 +740,6 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = None
do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0]
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -989,17 +956,16 @@ class CLIPType(Enum):
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20
OVIS = 21
KANDINSKY5 = 22
KANDINSKY5_IMAGE = 23
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
for p in ckpt_paths:
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if model_options.get("custom_operations", None) is None:
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
if metadata is not None:
quant_metadata = metadata.get("_quantization_metadata", None)
if quant_metadata is not None:
sd["_quantization_metadata"] = quant_metadata
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
@ -1021,7 +987,6 @@ class TEModel(Enum):
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
QWEN3_2B = 17
def detect_te_model(sd):
@ -1055,12 +1020,9 @@ def detect_te_model(sd):
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
if 'model.layers.0.self_attn.q_norm.weight' in sd:
if weight.shape[0] == 2560:
return TEModel.QWEN3_4B
elif weight.shape[0] == 2048:
return TEModel.QWEN3_2B
return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight']
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
@ -1116,7 +1078,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
@ -1140,7 +1102,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
@ -1169,7 +1131,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif te_model == TEModel.QWEN25_3B:
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
@ -1188,16 +1150,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
elif te_model == TEModel.QWEN3_2B:
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HIDREAM:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
@ -1240,12 +1199,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
elif clip_type == CLIPType.KANDINSKY5:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@ -1258,10 +1211,19 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters = 0
for c in clip_data:
if "_quantization_metadata" in c:
c.pop("_quantization_metadata")
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
logging.warning("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected: {}".format(u))
return clip
def load_gligen(ckpt_path):
@ -1320,10 +1282,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
@ -1332,22 +1290,18 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.quant_config is not None:
if model_config.scaled_fp8 is not None:
weight_dtype = None
if custom_operations is not None:
model_config.custom_operations = custom_operations
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
if model_config.quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config.clip_vision_prefix is not None:
@ -1365,33 +1319,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip:
if te_model_options.get("custom_operations", None) is None:
scaled_fp8_list = []
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
if k.endswith(".scaled_fp8"):
scaled_fp8_list.append(k[:-len("scaled_fp8")])
if len(scaled_fp8_list) > 0:
out_sd = {}
for k in sd:
skip = False
for pref in scaled_fp8_list:
skip = skip or k.startswith(pref)
if not skip:
out_sd[k] = sd[k]
for pref in scaled_fp8_list:
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
for k in quant_sd:
out_sd[k] = quant_sd[k]
sd = out_sd
clip_target = model_config.clip_target(state_dict=sd)
if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
if len(m_filter) > 0:
logging.warning("clip missing: {}".format(m))
else:
logging.debug("clip missing: {}".format(m))
if len(u) > 0:
logging.debug("clip unexpected {}:".format(u))
else:
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
@ -1438,9 +1381,6 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
if len(temp_sd) > 0:
sd = temp_sd
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd)
@ -1471,7 +1411,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.quant_config is not None:
if model_config.scaled_fp8 is not None:
weight_dtype = None
if dtype is None:
@ -1479,15 +1419,12 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
else:
unet_dtype = dtype
if model_config.quant_config is not None:
if model_config.layer_quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if custom_operations is not None:
model_config.custom_operations = custom_operations
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False):
model_config.optimizations["fp8"] = True
@ -1526,9 +1463,6 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if vae is not None:
vae_sd = vae.get_sd()
if metadata is None:
metadata = {}
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)

View File

@ -107,17 +107,29 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
config[k] = v
operations = model_options.get("custom_operations", None)
quant_config = model_options.get("quantization_metadata", None)
scaled_fp8 = None
quantization_metadata = model_options.get("quantization_metadata", None)
if operations is None:
if quant_config is not None:
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
logging.info("Using MixedPrecisionOps for text encoder")
layer_quant_config = None
if quantization_metadata is not None:
layer_quant_config = json.loads(quantization_metadata).get("layers", None)
if layer_quant_config is not None:
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
else:
operations = comfy.ops.manual_cast
# Fallback to scaled_fp8_ops for backward compatibility
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
if scaled_fp8 is not None:
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
self.num_layers = self.transformer.num_layers
@ -135,7 +147,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
self.return_attention_masks = return_attention_masks
self.execution_device = None
if layer == "hidden":
assert layer_idx is not None
@ -152,7 +163,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
self.execution_device = options.get("execution_device", self.execution_device)
if isinstance(self.layer, list) or self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
@ -165,7 +175,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
self.execution_device = None
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
@ -249,11 +258,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens):
if self.execution_device is None:
device = self.transformer.get_input_embeddings().weight.device
else:
device = self.execution_device
device = self.transformer.get_input_embeddings().weight.device
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None

View File

@ -21,14 +21,12 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
from . import supported_models_base
from . import latent_formats
from . import diffusers_convert
import comfy.model_management
class SD15(supported_models_base.BASE):
unet_config = {
@ -542,7 +540,7 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
memory_usage_factor = 1.6
memory_usage_factor = 1.2
text_encoder_key_prefix = ["text_encoders."]
@ -966,7 +964,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
@ -1027,15 +1025,7 @@ class ZImage(Lumina2):
"shift": 3.0,
}
memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16)
memory_usage_factor = 1.7
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
@ -1296,7 +1286,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.044
memory_usage_factor = 0.038
def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device)
@ -1339,7 +1329,7 @@ class Omnigen2(supported_models_base.BASE):
"shift": 2.6,
}
memory_usage_factor = 1.95 #TODO
memory_usage_factor = 1.65 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
@ -1404,7 +1394,7 @@ class HunyuanImage21(HunyuanVideo):
latent_format = latent_formats.HunyuanImage21
memory_usage_factor = 8.7
memory_usage_factor = 7.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1482,60 +1472,7 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
class Kandinsky5(supported_models_base.BASE):
unet_config = {
"image_model": "kandinsky5",
}
sampling_settings = {
"shift": 10.0,
}
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 1.25 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
class Kandinsky5Image(Kandinsky5):
unet_config = {
"image_model": "kandinsky5",
"model_dim": 2560,
"visual_embed_dim": 64,
}
sampling_settings = {
"shift": 3.0,
}
latent_format = latent_formats.Flux
memory_usage_factor = 1.25 #TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5Image(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid]

View File

@ -17,7 +17,6 @@
"""
import torch
import logging
from . import model_base
from . import utils
from . import latent_formats
@ -50,7 +49,8 @@ class BASE:
manual_cast_dtype = None
custom_operations = None
quant_config = None # quantization configuration for mixed precision
scaled_fp8 = None
layer_quant_config = None # Per-layer quantization configuration for mixed precision
optimizations = {"fp8": False}
@classmethod
@ -118,7 +118,3 @@ class BASE:
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype
def __getattr__(self, name):
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
return None

View File

@ -7,10 +7,10 @@ from transformers import T5TokenizerFast
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
if t5xxl_quantization_metadata is not None:
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = t5xxl_quantization_metadata
model_options["scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def te(dtype_t5=None, t5_quantization_metadata=None):
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
class CosmosTEModel_(CosmosT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5_quantization_metadata is not None:
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5_quantization_metadata is not None:
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_
@ -159,13 +159,15 @@ class Flux2TEModel(sd1_clip.SD1ClipModel):
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra
def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
class Flux2TEModel_(Flux2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if pruned:
model_options = model_options.copy()

View File

@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class MochiTEModel_(MochiT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5_quantization_metadata is not None:
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module):
return self.llama.load_sd(sd)
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
class HiDreamTEModel_(HiDreamTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5_quantization_metadata is not None:
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if llama_quantization_metadata is not None:
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
model_options["llama_scaled_fp8"] = llama_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HiDreamTEModel_

View File

@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
model_options["scaled_fp8"] = llama_scaled_fp8
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel):
else:
return super().load_sd(sd)
def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
class QwenImageTEModel_(HunyuanImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)

View File

@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast
import torch
import os
import numbers
import comfy.utils
def llama_detect(state_dict, prefix=""):
out = {}
@ -14,9 +14,12 @@ def llama_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_llama"] = state_dict[t5_key].dtype
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
if quant is not None:
out["llama_quantization_metadata"] = quant
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
if "_quantization_metadata" in state_dict:
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
return out
@ -28,10 +31,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
class LLAMAModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
model_options["scaled_fp8"] = llama_scaled_fp8
textmodel_json_config = {}
vocab_size = model_options.get("vocab_size", None)
@ -158,11 +161,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
return self.llama.load_sd(sd)
def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
model_options["llama_scaled_fp8"] = llama_scaled_fp8
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HunyuanVideoClipModel_

View File

@ -1,68 +0,0 @@
from comfy import sd1_clip
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
from .llama import Qwen25_7BVLI
class Kandinsky5Tokenizer(QwenImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Kandinsky5TEModel(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
return cond, l_pooled, extra
def set_clip_options(self, options):
super().set_clip_options(options)
self.clip_l.set_clip_options(options)
def reset_clip_options(self):
super().reset_clip_options()
self.clip_l.reset_clip_options()
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return super().load_sd(sd)
def te(dtype_llama=None, llama_quantization_metadata=None):
class Kandinsky5TEModel_(Kandinsky5TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Kandinsky5TEModel_

View File

@ -100,28 +100,6 @@ class Qwen3_4BConfig:
rope_scale = None
final_norm: bool = True
@dataclass
class Ovis25_2BConfig:
vocab_size: int = 151936
hidden_size: int = 2048
intermediate_size: int = 6144
num_hidden_layers: int = 28
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass
class Qwen25_7BVLI_Config:
vocab_size: int = 152064
@ -564,15 +542,6 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Ovis25_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Ovis25_2BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()

View File

@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
if model_type == "gemma2_2b":
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)

View File

@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
def te(dtype_llama=None, llama_quantization_metadata=None):
def te(dtype_llama=None, llama_scaled_fp8=None):
class Omnigen2TEModel_(Omnigen2Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -1,66 +0,0 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
import torch
import numbers
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
class OvisTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Ovis25_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
class OvisTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen3_2b"][0]
count_im_start = 0
if template_end == -1:
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 4004 and count_im_start < 1:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 1):
if tok_pairs[template_end + 1][0] == 25:
template_end += 1
out = out[:, template_end:]
return out, pooled, {}
def te(dtype_llama=None, llama_quantization_metadata=None):
class OvisTEModel_(OvisTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return OvisTEModel_

View File

@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
class PixArtTEModel_(PixArtT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5_quantization_metadata is not None:
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
if dtype is None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
return out, pooled, extra
def te(dtype_llama=None, llama_quantization_metadata=None):
def te(dtype_llama=None, llama_scaled_fp8=None):
class QwenImageTEModel_(QwenImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -6,15 +6,14 @@ import torch
import os
import comfy.model_management
import logging
import comfy.utils
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
if t5xxl_quantization_metadata is not None:
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
if t5xxl_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = t5xxl_quantization_metadata
model_options["scaled_fp8"] = t5xxl_scaled_fp8
model_options = {**model_options, "model_name": "t5xxl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@ -26,9 +25,9 @@ def t5_xxl_detect(state_dict, prefix=""):
if t5_key in state_dict:
out["dtype_t5"] = state_dict[t5_key].dtype
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
if quant is not None:
out["t5_quantization_metadata"] = quant
scaled_fp8_key = "{}scaled_fp8".format(prefix)
if scaled_fp8_key in state_dict:
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
return out
@ -157,11 +156,11 @@ class SD3ClipModel(torch.nn.Module):
else:
return self.t5xxl.load_sd(sd)
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5_quantization_metadata is not None:
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

View File

@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
def te(dtype_t5=None, t5_quantization_metadata=None):
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
class WanTEModel(WanT5Model):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5_quantization_metadata is not None:
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["quantization_metadata"] = t5_quantization_metadata
model_options["scaled_fp8"] = t5xxl_scaled_fp8
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)

View File

@ -34,9 +34,12 @@ class ZImageTEModel(sd1_clip.SD1ClipModel):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
def te(dtype_llama=None, llama_quantization_metadata=None):
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:

View File

@ -29,7 +29,6 @@ import itertools
from torch.nn.functional import interpolate
from einops import rearrange
from comfy.cli_args import args
import json
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
@ -53,7 +52,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
@ -803,17 +802,12 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
return None
return f.read(length_of_header)
ATTR_UNSET={}
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1], ATTR_UNSET)
if value is ATTR_UNSET:
delattr(obj, attrs[-1])
else:
setattr(obj, attrs[-1], value)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], value)
return prev
def set_attr_param(obj, attr, value):
@ -1200,68 +1194,3 @@ def unpack_latents(combined_latent, latent_shapes):
else:
output_tensors = combined_latent
return output_tensors
def detect_layer_quantization(state_dict, prefix):
for k in state_dict:
if k.startswith(prefix) and k.endswith(".comfy_quant"):
logging.info("Found quantization metadata version 1")
return {"mixed_ops": True}
return None
def convert_old_quants(state_dict, model_prefix="", metadata={}):
if metadata is None:
metadata = {}
quant_metadata = None
if "_quantization_metadata" not in metadata:
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
if scaled_fp8_key in state_dict:
scaled_fp8_weight = state_dict[scaled_fp8_key]
scaled_fp8_dtype = scaled_fp8_weight.dtype
if scaled_fp8_dtype == torch.float32:
scaled_fp8_dtype = torch.float8_e4m3fn
if scaled_fp8_weight.nelement() == 2:
full_precision_matrix_mult = True
else:
full_precision_matrix_mult = False
out_sd = {}
layers = {}
for k in list(state_dict.keys()):
if not k.startswith(model_prefix):
out_sd[k] = state_dict[k]
continue
k_out = k
w = state_dict.pop(k)
layer = None
if k_out.endswith(".scale_weight"):
layer = k_out[:-len(".scale_weight")]
k_out = "{}.weight_scale".format(layer)
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf
if k_out.endswith(".scale_input"):
layer = k_out[:-len(".scale_input")]
k_out = "{}.input_scale".format(layer)
if w.item() == 1.0:
continue
out_sd[k_out] = w
state_dict = out_sd
quant_metadata = {"layers": layers}
else:
quant_metadata = json.loads(metadata["_quantization_metadata"])
if quant_metadata is not None:
layers = quant_metadata["layers"]
for k, v in layers.items():
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata

View File

@ -5,20 +5,19 @@ This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility.
"""
from typing import Any
from typing import Any, Dict
from comfy.cli_args import args
# Default server capabilities
SERVER_FEATURE_FLAGS: dict[str, Any] = {
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
}
def get_connection_feature(
sockets_metadata: dict[str, dict[str, Any]],
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str,
default: Any = False
@ -42,7 +41,7 @@ def get_connection_feature(
def supports_feature(
sockets_metadata: dict[str, dict[str, Any]],
sockets_metadata: Dict[str, Dict[str, Any]],
sid: str,
feature_name: str
) -> bool:
@ -60,7 +59,7 @@ def supports_feature(
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> dict[str, Any]:
def get_server_features() -> Dict[str, Any]:
"""
Get the server's feature flags.

View File

@ -1,4 +1,4 @@
from typing import NamedTuple
from typing import Type, List, NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
class ComfyAPIWithVersion(NamedTuple):
version: str
api_class: type[ComfyAPIBase]
api_class: Type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version:
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
return packaging_version.parse(version_str)
registered_versions: list[ComfyAPIWithVersion] = []
registered_versions: List[ComfyAPIWithVersion] = []
def register_versions(versions: list[ComfyAPIWithVersion]):
def register_versions(versions: List[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version))
global registered_versions
registered_versions = versions
def get_all_versions() -> list[ComfyAPIWithVersion]:
def get_all_versions() -> List[ComfyAPIWithVersion]:
"""
Returns a list of all registered ComfyAPI versions.
"""

View File

@ -8,7 +8,7 @@ import os
import textwrap
import threading
from enum import Enum
from typing import Optional, get_origin, get_args, get_type_hints
from typing import Optional, Type, get_origin, get_args, get_type_hints
class TypeTracker:
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
return result_container["result"]
@classmethod
def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
"""
Creates a new class with synchronous versions of all async methods.
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
@classmethod
def _generate_imports(
cls, async_class: type, type_tracker: TypeTracker
cls, async_class: Type, type_tracker: TypeTracker
) -> list[str]:
"""Generate import statements for the stub file."""
imports = []
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
return imports
@classmethod
def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
"""Extract class attributes that are classes themselves."""
class_attributes = []
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
def _generate_inner_class_stub(
cls,
name: str,
attr: type,
attr: Type,
indent: str = " ",
type_tracker: Optional[TypeTracker] = None,
) -> list[str]:
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
return processed
@classmethod
def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
"""
Generate a .pyi stub file for the sync class to help IDEs with type checking.
"""
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
logging.error(traceback.format_exc())
def create_sync_class(async_class: type, thread_pool_size=10) -> type:
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
"""
Creates a sync version of an async class

View File

@ -1,4 +1,4 @@
from typing import TypeVar
from typing import Type, TypeVar
class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass")
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
)
return cls._instances[cls]
def inject_instance(cls: type[T], instance: T) -> None:
def inject_instance(cls: Type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation"
)
SingletonMetaclass._instances[cls] = instance
def get_instance(cls: type[T], *args, **kwargs) -> T:
def get_instance(cls: Type[T], *args, **kwargs) -> T:
"""
Gets the singleton instance of the class, creating it if it doesn't exist.
"""

View File

@ -1,15 +1,15 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import Type, TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io_public as io
from . import _ui_public as ui
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io as io
from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
@ -80,7 +80,7 @@ class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
Called when an extension is loaded.
This should be used to initialize any global resources needed by the extension.
This should be used to initialize any global resources neeeded by the extension.
"""
@abstractmethod
@ -113,7 +113,7 @@ ComfyAPI = ComfyAPI_latest
if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
# create new aliases for io and ui

View File

@ -1,5 +1,5 @@
import torch
from typing import TypedDict, Optional
from typing import TypedDict, List, Optional
ImageInput = torch.Tensor
"""
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
Optional noise mask tensor in the same format as samples.
"""
batch_index: Optional[list[int]]
batch_index: Optional[List[int]]

View File

@ -4,7 +4,7 @@ from fractions import Fraction
from typing import Optional, Union, IO
import io
import av
from .._util import VideoContainer, VideoCodec, VideoComponents
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
class VideoInput(ABC):
"""

View File

@ -3,14 +3,14 @@ from av.container import InputContainer
from av.subtitles.stream import SubtitleStream
from fractions import Fraction
from typing import Optional
from .._input import AudioInput, VideoInput
from comfy_api.latest._input import AudioInput, VideoInput
import av
import io
import json
import numpy as np
import math
import torch
from .._util import VideoContainer, VideoCodec, VideoComponents
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
def container_to_output_format(container_format: str | None) -> str | None:
@ -337,7 +337,7 @@ class VideoFromComponents(VideoInput):
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
if format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams

View File

@ -4,8 +4,7 @@ import copy
import inspect
from abc import ABC, abstractmethod
from collections import Counter
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
from typing_extensions import NotRequired, final
@ -26,7 +25,7 @@ if TYPE_CHECKING:
from comfy_api.input import VideoInput
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from ._resources import Resources, ResourcesLocal
from comfy_api.latest._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL
@ -151,9 +150,6 @@ class _IO_V3:
def __init__(self):
pass
def validate(self):
pass
@property
def io_type(self):
return self.Parent.io_type
@ -186,9 +182,6 @@ class Input(_IO_V3):
def get_io_type(self):
return _StringIOType(self.io_type)
def get_all(self) -> list[Input]:
return [self]
class WidgetInput(Input):
'''
Base class for a V3 Input with widget.
@ -568,8 +561,6 @@ class Conditioning(ComfyTypeIO):
'''Used by WAN Camera.'''
time_dim_concat: NotRequired[torch.Tensor]
'''Used by WAN Phantom Subject.'''
time_dim_replace: NotRequired[torch.Tensor]
'''Used by Kandinsky5 I2V.'''
CondList = list[tuple[torch.Tensor, PooledDict]]
Type = CondList
@ -774,13 +765,6 @@ class AudioEncoder(ComfyTypeIO):
class AudioEncoderOutput(ComfyTypeIO):
Type = Any
@comfytype(io_type="TRACKS")
class Tracks(ComfyTypeIO):
class TrackDict(TypedDict):
track_path: torch.Tensor
track_visibility: torch.Tensor
Type = TrackDict
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@ -830,61 +814,13 @@ class MultiType:
else:
return super().as_dict()
@comfytype(io_type="COMFY_MATCHTYPE_V3")
class MatchType(ComfyTypeIO):
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
self.template_id = template_id
# account for syntactic sugar
if not isinstance(allowed_types, Iterable):
allowed_types = [allowed_types]
for t in allowed_types:
if not isinstance(t, type):
if not isinstance(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
else:
if not issubclass(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
self.allowed_types = allowed_types
def as_dict(self):
return {
"template_id": self.template_id,
"allowed_types": ",".join([t.io_type for t in self.allowed_types]),
}
class Input(Input):
def __init__(self, id: str, template: MatchType.Template,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
class Output(Output):
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
class DynamicInput(Input, ABC):
'''
Abstract class for dynamic input registration.
'''
@abstractmethod
def get_dynamic(self) -> list[Input]:
return []
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
pass
...
class DynamicOutput(Output, ABC):
'''
@ -894,223 +830,99 @@ class DynamicOutput(Output, ABC):
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
@abstractmethod
def get_dynamic(self) -> list[Output]:
return []
...
@comfytype(io_type="COMFY_AUTOGROW_V3")
class Autogrow(ComfyTypeI):
Type = dict[str, Any]
_MaxNames = 100 # NOTE: max 100 names for sanity
class _AutogrowTemplate:
def __init__(self, input: Input):
# dynamic inputs are not allowed as the template input
assert(not isinstance(input, DynamicInput))
self.input = copy.copy(input)
if isinstance(self.input, WidgetInput):
self.input.force_input = True
self.names: list[str] = []
self.cached_inputs = {}
def _create_input(self, input: Input, name: str):
new_input = copy.copy(self.input)
new_input.id = name
return new_input
def _create_cached_inputs(self):
for name in self.names:
self.cached_inputs[name] = self._create_input(self.input, name)
def get_all(self) -> list[Input]:
return list(self.cached_inputs.values())
def as_dict(self):
return prune_dict({
"input": create_input_dict_v1([self.input]),
})
def validate(self):
self.input.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
real_inputs = []
for name, input in self.cached_inputs.items():
if name in live_inputs:
real_inputs.append(input)
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
class TemplatePrefix(_AutogrowTemplate):
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
super().__init__(input)
self.prefix = prefix
assert(min >= 0)
assert(max >= 1)
assert(max <= Autogrow._MaxNames)
self.min = min
self.max = max
self.names = [f"{self.prefix}{i}" for i in range(self.max)]
self._create_cached_inputs()
def as_dict(self):
return super().as_dict() | prune_dict({
"prefix": self.prefix,
"min": self.min,
"max": self.max,
})
class TemplateNames(_AutogrowTemplate):
def __init__(self, input: Input, names: list[str], min: int=1):
super().__init__(input)
self.names = names[:Autogrow._MaxNames]
assert(min >= 0)
self.min = min
self._create_cached_inputs()
def as_dict(self):
return super().as_dict() | prune_dict({
"names": self.names,
"min": self.min,
})
class AutogrowDynamic(ComfyTypeI):
Type = list[Any]
class Input(DynamicInput):
def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template_input = template_input
if min is not None:
assert(min >= 1)
if max is not None:
assert(max >= 1)
self.min = min
self.max = max
def get_dynamic(self) -> list[Input]:
curr_count = 1
new_inputs = []
for i in range(self.min):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = self.optional or new_input.optional
if isinstance(self.template_input, WidgetInput):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
# pretend to expand up to max
for i in range(curr_count-1, self.max):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = True
if isinstance(self.template_input, WidgetInput):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
return new_inputs
@comfytype(io_type="COMFY_COMBODYNAMIC_V3")
class ComboDynamic(ComfyTypeI):
class Input(DynamicInput):
def __init__(self, id: str):
pass
@comfytype(io_type="COMFY_MATCHTYPE_V3")
class MatchType(ComfyTypeIO):
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
self.template_id = template_id
self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
def as_dict(self):
return {
"template_id": self.template_id,
"allowed_types": "".join(t.io_type for t in self.allowed_types),
}
class Input(DynamicInput):
def __init__(self, id: str, template: MatchType.Template,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
def get_dynamic(self) -> list[Input]:
return [self]
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
def get_dynamic(self) -> list[Input]:
return self.template.get_all()
class Output(DynamicOutput):
def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
self.template = template
def get_all(self) -> list[Input]:
return [self] + self.template.get_all()
def validate(self):
self.template.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
curr_prefix = f"{curr_prefix}{self.id}."
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend
for inner_dict in d.values():
if self.id in inner_dict:
del inner_dict[self.id]
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
class DynamicCombo(ComfyTypeI):
Type = dict[str, Any]
class Option:
def __init__(self, key: str, inputs: list[Input]):
self.key = key
self.inputs = inputs
def as_dict(self):
return {
"key": self.key,
"inputs": create_input_dict_v1(self.inputs),
}
class Input(DynamicInput):
def __init__(self, id: str, options: list[DynamicCombo.Option],
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.options = options
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
# check if dynamic input's id is in live_inputs
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
key = live_inputs[self.id]
selected_option = None
for option in self.options:
if option.key == key:
selected_option = option
break
if selected_option is not None:
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
def get_dynamic(self) -> list[Input]:
return [input for option in self.options for input in option.inputs]
def get_all(self) -> list[Input]:
return [self] + [input for option in self.options for input in option.inputs]
def get_dynamic(self) -> list[Output]:
return [self]
def as_dict(self):
return super().as_dict() | prune_dict({
"options": [o.as_dict() for o in self.options],
"template": self.template.as_dict(),
})
def validate(self):
# make sure all nested inputs are validated
for option in self.options:
for input in option.inputs:
input.validate()
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
class DynamicSlot(ComfyTypeI):
Type = dict[str, Any]
class Input(DynamicInput):
def __init__(self, slot: Input, inputs: list[Input],
display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
assert(not isinstance(slot, DynamicInput))
self.slot = copy.copy(slot)
self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
optional = True
self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
self.inputs = inputs
self.force_input = None
# force widget inputs to have no widgets, otherwise this would be awkward
if isinstance(self.slot, WidgetInput):
self.force_input = True
self.slot.force_input = True
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
def get_dynamic(self) -> list[Input]:
return [self.slot] + self.inputs
def get_all(self) -> list[Input]:
return [self] + [self.slot] + self.inputs
def as_dict(self):
return super().as_dict() | prune_dict({
"slotType": str(self.slot.get_io_type()),
"inputs": create_input_dict_v1(self.inputs),
"forceInput": self.force_input,
})
def validate(self):
self.slot.validate()
for input in self.inputs:
input.validate()
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
dynamic = d.setdefault("dynamic_paths", {})
if self is not None:
dynamic[self.id] = f"{curr_prefix}{self.id}"
for i in inputs:
if not isinstance(i, DynamicInput):
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
class V3Data(TypedDict):
hidden_inputs: dict[str, Any]
dynamic_paths: dict[str, Any]
class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any,
@ -1172,7 +984,6 @@ class NodeInfoV1:
output_is_list: list[bool]=None
output_name: list[str]=None
output_tooltips: list[str]=None
output_matchtypes: list[str]=None
name: str=None
display_name: str=None
description: str=None
@ -1208,9 +1019,9 @@ class Schema:
"""Display name of node."""
category: str = "sd"
"""The category of the node, as per the "Add Node" menu."""
inputs: list[Input] = field(default_factory=list)
outputs: list[Output] = field(default_factory=list)
hidden: list[Hidden] = field(default_factory=list)
inputs: list[Input]=None
outputs: list[Output]=None
hidden: list[Hidden]=None
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False
@ -1250,11 +1061,7 @@ class Schema:
'''Validate the schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other
'''
nested_inputs: list[Input] = []
if self.inputs is not None:
for input in self.inputs:
nested_inputs.extend(input.get_all())
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
input_set = set(input_ids)
output_set = set(output_ids)
@ -1270,13 +1077,6 @@ class Schema:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0:
raise ValueError("\n".join(issues))
# validate inputs and outputs
if self.inputs is not None:
for input in self.inputs:
input.validate()
if self.outputs is not None:
for output in self.outputs:
output.validate()
def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
@ -1302,10 +1102,19 @@ class Schema:
if output.id is None:
output.id = f"_{i}_{output.io_type}_"
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
def get_v1_info(self, cls) -> NodeInfoV1:
# get V1 inputs
input = create_input_dict_v1(self.inputs, live_inputs)
input = {
"required": {}
}
if self.inputs:
for i in self.inputs:
if isinstance(i, DynamicInput):
dynamic_inputs = i.get_dynamic()
for d in dynamic_inputs:
add_to_dict_v1(d, input)
else:
add_to_dict_v1(i, input)
if self.hidden:
for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1314,24 +1123,12 @@ class Schema:
output_is_list = []
output_name = []
output_tooltips = []
output_matchtypes = []
any_matchtypes = False
if self.outputs:
for o in self.outputs:
output.append(o.io_type)
output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None)
# special handling for MatchType
if isinstance(o, MatchType.Output):
output_matchtypes.append(o.template.template_id)
any_matchtypes = True
else:
output_matchtypes.append(None)
# clear out lists that are all None
if not any_matchtypes:
output_matchtypes = None
info = NodeInfoV1(
input=input,
@ -1340,7 +1137,6 @@ class Schema:
output_is_list=output_is_list,
output_name=output_name,
output_tooltips=output_tooltips,
output_matchtypes=output_matchtypes,
name=self.node_id,
display_name=self.display_name,
category=self.category,
@ -1386,57 +1182,16 @@ class Schema:
return info
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
input = {
"required": {}
}
add_to_input_dict_v1(input, inputs, live_inputs)
return input
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
for i in inputs:
if isinstance(i, DynamicInput):
add_to_dict_v1(i, d)
if live_inputs is not None:
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
else:
add_to_dict_v1(i, d)
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
def add_to_dict_v1(i: Input, input: dict):
key = "optional" if i.optional else "required"
as_dict = i.as_dict()
# for v1, we don't want to include the optional key
as_dict.pop("optional", None)
if dynamic_dict is None:
value = (i.get_io_type(), as_dict)
else:
value = (i.get_io_type(), as_dict, dynamic_dict)
d.setdefault(key, {})[i.id] = value
input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
paths = v3_data.get("dynamic_paths", None)
if paths is None:
return values
values = values.copy()
result = {}
for key, path in paths.items():
parts = path.split(".")
current = result
for i, p in enumerate(parts):
is_last = (i == len(parts) - 1)
if is_last:
current[p] = values.pop(key, None)
else:
current = current.setdefault(p, {})
values.update(result)
return values
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@ -1556,12 +1311,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
return type_clone
@final
@ -1678,18 +1433,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
schema = cls.FINALIZE_SCHEMA()
info = schema.get_v1_info(cls, live_inputs)
info = schema.get_v1_info(cls)
input = info.input
if not include_hidden:
input.pop("hidden", None)
if return_schema:
v3_data: V3Data = {}
dynamic = input.pop("dynamic_paths", None)
if dynamic is not None:
v3_data["dynamic_paths"] = dynamic
return input, schema, v3_data
return input, schema
return input
@final
@ -1762,7 +1513,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
raise NotImplementedError
@classmethod
def validate_inputs(cls, **kwargs) -> bool | str:
def validate_inputs(cls, **kwargs) -> bool:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
raise NotImplementedError
@ -1822,7 +1573,7 @@ class NodeOutput(_NodeOutputInternal):
ui = data["ui"]
if "expand" in data:
expand = data["expand"]
return cls(*args, ui=ui, expand=expand)
return cls(args=args, ui=ui, expand=expand)
def __getitem__(self, index) -> Any:
return self.args[index]
@ -1877,7 +1628,6 @@ __all__ = [
"StyleModel",
"Gligen",
"UpscaleModel",
"LatentUpscaleModel",
"Audio",
"Video",
"SVG",
@ -1901,11 +1651,6 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
"Tracks",
# Dynamic Types
"MatchType",
# "DynamicCombo",
# "Autogrow",
# Other classes
"HiddenHolder",
"Hidden",
@ -1916,5 +1661,4 @@ __all__ = [
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
"V3Data",
]

View File

@ -1 +0,0 @@
from ._io import * # noqa: F403

View File

@ -3,8 +3,8 @@ from __future__ import annotations
import json
import os
import random
import uuid
from io import BytesIO
from typing import Type
import av
import numpy as np
@ -21,7 +21,7 @@ import folder_paths
# used for image preview
from comfy.cli_args import args
from ._io import ComfyNode, FolderType, Image, _UIOutput
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
class SavedResult(dict):
@ -82,7 +82,7 @@ class ImageSaveHelper:
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod
def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
@ -95,7 +95,7 @@ class ImageSaveHelper:
return metadata
@staticmethod
def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
@ -120,7 +120,7 @@ class ImageSaveHelper:
return metadata
@staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None:
@ -136,7 +136,7 @@ class ImageSaveHelper:
@staticmethod
def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -154,7 +154,7 @@ class ImageSaveHelper:
return results
@staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output."""
return SavedImages(
ImageSaveHelper.save_images(
@ -168,7 +168,7 @@ class ImageSaveHelper:
@staticmethod
def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult:
"""Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -190,7 +190,7 @@ class ImageSaveHelper:
@staticmethod
def get_save_animated_png_ui(
images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png(
@ -208,7 +208,7 @@ class ImageSaveHelper:
images,
filename_prefix: str,
folder_type: FolderType,
cls: type[ComfyNode] | None,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
@ -237,7 +237,7 @@ class ImageSaveHelper:
def get_save_animated_webp_ui(
images,
filename_prefix: str,
cls: type[ComfyNode] | None,
cls: Type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
@ -266,7 +266,7 @@ class AudioSaveHelper:
audio: dict,
filename_prefix: str,
folder_type: FolderType,
cls: type[ComfyNode] | None,
cls: Type[ComfyNode] | None,
format: str = "flac",
quality: str = "128k",
) -> list[SavedResult]:
@ -318,10 +318,9 @@ class AudioSaveHelper:
for key, value in metadata.items():
output_container.metadata[key] = value
layout = "mono" if waveform.shape[0] == 1 else "stereo"
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
@ -333,7 +332,7 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
@ -342,12 +341,12 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
out_stream = output_container.add_stream("flac", rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
layout=layout,
layout="mono" if waveform.shape[0] == 1 else "stereo",
)
frame.sample_rate = sample_rate
frame.pts = 0
@ -371,7 +370,7 @@ class AudioSaveHelper:
@staticmethod
def get_save_audio_ui(
audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios:
"""Save and instantly wrap for UI."""
return SavedAudios(
@ -387,7 +386,7 @@ class AudioSaveHelper:
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images(
image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
@ -411,7 +410,7 @@ class PreviewMask(PreviewImage):
class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio(
audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),
@ -437,19 +436,9 @@ class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file
self.camera_info = camera_info
self.bg_image_path = None
bg_image = kwargs.get("bg_image", None)
if bg_image is not None:
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
img = PILImage.fromarray(img_array)
temp_dir = folder_paths.get_temp_directory()
filename = f"bg_{uuid.uuid4().hex}.png"
bg_image_path = os.path.join(temp_dir, filename)
img.save(bg_image_path, compress_level=1)
self.bg_image_path = f"temp/{filename}"
def as_dict(self):
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
return {"result": [self.model_file, self.camera_info]}
class PreviewText(_UIOutput):

View File

@ -1 +0,0 @@
from ._ui import * # noqa: F403

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum
from fractions import Fraction
from typing import Optional
from .._input import ImageInput, AudioInput
from comfy_api.latest._input import ImageInput, AudioInput
class VideoCodec(str, Enum):
AUTO = "auto"

View File

@ -6,7 +6,7 @@ from comfy_api.latest import (
)
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@ -42,8 +42,4 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
"io",
"IO",
"ui",
"UI",
]

View File

@ -2,8 +2,9 @@ from comfy_api.latest import ComfyAPI_latest
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from comfy_api.internal import ComfyAPIBase
from typing import List, Type
supported_versions: list[type[ComfyAPIBase]] = [
supported_versions: List[Type[ComfyAPIBase]] = [
ComfyAPI_latest,
ComfyAPIAdapter_v0_0_2,
ComfyAPIAdapter_v0_0_1,

View File

@ -1,144 +0,0 @@
from typing import Literal
from pydantic import BaseModel, Field
class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
size: str | None = Field(None)
seed: int | None = Field(0, ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str | None = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: str | None = Field("adaptive")
seed: int | None = Field(..., ge=0, le=2147483647)
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
watermark: bool | None = Field(True)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(...)
response_format: str = Field("url")
image: list[str] | None = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: TaskStatusError | None = Field(None)
content: TaskStatusResult | None = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
"seedance-1-0-pro-fast-251015": {
"480p": 50,
"720p": 65,
"1080p": 100,
},
}

View File

@ -84,7 +84,15 @@ class GeminiSystemInstructionContent(BaseModel):
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
role: GeminiRole = Field(
...,
description="The identity of the entity that creates the message. "
"The following values are supported: "
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
"model: This indicates that the message is generated by the model. "
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
"For non-multi-turn conversations, this field can be left blank or unset.",
)
class GeminiFunctionDeclaration(BaseModel):

View File

@ -46,59 +46,21 @@ class TaskStatusVideoResult(BaseModel):
url: str | None = Field(None, description="URL for generated video")
class TaskStatusImageResult(BaseModel):
index: int = Field(..., description="Image Number0-9")
url: str = Field(..., description="URL for generated image")
class TaskStatusResults(BaseModel):
class TaskStatusVideoResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
class TaskStatusResponseData(BaseModel):
class TaskStatusVideoResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
task_result: TaskStatusResults | None = Field(None)
task_result: TaskStatusVideoResults | None = Field(None)
class TaskStatusResponse(BaseModel):
class TaskStatusVideoResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
data: TaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel):
image: str = Field(...)
class OmniProImageRequest(BaseModel):
model_name: str = Field(..., description="kling-image-o1")
resolution: str = Field(..., description="'1k' or '2k'")
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
class TextToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
class ImageToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
image: str = Field(...)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
data: TaskStatusVideoResponseData | None = Field(None)

View File

@ -0,0 +1,100 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -5,17 +5,11 @@ from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
class TripoGeometryQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
@ -67,20 +61,14 @@ class TripoSpec(str, Enum):
class TripoAnimation(str, Enum):
IDLE = "preset:idle"
WALK = "preset:walk"
RUN = "preset:run"
DIVE = "preset:dive"
CLIMB = "preset:climb"
JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash"
SHOOT = "preset:shoot"
HURT = "preset:hurt"
FALL = "preset:fall"
TURN = "preset:turn"
QUADRUPED_WALK = "preset:quadruped:walk"
HEXAPOD_WALK = "preset:hexapod:walk"
OCTOPOD_WALK = "preset:octopod:walk"
SERPENTINE_MARCH = "preset:serpentine:march"
AQUATIC_MARCH = "preset:aquatic:march"
class TripoStylizeStyle(str, Enum):
LEGO = "lego"
@ -117,11 +105,6 @@ class TripoTaskStatus(str, Enum):
BANNED = "banned"
EXPIRED = "expired"
class TripoFbxPreset(str, Enum):
BLENDER = "blender"
MIXAMO = "mixamo"
_3DSMAX = "3dsmax"
class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
file_token: str
@ -159,7 +142,6 @@ class TripoTextToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
@ -174,7 +156,6 @@ class TripoImageToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
@ -192,7 +173,6 @@ class TripoMultiviewToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
@ -239,24 +219,14 @@ class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(None, description='The size of the texture')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
class TripoTaskRequest(RootModel):
root: Union[

View File

@ -85,7 +85,7 @@ class Response1(BaseModel):
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
videos: Optional[list[Video]] = Field(None)
videos: Optional[list[Video]] = None
class VeoGenVidPollResponse(BaseModel):

View File

@ -1,27 +1,13 @@
import logging
import math
from enum import Enum
from typing import Literal, Optional, Union
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bytedance_api import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
VIDEO_TASKS_EXECUTION_TIME,
Image2ImageTaskCreationRequest,
Image2VideoTaskCreationRequest,
ImageTaskCreationResponse,
Seedream4Options,
Seedream4TaskCreationRequest,
TaskCreationResponse,
TaskImageContent,
TaskImageContentUrl,
TaskStatusResponse,
TaskTextContent,
Text2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
)
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
@ -43,6 +29,162 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
class Text2ImageModelName(str, Enum):
seedream_3 = "seedream-3-0-t2i-250415"
class Image2ImageModelName(str, Enum):
seededit_3 = "seededit-3-0-i2i-250628"
class Text2VideoModelName(str, Enum):
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
class Image2VideoModelName(str, Enum):
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
class Text2ImageTaskCreationRequest(BaseModel):
model: Text2ImageModelName = Text2ImageModelName.seedream_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
size: Optional[str] = Field(None)
seed: Optional[int] = Field(0, ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Image2ImageTaskCreationRequest(BaseModel):
model: Image2ImageModelName = Image2ImageModelName.seededit_3
prompt: str = Field(...)
response_format: Optional[str] = Field("url")
image: str = Field(..., description="Base64 encoded string or image URL")
size: Optional[str] = Field("adaptive")
seed: Optional[int] = Field(..., ge=0, le=2147483647)
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
watermark: Optional[bool] = Field(True)
class Seedream4Options(BaseModel):
max_images: int = Field(15)
class Seedream4TaskCreationRequest(BaseModel):
model: str = Field("seedream-4-0-250828")
prompt: str = Field(...)
response_format: str = Field("url")
image: Optional[list[str]] = Field(None, description="Image URLs")
size: str = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
sequential_image_generation: str = Field("disabled")
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
watermark: bool = Field(True)
class ImageTaskCreationResponse(BaseModel):
model: str = Field(...)
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
data: list = Field([], description="Contains information about the generated image(s).")
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
class TaskTextContent(BaseModel):
type: str = Field("text")
text: str = Field(...)
class TaskImageContentUrl(BaseModel):
url: str = Field(...)
class TaskImageContent(BaseModel):
type: str = Field("image_url")
image_url: TaskImageContentUrl = Field(...)
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
class Text2VideoTaskCreationRequest(BaseModel):
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
content: list[TaskTextContent] = Field(..., min_length=1)
class Image2VideoTaskCreationRequest(BaseModel):
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
class TaskCreationResponse(BaseModel):
id: str = Field(...)
class TaskStatusError(BaseModel):
code: str = Field(...)
message: str = Field(...)
class TaskStatusResult(BaseModel):
video_url: str = Field(...)
class TaskStatusResponse(BaseModel):
id: str = Field(...)
model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: Optional[TaskStatusError] = Field(None)
content: Optional[TaskStatusResult] = Field(None)
RECOMMENDED_PRESETS = [
("1024x1024 (1:1)", 1024, 1024),
("864x1152 (3:4)", 864, 1152),
("1152x864 (4:3)", 1152, 864),
("1280x720 (16:9)", 1280, 720),
("720x1280 (9:16)", 720, 1280),
("832x1248 (2:3)", 832, 1248),
("1248x832 (3:2)", 1248, 832),
("1512x648 (21:9)", 1512, 648),
("2048x2048 (1:1)", 2048, 2048),
("Custom", None, None),
]
RECOMMENDED_PRESETS_SEEDREAM_4 = [
("2048x2048 (1:1)", 2048, 2048),
("2304x1728 (4:3)", 2304, 1728),
("1728x2304 (3:4)", 1728, 2304),
("2560x1440 (16:9)", 2560, 1440),
("1440x2560 (9:16)", 1440, 2560),
("2496x1664 (3:2)", 2496, 1664),
("1664x2496 (2:3)", 1664, 2496),
("3024x1296 (21:9)", 3024, 1296),
("4096x4096 (1:1)", 4096, 4096),
("Custom", None, None),
]
# The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-lite-i2v-250428": {
"480p": 40,
"720p": 60,
"1080p": 90,
},
"seedance-1-0-pro-250528": {
"480p": 70,
"720p": 85,
"1080p": 115,
},
}
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error:
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
@ -52,6 +194,13 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
return response.data[0]["url"]
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "content") and response.content:
return response.content.video_url
return None
class ByteDanceImageNode(IO.ComfyNode):
@classmethod
@ -62,7 +211,12 @@ class ByteDanceImageNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
IO.Combo.Input(
"model",
options=Text2ImageModelName,
default=Text2ImageModelName.seedream_3,
tooltip="Model name",
),
IO.String.Input(
"prompt",
multiline=True,
@ -181,7 +335,12 @@ class ByteDanceImageEditNode(IO.ComfyNode):
category="api node/image/ByteDance",
description="Edit images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
IO.Combo.Input(
"model",
options=Image2ImageModelName,
default=Image2ImageModelName.seededit_3,
tooltip="Model name",
),
IO.Image.Input(
"image",
tooltip="The base image to edit",
@ -235,7 +394,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
async def execute(
cls,
model: str,
image: Input.Image,
image: torch.Tensor,
prompt: str,
seed: int,
guidance_scale: float,
@ -275,7 +434,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=["seedream-4-5-251128", "seedream-4-0-250828"],
options=["seedream-4-0-250828"],
tooltip="Model name",
),
IO.String.Input(
@ -300,7 +459,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
step=8,
step=64,
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@ -309,7 +468,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
default=2048,
min=1024,
max=4096,
step=8,
step=64,
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
optional=True,
),
@ -373,7 +532,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
image: Input.Image | None = None,
image: torch.Tensor = None,
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
width: int = 2048,
height: int = 2048,
@ -396,18 +555,6 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
raise ValueError(
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
)
out_num_pixels = w * h
mp_provided = out_num_pixels / 1_000_000.0
if "seedream-4-5" in model and out_num_pixels < 3686400:
raise ValueError(
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
f"but {mp_provided:.2f}MP provided."
)
if "seedream-4-0" in model and out_num_pixels < 921600:
raise ValueError(
f"Minimum image resolution that the selected model can generate is 0.92MP, "
f"but {mp_provided:.2f}MP provided."
)
n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10:
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
@ -460,8 +607,9 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default="seedance-1-0-pro-fast-251015",
options=Text2VideoModelName,
default=Text2VideoModelName.seedance_1_pro,
tooltip="Model name",
),
IO.String.Input(
"prompt",
@ -566,8 +714,9 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
default="seedance-1-0-pro-fast-251015",
options=Image2VideoModelName,
default=Image2VideoModelName.seedance_1_pro,
tooltip="Model name",
),
IO.String.Input(
"prompt",
@ -638,7 +787,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
image: Input.Image,
image: torch.Tensor,
resolution: str,
aspect_ratio: str,
duration: int,
@ -684,8 +833,9 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default="seedance-1-0-lite-i2v-250428",
options=[model.value for model in Image2VideoModelName],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
),
IO.String.Input(
"prompt",
@ -760,8 +910,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
first_frame: Input.Image,
last_frame: Input.Image,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
resolution: str,
aspect_ratio: str,
duration: int,
@ -818,8 +968,9 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
inputs=[
IO.Combo.Input(
"model",
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
default="seedance-1-0-lite-i2v-250428",
options=[Image2VideoModelName.seedance_1_lite.value],
default=Image2VideoModelName.seedance_1_lite.value,
tooltip="Model name",
),
IO.String.Input(
"prompt",
@ -883,7 +1034,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
cls,
model: str,
prompt: str,
images: Input.Image,
images: torch.Tensor,
resolution: str,
aspect_ratio: str,
duration: int,
@ -918,8 +1069,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
async def process_video_task(
cls: type[IO.ComfyNode],
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None,
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
estimated_duration: Optional[int],
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
@ -934,7 +1085,7 @@ async def process_video_task(
estimated_duration=estimated_duration,
response_model=TaskStatusResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:

View File

@ -13,7 +13,8 @@ import torch
from typing_extensions import override
import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api_nodes.apis.gemini_api import (
GeminiContent,
GeminiFileData,
@ -26,8 +27,6 @@ from comfy_api_nodes.apis.gemini_api import (
GeminiMimeType,
GeminiPart,
GeminiRole,
GeminiSystemInstructionContent,
GeminiTextPart,
Modality,
)
from comfy_api_nodes.util import (
@ -44,14 +43,6 @@ from comfy_api_nodes.util import (
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
GEMINI_IMAGE_SYS_PROMPT = (
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
"Interpret all user input—regardless of "
"format, intent, or abstraction—as literal visual directives for image composition.\n"
"If a prompt is conversational or lacks specific visual details, "
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
)
class GeminiModel(str, Enum):
@ -77,7 +68,7 @@ class GeminiImageModel(str, Enum):
async def create_image_parts(
cls: type[IO.ComfyNode],
images: Input.Image,
images: torch.Tensor,
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
@ -163,8 +154,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
image_tensors: list[torch.Tensor] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
@ -286,13 +277,6 @@ class GeminiNode(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.String.Output(),
@ -309,9 +293,7 @@ class GeminiNode(IO.ComfyNode):
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
return [
GeminiPart(
inlineData=GeminiInlineData(
@ -361,11 +343,10 @@ class GeminiNode(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: Input.Image | None = None,
images: torch.Tensor | None = None,
audio: Input.Audio | None = None,
video: Input.Video | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
@ -382,10 +363,7 @@ class GeminiNode(IO.ComfyNode):
if files is not None:
parts.extend(files)
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
# Create response
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -395,8 +373,7 @@ class GeminiNode(IO.ComfyNode):
role=GeminiRole.user,
parts=parts,
)
],
systemInstruction=gemini_system_prompt,
]
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@ -546,13 +523,6 @@ class GeminiImage(IO.ComfyNode):
"'IMAGE+TEXT' to return both the generated image and a text response.",
optional=True,
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.Image.Output(),
@ -572,11 +542,10 @@ class GeminiImage(IO.ComfyNode):
prompt: str,
model: str,
seed: int,
images: Input.Image | None = None,
images: torch.Tensor | None = None,
files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT",
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
@ -590,10 +559,6 @@ class GeminiImage(IO.ComfyNode):
if files is not None:
parts.extend(files)
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -605,7 +570,6 @@ class GeminiImage(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=None if aspect_ratio == "auto" else image_config,
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
@ -676,13 +640,6 @@ class GeminiImage2(IO.ComfyNode):
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
),
],
outputs=[
IO.Image.Output(),
@ -705,9 +662,8 @@ class GeminiImage2(IO.ComfyNode):
aspect_ratio: str,
resolution: str,
response_modalities: str,
images: Input.Image | None = None,
images: torch.Tensor | None = None,
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
@ -723,10 +679,6 @@ class GeminiImage2(IO.ComfyNode):
if aspect_ratio != "auto":
image_config.aspectRatio = aspect_ratio
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
@ -738,7 +690,6 @@ class GeminiImage2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,

View File

@ -6,7 +6,6 @@ For source of truth on the allowed permutations of request fields, please refere
import logging
import math
import re
import torch
from typing_extensions import override
@ -50,16 +49,12 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName,
)
from comfy_api_nodes.apis.kling_api import (
ImageToVideoWithAudioRequest,
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
OmniProFirstLastFrameRequest,
OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
TaskStatusResponse,
TextToVideoWithAudioRequest,
TaskStatusVideoResponse,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -105,6 +100,10 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
MODE_TEXT2VIDEO = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
@ -125,6 +124,8 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
MODE_START_END_FRAME = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
@ -209,36 +210,7 @@ VOICES_CONFIG = {
}
def normalize_omni_prompt_references(prompt: str) -> str:
"""
Rewrites Kling Omni-style placeholders used in the app, like:
@image, @image1, @image2, ... @imageN
@video, @video1, @video2, ... @videoN
into the API-compatible form:
<<<image_1>>>, <<<image_2>>>, ...
<<<video_1>>>, <<<video_2>>>, ...
This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
"""
if not prompt:
return prompt
def _image_repl(match):
return f"<<<image_{match.group('idx') or '1'}>>>"
def _video_repl(match):
return f"<<<video_{match.group('idx') or '1'}>>>"
# (?<!\w) avoids matching e.g. "test@image.com"
# (?!\w) makes sure we only match @image / @image<digits> and not @imageFoo
prompt = re.sub(r"(?<!\w)@image(?P<idx>\d*)(?!\w)", _image_repl, prompt)
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusVideoResponse) -> IO.NodeOutput:
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@ -246,9 +218,8 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusResponse,
response_model=TaskStatusVideoResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -479,12 +450,12 @@ async def execute_image2video(
task_id = task_creation_response.data.task_id
final_response = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@ -748,7 +719,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
IO.Combo.Input(
"mode",
options=modes,
default=modes[8],
default=modes[4],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@ -830,7 +801,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusResponse,
response_model=TaskStatusVideoResponse,
data=OmniProText2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -893,7 +864,6 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
@ -925,7 +895,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusResponse,
response_model=TaskStatusVideoResponse,
data=OmniProFirstLastFrameRequest(
model_name=model_name,
prompt=prompt,
@ -980,7 +950,6 @@ class OmniProImageToVideoNode(IO.ComfyNode):
duration: int,
reference_images: Input.Image,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
if get_number_of_images(reference_images) > 7:
raise ValueError("The maximum number of reference images is 7.")
@ -993,7 +962,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusResponse,
response_model=TaskStatusVideoResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1054,7 +1023,6 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
@ -1077,7 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusResponse,
response_model=TaskStatusVideoResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1135,7 +1103,6 @@ class OmniProEditVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
@ -1158,7 +1125,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=TaskStatusResponse,
response_model=TaskStatusVideoResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1171,90 +1138,6 @@ class OmniProEditVideoNode(IO.ComfyNode):
return await finish_omni_video_task(cls, response)
class OmniProImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageNode",
display_name="Kling Omni Image (Pro)",
category="api node/image/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-image-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the image content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
),
IO.Image.Input(
"reference_images",
tooltip="Up to 10 additional reference images.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
resolution: str,
aspect_ratio: str,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
image_list: list[OmniImageParamImage] = []
if reference_images is not None:
if get_number_of_images(reference_images) > 10:
raise ValueError("The maximum number of reference images is 10.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniImageParamImage(image=i))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
response_model=TaskStatusResponse,
data=OmniProImageRequest(
model_name=model_name,
prompt=prompt,
resolution=resolution.lower(),
aspect_ratio=aspect_ratio,
image_list=image_list if image_list else None,
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
class KlingCameraControlT2VNode(IO.ComfyNode):
"""
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
@ -1324,8 +1207,9 @@ class KlingImage2VideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImage2VideoNode",
display_name="Kling Image(First Frame) to Video",
display_name="Kling Image to Video",
category="api node/video/Kling",
description="Kling Image to Video Node",
inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1483,7 +1367,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.Combo.Input(
"mode",
options=modes,
default=modes[6],
default=modes[8],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@ -1946,7 +1830,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Combo.Input(
"model_name",
options=[i.value for i in KlingImageGenModelName],
default="kling-v2",
default="kling-v1",
),
IO.Combo.Input(
"aspect_ratio",
@ -2029,136 +1913,6 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.NodeOutput(await image_result_to_node_output(images))
class TextToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
display_name="Kling Text to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
mode: str,
aspect_ratio: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
response_model=TaskStatusResponse,
data=TextToVideoWithAudioRequest(
model_name=model_name,
prompt=prompt,
mode=mode,
aspect_ratio=aspect_ratio,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class ImageToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
display_name="Kling Image(First Frame) to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
start_frame: Input.Image,
prompt: str,
mode: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
response_model=TaskStatusResponse,
data=ImageToVideoWithAudioRequest(
model_name=model_name,
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
prompt=prompt,
mode=mode,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class KlingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2181,9 +1935,6 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode,
OmniProVideoToVideoNode,
OmniProEditVideoNode,
OmniProImageNode,
TextToVideoWithAudio,
ImageToVideoWithAudio,
]

View File

@ -1,9 +1,12 @@
from io import BytesIO
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
@ -23,9 +26,9 @@ class ExecuteTaskRequest(BaseModel):
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
fps: int | None = Field(25)
generate_audio: bool | None = Field(True)
image_uri: str | None = Field(None)
fps: Optional[int] = Field(25)
generate_audio: Optional[bool] = Field(True)
image_uri: Optional[str] = Field(None)
class TextToVideoNode(IO.ComfyNode):
@ -100,7 +103,7 @@ class TextToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@ -150,7 +153,7 @@ class ImageToVideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: Input.Image,
image: torch.Tensor,
model: str,
prompt: str,
duration: int,
@ -180,7 +183,7 @@ class ImageToVideoNode(IO.ComfyNode):
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):

View File

@ -1,8 +1,11 @@
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.input import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
@ -58,7 +61,7 @@ def validate_task_creation_response(response) -> None:
raise RuntimeError(error_msg)
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
@ -79,7 +82,7 @@ def validate_video_to_video_input(video: Input.Video) -> Input.Video:
return _validate_and_trim_duration(video)
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
@ -103,7 +106,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
@ -116,7 +119,7 @@ def _validate_minimum_duration(duration: float) -> None:
raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
@ -238,7 +241,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
image: Input.Image,
image: torch.Tensor,
prompt: str,
negative_prompt: str,
resolution: str,
@ -359,9 +362,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
prompt: str,
negative_prompt: str,
seed: int,
video: Input.Video | None = None,
video: Optional[VideoInput] = None,
control_type: str = "Motion Transfer",
motion_intensity: int | None = 100,
motion_intensity: Optional[int] = 100,
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:

View File

@ -0,0 +1,568 @@
"""
Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
"""
from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint,
sync_op,
poll_op,
)
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
PIKA_API_VERSION = "2.2"
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
estimated_duration=60,
max_poll_attempts=240,
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
raise Exception(error_msg)
video_url = final_response.url
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return IO.NodeOutput(await download_url_to_video_output(video_url))
def get_base_inputs_types() -> list[IO.Input]:
"""Get the base required inputs types common to all Pika nodes."""
return [
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
IO.Combo.Input("duration", options=[5, 10], default=5),
]
class PikaImageToVideo(IO.ComfyNode):
"""Pika 2.2 Image to Video Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaImageToVideoNode2_2",
display_name="Pika Image to Video",
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The image to convert to video"),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
image_bytes_io = tensor_to_bytesio(image)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode):
"""Pika Text2Video v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaTextToVideoNode2_2",
display_name="Pika Text to Video",
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
)
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
),
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode):
"""PikaScenes v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaScenesV2_2",
display_name="Pika Scenes (Video Image Composition)",
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Combo.Input(
"ingredients_mode",
options=["creative", "precise"],
default="creative",
),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
),
IO.Image.Input(
"image_ingredient_1",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_2",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_3",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_4",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_5",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
ingredients_mode: str,
aspect_ratio: float,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
all_image_bytes_io = []
for image in [
image_ingredient_1,
image_ingredient_2,
image_ingredient_3,
image_ingredient_4,
image_ingredient_5,
]:
if image is not None:
all_image_bytes_io.append(tensor_to_bytesio(image))
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode):
"""Pika Pikadditions Node. Add an image into a video."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikadditions",
display_name="Pikadditions (Video Object Insertion)",
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to add an image to."),
IO.Image.Input("image", tooltip="The image to add to the video."),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input(
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
image_bytes_io = tensor_to_bytesio(image)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode):
"""Pika Pikaswaps Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaswaps",
display_name="Pika Swaps (Video Object Replacement)",
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to swap an object in."),
IO.Image.Input(
"image",
tooltip="The image used to replace the masked object in the video.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
IO.String.Input("prompt_text", multiline=True, optional=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
IO.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
prompt_text: str = "",
negative_prompt: str = "",
seed: int = 0,
region_to_modify: str = "",
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
}
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode):
"""Pika Pikaffects Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaffects",
display_name="Pikaffects (Video Effects)",
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
IO.Combo.Input(
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
pikaffect: str,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode):
"""PikaFrames v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaStartEndFrameNode2_2",
display_name="Pika Start and End Frame to Video",
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image_start", tooltip="The first image to combine."),
IO.Image.Input("image_end", tooltip="The last image to combine."),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image_start: torch.Tensor,
image_end: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
),
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PikaImageToVideo,
PikaTextToVideoNode,
PikaScenes,
PikAdditionsNode,
PikaSwapsNode,
PikaffectsNode,
PikaStartEndFrameNode,
]
async def comfy_entrypoint() -> PikaApiNodesExtension:
return PikaApiNodesExtension()

View File

@ -11,11 +11,12 @@ User Guides:
"""
from typing import Union, Optional
from typing_extensions import override
from enum import Enum
from typing_extensions import override
import torch
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
@ -43,6 +44,8 @@ from comfy_api_nodes.util import (
sync_op,
poll_op,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@ -77,7 +80,7 @@ class RunwayGen3aAspectRatio(str, Enum):
field_1280_768 = "1280:768"
def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@ -86,13 +89,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> float | None:
) -> Union[float, None]:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the image URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
@ -100,7 +103,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
async def get_response(
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_op(
@ -116,8 +119,8 @@ async def get_response(
async def generate_video(
cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest,
estimated_duration: int | None = None,
) -> InputImpl.VideoFromFile:
estimated_duration: Optional[int] = None,
) -> VideoFromFile:
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
@ -190,7 +193,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: Input.Image,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
@ -280,7 +283,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: Input.Image,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
@ -378,8 +381,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
async def execute(
cls,
prompt: str,
start_frame: Input.Image,
end_frame: Input.Image,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
@ -464,7 +467,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
cls,
prompt: str,
ratio: str,
reference_image: Input.Image | None = None,
reference_image: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)

View File

@ -102,9 +102,8 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Int.Input("model_seed", default=42, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -132,7 +131,6 @@ class TripoTextToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
) -> IO.NodeOutput:
@ -156,7 +154,6 @@ class TripoTextToModelNode(IO.ComfyNode):
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
geometry_quality=geometry_quality,
auto_size=True,
quad=quad,
),
@ -197,7 +194,6 @@ class TripoImageToModelNode(IO.ComfyNode):
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -224,7 +220,6 @@ class TripoImageToModelNode(IO.ComfyNode):
orientation=None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
@ -251,7 +246,6 @@ class TripoImageToModelNode(IO.ComfyNode):
pbr=pbr,
model_seed=model_seed,
orientation=orientation,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
texture_seed=texture_seed,
texture_quality=texture_quality,
@ -301,7 +295,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -330,7 +323,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
@ -367,7 +359,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
face_limit=face_limit,
quad=quad,
@ -517,8 +508,6 @@ class TripoRetargetNode(IO.ComfyNode):
options=[
"preset:idle",
"preset:walk",
"preset:run",
"preset:dive",
"preset:climb",
"preset:jump",
"preset:slash",
@ -526,11 +515,6 @@ class TripoRetargetNode(IO.ComfyNode):
"preset:hurt",
"preset:fall",
"preset:turn",
"preset:quadruped:walk",
"preset:hexapod:walk",
"preset:octopod:walk",
"preset:serpentine:march",
"preset:aquatic:march"
],
),
],
@ -579,7 +563,7 @@ class TripoConversionNode(IO.ComfyNode):
"face_limit",
default=-1,
min=-1,
max=2000000,
max=500000,
optional=True,
),
IO.Int.Input(
@ -595,40 +579,6 @@ class TripoConversionNode(IO.ComfyNode):
default="JPEG",
optional=True,
),
IO.Boolean.Input("force_symmetry", default=False, optional=True),
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
IO.Float.Input(
"flatten_bottom_threshold",
default=0.0,
min=0.0,
max=1.0,
optional=True,
),
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
IO.Float.Input(
"scale_factor",
default=1.0,
min=0.0,
optional=True,
),
IO.Boolean.Input("with_animation", default=False, optional=True),
IO.Boolean.Input("pack_uv", default=False, optional=True),
IO.Boolean.Input("bake", default=False, optional=True),
IO.String.Input("part_names", default="", optional=True), # comma-separated list
IO.Combo.Input(
"fbx_preset",
options=["blender", "mixamo", "3dsmax"],
default="blender",
optional=True,
),
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
IO.Combo.Input(
"export_orientation",
options=["align_image", "default"],
default="default",
optional=True,
),
IO.Boolean.Input("animate_in_place", default=False, optional=True),
],
outputs=[],
hidden=[
@ -654,31 +604,12 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id,
format: str,
quad: bool,
force_symmetry: bool,
face_limit: int,
flatten_bottom: bool,
flatten_bottom_threshold: float,
texture_size: int,
texture_format: str,
pivot_to_center_bottom: bool,
scale_factor: float,
with_animation: bool,
pack_uv: bool,
bake: bool,
part_names: str,
fbx_preset: str,
export_vertex_colors: bool,
export_orientation: str,
animate_in_place: bool,
) -> IO.NodeOutput:
if not original_model_task_id:
raise RuntimeError("original_model_task_id is required")
# Parse part_names from comma-separated string to list
part_names_list = None
if part_names and part_names.strip():
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
@ -687,22 +618,9 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id=original_model_task_id,
format=format,
quad=quad if quad else None,
force_symmetry=force_symmetry if force_symmetry else None,
face_limit=face_limit if face_limit != -1 else None,
flatten_bottom=flatten_bottom if flatten_bottom else None,
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None,
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
scale_factor=scale_factor if scale_factor != 1.0 else None,
with_animation=with_animation if with_animation else None,
pack_uv=pack_uv if pack_uv else None,
bake=bake if bake else None,
part_names=part_names_list,
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
export_orientation=export_orientation if export_orientation != "default" else None,
animate_in_place=animate_in_place if animate_in_place else None,
),
)
return await poll_until_finished(cls, response, average_duration=30)

View File

@ -1,9 +1,11 @@
import base64
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollRequest,
VeoGenVidPollResponse,
@ -230,7 +232,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Check if video is provided as base64 or URL
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if hasattr(video, "gcsUri") and video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
@ -429,8 +431,8 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
aspect_ratio: str,
duration: int,
seed: int,
first_frame: Input.Image,
last_frame: Input.Image,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
model: str,
generate_audio: bool,
):
@ -491,7 +493,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")

View File

@ -47,7 +47,6 @@ from .validation_utils import (
validate_string,
validate_video_dimensions,
validate_video_duration,
validate_video_frame_count,
)
__all__ = [
@ -95,7 +94,6 @@ __all__ = [
"validate_string",
"validate_video_dimensions",
"validate_video_duration",
"validate_video_frame_count",
# Misc functions
"get_fs_object_size",
]

View File

@ -2,8 +2,8 @@ import asyncio
import contextlib
import os
import time
from collections.abc import Callable
from io import BytesIO
from typing import Callable, Optional, Union
from comfy.cli_args import args
from comfy.model_management import processing_interrupted
@ -35,12 +35,12 @@ def default_base_url() -> str:
async def sleep_with_interrupt(
seconds: float,
node_cls: type[IO.ComfyNode] | None,
label: str | None = None,
start_ts: float | None = None,
estimated_total: int | None = None,
node_cls: Optional[type[IO.ComfyNode]],
label: Optional[str] = None,
start_ts: Optional[float] = None,
estimated_total: Optional[int] = None,
*,
display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
):
"""
Sleep in 1s slices while:
@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower()
def get_fs_object_size(path_or_object: str | BytesIO) -> int:
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
if isinstance(path_or_object, str):
return os.path.getsize(path_or_object)
return len(path_or_object.getvalue())

View File

@ -4,11 +4,10 @@ import json
import logging
import time
import uuid
from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
from typing import Any, Literal, TypeVar
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
from urllib.parse import urljoin, urlparse
import aiohttp
@ -38,8 +37,8 @@ class ApiEndpoint:
path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
*,
query_params: dict[str, Any] | None = None,
headers: dict[str, str] | None = None,
query_params: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
):
self.path = path
self.method = method
@ -53,18 +52,18 @@ class _RequestConfig:
endpoint: ApiEndpoint
timeout: float
content_type: str
data: dict[str, Any] | None
files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Callable | None
data: Optional[dict[str, Any]]
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
multipart_parser: Optional[Callable]
max_retries: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
monitor_progress: bool = True
estimated_total: int | None = None
final_label_on_success: str | None = "Completed"
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
estimated_total: Optional[int] = None
final_label_on_success: Optional[str] = "Completed"
progress_origin_ts: Optional[float] = None
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
@dataclass
@ -72,10 +71,10 @@ class _PollUIState:
started: float
status_label: str = "Queued"
is_queued: bool = True
price: float | None = None
estimated_duration: int | None = None
price: Optional[float] = None
estimated_duration: Optional[int] = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: float | None = None # start time of current active interval (None if queued)
active_since: Optional[float] = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
@ -88,20 +87,20 @@ async def sync_op(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
response_model: type[M],
price_extractor: Callable[[M | Any], float | None] | None = None,
data: BaseModel | None = None,
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
response_model: Type[M],
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
data: Optional[BaseModel] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Callable | None = None,
multipart_parser: Optional[Callable] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: int | None = None,
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
estimated_duration: Optional[int] = None,
final_label_on_success: Optional[str] = "Completed",
progress_origin_ts: Optional[float] = None,
monitor_progress: bool = True,
) -> M:
raw = await sync_op_raw(
@ -132,22 +131,22 @@ async def poll_op(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
response_model: type[M],
status_extractor: Callable[[M | Any], str | int | None],
progress_extractor: Callable[[M | Any], int | None] | None = None,
price_extractor: Callable[[M | Any], float | None] | None = None,
completed_statuses: list[str | int] | None = None,
failed_statuses: list[str | int] | None = None,
queued_statuses: list[str | int] | None = None,
data: BaseModel | None = None,
response_model: Type[M],
status_extractor: Callable[[M], Optional[Union[str, int]]],
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
completed_statuses: Optional[list[Union[str, int]]] = None,
failed_statuses: Optional[list[Union[str, int]]] = None,
queued_statuses: Optional[list[Union[str, int]]] = None,
data: Optional[BaseModel] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
estimated_duration: Optional[int] = None,
cancel_endpoint: Optional[ApiEndpoint] = None,
cancel_timeout: float = 10.0,
) -> M:
raw = await poll_op_raw(
@ -179,22 +178,22 @@ async def sync_op_raw(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
data: dict[str, Any] | BaseModel | None = None,
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Callable | None = None,
multipart_parser: Optional[Callable] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: int | None = None,
estimated_duration: Optional[int] = None,
as_binary: bool = False,
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
final_label_on_success: Optional[str] = "Completed",
progress_origin_ts: Optional[float] = None,
monitor_progress: bool = True,
) -> dict[str, Any] | bytes:
) -> Union[dict[str, Any], bytes]:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
@ -230,21 +229,21 @@ async def poll_op_raw(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
status_extractor: Callable[[dict[str, Any]], str | int | None],
progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
completed_statuses: list[str | int] | None = None,
failed_statuses: list[str | int] | None = None,
queued_statuses: list[str | int] | None = None,
data: dict[str, Any] | BaseModel | None = None,
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
completed_statuses: Optional[list[Union[str, int]]] = None,
failed_statuses: Optional[list[Union[str, int]]] = None,
queued_statuses: Optional[list[Union[str, int]]] = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
estimated_duration: Optional[int] = None,
cancel_endpoint: Optional[ApiEndpoint] = None,
cancel_timeout: float = 10.0,
) -> dict[str, Any]:
"""
@ -262,7 +261,7 @@ async def poll_op_raw(
consumed_attempts = 0 # counts only non-queued polls
progress_bar = utils.ProgressBar(100) if progress_extractor else None
last_progress: int | None = None
last_progress: Optional[int] = None
state = _PollUIState(started=started, estimated_duration=estimated_duration)
stop_ticker = asyncio.Event()
@ -421,10 +420,10 @@ async def poll_op_raw(
def _display_text(
node_cls: type[IO.ComfyNode],
text: str | None,
text: Optional[str],
*,
status: str | int | None = None,
price: float | None = None,
status: Optional[Union[str, int]] = None,
price: Optional[float] = None,
) -> None:
display_lines: list[str] = []
if status:
@ -441,13 +440,13 @@ def _display_text(
def _display_time_progress(
node_cls: type[IO.ComfyNode],
status: str | int | None,
status: Optional[Union[str, int]],
elapsed_seconds: int,
estimated_total: int | None = None,
estimated_total: Optional[int] = None,
*,
price: float | None = None,
is_queued: bool | None = None,
processing_elapsed_seconds: int | None = None,
price: Optional[float] = None,
is_queued: Optional[bool] = None,
processing_elapsed_seconds: Optional[int] = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@ -489,7 +488,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@ -535,9 +534,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
def _snapshot_request_body_for_logging(
content_type: str,
method: str,
data: dict[str, Any] | None,
files: dict[str, Any] | list[tuple[str, Any]] | None,
) -> dict[str, Any] | str | None:
data: Optional[dict[str, Any]],
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
) -> Optional[Union[dict[str, Any], str]]:
if method.upper() == "GET":
return None
if content_type == "multipart/form-data":
@ -587,13 +586,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
attempt = 0
delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: int | None = None
extracted_price: float | None = None
final_elapsed_seconds: Optional[int] = None
extracted_price: Optional[float] = None
while True:
attempt += 1
stop_event = asyncio.Event()
monitor_task: asyncio.Task | None = None
sess: aiohttp.ClientSession | None = None
monitor_task: Optional[asyncio.Task] = None
sess: Optional[aiohttp.ClientSession] = None
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
@ -888,7 +887,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
)
def _validate_or_raise(response_model: type[M], payload: Any) -> M:
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
try:
return response_model.model_validate(payload)
except Exception as e:
@ -903,9 +902,9 @@ def _validate_or_raise(response_model: type[M], payload: Any) -> M:
def _wrap_model_extractor(
response_model: type[M],
extractor: Callable[[M], Any] | None,
) -> Callable[[dict[str, Any]], Any] | None:
response_model: Type[M],
extractor: Optional[Callable[[M], Any]],
) -> Optional[Callable[[dict[str, Any]], Any]]:
"""Wrap a typed extractor so it can be used by the dict-based poller.
Validates the dict into `response_model` before invoking `extractor`.
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
@ -930,10 +929,10 @@ def _wrap_model_extractor(
return _wrapped
def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
if not values:
return set()
out: set[str | int] = set()
out: set[Union[str, int]] = set()
for v in values:
nv = _normalize_status_value(v)
if nv is not None:
@ -941,7 +940,7 @@ def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
return out
def _normalize_status_value(val: str | int | None) -> str | int | None:
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
if isinstance(val, str):
return val.strip().lower()
return val

View File

@ -4,6 +4,7 @@ import math
import mimetypes
import uuid
from io import BytesIO
from typing import Optional
import av
import numpy as np
@ -11,7 +12,8 @@ import torch
from PIL import Image
from comfy.utils import common_upscale
from comfy_api.latest import Input, InputImpl, Types
from comfy_api.latest import Input, InputImpl
from comfy_api.util import VideoCodec, VideoContainer
from ._helpers import mimetype_to_extension
@ -55,7 +57,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
def tensor_to_bytesio(
image: torch.Tensor,
name: str | None = None,
name: Optional[str] = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
@ -175,8 +177,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
def video_to_base64_string(
video: Input.Video,
container_format: Types.VideoContainer | None = None,
codec: Types.VideoCodec | None = None,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
@ -187,11 +189,12 @@ def video_to_base64_string(
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = BytesIO()
video.save_to(
video_bytes_io,
format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
)
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")

View File

@ -3,15 +3,15 @@ import contextlib
import uuid
from io import BytesIO
from pathlib import Path
from typing import IO
from typing import IO, Optional, Union
from urllib.parse import urljoin, urlparse
import aiohttp
import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO as COMFY_IO
from comfy_api.latest import InputImpl
from . import request_logger
from ._helpers import (
@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
async def download_url_to_bytesio(
url: str,
dest: BytesIO | IO[bytes] | str | Path | None,
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
*,
timeout: float | None = None,
timeout: Optional[float] = None,
max_retries: int = 5,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
@ -71,10 +71,10 @@ async def download_url_to_bytesio(
is_path_sink = isinstance(dest, (str, Path))
fhandle = None
session: aiohttp.ClientSession | None = None
stop_evt: asyncio.Event | None = None
monitor_task: asyncio.Task | None = None
req_task: asyncio.Task | None = None
session: Optional[aiohttp.ClientSession] = None
stop_evt: Optional[asyncio.Event] = None
monitor_task: Optional[asyncio.Task] = None
req_task: Optional[asyncio.Task] = None
try:
with contextlib.suppress(Exception):
@ -234,11 +234,11 @@ async def download_url_to_video_output(
timeout: float = None,
max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None,
) -> InputImpl.VideoFromFile:
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
return InputImpl.VideoFromFile(result)
return VideoFromFile(result)
async def download_url_as_bytesio(

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import datetime
import hashlib
import json

View File

@ -4,13 +4,15 @@ import logging
import time
import uuid
from io import BytesIO
from typing import Optional
from urllib.parse import urlparse
import aiohttp
import torch
from pydantic import BaseModel, Field
from comfy_api.latest import IO, Input, Types
from comfy_api.latest import IO, Input
from comfy_api.util import VideoCodec, VideoContainer
from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt
@ -30,7 +32,7 @@ from .conversions import (
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: str | None = Field(
content_type: Optional[str] = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
@ -54,7 +56,7 @@ async def upload_images_to_comfyapi(
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
"""
# if batched, try to upload each file if max_images is greater than 0
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
@ -98,9 +100,9 @@ async def upload_video_to_comfyapi(
cls: type[IO.ComfyNode],
video: Input.Video,
*,
container: Types.VideoContainer = Types.VideoContainer.MP4,
codec: Types.VideoCodec = Types.VideoCodec.H264,
max_duration: int | None = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
wait_label: str | None = "Uploading",
) -> str:
"""
@ -218,7 +220,7 @@ async def upload_file(
return
monitor_task = asyncio.create_task(_monitor())
sess: aiohttp.ClientSession | None = None
sess: Optional[aiohttp.ClientSession] = None
try:
try:
request_logger.log_request_response(

View File

@ -1,7 +1,9 @@
import logging
from typing import Optional
import torch
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input
@ -16,10 +18,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
def validate_image_dimensions(
image: torch.Tensor,
min_width: int | None = None,
max_width: int | None = None,
min_height: int | None = None,
max_height: int | None = None,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
max_height: Optional[int] = None,
):
height, width = get_image_dimensions(image)
@ -35,8 +37,8 @@ def validate_image_dimensions(
def validate_image_aspect_ratio(
image: torch.Tensor,
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
@ -52,8 +54,8 @@ def validate_image_aspect_ratio(
def validate_images_aspect_ratio_closeness(
first_image: torch.Tensor,
second_image: torch.Tensor,
min_rel: float, # e.g. 0.8
max_rel: float, # e.g. 1.25
min_rel: float, # e.g. 0.8
max_rel: float, # e.g. 1.25
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
@ -82,8 +84,8 @@ def validate_images_aspect_ratio_closeness(
def validate_aspect_ratio_string(
aspect_ratio: str,
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
@ -95,10 +97,10 @@ def validate_aspect_ratio_string(
def validate_video_dimensions(
video: Input.Video,
min_width: int | None = None,
max_width: int | None = None,
min_height: int | None = None,
max_height: int | None = None,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
max_height: Optional[int] = None,
):
try:
width, height = video.get_dimensions()
@ -118,8 +120,8 @@ def validate_video_dimensions(
def validate_video_duration(
video: Input.Video,
min_duration: float | None = None,
max_duration: float | None = None,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
):
try:
duration = video.get_duration()
@ -134,23 +136,6 @@ def validate_video_duration(
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
def validate_video_frame_count(
video: Input.Video,
min_frame_count: int | None = None,
max_frame_count: int | None = None,
):
try:
frame_count = video.get_frame_count()
except Exception as e:
logging.error("Error getting frame count of video: %s", e)
return
if min_frame_count is not None and min_frame_count > frame_count:
raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
if max_frame_count is not None and frame_count > max_frame_count:
raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
def get_number_of_images(images):
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
@ -159,8 +144,8 @@ def get_number_of_images(images):
def validate_audio_duration(
audio: Input.Audio,
min_duration: float | None = None,
max_duration: float | None = None,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
) -> None:
sr = int(audio["sample_rate"])
dur = int(audio["waveform"].shape[-1]) / sr
@ -192,7 +177,7 @@ def validate_string(
)
def validate_container_format_is_mp4(video: Input.Video) -> None:
def validate_container_format_is_mp4(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
@ -209,8 +194,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
def _assert_ratio_bounds(
ar: float,
*,
min_ratio: tuple[float, float] | None = None,
max_ratio: tuple[float, float] | None = None,
min_ratio: Optional[tuple[float, float]] = None,
max_ratio: Optional[tuple[float, float]] = None,
strict: bool = True,
) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""

View File

@ -1,5 +1,4 @@
from __future__ import annotations
from comfy_api.latest import IO
def validate_node_input(
@ -24,11 +23,6 @@ def validate_node_input(
if not received_type != input_type:
return True
# If the received type or input_type is a MatchType, we can return True immediately;
# validation for this is handled by the frontend
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
return True
# Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str):
return False

View File

@ -6,80 +6,65 @@ import torch
import comfy.model_management
import folder_paths
import os
import io
import json
import random
import hashlib
import node_helpers
import logging
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, UI
from comfy.cli_args import args
from comfy.comfy_types import FileLocator
class EmptyLatentAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentAudio",
display_name="Empty Latent Audio",
category="latent/audio",
inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
),
],
outputs=[IO.Latent.Output()],
)
class EmptyLatentAudio:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
def INPUT_TYPES(s):
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds, batch_size):
length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
return IO.NodeOutput({"samples":latent, "type": "audio"})
latent = torch.zeros([batch_size, 64, length], device=self.device)
return ({"samples":latent, "type": "audio"}, )
generate = execute # TODO: remove
class ConditioningStableAudio(IO.ComfyNode):
class ConditioningStableAudio:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ConditioningStableAudio",
category="conditioning",
inputs=[
IO.Conditioning.Input("positive"),
IO.Conditioning.Input("negative"),
IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
],
)
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
"seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
}}
@classmethod
def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, positive, negative, seconds_start, seconds_total):
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
return IO.NodeOutput(positive, negative)
return (positive, negative)
append = execute # TODO: remove
class VAEEncodeAudio(IO.ComfyNode):
class VAEEncodeAudio:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEEncodeAudio",
display_name="VAE Encode Audio",
category="latent/audio",
inputs=[
IO.Audio.Input("audio"),
IO.Vae.Input("vae"),
],
outputs=[IO.Latent.Output()],
)
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
@classmethod
def execute(cls, vae, audio) -> IO.NodeOutput:
CATEGORY = "latent/audio"
def encode(self, vae, audio):
sample_rate = audio["sample_rate"]
if 44100 != sample_rate:
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
@ -87,134 +72,213 @@ class VAEEncodeAudio(IO.ComfyNode):
waveform = audio["waveform"]
t = vae.encode(waveform.movedim(1, -1))
return IO.NodeOutput({"samples":t})
return ({"samples":t}, )
encode = execute # TODO: remove
class VAEDecodeAudio(IO.ComfyNode):
class VAEDecodeAudio:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeAudio",
display_name="VAE Decode Audio",
category="latent/audio",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
],
outputs=[IO.Audio.Output()],
)
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "decode"
@classmethod
def execute(cls, vae, samples) -> IO.NodeOutput:
CATEGORY = "latent/audio"
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
decode = execute # TODO: remove
return ({"waveform": audio, "sample_rate": 44100}, )
class SaveAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAudio",
display_name="Save Audio (FLAC)",
category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
# Prepare metadata dictionary
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
# Opus supported sample rates
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially
sample_rate = audio["sample_rate"]
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create output with specified format
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, 'wb') as f:
f.write(output_buffer.getbuffer())
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } }
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
)
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
save_flac = execute # TODO: remove
RETURN_TYPES = ()
FUNCTION = "save_flac"
OUTPUT_NODE = True
class SaveAudioMP3(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAudioMP3",
display_name="Save Audio (MP3)",
category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
CATEGORY = "audio"
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
class SaveAudioMP3:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
)
)
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
save_mp3 = execute # TODO: remove
RETURN_TYPES = ()
FUNCTION = "save_mp3"
OUTPUT_NODE = True
class SaveAudioOpus(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAudioOpus",
display_name="Save Audio (Opus)",
category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
CATEGORY = "audio"
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class SaveAudioOpus:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
return IO.NodeOutput(
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
)
)
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
save_opus = execute # TODO: remove
RETURN_TYPES = ()
FUNCTION = "save_opus"
OUTPUT_NODE = True
class PreviewAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="PreviewAudio",
display_name="Preview Audio",
category="audio",
inputs=[
IO.Audio.Input("audio"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
CATEGORY = "audio"
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class PreviewAudio(SaveAudio):
def __init__(self):
self.output_dir = folder_paths.get_temp_directory()
self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
save_flac = execute # TODO: remove
def INPUT_TYPES(s):
return {"required":
{"audio": ("AUDIO", ), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format."""
@ -252,30 +316,26 @@ def load(filepath: str) -> tuple[torch.Tensor, int]:
wav = f32_pcm(wav)
return wav, sr
class LoadAudio(IO.ComfyNode):
class LoadAudio:
@classmethod
def define_schema(cls):
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return IO.Schema(
node_id="LoadAudio",
display_name="Load Audio",
category="audio",
inputs=[
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
],
outputs=[IO.Audio.Output()],
)
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return IO.NodeOutput(audio)
return (audio, )
@classmethod
def fingerprint_inputs(cls, audio):
def IS_CHANGED(s, audio):
image_path = folder_paths.get_annotated_filepath(audio)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
@ -283,69 +343,46 @@ class LoadAudio(IO.ComfyNode):
return m.digest().hex()
@classmethod
def validate_inputs(cls, audio):
def VALIDATE_INPUTS(s, audio):
if not folder_paths.exists_annotated_filepath(audio):
return "Invalid audio file: {}".format(audio)
return True
load = execute # TODO: remove
class RecordAudio(IO.ComfyNode):
class RecordAudio:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RecordAudio",
display_name="Record Audio",
category="audio",
inputs=[
IO.Custom("AUDIO_RECORD").Input("audio"),
],
outputs=[IO.Audio.Output()],
)
def INPUT_TYPES(s):
return {"required": {"audio": ("AUDIO_RECORD", {})}}
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return IO.NodeOutput(audio)
load = execute # TODO: remove
return (audio, )
class TrimAudioDuration(IO.ComfyNode):
class TrimAudioDuration:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TrimAudioDuration",
display_name="Trim Audio Duration",
description="Trim audio tensor into chosen time range.",
category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.Float.Input(
"start_index",
default=0.0,
min=-0xffffffffffffffff,
max=0xffffffffffffffff,
step=0.01,
tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
),
IO.Float.Input(
"duration",
default=60.0,
min=0.0,
step=0.01,
tooltip="Duration in seconds",
),
],
outputs=[IO.Audio.Output()],
)
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
},
}
@classmethod
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
FUNCTION = "trim"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Trim audio tensor into chosen time range."
def trim(self, audio, start_index, duration):
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
@ -362,30 +399,23 @@ class TrimAudioDuration(IO.ComfyNode):
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
trim = execute # TODO: remove
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
class SplitAudioChannels(IO.ComfyNode):
class SplitAudioChannels:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SplitAudioChannels",
display_name="Split Audio Channels",
description="Separates the audio into left and right channels.",
category="audio",
inputs=[
IO.Audio.Input("audio"),
],
outputs=[
IO.Audio.Output(display_name="left"),
IO.Audio.Output(display_name="right"),
],
)
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
}}
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
RETURN_TYPES = ("AUDIO", "AUDIO")
RETURN_NAMES = ("left", "right")
FUNCTION = "separate"
CATEGORY = "audio"
DESCRIPTION = "Separates the audio into left and right channels."
def separate(self, audio):
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
@ -395,9 +425,7 @@ class SplitAudioChannels(IO.ComfyNode):
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
separate = execute # TODO: remove
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
@ -415,29 +443,21 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
return waveform_1, waveform_2, output_sample_rate
class AudioConcat(IO.ComfyNode):
class AudioConcat:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="AudioConcat",
display_name="Audio Concat",
description="Concatenates the audio1 to audio2 in the specified direction.",
category="audio",
inputs=[
IO.Audio.Input("audio1"),
IO.Audio.Input("audio2"),
IO.Combo.Input(
"direction",
options=['after', 'before'],
default="after",
tooltip="Whether to append audio2 after or before audio1.",
)
],
outputs=[IO.Audio.Output()],
)
def INPUT_TYPES(s):
return {"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
}}
@classmethod
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
RETURN_TYPES = ("AUDIO",)
FUNCTION = "concat"
CATEGORY = "audio"
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
def concat(self, audio1, audio2, direction):
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@ -457,33 +477,26 @@ class AudioConcat(IO.ComfyNode):
elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
concat = execute # TODO: remove
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
class AudioMerge(IO.ComfyNode):
class AudioMerge:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="AudioMerge",
display_name="Audio Merge",
description="Combine two audio tracks by overlaying their waveforms.",
category="audio",
inputs=[
IO.Audio.Input("audio1"),
IO.Audio.Input("audio2"),
IO.Combo.Input(
"merge_method",
options=["add", "mean", "subtract", "multiply"],
tooltip="The method used to combine the audio waveforms.",
)
],
outputs=[IO.Audio.Output()],
)
def INPUT_TYPES(cls):
return {
"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
},
}
@classmethod
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
FUNCTION = "merge"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
def merge(self, audio1, audio2, merge_method):
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@ -517,110 +530,85 @@ class AudioMerge(IO.ComfyNode):
if max_val > 1.0:
waveform = waveform / max_val
return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
merge = execute # TODO: remove
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
class AudioAdjustVolume(IO.ComfyNode):
class AudioAdjustVolume:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="AudioAdjustVolume",
display_name="Audio Adjust Volume",
category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.Int.Input(
"volume",
default=1,
min=-100,
max=100,
tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
)
],
outputs=[IO.Audio.Output()],
)
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
}}
@classmethod
def execute(cls, audio, volume) -> IO.NodeOutput:
RETURN_TYPES = ("AUDIO",)
FUNCTION = "adjust_volume"
CATEGORY = "audio"
def adjust_volume(self, audio, volume):
if volume == 0:
return IO.NodeOutput(audio)
return (audio,)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
gain = 10 ** (volume / 20)
waveform = waveform * gain
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
adjust_volume = execute # TODO: remove
return ({"waveform": waveform, "sample_rate": sample_rate},)
class EmptyAudio(IO.ComfyNode):
class EmptyAudio:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="EmptyAudio",
display_name="Empty Audio",
category="audio",
inputs=[
IO.Float.Input(
"duration",
default=60.0,
min=0.0,
max=0xffffffffffffffff,
step=0.01,
tooltip="Duration of the empty audio clip in seconds",
),
IO.Int.Input(
"sample_rate",
default=44100,
tooltip="Sample rate of the empty audio clip.",
min=1,
max=192000,
),
IO.Int.Input(
"channels",
default=2,
min=1,
max=2,
tooltip="Number of audio channels (1 for mono, 2 for stereo).",
),
],
outputs=[IO.Audio.Output()],
)
def INPUT_TYPES(s):
return {"required": {
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
}}
@classmethod
def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
RETURN_TYPES = ("AUDIO",)
FUNCTION = "create_empty_audio"
CATEGORY = "audio"
def create_empty_audio(self, duration, sample_rate, channels):
num_samples = int(round(duration * sample_rate))
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
create_empty_audio = execute # TODO: remove
return ({"waveform": waveform, "sample_rate": sample_rate},)
class AudioExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
EmptyLatentAudio,
VAEEncodeAudio,
VAEDecodeAudio,
SaveAudio,
SaveAudioMP3,
SaveAudioOpus,
LoadAudio,
PreviewAudio,
ConditioningStableAudio,
RecordAudio,
TrimAudioDuration,
SplitAudioChannels,
AudioConcat,
AudioMerge,
AudioAdjustVolume,
EmptyAudio,
]
NODE_CLASS_MAPPINGS = {
"EmptyLatentAudio": EmptyLatentAudio,
"VAEEncodeAudio": VAEEncodeAudio,
"VAEDecodeAudio": VAEDecodeAudio,
"SaveAudio": SaveAudio,
"SaveAudioMP3": SaveAudioMP3,
"SaveAudioOpus": SaveAudioOpus,
"LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
"RecordAudio": RecordAudio,
"TrimAudioDuration": TrimAudioDuration,
"SplitAudioChannels": SplitAudioChannels,
"AudioConcat": AudioConcat,
"AudioMerge": AudioMerge,
"AudioAdjustVolume": AudioAdjustVolume,
"EmptyAudio": EmptyAudio,
}
async def comfy_entrypoint() -> AudioExtension:
return AudioExtension()
NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentAudio": "Empty Latent Audio",
"VAEEncodeAudio": "VAE Encode Audio",
"VAEDecodeAudio": "VAE Decode Audio",
"PreviewAudio": "Preview Audio",
"LoadAudio": "Load Audio",
"SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio",
"TrimAudioDuration": "Trim Audio Duration",
"SplitAudioChannels": "Split Audio Channels",
"AudioConcat": "Audio Concat",
"AudioMerge": "Audio Merge",
"AudioAdjustVolume": "Audio Adjust Volume",
"EmptyAudio": "Empty Audio",
}

View File

@ -26,9 +26,6 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
],
outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."),
@ -37,8 +34,7 @@ class ContextWindowsManualNode(io.ComfyNode):
)
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@ -47,15 +43,9 @@ class ContextWindowsManualNode(io.ComfyNode):
context_overlap=context_overlap,
context_stride=context_stride,
closed_loop=closed_loop,
dim=dim,
freenoise=freenoise,
cond_retain_index_list=cond_retain_index_list,
split_conds_to_windows=split_conds_to_windows
)
dim=dim)
# make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model)
if freenoise: # no other use for this wrapper at this time
comfy.context_windows.create_sampler_sample_wrapper(model)
return io.NodeOutput(model)
class WanContextWindowsManualNode(ContextWindowsManualNode):
@ -78,18 +68,14 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
]
return schema
@classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
class ContextWindowsExtension(ComfyExtension):

View File

@ -659,31 +659,6 @@ class SamplerSASolver(io.ComfyNode):
get_sampler = execute
class SamplerSEEDS2(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
],
outputs=[io.Sampler.Output()]
)
@classmethod
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
sampler_name = "seeds_2"
sampler = comfy.samplers.ksampler(
sampler_name,
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
)
return io.NodeOutput(sampler)
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
@ -1021,7 +996,6 @@ class CustomSamplersExtension(ComfyExtension):
SamplerDPMAdaptative,
SamplerER_SDE,
SamplerSASolver,
SamplerSEEDS2,
SplitSigmas,
SplitSigmasDenoise,
FlipSigmas,

View File

@ -154,13 +154,12 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FluxKontextMultiReferenceLatentMethod",
display_name="Edit Model Reference Method",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
options=["offset", "index", "uxo/uno"],
),
],
outputs=[

View File

@ -2,8 +2,6 @@
import torch
import logging
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
def Fourier_filter(x, threshold, scale):
# FFT
@ -24,26 +22,21 @@ def Fourier_filter(x, threshold, scale):
return x_filtered.to(x.dtype)
class FreeU(IO.ComfyNode):
class FreeU:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="FreeU",
category="model_patches/unet",
inputs=[
IO.Model.Input("model"),
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01),
IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}),
"b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}),
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@classmethod
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
CATEGORY = "model_patches/unet"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
@ -66,31 +59,23 @@ class FreeU(IO.ComfyNode):
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
return IO.NodeOutput(m)
return (m, )
patch = execute # TODO: remove
class FreeU_V2(IO.ComfyNode):
class FreeU_V2:
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="FreeU_V2",
category="model_patches/unet",
inputs=[
IO.Model.Input("model"),
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01),
IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01),
IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}),
"b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}),
"s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}),
"s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@classmethod
def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput:
CATEGORY = "model_patches/unet"
def patch(self, model, b1, b2, s1, s2):
model_channels = model.model.model_config.unet_config["model_channels"]
scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)}
on_cpu_devices = {}
@ -120,19 +105,9 @@ class FreeU_V2(IO.ComfyNode):
m = model.clone()
m.set_model_output_block_patch(output_block_patch)
return IO.NodeOutput(m)
return (m, )
patch = execute # TODO: remove
class FreelunchExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
FreeU,
FreeU_V2,
]
async def comfy_entrypoint() -> FreelunchExtension:
return FreelunchExtension()
NODE_CLASS_MAPPINGS = {
"FreeU": FreeU,
"FreeU_V2": FreeU_V2,
}

View File

@ -1,136 +0,0 @@
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.utils
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class Kandinsky5ImageToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Kandinsky5ImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty video latent"),
io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond_latent_out = {}
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encoded = vae.encode(start_image[:, :, :, :3])
cond_latent_out["samples"] = encoded
mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high)
reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high)
# normalization
normalized = (source - source_mean) / (source_std + 1e-8)
normalized = normalized * reference_std + reference_mean
return normalized
class NormalizeVideoLatentStart(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="NormalizeVideoLatentStart",
category="conditioning/video_models",
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
inputs=[
io.Latent.Input("latent"),
io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"),
io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"),
],
outputs=[
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput:
if latent["samples"].shape[2] <= 1:
return io.NodeOutput(latent)
s = latent.copy()
samples = latent["samples"].clone()
first_frames = samples[:, :, :start_frame_count]
reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)]
normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data)
samples[:, :, :start_frame_count] = normalized_first_frames
s["samples"] = samples
return io.NodeOutput(s)
class CLIPTextEncodeKandinsky5(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="CLIPTextEncodeKandinsky5",
category="advanced/conditioning/kandinsky5",
inputs=[
io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput:
tokens = clip.tokenize(clip_l)
tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"]
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
class Kandinsky5Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
Kandinsky5ImageToVideo,
NormalizeVideoLatentStart,
CLIPTextEncodeKandinsky5,
]
async def comfy_entrypoint() -> Kandinsky5Extension:
return Kandinsky5Extension()

Some files were not shown because too many files have changed in this diff Show More