import torch import nodes import comfy.utils from comfy.component_model.tensor_types import Latent from comfy.nodes.package_typing import Seed, Seed64 from .nodes_post_processing import gaussian_kernel def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: latent = comfy.utils.common_upscale(latent, target_shape[-1], target_shape[-2], "bilinear", "center") if repeat_batch: return comfy.utils.repeat_to_batch_size(latent, target_shape[0]) else: return latent class LatentAdd: @classmethod def INPUT_TYPES(s): return {"required": {"samples1": ("LATENT",), "samples2": ("LATENT",)}} RETURN_TYPES = ("LATENT",) FUNCTION = "op" CATEGORY = "latent/advanced" def op(self, samples1, samples2): samples_out = samples1.copy() s1 = samples1["samples"] s2 = samples2["samples"] s2 = reshape_latent_to(s1.shape, s2) samples_out["samples"] = s1 + s2 return (samples_out,) class LatentSubtract: @classmethod def INPUT_TYPES(s): return {"required": {"samples1": ("LATENT",), "samples2": ("LATENT",)}} RETURN_TYPES = ("LATENT",) FUNCTION = "op" CATEGORY = "latent/advanced" def op(self, samples1, samples2): samples_out = samples1.copy() s1 = samples1["samples"] s2 = samples2["samples"] s2 = reshape_latent_to(s1.shape, s2) samples_out["samples"] = s1 - s2 return (samples_out,) class LatentMultiply: @classmethod def INPUT_TYPES(s): return {"required": {"samples": ("LATENT",), "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "op" CATEGORY = "latent/advanced" def op(self, samples, multiplier): samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = s1 * multiplier return (samples_out,) class LatentInterpolate: @classmethod def INPUT_TYPES(s): return {"required": {"samples1": ("LATENT",), "samples2": ("LATENT",), "ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "op" CATEGORY = "latent/advanced" def op(self, samples1, samples2, ratio): samples_out = samples1.copy() s1 = samples1["samples"] s2 = samples2["samples"] s2 = reshape_latent_to(s1.shape, s2) m1 = torch.linalg.vector_norm(s1, dim=(1)) m2 = torch.linalg.vector_norm(s2, dim=(1)) s1 = torch.nan_to_num(s1 / m1) s2 = torch.nan_to_num(s2 / m2) t = (s1 * ratio + s2 * (1.0 - ratio)) mt = torch.linalg.vector_norm(t, dim=(1)) st = torch.nan_to_num(t / mt) samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio)) return (samples_out,) class LatentConcat: @classmethod def INPUT_TYPES(s): return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}} RETURN_TYPES = ("LATENT",) FUNCTION = "op" CATEGORY = "latent/advanced" def op(self, samples1, samples2, dim): samples_out = samples1.copy() s1 = samples1["samples"] s2 = samples2["samples"] s2 = comfy.utils.repeat_to_batch_size(s2, s1.shape[0]) if "-" in dim: c = (s2, s1) else: c = (s1, s2) if "x" in dim: dim = -1 elif "y" in dim: dim = -2 elif "t" in dim: dim = -3 samples_out["samples"] = torch.cat(c, dim=dim) return (samples_out,) class LatentCut: @classmethod def INPUT_TYPES(s): return {"required": {"samples": ("LATENT",), "dim": (["x", "y", "t"], ), "index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}), "amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}} RETURN_TYPES = ("LATENT",) FUNCTION = "op" CATEGORY = "latent/advanced" def op(self, samples, dim, index, amount): samples_out = samples.copy() s1 = samples["samples"] if "x" in dim: dim = s1.ndim - 1 elif "y" in dim: dim = s1.ndim - 2 elif "t" in dim: dim = s1.ndim - 3 if index >= 0: index = min(index, s1.shape[dim] - 1) amount = min(s1.shape[dim] - index, amount) else: index = max(index, -s1.shape[dim]) amount = min(-index, amount) samples_out["samples"] = torch.narrow(s1, dim, index, amount) return (samples_out,) class LatentBatch: @classmethod def INPUT_TYPES(s): return {"required": {"samples1": ("LATENT",), "samples2": ("LATENT",)}} RETURN_TYPES = ("LATENT",) FUNCTION = "batch" CATEGORY = "latent/batch" def batch(self, samples1, samples2): samples_out = samples1.copy() s1 = samples1["samples"] s2 = samples2["samples"] s2 = reshape_latent_to(s1.shape, s2, repeat_batch=False) s = torch.cat((s1, s2), dim=0) samples_out["samples"] = s samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])]) return (samples_out,) class LatentBatchSeedBehavior: @classmethod def INPUT_TYPES(s): return {"required": {"samples": ("LATENT",), "seed_behavior": (["random", "fixed"], {"default": "fixed"}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "op" CATEGORY = "latent/advanced" def op(self, samples, seed_behavior): samples_out = samples.copy() latent = samples["samples"] if seed_behavior == "random": if 'batch_index' in samples_out: samples_out.pop('batch_index') elif seed_behavior == "fixed": batch_number = samples_out.get("batch_index", [0])[0] samples_out["batch_index"] = [batch_number] * latent.shape[0] return (samples_out,) class LatentAddNoiseChannels: def __init__(self): pass @classmethod def INPUT_TYPES(cls): return { "required": { "samples": ("LATENT",), "std_dev": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}), "seed": Seed64, "slice_i": ("INT", {"default": 0, "min": -16, "max": 16}), "slice_j": ("INT", {"default": 16, "min": -16, "max": 16}), } } RETURN_TYPES = ("LATENT",) FUNCTION = "inject_noise" CATEGORY = "latent/advanced" def inject_noise(self, samples: Latent, std_dev, seed: int, slice_i: int, slice_j: int): s = samples.copy() latent = samples["samples"] with comfy.utils.seed_for_block(seed): if not isinstance(latent, torch.Tensor): raise TypeError("Input must be a PyTorch tensor") noise = torch.randn_like(latent[:, slice_i:slice_j, :, :]) * std_dev noised_latent = latent.clone() noised_latent[:, slice_i:slice_j, :, :] += noise s["samples"] = noised_latent return (s,) class LatentApplyOperation: @classmethod def INPUT_TYPES(s): return {"required": {"samples": ("LATENT",), "operation": ("LATENT_OPERATION",), }} RETURN_TYPES = ("LATENT",) FUNCTION = "op" CATEGORY = "latent/advanced/operations" EXPERIMENTAL = True def op(self, samples, operation): samples_out = samples.copy() s1 = samples["samples"] samples_out["samples"] = operation(latent=s1) return (samples_out,) class LatentApplyOperationCFG: @classmethod def INPUT_TYPES(s): return {"required": {"model": ("MODEL",), "operation": ("LATENT_OPERATION",), }} RETURN_TYPES = ("MODEL",) FUNCTION = "patch" CATEGORY = "latent/advanced/operations" EXPERIMENTAL = True def patch(self, model, operation): m = model.clone() def pre_cfg_function(args): conds_out = args["conds_out"] if len(conds_out) == 2: conds_out[0] = operation(latent=(conds_out[0] - conds_out[1])) + conds_out[1] else: conds_out[0] = operation(latent=conds_out[0]) return conds_out m.set_model_sampler_pre_cfg_function(pre_cfg_function) return (m,) class LatentOperationTonemapReinhard: @classmethod def INPUT_TYPES(s): return {"required": {"multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}), }} RETURN_TYPES = ("LATENT_OPERATION",) FUNCTION = "op" CATEGORY = "latent/advanced/operations" EXPERIMENTAL = True def op(self, multiplier): def tonemap_reinhard(latent, **kwargs): latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:, None] normalized_latent = latent / latent_vector_magnitude mean = torch.mean(latent_vector_magnitude, dim=(1, 2, 3), keepdim=True) std = torch.std(latent_vector_magnitude, dim=(1, 2, 3), keepdim=True) top = (std * 5 + mean) * multiplier # reinhard latent_vector_magnitude *= (1.0 / top) new_magnitude = latent_vector_magnitude / (latent_vector_magnitude + 1.0) new_magnitude *= top return normalized_latent * new_magnitude return (tonemap_reinhard,) class LatentOperationSharpen: @classmethod def INPUT_TYPES(s): return {"required": { "sharpen_radius": ("INT", { "default": 9, "min": 1, "max": 31, "step": 1 }), "sigma": ("FLOAT", { "default": 1.0, "min": 0.1, "max": 10.0, "step": 0.1 }), "alpha": ("FLOAT", { "default": 0.1, "min": 0.0, "max": 5.0, "step": 0.01 }), }} RETURN_TYPES = ("LATENT_OPERATION",) FUNCTION = "op" CATEGORY = "latent/advanced/operations" EXPERIMENTAL = True def op(self, sharpen_radius, sigma, alpha): def sharpen(latent, **kwargs): luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:, None] normalized_latent = latent / luminance channels = latent.shape[1] kernel_size = sharpen_radius * 2 + 1 kernel = gaussian_kernel(kernel_size, sigma, device=luminance.device) center = kernel_size // 2 kernel *= alpha * -10 kernel[center, center] = kernel[center, center] - kernel.sum() + 1.0 padded_image = torch.nn.functional.pad(normalized_latent, (sharpen_radius, sharpen_radius, sharpen_radius, sharpen_radius), 'reflect') sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:, :, sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius] return luminance * sharpened return (sharpen,) NODE_CLASS_MAPPINGS = { "LatentAdd": LatentAdd, "LatentSubtract": LatentSubtract, "LatentMultiply": LatentMultiply, "LatentInterpolate": LatentInterpolate, "LatentConcat": LatentConcat, "LatentCut": LatentCut, "LatentBatch": LatentBatch, "LatentBatchSeedBehavior": LatentBatchSeedBehavior, "LatentAddNoiseChannels": LatentAddNoiseChannels, "LatentApplyOperation": LatentApplyOperation, "LatentApplyOperationCFG": LatentApplyOperationCFG, "LatentOperationTonemapReinhard": LatentOperationTonemapReinhard, "LatentOperationSharpen": LatentOperationSharpen, }