mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-23 07:27:24 +08:00
Compare commits
19 Commits
2364d2e4c1
...
b5def6a96a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b5def6a96a | ||
|
|
783782d5d7 | ||
|
|
3e3ed8cc2a | ||
|
|
67f6cb3527 | ||
|
|
0230e0e7cc | ||
|
|
b5921c8ac2 | ||
|
|
63103d519e | ||
|
|
d56887ac52 | ||
|
|
752355991a | ||
|
|
74f5398b91 | ||
|
|
202e99b7e3 | ||
|
|
360b1cb1fb | ||
|
|
a5b0d08e7c | ||
|
|
e6c5ed5c7f | ||
|
|
506b880565 | ||
|
|
900aaaa445 | ||
|
|
61e4946fa2 | ||
|
|
a9396119ac | ||
|
|
eb7829b6ce |
@ -1,2 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-dynamic-vram
|
||||
pause
|
||||
@ -1,2 +1,2 @@
|
||||
# Admins
|
||||
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128
|
||||
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai
|
||||
|
||||
@ -193,13 +193,15 @@ If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
#### Alternative Downloads:
|
||||
#### All Official Portable Downloads:
|
||||
|
||||
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
|
||||
[Portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
[Portable for Nvidia GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) (supports 20 series and above).
|
||||
|
||||
[Portable for Nvidia GPUs with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
|
||||
@ -90,7 +90,6 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
|
||||
|
||||
class LatentPreviewMethod(enum.Enum):
|
||||
|
||||
@ -786,8 +786,26 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
pass
|
||||
|
||||
class CogVideoX(LatentFormat):
|
||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||
|
||||
scale_factor matches the vae/config.json scaling_factor for the 2b variant.
|
||||
The 5b-class checkpoints (CogVideoX-5b, CogVideoX-1.5-5B, CogVideoX-Fun-V1.5-*)
|
||||
use a different value; see CogVideoX1_5 below.
|
||||
"""
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.15258426
|
||||
|
||||
|
||||
class CogVideoX1_5(CogVideoX):
|
||||
"""Latent format for 5b-class CogVideoX checkpoints.
|
||||
|
||||
Covers THUDM/CogVideoX-5b, THUDM/CogVideoX-1.5-5B, and the CogVideoX-Fun
|
||||
V1.5-5b family (including VOID inpainting). All of these have
|
||||
scaling_factor=0.7 in their vae/config.json. Auto-selected in
|
||||
supported_models.CogVideoX_T2V based on transformer hidden dim.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.scale_factor = 0.7
|
||||
|
||||
@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_prefetch
|
||||
|
||||
class CompressedTimestep:
|
||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
|
||||
|
||||
# Process transformer blocks
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
import comfy.model_base
|
||||
@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||
if isinstance(value, torch.Tensor):
|
||||
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||
elif isinstance(value, tuple):
|
||||
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||
elif isinstance(value, list):
|
||||
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module):
|
||||
if "latent_shapes" in extra_conds:
|
||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["prefetch_dynamic_vbars"] = (
|
||||
self.current_patcher is not None and self.current_patcher.is_dynamic()
|
||||
)
|
||||
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||
model_output, _ = utils.pack_latents(model_output)
|
||||
|
||||
@ -31,6 +31,7 @@ from contextlib import nullcontext
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy.quant_ops
|
||||
import comfy_aimdo.vram_buffer
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 #No vram present: no need to move models to vram
|
||||
@ -112,10 +113,6 @@ if args.directml is not None:
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # noqa: F401
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
_ = torch.xpu.device_count()
|
||||
@ -583,9 +580,6 @@ class LoadedModel:
|
||||
|
||||
real_model = self.model.model
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
with torch.no_grad():
|
||||
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
@ -1182,6 +1176,10 @@ stream_counters = {}
|
||||
|
||||
STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
|
||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||
|
||||
def get_cast_buffer(offload_stream, device, size, ref):
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
@ -1215,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
||||
|
||||
return cast_buffer
|
||||
|
||||
def get_aimdo_cast_buffer(offload_stream, device):
|
||||
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
|
||||
if cast_buffer is None:
|
||||
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
|
||||
return cast_buffer
|
||||
def reset_cast_buffers():
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||
if offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
@ -1581,10 +1592,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.get_device_properties(device).has_fp16
|
||||
return torch.xpu.get_device_properties(device).has_fp16
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
@ -1650,10 +1658,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
if is_intel_xpu():
|
||||
if torch_version_numeric < (2, 3):
|
||||
return True
|
||||
else:
|
||||
return torch.xpu.is_bf16_supported()
|
||||
return torch.xpu.is_bf16_supported()
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
@ -1784,6 +1789,7 @@ def soft_empty_cache(force=False):
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.synchronize()
|
||||
torch.xpu.empty_cache()
|
||||
elif is_ascend_npu():
|
||||
torch.npu.empty_cache()
|
||||
|
||||
@ -121,9 +121,20 @@ class LowVramPatch:
|
||||
self.patches = patches
|
||||
self.convert_func = convert_func # TODO: remove
|
||||
self.set_func = set_func
|
||||
self.prepared_patches = None
|
||||
|
||||
def prepare(self, allocate_buffer, stream):
|
||||
self.prepared_patches = [
|
||||
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||
for patch in self.patches[self.key]
|
||||
]
|
||||
|
||||
def clear_prepared(self):
|
||||
self.prepared_patches = None
|
||||
|
||||
def __call__(self, weight):
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
|
||||
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
|
||||
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||
|
||||
|
||||
65
comfy/model_prefetch.py
Normal file
65
comfy/model_prefetch.py
Normal file
@ -0,0 +1,65 @@
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
|
||||
PREFETCH_QUEUES = []
|
||||
|
||||
def cleanup_prefetched_modules(comfy_modules):
|
||||
for s in comfy_modules:
|
||||
prefetch = getattr(s, "_prefetch", None)
|
||||
if prefetch is None:
|
||||
continue
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_fn.clear_prepared()
|
||||
if prefetch["signature"] is not None:
|
||||
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||
delattr(s, "_prefetch")
|
||||
|
||||
def cleanup_prefetch_queues():
|
||||
global PREFETCH_QUEUES
|
||||
|
||||
for queue in PREFETCH_QUEUES:
|
||||
for entry in queue:
|
||||
if entry is None or not isinstance(entry, tuple):
|
||||
continue
|
||||
_, prefetch_state = entry
|
||||
comfy_modules = prefetch_state[1]
|
||||
if comfy_modules is not None:
|
||||
cleanup_prefetched_modules(comfy_modules)
|
||||
PREFETCH_QUEUES = []
|
||||
|
||||
def prefetch_queue_pop(queue, device, module):
|
||||
if queue is None:
|
||||
return
|
||||
|
||||
consumed = queue.pop(0)
|
||||
if consumed is not None:
|
||||
offload_stream, prefetch_state = consumed
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
_, comfy_modules = prefetch_state
|
||||
if comfy_modules is not None:
|
||||
cleanup_prefetched_modules(comfy_modules)
|
||||
|
||||
prefetch = queue[0]
|
||||
if prefetch is not None:
|
||||
comfy_modules = []
|
||||
for s in prefetch.modules():
|
||||
if hasattr(s, "_v"):
|
||||
comfy_modules.append(s)
|
||||
|
||||
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||
|
||||
def make_prefetch_queue(queue, device, transformer_options):
|
||||
if (not transformer_options.get("prefetch_dynamic_vbars", False)
|
||||
or comfy.model_management.NUM_STREAMS == 0
|
||||
or comfy.model_management.is_device_cpu(device)
|
||||
or not comfy.model_management.device_supports_non_blocking(device)):
|
||||
return None
|
||||
|
||||
queue = [None] + queue + [None]
|
||||
PREFETCH_QUEUES.append(queue)
|
||||
return queue
|
||||
181
comfy/ops.py
181
comfy/ops.py
@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
|
||||
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
||||
|
||||
|
||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = None
|
||||
if s.bias is not None:
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
||||
return weight, bias, (None, None, None)
|
||||
|
||||
# FIXME: add n=1 cache hit fast path
|
||||
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
|
||||
offload_stream = None
|
||||
xfer_dest = None
|
||||
cast_buffer = None
|
||||
cast_buffer_offset = 0
|
||||
|
||||
def ensure_offload_stream(module, required_size, check_largest):
|
||||
nonlocal offload_stream
|
||||
nonlocal cast_buffer
|
||||
|
||||
if offload_stream is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
|
||||
return
|
||||
|
||||
current_size = 0 if cast_buffer is None else cast_buffer.size()
|
||||
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
cast_buffer = None
|
||||
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
|
||||
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
|
||||
|
||||
def get_cast_buffer(buffer_size):
|
||||
nonlocal offload_stream
|
||||
nonlocal cast_buffer
|
||||
nonlocal cast_buffer_offset
|
||||
|
||||
if buffer_size == 0:
|
||||
return None
|
||||
|
||||
if offload_stream is None:
|
||||
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
|
||||
|
||||
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
||||
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
|
||||
cast_buffer_offset += buffer_size
|
||||
return buffer
|
||||
|
||||
for s in comfy_modules:
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
prefetch = {
|
||||
"signature": signature,
|
||||
"resident": resident,
|
||||
}
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
if signature is not None:
|
||||
if resident:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
||||
s._prefetch = prefetch
|
||||
continue
|
||||
|
||||
if not resident:
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||
cast_dest = None
|
||||
needs_cast = False
|
||||
|
||||
xfer_source = [ s.weight, s.bias ]
|
||||
|
||||
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
if data is None:
|
||||
continue
|
||||
if data.dtype != geometry.dtype:
|
||||
needs_cast = True
|
||||
cast_dest = xfer_dest
|
||||
if cast_dest is None:
|
||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
||||
xfer_dest = None
|
||||
break
|
||||
|
||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
if xfer_dest is None and offload_stream is not None:
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
if xfer_dest is None:
|
||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
||||
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
|
||||
if xfer_dest is None:
|
||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||
offload_stream = None
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
if signature is None and pin is None:
|
||||
comfy.pinned_memory.pin_memory(s)
|
||||
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
xfer_source = [ pin ]
|
||||
#send it over
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
if cast_dest is not None:
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||
|
||||
prefetch["xfer_dest"] = xfer_dest
|
||||
prefetch["cast_dest"] = cast_dest
|
||||
prefetch["cast_geometry"] = cast_geometry
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
return offload_stream
|
||||
|
||||
|
||||
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
|
||||
|
||||
prefetch = getattr(s, "_prefetch", None)
|
||||
|
||||
if prefetch["resident"]:
|
||||
weight = s._v_weight
|
||||
bias = s._v_bias
|
||||
else:
|
||||
xfer_dest = prefetch["xfer_dest"]
|
||||
if prefetch["needs_cast"]:
|
||||
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
|
||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
||||
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
|
||||
if post_cast is not None:
|
||||
post_cast.copy_(pre_cast)
|
||||
xfer_dest = cast_dest
|
||||
|
||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
||||
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
|
||||
weight = params[0]
|
||||
bias = params[1]
|
||||
if signature is not None:
|
||||
if prefetch["signature"] is not None:
|
||||
s._v_weight = weight
|
||||
s._v_bias = bias
|
||||
s._v_signature=signature
|
||||
s._v_signature = prefetch["signature"]
|
||||
|
||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
fns = getattr(s, param_key + "_function", [])
|
||||
|
||||
if x is None:
|
||||
return None
|
||||
|
||||
orig = x
|
||||
|
||||
def to_dequant(tensor, dtype):
|
||||
@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
||||
x = f(x)
|
||||
return x
|
||||
|
||||
update_weight = signature is not None
|
||||
update_weight = prefetch["signature"] is not None
|
||||
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
|
||||
if bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
||||
|
||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
||||
if s.bias is not None:
|
||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
||||
|
||||
#FIXME: weird offload return protocol
|
||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||
return weight, bias
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||
@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
def format_return(result, offloadable):
|
||||
weight, bias, offload_stream = result
|
||||
return (weight, bias, offload_stream) if offloadable else (weight, bias)
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
|
||||
prefetched = hasattr(s, "_prefetch")
|
||||
offload_stream = None
|
||||
offload_device = None
|
||||
if not prefetched:
|
||||
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
|
||||
|
||||
if not prefetched:
|
||||
if getattr(s, "_prefetch")["signature"] is not None:
|
||||
offload_device = device
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
lowvram_fn.clear_prepared()
|
||||
delattr(s, "_prefetch")
|
||||
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
|
||||
|
||||
|
||||
if offloadable and (device != s.weight.device or
|
||||
(s.bias is not None and device != s.bias.device)):
|
||||
@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
if offloadable:
|
||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
|
||||
|
||||
|
||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
|
||||
@ -65,6 +65,7 @@ import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.cogvideo
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -1223,6 +1224,7 @@ class CLIPType(Enum):
|
||||
NEWBIE = 24
|
||||
FLUX2 = 25
|
||||
LONGCAT_IMAGE = 26
|
||||
COGVIDEOX = 27
|
||||
|
||||
|
||||
|
||||
@ -1418,6 +1420,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif clip_type == CLIPType.COGVIDEOX:
|
||||
clip_target.clip = comfy.text_encoders.cogvideo.cogvideo_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.cogvideo.CogVideoXTokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||
|
||||
@ -1853,6 +1853,14 @@ class CogVideoX_T2V(supported_models_base.BASE):
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
# 2b-class (dim=1920, heads=30) uses scale_factor=1.15258426.
|
||||
# 5b-class (dim=3072, heads=48) — incl. CogVideoX-5b, 1.5-5B, and
|
||||
# Fun-V1.5 inpainting — uses scale_factor=0.7 per vae/config.json.
|
||||
if unet_config.get("num_attention_heads", 0) >= 48:
|
||||
self.latent_format = latent_formats.CogVideoX1_5
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
@ -1879,6 +1887,20 @@ class CogVideoX_I2V(CogVideoX_T2V):
|
||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class CogVideoX_Inpaint(CogVideoX_T2V):
|
||||
unet_config = {
|
||||
"image_model": "cogvideox",
|
||||
"in_channels": 48,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
self.unet_config.setdefault("sample_height", 96)
|
||||
self.unet_config.setdefault("sample_width", 170)
|
||||
self.unet_config.setdefault("sample_frames", 81)
|
||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
|
||||
models = [
|
||||
LotusD,
|
||||
@ -1958,6 +1980,7 @@ models = [
|
||||
ErnieImage,
|
||||
SAM3,
|
||||
SAM31,
|
||||
CogVideoX_Inpaint,
|
||||
CogVideoX_I2V,
|
||||
CogVideoX_T2V,
|
||||
SVD_img2vid,
|
||||
|
||||
@ -1,6 +1,48 @@
|
||||
import comfy.text_encoders.sd3_clip
|
||||
from comfy import sd1_clip
|
||||
|
||||
|
||||
class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
|
||||
"""Inner T5 tokenizer for CogVideoX.
|
||||
|
||||
CogVideoX was trained with T5 embeddings padded to 226 tokens (not 77 like SD3).
|
||||
Used both directly by supported_models.CogVideoX_T2V.clip_target (paired with
|
||||
the raw T5XXLModel) and by the CogVideoXTokenizer outer wrapper below.
|
||||
"""
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)
|
||||
|
||||
|
||||
class CogVideoXTokenizer(sd1_clip.SD1Tokenizer):
|
||||
"""Outer tokenizer wrapper for CLIPLoader (type="cogvideox")."""
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
clip_name="t5xxl", tokenizer=CogVideoXT5Tokenizer)
|
||||
|
||||
|
||||
class CogVideoXT5XXL(sd1_clip.SD1ClipModel):
|
||||
"""Outer T5XXL model wrapper for CLIPLoader (type="cogvideox").
|
||||
|
||||
Wraps the raw T5XXL model in the SD1ClipModel interface so that CLIP.__init__
|
||||
(which reads self.dtypes) works correctly. The inner model is the standard
|
||||
sd3_clip.T5XXLModel (no attention_mask change needed for CogVideoX).
|
||||
"""
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl",
|
||||
clip_model=comfy.text_encoders.sd3_clip.T5XXLModel,
|
||||
model_options=model_options)
|
||||
|
||||
|
||||
def cogvideo_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
"""Factory that returns a CogVideoXT5XXL class configured with the detected
|
||||
T5 dtype and optional quantization metadata, for use in load_text_encoder_state_dicts.
|
||||
"""
|
||||
class CogVideoXTEModel_(CogVideoXT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return CogVideoXTEModel_
|
||||
|
||||
@ -459,27 +459,23 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
total_images = image.shape[0]
|
||||
captured_feat = None
|
||||
|
||||
model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768
|
||||
model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024
|
||||
model_w = int(head.heatmap_size[0]) * 4 # 192 * 4 = 768
|
||||
model_h = int(head.heatmap_size[1]) * 4 # 256 * 4 = 1024
|
||||
|
||||
def _resize_to_model(imgs):
|
||||
"""Aspect-preserving resize + zero-pad BHWC images to (model_h, model_w). Returns (resized_bhwc, scale, pad_top, pad_left)."""
|
||||
"""Stretch BHWC images to (model_h, model_w), model expects no aspect preservation."""
|
||||
h, w = imgs.shape[-3], imgs.shape[-2]
|
||||
scale = min(model_h / h, model_w / w)
|
||||
sh, sw = int(round(h * scale)), int(round(w * scale))
|
||||
pt, pl = (model_h - sh) // 2, (model_w - sw) // 2
|
||||
method = "area" if (model_h <= h and model_w <= w) else "bilinear"
|
||||
chw = imgs.permute(0, 3, 1, 2).float()
|
||||
scaled = comfy.utils.common_upscale(chw, sw, sh, upscale_method="bilinear", crop="disabled")
|
||||
padded = torch.zeros(scaled.shape[0], scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device)
|
||||
padded[:, :, pt:pt + sh, pl:pl + sw] = scaled
|
||||
return padded.permute(0, 2, 3, 1), scale, pt, pl
|
||||
scaled = comfy.utils.common_upscale(chw, model_w, model_h, upscale_method=method, crop="disabled")
|
||||
return scaled.permute(0, 2, 3, 1), model_w / w, model_h / h
|
||||
|
||||
def _remap_keypoints(kp, scale, pad_top, pad_left, offset_x=0, offset_y=0):
|
||||
def _remap_keypoints(kp, scale_x, scale_y, offset_x=0, offset_y=0):
|
||||
"""Remap keypoints from model space back to original image space."""
|
||||
kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32)
|
||||
invalid = kp[..., 0] < 0
|
||||
kp[..., 0] = (kp[..., 0] - pad_left) / scale + offset_x
|
||||
kp[..., 1] = (kp[..., 1] - pad_top) / scale + offset_y
|
||||
kp[..., 0] = kp[..., 0] / scale_x + offset_x
|
||||
kp[..., 1] = kp[..., 1] / scale_y + offset_y
|
||||
kp[invalid] = -1
|
||||
return kp
|
||||
|
||||
@ -529,18 +525,18 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
continue
|
||||
|
||||
crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C)
|
||||
crop_resized, scale, pad_top, pad_left = _resize_to_model(crop)
|
||||
crop_resized, sx, sy = _resize_to_model(crop)
|
||||
|
||||
latent_crop = vae.encode(crop_resized)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_crop)
|
||||
kp = _remap_keypoints(kp_batch[0], scale, pad_top, pad_left, x1, y1)
|
||||
kp = _remap_keypoints(kp_batch[0], sx, sy, x1, y1)
|
||||
img_keypoints.append(kp)
|
||||
img_scores.append(sc_batch[0])
|
||||
else:
|
||||
img_resized, scale, pad_top, pad_left = _resize_to_model(img)
|
||||
img_resized, sx, sy = _resize_to_model(img)
|
||||
latent_img = vae.encode(img_resized)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_img)
|
||||
img_keypoints.append(_remap_keypoints(kp_batch[0], scale, pad_top, pad_left))
|
||||
img_keypoints.append(_remap_keypoints(kp_batch[0], sx, sy))
|
||||
img_scores.append(sc_batch[0])
|
||||
|
||||
all_keypoints.append(img_keypoints)
|
||||
@ -549,12 +545,12 @@ class SDPoseKeypointExtractor(io.ComfyNode):
|
||||
|
||||
else: # full-image mode, batched
|
||||
for batch_start in tqdm(range(0, total_images, batch_size), desc="Extracting keypoints"):
|
||||
batch_resized, scale, pad_top, pad_left = _resize_to_model(image[batch_start:batch_start + batch_size])
|
||||
batch_resized, sx, sy = _resize_to_model(image[batch_start:batch_start + batch_size])
|
||||
latent_batch = vae.encode(batch_resized)
|
||||
kp_batch, sc_batch = _run_on_latent(latent_batch)
|
||||
|
||||
for kp, sc in zip(kp_batch, sc_batch):
|
||||
all_keypoints.append([_remap_keypoints(kp, scale, pad_top, pad_left)])
|
||||
all_keypoints.append([_remap_keypoints(kp, sx, sy)])
|
||||
all_scores.append([sc])
|
||||
|
||||
pbar.update(len(kp_batch))
|
||||
@ -727,13 +723,13 @@ class CropByBBoxes(io.ComfyNode):
|
||||
scale = min(output_width / crop_w, output_height / crop_h)
|
||||
scaled_w = int(round(crop_w * scale))
|
||||
scaled_h = int(round(crop_h * scale))
|
||||
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
|
||||
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="area", crop="disabled")
|
||||
pad_left = (output_width - scaled_w) // 2
|
||||
pad_top = (output_height - scaled_h) // 2
|
||||
resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)
|
||||
resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
|
||||
else: # "stretch"
|
||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="area", crop="disabled")
|
||||
crops.append(resized)
|
||||
|
||||
if not crops:
|
||||
|
||||
483
comfy_extras/nodes_void.py
Normal file
483
comfy_extras/nodes_void.py
Normal file
@ -0,0 +1,483 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
import comfy
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
import comfy.samplers
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
import nodes
|
||||
from comfy.utils import model_trange as trange
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from torchvision.models.optical_flow import raft_large
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
from comfy_extras.void_noise_warp import RaftOpticalFlow, get_noise_from_video
|
||||
|
||||
OpticalFlow = io.Custom("OPTICAL_FLOW")
|
||||
|
||||
TEMPORAL_COMPRESSION = 4
|
||||
PATCH_SIZE_T = 2
|
||||
|
||||
|
||||
def _valid_void_length(length: int) -> int:
|
||||
"""Round ``length`` down to a value that produces an even latent_t.
|
||||
|
||||
VOID / CogVideoX-Fun-V1.5 uses patch_size_t=2, so the VAE-encoded latent
|
||||
must have an even temporal dimension. If latent_t is odd, the transformer
|
||||
pad_to_patch_size circular-wraps an extra latent frame onto the end; after
|
||||
the post-transformer crop the last real latent frame has been influenced
|
||||
by the wrapped phantom frame, producing visible jitter and "disappearing"
|
||||
subjects near the end of the decoded video. Rounding down fixes this.
|
||||
"""
|
||||
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
|
||||
if latent_t % PATCH_SIZE_T == 0:
|
||||
return length
|
||||
# Round latent_t down to the nearest multiple of PATCH_SIZE_T, then invert
|
||||
# the ((length - 1) // TEMPORAL_COMPRESSION) + 1 formula. Floor at 1 frame
|
||||
# so we never return a non-positive length.
|
||||
target_latent_t = max(PATCH_SIZE_T, (latent_t // PATCH_SIZE_T) * PATCH_SIZE_T)
|
||||
return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1
|
||||
|
||||
|
||||
class OpticalFlowLoader(io.ComfyNode):
|
||||
"""Load an optical flow model from ``models/optical_flow/``.
|
||||
|
||||
Only torchvision's RAFT-large format is recognized today (the model used
|
||||
by VOIDWarpedNoise). The checkpoint must be placed under
|
||||
``models/optical_flow/`` — ComfyUI never downloads optical-flow weights
|
||||
at runtime.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="OpticalFlowLoader",
|
||||
display_name="Load Optical Flow Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"model_name",
|
||||
options=folder_paths.get_filename_list("optical_flow"),
|
||||
tooltip=(
|
||||
"Optical flow model to load. Files must be placed in the "
|
||||
"'optical_flow' folder. Today only torchvision's "
|
||||
"raft_large.pth is supported."
|
||||
),
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
OpticalFlow.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
|
||||
model_path = folder_paths.get_full_path_or_raise("optical_flow", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
|
||||
has_raft_keys = (
|
||||
any(k.startswith("feature_encoder.") for k in sd)
|
||||
and any(k.startswith("context_encoder.") for k in sd)
|
||||
and any(k.startswith("update_block.") for k in sd)
|
||||
)
|
||||
if not has_raft_keys:
|
||||
raise ValueError(
|
||||
"Unrecognized optical flow model format: expected a torchvision "
|
||||
"RAFT-large state dict with 'feature_encoder.', 'context_encoder.' "
|
||||
"and 'update_block.' prefixes."
|
||||
)
|
||||
|
||||
model = raft_large(weights=None, progress=False)
|
||||
model.load_state_dict(sd)
|
||||
model.eval().to(torch.float32)
|
||||
|
||||
patcher = comfy.model_patcher.ModelPatcher(
|
||||
model,
|
||||
load_device=comfy.model_management.get_torch_device(),
|
||||
offload_device=comfy.model_management.unet_offload_device(),
|
||||
)
|
||||
return io.NodeOutput(patcher)
|
||||
|
||||
|
||||
class VOIDQuadmaskPreprocess(io.ComfyNode):
|
||||
"""Preprocess a quadmask video for VOID inpainting.
|
||||
|
||||
Quantizes mask values to four semantic levels, inverts, and normalizes:
|
||||
0 -> primary object to remove
|
||||
63 -> overlap of primary + affected
|
||||
127 -> affected region (interactions)
|
||||
255 -> background (keep)
|
||||
|
||||
After inversion and normalization, the output mask has values in [0, 1]
|
||||
with four discrete levels: 1.0 (remove), ~0.75, ~0.50, 0.0 (keep).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VOIDQuadmaskPreprocess",
|
||||
category="mask/video",
|
||||
inputs=[
|
||||
io.Mask.Input("mask"),
|
||||
io.Int.Input("dilate_width", default=0, min=0, max=50, step=1,
|
||||
tooltip="Dilation radius for the primary mask region (0 = no dilation)"),
|
||||
],
|
||||
outputs=[
|
||||
io.Mask.Output(display_name="quadmask"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, dilate_width=0) -> io.NodeOutput:
|
||||
m = mask.clone()
|
||||
|
||||
if m.max() <= 1.0:
|
||||
m = m * 255.0
|
||||
|
||||
if dilate_width > 0 and m.ndim >= 3:
|
||||
binary = (m < 128).float()
|
||||
kernel_size = dilate_width * 2 + 1
|
||||
if binary.ndim == 3:
|
||||
binary = binary.unsqueeze(1)
|
||||
dilated = torch.nn.functional.max_pool2d(
|
||||
binary, kernel_size=kernel_size, stride=1, padding=dilate_width
|
||||
)
|
||||
if dilated.ndim == 4:
|
||||
dilated = dilated.squeeze(1)
|
||||
m = torch.where(dilated > 0.5, torch.zeros_like(m), m)
|
||||
|
||||
m = torch.where(m <= 31, torch.zeros_like(m), m)
|
||||
m = torch.where((m > 31) & (m <= 95), torch.full_like(m, 63), m)
|
||||
m = torch.where((m > 95) & (m <= 191), torch.full_like(m, 127), m)
|
||||
m = torch.where(m > 191, torch.full_like(m, 255), m)
|
||||
|
||||
m = (255.0 - m) / 255.0
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
class VOIDInpaintConditioning(io.ComfyNode):
|
||||
"""Build VOID inpainting conditioning for CogVideoX.
|
||||
|
||||
Encodes the processed quadmask and masked source video through the VAE,
|
||||
producing a 32-channel concat conditioning (16ch mask + 16ch masked video)
|
||||
that gets concatenated with the 16ch noise latent by the model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VOIDInpaintConditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("video", tooltip="Source video frames [T, H, W, 3]"),
|
||||
io.Mask.Input("quadmask", tooltip="Preprocessed quadmask from VOIDQuadmaskPreprocess [T, H, W]"),
|
||||
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1,
|
||||
tooltip="Number of pixel frames to process. For CogVideoX-Fun-V1.5 "
|
||||
"(patch_size_t=2), latent_t must be even — lengths that "
|
||||
"produce odd latent_t are rounded down (e.g. 49 → 45)."),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, video, quadmask,
|
||||
width, height, length, batch_size) -> io.NodeOutput:
|
||||
|
||||
adjusted_length = _valid_void_length(length)
|
||||
if adjusted_length != length:
|
||||
logging.warning(
|
||||
"VOIDInpaintConditioning: rounding length %d down to %d so that "
|
||||
"latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2). "
|
||||
"Using odd latent_t causes the last frame to be corrupted by "
|
||||
"circular padding.", length, adjusted_length,
|
||||
)
|
||||
length = adjusted_length
|
||||
|
||||
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
|
||||
latent_h = height // 8
|
||||
latent_w = width // 8
|
||||
|
||||
vid = video[:length]
|
||||
vid = comfy.utils.common_upscale(
|
||||
vid.movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
|
||||
qm = quadmask[:length]
|
||||
if qm.ndim == 3:
|
||||
qm = qm.unsqueeze(-1)
|
||||
qm = comfy.utils.common_upscale(
|
||||
qm.movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
if qm.ndim == 4 and qm.shape[-1] == 1:
|
||||
qm = qm.squeeze(-1)
|
||||
|
||||
mask_condition = qm
|
||||
if mask_condition.ndim == 3:
|
||||
mask_condition_3ch = mask_condition.unsqueeze(-1).expand(-1, -1, -1, 3)
|
||||
else:
|
||||
mask_condition_3ch = mask_condition
|
||||
|
||||
inverted_mask_3ch = 1.0 - mask_condition_3ch
|
||||
masked_video = vid[:, :, :, :3] * (1.0 - mask_condition_3ch)
|
||||
|
||||
mask_latents = vae.encode(inverted_mask_3ch)
|
||||
masked_video_latents = vae.encode(masked_video)
|
||||
|
||||
def _match_temporal(lat, target_t):
|
||||
if lat.shape[2] > target_t:
|
||||
return lat[:, :, :target_t]
|
||||
elif lat.shape[2] < target_t:
|
||||
pad = target_t - lat.shape[2]
|
||||
return torch.cat([lat, lat[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2)
|
||||
return lat
|
||||
|
||||
mask_latents = _match_temporal(mask_latents, latent_t)
|
||||
masked_video_latents = _match_temporal(masked_video_latents, latent_t)
|
||||
|
||||
inpaint_latents = torch.cat([mask_latents, masked_video_latents], dim=1)
|
||||
|
||||
# No explicit scaling needed here: the model's CogVideoX.concat_cond()
|
||||
# applies process_latent_in (×latent_format.scale_factor) to each 16-ch
|
||||
# block of the stored conditioning. For 5b-class checkpoints (incl. the
|
||||
# VOID/CogVideoX-Fun-V1.5 inpainting model) that scale_factor is auto-
|
||||
# selected as 0.7 in supported_models.CogVideoX_T2V, which matches the
|
||||
# diffusers vae/config.json scaling_factor VOID was trained with.
|
||||
|
||||
positive = node_helpers.conditioning_set_values(
|
||||
positive, {"concat_latent_image": inpaint_latents}
|
||||
)
|
||||
negative = node_helpers.conditioning_set_values(
|
||||
negative, {"concat_latent_image": inpaint_latents}
|
||||
)
|
||||
|
||||
noise_latent = torch.zeros(
|
||||
[batch_size, 16, latent_t, latent_h, latent_w],
|
||||
device=comfy.model_management.intermediate_device()
|
||||
)
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": noise_latent})
|
||||
|
||||
|
||||
class VOIDWarpedNoise(io.ComfyNode):
|
||||
"""Generate optical-flow warped noise for VOID Pass 2 refinement.
|
||||
|
||||
Takes the Pass 1 output video and produces temporally-correlated noise
|
||||
by warping Gaussian noise along optical flow vectors. This noise is used
|
||||
as the initial latent for Pass 2, resulting in better temporal consistency.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VOIDWarpedNoise",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
OpticalFlow.Input(
|
||||
"optical_flow",
|
||||
tooltip="Optical flow model from OpticalFlowLoader (RAFT-large).",
|
||||
),
|
||||
io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"),
|
||||
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1,
|
||||
tooltip="Number of pixel frames. Rounded down to make latent_t "
|
||||
"even (patch_size_t=2 requirement), e.g. 49 → 45."),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="warped_noise"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, optical_flow, video, width, height, length, batch_size) -> io.NodeOutput:
|
||||
|
||||
adjusted_length = _valid_void_length(length)
|
||||
if adjusted_length != length:
|
||||
logging.warning(
|
||||
"VOIDWarpedNoise: rounding length %d down to %d so that "
|
||||
"latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2).",
|
||||
length, adjusted_length,
|
||||
)
|
||||
length = adjusted_length
|
||||
|
||||
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
|
||||
latent_h = height // 8
|
||||
latent_w = width // 8
|
||||
|
||||
# RAFT + noise warp is real compute, not an "intermediate" buffer, so
|
||||
# we want the actual torch device (CUDA/MPS). The final latent is
|
||||
# moved back to intermediate_device() before returning to match the
|
||||
# rest of the ComfyUI pipeline.
|
||||
device = comfy.model_management.get_torch_device()
|
||||
|
||||
comfy.model_management.load_model_gpu(optical_flow)
|
||||
raft = RaftOpticalFlow(optical_flow.model, device=device)
|
||||
|
||||
vid = video[:length].to(device)
|
||||
vid = comfy.utils.common_upscale(
|
||||
vid.movedim(-1, 1), width, height, "bilinear", "center"
|
||||
).movedim(1, -1)
|
||||
vid_uint8 = (vid.clamp(0, 1) * 255).to(torch.uint8)
|
||||
|
||||
FRAME = 2**-1
|
||||
FLOW = 2**3
|
||||
LATENT_SCALE = 8
|
||||
|
||||
warped = get_noise_from_video(
|
||||
vid_uint8,
|
||||
raft,
|
||||
noise_channels=16,
|
||||
resize_frames=FRAME,
|
||||
resize_flow=FLOW,
|
||||
downscale_factor=round(FRAME * FLOW) * LATENT_SCALE,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if warped.shape[0] != latent_t:
|
||||
indices = torch.linspace(0, warped.shape[0] - 1, latent_t,
|
||||
device=device).long()
|
||||
warped = warped[indices]
|
||||
|
||||
if warped.shape[1] != latent_h or warped.shape[2] != latent_w:
|
||||
# (T, H, W, C) → (T, C, H, W) → bilinear resize → back
|
||||
warped = warped.permute(0, 3, 1, 2)
|
||||
warped = torch.nn.functional.interpolate(
|
||||
warped, size=(latent_h, latent_w),
|
||||
mode="bilinear", align_corners=False,
|
||||
)
|
||||
warped = warped.permute(0, 2, 3, 1)
|
||||
|
||||
# (T, H, W, C) → (B, C, T, H, W)
|
||||
warped_tensor = warped.permute(3, 0, 1, 2).unsqueeze(0)
|
||||
if batch_size > 1:
|
||||
warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
warped_tensor = warped_tensor.to(comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": warped_tensor})
|
||||
|
||||
|
||||
class Noise_FromLatent:
|
||||
"""Wraps a pre-computed LATENT tensor as a NOISE source."""
|
||||
def __init__(self, latent_dict):
|
||||
self.seed = 0
|
||||
self._samples = latent_dict["samples"]
|
||||
|
||||
def generate_noise(self, input_latent):
|
||||
return self._samples.clone().cpu()
|
||||
|
||||
|
||||
class VOIDWarpedNoiseSource(io.ComfyNode):
|
||||
"""Convert a LATENT (e.g. from VOIDWarpedNoise) into a NOISE source
|
||||
for use with SamplerCustomAdvanced."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VOIDWarpedNoiseSource",
|
||||
category="sampling/custom_sampling/noise",
|
||||
inputs=[
|
||||
io.Latent.Input("warped_noise",
|
||||
tooltip="Warped noise latent from VOIDWarpedNoise"),
|
||||
],
|
||||
outputs=[io.Noise.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, warped_noise) -> io.NodeOutput:
|
||||
return io.NodeOutput(Noise_FromLatent(warped_noise))
|
||||
|
||||
|
||||
class VOID_DDIM(comfy.samplers.Sampler):
|
||||
"""DDIM sampler for VOID inpainting models.
|
||||
|
||||
VOID was trained with the diffusers CogVideoXDDIMScheduler which operates in
|
||||
alpha-space (input std ≈ 1). The standard KSampler applies noise_scaling that
|
||||
multiplies by sqrt(1+sigma^2) ≈ 4500x, which is incompatible with VOID's
|
||||
training. This sampler skips noise_scaling and implements the DDIM update rule
|
||||
directly using sigma-to-alpha conversion.
|
||||
"""
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
x = noise.to(torch.float32)
|
||||
model_options = extra_args.get("model_options", {})
|
||||
seed = extra_args.get("seed", None)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable_pbar):
|
||||
sigma = sigmas[i]
|
||||
sigma_next = sigmas[i + 1]
|
||||
|
||||
denoised = model_wrap(x, sigma * s_in, model_options=model_options, seed=seed)
|
||||
|
||||
if callback is not None:
|
||||
callback(i, denoised, x, len(sigmas) - 1)
|
||||
|
||||
if sigma_next == 0:
|
||||
x = denoised
|
||||
else:
|
||||
alpha_t = 1.0 / (1.0 + sigma ** 2)
|
||||
alpha_prev = 1.0 / (1.0 + sigma_next ** 2)
|
||||
|
||||
pred_eps = (x - (alpha_t ** 0.5) * denoised) / (1.0 - alpha_t) ** 0.5
|
||||
x = (alpha_prev ** 0.5) * denoised + (1.0 - alpha_prev) ** 0.5 * pred_eps
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class VOIDSampler(io.ComfyNode):
|
||||
"""VOID DDIM sampler for use with SamplerCustom / SamplerCustomAdvanced.
|
||||
|
||||
Required for VOID inpainting models. Implements the same DDIM loop that VOID
|
||||
was trained with (diffusers CogVideoXDDIMScheduler), without the noise_scaling
|
||||
that the standard KSampler applies. Use with RandomNoise or VOIDWarpedNoiseSource.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VOIDSampler",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[],
|
||||
outputs=[io.Sampler.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls) -> io.NodeOutput:
|
||||
return io.NodeOutput(VOID_DDIM())
|
||||
|
||||
get_sampler = execute
|
||||
|
||||
|
||||
class VOIDExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
OpticalFlowLoader,
|
||||
VOIDQuadmaskPreprocess,
|
||||
VOIDInpaintConditioning,
|
||||
VOIDWarpedNoise,
|
||||
VOIDWarpedNoiseSource,
|
||||
VOIDSampler,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> VOIDExtension:
|
||||
return VOIDExtension()
|
||||
494
comfy_extras/void_noise_warp.py
Normal file
494
comfy_extras/void_noise_warp.py
Normal file
@ -0,0 +1,494 @@
|
||||
"""
|
||||
Optical-flow-warped noise for VOID Pass 2 refinement.
|
||||
|
||||
Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert):
|
||||
https://github.com/RyannDaGreat/CommonSource
|
||||
- noise_warp.py (NoiseWarper / warp_xyωc / regaussianize / get_noise_from_video)
|
||||
- raft.py (RaftOpticalFlow)
|
||||
|
||||
Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually
|
||||
uses (torch THWC uint8 input, no background removal, no visualization, no disk
|
||||
I/O, default warp/noise params) have been inlined. External ``rp`` utilities
|
||||
have been replaced with equivalents from torch.nn.functional / einops. The
|
||||
RAFT optical-flow model itself is loaded offline via ``OpticalFlowLoader`` in
|
||||
``nodes_void.py`` and passed into ``get_noise_from_video`` by the caller; this
|
||||
module never downloads weights at runtime.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Low-level torch image helpers (drop-in replacements for rp.torch_* primitives)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _torch_resize_chw(image, size, interp, copy=True):
|
||||
"""Resize a CHW tensor.
|
||||
|
||||
``size`` is either a scalar factor or a (h, w) tuple. ``interp`` is one
|
||||
of ``"bilinear"``, ``"nearest"``, ``"area"``. When ``copy`` is False and
|
||||
the requested size matches the input, returns the input tensor as is
|
||||
(faster but callers must not mutate the result).
|
||||
"""
|
||||
if image.ndim != 3:
|
||||
raise ValueError(
|
||||
f"_torch_resize_chw expects a 3D CHW tensor, got shape {tuple(image.shape)}"
|
||||
)
|
||||
_, in_h, in_w = image.shape
|
||||
if isinstance(size, (int, float)) and not isinstance(size, bool):
|
||||
new_h = max(1, int(in_h * size))
|
||||
new_w = max(1, int(in_w * size))
|
||||
else:
|
||||
new_h, new_w = size
|
||||
|
||||
if (new_h, new_w) == (in_h, in_w):
|
||||
return image.clone() if copy else image
|
||||
|
||||
kwargs = {}
|
||||
if interp in ("bilinear", "bicubic"):
|
||||
kwargs["align_corners"] = False
|
||||
out = F.interpolate(image[None], size=(new_h, new_w), mode=interp, **kwargs)[0]
|
||||
return out
|
||||
|
||||
|
||||
def _torch_remap_relative(image, dx, dy, interp="bilinear"):
|
||||
"""Relative remap of a CHW image via ``F.grid_sample``.
|
||||
|
||||
Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
|
||||
for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
|
||||
"""
|
||||
if image.ndim != 3:
|
||||
raise ValueError(
|
||||
f"_torch_remap_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
|
||||
)
|
||||
if dx.shape != dy.shape:
|
||||
raise ValueError(
|
||||
f"_torch_remap_relative: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
|
||||
)
|
||||
_, h, w = image.shape
|
||||
|
||||
x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype)
|
||||
y_abs = dy + torch.arange(h, device=dy.device, dtype=dy.dtype)[:, None]
|
||||
|
||||
x_norm = (x_abs / (w - 1)) * 2 - 1
|
||||
y_norm = (y_abs / (h - 1)) * 2 - 1
|
||||
|
||||
grid = torch.stack([x_norm, y_norm], dim=-1)[None].to(image.dtype)
|
||||
out = F.grid_sample(
|
||||
image[None], grid, mode=interp, align_corners=True, padding_mode="zeros"
|
||||
)[0]
|
||||
return out
|
||||
|
||||
|
||||
def _torch_scatter_add_relative(image, dx, dy):
|
||||
"""Scatter-add a CHW image using relative floor-rounded (dx, dy) offsets.
|
||||
|
||||
Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
|
||||
interp='floor')``. Out-of-bounds targets are dropped.
|
||||
"""
|
||||
if image.ndim != 3:
|
||||
raise ValueError(
|
||||
f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
|
||||
)
|
||||
in_c, in_h, in_w = image.shape
|
||||
if dx.shape != (in_h, in_w) or dy.shape != (in_h, in_w):
|
||||
raise ValueError(
|
||||
f"_torch_scatter_add_relative: dx/dy must be ({in_h}, {in_w}), "
|
||||
f"got dx={tuple(dx.shape)} dy={tuple(dy.shape)}"
|
||||
)
|
||||
|
||||
x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long)
|
||||
y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None]
|
||||
|
||||
valid = ((y >= 0) & (y < in_h) & (x >= 0) & (x < in_w)).reshape(-1)
|
||||
indices = (y * in_w + x).reshape(-1)[valid]
|
||||
|
||||
flat_image = rearrange(image, "c h w -> (h w) c")[valid]
|
||||
out = torch.zeros((in_h * in_w, in_c), dtype=image.dtype, device=image.device)
|
||||
out.index_add_(0, indices, flat_image)
|
||||
return rearrange(out, "(h w) c -> c h w", h=in_h, w=in_w)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Noise warping primitives (ported from noise_warp.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def unique_pixels(image):
|
||||
"""Find unique pixel values in a CHW tensor.
|
||||
|
||||
Returns ``(unique_colors [U, C], counts [U], index_matrix [H, W])`` where
|
||||
``index_matrix[i, j]`` is the index of the unique color at that pixel.
|
||||
"""
|
||||
_, h, w = image.shape
|
||||
flat = rearrange(image, "c h w -> (h w) c")
|
||||
unique_colors, inverse_indices, counts = torch.unique(
|
||||
flat, dim=0, return_inverse=True, return_counts=True, sorted=False,
|
||||
)
|
||||
index_matrix = rearrange(inverse_indices, "(h w) -> h w", h=h, w=w)
|
||||
return unique_colors, counts, index_matrix
|
||||
|
||||
|
||||
def sum_indexed_values(image, index_matrix):
|
||||
"""For each unique index, sum the CHW image values at its pixels."""
|
||||
_, h, w = image.shape
|
||||
u = int(index_matrix.max().item()) + 1
|
||||
flat = rearrange(image, "c h w -> (h w) c")
|
||||
out = torch.zeros((u, flat.shape[1]), dtype=flat.dtype, device=flat.device)
|
||||
out.index_add_(0, index_matrix.view(-1), flat)
|
||||
return out
|
||||
|
||||
|
||||
def indexed_to_image(index_matrix, unique_colors):
|
||||
"""Build a CHW image from an index matrix and a (U, C) color table."""
|
||||
h, w = index_matrix.shape
|
||||
flat = unique_colors[index_matrix.view(-1)]
|
||||
return rearrange(flat, "(h w) c -> c h w", h=h, w=w)
|
||||
|
||||
|
||||
def regaussianize(noise):
|
||||
"""Variance-preserving re-sampling of a CHW noise tensor.
|
||||
|
||||
Wherever the noise contains groups of identical pixel values (e.g. after
|
||||
a nearest-neighbor warp that duplicated source pixels), adds zero-mean
|
||||
foreign noise within each group and scales by ``1/sqrt(count)`` so the
|
||||
output is unit-variance gaussian again.
|
||||
"""
|
||||
_, hs, ws = noise.shape
|
||||
_, counts, index_matrix = unique_pixels(noise[:1])
|
||||
|
||||
foreign_noise = torch.randn_like(noise)
|
||||
summed = sum_indexed_values(foreign_noise, index_matrix)
|
||||
meaned = indexed_to_image(index_matrix, summed / rearrange(counts, "u -> u 1"))
|
||||
zeroed_foreign = foreign_noise - meaned
|
||||
|
||||
counts_image = indexed_to_image(index_matrix, rearrange(counts, "u -> u 1"))
|
||||
|
||||
output = noise / counts_image ** 0.5 + zeroed_foreign
|
||||
return output, counts_image
|
||||
|
||||
|
||||
def xy_meshgrid_like_image(image):
|
||||
"""Return a (2, H, W) tensor of (x, y) pixel coordinates matching ``image``."""
|
||||
_, h, w = image.shape
|
||||
y, x = torch.meshgrid(
|
||||
torch.arange(h, device=image.device, dtype=image.dtype),
|
||||
torch.arange(w, device=image.device, dtype=image.dtype),
|
||||
indexing="ij",
|
||||
)
|
||||
return torch.stack([x, y])
|
||||
|
||||
|
||||
def noise_to_state(noise):
|
||||
"""Pack a (C, H, W) noise tensor into a state tensor (3+C, H, W) = [dx, dy, ω, noise]."""
|
||||
zeros = torch.zeros_like(noise[:1])
|
||||
ones = torch.ones_like(noise[:1])
|
||||
return torch.cat([zeros, zeros, ones, noise])
|
||||
|
||||
|
||||
def state_to_noise(state):
|
||||
"""Unpack the noise channels from a state tensor."""
|
||||
return state[3:]
|
||||
|
||||
|
||||
def warp_state(state, flow):
|
||||
"""Warp a noise-warper state tensor along the given optical flow.
|
||||
|
||||
``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
|
||||
``flow`` has shape ``(2, h, w)`` (= dx, dy).
|
||||
"""
|
||||
if flow.device != state.device:
|
||||
raise ValueError(
|
||||
f"warp_state: flow and state must be on the same device, "
|
||||
f"got flow={flow.device} state={state.device}"
|
||||
)
|
||||
if state.ndim != 3:
|
||||
raise ValueError(
|
||||
f"warp_state: state must be 3D (3+C, H, W), got shape {tuple(state.shape)}"
|
||||
)
|
||||
xyoc, h, w = state.shape
|
||||
if flow.shape != (2, h, w):
|
||||
raise ValueError(
|
||||
f"warp_state: flow must have shape (2, {h}, {w}), got {tuple(flow.shape)}"
|
||||
)
|
||||
device = state.device
|
||||
|
||||
x_ch, y_ch = 0, 1
|
||||
xy = 2 # state[:xy] = [dx, dy]
|
||||
xyw = 3 # state[:xyw] = [dx, dy, ω]
|
||||
w_ch = 2 # state[w_ch] = ω
|
||||
c = xyoc - xyw
|
||||
oc = xyoc - xy
|
||||
if c <= 0:
|
||||
raise ValueError(
|
||||
f"warp_state: state has no noise channels (expected 3+C with C>0, got {xyoc} channels)"
|
||||
)
|
||||
if not (state[w_ch] > 0).all():
|
||||
raise ValueError("warp_state: all weights in state[2] must be > 0")
|
||||
|
||||
grid = xy_meshgrid_like_image(state)
|
||||
|
||||
init = torch.empty_like(state)
|
||||
init[:xy] = 0
|
||||
init[w_ch] = 1
|
||||
init[-c:] = 0
|
||||
|
||||
# --- Expansion branch: nearest-neighbor remap with negated flow ---
|
||||
pre_expand = torch.empty_like(state)
|
||||
pre_expand[:xy] = _torch_remap_relative(state[:xy], -flow[0], -flow[1], "nearest")
|
||||
pre_expand[-oc:] = _torch_remap_relative(state[-oc:], -flow[0], -flow[1], "nearest")
|
||||
pre_expand[w_ch][pre_expand[w_ch] == 0] = 1
|
||||
|
||||
# --- Shrink branch: scatter-add state into new positions ---
|
||||
pre_shrink = state.clone()
|
||||
pre_shrink[:xy] += flow
|
||||
|
||||
pos = (grid + pre_shrink[:xy]).round()
|
||||
in_bounds = (pos[x_ch] >= 0) & (pos[x_ch] < w) & (pos[y_ch] >= 0) & (pos[y_ch] < h)
|
||||
pre_shrink = torch.where(~in_bounds[None], init, pre_shrink)
|
||||
|
||||
scat_xy = pre_shrink[:xy].round()
|
||||
pre_shrink[:xy] -= scat_xy
|
||||
pre_shrink[:xy] = 0 # xy_mode='none' in upstream
|
||||
|
||||
def scat(tensor):
|
||||
return _torch_scatter_add_relative(tensor, scat_xy[0], scat_xy[1])
|
||||
|
||||
# rp.torch_scatter_add_image on a bool tensor errors on modern torch;
|
||||
# scatter-sum a float ones tensor and threshold to get the mask instead.
|
||||
shrink_mask = scat(torch.ones(1, h, w, dtype=state.dtype, device=device)) > 0
|
||||
|
||||
# Drop expansion samples at positions that will be filled by shrink.
|
||||
pre_expand = torch.where(shrink_mask, init, pre_expand)
|
||||
|
||||
# Regaussianize both branches together so duplicated-source groups are
|
||||
# counted globally, then split back apart.
|
||||
concat = torch.cat([pre_shrink, pre_expand], dim=2) # along width
|
||||
concat[-c:], counts_image = regaussianize(concat[-c:])
|
||||
concat[w_ch] = concat[w_ch] / counts_image[0]
|
||||
concat[w_ch] = concat[w_ch].nan_to_num()
|
||||
pre_shrink, expand = torch.chunk(concat, chunks=2, dim=2)
|
||||
|
||||
shrink = torch.empty_like(pre_shrink)
|
||||
shrink[w_ch] = scat(pre_shrink[w_ch][None])[0]
|
||||
shrink[:xy] = scat(pre_shrink[:xy] * pre_shrink[w_ch][None]) / shrink[w_ch][None]
|
||||
shrink[-c:] = scat(pre_shrink[-c:] * pre_shrink[w_ch][None]) / scat(
|
||||
pre_shrink[w_ch][None] ** 2
|
||||
).sqrt()
|
||||
|
||||
output = torch.where(shrink_mask, shrink, expand)
|
||||
output[w_ch] = output[w_ch] / output[w_ch].mean()
|
||||
output[w_ch] += 1e-5
|
||||
output[w_ch] **= 0.9999
|
||||
return output
|
||||
|
||||
|
||||
class NoiseWarper:
|
||||
"""Maintain a warpable noise state and emit gaussian noise per frame.
|
||||
|
||||
Simplified from RyannDaGreat/CommonSource/noise_warp.py::NoiseWarper:
|
||||
``scale_factor``, ``post_noise_alpha``, ``progressive_noise_alpha``, and
|
||||
``warp_kwargs`` are all dropped since VOIDWarpedNoise always uses defaults.
|
||||
"""
|
||||
|
||||
def __init__(self, c, h, w, device, dtype=torch.float32):
|
||||
if c <= 0 or h <= 0 or w <= 0:
|
||||
raise ValueError(
|
||||
f"NoiseWarper: c/h/w must all be positive, got c={c} h={h} w={w}"
|
||||
)
|
||||
self.c = c
|
||||
self.h = h
|
||||
self.w = w
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
noise = torch.randn(c, h, w, dtype=dtype, device=device)
|
||||
self._state = noise_to_state(noise)
|
||||
|
||||
@property
|
||||
def noise(self):
|
||||
# With scale_factor=1 the "downsample to respect weights" step is a
|
||||
# size-preserving no-op; the weight-variance correction math still
|
||||
# runs to stay faithful to upstream.
|
||||
n = state_to_noise(self._state)
|
||||
weights = self._state[2:3]
|
||||
return n * weights / (weights ** 2).sqrt()
|
||||
|
||||
def __call__(self, dx, dy):
|
||||
if dx.shape != dy.shape:
|
||||
raise ValueError(
|
||||
f"NoiseWarper: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
|
||||
)
|
||||
flow = torch.stack([dx, dy]).to(self.device, self.dtype)
|
||||
_, oflowh, ofloww = flow.shape
|
||||
|
||||
flow = _torch_resize_chw(flow, (self.h, self.w), "bilinear", copy=True)
|
||||
flowh, floww = flow.shape[-2:]
|
||||
|
||||
# Upstream scales flow[0] by flowh/oflowh and flow[1] by floww/ofloww
|
||||
# (channel-order appears swapped but harmless when H and W are scaled
|
||||
# by the same factor, which is always the case for our callers).
|
||||
flow[0] *= flowh / oflowh
|
||||
flow[1] *= floww / ofloww
|
||||
|
||||
self._state = warp_state(self._state, flow)
|
||||
return self
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RAFT optical flow wrapper (ported from raft.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class RaftOpticalFlow:
|
||||
"""RAFT-large wrapper around a pre-loaded torchvision model.
|
||||
|
||||
``model`` must be the ``torchvision.models.optical_flow.raft_large`` module
|
||||
with its weights already populated; this class is load-agnostic so the
|
||||
caller owns downloading/offload concerns (see ``OpticalFlowLoader`` in
|
||||
``nodes_void.py``). ``__call__`` returns a ``(2, H, W)`` flow.
|
||||
"""
|
||||
|
||||
def __init__(self, model, device=None):
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
self.device = device
|
||||
self.model = model
|
||||
|
||||
def _preprocess(self, image_chw):
|
||||
image = image_chw.to(self.device, torch.float32)
|
||||
_, h, w = image.shape
|
||||
new_h = (h // 8) * 8
|
||||
new_w = (w // 8) * 8
|
||||
image = _torch_resize_chw(image, (new_h, new_w), "bilinear", copy=False)
|
||||
image = image * 2 - 1
|
||||
return image[None]
|
||||
|
||||
def __call__(self, from_image, to_image):
|
||||
"""``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
|
||||
if from_image.shape != to_image.shape:
|
||||
raise ValueError(
|
||||
f"RaftOpticalFlow: from_image and to_image must match, "
|
||||
f"got {tuple(from_image.shape)} vs {tuple(to_image.shape)}"
|
||||
)
|
||||
_, h, w = from_image.shape
|
||||
with torch.no_grad():
|
||||
img1 = self._preprocess(from_image)
|
||||
img2 = self._preprocess(to_image)
|
||||
list_of_flows = self.model(img1, img2)
|
||||
flow = list_of_flows[-1][0] # (2, new_h, new_w)
|
||||
if flow.shape[-2:] != (h, w):
|
||||
flow = _torch_resize_chw(flow, (h, w), "bilinear", copy=False)
|
||||
return flow
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Narrow entry point used by VOIDWarpedNoise
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def get_noise_from_video(
|
||||
video_frames: torch.Tensor,
|
||||
raft: RaftOpticalFlow,
|
||||
*,
|
||||
noise_channels: int = 16,
|
||||
resize_frames: float = 0.5,
|
||||
resize_flow: int = 8,
|
||||
downscale_factor: int = 32,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Produce optical-flow-warped gaussian noise from a video.
|
||||
|
||||
Args:
|
||||
video_frames: ``(T, H, W, 3)`` uint8 torch tensor.
|
||||
raft: Pre-loaded RAFT optical-flow wrapper (see ``RaftOpticalFlow``).
|
||||
noise_channels: Channels in the output noise.
|
||||
resize_frames: Pre-RAFT frame scale factor.
|
||||
resize_flow: Post-flow up-scale factor applied to the optical flow;
|
||||
the internal noise state is allocated at
|
||||
``(resize_flow * resize_frames * H, resize_flow * resize_frames * W)``.
|
||||
downscale_factor: Area-pool factor applied to the noise before return;
|
||||
should evenly divide the internal noise resolution.
|
||||
device: Target device. Defaults to ``comfy.model_management.get_torch_device()``.
|
||||
|
||||
Returns:
|
||||
``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
|
||||
"""
|
||||
if not isinstance(resize_flow, int) or resize_flow < 1:
|
||||
raise ValueError(
|
||||
f"get_noise_from_video: resize_flow must be a positive int, got {resize_flow!r}"
|
||||
)
|
||||
if video_frames.ndim != 4 or video_frames.shape[-1] != 3:
|
||||
raise ValueError(
|
||||
"get_noise_from_video: video_frames must have shape (T, H, W, 3), "
|
||||
f"got {tuple(video_frames.shape)}"
|
||||
)
|
||||
if video_frames.dtype != torch.uint8:
|
||||
raise TypeError(
|
||||
"get_noise_from_video: video_frames must be uint8 in [0, 255], "
|
||||
f"got dtype {video_frames.dtype}"
|
||||
)
|
||||
|
||||
if device is None:
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||
|
||||
if device.type == "cpu":
|
||||
logging.warning(
|
||||
"VOIDWarpedNoise: running get_noise_from_video on CPU; this will be "
|
||||
"slow (minutes for ~45 frames). Use CUDA for interactive use."
|
||||
)
|
||||
|
||||
T = video_frames.shape[0]
|
||||
frames = video_frames.to(device).permute(0, 3, 1, 2).to(torch.float32) / 255.0
|
||||
if resize_frames != 1.0:
|
||||
new_h = max(1, int(frames.shape[2] * resize_frames))
|
||||
new_w = max(1, int(frames.shape[3] * resize_frames))
|
||||
frames = F.interpolate(frames, size=(new_h, new_w), mode="area")
|
||||
|
||||
_, _, H, W = frames.shape
|
||||
internal_h = resize_flow * H
|
||||
internal_w = resize_flow * W
|
||||
if internal_h % downscale_factor or internal_w % downscale_factor:
|
||||
logging.warning(
|
||||
"VOIDWarpedNoise: internal noise size %dx%d is not divisible by "
|
||||
"downscale_factor %d; output noise may have artifacts.",
|
||||
internal_h, internal_w, downscale_factor,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
warper = NoiseWarper(
|
||||
c=noise_channels, h=internal_h, w=internal_w, device=device,
|
||||
)
|
||||
down_h = warper.h // downscale_factor
|
||||
down_w = warper.w // downscale_factor
|
||||
output = torch.empty(
|
||||
(T, down_h, down_w, noise_channels), dtype=torch.float32, device=device,
|
||||
)
|
||||
|
||||
def downscale(noise_chw):
|
||||
# Area-pool to 1/downscale_factor then multiply by downscale_factor
|
||||
# to adjust std (sqrt of pool area == downscale_factor for a
|
||||
# square pool).
|
||||
down = _torch_resize_chw(noise_chw, 1.0 / downscale_factor, "area", copy=False)
|
||||
return down * downscale_factor
|
||||
|
||||
output[0] = downscale(warper.noise).permute(1, 2, 0)
|
||||
|
||||
prev = frames[0]
|
||||
for i in range(1, T):
|
||||
curr = frames[i]
|
||||
flow = raft(prev, curr).to(device)
|
||||
warper(flow[0], flow[1])
|
||||
output[i] = downscale(warper.noise).permute(1, 2, 0)
|
||||
prev = curr
|
||||
|
||||
return output
|
||||
@ -15,6 +15,7 @@ import torch
|
||||
from comfy.cli_args import args
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
import comfy.model_prefetch
|
||||
import comfy_aimdo.model_vbar
|
||||
|
||||
from latent_preview import set_preview_method
|
||||
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
if args.verbose == "DEBUG":
|
||||
comfy_aimdo.control.analyze()
|
||||
comfy.model_management.reset_cast_buffers()
|
||||
comfy.model_prefetch.cleanup_prefetch_queues()
|
||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||
|
||||
if has_pending_tasks:
|
||||
|
||||
@ -54,6 +54,8 @@ folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_enc
|
||||
|
||||
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
input_directory = os.path.join(base_path, "input")
|
||||
|
||||
5
nodes.py
5
nodes.py
@ -958,7 +958,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -968,7 +968,7 @@ class CLIPLoader:
|
||||
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion", device="default"):
|
||||
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
|
||||
@ -2445,6 +2445,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_rtdetr.py",
|
||||
"nodes_frame_interpolation.py",
|
||||
"nodes_sam3.py",
|
||||
"nodes_void.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
Loading…
Reference in New Issue
Block a user