mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 01:52:59 +08:00
Inpainting for z image fun control. Use the ZImageFunControlnet node. (#11346)
image -> control image ex: pose inpaint_image -> image for inpainting mask -> inpaint mask
This commit is contained in:
parent
3d082c3206
commit
645ee1881e
@ -313,22 +313,46 @@ class ZImageControlPatch:
|
|||||||
self.inpaint_image = inpaint_image
|
self.inpaint_image = inpaint_image
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.encoded_image = self.encode_latent_cond(image)
|
self.is_inpaint = self.model_patch.model.additional_in_dim > 0
|
||||||
self.encoded_image_size = (image.shape[1], image.shape[2])
|
|
||||||
|
skip_encoding = False
|
||||||
|
if self.image is not None and self.inpaint_image is not None:
|
||||||
|
if self.image.shape != self.inpaint_image.shape:
|
||||||
|
skip_encoding = True
|
||||||
|
|
||||||
|
if skip_encoding:
|
||||||
|
self.encoded_image = None
|
||||||
|
else:
|
||||||
|
self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image)
|
||||||
|
if self.image is None:
|
||||||
|
self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2])
|
||||||
|
else:
|
||||||
|
self.encoded_image_size = (self.image.shape[1], self.image.shape[2])
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
|
|
||||||
def encode_latent_cond(self, control_image, inpaint_image=None):
|
def encode_latent_cond(self, control_image=None, inpaint_image=None):
|
||||||
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
latent_image = None
|
||||||
if self.model_patch.model.additional_in_dim > 0:
|
if control_image is not None:
|
||||||
if self.mask is None:
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
|
||||||
mask_ = torch.zeros_like(latent_image)[:, :1]
|
|
||||||
else:
|
if self.is_inpaint:
|
||||||
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
|
|
||||||
if inpaint_image is None:
|
if inpaint_image is None:
|
||||||
inpaint_image = torch.ones_like(control_image) * 0.5
|
inpaint_image = torch.ones_like(control_image) * 0.5
|
||||||
|
|
||||||
|
if self.mask is not None:
|
||||||
|
mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center")
|
||||||
|
inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5
|
||||||
|
|
||||||
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
|
||||||
|
|
||||||
|
if self.mask is None:
|
||||||
|
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
|
||||||
|
else:
|
||||||
|
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
|
||||||
|
|
||||||
|
if latent_image is None:
|
||||||
|
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
|
||||||
|
|
||||||
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
|
||||||
else:
|
else:
|
||||||
return latent_image
|
return latent_image
|
||||||
@ -344,13 +368,18 @@ class ZImageControlPatch:
|
|||||||
block_type = kwargs.get("block_type", "")
|
block_type = kwargs.get("block_type", "")
|
||||||
spacial_compression = self.vae.spacial_compression_encode()
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
|
||||||
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
|
image_scaled = None
|
||||||
|
if self.image is not None:
|
||||||
|
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||||
|
self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2])
|
||||||
|
|
||||||
inpaint_scaled = None
|
inpaint_scaled = None
|
||||||
if self.inpaint_image is not None:
|
if self.inpaint_image is not None:
|
||||||
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
|
||||||
|
self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2])
|
||||||
|
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled)
|
self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled)
|
||||||
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
|
|
||||||
comfy.model_management.load_models_gpu(loaded_models)
|
comfy.model_management.load_models_gpu(loaded_models)
|
||||||
|
|
||||||
cnet_blocks = self.model_patch.model.n_control_layers
|
cnet_blocks = self.model_patch.model.n_control_layers
|
||||||
@ -391,7 +420,8 @@ class ZImageControlPatch:
|
|||||||
|
|
||||||
def to(self, device_or_dtype):
|
def to(self, device_or_dtype):
|
||||||
if isinstance(device_or_dtype, torch.device):
|
if isinstance(device_or_dtype, torch.device):
|
||||||
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
if self.encoded_image is not None:
|
||||||
|
self.encoded_image = self.encoded_image.to(device_or_dtype)
|
||||||
self.temp_data = None
|
self.temp_data = None
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@ -414,9 +444,12 @@ class QwenImageDiffsynthControlnet:
|
|||||||
|
|
||||||
CATEGORY = "advanced/loaders/qwen"
|
CATEGORY = "advanced/loaders/qwen"
|
||||||
|
|
||||||
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
|
def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None):
|
||||||
model_patched = model.clone()
|
model_patched = model.clone()
|
||||||
image = image[:, :, :, :3]
|
if image is not None:
|
||||||
|
image = image[:, :, :, :3]
|
||||||
|
if inpaint_image is not None:
|
||||||
|
inpaint_image = inpaint_image[:, :, :, :3]
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
if mask.ndim == 3:
|
if mask.ndim == 3:
|
||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
@ -425,13 +458,24 @@ class QwenImageDiffsynthControlnet:
|
|||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
|
|
||||||
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
|
||||||
patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask)
|
patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask)
|
||||||
model_patched.set_model_noise_refiner_patch(patch)
|
model_patched.set_model_noise_refiner_patch(patch)
|
||||||
model_patched.set_model_double_block_patch(patch)
|
model_patched.set_model_double_block_patch(patch)
|
||||||
else:
|
else:
|
||||||
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
|
||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|
||||||
|
class ZImageFunControlnet(QwenImageDiffsynthControlnet):
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": { "model": ("MODEL",),
|
||||||
|
"model_patch": ("MODEL_PATCH",),
|
||||||
|
"vae": ("VAE",),
|
||||||
|
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
||||||
|
},
|
||||||
|
"optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}}
|
||||||
|
|
||||||
|
CATEGORY = "advanced/loaders/zimage"
|
||||||
|
|
||||||
class UsoStyleProjectorPatch:
|
class UsoStyleProjectorPatch:
|
||||||
def __init__(self, model_patch, encoded_image):
|
def __init__(self, model_patch, encoded_image):
|
||||||
@ -479,5 +523,6 @@ class USOStyleReference:
|
|||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelPatchLoader": ModelPatchLoader,
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
|
"ZImageFunControlnet": ZImageFunControlnet,
|
||||||
"USOStyleReference": USOStyleReference,
|
"USOStyleReference": USOStyleReference,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user