diff --git a/README.md b/README.md index 1eeb810de..f05311421 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,9 @@ The portable above currently comes with python 3.13 and pytorch cuda 13.0. Updat #### Alternative Downloads: -[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) +[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) + +[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z) [Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py index b224306da..1477afa01 100644 --- a/api_server/routes/internal/internal_routes.py +++ b/api_server/routes/internal/internal_routes.py @@ -67,7 +67,7 @@ class InternalRoutes: (entry for entry in os.scandir(directory) if is_visible_file(entry)), key=lambda entry: -entry.stat().st_mtime ) - return web.json_response([entry.name for entry in sorted_files], status=200) + return web.json_response([f"{entry.name} [{directory_type}]" for entry in sorted_files], status=200) def get_app(self): diff --git a/comfy/ldm/ernie/model.py b/comfy/ldm/ernie/model.py index f7cdb51e6..eba661aec 100644 --- a/comfy/ldm/ernie/model.py +++ b/comfy/ldm/ernie/model.py @@ -118,8 +118,6 @@ class ErnieImageAttention(nn.Module): query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - query, key = query.to(x.dtype), key.to(x.dtype) - q_flat = query.reshape(B, S, -1) k_flat = key.reshape(B, S, -1) @@ -161,16 +159,16 @@ class ErnieImageSharedAdaLNBlock(nn.Module): residual = x x_norm = self.adaLN_sa_ln(x) - x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) + x_norm = x_norm * (1 + scale_msa) + shift_msa attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) - x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype) + x = residual + gate_msa * attn_out residual = x x_norm = self.adaLN_mlp_ln(x) - x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) + x_norm = x_norm * (1 + scale_mlp) + shift_mlp - return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype) + return residual + gate_mlp * self.mlp(x_norm) class ErnieImageAdaLNContinuous(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None): @@ -183,7 +181,7 @@ class ErnieImageAdaLNContinuous(nn.Module): def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: scale, shift = self.linear(conditioning).chunk(2, dim=-1) x = self.norm(x) - x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)) return x class ErnieImageModel(nn.Module): diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index fa0a00748..dd5320c8f 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -4,9 +4,6 @@ import math import torch import torchaudio -import comfy.model_management -import comfy.model_patcher -import comfy.utils as utils from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( @@ -43,30 +40,6 @@ class AudioVAEComponentConfig: return cls(autoencoder=audio_config, vocoder=vocoder_config) - -class ModelDeviceManager: - """Manages device placement and GPU residency for the composed model.""" - - def __init__(self, module: torch.nn.Module): - load_device = comfy.model_management.get_torch_device() - offload_device = comfy.model_management.vae_offload_device() - self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) - - def ensure_model_loaded(self) -> None: - comfy.model_management.free_memory( - self.patcher.model_size(), - self.patcher.load_device, - ) - comfy.model_management.load_model_gpu(self.patcher) - - def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: - return tensor.to(self.patcher.load_device) - - @property - def load_device(self): - return self.patcher.load_device - - class AudioLatentNormalizer: """Applies per-channel statistics in patch space and restores original layout.""" @@ -132,23 +105,17 @@ class AudioPreprocessor: class AudioVAE(torch.nn.Module): """High-level Audio VAE wrapper exposing encode and decode entry points.""" - def __init__(self, state_dict: dict, metadata: dict): + def __init__(self, metadata: dict): super().__init__() component_config = AudioVAEComponentConfig.from_metadata(metadata) - vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) - vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) - self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) if "bwe" in component_config.vocoder: self.vocoder = VocoderWithBWE(config=component_config.vocoder) else: self.vocoder = Vocoder(config=component_config.vocoder) - self.autoencoder.load_state_dict(vae_sd, strict=False) - self.vocoder.load_state_dict(vocoder_sd, strict=False) - autoencoder_config = self.autoencoder.get_config() self.normalizer = AudioLatentNormalizer( AudioPatchifier( @@ -168,18 +135,12 @@ class AudioVAE(torch.nn.Module): n_fft=autoencoder_config["n_fft"], ) - self.device_manager = ModelDeviceManager(self) - - def encode(self, audio: dict) -> torch.Tensor: + def encode(self, audio, sample_rate=44100) -> torch.Tensor: """Encode a waveform dictionary into normalized latent tensors.""" - waveform = audio["waveform"] - waveform_sample_rate = audio["sample_rate"] + waveform = audio + waveform_sample_rate = sample_rate input_device = waveform.device - # Ensure that Audio VAE is loaded on the correct device. - self.device_manager.ensure_model_loaded() - - waveform = self.device_manager.move_to_load_device(waveform) expected_channels = self.autoencoder.encoder.in_channels if waveform.shape[1] != expected_channels: if waveform.shape[1] == 1: @@ -190,7 +151,7 @@ class AudioVAE(torch.nn.Module): ) mel_spec = self.preprocessor.waveform_to_mel( - waveform, waveform_sample_rate, device=self.device_manager.load_device + waveform, waveform_sample_rate, device=waveform.device ) latents = self.autoencoder.encode(mel_spec) @@ -204,17 +165,13 @@ class AudioVAE(torch.nn.Module): """Decode normalized latent tensors into an audio waveform.""" original_shape = latents.shape - # Ensure that Audio VAE is loaded on the correct device. - self.device_manager.ensure_model_loaded() - - latents = self.device_manager.move_to_load_device(latents) latents = self.normalizer.denormalize(latents) target_shape = self.target_shape_from_latents(original_shape) mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) waveform = self.run_vocoder(mel_spec) - return self.device_manager.move_to_load_device(waveform) + return waveform def target_shape_from_latents(self, latents_shape): batch, _, time, _ = latents_shape diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 295310df6..4b92c44cf 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -34,6 +34,16 @@ class TimestepBlock(nn.Module): #This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index" def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): for layer in ts: + if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: + found_patched = False + for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: + if isinstance(layer, class_type): + x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) + found_patched = True + break + if found_patched: + continue + if isinstance(layer, VideoResBlock): x = layer(x, emb, num_video_frames, image_only_indicator) elif isinstance(layer, TimestepBlock): @@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out elif isinstance(layer, Upsample): x = layer(x, output_shape=output_shape) else: - if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]: - found_patched = False - for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]: - if isinstance(layer, class_type): - x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator) - found_patched = True - break - if found_patched: - continue x = layer(x) return x @@ -894,6 +895,12 @@ class UNetModel(nn.Module): h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = apply_control(h, control, 'middle') + if "middle_block_after_patch" in transformer_patches: + patch = transformer_patches["middle_block_after_patch"] + for p in patch: + out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y, + "timesteps": timesteps, "transformer_options": transformer_options}) + h = out["h"] for id, module in enumerate(self.output_blocks): transformer_options["block"] = ("output", id) @@ -905,8 +912,9 @@ class UNetModel(nn.Module): for p in patch: h, hsp = p(h, hsp, transformer_options) - h = th.cat([h, hsp], dim=1) - del hsp + if hsp is not None: + h = th.cat([h, hsp], dim=1) + del hsp if len(hs) > 0: output_shape = hs[-1].shape else: diff --git a/comfy/ldm/supir/__init__.py b/comfy/ldm/supir/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/supir/supir_modules.py b/comfy/ldm/supir/supir_modules.py new file mode 100644 index 000000000..7389b01d2 --- /dev/null +++ b/comfy/ldm/supir/supir_modules.py @@ -0,0 +1,226 @@ +import torch +import torch.nn as nn + +from comfy.ldm.modules.diffusionmodules.util import timestep_embedding +from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer +from comfy.ldm.modules.attention import optimized_attention + + +class ZeroSFT(nn.Module): + def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None): + super().__init__() + + ks = 3 + pw = ks // 2 + + self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device) + + nhidden = 128 + + self.mlp_shared = nn.Sequential( + operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device), + nn.SiLU() + ) + self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device) + self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device) + + self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device) + self.pre_concat = bool(concat_channels != 0) + + def forward(self, c, h, h_ori=None, control_scale=1): + if h_ori is not None and self.pre_concat: + h_raw = torch.cat([h_ori, h], dim=1) + else: + h_raw = h + + h = h + self.zero_conv(c) + if h_ori is not None and self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + actv = self.mlp_shared(c) + gamma = self.zero_mul(actv) + beta = self.zero_add(actv) + h = self.param_free_norm(h) + h = torch.addcmul(h + beta, h, gamma) + if h_ori is not None and not self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + return torch.lerp(h_raw, h, control_scale) + + +class _CrossAttnInner(nn.Module): + """Inner cross-attention module matching the state_dict layout of the original CrossAttention.""" + def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None): + super().__init__() + inner_dim = dim_head * heads + self.heads = heads + self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), + ) + + def forward(self, x, context): + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + return self.to_out(optimized_attention(q, k, v, self.heads)) + + +class ZeroCrossAttn(nn.Module): + def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None): + super().__init__() + heads = query_dim // 64 + dim_head = 64 + self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations) + self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device) + self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device) + + def forward(self, context, x, control_scale=1): + b, c, h, w = x.shape + x_in = x + + x = self.attn( + self.norm1(x).flatten(2).transpose(1, 2), + self.norm2(context).flatten(2).transpose(1, 2), + ).transpose(1, 2).unflatten(2, (h, w)) + + return x_in + x * control_scale + + +class GLVControl(nn.Module): + """SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only).""" + def __init__( + self, + in_channels=4, + model_channels=320, + num_res_blocks=2, + attention_resolutions=(4, 2), + channel_mult=(1, 2, 4), + num_head_channels=64, + transformer_depth=(1, 2, 10), + context_dim=2048, + adm_in_channels=2816, + use_linear_in_transformer=True, + use_checkpoint=False, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__() + self.model_channels = model_channels + time_embed_dim = model_channels * 4 + + self.time_embed = nn.Sequential( + operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device), + ) + + self.label_emb = nn.Sequential( + nn.Sequential( + operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device), + ) + ) + + self.input_blocks = nn.ModuleList([ + TimestepEmbedSequential( + operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device) + ) + ]) + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(num_res_blocks): + layers = [ + ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels, + dtype=dtype, device=device, operations=operations) + ] + ch = mult * model_channels + if ds in attention_resolutions: + num_heads = ch // num_head_channels + layers.append( + SpatialTransformer(ch, num_heads, num_head_channels, + depth=transformer_depth[level], context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + dtype=dtype, device=device, operations=operations) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + if level != len(channel_mult) - 1: + self.input_blocks.append( + TimestepEmbedSequential( + Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations) + ) + ) + ds *= 2 + + num_heads = ch // num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations), + SpatialTransformer(ch, num_heads, num_head_channels, + depth=transformer_depth[-1], context_dim=context_dim, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + dtype=dtype, device=device, operations=operations), + ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations), + ) + + self.input_hint_block = TimestepEmbedSequential( + operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device) + ) + + def forward(self, x, timesteps, xt, context=None, y=None, **kwargs): + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype) + emb = self.time_embed(t_emb) + self.label_emb(y) + + guided_hint = self.input_hint_block(x, emb, context) + + hs = [] + h = xt + for module in self.input_blocks: + if guided_hint is not None: + h = module(h, emb, context) + h += guided_hint + guided_hint = None + else: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + hs.append(h) + return hs + + +class SUPIR(nn.Module): + """ + SUPIR model containing GLVControl (control encoder) and project_modules (adapters). + State dict keys match the original SUPIR checkpoint layout: + control_model.* -> GLVControl + project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn + """ + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + + self.control_model = GLVControl(dtype=dtype, device=device, operations=operations) + + project_channel_scale = 2 + cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3 + project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3] + concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0] + cross_attn_insert_idx = [6, 3] + + self.project_modules = nn.ModuleList() + for i in range(len(cond_output_channels)): + self.project_modules.append(ZeroSFT( + project_channels[i], cond_output_channels[i], + concat_channels=concat_channels[i], + dtype=dtype, device=device, operations=operations, + )) + + for i in cross_attn_insert_idx: + self.project_modules.insert(i, ZeroCrossAttn( + cond_output_channels[i], concat_channels[i], + dtype=dtype, device=device, operations=operations, + )) diff --git a/comfy/ldm/supir/supir_patch.py b/comfy/ldm/supir/supir_patch.py new file mode 100644 index 000000000..b67ab4cd8 --- /dev/null +++ b/comfy/ldm/supir/supir_patch.py @@ -0,0 +1,103 @@ +import torch +from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample + + +class SUPIRPatch: + """ + Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters). + Runs GLVControl lazily on first patch invocation per step, applies adapters through + middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch. + """ + SIGMA_MAX = 14.6146 + + def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end): + self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl + self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn + self.hint_latent = hint_latent # encoded LQ image latent + self.strength_start = strength_start + self.strength_end = strength_end + self.cached_features = None + self.adapter_idx = 0 + self.control_idx = 0 + self.current_control_idx = 0 + self.active = True + + def _ensure_features(self, kwargs): + """Run GLVControl on first call per step, cache results.""" + if self.cached_features is not None: + return + x = kwargs["x"] + b = x.shape[0] + hint = self.hint_latent.to(device=x.device, dtype=x.dtype) + if hint.shape[0] != b: + hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b] + self.cached_features = self.model_patch.model.control_model( + hint, kwargs["timesteps"], x, + kwargs["context"], kwargs["y"] + ) + self.adapter_idx = len(self.project_modules) - 1 + self.control_idx = len(self.cached_features) - 1 + + def _get_control_scale(self, kwargs): + if self.strength_start == self.strength_end: + return self.strength_end + sigma = kwargs["transformer_options"].get("sigmas") + if sigma is None: + return self.strength_end + s = sigma[0].item() if sigma.dim() > 0 else sigma.item() + t = min(s / self.SIGMA_MAX, 1.0) + return t * (self.strength_start - self.strength_end) + self.strength_end + + def middle_after(self, kwargs): + """middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block.""" + self.cached_features = None # reset from previous step + self.current_scale = self._get_control_scale(kwargs) + self.active = self.current_scale > 0 + if not self.active: + return {"h": kwargs["h"]} + self._ensure_features(kwargs) + h = kwargs["h"] + h = self.project_modules[self.adapter_idx]( + self.cached_features[self.control_idx], h, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + self.control_idx -= 1 + return {"h": h} + + def output_block(self, h, hsp, transformer_options): + """output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat.""" + if not self.active: + return h, hsp + self.current_control_idx = self.control_idx + h = self.project_modules[self.adapter_idx]( + self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + self.control_idx -= 1 + return h, None + + def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw): + """forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample.""" + block_type, _ = transformer_options["block"] + if block_type == "output" and self.active and self.cached_features is not None: + x = self.project_modules[self.adapter_idx]( + self.cached_features[self.current_control_idx], x, control_scale=self.current_scale + ) + self.adapter_idx -= 1 + return layer(x, output_shape=output_shape) + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.cached_features = None + if self.hint_latent is not None: + self.hint_latent = self.hint_latent.to(device_or_dtype) + return self + + def models(self): + return [self.model_patch] + + def register(self, model_patcher): + """Register all patches on a cloned model patcher.""" + model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch") + model_patcher.set_model_output_block_patch(self.output_block) + model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 6deb71e12..93d19d6fe 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -506,6 +506,10 @@ class ModelPatcher: def set_model_noise_refiner_patch(self, patch): self.set_model_patch(patch, "noise_refiner") + def set_model_middle_block_after_patch(self, patch): + self.set_model_patch(patch, "middle_block_after_patch") + + def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): rope_options = self.model_options["transformer_options"].get("rope_options", {}) rope_options["scale_x"] = scale_x diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..a4d3ee269 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -12,6 +12,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.audio.autoencoder import AudioOobleckVAE import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder +import comfy.ldm.lightricks.vae.audio_vae import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 @@ -805,6 +806,23 @@ class VAE: self.downscale_index_formula = (4, 8, 8) self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio + self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata) + self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) + self.latent_channels = self.first_stage_model.latent_channels + self.audio_sample_rate_output = self.first_stage_model.output_sample_rate + self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes + self.output_channels = 2 + self.pad_channel_value = "replicate" + self.upscale_ratio = 4096 + self.downscale_ratio = 4096 + self.latent_dim = 2 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.float32] + self.disable_offload = True + self.extra_1d_channel = 16 else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index cbfaf913d..1602add84 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -3,136 +3,136 @@ from typing_extensions import override import comfy.model_management import node_helpers -from comfy_api.latest import ComfyExtension, io +from comfy_api.latest import ComfyExtension, IO -class TextEncodeAceStepAudio(io.ComfyNode): +class TextEncodeAceStepAudio(IO.ComfyNode): @classmethod def define_schema(cls): - return io.Schema( + return IO.Schema( node_id="TextEncodeAceStepAudio", category="conditioning", inputs=[ - io.Clip.Input("clip"), - io.String.Input("tags", multiline=True, dynamic_prompts=True), - io.String.Input("lyrics", multiline=True, dynamic_prompts=True), - io.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01), + IO.Clip.Input("clip"), + IO.String.Input("tags", multiline=True, dynamic_prompts=True), + IO.String.Input("lyrics", multiline=True, dynamic_prompts=True), + IO.Float.Input("lyrics_strength", default=1.0, min=0.0, max=10.0, step=0.01), ], - outputs=[io.Conditioning.Output()], + outputs=[IO.Conditioning.Output()], ) @classmethod - def execute(cls, clip, tags, lyrics, lyrics_strength) -> io.NodeOutput: + def execute(cls, clip, tags, lyrics, lyrics_strength) -> IO.NodeOutput: tokens = clip.tokenize(tags, lyrics=lyrics) conditioning = clip.encode_from_tokens_scheduled(tokens) conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) - return io.NodeOutput(conditioning) + return IO.NodeOutput(conditioning) -class TextEncodeAceStepAudio15(io.ComfyNode): +class TextEncodeAceStepAudio15(IO.ComfyNode): @classmethod def define_schema(cls): - return io.Schema( + return IO.Schema( node_id="TextEncodeAceStepAudio1.5", category="conditioning", inputs=[ - io.Clip.Input("clip"), - io.String.Input("tags", multiline=True, dynamic_prompts=True), - io.String.Input("lyrics", multiline=True, dynamic_prompts=True), - io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), - io.Int.Input("bpm", default=120, min=10, max=300), - io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), - io.Combo.Input("timesignature", options=['2', '3', '4', '6']), - io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), - io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), - io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), - io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), - io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), - io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), - io.Int.Input("top_k", default=0, min=0, max=100, advanced=True), - io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True), + IO.Clip.Input("clip"), + IO.String.Input("tags", multiline=True, dynamic_prompts=True), + IO.String.Input("lyrics", multiline=True, dynamic_prompts=True), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + IO.Int.Input("bpm", default=120, min=10, max=300), + IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), + IO.Combo.Input("timesignature", options=['2', '3', '4', '6']), + IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), + IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), + IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), + IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), + IO.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), + IO.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), + IO.Int.Input("top_k", default=0, min=0, max=100, advanced=True), + IO.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True), ], - outputs=[io.Conditioning.Output()], + outputs=[IO.Conditioning.Output()], ) @classmethod - def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput: + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> IO.NodeOutput: tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p) conditioning = clip.encode_from_tokens_scheduled(tokens) - return io.NodeOutput(conditioning) + return IO.NodeOutput(conditioning) -class EmptyAceStepLatentAudio(io.ComfyNode): +class EmptyAceStepLatentAudio(IO.ComfyNode): @classmethod def define_schema(cls): - return io.Schema( + return IO.Schema( node_id="EmptyAceStepLatentAudio", display_name="Empty Ace Step 1.0 Latent Audio", category="latent/audio", inputs=[ - io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), - io.Int.Input( + IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), + IO.Int.Input( "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." ), ], - outputs=[io.Latent.Output()], + outputs=[IO.Latent.Output()], ) @classmethod - def execute(cls, seconds, batch_size) -> io.NodeOutput: + def execute(cls, seconds, batch_size) -> IO.NodeOutput: length = int(seconds * 44100 / 512 / 8) latent = torch.zeros([batch_size, 8, 16, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) - return io.NodeOutput({"samples": latent, "type": "audio"}) + return IO.NodeOutput({"samples": latent, "type": "audio"}) -class EmptyAceStep15LatentAudio(io.ComfyNode): +class EmptyAceStep15LatentAudio(IO.ComfyNode): @classmethod def define_schema(cls): - return io.Schema( + return IO.Schema( node_id="EmptyAceStep1.5LatentAudio", display_name="Empty Ace Step 1.5 Latent Audio", category="latent/audio", inputs=[ - io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), - io.Int.Input( + IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), + IO.Int.Input( "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." ), ], - outputs=[io.Latent.Output()], + outputs=[IO.Latent.Output()], ) @classmethod - def execute(cls, seconds, batch_size) -> io.NodeOutput: + def execute(cls, seconds, batch_size) -> IO.NodeOutput: length = round((seconds * 48000 / 1920)) latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) - return io.NodeOutput({"samples": latent, "type": "audio"}) + return IO.NodeOutput({"samples": latent, "type": "audio"}) -class ReferenceAudio(io.ComfyNode): +class ReferenceAudio(IO.ComfyNode): @classmethod def define_schema(cls): - return io.Schema( + return IO.Schema( node_id="ReferenceTimbreAudio", display_name="Reference Audio", category="advanced/conditioning/audio", is_experimental=True, description="This node sets the reference audio for ace step 1.5", inputs=[ - io.Conditioning.Input("conditioning"), - io.Latent.Input("latent", optional=True), + IO.Conditioning.Input("conditioning"), + IO.Latent.Input("latent", optional=True), ], outputs=[ - io.Conditioning.Output(), + IO.Conditioning.Output(), ] ) @classmethod - def execute(cls, conditioning, latent=None) -> io.NodeOutput: + def execute(cls, conditioning, latent=None) -> IO.NodeOutput: if latent is not None: conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True) - return io.NodeOutput(conditioning) + return IO.NodeOutput(conditioning) class AceExtension(ComfyExtension): @override - async def get_node_list(self) -> list[type[io.ComfyNode]]: + async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ TextEncodeAceStepAudio, EmptyAceStepLatentAudio, diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index a395392d8..5f514716f 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None): std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100)) return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]} diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 3e4222264..3ec635c75 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -3,9 +3,8 @@ import comfy.utils import comfy.model_management import torch -from comfy.ldm.lightricks.vae.audio_vae import AudioVAE from comfy_api.latest import ComfyExtension, io - +from comfy_extras.nodes_audio import VAEEncodeAudio class LTXVAudioVAELoader(io.ComfyNode): @classmethod @@ -28,10 +27,14 @@ class LTXVAudioVAELoader(io.ComfyNode): def execute(cls, ckpt_name: str) -> io.NodeOutput: ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) - return io.NodeOutput(AudioVAE(sd, metadata)) + sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True) + vae = comfy.sd.VAE(sd=sd, metadata=metadata) + vae.throw_exception_if_invalid() + + return io.NodeOutput(vae) -class LTXVAudioVAEEncode(io.ComfyNode): +class LTXVAudioVAEEncode(VAEEncodeAudio): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( @@ -50,15 +53,8 @@ class LTXVAudioVAEEncode(io.ComfyNode): ) @classmethod - def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: - audio_latents = audio_vae.encode(audio) - return io.NodeOutput( - { - "samples": audio_latents, - "sample_rate": int(audio_vae.sample_rate), - "type": "audio", - } - ) + def execute(cls, audio, audio_vae) -> io.NodeOutput: + return super().execute(audio_vae, audio) class LTXVAudioVAEDecode(io.ComfyNode): @@ -80,12 +76,12 @@ class LTXVAudioVAEDecode(io.ComfyNode): ) @classmethod - def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + def execute(cls, samples, audio_vae) -> io.NodeOutput: audio_latent = samples["samples"] if audio_latent.is_nested: audio_latent = audio_latent.unbind()[-1] - audio = audio_vae.decode(audio_latent).to(audio_latent.device) - output_audio_sample_rate = audio_vae.output_sample_rate + audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device) + output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate return io.NodeOutput( { "waveform": audio, @@ -143,17 +139,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode): frames_number: int, frame_rate: int, batch_size: int, - audio_vae: AudioVAE, + audio_vae, ) -> io.NodeOutput: """Generate empty audio latents matching the reference pipeline structure.""" assert audio_vae is not None, "Audio VAE model is required" z_channels = audio_vae.latent_channels - audio_freq = audio_vae.latent_frequency_bins - sampling_rate = int(audio_vae.sample_rate) + audio_freq = audio_vae.first_stage_model.latent_frequency_bins + sampling_rate = int(audio_vae.first_stage_model.sample_rate) - num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate) audio_latents = torch.zeros( (batch_size, z_channels, num_audio_latents, audio_freq), diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 176e6bc2f..748559a6b 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -7,7 +7,10 @@ import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats import comfy.ldm.lumina.controlnet +import comfy.ldm.supir.supir_modules from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel +from comfy_api.latest import io +from comfy.ldm.supir.supir_patch import SUPIRPatch class BlockWiseControlBlock(torch.nn.Module): @@ -266,6 +269,27 @@ class ModelPatchLoader: out_dim=sd["audio_proj.norm.weight"].shape[0], device=comfy.model_management.unet_offload_device(), operations=comfy.ops.manual_cast) + elif 'model.control_model.input_hint_block.0.weight' in sd or 'control_model.input_hint_block.0.weight' in sd: + prefix_replace = {} + if 'model.control_model.input_hint_block.0.weight' in sd: + prefix_replace["model.control_model."] = "control_model." + prefix_replace["model.diffusion_model.project_modules."] = "project_modules." + else: + prefix_replace["control_model."] = "control_model." + prefix_replace["project_modules."] = "project_modules." + + # Extract denoise_encoder weights before filter_keys discards them + de_prefix = "first_stage_model.denoise_encoder." + denoise_encoder_sd = {} + for k in list(sd.keys()): + if k.startswith(de_prefix): + denoise_encoder_sd[k[len(de_prefix):]] = sd.pop(k) + + sd = comfy.utils.state_dict_prefix_replace(sd, prefix_replace, filter_keys=True) + sd.pop("control_model.mask_LQ", None) + model = comfy.ldm.supir.supir_modules.SUPIR(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + if denoise_encoder_sd: + model.denoise_encoder_sd = denoise_encoder_sd model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) model.load_state_dict(sd, assign=model_patcher.is_dynamic()) @@ -565,9 +589,89 @@ class MultiTalkModelPatch(torch.nn.Module): ) +class SUPIRApply(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="SUPIRApply", + category="model_patches/supir", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.ModelPatch.Input("model_patch"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Float.Input("strength_start", default=1.0, min=0.0, max=10.0, step=0.01, + tooltip="Control strength at the start of sampling (high sigma)."), + io.Float.Input("strength_end", default=1.0, min=0.0, max=10.0, step=0.01, + tooltip="Control strength at the end of sampling (low sigma). Linearly interpolated from start."), + io.Float.Input("restore_cfg", default=4.0, min=0.0, max=20.0, step=0.1, advanced=True, + tooltip="Pulls denoised output toward the input latent. Higher = stronger fidelity to input. 0 to disable."), + io.Float.Input("restore_cfg_s_tmin", default=0.05, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Sigma threshold below which restore_cfg is disabled."), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def _encode_with_denoise_encoder(cls, vae, model_patch, image): + """Encode using denoise_encoder weights from SUPIR checkpoint if available.""" + denoise_sd = getattr(model_patch.model, 'denoise_encoder_sd', None) + if not denoise_sd: + return vae.encode(image) + + # Clone VAE patcher, apply denoise_encoder weights to clone, encode + orig_patcher = vae.patcher + vae.patcher = orig_patcher.clone() + patches = {f"encoder.{k}": (v,) for k, v in denoise_sd.items()} + vae.patcher.add_patches(patches, strength_patch=1.0, strength_model=0.0) + try: + return vae.encode(image) + finally: + vae.patcher = orig_patcher + + @classmethod + def execute(cls, *, model: io.Model.Type, model_patch: io.ModelPatch.Type, vae: io.Vae.Type, image: io.Image.Type, + strength_start: float, strength_end: float, restore_cfg: float, restore_cfg_s_tmin: float) -> io.NodeOutput: + model_patched = model.clone() + hint_latent = model.get_model_object("latent_format").process_in( + cls._encode_with_denoise_encoder(vae, model_patch, image[:, :, :, :3])) + patch = SUPIRPatch(model_patch, model_patch.model.project_modules, hint_latent, strength_start, strength_end) + patch.register(model_patched) + + if restore_cfg > 0.0: + # Round-trip to match original pipeline: decode hint, re-encode with regular VAE + latent_format = model.get_model_object("latent_format") + decoded = vae.decode(latent_format.process_out(hint_latent)) + x_center = latent_format.process_in(vae.encode(decoded[:, :, :, :3])) + sigma_max = 14.6146 + + def restore_cfg_function(args): + denoised = args["denoised"] + sigma = args["sigma"] + if sigma.dim() > 0: + s = sigma[0].item() + else: + s = sigma.item() + if s > restore_cfg_s_tmin: + ref = x_center.to(device=denoised.device, dtype=denoised.dtype) + b = denoised.shape[0] + if ref.shape[0] != b: + ref = ref.expand(b, -1, -1, -1) if ref.shape[0] == 1 else ref.repeat((b + ref.shape[0] - 1) // ref.shape[0], 1, 1, 1)[:b] + sigma_val = sigma.view(-1, 1, 1, 1) if sigma.dim() > 0 else sigma + d_center = denoised - ref + denoised = denoised - d_center * ((sigma_val / sigma_max) ** restore_cfg) + return denoised + + model_patched.set_model_sampler_post_cfg_function(restore_cfg_function) + + return io.NodeOutput(model_patched) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, "ZImageFunControlnet": ZImageFunControlnet, "USOStyleReference": USOStyleReference, + "SUPIRApply": SUPIRApply, } diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 9037c3d20..c932b747a 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -6,6 +6,7 @@ from PIL import Image import math from enum import Enum from typing import TypedDict, Literal +import kornia import comfy.utils import comfy.model_management @@ -660,6 +661,228 @@ class BatchImagesMasksLatentsNode(io.ComfyNode): return io.NodeOutput(batched) +class ColorTransfer(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ColorTransfer", + category="image/postprocessing", + description="Match the colors of one image to another using various algorithms.", + search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"], + inputs=[ + io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."), + io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"), + io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],), + io.DynamicCombo.Input("source_stats", + tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)", + options=[ + io.DynamicCombo.Option("per_frame", []), + io.DynamicCombo.Option("uniform", []), + io.DynamicCombo.Option("target_frame", [ + io.Int.Input("target_index", default=0, min=0, max=10000, + tooltip="Frame index used as the source baseline for computing the transform to image_ref"), + ]), + ]), + io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + io.Image.Output(display_name="image"), + ], + ) + + @staticmethod + def _to_lab(images, i, device): + return kornia.color.rgb_to_lab( + images[i:i+1].to(device, dtype=torch.float32).permute(0, 3, 1, 2)) + + @staticmethod + def _pool_stats(images, device, is_reinhard, eps): + """Two-pass pooled mean + std/cov across all frames.""" + N, C = images.shape[0], images.shape[3] + HW = images.shape[1] * images.shape[2] + mean = torch.zeros(C, 1, device=device, dtype=torch.float32) + for i in range(N): + mean += ColorTransfer._to_lab(images, i, device).view(C, -1).mean(dim=-1, keepdim=True) + mean /= N + acc = torch.zeros(C, 1 if is_reinhard else C, device=device, dtype=torch.float32) + for i in range(N): + centered = ColorTransfer._to_lab(images, i, device).view(C, -1) - mean + if is_reinhard: + acc += (centered * centered).mean(dim=-1, keepdim=True) + else: + acc += centered @ centered.T / HW + if is_reinhard: + return mean, torch.sqrt(acc / N).clamp_min_(eps) + return mean, acc / N + + @staticmethod + def _frame_stats(lab_flat, hw, is_reinhard, eps): + """Per-frame mean + std/cov.""" + mean = lab_flat.mean(dim=-1, keepdim=True) + if is_reinhard: + return mean, lab_flat.std(dim=-1, keepdim=True, unbiased=False).clamp_min_(eps) + centered = lab_flat - mean + return mean, centered @ centered.T / hw + + @staticmethod + def _mkl_matrix(cov_s, cov_r, eps): + """Compute MKL 3x3 transform matrix from source and ref covariances.""" + eig_val_s, eig_vec_s = torch.linalg.eigh(cov_s) + sqrt_val_s = torch.sqrt(eig_val_s.clamp_min(0)).clamp_min_(eps) + + scaled_V = eig_vec_s * sqrt_val_s.unsqueeze(0) + mid = scaled_V.T @ cov_r @ scaled_V + eig_val_m, eig_vec_m = torch.linalg.eigh(mid) + sqrt_m = torch.sqrt(eig_val_m.clamp_min(0)) + + inv_sqrt_s = 1.0 / sqrt_val_s + inv_scaled_V = eig_vec_s * inv_sqrt_s.unsqueeze(0) + M_half = (eig_vec_m * sqrt_m.unsqueeze(0)) @ eig_vec_m.T + return inv_scaled_V @ M_half @ inv_scaled_V.T + + @staticmethod + def _histogram_lut(src, ref, bins=256): + """Build per-channel LUT from source and ref histograms. src/ref: (C, HW) in [0,1].""" + s_bins = (src * (bins - 1)).long().clamp(0, bins - 1) + r_bins = (ref * (bins - 1)).long().clamp(0, bins - 1) + s_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype) + r_hist = torch.zeros(src.shape[0], bins, device=src.device, dtype=src.dtype) + ones_s = torch.ones_like(src) + ones_r = torch.ones_like(ref) + s_hist.scatter_add_(1, s_bins, ones_s) + r_hist.scatter_add_(1, r_bins, ones_r) + s_cdf = s_hist.cumsum(1) + s_cdf = s_cdf / s_cdf[:, -1:] + r_cdf = r_hist.cumsum(1) + r_cdf = r_cdf / r_cdf[:, -1:] + return torch.searchsorted(r_cdf, s_cdf).clamp_max_(bins - 1).float() / (bins - 1) + + @classmethod + def _pooled_cdf(cls, images, device, num_bins=256): + """Build pooled CDF across all frames, one frame at a time.""" + C = images.shape[3] + hist = torch.zeros(C, num_bins, device=device, dtype=torch.float32) + for i in range(images.shape[0]): + frame = images[i].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1) + bins = (frame * (num_bins - 1)).long().clamp(0, num_bins - 1) + hist.scatter_add_(1, bins, torch.ones_like(frame)) + cdf = hist.cumsum(1) + return cdf / cdf[:, -1:] + + @classmethod + def _build_histogram_transform(cls, image_target, image_ref, device, stats_mode, target_index, B): + """Build per-frame or uniform LUT transform for histogram mode.""" + if stats_mode == 'per_frame': + return None # LUT computed per-frame in the apply loop + + r_cdf = cls._pooled_cdf(image_ref, device) + if stats_mode == 'target_frame': + ti = min(target_index, B - 1) + s_cdf = cls._pooled_cdf(image_target[ti:ti+1], device) + else: + s_cdf = cls._pooled_cdf(image_target, device) + return torch.searchsorted(r_cdf, s_cdf).clamp_max_(255).float() / 255.0 + + @classmethod + def _build_lab_transform(cls, image_target, image_ref, device, stats_mode, target_index, is_reinhard): + """Build transform parameters for Lab-based methods. Returns a transform function.""" + eps = 1e-6 + B, H, W, C = image_target.shape + B_ref = image_ref.shape[0] + single_ref = B_ref == 1 + HW = H * W + HW_ref = image_ref.shape[1] * image_ref.shape[2] + + # Precompute ref stats + if single_ref or stats_mode in ('uniform', 'target_frame'): + ref_mean, ref_sc = cls._pool_stats(image_ref, device, is_reinhard, eps) + + # Uniform/target_frame: precompute single affine transform + if stats_mode in ('uniform', 'target_frame'): + if stats_mode == 'target_frame': + ti = min(target_index, B - 1) + s_lab = cls._to_lab(image_target, ti, device).view(C, -1) + s_mean, s_sc = cls._frame_stats(s_lab, HW, is_reinhard, eps) + else: + s_mean, s_sc = cls._pool_stats(image_target, device, is_reinhard, eps) + + if is_reinhard: + scale = ref_sc / s_sc + offset = ref_mean - scale * s_mean + return lambda src_flat, **_: src_flat * scale + offset + T = cls._mkl_matrix(s_sc, ref_sc, eps) + offset = ref_mean - T @ s_mean + return lambda src_flat, **_: T @ src_flat + offset + + # per_frame + def per_frame_transform(src_flat, frame_idx): + s_mean, s_sc = cls._frame_stats(src_flat, HW, is_reinhard, eps) + + if single_ref: + r_mean, r_sc = ref_mean, ref_sc + else: + ri = min(frame_idx, B_ref - 1) + r_mean, r_sc = cls._frame_stats(cls._to_lab(image_ref, ri, device).view(C, -1), HW_ref, is_reinhard, eps) + + centered = src_flat - s_mean + if is_reinhard: + return centered * (r_sc / s_sc) + r_mean + T = cls._mkl_matrix(centered @ centered.T / HW, r_sc, eps) + return T @ centered + r_mean + + return per_frame_transform + + @classmethod + def execute(cls, image_target, image_ref, method, source_stats, strength=1.0) -> io.NodeOutput: + stats_mode = source_stats["source_stats"] + target_index = source_stats.get("target_index", 0) + + if strength == 0 or image_ref is None: + return io.NodeOutput(image_target) + + device = comfy.model_management.get_torch_device() + intermediate_device = comfy.model_management.intermediate_device() + intermediate_dtype = comfy.model_management.intermediate_dtype() + + B, H, W, C = image_target.shape + B_ref = image_ref.shape[0] + pbar = comfy.utils.ProgressBar(B) + out = torch.empty(B, H, W, C, device=intermediate_device, dtype=intermediate_dtype) + + if method == 'histogram': + uniform_lut = cls._build_histogram_transform( + image_target, image_ref, device, stats_mode, target_index, B) + + for i in range(B): + src = image_target[i].to(device, dtype=torch.float32).permute(2, 0, 1) + src_flat = src.reshape(C, -1) + if uniform_lut is not None: + lut = uniform_lut + else: + ri = min(i, B_ref - 1) + ref = image_ref[ri].to(device, dtype=torch.float32).permute(2, 0, 1).reshape(C, -1) + lut = cls._histogram_lut(src_flat, ref) + bin_idx = (src_flat * 255).long().clamp(0, 255) + matched = lut.gather(1, bin_idx).view(C, H, W) + result = matched if strength == 1.0 else torch.lerp(src, matched, strength) + out[i] = result.permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype) + pbar.update(1) + else: + transform = cls._build_lab_transform(image_target, image_ref, device, stats_mode, target_index, is_reinhard=method == "reinhard_lab") + + for i in range(B): + src_frame = cls._to_lab(image_target, i, device) + corrected = transform(src_frame.view(C, -1), frame_idx=i) + if strength == 1.0: + result = kornia.color.lab_to_rgb(corrected.view(1, C, H, W)) + else: + result = kornia.color.lab_to_rgb(torch.lerp(src_frame, corrected.view(1, C, H, W), strength)) + out[i] = result.squeeze(0).permute(1, 2, 0).clamp_(0, 1).to(device=intermediate_device, dtype=intermediate_dtype) + pbar.update(1) + + return io.NodeOutput(out) + + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -673,6 +896,7 @@ class PostProcessingExtension(ComfyExtension): BatchImagesNode, BatchMasksNode, BatchLatentsNode, + ColorTransfer, # BatchImagesMasksLatentsNode, ] diff --git a/requirements.txt b/requirements.txt index 3de845f48..671bd5693 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.42.11 +comfyui-frontend-package==1.42.12 comfyui-workflow-templates==0.9.57 comfyui-embedded-docs==0.4.3 torch @@ -19,7 +19,7 @@ scipy tqdm psutil alembic -SQLAlchemy +SQLAlchemy>=2.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8