diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 60c0dfd7e..91bebed3d 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -793,9 +793,27 @@ class ZImagePixelSpace(ChromaRadiance): pass class CogVideoX(LatentFormat): + """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). + + scale_factor matches the vae/config.json scaling_factor for the 2b variant. + The 5b-class checkpoints (CogVideoX-5b, CogVideoX-1.5-5B, CogVideoX-Fun-V1.5-*) + use a different value; see CogVideoX1_5 below. + """ latent_channels = 16 latent_dimensions = 3 temporal_downscale_ratio = 4 def __init__(self): self.scale_factor = 1.15258426 + + +class CogVideoX1_5(CogVideoX): + """Latent format for 5b-class CogVideoX checkpoints. + + Covers THUDM/CogVideoX-5b, THUDM/CogVideoX-1.5-5B, and the CogVideoX-Fun + V1.5-5b family (including VOID inpainting). All of these have + scaling_factor=0.7 in their vae/config.json. Auto-selected in + supported_models.CogVideoX_T2V based on transformer hidden dim. + """ + def __init__(self): + self.scale_factor = 0.7 diff --git a/comfy/sd.py b/comfy/sd.py index 9fce0e7d0..749bdd710 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -66,6 +66,7 @@ import comfy.text_encoders.longcat_image import comfy.text_encoders.qwen35 import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 +import comfy.text_encoders.cogvideo import comfy.model_patcher import comfy.lora @@ -1224,6 +1225,7 @@ class CLIPType(Enum): NEWBIE = 24 FLUX2 = 25 LONGCAT_IMAGE = 26 + COGVIDEOX = 27 @@ -1428,6 +1430,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer + elif clip_type == CLIPType.COGVIDEOX: + clip_target.clip = comfy.text_encoders.cogvideo.cogvideo_te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.cogvideo.CogVideoXTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index dff40461f..6a9613602 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1872,6 +1872,14 @@ class CogVideoX_T2V(supported_models_base.BASE): vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] + def __init__(self, unet_config): + # 2b-class (dim=1920, heads=30) uses scale_factor=1.15258426. + # 5b-class (dim=3072, heads=48) — incl. CogVideoX-5b, 1.5-5B, and + # Fun-V1.5 inpainting — uses scale_factor=0.7 per vae/config.json. + if unet_config.get("num_attention_heads", 0) >= 48: + self.latent_format = latent_formats.CogVideoX1_5 + super().__init__(unet_config) + def get_model(self, state_dict, prefix="", device=None): # CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE if self.unet_config.get("patch_size_t") is not None: @@ -1898,6 +1906,20 @@ class CogVideoX_I2V(CogVideoX_T2V): out = model_base.CogVideoX(self, image_to_video=True, device=device) return out +class CogVideoX_Inpaint(CogVideoX_T2V): + unet_config = { + "image_model": "cogvideox", + "in_channels": 48, + } + + def get_model(self, state_dict, prefix="", device=None): + if self.unet_config.get("patch_size_t") is not None: + self.unet_config.setdefault("sample_height", 96) + self.unet_config.setdefault("sample_width", 170) + self.unet_config.setdefault("sample_frames", 81) + out = model_base.CogVideoX(self, image_to_video=True, device=device) + return out + models = [ LotusD, @@ -1978,6 +2000,7 @@ models = [ ErnieImage, SAM3, SAM31, + CogVideoX_Inpaint, CogVideoX_I2V, CogVideoX_T2V, SVD_img2vid, diff --git a/comfy/text_encoders/cogvideo.py b/comfy/text_encoders/cogvideo.py index f1e8e3f5d..b97310709 100644 --- a/comfy/text_encoders/cogvideo.py +++ b/comfy/text_encoders/cogvideo.py @@ -1,6 +1,48 @@ import comfy.text_encoders.sd3_clip +from comfy import sd1_clip class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): + """Inner T5 tokenizer for CogVideoX. + + CogVideoX was trained with T5 embeddings padded to 226 tokens (not 77 like SD3). + Used both directly by supported_models.CogVideoX_T2V.clip_target (paired with + the raw T5XXLModel) and by the CogVideoXTokenizer outer wrapper below. + """ def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) + + +class CogVideoXTokenizer(sd1_clip.SD1Tokenizer): + """Outer tokenizer wrapper for CLIPLoader (type="cogvideox").""" + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, + clip_name="t5xxl", tokenizer=CogVideoXT5Tokenizer) + + +class CogVideoXT5XXL(sd1_clip.SD1ClipModel): + """Outer T5XXL model wrapper for CLIPLoader (type="cogvideox"). + + Wraps the raw T5XXL model in the SD1ClipModel interface so that CLIP.__init__ + (which reads self.dtypes) works correctly. The inner model is the standard + sd3_clip.T5XXLModel (no attention_mask change needed for CogVideoX). + """ + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="t5xxl", + clip_model=comfy.text_encoders.sd3_clip.T5XXLModel, + model_options=model_options) + + +def cogvideo_te(dtype_t5=None, t5_quantization_metadata=None): + """Factory that returns a CogVideoXT5XXL class configured with the detected + T5 dtype and optional quantization metadata, for use in load_text_encoder_state_dicts. + """ + class CogVideoXTEModel_(CogVideoXT5XXL): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5_quantization_metadata is not None: + model_options = model_options.copy() + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata + if dtype_t5 is not None: + dtype = dtype_t5 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return CogVideoXTEModel_ diff --git a/comfy_extras/nodes_void.py b/comfy_extras/nodes_void.py new file mode 100644 index 000000000..e7a8f3757 --- /dev/null +++ b/comfy_extras/nodes_void.py @@ -0,0 +1,483 @@ +import logging + +import torch + +import comfy +import comfy.model_management +import comfy.model_patcher +import comfy.samplers +import comfy.utils +import folder_paths +import node_helpers +import nodes +from comfy.utils import model_trange as trange +from comfy_api.latest import ComfyExtension, io +from torchvision.models.optical_flow import raft_large +from typing_extensions import override + + +from comfy_extras.void_noise_warp import RaftOpticalFlow, get_noise_from_video + +OpticalFlow = io.Custom("OPTICAL_FLOW") + +TEMPORAL_COMPRESSION = 4 +PATCH_SIZE_T = 2 + + +def _valid_void_length(length: int) -> int: + """Round ``length`` down to a value that produces an even latent_t. + + VOID / CogVideoX-Fun-V1.5 uses patch_size_t=2, so the VAE-encoded latent + must have an even temporal dimension. If latent_t is odd, the transformer + pad_to_patch_size circular-wraps an extra latent frame onto the end; after + the post-transformer crop the last real latent frame has been influenced + by the wrapped phantom frame, producing visible jitter and "disappearing" + subjects near the end of the decoded video. Rounding down fixes this. + """ + latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1 + if latent_t % PATCH_SIZE_T == 0: + return length + # Round latent_t down to the nearest multiple of PATCH_SIZE_T, then invert + # the ((length - 1) // TEMPORAL_COMPRESSION) + 1 formula. Floor at 1 frame + # so we never return a non-positive length. + target_latent_t = max(PATCH_SIZE_T, (latent_t // PATCH_SIZE_T) * PATCH_SIZE_T) + return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1 + + +class OpticalFlowLoader(io.ComfyNode): + """Load an optical flow model from ``models/optical_flow/``. + + Only torchvision's RAFT-large format is recognized today (the model used + by VOIDWarpedNoise). The checkpoint must be placed under + ``models/optical_flow/`` — ComfyUI never downloads optical-flow weights + at runtime. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="OpticalFlowLoader", + display_name="Load Optical Flow Model", + category="loaders", + inputs=[ + io.Combo.Input( + "model_name", + options=folder_paths.get_filename_list("optical_flow"), + tooltip=( + "Optical flow model to load. Files must be placed in the " + "'optical_flow' folder. Today only torchvision's " + "raft_large.pth is supported." + ), + ), + ], + outputs=[ + OpticalFlow.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + + model_path = folder_paths.get_full_path_or_raise("optical_flow", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + has_raft_keys = ( + any(k.startswith("feature_encoder.") for k in sd) + and any(k.startswith("context_encoder.") for k in sd) + and any(k.startswith("update_block.") for k in sd) + ) + if not has_raft_keys: + raise ValueError( + "Unrecognized optical flow model format: expected a torchvision " + "RAFT-large state dict with 'feature_encoder.', 'context_encoder.' " + "and 'update_block.' prefixes." + ) + + model = raft_large(weights=None, progress=False) + model.load_state_dict(sd) + model.eval().to(torch.float32) + + patcher = comfy.model_patcher.ModelPatcher( + model, + load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device(), + ) + return io.NodeOutput(patcher) + + +class VOIDQuadmaskPreprocess(io.ComfyNode): + """Preprocess a quadmask video for VOID inpainting. + + Quantizes mask values to four semantic levels, inverts, and normalizes: + 0 -> primary object to remove + 63 -> overlap of primary + affected + 127 -> affected region (interactions) + 255 -> background (keep) + + After inversion and normalization, the output mask has values in [0, 1] + with four discrete levels: 1.0 (remove), ~0.75, ~0.50, 0.0 (keep). + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDQuadmaskPreprocess", + category="mask/video", + inputs=[ + io.Mask.Input("mask"), + io.Int.Input("dilate_width", default=0, min=0, max=50, step=1, + tooltip="Dilation radius for the primary mask region (0 = no dilation)"), + ], + outputs=[ + io.Mask.Output(display_name="quadmask"), + ], + ) + + @classmethod + def execute(cls, mask, dilate_width=0) -> io.NodeOutput: + m = mask.clone() + + if m.max() <= 1.0: + m = m * 255.0 + + if dilate_width > 0 and m.ndim >= 3: + binary = (m < 128).float() + kernel_size = dilate_width * 2 + 1 + if binary.ndim == 3: + binary = binary.unsqueeze(1) + dilated = torch.nn.functional.max_pool2d( + binary, kernel_size=kernel_size, stride=1, padding=dilate_width + ) + if dilated.ndim == 4: + dilated = dilated.squeeze(1) + m = torch.where(dilated > 0.5, torch.zeros_like(m), m) + + m = torch.where(m <= 31, torch.zeros_like(m), m) + m = torch.where((m > 31) & (m <= 95), torch.full_like(m, 63), m) + m = torch.where((m > 95) & (m <= 191), torch.full_like(m, 127), m) + m = torch.where(m > 191, torch.full_like(m, 255), m) + + m = (255.0 - m) / 255.0 + + return io.NodeOutput(m) + + +class VOIDInpaintConditioning(io.ComfyNode): + """Build VOID inpainting conditioning for CogVideoX. + + Encodes the processed quadmask and masked source video through the VAE, + producing a 32-channel concat conditioning (16ch mask + 16ch masked video) + that gets concatenated with the 16ch noise latent by the model. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDInpaintConditioning", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("video", tooltip="Source video frames [T, H, W, 3]"), + io.Mask.Input("quadmask", tooltip="Preprocessed quadmask from VOIDQuadmaskPreprocess [T, H, W]"), + io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1, + tooltip="Number of pixel frames to process. For CogVideoX-Fun-V1.5 " + "(patch_size_t=2), latent_t must be even — lengths that " + "produce odd latent_t are rounded down (e.g. 49 → 45)."), + io.Int.Input("batch_size", default=1, min=1, max=64), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, video, quadmask, + width, height, length, batch_size) -> io.NodeOutput: + + adjusted_length = _valid_void_length(length) + if adjusted_length != length: + logging.warning( + "VOIDInpaintConditioning: rounding length %d down to %d so that " + "latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2). " + "Using odd latent_t causes the last frame to be corrupted by " + "circular padding.", length, adjusted_length, + ) + length = adjusted_length + + latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1 + latent_h = height // 8 + latent_w = width // 8 + + vid = video[:length] + vid = comfy.utils.common_upscale( + vid.movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + + qm = quadmask[:length] + if qm.ndim == 3: + qm = qm.unsqueeze(-1) + qm = comfy.utils.common_upscale( + qm.movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + if qm.ndim == 4 and qm.shape[-1] == 1: + qm = qm.squeeze(-1) + + mask_condition = qm + if mask_condition.ndim == 3: + mask_condition_3ch = mask_condition.unsqueeze(-1).expand(-1, -1, -1, 3) + else: + mask_condition_3ch = mask_condition + + inverted_mask_3ch = 1.0 - mask_condition_3ch + masked_video = vid[:, :, :, :3] * (1.0 - mask_condition_3ch) + + mask_latents = vae.encode(inverted_mask_3ch) + masked_video_latents = vae.encode(masked_video) + + def _match_temporal(lat, target_t): + if lat.shape[2] > target_t: + return lat[:, :, :target_t] + elif lat.shape[2] < target_t: + pad = target_t - lat.shape[2] + return torch.cat([lat, lat[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2) + return lat + + mask_latents = _match_temporal(mask_latents, latent_t) + masked_video_latents = _match_temporal(masked_video_latents, latent_t) + + inpaint_latents = torch.cat([mask_latents, masked_video_latents], dim=1) + + # No explicit scaling needed here: the model's CogVideoX.concat_cond() + # applies process_latent_in (×latent_format.scale_factor) to each 16-ch + # block of the stored conditioning. For 5b-class checkpoints (incl. the + # VOID/CogVideoX-Fun-V1.5 inpainting model) that scale_factor is auto- + # selected as 0.7 in supported_models.CogVideoX_T2V, which matches the + # diffusers vae/config.json scaling_factor VOID was trained with. + + positive = node_helpers.conditioning_set_values( + positive, {"concat_latent_image": inpaint_latents} + ) + negative = node_helpers.conditioning_set_values( + negative, {"concat_latent_image": inpaint_latents} + ) + + noise_latent = torch.zeros( + [batch_size, 16, latent_t, latent_h, latent_w], + device=comfy.model_management.intermediate_device() + ) + + return io.NodeOutput(positive, negative, {"samples": noise_latent}) + + +class VOIDWarpedNoise(io.ComfyNode): + """Generate optical-flow warped noise for VOID Pass 2 refinement. + + Takes the Pass 1 output video and produces temporally-correlated noise + by warping Gaussian noise along optical flow vectors. This noise is used + as the initial latent for Pass 2, resulting in better temporal consistency. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDWarpedNoise", + category="latent/video", + inputs=[ + OpticalFlow.Input( + "optical_flow", + tooltip="Optical flow model from OpticalFlowLoader (RAFT-large).", + ), + io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"), + io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1, + tooltip="Number of pixel frames. Rounded down to make latent_t " + "even (patch_size_t=2 requirement), e.g. 49 → 45."), + io.Int.Input("batch_size", default=1, min=1, max=64), + ], + outputs=[ + io.Latent.Output(display_name="warped_noise"), + ], + ) + + @classmethod + def execute(cls, optical_flow, video, width, height, length, batch_size) -> io.NodeOutput: + + adjusted_length = _valid_void_length(length) + if adjusted_length != length: + logging.warning( + "VOIDWarpedNoise: rounding length %d down to %d so that " + "latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2).", + length, adjusted_length, + ) + length = adjusted_length + + latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1 + latent_h = height // 8 + latent_w = width // 8 + + # RAFT + noise warp is real compute, not an "intermediate" buffer, so + # we want the actual torch device (CUDA/MPS). The final latent is + # moved back to intermediate_device() before returning to match the + # rest of the ComfyUI pipeline. + device = comfy.model_management.get_torch_device() + + comfy.model_management.load_model_gpu(optical_flow) + raft = RaftOpticalFlow(optical_flow.model, device=device) + + vid = video[:length].to(device) + vid = comfy.utils.common_upscale( + vid.movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + vid_uint8 = (vid.clamp(0, 1) * 255).to(torch.uint8) + + FRAME = 2**-1 + FLOW = 2**3 + LATENT_SCALE = 8 + + warped = get_noise_from_video( + vid_uint8, + raft, + noise_channels=16, + resize_frames=FRAME, + resize_flow=FLOW, + downscale_factor=round(FRAME * FLOW) * LATENT_SCALE, + device=device, + ) + + if warped.shape[0] != latent_t: + indices = torch.linspace(0, warped.shape[0] - 1, latent_t, + device=device).long() + warped = warped[indices] + + if warped.shape[1] != latent_h or warped.shape[2] != latent_w: + # (T, H, W, C) → (T, C, H, W) → bilinear resize → back + warped = warped.permute(0, 3, 1, 2) + warped = torch.nn.functional.interpolate( + warped, size=(latent_h, latent_w), + mode="bilinear", align_corners=False, + ) + warped = warped.permute(0, 2, 3, 1) + + # (T, H, W, C) → (B, C, T, H, W) + warped_tensor = warped.permute(3, 0, 1, 2).unsqueeze(0) + if batch_size > 1: + warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1) + + warped_tensor = warped_tensor.to(comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": warped_tensor}) + + +class Noise_FromLatent: + """Wraps a pre-computed LATENT tensor as a NOISE source.""" + def __init__(self, latent_dict): + self.seed = 0 + self._samples = latent_dict["samples"] + + def generate_noise(self, input_latent): + return self._samples.clone().cpu() + + +class VOIDWarpedNoiseSource(io.ComfyNode): + """Convert a LATENT (e.g. from VOIDWarpedNoise) into a NOISE source + for use with SamplerCustomAdvanced.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDWarpedNoiseSource", + category="sampling/custom_sampling/noise", + inputs=[ + io.Latent.Input("warped_noise", + tooltip="Warped noise latent from VOIDWarpedNoise"), + ], + outputs=[io.Noise.Output()], + ) + + @classmethod + def execute(cls, warped_noise) -> io.NodeOutput: + return io.NodeOutput(Noise_FromLatent(warped_noise)) + + +class VOID_DDIM(comfy.samplers.Sampler): + """DDIM sampler for VOID inpainting models. + + VOID was trained with the diffusers CogVideoXDDIMScheduler which operates in + alpha-space (input std ≈ 1). The standard KSampler applies noise_scaling that + multiplies by sqrt(1+sigma^2) ≈ 4500x, which is incompatible with VOID's + training. This sampler skips noise_scaling and implements the DDIM update rule + directly using sigma-to-alpha conversion. + """ + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + x = noise.to(torch.float32) + model_options = extra_args.get("model_options", {}) + seed = extra_args.get("seed", None) + s_in = x.new_ones([x.shape[0]]) + + for i in trange(len(sigmas) - 1, disable=disable_pbar): + sigma = sigmas[i] + sigma_next = sigmas[i + 1] + + denoised = model_wrap(x, sigma * s_in, model_options=model_options, seed=seed) + + if callback is not None: + callback(i, denoised, x, len(sigmas) - 1) + + if sigma_next == 0: + x = denoised + else: + alpha_t = 1.0 / (1.0 + sigma ** 2) + alpha_prev = 1.0 / (1.0 + sigma_next ** 2) + + pred_eps = (x - (alpha_t ** 0.5) * denoised) / (1.0 - alpha_t) ** 0.5 + x = (alpha_prev ** 0.5) * denoised + (1.0 - alpha_prev) ** 0.5 * pred_eps + + return x + + +class VOIDSampler(io.ComfyNode): + """VOID DDIM sampler for use with SamplerCustom / SamplerCustomAdvanced. + + Required for VOID inpainting models. Implements the same DDIM loop that VOID + was trained with (diffusers CogVideoXDDIMScheduler), without the noise_scaling + that the standard KSampler applies. Use with RandomNoise or VOIDWarpedNoiseSource. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDSampler", + category="sampling/custom_sampling/samplers", + inputs=[], + outputs=[io.Sampler.Output()], + ) + + @classmethod + def execute(cls) -> io.NodeOutput: + return io.NodeOutput(VOID_DDIM()) + + get_sampler = execute + + +class VOIDExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + OpticalFlowLoader, + VOIDQuadmaskPreprocess, + VOIDInpaintConditioning, + VOIDWarpedNoise, + VOIDWarpedNoiseSource, + VOIDSampler, + ] + + +async def comfy_entrypoint() -> VOIDExtension: + return VOIDExtension() diff --git a/comfy_extras/void_noise_warp.py b/comfy_extras/void_noise_warp.py new file mode 100644 index 000000000..fcc9a5f8b --- /dev/null +++ b/comfy_extras/void_noise_warp.py @@ -0,0 +1,494 @@ +""" +Optical-flow-warped noise for VOID Pass 2 refinement. + +Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert): + https://github.com/RyannDaGreat/CommonSource + - noise_warp.py (NoiseWarper / warp_xyωc / regaussianize / get_noise_from_video) + - raft.py (RaftOpticalFlow) + +Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually +uses (torch THWC uint8 input, no background removal, no visualization, no disk +I/O, default warp/noise params) have been inlined. External ``rp`` utilities +have been replaced with equivalents from torch.nn.functional / einops. The +RAFT optical-flow model itself is loaded offline via ``OpticalFlowLoader`` in +``nodes_void.py`` and passed into ``get_noise_from_video`` by the caller; this +module never downloads weights at runtime. +""" + +import logging +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange + +import comfy.model_management + + +# --------------------------------------------------------------------------- +# Low-level torch image helpers (drop-in replacements for rp.torch_* primitives) +# --------------------------------------------------------------------------- + +def _torch_resize_chw(image, size, interp, copy=True): + """Resize a CHW tensor. + + ``size`` is either a scalar factor or a (h, w) tuple. ``interp`` is one + of ``"bilinear"``, ``"nearest"``, ``"area"``. When ``copy`` is False and + the requested size matches the input, returns the input tensor as is + (faster but callers must not mutate the result). + """ + if image.ndim != 3: + raise ValueError( + f"_torch_resize_chw expects a 3D CHW tensor, got shape {tuple(image.shape)}" + ) + _, in_h, in_w = image.shape + if isinstance(size, (int, float)) and not isinstance(size, bool): + new_h = max(1, int(in_h * size)) + new_w = max(1, int(in_w * size)) + else: + new_h, new_w = size + + if (new_h, new_w) == (in_h, in_w): + return image.clone() if copy else image + + kwargs = {} + if interp in ("bilinear", "bicubic"): + kwargs["align_corners"] = False + out = F.interpolate(image[None], size=(new_h, new_w), mode=interp, **kwargs)[0] + return out + + +def _torch_remap_relative(image, dx, dy, interp="bilinear"): + """Relative remap of a CHW image via ``F.grid_sample``. + + Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)`` + for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0. + """ + if image.ndim != 3: + raise ValueError( + f"_torch_remap_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}" + ) + if dx.shape != dy.shape: + raise ValueError( + f"_torch_remap_relative: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}" + ) + _, h, w = image.shape + + x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype) + y_abs = dy + torch.arange(h, device=dy.device, dtype=dy.dtype)[:, None] + + x_norm = (x_abs / (w - 1)) * 2 - 1 + y_norm = (y_abs / (h - 1)) * 2 - 1 + + grid = torch.stack([x_norm, y_norm], dim=-1)[None].to(image.dtype) + out = F.grid_sample( + image[None], grid, mode=interp, align_corners=True, padding_mode="zeros" + )[0] + return out + + +def _torch_scatter_add_relative(image, dx, dy): + """Scatter-add a CHW image using relative floor-rounded (dx, dy) offsets. + + Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True, + interp='floor')``. Out-of-bounds targets are dropped. + """ + if image.ndim != 3: + raise ValueError( + f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}" + ) + in_c, in_h, in_w = image.shape + if dx.shape != (in_h, in_w) or dy.shape != (in_h, in_w): + raise ValueError( + f"_torch_scatter_add_relative: dx/dy must be ({in_h}, {in_w}), " + f"got dx={tuple(dx.shape)} dy={tuple(dy.shape)}" + ) + + x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long) + y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None] + + valid = ((y >= 0) & (y < in_h) & (x >= 0) & (x < in_w)).reshape(-1) + indices = (y * in_w + x).reshape(-1)[valid] + + flat_image = rearrange(image, "c h w -> (h w) c")[valid] + out = torch.zeros((in_h * in_w, in_c), dtype=image.dtype, device=image.device) + out.index_add_(0, indices, flat_image) + return rearrange(out, "(h w) c -> c h w", h=in_h, w=in_w) + + +# --------------------------------------------------------------------------- +# Noise warping primitives (ported from noise_warp.py) +# --------------------------------------------------------------------------- + +def unique_pixels(image): + """Find unique pixel values in a CHW tensor. + + Returns ``(unique_colors [U, C], counts [U], index_matrix [H, W])`` where + ``index_matrix[i, j]`` is the index of the unique color at that pixel. + """ + _, h, w = image.shape + flat = rearrange(image, "c h w -> (h w) c") + unique_colors, inverse_indices, counts = torch.unique( + flat, dim=0, return_inverse=True, return_counts=True, sorted=False, + ) + index_matrix = rearrange(inverse_indices, "(h w) -> h w", h=h, w=w) + return unique_colors, counts, index_matrix + + +def sum_indexed_values(image, index_matrix): + """For each unique index, sum the CHW image values at its pixels.""" + _, h, w = image.shape + u = int(index_matrix.max().item()) + 1 + flat = rearrange(image, "c h w -> (h w) c") + out = torch.zeros((u, flat.shape[1]), dtype=flat.dtype, device=flat.device) + out.index_add_(0, index_matrix.view(-1), flat) + return out + + +def indexed_to_image(index_matrix, unique_colors): + """Build a CHW image from an index matrix and a (U, C) color table.""" + h, w = index_matrix.shape + flat = unique_colors[index_matrix.view(-1)] + return rearrange(flat, "(h w) c -> c h w", h=h, w=w) + + +def regaussianize(noise): + """Variance-preserving re-sampling of a CHW noise tensor. + + Wherever the noise contains groups of identical pixel values (e.g. after + a nearest-neighbor warp that duplicated source pixels), adds zero-mean + foreign noise within each group and scales by ``1/sqrt(count)`` so the + output is unit-variance gaussian again. + """ + _, hs, ws = noise.shape + _, counts, index_matrix = unique_pixels(noise[:1]) + + foreign_noise = torch.randn_like(noise) + summed = sum_indexed_values(foreign_noise, index_matrix) + meaned = indexed_to_image(index_matrix, summed / rearrange(counts, "u -> u 1")) + zeroed_foreign = foreign_noise - meaned + + counts_image = indexed_to_image(index_matrix, rearrange(counts, "u -> u 1")) + + output = noise / counts_image ** 0.5 + zeroed_foreign + return output, counts_image + + +def xy_meshgrid_like_image(image): + """Return a (2, H, W) tensor of (x, y) pixel coordinates matching ``image``.""" + _, h, w = image.shape + y, x = torch.meshgrid( + torch.arange(h, device=image.device, dtype=image.dtype), + torch.arange(w, device=image.device, dtype=image.dtype), + indexing="ij", + ) + return torch.stack([x, y]) + + +def noise_to_state(noise): + """Pack a (C, H, W) noise tensor into a state tensor (3+C, H, W) = [dx, dy, ω, noise].""" + zeros = torch.zeros_like(noise[:1]) + ones = torch.ones_like(noise[:1]) + return torch.cat([zeros, zeros, ones, noise]) + + +def state_to_noise(state): + """Unpack the noise channels from a state tensor.""" + return state[3:] + + +def warp_state(state, flow): + """Warp a noise-warper state tensor along the given optical flow. + + ``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels). + ``flow`` has shape ``(2, h, w)`` (= dx, dy). + """ + if flow.device != state.device: + raise ValueError( + f"warp_state: flow and state must be on the same device, " + f"got flow={flow.device} state={state.device}" + ) + if state.ndim != 3: + raise ValueError( + f"warp_state: state must be 3D (3+C, H, W), got shape {tuple(state.shape)}" + ) + xyoc, h, w = state.shape + if flow.shape != (2, h, w): + raise ValueError( + f"warp_state: flow must have shape (2, {h}, {w}), got {tuple(flow.shape)}" + ) + device = state.device + + x_ch, y_ch = 0, 1 + xy = 2 # state[:xy] = [dx, dy] + xyw = 3 # state[:xyw] = [dx, dy, ω] + w_ch = 2 # state[w_ch] = ω + c = xyoc - xyw + oc = xyoc - xy + if c <= 0: + raise ValueError( + f"warp_state: state has no noise channels (expected 3+C with C>0, got {xyoc} channels)" + ) + if not (state[w_ch] > 0).all(): + raise ValueError("warp_state: all weights in state[2] must be > 0") + + grid = xy_meshgrid_like_image(state) + + init = torch.empty_like(state) + init[:xy] = 0 + init[w_ch] = 1 + init[-c:] = 0 + + # --- Expansion branch: nearest-neighbor remap with negated flow --- + pre_expand = torch.empty_like(state) + pre_expand[:xy] = _torch_remap_relative(state[:xy], -flow[0], -flow[1], "nearest") + pre_expand[-oc:] = _torch_remap_relative(state[-oc:], -flow[0], -flow[1], "nearest") + pre_expand[w_ch][pre_expand[w_ch] == 0] = 1 + + # --- Shrink branch: scatter-add state into new positions --- + pre_shrink = state.clone() + pre_shrink[:xy] += flow + + pos = (grid + pre_shrink[:xy]).round() + in_bounds = (pos[x_ch] >= 0) & (pos[x_ch] < w) & (pos[y_ch] >= 0) & (pos[y_ch] < h) + pre_shrink = torch.where(~in_bounds[None], init, pre_shrink) + + scat_xy = pre_shrink[:xy].round() + pre_shrink[:xy] -= scat_xy + pre_shrink[:xy] = 0 # xy_mode='none' in upstream + + def scat(tensor): + return _torch_scatter_add_relative(tensor, scat_xy[0], scat_xy[1]) + + # rp.torch_scatter_add_image on a bool tensor errors on modern torch; + # scatter-sum a float ones tensor and threshold to get the mask instead. + shrink_mask = scat(torch.ones(1, h, w, dtype=state.dtype, device=device)) > 0 + + # Drop expansion samples at positions that will be filled by shrink. + pre_expand = torch.where(shrink_mask, init, pre_expand) + + # Regaussianize both branches together so duplicated-source groups are + # counted globally, then split back apart. + concat = torch.cat([pre_shrink, pre_expand], dim=2) # along width + concat[-c:], counts_image = regaussianize(concat[-c:]) + concat[w_ch] = concat[w_ch] / counts_image[0] + concat[w_ch] = concat[w_ch].nan_to_num() + pre_shrink, expand = torch.chunk(concat, chunks=2, dim=2) + + shrink = torch.empty_like(pre_shrink) + shrink[w_ch] = scat(pre_shrink[w_ch][None])[0] + shrink[:xy] = scat(pre_shrink[:xy] * pre_shrink[w_ch][None]) / shrink[w_ch][None] + shrink[-c:] = scat(pre_shrink[-c:] * pre_shrink[w_ch][None]) / scat( + pre_shrink[w_ch][None] ** 2 + ).sqrt() + + output = torch.where(shrink_mask, shrink, expand) + output[w_ch] = output[w_ch] / output[w_ch].mean() + output[w_ch] += 1e-5 + output[w_ch] **= 0.9999 + return output + + +class NoiseWarper: + """Maintain a warpable noise state and emit gaussian noise per frame. + + Simplified from RyannDaGreat/CommonSource/noise_warp.py::NoiseWarper: + ``scale_factor``, ``post_noise_alpha``, ``progressive_noise_alpha``, and + ``warp_kwargs`` are all dropped since VOIDWarpedNoise always uses defaults. + """ + + def __init__(self, c, h, w, device, dtype=torch.float32): + if c <= 0 or h <= 0 or w <= 0: + raise ValueError( + f"NoiseWarper: c/h/w must all be positive, got c={c} h={h} w={w}" + ) + self.c = c + self.h = h + self.w = w + self.device = device + self.dtype = dtype + + noise = torch.randn(c, h, w, dtype=dtype, device=device) + self._state = noise_to_state(noise) + + @property + def noise(self): + # With scale_factor=1 the "downsample to respect weights" step is a + # size-preserving no-op; the weight-variance correction math still + # runs to stay faithful to upstream. + n = state_to_noise(self._state) + weights = self._state[2:3] + return n * weights / (weights ** 2).sqrt() + + def __call__(self, dx, dy): + if dx.shape != dy.shape: + raise ValueError( + f"NoiseWarper: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}" + ) + flow = torch.stack([dx, dy]).to(self.device, self.dtype) + _, oflowh, ofloww = flow.shape + + flow = _torch_resize_chw(flow, (self.h, self.w), "bilinear", copy=True) + flowh, floww = flow.shape[-2:] + + # Upstream scales flow[0] by flowh/oflowh and flow[1] by floww/ofloww + # (channel-order appears swapped but harmless when H and W are scaled + # by the same factor, which is always the case for our callers). + flow[0] *= flowh / oflowh + flow[1] *= floww / ofloww + + self._state = warp_state(self._state, flow) + return self + + +# --------------------------------------------------------------------------- +# RAFT optical flow wrapper (ported from raft.py) +# --------------------------------------------------------------------------- + +class RaftOpticalFlow: + """RAFT-large wrapper around a pre-loaded torchvision model. + + ``model`` must be the ``torchvision.models.optical_flow.raft_large`` module + with its weights already populated; this class is load-agnostic so the + caller owns downloading/offload concerns (see ``OpticalFlowLoader`` in + ``nodes_void.py``). ``__call__`` returns a ``(2, H, W)`` flow. + """ + + def __init__(self, model, device=None): + if device is None: + device = comfy.model_management.get_torch_device() + device = torch.device(device) if not isinstance(device, torch.device) else device + + model = model.to(device) + model.eval() + self.device = device + self.model = model + + def _preprocess(self, image_chw): + image = image_chw.to(self.device, torch.float32) + _, h, w = image.shape + new_h = (h // 8) * 8 + new_w = (w // 8) * 8 + image = _torch_resize_chw(image, (new_h, new_w), "bilinear", copy=False) + image = image * 2 - 1 + return image[None] + + def __call__(self, from_image, to_image): + """``from_image``, ``to_image``: CHW float tensors in [0, 1].""" + if from_image.shape != to_image.shape: + raise ValueError( + f"RaftOpticalFlow: from_image and to_image must match, " + f"got {tuple(from_image.shape)} vs {tuple(to_image.shape)}" + ) + _, h, w = from_image.shape + with torch.no_grad(): + img1 = self._preprocess(from_image) + img2 = self._preprocess(to_image) + list_of_flows = self.model(img1, img2) + flow = list_of_flows[-1][0] # (2, new_h, new_w) + if flow.shape[-2:] != (h, w): + flow = _torch_resize_chw(flow, (h, w), "bilinear", copy=False) + return flow + + +# --------------------------------------------------------------------------- +# Narrow entry point used by VOIDWarpedNoise +# --------------------------------------------------------------------------- + +def get_noise_from_video( + video_frames: torch.Tensor, + raft: RaftOpticalFlow, + *, + noise_channels: int = 16, + resize_frames: float = 0.5, + resize_flow: int = 8, + downscale_factor: int = 32, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """Produce optical-flow-warped gaussian noise from a video. + + Args: + video_frames: ``(T, H, W, 3)`` uint8 torch tensor. + raft: Pre-loaded RAFT optical-flow wrapper (see ``RaftOpticalFlow``). + noise_channels: Channels in the output noise. + resize_frames: Pre-RAFT frame scale factor. + resize_flow: Post-flow up-scale factor applied to the optical flow; + the internal noise state is allocated at + ``(resize_flow * resize_frames * H, resize_flow * resize_frames * W)``. + downscale_factor: Area-pool factor applied to the noise before return; + should evenly divide the internal noise resolution. + device: Target device. Defaults to ``comfy.model_management.get_torch_device()``. + + Returns: + ``(T, H', W', noise_channels)`` float32 noise tensor on ``device``. + """ + if not isinstance(resize_flow, int) or resize_flow < 1: + raise ValueError( + f"get_noise_from_video: resize_flow must be a positive int, got {resize_flow!r}" + ) + if video_frames.ndim != 4 or video_frames.shape[-1] != 3: + raise ValueError( + "get_noise_from_video: video_frames must have shape (T, H, W, 3), " + f"got {tuple(video_frames.shape)}" + ) + if video_frames.dtype != torch.uint8: + raise TypeError( + "get_noise_from_video: video_frames must be uint8 in [0, 255], " + f"got dtype {video_frames.dtype}" + ) + + if device is None: + device = comfy.model_management.get_torch_device() + device = torch.device(device) if not isinstance(device, torch.device) else device + + if device.type == "cpu": + logging.warning( + "VOIDWarpedNoise: running get_noise_from_video on CPU; this will be " + "slow (minutes for ~45 frames). Use CUDA for interactive use." + ) + + T = video_frames.shape[0] + frames = video_frames.to(device).permute(0, 3, 1, 2).to(torch.float32) / 255.0 + if resize_frames != 1.0: + new_h = max(1, int(frames.shape[2] * resize_frames)) + new_w = max(1, int(frames.shape[3] * resize_frames)) + frames = F.interpolate(frames, size=(new_h, new_w), mode="area") + + _, _, H, W = frames.shape + internal_h = resize_flow * H + internal_w = resize_flow * W + if internal_h % downscale_factor or internal_w % downscale_factor: + logging.warning( + "VOIDWarpedNoise: internal noise size %dx%d is not divisible by " + "downscale_factor %d; output noise may have artifacts.", + internal_h, internal_w, downscale_factor, + ) + + with torch.no_grad(): + warper = NoiseWarper( + c=noise_channels, h=internal_h, w=internal_w, device=device, + ) + down_h = warper.h // downscale_factor + down_w = warper.w // downscale_factor + output = torch.empty( + (T, down_h, down_w, noise_channels), dtype=torch.float32, device=device, + ) + + def downscale(noise_chw): + # Area-pool to 1/downscale_factor then multiply by downscale_factor + # to adjust std (sqrt of pool area == downscale_factor for a + # square pool). + down = _torch_resize_chw(noise_chw, 1.0 / downscale_factor, "area", copy=False) + return down * downscale_factor + + output[0] = downscale(warper.noise).permute(1, 2, 0) + + prev = frames[0] + for i in range(1, T): + curr = frames[i] + flow = raft(prev, curr).to(device) + warper(flow[0], flow[1]) + output[i] = downscale(warper.noise).permute(1, 2, 0) + prev = curr + + return output diff --git a/folder_paths.py b/folder_paths.py index 039f72636..98d3b1880 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -54,6 +54,8 @@ folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_enc folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) +folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/models/optical_flow/put_optical_flow_models_here b/models/optical_flow/put_optical_flow_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index cf61d9df0..ad0cbc675 100644 --- a/nodes.py +++ b/nodes.py @@ -958,7 +958,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -968,7 +968,7 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -2430,6 +2430,7 @@ async def init_builtin_extra_nodes(): "nodes_rtdetr.py", "nodes_frame_interpolation.py", "nodes_sam3.py", + "nodes_void.py", ] import_failed = []