diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 716d728c2..cf3ebd520 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1,10 +1,8 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union, List, Dict, Any, Callable import einops -from einops import rearrange, repeat +from einops import rearrange import comfy.model_management -from torch import nn -import torch.nn.utils.rnn as rnn_utils import torch.nn.functional as F from math import ceil, pi import torch @@ -15,7 +13,6 @@ from comfy.rmsnorm import RMSNorm from torch.nn.modules.utils import _triple from torch import nn import math -import logging class Cache: def __init__(self, disable=False, prefix="", cache=None): @@ -126,7 +123,7 @@ def safe_pad_operation(x, padding, mode='constant', value=0.0): """Safe padding operation that handles Half precision only for problematic modes""" # Modes qui nécessitent le fix Half precision problematic_modes = ['replicate', 'reflect', 'circular'] - + if mode in problematic_modes: try: return F.pad(x, padding, mode=mode, value=value) @@ -189,7 +186,7 @@ def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tup resized_h, resized_w = round(h * scale), round(w * scale) wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. wt = ceil(min(t, 30) / resized_nt) # window size. - + st, sh, sw = ( # shift size. 0.5 if wt < t else 0, 0.5 if wh < h else 0, @@ -466,7 +463,7 @@ def apply_rotary_emb( freqs = freqs.to(t_middle.device) t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) - + out = torch.cat((t_left, t_transformed, t_right), dim=-1) return out.type(dtype) @@ -655,7 +652,7 @@ class NaSwinAttention(NaMMAttention): self, *args, window: Union[int, Tuple[int, int, int]], - window_method: bool, # shifted or not + window_method: bool, # shifted or not **kwargs, ): super().__init__(*args, **kwargs) @@ -765,7 +762,7 @@ class NaSwinAttention(NaMMAttention): vid_out, txt_out = self.proj_out(vid_out, txt_out) return vid_out, txt_out - + class MLP(nn.Module): def __init__( self, @@ -1274,7 +1271,7 @@ class NaDiT(nn.Module): layers=["out"], modes=["in"], ) - + self.stop_cfg_index = -1 def set_cfg_stop_index(self, cfg): @@ -1290,7 +1287,7 @@ class NaDiT(nn.Module): context, # l c disable_cache: bool = False, # for test # TODO ? // gives an error when set to True **kwargs - ): + ): transformer_options = kwargs.get("transformer_options", {}) patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index a8f8c31f2..f30646dda 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -317,26 +317,26 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', """Safe interpolate operation that handles Half precision for problematic modes""" # Modes qui peuvent causer des problèmes avec Half precision problematic_modes = ['bilinear', 'bicubic', 'trilinear'] - + if mode in problematic_modes: try: return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, + x, + size=size, + scale_factor=scale_factor, + mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor ) except RuntimeError as e: - if ("not implemented for 'Half'" in str(e) or + if ("not implemented for 'Half'" in str(e) or "compute_indices_weights" in str(e)): original_dtype = x.dtype return F.interpolate( - x.float(), - size=size, - scale_factor=scale_factor, - mode=mode, + x.float(), + size=size, + scale_factor=scale_factor, + mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor ).to(original_dtype) @@ -345,10 +345,10 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', else: # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, + x, + size=size, + scale_factor=scale_factor, + mode=mode, align_corners=align_corners, recompute_scale_factor=recompute_scale_factor ) @@ -426,7 +426,7 @@ class Upsample3D(nn.Module): **kwargs, ): super().__init__() - self.interpolate = interpolate + self.interpolate = interpolate self.channels = channels self.out_channels = out_channels or channels self.use_conv_transpose = use_conv_transpose @@ -444,7 +444,7 @@ class Upsample3D(nn.Module): if kernel_size is None: kernel_size = 3 self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) - + conv = self.conv if self.name == "conv" else self.Conv2d_0 assert type(conv) is not nn.ConvTranspose2d @@ -587,7 +587,7 @@ class Downsample3D(nn.Module): kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), ) - + self.conv = conv @@ -1565,7 +1565,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): latent = latent / scale + shift latent = rearrange(latent, "b ... c -> b c ...") latent = latent.squeeze(2) - + if latent.ndim == 4: latent = latent.unsqueeze(2) diff --git a/comfy/model_base.py b/comfy/model_base.py index 2b354f418..53f953710 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -815,7 +815,7 @@ class HunyuanDiT(BaseModel): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out - + class SeedVR2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index f1312c3ab..886409d47 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -445,7 +445,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["pad_tokens_multiple"] = 32 return dit_config - + elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b dit_config = {} dit_config["image_model"] = "seedvr2" diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1c325524d..9bbb1d0cd 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1287,7 +1287,7 @@ class Chroma(supported_models_base.BASE): pref = self.text_encoder_key_prefix[0] t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) - + class SeedVR2(supported_models_base.BASE): unet_config = { "image_model": "seedvr2" diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e3281b5f3..ce5437517 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -15,7 +15,7 @@ from torchvision.transforms.functional import InterpolationMode @torch.inference_mode() def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True): - + gc.collect() torch.cuda.empty_cache() @@ -23,7 +23,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora x = x.unsqueeze(2) b, c, d, h, w = x.shape - + sf_s = getattr(vae_model, "spatial_downsample_factor", 8) sf_t = getattr(vae_model, "temporal_downsample_factor", 4) @@ -32,7 +32,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora ov_h, ov_w = tile_overlap ti_t = temporal_size ov_t = temporal_overlap - + target_d = (d + sf_t - 1) // sf_t target_h = (h + sf_s - 1) // sf_s target_w = (w + sf_s - 1) // sf_s @@ -55,7 +55,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora storage_device = torch.device("cpu") result = None count = None - + ramp_cache = {} def get_ramp(steps): if steps not in ramp_cache: @@ -66,10 +66,10 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora bar = ProgressBar(d // stride_t) for t_idx in range(0, d, stride_t): t_end = min(t_idx + ti_t, d) - + for y_idx in range(0, h, stride_h): y_end = min(y_idx + ti_h, h) - + for x_idx in range(0, w, stride_w): x_end = min(x_idx + ti_w, w) @@ -94,7 +94,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2] ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] - + cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2)) cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) @@ -115,7 +115,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora r = get_ramp(cur_ov_t) if t_idx > 0: w_t[:cur_ov_t] = r if t_end < d: w_t[-cur_ov_t:] = 1.0 - r - + if cur_ov_h > 0: r = get_ramp(cur_ov_h) if y_idx > 0: w_h[:cur_ov_h] = r @@ -131,11 +131,11 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora tile_out.mul_(final_weight) result[:, :, ts:te, ys:ye, xs:xe] += tile_out count[:, :, ts:te, ys:ye, xs:xe] += final_weight - + del tile_out, final_weight, tile_x, w_t, w_h, w_w bar.update(1) result.div_(count.clamp(min=1e-6)) - + if result.device != x.device: result = result.to(x.device).to(x.dtype) @@ -238,7 +238,7 @@ def cut_videos(videos): videos = torch.cat([videos, padding], dim=1) assert (videos.size(1) - 1) % (4) == 0 return videos - + class SeedVR2InputProcessing(io.ComfyNode): @classmethod def define_schema(cls): @@ -259,7 +259,7 @@ class SeedVR2InputProcessing(io.ComfyNode): io.Latent.Output("vae_conditioning") ] ) - + @classmethod def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap): device = vae.patcher.load_device @@ -271,7 +271,7 @@ class SeedVR2InputProcessing(io.ComfyNode): scale = 0.9152; shift = 0 if images.dim() != 5: # add the t dim images = images.unsqueeze(0) - images = images.permute(0, 1, 4, 2, 3) + images = images.permute(0, 1, 4, 2, 3) b, t, c, h, w = images.shape images = images.reshape(b * t, c, h, w) @@ -328,7 +328,7 @@ class SeedVR2Conditioning(io.ComfyNode): @classmethod def execute(cls, vae_conditioning, model) -> io.NodeOutput: - + vae_conditioning = vae_conditioning["samples"] device = vae_conditioning.device model = model.model.diffusion_model