This commit is contained in:
Yousef Rafat 2025-12-23 12:35:00 +02:00
parent e30298dda2
commit 5b0c80a093
6 changed files with 43 additions and 46 deletions

View File

@ -1,10 +1,8 @@
from dataclasses import dataclass
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import einops
from einops import rearrange, repeat
from einops import rearrange
import comfy.model_management
from torch import nn
import torch.nn.utils.rnn as rnn_utils
import torch.nn.functional as F
from math import ceil, pi
import torch
@ -15,7 +13,6 @@ from comfy.rmsnorm import RMSNorm
from torch.nn.modules.utils import _triple
from torch import nn
import math
import logging
class Cache:
def __init__(self, disable=False, prefix="", cache=None):
@ -126,7 +123,7 @@ def safe_pad_operation(x, padding, mode='constant', value=0.0):
"""Safe padding operation that handles Half precision only for problematic modes"""
# Modes qui nécessitent le fix Half precision
problematic_modes = ['replicate', 'reflect', 'circular']
if mode in problematic_modes:
try:
return F.pad(x, padding, mode=mode, value=value)
@ -189,7 +186,7 @@ def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tup
resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
wt = ceil(min(t, 30) / resized_nt) # window size.
st, sh, sw = ( # shift size.
0.5 if wt < t else 0,
0.5 if wh < h else 0,
@ -466,7 +463,7 @@ def apply_rotary_emb(
freqs = freqs.to(t_middle.device)
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
return out.type(dtype)
@ -655,7 +652,7 @@ class NaSwinAttention(NaMMAttention):
self,
*args,
window: Union[int, Tuple[int, int, int]],
window_method: bool, # shifted or not
window_method: bool, # shifted or not
**kwargs,
):
super().__init__(*args, **kwargs)
@ -765,7 +762,7 @@ class NaSwinAttention(NaMMAttention):
vid_out, txt_out = self.proj_out(vid_out, txt_out)
return vid_out, txt_out
class MLP(nn.Module):
def __init__(
self,
@ -1274,7 +1271,7 @@ class NaDiT(nn.Module):
layers=["out"],
modes=["in"],
)
self.stop_cfg_index = -1
def set_cfg_stop_index(self, cfg):
@ -1290,7 +1287,7 @@ class NaDiT(nn.Module):
context, # l c
disable_cache: bool = False, # for test # TODO ? // gives an error when set to True
**kwargs
):
):
transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})

View File

@ -317,26 +317,26 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest',
"""Safe interpolate operation that handles Half precision for problematic modes"""
# Modes qui peuvent causer des problèmes avec Half precision
problematic_modes = ['bilinear', 'bicubic', 'trilinear']
if mode in problematic_modes:
try:
return F.interpolate(
x,
size=size,
scale_factor=scale_factor,
mode=mode,
x,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor
)
except RuntimeError as e:
if ("not implemented for 'Half'" in str(e) or
if ("not implemented for 'Half'" in str(e) or
"compute_indices_weights" in str(e)):
original_dtype = x.dtype
return F.interpolate(
x.float(),
size=size,
scale_factor=scale_factor,
mode=mode,
x.float(),
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor
).to(original_dtype)
@ -345,10 +345,10 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest',
else:
# Pour 'nearest' et autres modes compatibles, pas de fix nécessaire
return F.interpolate(
x,
size=size,
scale_factor=scale_factor,
mode=mode,
x,
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor
)
@ -426,7 +426,7 @@ class Upsample3D(nn.Module):
**kwargs,
):
super().__init__()
self.interpolate = interpolate
self.interpolate = interpolate
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv_transpose = use_conv_transpose
@ -444,7 +444,7 @@ class Upsample3D(nn.Module):
if kernel_size is None:
kernel_size = 3
self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
conv = self.conv if self.name == "conv" else self.Conv2d_0
assert type(conv) is not nn.ConvTranspose2d
@ -587,7 +587,7 @@ class Downsample3D(nn.Module):
kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
)
self.conv = conv
@ -1565,7 +1565,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
latent = latent / scale + shift
latent = rearrange(latent, "b ... c -> b c ...")
latent = latent.squeeze(2)
if latent.ndim == 4:
latent = latent.unsqueeze(2)

View File

@ -815,7 +815,7 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)

View File

@ -445,7 +445,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["pad_tokens_multiple"] = 32
return dit_config
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"

View File

@ -1287,7 +1287,7 @@ class Chroma(supported_models_base.BASE):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"

View File

@ -15,7 +15,7 @@ from torchvision.transforms.functional import InterpolationMode
@torch.inference_mode()
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True):
gc.collect()
torch.cuda.empty_cache()
@ -23,7 +23,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
x = x.unsqueeze(2)
b, c, d, h, w = x.shape
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
@ -32,7 +32,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
ov_h, ov_w = tile_overlap
ti_t = temporal_size
ov_t = temporal_overlap
target_d = (d + sf_t - 1) // sf_t
target_h = (h + sf_s - 1) // sf_s
target_w = (w + sf_s - 1) // sf_s
@ -55,7 +55,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
storage_device = torch.device("cpu")
result = None
count = None
ramp_cache = {}
def get_ramp(steps):
if steps not in ramp_cache:
@ -66,10 +66,10 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
bar = ProgressBar(d // stride_t)
for t_idx in range(0, d, stride_t):
t_end = min(t_idx + ti_t, d)
for y_idx in range(0, h, stride_h):
y_end = min(y_idx + ti_h, h)
for x_idx in range(0, w, stride_w):
x_end = min(x_idx + ti_w, w)
@ -94,7 +94,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2]
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2))
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
@ -115,7 +115,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
r = get_ramp(cur_ov_t)
if t_idx > 0: w_t[:cur_ov_t] = r
if t_end < d: w_t[-cur_ov_t:] = 1.0 - r
if cur_ov_h > 0:
r = get_ramp(cur_ov_h)
if y_idx > 0: w_h[:cur_ov_h] = r
@ -131,11 +131,11 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
tile_out.mul_(final_weight)
result[:, :, ts:te, ys:ye, xs:xe] += tile_out
count[:, :, ts:te, ys:ye, xs:xe] += final_weight
del tile_out, final_weight, tile_x, w_t, w_h, w_w
bar.update(1)
result.div_(count.clamp(min=1e-6))
if result.device != x.device:
result = result.to(x.device).to(x.dtype)
@ -238,7 +238,7 @@ def cut_videos(videos):
videos = torch.cat([videos, padding], dim=1)
assert (videos.size(1) - 1) % (4) == 0
return videos
class SeedVR2InputProcessing(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -259,7 +259,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
io.Latent.Output("vae_conditioning")
]
)
@classmethod
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap):
device = vae.patcher.load_device
@ -271,7 +271,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
scale = 0.9152; shift = 0
if images.dim() != 5: # add the t dim
images = images.unsqueeze(0)
images = images.permute(0, 1, 4, 2, 3)
images = images.permute(0, 1, 4, 2, 3)
b, t, c, h, w = images.shape
images = images.reshape(b * t, c, h, w)
@ -328,7 +328,7 @@ class SeedVR2Conditioning(io.ComfyNode):
@classmethod
def execute(cls, vae_conditioning, model) -> io.NodeOutput:
vae_conditioning = vae_conditioning["samples"]
device = vae_conditioning.device
model = model.model.diffusion_model