mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-07 21:00:49 +08:00
ruff
This commit is contained in:
parent
e30298dda2
commit
5b0c80a093
@ -1,10 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
|
||||
import einops
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange
|
||||
import comfy.model_management
|
||||
from torch import nn
|
||||
import torch.nn.utils.rnn as rnn_utils
|
||||
import torch.nn.functional as F
|
||||
from math import ceil, pi
|
||||
import torch
|
||||
@ -15,7 +13,6 @@ from comfy.rmsnorm import RMSNorm
|
||||
from torch.nn.modules.utils import _triple
|
||||
from torch import nn
|
||||
import math
|
||||
import logging
|
||||
|
||||
class Cache:
|
||||
def __init__(self, disable=False, prefix="", cache=None):
|
||||
@ -126,7 +123,7 @@ def safe_pad_operation(x, padding, mode='constant', value=0.0):
|
||||
"""Safe padding operation that handles Half precision only for problematic modes"""
|
||||
# Modes qui nécessitent le fix Half precision
|
||||
problematic_modes = ['replicate', 'reflect', 'circular']
|
||||
|
||||
|
||||
if mode in problematic_modes:
|
||||
try:
|
||||
return F.pad(x, padding, mode=mode, value=value)
|
||||
@ -189,7 +186,7 @@ def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tup
|
||||
resized_h, resized_w = round(h * scale), round(w * scale)
|
||||
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
|
||||
wt = ceil(min(t, 30) / resized_nt) # window size.
|
||||
|
||||
|
||||
st, sh, sw = ( # shift size.
|
||||
0.5 if wt < t else 0,
|
||||
0.5 if wh < h else 0,
|
||||
@ -466,7 +463,7 @@ def apply_rotary_emb(
|
||||
|
||||
freqs = freqs.to(t_middle.device)
|
||||
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
||||
|
||||
|
||||
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
||||
|
||||
return out.type(dtype)
|
||||
@ -655,7 +652,7 @@ class NaSwinAttention(NaMMAttention):
|
||||
self,
|
||||
*args,
|
||||
window: Union[int, Tuple[int, int, int]],
|
||||
window_method: bool, # shifted or not
|
||||
window_method: bool, # shifted or not
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -765,7 +762,7 @@ class NaSwinAttention(NaMMAttention):
|
||||
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
||||
|
||||
return vid_out, txt_out
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -1274,7 +1271,7 @@ class NaDiT(nn.Module):
|
||||
layers=["out"],
|
||||
modes=["in"],
|
||||
)
|
||||
|
||||
|
||||
self.stop_cfg_index = -1
|
||||
|
||||
def set_cfg_stop_index(self, cfg):
|
||||
@ -1290,7 +1287,7 @@ class NaDiT(nn.Module):
|
||||
context, # l c
|
||||
disable_cache: bool = False, # for test # TODO ? // gives an error when set to True
|
||||
**kwargs
|
||||
):
|
||||
):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
@ -317,26 +317,26 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest',
|
||||
"""Safe interpolate operation that handles Half precision for problematic modes"""
|
||||
# Modes qui peuvent causer des problèmes avec Half precision
|
||||
problematic_modes = ['bilinear', 'bicubic', 'trilinear']
|
||||
|
||||
|
||||
if mode in problematic_modes:
|
||||
try:
|
||||
return F.interpolate(
|
||||
x,
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
x,
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corners=align_corners,
|
||||
recompute_scale_factor=recompute_scale_factor
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if ("not implemented for 'Half'" in str(e) or
|
||||
if ("not implemented for 'Half'" in str(e) or
|
||||
"compute_indices_weights" in str(e)):
|
||||
original_dtype = x.dtype
|
||||
return F.interpolate(
|
||||
x.float(),
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
x.float(),
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corners=align_corners,
|
||||
recompute_scale_factor=recompute_scale_factor
|
||||
).to(original_dtype)
|
||||
@ -345,10 +345,10 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest',
|
||||
else:
|
||||
# Pour 'nearest' et autres modes compatibles, pas de fix nécessaire
|
||||
return F.interpolate(
|
||||
x,
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
x,
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corners=align_corners,
|
||||
recompute_scale_factor=recompute_scale_factor
|
||||
)
|
||||
@ -426,7 +426,7 @@ class Upsample3D(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.interpolate = interpolate
|
||||
self.interpolate = interpolate
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
@ -444,7 +444,7 @@ class Upsample3D(nn.Module):
|
||||
if kernel_size is None:
|
||||
kernel_size = 3
|
||||
self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias)
|
||||
|
||||
|
||||
conv = self.conv if self.name == "conv" else self.Conv2d_0
|
||||
|
||||
assert type(conv) is not nn.ConvTranspose2d
|
||||
@ -587,7 +587,7 @@ class Downsample3D(nn.Module):
|
||||
kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
|
||||
stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio),
|
||||
)
|
||||
|
||||
|
||||
self.conv = conv
|
||||
|
||||
|
||||
@ -1565,7 +1565,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
latent = latent / scale + shift
|
||||
latent = rearrange(latent, "b ... c -> b c ...")
|
||||
latent = latent.squeeze(2)
|
||||
|
||||
|
||||
if latent.ndim == 4:
|
||||
latent = latent.unsqueeze(2)
|
||||
|
||||
|
||||
@ -815,7 +815,7 @@ class HunyuanDiT(BaseModel):
|
||||
|
||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
||||
|
||||
class SeedVR2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||
|
||||
@ -445,7 +445,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["pad_tokens_multiple"] = 32
|
||||
|
||||
return dit_config
|
||||
|
||||
|
||||
elif "{}blocks.36.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
|
||||
@ -1287,7 +1287,7 @@ class Chroma(supported_models_base.BASE):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
|
||||
class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "seedvr2"
|
||||
|
||||
@ -15,7 +15,7 @@ from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), temporal_size=16, temporal_overlap=4, encode=True):
|
||||
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -23,7 +23,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
x = x.unsqueeze(2)
|
||||
|
||||
b, c, d, h, w = x.shape
|
||||
|
||||
|
||||
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
||||
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
||||
|
||||
@ -32,7 +32,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
ov_h, ov_w = tile_overlap
|
||||
ti_t = temporal_size
|
||||
ov_t = temporal_overlap
|
||||
|
||||
|
||||
target_d = (d + sf_t - 1) // sf_t
|
||||
target_h = (h + sf_s - 1) // sf_s
|
||||
target_w = (w + sf_s - 1) // sf_s
|
||||
@ -55,7 +55,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
storage_device = torch.device("cpu")
|
||||
result = None
|
||||
count = None
|
||||
|
||||
|
||||
ramp_cache = {}
|
||||
def get_ramp(steps):
|
||||
if steps not in ramp_cache:
|
||||
@ -66,10 +66,10 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
bar = ProgressBar(d // stride_t)
|
||||
for t_idx in range(0, d, stride_t):
|
||||
t_end = min(t_idx + ti_t, d)
|
||||
|
||||
|
||||
for y_idx in range(0, h, stride_h):
|
||||
y_end = min(y_idx + ti_h, h)
|
||||
|
||||
|
||||
for x_idx in range(0, w, stride_w):
|
||||
x_end = min(x_idx + ti_w, w)
|
||||
|
||||
@ -94,7 +94,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
ts, te = t_idx // sf_t, (t_idx // sf_t) + tile_out.shape[2]
|
||||
ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3]
|
||||
xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4]
|
||||
|
||||
|
||||
cur_ov_t = max(0, min(ov_t // sf_t, tile_out.shape[2] // 2))
|
||||
cur_ov_h = max(0, min(ov_h // sf_s, tile_out.shape[3] // 2))
|
||||
cur_ov_w = max(0, min(ov_w // sf_s, tile_out.shape[4] // 2))
|
||||
@ -115,7 +115,7 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
r = get_ramp(cur_ov_t)
|
||||
if t_idx > 0: w_t[:cur_ov_t] = r
|
||||
if t_end < d: w_t[-cur_ov_t:] = 1.0 - r
|
||||
|
||||
|
||||
if cur_ov_h > 0:
|
||||
r = get_ramp(cur_ov_h)
|
||||
if y_idx > 0: w_h[:cur_ov_h] = r
|
||||
@ -131,11 +131,11 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
tile_out.mul_(final_weight)
|
||||
result[:, :, ts:te, ys:ye, xs:xe] += tile_out
|
||||
count[:, :, ts:te, ys:ye, xs:xe] += final_weight
|
||||
|
||||
|
||||
del tile_out, final_weight, tile_x, w_t, w_h, w_w
|
||||
bar.update(1)
|
||||
result.div_(count.clamp(min=1e-6))
|
||||
|
||||
|
||||
if result.device != x.device:
|
||||
result = result.to(x.device).to(x.dtype)
|
||||
|
||||
@ -238,7 +238,7 @@ def cut_videos(videos):
|
||||
videos = torch.cat([videos, padding], dim=1)
|
||||
assert (videos.size(1) - 1) % (4) == 0
|
||||
return videos
|
||||
|
||||
|
||||
class SeedVR2InputProcessing(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -259,7 +259,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
io.Latent.Output("vae_conditioning")
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, vae, resolution_height, resolution_width, spatial_tile_size, temporal_tile_size, spatial_overlap, temporal_overlap):
|
||||
device = vae.patcher.load_device
|
||||
@ -271,7 +271,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
scale = 0.9152; shift = 0
|
||||
if images.dim() != 5: # add the t dim
|
||||
images = images.unsqueeze(0)
|
||||
images = images.permute(0, 1, 4, 2, 3)
|
||||
images = images.permute(0, 1, 4, 2, 3)
|
||||
|
||||
b, t, c, h, w = images.shape
|
||||
images = images.reshape(b * t, c, h, w)
|
||||
@ -328,7 +328,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, vae_conditioning, model) -> io.NodeOutput:
|
||||
|
||||
|
||||
vae_conditioning = vae_conditioning["samples"]
|
||||
device = vae_conditioning.device
|
||||
model = model.model.diffusion_model
|
||||
|
||||
Loading…
Reference in New Issue
Block a user