diff --git a/comfy/model_base.py b/comfy/model_base.py index c3c807a68..ad661ec7d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -105,6 +105,29 @@ class BaseModel(torch.nn.Module): return {**unet_state_dict, **vae_state_dict, **clip_state_dict} +def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0): + adm_inputs = [] + weights = [] + noise_aug = [] + for unclip_cond in unclip_conditioning: + for adm_cond in unclip_cond["clip_vision_output"].image_embeds: + weight = unclip_cond["strength"] + noise_augment = unclip_cond["noise_augmentation"] + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight + weights.append(weight) + noise_aug.append(noise_augment) + adm_inputs.append(adm_out) + + if len(noise_aug) > 1: + adm_out = torch.stack(adm_inputs).sum(0) + noise_augment = noise_augment_merge + noise_level = round((noise_augmentor.max_noise_level - 1) * noise_augment) + c_adm, noise_level_emb = noise_augmentor(adm_out[:, :noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) + adm_out = torch.cat((c_adm, noise_level_emb), 1) + + return adm_out class SD21UNCLIP(BaseModel): def __init__(self, model_config, noise_aug_config, model_type=ModelType.V_PREDICTION, device=None): @@ -114,33 +137,11 @@ class SD21UNCLIP(BaseModel): def encode_adm(self, **kwargs): unclip_conditioning = kwargs.get("unclip_conditioning", None) device = kwargs["device"] - - if unclip_conditioning is not None: - adm_inputs = [] - weights = [] - noise_aug = [] - for unclip_cond in unclip_conditioning: - for adm_cond in unclip_cond["clip_vision_output"].image_embeds: - weight = unclip_cond["strength"] - noise_augment = unclip_cond["noise_augmentation"] - noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = self.noise_augmentor(adm_cond.to(device), noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) * weight - weights.append(weight) - noise_aug.append(noise_augment) - adm_inputs.append(adm_out) - - if len(noise_aug) > 1: - adm_out = torch.stack(adm_inputs).sum(0) - #TODO: add a way to control this - noise_augment = 0.05 - noise_level = round((self.noise_augmentor.max_noise_level - 1) * noise_augment) - c_adm, noise_level_emb = self.noise_augmentor(adm_out[:, :self.noise_augmentor.time_embed.dim], noise_level=torch.tensor([noise_level], device=device)) - adm_out = torch.cat((c_adm, noise_level_emb), 1) + if unclip_conditioning is None: + return torch.zeros((1, self.adm_channels)) else: - adm_out = torch.zeros((1, self.adm_channels)) + return unclip_adm(unclip_conditioning, device, self.noise_augmentor, kwargs.get("unclip_noise_augment_merge", 0.05)) - return adm_out class SDInpaint(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 3be141dfe..a138b292e 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -59,8 +59,8 @@ class Blend: def g(self, x): return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) -def gaussian_kernel(kernel_size: int, sigma: float): - x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size), indexing="ij") +def gaussian_kernel(kernel_size: int, sigma: float, device=None): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size, device=device), torch.linspace(-1, 1, kernel_size, device=device), indexing="ij") d = torch.sqrt(x * x + y * y) g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) return g / g.sum() @@ -101,7 +101,7 @@ class Blur: batch_size, height, width, channels = image.shape kernel_size = blur_radius * 2 + 1 - kernel = gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) + kernel = gaussian_kernel(kernel_size, sigma, device=image.device).repeat(channels, 1, 1).unsqueeze(1) image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) padded_image = F.pad(image, (blur_radius,blur_radius,blur_radius,blur_radius), 'reflect') diff --git a/nodes.py b/nodes.py index 5f7ea95c0..5b144c2fc 100644 --- a/nodes.py +++ b/nodes.py @@ -1448,6 +1448,22 @@ class ImageInvert: s = 1.0 - image return (s,) +class ImageBatch: + + @classmethod + def INPUT_TYPES(s): + return {"required": { "image1": ("IMAGE",), "image2": ("IMAGE",)}} + + RETURN_TYPES = ("IMAGE",) + FUNCTION = "batch" + + CATEGORY = "image" + + def batch(self, image1, image2): + if image1.shape[1:] != image2.shape[1:]: + image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1) + s = torch.cat((image1, image2), dim=0) + return (s,) class ImagePadForOutpaint: @@ -1533,6 +1549,7 @@ NODE_CLASS_MAPPINGS = { "ImageScale": ImageScale, "ImageScaleBy": ImageScaleBy, "ImageInvert": ImageInvert, + "ImageBatch": ImageBatch, "ImagePadForOutpaint": ImagePadForOutpaint, "ConditioningAverage ": ConditioningAverage , "ConditioningCombine": ConditioningCombine, @@ -1627,6 +1644,7 @@ NODE_DISPLAY_NAME_MAPPINGS = { "ImageUpscaleWithModel": "Upscale Image (using Model)", "ImageInvert": "Invert Image", "ImagePadForOutpaint": "Pad Image for Outpainting", + "ImageBatch": "Batch Images", # _for_testing "VAEDecodeTiled": "VAE Decode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)",