This commit is contained in:
Yousef Rafat 2025-12-23 12:35:00 +02:00
parent e30298dda2
commit 5b0c80a093
6 changed files with 43 additions and 46 deletions

View File

@ -1,10 +1,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any, Callable from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import einops import einops
from einops import rearrange, repeat from einops import rearrange
import comfy.model_management import comfy.model_management
from torch import nn
import torch.nn.utils.rnn as rnn_utils
import torch.nn.functional as F import torch.nn.functional as F
from math import ceil, pi from math import ceil, pi
import torch import torch
@ -15,7 +13,6 @@ from comfy.rmsnorm import RMSNorm
from torch.nn.modules.utils import _triple from torch.nn.modules.utils import _triple
from torch import nn from torch import nn
import math import math
import logging
class Cache: class Cache:
def __init__(self, disable=False, prefix="", cache=None): def __init__(self, disable=False, prefix="", cache=None):
@ -126,7 +123,7 @@ def safe_pad_operation(x, padding, mode='constant', value=0.0):
"""Safe padding operation that handles Half precision only for problematic modes""" """Safe padding operation that handles Half precision only for problematic modes"""
# Modes qui nécessitent le fix Half precision # Modes qui nécessitent le fix Half precision
problematic_modes = ['replicate', 'reflect', 'circular'] problematic_modes = ['replicate', 'reflect', 'circular']
if mode in problematic_modes: if mode in problematic_modes:
try: try:
return F.pad(x, padding, mode=mode, value=value) return F.pad(x, padding, mode=mode, value=value)
@ -189,7 +186,7 @@ def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tup
resized_h, resized_w = round(h * scale), round(w * scale) resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
wt = ceil(min(t, 30) / resized_nt) # window size. wt = ceil(min(t, 30) / resized_nt) # window size.
st, sh, sw = ( # shift size. st, sh, sw = ( # shift size.
0.5 if wt < t else 0, 0.5 if wt < t else 0,
0.5 if wh < h else 0, 0.5 if wh < h else 0,
@ -466,7 +463,7 @@ def apply_rotary_emb(
freqs = freqs.to(t_middle.device) freqs = freqs.to(t_middle.device)
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale) t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
out = torch.cat((t_left, t_transformed, t_right), dim=-1) out = torch.cat((t_left, t_transformed, t_right), dim=-1)
return out.type(dtype) return out.type(dtype)
@ -655,7 +652,7 @@ class NaSwinAttention(NaMMAttention):
self, self,
*args, *args,
window: Union[int, Tuple[int, int, int]], window: Union[int, Tuple[int, int, int]],
window_method: bool, # shifted or not window_method: bool, # shifted or not
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -765,7 +762,7 @@ class NaSwinAttention(NaMMAttention):
vid_out, txt_out = self.proj_out(vid_out, txt_out) vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out return vid_out, txt_out
class MLP(nn.Module): class MLP(nn.Module):
def __init__( def __init__(
self, self,
@ -1274,7 +1271,7 @@ class NaDiT(nn.Module):
layers=["out"], layers=["out"],
modes=["in"], modes=["in"],
) )
self.stop_cfg_index = -1 self.stop_cfg_index = -1
def set_cfg_stop_index(self, cfg): def set_cfg_stop_index(self, cfg):
@ -1290,7 +1287,7 @@ class NaDiT(nn.Module):
context, # l c context, # l c
disable_cache: bool = False, # for test # TODO ? // gives an error when set to True disable_cache: bool = False, # for test # TODO ? // gives an error when set to True
**kwargs **kwargs
): ):
transformer_options = kwargs.get("transformer_options", {}) transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})

View File

@ -317,26 +317,26 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest',
"""Safe interpolate operation that handles Half precision for problematic modes""" """Safe interpolate operation that handles Half precision for problematic modes"""
# Modes qui peuvent causer des problèmes avec Half precision # Modes qui peuvent causer des problèmes avec Half precision
problematic_modes = ['bilinear', 'bicubic', 'trilinear'] problematic_modes = ['bilinear', 'bicubic', 'trilinear']
if mode in problematic_modes: if mode in problematic_modes:
try: try:
return F.interpolate( return F.interpolate(
x, x,
size=size, size=size,
scale_factor=scale_factor, scale_factor=scale_factor,
mode=mode, mode=mode,
align_corners=align_corners, align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor recompute_scale_factor=recompute_scale_factor
) )
except RuntimeError as e: except RuntimeError as e:
if ("not implemented for 'Half'" in str(e) or if ("not implemented for 'Half'" in str(e) or
"compute_indices_weights" in str(e)): "compute_indices_weights" in str(e)):
original_dtype = x.dtype original_dtype = x.dtype
return F.interpolate( return F.interpolate(
x.float(), x.float(),
size=size, size=size,
scale_factor=scale_factor, scale_factor=scale_factor,
mode=mode, mode=mode,
align_corners=align_corners, align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor recompute_scale_factor=recompute_scale_factor
).to(original_dtype) ).to(original_dtype)
@ -345,10 +345,10 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest',
else: else:
# Pour 'nearest' et autres modes compatibles, pas de fix nécessaire # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire
return F.interpolate( return F.interpolate(
x, x,
size=size, size=size,
scale_factor=scale_factor, scale_factor=scale_factor,
mode=mode, mode=mode,
align_corners=align_corners, align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor recompute_scale_factor=recompute_scale_factor
) )
@ -426,7 +426,7 @@ class Upsample3D(nn.Module):
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
self.interpolate = interpolate self.interpolate = interpolate
self.channels = channels self.channels = channels
self.out_channels = out_channels or channels self.out_channels = out_channels or channels
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
@ -444,7 +444,7 @@ class Upsample3D(nn.Module):
if kernel_size is None: if kernel_size is None:
kernel_size = 3 kernel_size = 3
self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
conv = self.conv if self.name == "conv" else self.Conv2d_0 conv = self.conv if self.name == "conv" else self.Conv2d_0
assert type(conv) is not nn.ConvTranspose2d assert type(conv) is not nn.ConvTranspose2d
@ -587,7 +587,7 @@ class Downsample3D(nn.Module):
kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
) )
self.conv = conv self.conv = conv
@ -1565,7 +1565,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
latent = latent / scale + shift latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...") latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2) latent = latent.squeeze(2)
if latent.ndim == 4: if latent.ndim == 4:
latent = latent.unsqueeze(2) latent = latent.unsqueeze(2)

View File

@ -815,7 +815,7 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out return out
class SeedVR2(BaseModel): class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)

View File

@ -445,7 +445,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["pad_tokens_multiple"] = 32 dit_config["pad_tokens_multiple"] = 32
return dit_config return dit_config
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {} dit_config = {}
dit_config["image_model"] = "seedvr2" dit_config["image_model"] = "seedvr2"

View File

@ -1287,7 +1287,7 @@ class Chroma(supported_models_base.BASE):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class SeedVR2(supported_models_base.BASE): class SeedVR2(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "seedvr2" "image_model": "seedvr2"

View File

@ -15,7 +15,7 @@ from torchvision.transforms.functional import InterpolationMode
@torch.inference_mode() @torch.inference_mode()
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True): def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -23,7 +23,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
x = x.unsqueeze(2) x = x.unsqueeze(2)
b, c, d, h, w = x.shape b, c, d, h, w = x.shape
sf_s = getattr(vae_model, "spatial_downsample_factor", 8) sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
sf_t = getattr(vae_model, "temporal_downsample_factor", 4) sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
@ -32,7 +32,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
ov_h, ov_w = tile_overlap ov_h, ov_w = tile_overlap
ti_t = temporal_size ti_t = temporal_size
ov_t = temporal_overlap ov_t = temporal_overlap
target_d = (d + sf_t - 1) // sf_t target_d = (d + sf_t - 1) // sf_t
target_h = (h + sf_s - 1) // sf_s target_h = (h + sf_s - 1) // sf_s
target_w = (w + sf_s - 1) // sf_s target_w = (w + sf_s - 1) // sf_s
@ -55,7 +55,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
storage_device = torch.device("cpu") storage_device = torch.device("cpu")
result = None result = None
count = None count = None
ramp_cache = {} ramp_cache = {}
def get_ramp(steps): def get_ramp(steps):
if steps not in ramp_cache: if steps not in ramp_cache:
@ -66,10 +66,10 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
bar = ProgressBar(d // stride_t) bar = ProgressBar(d // stride_t)
for t_idx in range(0, d, stride_t): for t_idx in range(0, d, stride_t):
t_end = min(t_idx + ti_t, d) t_end = min(t_idx + ti_t, d)
for y_idx in range(0, h, stride_h): for y_idx in range(0, h, stride_h):
y_end = min(y_idx + ti_h, h) y_end = min(y_idx + ti_h, h)
for x_idx in range(0, w, stride_w): for x_idx in range(0, w, stride_w):
x_end = min(x_idx + ti_w, w) x_end = min(x_idx + ti_w, w)
@ -94,7 +94,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2] ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2]
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2)) cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2))
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2)) cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2)) cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
@ -115,7 +115,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
r = get_ramp(cur_ov_t) r = get_ramp(cur_ov_t)
if t_idx > 0: w_t[:cur_ov_t] = r if t_idx > 0: w_t[:cur_ov_t] = r
if t_end < d: w_t[-cur_ov_t:] = 1.0 - r if t_end < d: w_t[-cur_ov_t:] = 1.0 - r
if cur_ov_h > 0: if cur_ov_h > 0:
r = get_ramp(cur_ov_h) r = get_ramp(cur_ov_h)
if y_idx > 0: w_h[:cur_ov_h] = r if y_idx > 0: w_h[:cur_ov_h] = r
@ -131,11 +131,11 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
tile_out.mul_(final_weight) tile_out.mul_(final_weight)
result[:, :, ts:te, ys:ye, xs:xe] += tile_out result[:, :, ts:te, ys:ye, xs:xe] += tile_out
count[:, :, ts:te, ys:ye, xs:xe] += final_weight count[:, :, ts:te, ys:ye, xs:xe] += final_weight
del tile_out, final_weight, tile_x, w_t, w_h, w_w del tile_out, final_weight, tile_x, w_t, w_h, w_w
bar.update(1) bar.update(1)
result.div_(count.clamp(min=1e-6)) result.div_(count.clamp(min=1e-6))
if result.device != x.device: if result.device != x.device:
result = result.to(x.device).to(x.dtype) result = result.to(x.device).to(x.dtype)
@ -238,7 +238,7 @@ def cut_videos(videos):
videos = torch.cat([videos, padding], dim=1) videos = torch.cat([videos, padding], dim=1)
assert (videos.size(1) - 1) % (4) == 0 assert (videos.size(1) - 1) % (4) == 0
return videos return videos
class SeedVR2InputProcessing(io.ComfyNode): class SeedVR2InputProcessing(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -259,7 +259,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
io.Latent.Output("vae_conditioning") io.Latent.Output("vae_conditioning")
] ]
) )
@classmethod @classmethod
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap): def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap):
device = vae.patcher.load_device device = vae.patcher.load_device
@ -271,7 +271,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
scale = 0.9152; shift = 0 scale = 0.9152; shift = 0
if images.dim() != 5: # add the t dim if images.dim() != 5: # add the t dim
images = images.unsqueeze(0) images = images.unsqueeze(0)
images = images.permute(0, 1, 4, 2, 3) images = images.permute(0, 1, 4, 2, 3)
b, t, c, h, w = images.shape b, t, c, h, w = images.shape
images = images.reshape(b * t, c, h, w) images = images.reshape(b * t, c, h, w)
@ -328,7 +328,7 @@ class SeedVR2Conditioning(io.ComfyNode):
@classmethod @classmethod
def execute(cls, vae_conditioning, model) -> io.NodeOutput: def execute(cls, vae_conditioning, model) -> io.NodeOutput:
vae_conditioning = vae_conditioning["samples"] vae_conditioning = vae_conditioning["samples"]
device = vae_conditioning.device device = vae_conditioning.device
model = model.model.diffusion_model model = model.model.diffusion_model