diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 33b268ac1..9d45ee3ac 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1828,6 +1828,20 @@ class CogVideoX_I2V(CogVideoX_T2V): out = model_base.CogVideoX(self, image_to_video=True, device=device) return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, CogVideoX_I2V, CogVideoX_T2V] +class CogVideoX_Inpaint(CogVideoX_T2V): + unet_config = { + "image_model": "cogvideox", + "in_channels": 48, + } + + def get_model(self, state_dict, prefix="", device=None): + if self.unet_config.get("patch_size_t") is not None: + self.unet_config.setdefault("sample_height", 96) + self.unet_config.setdefault("sample_width", 170) + self.unet_config.setdefault("sample_frames", 81) + out = model_base.CogVideoX(self, image_to_video=True, device=device) + return out + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, CogVideoX_Inpaint, CogVideoX_I2V, CogVideoX_T2V] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_void.py b/comfy_extras/nodes_void.py new file mode 100644 index 000000000..79122d013 --- /dev/null +++ b/comfy_extras/nodes_void.py @@ -0,0 +1,172 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.utils +from comfy_api.latest import io, ComfyExtension +from typing_extensions import override + + +class VOIDQuadmaskPreprocess(io.ComfyNode): + """Preprocess a quadmask video for VOID inpainting. + + Quantizes mask values to four semantic levels, inverts, and normalizes: + 0 -> primary object to remove + 63 -> overlap of primary + affected + 127 -> affected region (interactions) + 255 -> background (keep) + + After inversion and normalization, the output mask has values in [0, 1] + with four discrete levels: 1.0 (remove), ~0.75, ~0.50, 0.0 (keep). + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDQuadmaskPreprocess", + category="mask/video", + inputs=[ + io.Mask.Input("mask"), + io.Int.Input("dilate_width", default=0, min=0, max=50, step=1, + tooltip="Dilation radius for the primary mask region (0 = no dilation)"), + ], + outputs=[ + io.Mask.Output(display_name="quadmask"), + ], + ) + + @classmethod + def execute(cls, mask, dilate_width=0) -> io.NodeOutput: + m = mask.clone() + + if m.max() <= 1.0: + m = m * 255.0 + + if dilate_width > 0 and m.ndim >= 3: + binary = (m < 128).float() + kernel_size = dilate_width * 2 + 1 + if binary.ndim == 3: + binary = binary.unsqueeze(1) + dilated = torch.nn.functional.max_pool2d( + binary, kernel_size=kernel_size, stride=1, padding=dilate_width + ) + if dilated.ndim == 4: + dilated = dilated.squeeze(1) + m = torch.where(dilated > 0.5, torch.zeros_like(m), m) + + m = torch.where(m <= 31, torch.zeros_like(m), m) + m = torch.where((m > 31) & (m <= 95), torch.full_like(m, 63), m) + m = torch.where((m > 95) & (m <= 191), torch.full_like(m, 127), m) + m = torch.where(m > 191, torch.full_like(m, 255), m) + + m = (255.0 - m) / 255.0 + + return io.NodeOutput(m) + + +class VOIDInpaintConditioning(io.ComfyNode): + """Build VOID inpainting conditioning for CogVideoX. + + Encodes the processed quadmask and masked source video through the VAE, + producing a 32-channel concat conditioning (16ch mask + 16ch masked video) + that gets concatenated with the 16ch noise latent by the model. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="VOIDInpaintConditioning", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("video", tooltip="Source video frames [T, H, W, 3]"), + io.Mask.Input("quadmask", tooltip="Preprocessed quadmask from VOIDQuadmaskPreprocess [T, H, W]"), + io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=1, + tooltip="Number of pixel frames to process"), + io.Int.Input("batch_size", default=1, min=1, max=64), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, video, quadmask, + width, height, length, batch_size) -> io.NodeOutput: + + temporal_compression = 4 + latent_t = ((length - 1) // temporal_compression) + 1 + latent_h = height // 8 + latent_w = width // 8 + + vid = video[:length] + vid = comfy.utils.common_upscale( + vid.movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + + qm = quadmask[:length] + if qm.ndim == 3: + qm = qm.unsqueeze(-1) + qm = comfy.utils.common_upscale( + qm.movedim(-1, 1), width, height, "bilinear", "center" + ).movedim(1, -1) + if qm.ndim == 4 and qm.shape[-1] == 1: + qm = qm.squeeze(-1) + + mask_condition = qm + if mask_condition.ndim == 3: + mask_condition_3ch = mask_condition.unsqueeze(-1).expand(-1, -1, -1, 3) + else: + mask_condition_3ch = mask_condition + + inverted_mask_3ch = 1.0 - mask_condition_3ch + masked_video = vid[:, :, :, :3] * (1.0 - mask_condition_3ch) + + mask_latents = vae.encode(inverted_mask_3ch) + masked_video_latents = vae.encode(masked_video) + + def _match_temporal(lat, target_t): + if lat.shape[2] > target_t: + return lat[:, :, :target_t] + elif lat.shape[2] < target_t: + pad = target_t - lat.shape[2] + return torch.cat([lat, lat[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2) + return lat + + mask_latents = _match_temporal(mask_latents, latent_t) + masked_video_latents = _match_temporal(masked_video_latents, latent_t) + + inpaint_latents = torch.cat([mask_latents, masked_video_latents], dim=1) + + positive = node_helpers.conditioning_set_values( + positive, {"concat_latent_image": inpaint_latents} + ) + negative = node_helpers.conditioning_set_values( + negative, {"concat_latent_image": inpaint_latents} + ) + + noise_latent = torch.zeros( + [batch_size, 16, latent_t, latent_h, latent_w], + device=comfy.model_management.intermediate_device() + ) + + return io.NodeOutput(positive, negative, {"samples": noise_latent}) + + +class VOIDExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + VOIDQuadmaskPreprocess, + VOIDInpaintConditioning, + ] + + +async def comfy_entrypoint() -> VOIDExtension: + return VOIDExtension() diff --git a/nodes.py b/nodes.py index ba2fa0246..5a6258054 100644 --- a/nodes.py +++ b/nodes.py @@ -2458,6 +2458,7 @@ async def init_builtin_extra_nodes(): "nodes_painter.py", "nodes_curve.py", "nodes_rtdetr.py", + "nodes_void.py", ] import_failed = []