Add noise to specific channels in a latent

This commit is contained in:
doctorpangloss 2024-09-23 08:51:48 -07:00
parent da4bc5dba3
commit 4bdc208f29
6 changed files with 80 additions and 47 deletions

View File

@ -6,7 +6,8 @@ from typing_extensions import NotRequired
ImageBatch = Float[Tensor, "batch height width channels"] ImageBatch = Float[Tensor, "batch height width channels"]
LatentBatch = Float[Tensor, "batch channels height width"] LatentBatch = Float[Tensor, "batch channels height width"]
SD1LatentBatch = Float[Tensor, "batch 8 height width"] SD15LatentBatch = Float[Tensor, "batch 4 height width"]
SDXLLatentBatch = Float[Tensor, "batch 8 height width"]
SD3LatentBatch = Float[Tensor, "batch 16 height width"] SD3LatentBatch = Float[Tensor, "batch 16 height width"]
MaskBatch = Float[Tensor, "batch height width"] MaskBatch = Float[Tensor, "batch height width"]
RGBImageBatch = Float[Tensor, "batch height width 3"] RGBImageBatch = Float[Tensor, "batch height width 3"]

View File

@ -480,13 +480,6 @@ class LoadLatent:
samples = {"samples": latent["latent_tensor"].float() * multiplier} samples = {"samples": latent["latent_tensor"].float() * multiplier}
return (samples, ) return (samples, )
@classmethod
def IS_CHANGED(s, latent):
image_path = folder_paths.get_annotated_filepath(latent)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod @classmethod
def VALIDATE_INPUTS(s, latent): def VALIDATE_INPUTS(s, latent):
@ -1625,14 +1618,6 @@ class LoadImage:
return (output_image, output_mask) return (output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod @classmethod
def VALIDATE_INPUTS(s, image): def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image): if not folder_paths.exists_annotated_filepath(image):
@ -1674,14 +1659,6 @@ class LoadImageMask:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (mask.unsqueeze(0),) return (mask.unsqueeze(0),)
@classmethod
def IS_CHANGED(s, image, channel):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod @classmethod
def VALIDATE_INPUTS(s, image): def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image): if not folder_paths.exists_annotated_filepath(image):

View File

@ -927,6 +927,7 @@ def seed_for_block(seed):
torch_rng_state = torch.get_rng_state() torch_rng_state = torch.get_rng_state()
random_state = random.getstate() random_state = random.getstate()
numpy_rng_state = np.random.get_state() numpy_rng_state = np.random.get_state()
# todo: investigate with torch.random.fork_rng(devices=(device,))
if torch.cuda.is_available(): if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state_all() cuda_rng_state = torch.cuda.get_rng_state_all()
else: else:

View File

@ -29,7 +29,7 @@ class FluxGuidance:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"conditioning": ("CONDITIONING", ), "conditioning": ("CONDITIONING", ),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}), "guidance": ("FLOAT", {"default": 3.5, "min": -100.0, "max": 100.0, "step": 0.1}),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)

View File

@ -1,6 +1,9 @@
import comfy.utils
import torch import torch
import comfy.utils
from comfy.component_model.tensor_types import Latent
def reshape_latent_to(target_shape, latent): def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]: if latent.shape[1:] != target_shape[1:]:
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center") latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
@ -10,7 +13,7 @@ def reshape_latent_to(target_shape, latent):
class LatentAdd: class LatentAdd:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} return {"required": {"samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "op" FUNCTION = "op"
@ -27,10 +30,11 @@ class LatentAdd:
samples_out["samples"] = s1 + s2 samples_out["samples"] = s1 + s2
return (samples_out,) return (samples_out,)
class LatentSubtract: class LatentSubtract:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} return {"required": {"samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "op" FUNCTION = "op"
@ -47,11 +51,12 @@ class LatentSubtract:
samples_out["samples"] = s1 - s2 samples_out["samples"] = s1 - s2
return (samples_out,) return (samples_out,)
class LatentMultiply: class LatentMultiply:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), return {"required": {"samples": ("LATENT",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
@ -66,13 +71,14 @@ class LatentMultiply:
samples_out["samples"] = s1 * multiplier samples_out["samples"] = s1 * multiplier
return (samples_out,) return (samples_out,)
class LatentInterpolate: class LatentInterpolate:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), return {"required": {"samples1": ("LATENT",),
"samples2": ("LATENT",), "samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "op" FUNCTION = "op"
@ -100,10 +106,11 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,) return (samples_out,)
class LatentBatch: class LatentBatch:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}} return {"required": {"samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "batch" FUNCTION = "batch"
@ -122,11 +129,12 @@ class LatentBatch:
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,) return (samples_out,)
class LatentBatchSeedBehavior: class LatentBatchSeedBehavior:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",), return {"required": {"samples": ("LATENT",),
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}} "seed_behavior": (["random", "fixed"], {"default": "fixed"}), }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "op" FUNCTION = "op"
@ -145,6 +153,44 @@ class LatentBatchSeedBehavior:
return (samples_out,) return (samples_out,)
class LatentAddNoiseChannels:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"samples": ("LATENT",),
"std_dev": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"slice_i": ("INT", {"default": 0, "min": -16, "max": 16}),
"slice_j": ("INT", {"default": 16, "min": -16, "max": 16}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "inject_noise"
CATEGORY = "latent/advanced"
def inject_noise(self, samples: Latent, std_dev, seed: int, slice_i: int, slice_j: int):
s = samples.copy()
latent = samples["samples"]
with comfy.utils.seed_for_block(seed):
if not isinstance(latent, torch.Tensor):
raise TypeError("Input must be a PyTorch tensor")
noise = torch.randn_like(latent[:, slice_i:slice_j, :, :]) * std_dev
noised_latent = latent.clone()
noised_latent[:, slice_i:slice_j, :, :] += noise
s["samples"] = noised_latent
return (s,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd, "LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract, "LatentSubtract": LatentSubtract,
@ -152,4 +198,5 @@ NODE_CLASS_MAPPINGS = {
"LatentInterpolate": LatentInterpolate, "LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch, "LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior, "LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentAddNoiseChannels": LatentAddNoiseChannels,
} }

View File

@ -1,11 +1,13 @@
import torch
from torch import einsum
import torch.nn.functional as F
import math import math
import torch
import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from comfy.ldm.modules.attention import optimized_attention from torch import einsum
from comfy import samplers from comfy import samplers
from comfy.ldm.modules.attention import optimized_attention
# from comfy/ldm/modules/attention.py # from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output # but modified to return attention scores as well as output
@ -50,6 +52,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
) )
return (out, sim) return (out, sim)
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0): def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map # reshape and GAP the attention map
_, hw1, hw2 = attn.shape _, hw1, hw2 = attn.shape
@ -57,7 +60,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
attn = attn.reshape(b, -1, hw1, hw2) attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool # Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length() ratio = 2 ** (math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)] mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape # Reshape
@ -73,6 +76,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
blurred = blurred * mask + x0 * (1 - mask) blurred = blurred * mask + x0 * (1 - mask)
return blurred return blurred
def gaussian_blur_2d(img, kernel_size, sigma): def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5 ksize_half = (kernel_size - 1) * 0.5
@ -92,13 +96,15 @@ def gaussian_blur_2d(img, kernel_size, sigma):
img = F.conv2d(img, kernel2d, groups=img.shape[-3]) img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img return img
class SelfAttentionGuidance: class SelfAttentionGuidance:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": {"model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}), "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "patch" FUNCTION = "patch"
@ -123,7 +129,7 @@ class SelfAttentionGuidance:
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) (out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn] # when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)] attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index + 1)]
return out return out
else: else:
return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"]) return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
@ -142,7 +148,7 @@ class SelfAttentionGuidance:
sigma = args["sigma"] sigma = args["sigma"]
model_options = args["model_options"] model_options = args["model_options"]
x = args["input"] x = args["input"]
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding if min(cfg_result.shape[2:]) <= 4: # skip when too small to add padding
return cfg_result return cfg_result
# create the adversarially blurred image # create the adversarially blurred image
@ -158,7 +164,8 @@ class SelfAttentionGuidance:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch # unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0) m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
return (m, ) return (m,)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SelfAttentionGuidance": SelfAttentionGuidance, "SelfAttentionGuidance": SelfAttentionGuidance,