Compare commits

...

37 Commits

Author SHA1 Message Date
Yousef R. Gamaleldin
34767ef157
Merge fb8726244c into 783782d5d7 2026-05-02 19:34:45 -04:00
rattus
783782d5d7
Implement block prefetch + Lora Async load + and adopt in LTX (Speedup!) (CORE-111) (#13618)
* mm: Use Aimdo raw allocator for cast buffers

pytorch manages allocation of growing buffers on streams poorly. Pyt
has no windows support for the expandable segments allocator (which is
the right tool for this job), while also segmenting the memory by
stream such that it can be generally re-used. So kick the problem to
aimdo which can just grow a virtual region thats freed per stream.

* plan

* ops: move cpu handler up to the caller

* ops: split up prefetch from weight prep block prefetching API

Split up the casting and weight formating/lora stuff in prep for
arbitrary prefetch support.

* ops: implement block prefetching API

allow a model to construct a prefetch list and operate it for increased
async offload.

* ltxv2: Implement block prefetching

* Implement lora async offload

Implement async offload of loras.
2026-05-02 19:23:24 -04:00
Yousef Rafat
fb8726244c Merge branch 'advanced_save' of https://github.com/yousef-rafat/ComfyUI into advanced_save 2026-05-01 14:29:56 +03:00
Yousef Rafat
88d7b1bcab workflow embedded fix 2026-05-01 14:29:44 +03:00
Yousef R. Gamaleldin
f6c6c4c2b7
remove linear from png
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-05-01 11:10:24 +03:00
Yousef R. Gamaleldin
0e3c8c07c3
remvoe linear and raw from avif
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-04-30 15:12:02 +03:00
Yousef R. Gamaleldin
f10bb1e780
remove srgb from exr
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-04-30 15:08:19 +03:00
Yousef Rafat
59075cf255 display_name 2026-04-30 15:02:26 +03:00
Yousef Rafat
15f993a036 remove 12-bit 2026-04-30 14:58:18 +03:00
Yousef R. Gamaleldin
e996b817cd
Update comfy_extras/nodes_convert_color_space.py
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-04-30 14:57:41 +03:00
Yousef R. Gamaleldin
632771d988
remove download
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-04-30 14:53:47 +03:00
Yousef Rafat
87514354a5 ... 2026-04-30 11:43:53 +03:00
Yousef R. Gamaleldin
87994368cc
Merge branch 'master' into advanced_save 2026-04-30 02:35:05 +03:00
Yousef Rafat
5648a89f94 exr fix 2026-04-29 23:15:09 +03:00
Yousef Rafat
6620c1898b .. 2026-04-29 23:11:36 +03:00
Yousef Rafat
7069f6a92f transpose error for exr 2026-04-29 22:36:01 +03:00
Yousef Rafat
c77b36d98f add interpret_as 2026-04-29 21:23:28 +03:00
Yousef Rafat
5cca14c798 quick fix for alpha 2026-04-29 15:07:35 +03:00
Yousef Rafat
1e8fc2f1a8 .. 2026-04-29 15:06:03 +03:00
Yousef Rafat
4b51c8f774 imports 2026-04-29 14:35:33 +03:00
Yousef Rafat
e3e26fbdb0 . 2026-04-29 14:35:00 +03:00
Yousef Rafat
8e3396c035 renaming 2026-04-29 14:32:48 +03:00
Yousef Rafat
cde0936e42 interept_as as combo 2026-04-29 14:06:40 +03:00
Yousef R. Gamaleldin
14e114a936
Merge pull request #4 from alexisrolland/advanced_save
Iterate on new Save Image node
2026-04-29 14:00:38 +03:00
Yousef R. Gamaleldin
595683653a
Merge branch 'advanced_save' into advanced_save 2026-04-29 13:56:54 +03:00
Yousef Rafat
69227ff007 Merge branch 'master' into advanced_save 2026-04-29 13:15:29 +03:00
Yousef Rafat
693919e787 moving to nodes_images 2026-04-29 13:09:44 +03:00
Yousef Rafat
941a4a9203 move covert color space node with grayscale fix 2026-04-29 13:02:53 +03:00
Alexis Rolland
923c2afd96 Rename file_format to format for consistency with Save Video node
Co-authored-by: Copilot <copilot@github.com>
2026-04-29 12:52:03 +08:00
Alexis Rolland
92ab48531f Iterate on new Save Image node 2026-04-29 12:12:12 +08:00
Alexis Rolland
8ef7bb175b
Merge branch 'master' into advanced_save 2026-04-29 09:02:39 +08:00
Yousef Rafat
34d7a5da73 default pyav build doesn't support float16 exr 2026-04-28 22:56:42 +03:00
Yousef Rafat
29653e7fcc png, exr, avif 2026-04-28 21:40:13 +03:00
Yousef Rafat
80e421b626 correct avif code 2026-04-27 21:52:44 +03:00
Yousef Rafat
ebb9acf3cf bytesio -> tempfile 2026-04-24 23:27:59 +03:00
Yousef R. Gamaleldin
6fa0488449
Merge branch 'master' into advanced_save 2026-04-24 23:17:46 +03:00
Yousef Rafat
8fff8814ad init 2026-04-24 23:09:23 +03:00
11 changed files with 711 additions and 58 deletions

View File

@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
import comfy.model_prefetch
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
a_prompt_timestep=a_prompt_timestep,
)
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):

View File

@ -17,6 +17,7 @@
"""
from __future__ import annotations
import comfy.memory_management
import comfy.utils
import comfy.model_management
import comfy.model_base
@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
weight = old_weight
return weight
def prefetch_prepared_value(value, allocate_buffer, stream):
if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
return value

View File

@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module):
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
transformer_options = transformer_options.copy()
transformer_options["prefetch_dynamic_vbars"] = (
self.current_patcher is not None and self.current_patcher.is_dynamic()
)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.vram_buffer
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@ -1175,6 +1176,10 @@ stream_counters = {}
STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
return cast_buffer
def get_aimdo_cast_buffer(offload_stream, device):
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):

View File

@ -121,9 +121,20 @@ class LowVramPatch:
self.patches = patches
self.convert_func = convert_func # TODO: remove
self.set_func = set_func
self.prepared_patches = None
def prepare(self, allocate_buffer, stream):
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]
def clear_prepared(self):
self.prepared_patches = None
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2

65
comfy/model_prefetch.py Normal file
View File

@ -0,0 +1,65 @@
import comfy_aimdo.model_vbar
import comfy.model_management
import comfy.ops
PREFETCH_QUEUES = []
def cleanup_prefetched_modules(comfy_modules):
for s in comfy_modules:
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
continue
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
if prefetch["signature"] is not None:
comfy_aimdo.model_vbar.vbar_unpin(s._v)
delattr(s, "_prefetch")
def cleanup_prefetch_queues():
global PREFETCH_QUEUES
for queue in PREFETCH_QUEUES:
for entry in queue:
if entry is None or not isinstance(entry, tuple):
continue
_, prefetch_state = entry
comfy_modules = prefetch_state[1]
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
PREFETCH_QUEUES = []
def prefetch_queue_pop(queue, device, module):
if queue is None:
return
consumed = queue.pop(0)
if consumed is not None:
offload_stream, prefetch_state = consumed
offload_stream.wait_stream(comfy.model_management.current_stream(device))
_, comfy_modules = prefetch_state
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
prefetch = queue[0]
if prefetch is not None:
comfy_modules = []
for s in prefetch.modules():
if hasattr(s, "_v"):
comfy_modules.append(s)
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))
def make_prefetch_queue(queue, device, transformer_options):
if (not transformer_options.get("prefetch_dynamic_vbars", False)
or comfy.model_management.NUM_STREAMS == 0
or comfy.model_management.is_device_cpu(device)
or not comfy.model_management.device_supports_non_blocking(device)):
return None
queue = [None] + queue + [None]
PREFETCH_QUEUES.append(queue)
return queue

View File

@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
# FIXME: add n=1 cache hit fast path
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
offload_stream = None
xfer_dest = None
cast_buffer = None
cast_buffer_offset = 0
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
nonlocal cast_buffer
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
return
current_size = 0 if cast_buffer is None else cast_buffer.size()
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = None
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
def get_cast_buffer(buffer_size):
nonlocal offload_stream
nonlocal cast_buffer
nonlocal cast_buffer_offset
if buffer_size == 0:
return None
if offload_stream is None:
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
cast_buffer_offset += buffer_size
return buffer
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
prefetch = {
"signature": signature,
"resident": resident,
}
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
s._prefetch = prefetch
continue
if not resident:
materialize_meta_param(s, ["weight", "bias"])
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
needs_cast = False
xfer_source = [ s.weight, s.bias ]
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if data is None:
continue
if data.dtype != geometry.dtype:
needs_cast = True
cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None:
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
offload_stream = comfy.model_management.get_offload_stream(device)
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
if cast_dest is not None:
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
prefetch["cast_geometry"] = cast_geometry
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
return offload_stream
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
prefetch = getattr(s, "_prefetch", None)
if prefetch["resident"]:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = prefetch["xfer_dest"]
if prefetch["needs_cast"]:
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
weight = params[0]
bias = params[1]
if signature is not None:
if prefetch["signature"] is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
s._v_signature = prefetch["signature"]
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
if x is None:
return None
orig = x
def to_dequant(tensor, dtype):
@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = f(x)
return x
update_weight = signature is not None
update_weight = prefetch["signature"] is not None
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
if bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
return weight, bias
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
def format_return(result, offloadable):
weight, bias, offload_stream = result
return (weight, bias, offload_stream) if offloadable else (weight, bias)
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return format_return((weight, bias, (None, None, None)), offloadable)
prefetched = hasattr(s, "_prefetch")
offload_stream = None
offload_device = None
if not prefetched:
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
comfy.model_management.sync_stream(device, offload_stream)
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None:
offload_device = device
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
delattr(s, "_prefetch")
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.weight_function:
weight = f(weight)
if offloadable:
return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
def uncast_bias_weight(s, weight, bias, offload_stream):

View File

@ -0,0 +1,108 @@
import torch
from comfy_api.latest import IO
from typing_extensions import override
from comfy_api.latest import ComfyExtension
# Rec.709 to Rec.2020 Gamut Conversion Matrix
M_709_to_2020 = torch.tensor([[0.6274, 0.3293, 0.0433],[0.0691, 0.9195, 0.0114],[0.0164, 0.0880, 0.8956]
])
# Rec.2020 to Rec.709 Gamut Conversion Matrix
M_2020_to_709 = torch.tensor([[ 1.6605, -0.5876, -0.0728],[-0.1246, 1.1329, -0.0083],[-0.0182, -0.1006, 1.1187]
])
def srgb_to_linear(tensor):
mask = tensor <= 0.04045
return torch.where(mask, tensor / 12.92, torch.pow((tensor + 0.055) / 1.055, 2.4))
def linear_to_srgb(tensor):
mask = tensor <= 0.0031308
return torch.where(mask, tensor * 12.92, 1.055 * torch.pow(tensor.clamp(min=1e-8), 1.0 / 2.4) - 0.055)
def linear_to_pq(linear_tensor):
"""SMPTE ST 2084 (PQ) encoding"""
m1, m2 = (2610 / 4096 / 4), (2523 / 4096 * 128)
c1, c2, c3 = (3424 / 4096), (2413 / 4096 * 32), (2392 / 4096 * 32)
l_norm = torch.clamp(linear_tensor, 0.0, 1.0)
l_m1 = torch.pow(l_norm, m1)
return torch.pow((c1 + c2 * l_m1) / (1 + c3 * l_m1), m2)
def pq_to_linear(pq_tensor):
"""Inverse SMPTE ST 2084 (PQ) decoding"""
m1, m2 = (2610 / 4096 / 4), (2523 / 4096 * 128)
c1, c2, c3 = (3424 / 4096), (2413 / 4096 * 32), (2392 / 4096 * 32)
n = torch.pow(torch.clamp(pq_tensor, 0.0, 1.0), 1/m2)
return torch.pow(torch.clamp((n - c1) / (c2 - c3 * n), min=0.0), 1/m1)
class ConvertColorSpace(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ConvertColorSpace",
display_name="Convert Color Space",
category="image/color",
inputs=[
IO.Image.Input("images"),
IO.Combo.Input("source_color_space", options=["sRGB", "Linear", "HDR Display (PQ/Rec.2020)", "Grayscale"], default="sRGB"),
IO.Combo.Input("target_color_space", options=["sRGB", "Linear", "HDR Display (PQ/Rec.2020)", "Grayscale"], default="Linear"),
],
outputs=[
IO.Image.Output("images"),
]
)
@classmethod
def execute(cls, images, source_color_space, target_color_space) -> IO.NodeOutput:
img_tensor = images.clone()
device = img_tensor.device
has_alpha = img_tensor.shape[-1] == 4
alpha = img_tensor[..., 3:4] if has_alpha else None
rgb = img_tensor[..., :3]
# turn source into linear
if source_color_space == "sRGB":
rgb = srgb_to_linear(rgb)
elif source_color_space == "Grayscale":
# assume Grayscale has sRGB gamma
luma = 0.2126 * rgb[..., 0] + 0.7152 * rgb[..., 1] + 0.0722 * rgb[..., 2]
rgb = luma.unsqueeze(-1).repeat(1, 1, 1, 3)
rgb = srgb_to_linear(rgb)
elif source_color_space == "HDR Display (PQ/Rec.2020)":
# assuming Linear Rec.2020 input. Convert to Linear Rec.709
matrix = M_2020_to_709.to(device=device, dtype=rgb.dtype)
rgb = pq_to_linear(rgb)
rgb = torch.matmul(rgb, matrix.T)
# turn source into target space
if target_color_space == "sRGB":
rgb = linear_to_srgb(rgb)
elif target_color_space == "Grayscale":
luma = 0.2126 * rgb[..., 0] + 0.7152 * rgb[..., 1] + 0.0722 * rgb[..., 2]
rgb = luma.unsqueeze(-1).repeat(1, 1, 1, 3)
rgb = linear_to_srgb(rgb) # reapply srgb gamma
elif target_color_space == "HDR Display (PQ/Rec.2020)":
# convert Gamut from Linear Rec.709 to Linear Rec.2020
rgb = torch.matmul(rgb, M_709_to_2020.to(device=device, dtype=rgb.dtype).T).clamp(min=0)
rgb = linear_to_pq(rgb)
img_tensor = torch.cat([rgb, alpha], dim=-1) if has_alpha else rgb
return IO.NodeOutput(img_tensor)
class ConvertColorSpaceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ConvertColorSpace
]
async def comfy_entrypoint() -> ConvertColorSpaceExtension:
return ConvertColorSpaceExtension()

View File

@ -3,15 +3,25 @@ from __future__ import annotations
import nodes
import folder_paths
import av
import json
import os
import re
import math
import numpy as np
import struct
import torch
import zlib
import tempfile
import logging
import comfy.utils
from fractions import Fraction
from server import PromptServer
from comfy_api.latest import ComfyExtension, IO, UI
from comfy.cli_args import args
from typing_extensions import override
SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later.
@ -823,6 +833,344 @@ class ImageMergeTileList(IO.ComfyNode):
return IO.NodeOutput(merged_image)
def create_png_chunk(chunk_type: bytes, data: bytes) -> bytes:
"""Creates a valid PNG chunk with Length, Type, Data, and CRC32."""
chunk = struct.pack('>I', len(data)) + chunk_type + data
crc = zlib.crc32(chunk_type + data) & 0xffffffff
return chunk + struct.pack('>I', crc)
def inject_comfy_metadata_png(png_bytes, prompt=None, extra_pnginfo=None):
# IEND chunk is the last 12 bytes of png files
content = png_bytes[:-12]
iend = png_bytes[-12:]
metadata_chunks = b""
if prompt is not None:
payload = b'prompt\x00' + json.dumps(prompt).encode('utf-8')
metadata_chunks += create_png_chunk(b'tEXt', payload)
if extra_pnginfo is not None:
for k, v in extra_pnginfo.items():
payload = k.encode('utf-8') + b'\x00' + json.dumps(v).encode('utf-8')
metadata_chunks += create_png_chunk(b'tEXt', payload)
return content + metadata_chunks + iend
def inject_comfy_metadata_exr(exr_bytes: bytes, prompt, extra_pnginfo) -> bytes:
# skip magic and version
idx = 8
# parse through existing attributes to find the end of the header
while True:
name_start = idx
while exr_bytes[idx] != 0:
idx += 1
name = exr_bytes[name_start:idx]
idx += 1
# empty name means we hit the header terminator
if len(name) == 0:
break
# skip attribute type string
while exr_bytes[idx] != 0:
idx += 1
idx += 1
# read attribute size and skip the value
attr_size = struct.unpack('<I', exr_bytes[idx:idx+4])[0]
idx += 4 + attr_size
# offset table starts right after the header terminator
table_start = idx
# build comfyui metadata payload
payload = b""
if prompt is not None:
prompt_str = json.dumps(prompt).encode('utf-8')
payload += b"prompt\x00string\x00" + struct.pack('<I', len(prompt_str)) + prompt_str
if extra_pnginfo is not None:
for k, v in extra_pnginfo.items():
k_enc = k.encode('utf-8')[:254]
v_enc = json.dumps(v).encode('utf-8')
payload += k_enc + b"\x00string\x00" + struct.pack('<I', len(v_enc)) + v_enc
# find the first pixel offset to calculate the table size
min_offset = struct.unpack('<Q', exr_bytes[table_start:table_start+8])[0]
num_entries = 1
while table_start + num_entries * 8 < min_offset:
offset = struct.unpack('<Q', exr_bytes[table_start + num_entries*8 : table_start + num_entries*8 + 8])[0]
if offset < min_offset:
min_offset = offset
num_entries += 1
# shift table pointers by the payload size
shift_amount = len(payload)
new_table = bytearray()
for i in range(num_entries):
offset = struct.unpack('<Q', exr_bytes[table_start + i*8 : table_start + i*8 + 8])[0]
new_table.extend(struct.pack('<Q', offset + shift_amount))
# stitch the file back together with the new header and updated table
return exr_bytes[:table_start - 1] + payload + b'\x00' + new_table + exr_bytes[table_start + num_entries*8:]
def inject_comfy_metadata_avif(avif_bytes: bytes, prompt, extra_pnginfo) -> bytes:
metadata = {}
if prompt is not None:
metadata["prompt"] = prompt
if extra_pnginfo is not None:
for k, v in extra_pnginfo.items():
metadata[k] = v
payload = json.dumps(metadata).encode('utf-8')
# 16-byte uuid required by isobmff spec
# 'comfyui_workflow' is exactly 16 bytes long!
comfy_uuid = b'comfyui_workflow'
# box size: 4 (size) + 4 (type) + 16 (uuid) + payload length
box_size = 4 + 4 + 16 + len(payload)
uuid_box = struct.pack('>I', box_size) + b'uuid' + comfy_uuid + payload
# isobmff allows top-level boxes at the end of the file.
return avif_bytes + uuid_box
class SaveImageAdvanced(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveImageAdvanced",
search_aliases=["save", "save image", "export image", "output image", "write image"],
display_name="Save Image",
description="Saves the input images to your ComfyUI output directory.",
category="image",
essentials_category="Basics",
inputs=[
IO.Image.Input(
"images",
tooltip="The images to save."
),
IO.String.Input(
"filename_prefix",
default="ComfyUI",
tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.",
),
IO.DynamicCombo.Input(
"format",
options=[
IO.DynamicCombo.Option(
"png",
[
IO.Combo.Input(
"bit_depth",
options=["8-bit", "16-bit"],
default="8-bit",
advanced=True,
),
IO.Combo.Input(
"interpret_as",
options=["sRGB", "Raw/Data"],
default="sRGB",
advanced=True,
),
],
),
IO.DynamicCombo.Option(
"avif",
[
IO.Combo.Input(
"bit_depth",
options=["8-bit", "10-bit"],
default="8-bit",
advanced=True,
),
IO.Combo.Input(
"interpret_as",
options=["sRGB"],
default="sRGB",
advanced=True,
),
],
),
IO.DynamicCombo.Option(
"exr",
[
IO.Combo.Input(
"bit_depth",
options=["32-bit"],
default="32-bit",
advanced=True,
),
IO.Combo.Input(
"interpret_as",
options=["Linear", "Raw/Data"],
default="Linear",
advanced=True,
),
],
),
],
tooltip="The file format in which to save the image.",
),
IO.Boolean.Input("embed_workflow", default=True, advanced=True),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, images, filename_prefix: str, format: dict, embed_workflow: bool) -> IO.NodeOutput:
output_dir = folder_paths.get_output_directory()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.\
get_save_image_path(filename_prefix, output_dir, images[0].shape[1], images[0].shape[0])
results = list()
prompt = cls.hidden.prompt
extra_pnginfo = cls.hidden.extra_pnginfo
for batch_number, image in enumerate(images):
# get widget values from dynamic combo
file_format = format["format"]
bit_depth = format["bit_depth"]
interpret_as = format["interpret_as"]
img_tensor = image.clone()
height, width, num_channels = img_tensor.shape
has_alpha = (num_channels == 4)
# file pathing
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{file_format}"
file_path = os.path.join(full_output_folder, file)
if file_format in ["png", "exr", "avif"]:
if bit_depth == "32-bit":
img_np = img_tensor.cpu().numpy().astype(np.float32)
img_np = img_np[:, :, [1, 2, 0, 3]] if has_alpha else img_np[:, :,[1, 2, 0]]
av_fmt = 'gbrapf32le' if has_alpha else 'gbrpf32le'
elif bit_depth in ["10-bit", "12-bit", "16-bit"]:
img_np = (img_tensor * 65535.0).clamp(0, 65535).to(torch.int32).cpu().numpy().astype(np.uint16)
av_fmt = 'rgba64le' if has_alpha else 'rgb48le'
else:
img_np = (img_tensor * 255.0).clamp(0, 255).to(torch.int32).cpu().numpy().astype(np.uint8)
av_fmt = 'rgba' if has_alpha else 'rgb24'
fd, tmp_path = tempfile.mkstemp(suffix=f".{file_format}")
os.close(fd)
container_format = "image2" if file_format in ["png", "exr"] else "avif"
container = av.open(tmp_path, mode='w', format=container_format)
if file_format == "exr":
stream = container.add_stream('exr', rate=1)
stream.pix_fmt = av_fmt
elif file_format == "avif":
stream = container.add_stream('libsvtav1', rate=1)
stream.time_base = Fraction(1, 1)
if bit_depth in ["10-bit", "16-bit", "32-bit"]:
stream.pix_fmt = 'yuv420p10le'
else:
stream.pix_fmt = 'yuv420p'
stream.codec_context.color_range = 2
if interpret_as == "Raw/Data": # 2 == unspecified
stream.codec_context.colorspace = 2
stream.codec_context.color_primaries = 2
stream.codec_context.color_trc = 2
elif interpret_as == "Linear":
stream.codec_context.colorspace = 1
stream.codec_context.color_primaries = 1
stream.codec_context.color_trc = 8
else: # sRGB
stream.codec_context.colorspace = 1
stream.codec_context.color_primaries = 1
stream.codec_context.color_trc = 13
stream.options = {
'preset': '10',
'svtav1-params': 'rc=0:qp=20:color-range=1:color-matrix=1:enable-overlays=1',
'g': '1'
}
elif file_format == "png":
stream = container.add_stream('png', rate=1)
if bit_depth == "16-bit":
stream.pix_fmt = 'rgba64be' if has_alpha else 'rgb48be'
else:
stream.pix_fmt = av_fmt
stream.width = width
stream.height = height
stream.time_base = Fraction(1, 1)
is_planar = av_fmt.startswith('gbrp') or 'p' in av_fmt.split('rgba')[-1]
if is_planar:
if av_fmt.startswith('gbr'):
img_np = img_np[:, :, [1, 2, 0, 3]] if has_alpha else img_np[:, :, [1, 2, 0]]
img_np = img_np.transpose(2, 0, 1)
try:
frame = av.VideoFrame.from_ndarray(img_np, format=av_fmt)
except ValueError:
logging.warning("[WARNING] Current FFMPEG Binary can't save natively. Fallbacking.")
img_np = (img_tensor * 65535.0).clamp(0, 65535).to(torch.int32).cpu().numpy().astype(np.uint16)
av_fmt = 'rgba64le' if has_alpha else 'rgb48le'
frame = av.VideoFrame.from_ndarray(img_np, format=av_fmt)
# reformat for both avif and exr to ensure correct internal conversion
if file_format in ["avif", "exr"] or (file_format == "png" and bit_depth == "16-bit"):
reformat_kwargs = {"format": stream.pix_fmt}
if file_format == "avif":
reformat_kwargs.update({
"src_colorspace": 1, "dst_colorspace": 1,
"src_color_range": 2, "dst_color_range": 2
})
frame = frame.reformat(**reformat_kwargs)
frame.pts = 0
frame.time_base = stream.time_base
if file_format == "avif":
frame.color_range = 2
frame.colorspace = stream.codec_context.colorspace
for packet in stream.encode(frame):
container.mux(packet)
for packet in stream.encode():
container.mux(packet)
container.close()
with open(tmp_path, "rb") as f:
final_bytes = f.read()
os.remove(tmp_path)
if embed_workflow and not args.disable_metadata:
if file_format == "png":
final_bytes = inject_comfy_metadata_png(final_bytes, prompt, extra_pnginfo)
elif file_format == "exr":
final_bytes = inject_comfy_metadata_exr(final_bytes, prompt, extra_pnginfo)
else:
final_bytes = inject_comfy_metadata_avif(final_bytes, prompt, extra_pnginfo)
with open(file_path, "wb") as f:
f.write(final_bytes)
results.append({
"filename": file,
"subfolder": subfolder,
"type": "output"
})
counter += 1
return IO.NodeOutput(ui={"images": results})
class ImagesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -835,6 +1183,7 @@ class ImagesExtension(ComfyExtension):
ImageAddNoise,
SaveAnimatedWEBP,
SaveAnimatedPNG,
SaveImageAdvanced,
SaveSVGNode,
ImageStitch,
ResizeAndPadImage,

View File

@ -15,6 +15,7 @@ import torch
from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
import comfy.model_prefetch
import comfy_aimdo.model_vbar
from latent_preview import set_preview_method
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if args.verbose == "DEBUG":
comfy_aimdo.control.analyze()
comfy.model_management.reset_cast_buffers()
comfy.model_prefetch.cleanup_prefetch_queues()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks:

View File

@ -1633,6 +1633,7 @@ class SaveImage:
ESSENTIALS_CATEGORY = "Basics"
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"]
DEPRECATED = True
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
@ -2138,7 +2139,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch",
# Image
"SaveImage": "Save Image",
"SaveImage": "Save Image (DEPRECATED)",
"PreviewImage": "Preview Image",
"LoadImage": "Load Image",
"LoadImageMask": "Load Image (as Mask)",
@ -2445,6 +2446,7 @@ async def init_builtin_extra_nodes():
"nodes_rtdetr.py",
"nodes_frame_interpolation.py",
"nodes_sam3.py",
"nodes_convert_color_space.py",
]
import_failed = []