Compare commits

...

17 Commits

Author SHA1 Message Date
tashiscool
6c7960ea10
Merge 45a2363e6a into f6d5068ac0 2026-05-02 22:35:24 -07:00
Alexis Rolland
f6d5068ac0
Update README (#13679)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Updated the README to include a new screenshot, improved description and add Ernie Image to supported models.
2026-05-03 12:20:17 +08:00
Jukka Seppänen
be95871adc
feat: Gemma4 text generation support (CORE-30) (#13376)
* initial gemma4 support

* parity with reference implementation

outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize

* Cleanup, video fixes

* cleanup, enable fused rms norm by default

* update comment

* Cleanup

* Update sd.py

* Various fixes

* Add fp8 scaled embedding support

* small fixes

* Translate think tokens

* Fix image encoder attention mask type

So it works with basic attention

* Handle thinking tokens different only for Gemma4

* Code cleanup

* Update nodes_textgen.py

* Use embed scale class instead of buffer

Slight difference to HF, but technically more accurate and simpler code

* Default to fused rms_norm

* Update gemma4.py
2026-05-02 22:46:15 -04:00
Alexander Piskun
f756d801a1
[Partner Nodes] Topaz Astra 2 model (#13672)
* feat(api-nodes): add Topaz Astra 2 model

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat(api-nodes): make Astra 2 the default Topaz upscaler model

Reorder UPSCALER_MODELS_MAP and the upscaler_model dynamic combo so
"Astra 2" appears first, surfacing it as the default selection.

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Marwan Mostafa <marawan206@gmail.com>
2026-05-02 19:29:00 -07:00
Daxiong (Lin)
1d23a875ed
chore: update workflow templates to v0.9.68 (#13678) 2026-05-03 10:06:55 +08:00
comfyanonymous
ef6722f6be
Some cleanups to the load image node. (#13677) 2026-05-02 20:34:27 -04:00
rattus
783782d5d7
Implement block prefetch + Lora Async load + and adopt in LTX (Speedup!) (CORE-111) (#13618)
* mm: Use Aimdo raw allocator for cast buffers

pytorch manages allocation of growing buffers on streams poorly. Pyt
has no windows support for the expandable segments allocator (which is
the right tool for this job), while also segmenting the memory by
stream such that it can be generally re-used. So kick the problem to
aimdo which can just grow a virtual region thats freed per stream.

* plan

* ops: move cpu handler up to the caller

* ops: split up prefetch from weight prep block prefetching API

Split up the casting and weight formating/lora stuff in prep for
arbitrary prefetch support.

* ops: implement block prefetching API

allow a model to construct a prefetch list and operate it for increased
async offload.

* ltxv2: Implement block prefetching

* Implement lora async offload

Implement async offload of loras.
2026-05-02 19:23:24 -04:00
comfyanonymous
3e3ed8cc2a
Add script in AMD portable to launch with dynamic vram. (#13667)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-01 20:19:46 -04:00
comfyanonymous
67f6cb3527
List all the portable downloads in the README section. (#13666) 2026-05-01 20:19:32 -04:00
Alexis Rolland
0230e0e7cc
Adding kijai (#13664)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-05-02 06:37:18 +08:00
Jukka Seppänen
b5921c8ac2
SDPose: resize fix (#13656) 2026-05-01 14:17:25 -07:00
Simon Lui
63103d519e
Remove IPEX and clean up checks and add missing synchronize during empty cache. (#13653) 2026-05-01 14:16:41 -07:00
Alexander Piskun
cf758bd256
chore(api-nodes): increase default timeout for partner API node tasks (#13663)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-05-01 12:48:41 -07:00
Daxiong (Lin)
10b45a71cd
chore: update workflow templates to v0.9.66 (#13662)
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-05-01 12:11:30 -07:00
Alexander Piskun
fa7553138e
chore(api-nodes): remove Moonvalley API nodes (#13659)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-01 11:09:25 -07:00
Tashdid Khan
45a2363e6a fix: remove hardcoded local paths from MPS FP8 tests
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 21:27:35 -05:00
Tashdid Khan
edd44a6874 fix: add CPU fallback for FP8 quantization on MPS (Apple Silicon)
MPS does not support float8_e4m3fn/float8_e5m2 dtypes. When FP8-quantized
models (FLUX, SD3.5, Wan 2.2, LTX-Video) are loaded on Apple Silicon, the
quantization step crashes with:

  TypeError: Trying to convert Float8_e4m3fn to the MPS backend but it does
  not have support for that dtype.

This adds device-aware fallbacks that move tensors to CPU for the FP8
quantization step only. The rest of inference remains on MPS.

Three code paths are patched:
- comfy/float.py: stochastic_rounding() — also fixes the secondary
  "Placeholder storage has not been allocated on MPS device!" error
  caused by torch.Generator being bound to MPS.
- comfy/float.py: stochastic_round_quantize_nvfp4*() — these create
  float8_e4m3fn block scales internally.
- comfy/quant_ops.py: _TensorCoreFP8LayoutBase.quantize() — the
  ck.quantize_per_tensor_fp8 path also fails on MPS.

Fixes: #6995, #9255, #11626, #11817

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-09 21:21:31 -05:00
40 changed files with 2289 additions and 876 deletions

View File

@ -1,2 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-dynamic-vram
pause

View File

@ -1,2 +1,2 @@
# Admins
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai

View File

@ -1,7 +1,7 @@
<div align="center">
# ComfyUI
**The most powerful and modular visual AI engine and application.**
**The most powerful and modular AI engine for content creation.**
[![Website][website-shield]][website-url]
@ -31,10 +31,15 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/4aab0bef-b413-4595-9766-a2c134676d27" />
</div>
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
- ComfyUI natively supports the latest open-source state of the art models.
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
- It integrates seamlessly into production pipelines with our API endpoints.
## Get Started
@ -77,6 +82,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Ernie Image
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
@ -193,13 +199,15 @@ If you have trouble extracting it, right click the file -> properties -> unblock
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads:
#### All Official Portable Downloads:
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
[Portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
[Portable for Nvidia GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) (supports 20 series and above).
[Portable for Nvidia GPUs with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?

View File

@ -90,7 +90,6 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
class LatentPreviewMethod(enum.Enum):

View File

@ -55,6 +55,11 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
# MPS does not support FP8 dtypes — perform rounding on CPU and return the result there.
on_mps = value.device.type == "mps"
if on_mps:
value = value.cpu()
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)
@ -159,6 +164,12 @@ def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
# MPS does not support FP8 dtypes used for block scales — perform on CPU.
on_mps = x.device.type == "mps"
if on_mps:
x = x.cpu()
per_tensor_scale = per_tensor_scale.cpu() if isinstance(per_tensor_scale, torch.Tensor) else per_tensor_scale
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
@ -179,6 +190,12 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
# MPS does not support FP8 dtypes used for block scales — perform on CPU.
on_mps = x.device.type == "mps"
if on_mps:
x = x.cpu()
per_tensor_scale = per_tensor_scale.cpu() if isinstance(per_tensor_scale, torch.Tensor) else per_tensor_scale
orig_shape = x.shape
# Handle padding

View File

@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
import comfy.model_prefetch
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
a_prompt_timestep=a_prompt_timestep,
)
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):

View File

@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled():
import xformers
import xformers.ops
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
n_rep = q.shape[-3] // k.shape[-3]
k = k.repeat_interleave(n_rep, dim=-3)
v = v.repeat_interleave(n_rep, dim=-3)
scale = kwargs.get("scale", dim_head ** -0.5)
h = heads
if skip_reshape:
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape
dim_head //= heads
if "scale" in kwargs:
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
query = query * (kwargs["scale"] * dim_head ** 0.5)
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape:
q, k, v = map(
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
dropout_p=0.0, is_causal=False
dropout_p=0.0, is_causal=False, **sdpa_extra
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out

View File

@ -17,6 +17,7 @@
"""
from __future__ import annotations
import comfy.memory_management
import comfy.utils
import comfy.model_management
import comfy.model_base
@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
weight = old_weight
return weight
def prefetch_prepared_value(value, allocate_buffer, stream):
if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
return value

View File

@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module):
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
transformer_options = transformer_options.copy()
transformer_options["prefetch_dynamic_vbars"] = (
self.current_patcher is not None and self.current_patcher.is_dynamic()
)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.vram_buffer
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -112,10 +113,6 @@ if args.directml is not None:
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex # noqa: F401
except:
pass
try:
_ = torch.xpu.device_count()
@ -583,9 +580,6 @@ class LoadedModel:
real_model = self.model.model
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
with torch.no_grad():
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
@ -1182,6 +1176,10 @@ stream_counters = {}
STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
@ -1215,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
return cast_buffer
def get_aimdo_cast_buffer(offload_stream, device):
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):
@ -1581,10 +1592,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
if torch_version_numeric < (2, 3):
return True
else:
return torch.xpu.get_device_properties(device).has_fp16
return torch.xpu.get_device_properties(device).has_fp16
if is_ascend_npu():
return True
@ -1650,10 +1658,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
if torch_version_numeric < (2, 3):
return True
else:
return torch.xpu.is_bf16_supported()
return torch.xpu.is_bf16_supported()
if is_ascend_npu():
return True
@ -1784,6 +1789,7 @@ def soft_empty_cache(force=False):
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()
elif is_intel_xpu():
torch.xpu.synchronize()
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()

View File

@ -121,9 +121,20 @@ class LowVramPatch:
self.patches = patches
self.convert_func = convert_func # TODO: remove
self.set_func = set_func
self.prepared_patches = None
def prepare(self, allocate_buffer, stream):
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]
def clear_prepared(self):
self.prepared_patches = None
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2

65
comfy/model_prefetch.py Normal file
View File

@ -0,0 +1,65 @@
import comfy_aimdo.model_vbar
import comfy.model_management
import comfy.ops
PREFETCH_QUEUES = []
def cleanup_prefetched_modules(comfy_modules):
for s in comfy_modules:
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
continue
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
if prefetch["signature"] is not None:
comfy_aimdo.model_vbar.vbar_unpin(s._v)
delattr(s, "_prefetch")
def cleanup_prefetch_queues():
global PREFETCH_QUEUES
for queue in PREFETCH_QUEUES:
for entry in queue:
if entry is None or not isinstance(entry, tuple):
continue
_, prefetch_state = entry
comfy_modules = prefetch_state[1]
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
PREFETCH_QUEUES = []
def prefetch_queue_pop(queue, device, module):
if queue is None:
return
consumed = queue.pop(0)
if consumed is not None:
offload_stream, prefetch_state = consumed
offload_stream.wait_stream(comfy.model_management.current_stream(device))
_, comfy_modules = prefetch_state
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
prefetch = queue[0]
if prefetch is not None:
comfy_modules = []
for s in prefetch.modules():
if hasattr(s, "_v"):
comfy_modules.append(s)
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))
def make_prefetch_queue(queue, device, transformer_options):
if (not transformer_options.get("prefetch_dynamic_vbars", False)
or comfy.model_management.NUM_STREAMS == 0
or comfy.model_management.is_device_cpu(device)
or not comfy.model_management.device_supports_non_blocking(device)):
return None
queue = [None] + queue + [None]
PREFETCH_QUEUES.append(queue)
return queue

View File

@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
# FIXME: add n=1 cache hit fast path
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
offload_stream = None
xfer_dest = None
cast_buffer = None
cast_buffer_offset = 0
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
nonlocal cast_buffer
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
return
current_size = 0 if cast_buffer is None else cast_buffer.size()
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = None
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
def get_cast_buffer(buffer_size):
nonlocal offload_stream
nonlocal cast_buffer
nonlocal cast_buffer_offset
if buffer_size == 0:
return None
if offload_stream is None:
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
cast_buffer_offset += buffer_size
return buffer
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
prefetch = {
"signature": signature,
"resident": resident,
}
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
s._prefetch = prefetch
continue
if not resident:
materialize_meta_param(s, ["weight", "bias"])
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
needs_cast = False
xfer_source = [ s.weight, s.bias ]
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if data is None:
continue
if data.dtype != geometry.dtype:
needs_cast = True
cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None:
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
offload_stream = comfy.model_management.get_offload_stream(device)
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
if cast_dest is not None:
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
prefetch["cast_geometry"] = cast_geometry
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
return offload_stream
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
prefetch = getattr(s, "_prefetch", None)
if prefetch["resident"]:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = prefetch["xfer_dest"]
if prefetch["needs_cast"]:
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
weight = params[0]
bias = params[1]
if signature is not None:
if prefetch["signature"] is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
s._v_signature = prefetch["signature"]
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
if x is None:
return None
orig = x
def to_dequant(tensor, dtype):
@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = f(x)
return x
update_weight = signature is not None
update_weight = prefetch["signature"] is not None
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
if bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
return weight, bias
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
def format_return(result, offloadable):
weight, bias, offload_stream = result
return (weight, bias, offload_stream) if offloadable else (weight, bias)
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return format_return((weight, bias, (None, None, None)), offloadable)
prefetched = hasattr(s, "_prefetch")
offload_stream = None
offload_device = None
if not prefetched:
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
comfy.model_management.sync_stream(device, offload_stream)
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None:
offload_device = device
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
delattr(s, "_prefetch")
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.weight_function:
weight = f(weight)
if offloadable:
return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
def uncast_bias_weight(s, weight, bias, offload_stream):
@ -1173,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf)
return self
class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
if scale is not None:
scale = scale.float()
manually_loaded_keys.append(scale_key)
params = layout_cls.Params(
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.num_embeddings, self.embedding_dim),
)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight
# Optimized path: lookup in fp8, dequantize only the selected rows.
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
if isinstance(qdata, QuantizedTensor):
scale = qdata._params.scale
qdata = qdata._qdata
else:
scale = None
x = torch.nn.functional.embedding(
input, qdata, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
uncast_bias_weight(self, qdata, None, offload_stream)
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
x = x.to(dtype=target_dtype)
if scale is not None and scale != 1.0:
x = x * scale.to(dtype=target_dtype)
return x
# Fallback for non-quantized or weight_function (LoRA) case
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):

View File

@ -83,6 +83,12 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
# MPS does not support FP8 dtypes — move to CPU for quantization.
on_mps = tensor.device.type == "mps"
if on_mps:
tensor = tensor.cpu()
scale = scale.cpu()
if stochastic_rounding > 0:
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)

View File

@ -3,6 +3,7 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
def rms_norm(x, weight=None, eps=1e-6):
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)

View File

@ -65,6 +65,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.model_patcher
import comfy.lora
@ -1271,6 +1272,9 @@ class TEModel(Enum):
QWEN35_9B = 26
QWEN35_27B = 27
MINISTRAL_3_3B = 28
GEMMA_4_E4B = 29
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
def detect_te_model(sd):
@ -1296,6 +1300,12 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E4B
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E2B
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

File diff suppressed because it is too large Load Diff

View File

@ -521,7 +521,7 @@ class Attention(nn.Module):
else:
present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window:
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
@ -533,12 +533,12 @@ class Attention(nn.Module):
return self.o_proj(output), present_key_value
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
super().__init__()
ops = ops or nn
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.mlp_activation == "silu":
self.activation = torch.nn.functional.silu
elif config.mlp_activation == "gelu_pytorch_tanh":
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
return x, present_key_value
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
class ScaledEmbedding(ops.Embedding):
def forward(self, input_ids, out_dtype=None):
return super().forward(input_ids, out_dtype=out_dtype) * scale
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.normalize_in = True
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
else:
transformer = TransformerBlock
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
self.config.rope_dims,
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
if embeds is not None:
x = embeds
else:
x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
@ -850,7 +848,7 @@ class BaseGenerate:
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
device = embeds.device
if stop_tokens is None:
@ -875,14 +873,16 @@ class BaseGenerate:
pbar = comfy.utils.ProgressBar(max_length)
# Generation loop
current_input_ids = initial_input_ids
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
current_input_ids = next_token if initial_input_ids is not None else None
pbar.update(1)
if token_id in stop_tokens:

View File

@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):

View File

@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
embeds, _, _, _ = super().process_tokens(tokens, device)
return embeds
class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
nn.Module.__init__(self)
self.config = config
self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)

View File

@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@ -1,152 +0,0 @@
from enum import Enum
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field, StrictBytes
class MoonvalleyPromptResponse(BaseModel):
error: Optional[Dict[str, Any]] = None
frame_conditioning: Optional[Dict[str, Any]] = None
id: Optional[str] = None
inference_params: Optional[Dict[str, Any]] = None
meta: Optional[Dict[str, Any]] = None
model_params: Optional[Dict[str, Any]] = None
output_url: Optional[str] = None
prompt_text: Optional[str] = None
status: Optional[str] = None
class MoonvalleyTextToVideoInferenceParams(BaseModel):
add_quality_guidance: Optional[bool] = Field(
True, description='Whether to add quality guidance'
)
caching_coefficient: Optional[float] = Field(
0.3, description='Caching coefficient for optimization'
)
caching_cooldown: Optional[int] = Field(
3, description='Number of caching cooldown steps'
)
caching_warmup: Optional[int] = Field(
3, description='Number of caching warmup steps'
)
clip_value: Optional[float] = Field(
3, description='CLIP value for generation control'
)
conditioning_frame_index: Optional[int] = Field(
0, description='Index of the conditioning frame'
)
cooldown_steps: Optional[int] = Field(
75, description='Number of cooldown steps (calculated based on num_frames)'
)
fps: Optional[int] = Field(
24, description='Frames per second of the generated video'
)
guidance_scale: Optional[float] = Field(
10, description='Guidance scale for generation control'
)
height: Optional[int] = Field(
1080, description='Height of the generated video in pixels'
)
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
num_frames: Optional[int] = Field(64, description='Number of frames to generate')
seed: Optional[int] = Field(
None, description='Random seed for generation (default: random)'
)
shift_value: Optional[float] = Field(
3, description='Shift value for generation control'
)
steps: Optional[int] = Field(80, description='Number of denoising steps')
use_guidance_schedule: Optional[bool] = Field(
True, description='Whether to use guidance scheduling'
)
use_negative_prompts: Optional[bool] = Field(
False, description='Whether to use negative prompts'
)
use_timestep_transform: Optional[bool] = Field(
True, description='Whether to use timestep transformation'
)
warmup_steps: Optional[int] = Field(
0, description='Number of warmup steps (calculated based on num_frames)'
)
width: Optional[int] = Field(
1920, description='Width of the generated video in pixels'
)
class MoonvalleyTextToVideoRequest(BaseModel):
image_url: Optional[str] = None
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
prompt_text: Optional[str] = None
webhook_url: Optional[str] = None
class MoonvalleyUploadFileRequest(BaseModel):
file: Optional[StrictBytes] = None
class MoonvalleyUploadFileResponse(BaseModel):
access_url: Optional[str] = None
class MoonvalleyVideoToVideoInferenceParams(BaseModel):
add_quality_guidance: Optional[bool] = Field(
True, description='Whether to add quality guidance'
)
caching_coefficient: Optional[float] = Field(
0.3, description='Caching coefficient for optimization'
)
caching_cooldown: Optional[int] = Field(
3, description='Number of caching cooldown steps'
)
caching_warmup: Optional[int] = Field(
3, description='Number of caching warmup steps'
)
clip_value: Optional[float] = Field(
3, description='CLIP value for generation control'
)
conditioning_frame_index: Optional[int] = Field(
0, description='Index of the conditioning frame'
)
cooldown_steps: Optional[int] = Field(
36, description='Number of cooldown steps (calculated based on num_frames)'
)
guidance_scale: Optional[float] = Field(
15, description='Guidance scale for generation control'
)
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
seed: Optional[int] = Field(
None, description='Random seed for generation (default: random)'
)
shift_value: Optional[float] = Field(
3, description='Shift value for generation control'
)
steps: Optional[int] = Field(80, description='Number of denoising steps')
use_guidance_schedule: Optional[bool] = Field(
True, description='Whether to use guidance scheduling'
)
use_negative_prompts: Optional[bool] = Field(
False, description='Whether to use negative prompts'
)
use_timestep_transform: Optional[bool] = Field(
True, description='Whether to use timestep transformation'
)
warmup_steps: Optional[int] = Field(
24, description='Number of warmup steps (calculated based on num_frames)'
)
class ControlType(str, Enum):
motion_control = 'motion_control'
pose_control = 'pose_control'
class MoonvalleyVideoToVideoRequest(BaseModel):
control_type: ControlType = Field(
..., description='Supported types for video control'
)
inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None
prompt_text: str = Field(..., description='Describes the video to generate')
video_url: str = Field(..., description='Url to control video')
webhook_url: Optional[str] = Field(
None, description='Optional webhook URL for notifications'
)

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional
from pydantic import BaseModel, Field
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
grain: Optional[float] = Field(None, description="Grain after AI model processing")
grainSize: Optional[float] = Field(None, description="Size of generated grain")
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
realism: float | None = Field(None, description="ast-2 realism control")
class OutputInformationVideo(BaseModel):
@ -90,7 +93,7 @@ class Overrides(BaseModel):
class CreateVideoRequest(BaseModel):
source: CreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))

View File

@ -1403,7 +1403,6 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
poll_interval=9,
max_poll_attempts=180,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
@ -1585,7 +1584,6 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=False),
poll_interval=9,
max_poll_attempts=180,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
@ -1907,7 +1905,6 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
status_extractor=lambda r: r.status,
price_extractor=_seedance2_price_extractor(model_id, has_video_input=has_video_input),
poll_interval=9,
max_poll_attempts=180,
)
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))

View File

@ -178,7 +178,6 @@ class HitPawGeneralImageEnhance(IO.ComfyNode):
status_extractor=lambda x: x.data.status,
price_extractor=lambda x: request_price,
poll_interval=10.0,
max_poll_attempts=480,
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url))
@ -324,7 +323,6 @@ class HitPawVideoEnhance(IO.ComfyNode):
status_extractor=lambda x: x.data.status,
price_extractor=lambda x: request_price,
poll_interval=10.0,
max_poll_attempts=320,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url))

View File

@ -276,7 +276,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -3062,7 +3061,6 @@ class KlingVideoNode(IO.ComfyNode):
cls,
ApiEndpoint(path=poll_path),
response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -3188,7 +3186,6 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))

View File

@ -230,7 +230,6 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
@ -391,7 +390,6 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
@ -541,7 +539,6 @@ class MagnificImageStyleTransferNode(IO.ComfyNode):
response_model=TaskResponse,
status_extractor=lambda x: x.status,
poll_interval=10.0,
max_poll_attempts=480,
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
@ -782,7 +779,6 @@ class MagnificImageRelightNode(IO.ComfyNode):
response_model=TaskResponse,
status_extractor=lambda x: x.status,
poll_interval=10.0,
max_poll_attempts=480,
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
@ -924,7 +920,6 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode):
response_model=TaskResponse,
status_extractor=lambda x: x.status,
poll_interval=10.0,
max_poll_attempts=480,
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))

View File

@ -1,534 +0,0 @@
import logging
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.moonvalley import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyTextToVideoRequest,
MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
trim_video,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_container_format_is_mp4,
validate_image_dimensions,
validate_string,
)
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video"
API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video"
MIN_WIDTH = 300
MIN_HEIGHT = 300
MAX_WIDTH = 10000
MAX_HEIGHT = 10000
MIN_VID_WIDTH = 300
MIN_VID_HEIGHT = 300
MAX_VID_WIDTH = 10000
MAX_VID_HEIGHT = 10000
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
"""Verifies that the initial response contains a task ID."""
return bool(response.id)
def validate_task_creation_response(response) -> None:
if not is_valid_task_creation_response(response):
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
logging.error(error_msg)
raise RuntimeError(error_msg)
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
Args:
video: Input video to validate
Returns:
Validated and potentially trimmed video
Raises:
ValueError: If video doesn't meet requirements
MoonvalleyApiError: If video duration is too short
"""
width, height = _get_video_dimensions(video)
_validate_video_dimensions(width, height)
validate_container_format_is_mp4(video)
return _validate_and_trim_duration(video)
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
raise ValueError(f"Cannot get video dimensions: {e}") from e
def _validate_video_dimensions(width: int, height: int) -> None:
"""Validates video dimensions meet Moonvalley V2V requirements."""
supported_resolutions = {
(1920, 1080),
(1080, 1920),
(1152, 1152),
(1536, 1152),
(1152, 1536),
}
if (width, height) not in supported_resolutions:
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
return _trim_if_too_long(video, duration)
def _validate_minimum_duration(duration: float) -> None:
"""Ensures video is at least 5 seconds long."""
if duration < 5:
raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
return video
def parse_width_height_from_res(resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
# "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
return res_map.get(resolution, {"width": 1920, "height": 1080})
def parse_control_parameter(value):
control_map = {
"Motion Transfer": "motion_control",
"Canny": "canny_control",
"Pose Transfer": "pose_control",
"Depth": "depth_control",
}
return control_map.get(value, control_map["Motion Transfer"])
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
return await poll_op(
cls,
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
response_model=MoonvalleyPromptResponse,
status_extractor=lambda r: (r.status if r and r.status else None),
poll_interval=16.0,
max_poll_attempts=240,
)
class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyImg2VideoNode",
display_name="Moonvalley Marey Image to Video",
category="api node/video/Moonvalley Marey",
description="Moonvalley Marey Image to Video Node",
inputs=[
IO.Image.Input(
"image",
tooltip="The reference image used to generate the video",
),
IO.String.Input(
"prompt",
multiline=True,
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
IO.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
# "21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
IO.Float.Input(
"prompt_adherence",
default=4.5,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=True,
),
IO.Int.Input(
"steps",
default=80,
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Number of denoising steps",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 1.5}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
prompt: str,
negative_prompt: str,
resolution: str,
prompt_adherence: float,
seed: int,
steps: int,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
width=width_height["width"],
height=width_height["height"],
use_negative_prompts=True,
)
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
),
)
validate_task_creation_response(task_creation_response)
final_response = await get_response(cls, task_creation_response.id)
video = await download_url_to_video_output(final_response.output_url)
return IO.NodeOutput(video)
class MoonvalleyVideo2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyVideo2VideoNode",
display_name="Moonvalley Marey Video to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Describes the video to generate",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=False,
),
IO.Video.Input(
"video",
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
),
IO.Combo.Input(
"control_type",
options=["Motion Transfer", "Pose Transfer"],
default="Motion Transfer",
optional=True,
),
IO.Int.Input(
"motion_intensity",
default=100,
min=0,
max=100,
step=1,
tooltip="Only used if control_type is 'Motion Transfer'",
optional=True,
),
IO.Int.Input(
"steps",
default=60,
min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
max=100,
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Number of inference steps",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 2.25}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
seed: int,
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: int | None = 100,
steps=60,
prompt_adherence=4.5,
) -> IO.NodeOutput:
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(cls, validated_video)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
# Only include motion_intensity for Motion Transfer
control_params = {}
if control_type == "Motion Transfer" and motion_intensity is not None:
control_params["motion_intensity"] = motion_intensity
inference_params = MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
seed=seed,
control_params=control_params,
steps=steps,
guidance_scale=prompt_adherence,
)
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyVideoToVideoRequest(
control_type=parse_control_parameter(control_type),
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params,
),
)
validate_task_creation_response(task_creation_response)
final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyTxt2VideoNode",
display_name="Moonvalley Marey Text to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
IO.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
IO.Float.Input(
"prompt_adherence",
default=4.0,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Random seed value",
),
IO.Int.Input(
"steps",
default=80,
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Inference steps",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 1.5}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
resolution: str,
prompt_adherence: float,
seed: int,
steps: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height["width"],
height=width_height["height"],
)
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
)
validate_task_creation_response(task_creation_response)
final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
class MoonvalleyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
MoonvalleyImg2VideoNode,
MoonvalleyTxt2VideoNode,
MoonvalleyVideo2VideoNode,
]
async def comfy_entrypoint() -> MoonvalleyExtension:
return MoonvalleyExtension()

View File

@ -36,11 +36,15 @@ from comfy_api_nodes.util import (
)
UPSCALER_MODELS_MAP = {
"Astra 2": "ast-2",
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
"Starlight Precise 2.5": "slp-2.5",
}
AST2_MAX_FRAMES = 9000
AST2_MAX_FRAMES_WITH_PROMPT = 450
class TopazImageEnhance(IO.ComfyNode):
@classmethod
@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance",
display_name="Topaz Video Enhance (Legacy)",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
IO.Combo.Input(
"upscaler_model",
options=[
"Starlight (Astra) Fast",
"Starlight (Astra) Creative",
"Starlight Precise 2.5",
],
),
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"upscaler_creativity",
@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
@ -453,7 +465,350 @@ class TopazVideoEnhance(IO.ComfyNode):
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
poll_interval=10.0,
max_poll_attempts=320,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazVideoEnhanceV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhanceV2",
display_name="Topaz Video Enhance",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.DynamicCombo.Input(
"upscaler_model",
options=[
IO.DynamicCombo.Option(
"Astra 2",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Float.Input(
"creativity",
default=0.5,
min=0.0,
max=1.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Creative strength of the upscale.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional descriptive (not instructive) scene prompt."
f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.",
),
IO.Float.Input(
"sharp",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pre-enhance sharpness: "
"0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.",
advanced=True,
),
IO.Float.Input(
"realism",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pulls output toward photographic realism."
"Leave at 0 for the model default.",
advanced=True,
),
],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Fast",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Creative",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"creativity",
options=["low", "middle", "high"],
default="low",
tooltip="Creative strength of the upscale.",
),
],
),
IO.DynamicCombo.Option(
"Starlight Precise 2.5",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])],
),
IO.DynamicCombo.Option("Disabled", []),
],
),
IO.DynamicCombo.Input(
"interpolation_model",
options=[
IO.DynamicCombo.Option("Disabled", []),
IO.DynamicCombo.Option(
"apo-8",
[
IO.Int.Input(
"interpolation_frame_rate",
default=60,
min=15,
max=240,
display_mode=IO.NumberDisplay.number,
tooltip="Output frame rate.",
),
IO.Int.Input(
"interpolation_slowmo",
default=1,
min=1,
max=16,
display_mode=IO.NumberDisplay.number,
tooltip="Slow-motion factor applied to the input video. "
"For example, 2 makes the output twice as slow and doubles the duration.",
advanced=True,
),
IO.Boolean.Input(
"interpolation_duplicate",
default=False,
tooltip="Analyze the input for duplicate frames and remove them.",
advanced=True,
),
IO.Float.Input(
"interpolation_duplicate_threshold",
default=0.01,
min=0.001,
max=0.1,
step=0.001,
display_mode=IO.NumberDisplay.number,
tooltip="Detection sensitivity for duplicate frames.",
advanced=True,
),
],
),
],
),
IO.Combo.Input(
"dynamic_compression_level",
options=["Low", "Mid", "High"],
default="Low",
tooltip="CQP level.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=[
"upscaler_model",
"upscaler_model.upscaler_resolution",
"interpolation_model",
]),
expr="""
(
$model := $lookup(widgets, "upscaler_model");
$res := $lookup(widgets, "upscaler_model.upscaler_resolution");
$interp := $lookup(widgets, "interpolation_model");
$is4k := $contains($res, "4k");
$hasInterp := $interp != "disabled";
$rates := {
"starlight (astra) fast": {"hd": 0.43, "uhd": 0.85},
"starlight precise 2.5": {"hd": 0.70, "uhd": 1.54},
"astra 2": {"hd": 1.72, "uhd": 2.85},
"starlight (astra) creative": {"hd": 2.25, "uhd": 3.99}
};
$surcharge := $is4k ? 0.28 : 0.14;
$entry := $lookup($rates, $model);
$base := $is4k ? $entry.uhd : $entry.hd;
$hi := $base + ($hasInterp ? $surcharge : 0);
$model = "disabled"
? {"type":"text","text":"Interpolation only"}
: ($hasInterp
? {"type":"text","text":"~" & $string($base) & "" & $string($hi) & " credits/src frame"}
: {"type":"text","text":"~" & $string($base) & " credits/src frame"})
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
upscaler_model: dict,
interpolation_model: dict,
dynamic_compression_level: str = "Low",
) -> IO.NodeOutput:
upscaler_choice = upscaler_model["upscaler_model"]
interpolation_choice = interpolation_model["interpolation_model"]
if upscaler_choice == "Disabled" and interpolation_choice == "Disabled":
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
src_video_stream = video.get_stream_source()
target_width = src_width
target_height = src_height
target_frame_rate = src_frame_rate
filters = []
if upscaler_choice != "Disabled":
if "1080p" in upscaler_model["upscaler_resolution"]:
target_pixel_p = 1080
max_long_side = 1920
else:
target_pixel_p = 2160
max_long_side = 3840
ar = src_width / src_height
if src_width >= src_height:
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
target_height = target_pixel_p
target_width = int(target_height * ar)
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
if target_width > max_long_side:
target_width = max_long_side
target_height = int(target_width / ar)
else:
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
target_width = target_pixel_p
target_height = int(target_width / ar)
# Check if height exceeds standard bounds
if target_height > max_long_side:
target_height = max_long_side
target_width = int(target_height * ar)
if target_width % 2 != 0:
target_width += 1
if target_height % 2 != 0:
target_height += 1
model_id = UPSCALER_MODELS_MAP[upscaler_choice]
if model_id == "slc-1":
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
isOptimizedMode=True,
)
)
elif model_id == "ast-2":
n_frames = video.get_frame_count()
ast2_prompt = (upscaler_model["prompt"] or "").strip()
if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT:
raise ValueError(
f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames "
f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip."
)
if n_frames > AST2_MAX_FRAMES:
raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.")
realism = upscaler_model["realism"]
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
prompt=(ast2_prompt or None),
sharp=upscaler_model["sharp"],
realism=(realism if realism > 0 else None),
)
)
else:
filters.append(VideoEnhancementFilter(model=model_id))
if interpolation_choice != "Disabled":
target_frame_rate = interpolation_model["interpolation_frame_rate"]
filters.append(
VideoFrameInterpolationFilter(
model=interpolation_choice,
slowmo=interpolation_model["interpolation_slowmo"],
fps=interpolation_model["interpolation_frame_rate"],
duplicate=interpolation_model["interpolation_duplicate"],
duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"],
),
)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
response_model=CreateVideoResponse,
data=CreateVideoRequest(
source=CreateVideoRequestSource(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=video.get_frame_count(),
frameRate=src_frame_rate,
resolution=Resolution(width=src_width, height=src_height),
),
filters=filters,
output=OutputInformationVideo(
resolution=Resolution(width=target_width, height=target_height),
frameRate=target_frame_rate,
audioCodec="AAC",
audioTransfer="Copy",
dynamicCompressionLevel=dynamic_compression_level,
),
),
wait_label="Creating task",
final_label_on_success="Task created",
)
upload_res = await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
method="PATCH",
),
response_model=VideoAcceptResponse,
wait_label="Preparing upload",
final_label_on_success="Upload started",
)
if len(upload_res.urls) > 1:
raise NotImplementedError(
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
)
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
if isinstance(src_video_stream, BytesIO):
src_video_stream.seek(0)
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
else:
with builtins.open(src_video_stream, "rb") as video_file:
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
method="PATCH",
),
response_model=VideoCompleteUploadResponse,
data=VideoCompleteUploadRequest(
uploadResults=[
VideoCompleteUploadRequestPart(
partNum=1,
eTag=upload_etag,
),
],
),
wait_label="Finalizing upload",
final_label_on_success="Upload completed",
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
response_model=VideoStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
poll_interval=10.0,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
@ -464,6 +819,7 @@ class TopazExtension(ComfyExtension):
return [
TopazImageEnhance,
TopazVideoEnhance,
TopazVideoEnhanceV2,
]

View File

@ -38,7 +38,7 @@ async def execute_task(
cls: type[IO.ComfyNode],
vidu_endpoint: str,
payload: TaskCreationRequest | TaskExtendCreationRequest | TaskMultiFrameCreationRequest,
max_poll_attempts: int = 320,
max_poll_attempts: int = 480,
) -> list[TaskResult]:
task_creation_response = await sync_op(
cls,
@ -1097,7 +1097,6 @@ class ViduExtendVideoNode(IO.ComfyNode):
video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading video"),
images=[image_url] if image_url else None,
),
max_poll_attempts=480,
)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))

View File

@ -818,7 +818,6 @@ class WanReferenceVideoApi(IO.ComfyNode):
response_model=VideoTaskStatusResponse,
status_extractor=lambda x: x.output.task_status,
poll_interval=6,
max_poll_attempts=280,
)
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))

View File

@ -84,7 +84,6 @@ class WavespeedFlashVSRNode(IO.ComfyNode):
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=10.0,
max_poll_attempts=480,
)
if final_response.code != 200:
raise ValueError(
@ -156,7 +155,6 @@ class WavespeedImageUpscaleNode(IO.ComfyNode):
response_model=TaskResultResponse,
status_extractor=lambda x: "failed" if x.data is None else x.data.status,
poll_interval=10.0,
max_poll_attempts=480,
)
if final_response.code != 200:
raise ValueError(

View File

@ -148,7 +148,7 @@ async def poll_op(
queued_statuses: list[str | int] | None = None,
data: BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
max_poll_attempts: int = 480,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
@ -254,7 +254,7 @@ async def poll_op_raw(
queued_statuses: list[str | int] | None = None,
data: dict[str, Any] | BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
max_poll_attempts: int = 480,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,

View File

@ -459,27 +459,23 @@ class SDPoseKeypointExtractor(io.ComfyNode):
total_images = image.shape[0]
captured_feat = None
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
model_w = int(head.heatmap_size[0]) * 4 # 192 * 4 = 768
model_h = int(head.heatmap_size[1]) * 4 # 256 * 4 = 1024
def _resize_to_model(imgs):
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
"""Stretch BHWC images to (model_h, model_w), model expects no aspect preservation."""
h, w = imgs.shape[-3], imgs.shape[-2]
scale = min(model_h / h, model_w / w)
sh, sw = int(round(h * scale)), int(round(w * scale))
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
method = "area" if (model_h <= h and model_w <= w) else "bilinear"
chw = imgs.permute(0, 3, 1, 2).float()
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
return padded.permute(0, 2, 3, 1), scale, pt, pl
scaled = comfy.utils.common_upscale(chw, model_w, model_h, upscale_method=method, crop="disabled")
return scaled.permute(0, 2, 3, 1), model_w / w, model_h / h
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
def _remap_keypoints(kp, scale_x, scale_y, offset_x=0, offset_y=0):
"""Remap keypoints from model space back to original image space."""
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
invalid = kp[..., 0] < 0
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
kp[..., 0] = kp[..., 0] / scale_x + offset_x
kp[..., 1] = kp[..., 1] / scale_y + offset_y
kp[invalid] = -1
return kp
@ -529,18 +525,18 @@ class SDPoseKeypointExtractor(io.ComfyNode):
continue
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
crop_resized, sx, sy = _resize_to_model(crop)
latent_crop = vae.encode(crop_resized)
kp_batch, sc_batch = _run_on_latent(latent_crop)
kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
kp = _remap_keypoints(kp_batch[0], sx, sy, x1, y1)
img_keypoints.append(kp)
img_scores.append(sc_batch[0])
else:
img_resized, scale, pad_top, pad_left = _resize_to_model(img)
img_resized, sx, sy = _resize_to_model(img)
latent_img = vae.encode(img_resized)
kp_batch, sc_batch = _run_on_latent(latent_img)
img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
img_keypoints.append(_remap_keypoints(kp_batch[0], sx, sy))
img_scores.append(sc_batch[0])
all_keypoints.append(img_keypoints)
@ -549,12 +545,12 @@ class SDPoseKeypointExtractor(io.ComfyNode):
else: # full-image mode, batched
for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
batch_resized, sx, sy = _resize_to_model(image[batch_start:batch_start + batch_size])
latent_batch = vae.encode(batch_resized)
kp_batch, sc_batch = _run_on_latent(latent_batch)
for kp, sc in zip(kp_batch, sc_batch):
all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
all_keypoints.append([_remap_keypoints(kp, sx, sy)])
all_scores.append([sc])
pbar.update(len(kp_batch))
@ -727,13 +723,13 @@ class CropByBBoxes(io.ComfyNode):
scale = min(output_width / crop_w, output_height / crop_h)
scaled_w = int(round(crop_w * scale))
scaled_h = int(round(crop_h * scale))
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="area", crop="disabled")
pad_left = (output_width - scaled_w) // 2
pad_top = (output_height - scaled_h) // 2
resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)
resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
else: # "stretch"
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="area", crop="disabled")
crops.append(resized)
if not crops:

View File

@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("image", optional=True),
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
io.Audio.Input("audio", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
# Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on"
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
seed=seed
)
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
generated_text = clip.decode(generated_ids)
return io.NodeOutput(generated_text)
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
class TextgenExtension(ComfyExtension):

View File

@ -15,6 +15,7 @@ import torch
from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
import comfy.model_prefetch
import comfy_aimdo.model_vbar
from latent_preview import set_preview_method
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if args.verbose == "DEBUG":
comfy_aimdo.control.analyze()
comfy.model_management.reset_cast_buffers()
comfy.model_prefetch.cleanup_prefetch_queues()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks:

View File

@ -1694,26 +1694,27 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
dtype = comfy.model_management.intermediate_dtype()
device = comfy.model_management.intermediate_device()
components = InputImpl.VideoFromFile(image_path).get_components()
if components.images.shape[0] > 0:
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu"))
return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device))
# This code is left here to handle animated webp which pyav does not support loading
img = node_helpers.pillow(Image.open, image_path)
output_images = []
output_masks = []
w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
@ -1728,25 +1729,15 @@ class LoadImage:
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO":
break # ignore all frames except the first one for MPO format
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype))
@classmethod
def IS_CHANGED(s, image):

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.15
comfyui-workflow-templates==0.9.65
comfyui-workflow-templates==0.9.68
comfyui-embedded-docs==0.4.4
torch
torchsde

147
tests/test_fp8_mps.py Normal file
View File

@ -0,0 +1,147 @@
"""
Tests for FP8 quantization on MPS (Apple Silicon) devices.
MPS does not natively support float8_e4m3fn or float8_e5m2 dtypes.
These tests verify that:
1. FP8 operations correctly fall back to CPU when on MPS.
2. The round-trip (quantize on CPU -> result on original device) is numerically sound.
3. No "Placeholder storage has not been allocated on MPS device!" errors occur.
"""
import pytest
import torch
import comfy.float
from comfy.quant_ops import TensorCoreFP8E4M3Layout, TensorCoreFP8E5M2Layout
# Skip the entire module if MPS is not available
pytestmark = pytest.mark.skipif(
not torch.backends.mps.is_available(),
reason="MPS backend not available"
)
# ── helpers ──────────────────────────────────────────────────────────────────
def _make_mps_tensor(shape=(256, 256), dtype=torch.float32):
return torch.randn(shape, device="mps", dtype=dtype)
# ── Tests for comfy.float ────────────────────────────────────────────────────
class TestStochasticRoundingMPS:
"""Tests for comfy.float.stochastic_rounding on MPS device."""
def test_stochastic_rounding_fp8_e4m3fn_on_mps(self):
"""stochastic_rounding must not crash when input is on MPS and target dtype is float8_e4m3fn."""
x = _make_mps_tensor((64, 64), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
assert result.dtype == torch.float8_e4m3fn
assert result.shape == x.shape
def test_stochastic_rounding_fp8_e5m2_on_mps(self):
"""stochastic_rounding must not crash when input is on MPS and target dtype is float8_e5m2."""
x = _make_mps_tensor((64, 64), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e5m2, seed=42)
assert result.dtype == torch.float8_e5m2
assert result.shape == x.shape
def test_stochastic_rounding_fp8_result_on_cpu(self):
"""Result of FP8 rounding from MPS input should be on CPU (since MPS can't hold FP8)."""
x = _make_mps_tensor((32, 32), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
# FP8 tensors cannot live on MPS, so result must be on CPU
assert result.device.type == "cpu"
def test_stochastic_rounding_non_fp8_still_works(self):
"""Non-FP8 dtypes on MPS must still work as before (no regression)."""
x = _make_mps_tensor((32, 32), dtype=torch.float32)
r16 = comfy.float.stochastic_rounding(x, dtype=torch.float16, seed=0)
assert r16.dtype == torch.float16
assert r16.device.type == "mps"
rbf16 = comfy.float.stochastic_rounding(x, dtype=torch.bfloat16, seed=0)
assert rbf16.dtype == torch.bfloat16
assert rbf16.device.type == "mps"
def test_stochastic_rounding_fp8_numerical_sanity(self):
"""FP8 round-trip (float32 -> fp8 -> float32) should have bounded error."""
x = torch.randn(128, 128, device="mps", dtype=torch.float32)
x_clamped = torch.clamp(x, min=-448, max=448) # FP8 e4m3fn range
fp8 = comfy.float.stochastic_rounding(x_clamped, dtype=torch.float8_e4m3fn, seed=123)
# Convert back to float32 for comparison
reconstructed = fp8.to(torch.float32)
# Max relative error should be bounded (FP8 e4m3fn has ~0.125 relative precision)
x_cpu = x_clamped.cpu()
max_abs_err = (reconstructed - x_cpu).abs().max().item()
# FP8 e4m3fn max value is 448, min subnormal ~0.001953
# For random normal data, error should be well under 1.0
assert max_abs_err < 2.0, f"FP8 round-trip error too large: {max_abs_err}"
class TestManualStochasticRoundMPS:
"""Tests for comfy.float.manual_stochastic_round_to_float8 on MPS device."""
def test_manual_round_fp8_on_mps_tensor(self):
"""stochastic_rounding handles MPS generator internally without 'Placeholder storage' error."""
x = _make_mps_tensor((16, 16), dtype=torch.float32)
result = comfy.float.stochastic_rounding(x, dtype=torch.float8_e4m3fn, seed=42)
assert result.dtype == torch.float8_e4m3fn
class TestNVFP4StochasticRoundMPS:
"""Tests for NVFP4 stochastic rounding on MPS - also creates FP8 tensors internally."""
def test_nvfp4_stochastic_round_on_mps(self):
"""stochastic_round_quantize_nvfp4 creates FP8 block scales internally."""
# NVFP4 requires 2D input with dimensions divisible by 16
x = torch.randn(32, 32, device="mps", dtype=torch.float32)
scale = torch.tensor(1.0, device="mps", dtype=torch.float32)
# This should not crash - internally creates float8_e4m3fn block scales
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(
x, scale, pad_16x=False, seed=42
)
assert qdata.dtype == torch.uint8
# ── Tests for comfy.quant_ops (integration) ──────────────────────────────────
class TestQuantOpsMPS:
"""Tests for the quantization ops layer that calls into comfy.float."""
def test_fp8_layout_quantize_on_mps(self):
"""TensorCoreFP8E4M3Layout.quantize must work with MPS tensors."""
x = _make_mps_tensor((64, 64), dtype=torch.bfloat16)
qdata, params = TensorCoreFP8E4M3Layout.quantize(
x, scale="recalculate", stochastic_rounding=42
)
assert qdata.dtype == torch.float8_e4m3fn
assert params.orig_dtype == torch.bfloat16
def test_fp8_layout_quantize_without_stochastic_on_mps(self):
"""TensorCoreFP8E4M3Layout.quantize with stochastic_rounding=0 uses ck.quantize_per_tensor_fp8."""
x = _make_mps_tensor((64, 64), dtype=torch.bfloat16)
qdata, params = TensorCoreFP8E4M3Layout.quantize(
x, scale="recalculate", stochastic_rounding=0
)
assert qdata.dtype == torch.float8_e4m3fn
def test_fp8_e5m2_layout_quantize_on_mps(self):
"""TensorCoreFP8E5M2Layout.quantize must work with MPS tensors."""
x = _make_mps_tensor((64, 64), dtype=torch.float32)
qdata, params = TensorCoreFP8E5M2Layout.quantize(
x, scale="recalculate", stochastic_rounding=42
)
assert qdata.dtype == torch.float8_e5m2
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])