Add noise to specific channels in a latent

This commit is contained in:
doctorpangloss 2024-09-23 08:51:48 -07:00
parent da4bc5dba3
commit 4bdc208f29
6 changed files with 80 additions and 47 deletions

View File

@ -6,7 +6,8 @@ from typing_extensions import NotRequired
ImageBatch = Float[Tensor, "batch height width channels"]
LatentBatch = Float[Tensor, "batch channels height width"]
SD1LatentBatch = Float[Tensor, "batch 8 height width"]
SD15LatentBatch = Float[Tensor, "batch 4 height width"]
SDXLLatentBatch = Float[Tensor, "batch 8 height width"]
SD3LatentBatch = Float[Tensor, "batch 16 height width"]
MaskBatch = Float[Tensor, "batch height width"]
RGBImageBatch = Float[Tensor, "batch height width 3"]

View File

@ -480,13 +480,6 @@ class LoadLatent:
samples = {"samples": latent["latent_tensor"].float() * multiplier}
return (samples, )
@classmethod
def IS_CHANGED(s, latent):
image_path = folder_paths.get_annotated_filepath(latent)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, latent):
@ -1625,14 +1618,6 @@ class LoadImage:
return (output_image, output_mask)
@classmethod
def IS_CHANGED(s, image):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
@ -1674,14 +1659,6 @@ class LoadImageMask:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
return (mask.unsqueeze(0),)
@classmethod
def IS_CHANGED(s, image, channel):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):

View File

@ -927,6 +927,7 @@ def seed_for_block(seed):
torch_rng_state = torch.get_rng_state()
random_state = random.getstate()
numpy_rng_state = np.random.get_state()
# todo: investigate with torch.random.fork_rng(devices=(device,))
if torch.cuda.is_available():
cuda_rng_state = torch.cuda.get_rng_state_all()
else:

View File

@ -29,7 +29,7 @@ class FluxGuidance:
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
"guidance": ("FLOAT", {"default": 3.5, "min": -100.0, "max": 100.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING",)

View File

@ -1,6 +1,9 @@
import comfy.utils
import torch
import comfy.utils
from comfy.component_model.tensor_types import Latent
def reshape_latent_to(target_shape, latent):
if latent.shape[1:] != target_shape[1:]:
latent = comfy.utils.common_upscale(latent, target_shape[3], target_shape[2], "bilinear", "center")
@ -10,7 +13,7 @@ def reshape_latent_to(target_shape, latent):
class LatentAdd:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
return {"required": {"samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
@ -27,10 +30,11 @@ class LatentAdd:
samples_out["samples"] = s1 + s2
return (samples_out,)
class LatentSubtract:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
return {"required": {"samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
@ -47,11 +51,12 @@ class LatentSubtract:
samples_out["samples"] = s1 - s2
return (samples_out,)
class LatentMultiply:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
return {"required": {"samples": ("LATENT",),
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
@ -66,13 +71,14 @@ class LatentMultiply:
samples_out["samples"] = s1 * multiplier
return (samples_out,)
class LatentInterpolate:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
return {"required": {"samples1": ("LATENT",),
"samples2": ("LATENT",),
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
@ -100,10 +106,11 @@ class LatentInterpolate:
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
return (samples_out,)
class LatentBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
return {"required": {"samples1": ("LATENT",), "samples2": ("LATENT",)}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "batch"
@ -122,11 +129,12 @@ class LatentBatch:
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
return (samples_out,)
class LatentBatchSeedBehavior:
@classmethod
def INPUT_TYPES(s):
return {"required": { "samples": ("LATENT",),
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
return {"required": {"samples": ("LATENT",),
"seed_behavior": (["random", "fixed"], {"default": "fixed"}), }}
RETURN_TYPES = ("LATENT",)
FUNCTION = "op"
@ -145,6 +153,44 @@ class LatentBatchSeedBehavior:
return (samples_out,)
class LatentAddNoiseChannels:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"samples": ("LATENT",),
"std_dev": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"slice_i": ("INT", {"default": 0, "min": -16, "max": 16}),
"slice_j": ("INT", {"default": 16, "min": -16, "max": 16}),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "inject_noise"
CATEGORY = "latent/advanced"
def inject_noise(self, samples: Latent, std_dev, seed: int, slice_i: int, slice_j: int):
s = samples.copy()
latent = samples["samples"]
with comfy.utils.seed_for_block(seed):
if not isinstance(latent, torch.Tensor):
raise TypeError("Input must be a PyTorch tensor")
noise = torch.randn_like(latent[:, slice_i:slice_j, :, :]) * std_dev
noised_latent = latent.clone()
noised_latent[:, slice_i:slice_j, :, :] += noise
s["samples"] = noised_latent
return (s,)
NODE_CLASS_MAPPINGS = {
"LatentAdd": LatentAdd,
"LatentSubtract": LatentSubtract,
@ -152,4 +198,5 @@ NODE_CLASS_MAPPINGS = {
"LatentInterpolate": LatentInterpolate,
"LatentBatch": LatentBatch,
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
"LatentAddNoiseChannels": LatentAddNoiseChannels,
}

View File

@ -1,11 +1,13 @@
import torch
from torch import einsum
import torch.nn.functional as F
import math
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from comfy.ldm.modules.attention import optimized_attention
from torch import einsum
from comfy import samplers
from comfy.ldm.modules.attention import optimized_attention
# from comfy/ldm/modules/attention.py
# but modified to return attention scores as well as output
@ -50,6 +52,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
)
return (out, sim)
def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
# reshape and GAP the attention map
_, hw1, hw2 = attn.shape
@ -57,7 +60,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
attn = attn.reshape(b, -1, hw1, hw2)
# Global Average Pool
mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
ratio = 2 ** (math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
mid_shape = [math.ceil(lh / ratio), math.ceil(lw / ratio)]
# Reshape
@ -73,6 +76,7 @@ def create_blur_map(x0, attn, sigma=3.0, threshold=1.0):
blurred = blurred * mask + x0 * (1 - mask)
return blurred
def gaussian_blur_2d(img, kernel_size, sigma):
ksize_half = (kernel_size - 1) * 0.5
@ -92,13 +96,15 @@ def gaussian_blur_2d(img, kernel_size, sigma):
img = F.conv2d(img, kernel2d, groups=img.shape[-3])
return img
class SelfAttentionGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
return {"required": {"model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
}}
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
@ -123,7 +129,7 @@ class SelfAttentionGuidance:
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index + 1)]
return out
else:
return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
@ -142,7 +148,7 @@ class SelfAttentionGuidance:
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
if min(cfg_result.shape[2:]) <= 4: # skip when too small to add padding
return cfg_result
# create the adversarially blurred image
@ -158,7 +164,8 @@ class SelfAttentionGuidance:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
return (m, )
return (m,)
NODE_CLASS_MAPPINGS = {
"SelfAttentionGuidance": SelfAttentionGuidance,