mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Cleanups using AGENTS.md
This commit is contained in:
parent
e595965392
commit
f437d87155
@ -60,14 +60,7 @@ def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *a
|
||||
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
||||
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
||||
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
||||
out_dtype = q_i.dtype
|
||||
if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
|
||||
q_i = q_i.to(torch.bfloat16)
|
||||
k_i = k_i.to(torch.bfloat16)
|
||||
v_i = v_i.to(torch.bfloat16)
|
||||
out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
|
||||
if out_i.dtype != out_dtype:
|
||||
out_i = out_i.to(out_dtype)
|
||||
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
||||
|
||||
out = torch.cat(out, dim=0)
|
||||
|
||||
@ -2,8 +2,6 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.seedvr.vae import safe_interpolate_operation
|
||||
from comfy.ldm.seedvr.constants import (
|
||||
CIELAB_DELTA,
|
||||
CIELAB_KAPPA,
|
||||
@ -28,7 +26,7 @@ def wavelet_blur(image: Tensor, radius):
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
||||
|
||||
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
||||
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
|
||||
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
||||
|
||||
return output
|
||||
@ -49,8 +47,7 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
if content_feat.shape != style_feat.shape:
|
||||
# Resize style to match content spatial dimensions
|
||||
if len(content_feat.shape) >= 3:
|
||||
# safe_interpolate_operation handles FP16 conversion automatically
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat = F.interpolate(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
@ -65,7 +62,7 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
del style_high_freq # Free memory immediately
|
||||
|
||||
if content_high_freq.shape != style_low_freq.shape:
|
||||
style_low_freq = safe_interpolate_operation(
|
||||
style_low_freq = F.interpolate(
|
||||
style_low_freq,
|
||||
size=content_high_freq.shape[-2:],
|
||||
mode='bilinear',
|
||||
@ -227,7 +224,7 @@ def lab_color_transfer(
|
||||
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat = F.interpolate(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
@ -308,7 +305,7 @@ def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat = F.interpolate(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union, List, Dict, Any, Callable
|
||||
import einops
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
from math import ceil, pi
|
||||
import torch
|
||||
@ -23,52 +21,6 @@ from comfy.ldm.seedvr.constants import (
|
||||
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS,
|
||||
)
|
||||
import comfy.model_management
|
||||
import numbers
|
||||
|
||||
def _torch_float8_types():
|
||||
return tuple(
|
||||
getattr(torch, name)
|
||||
for name in (
|
||||
"float8_e4m3fn",
|
||||
"float8_e4m3fnuz",
|
||||
"float8_e5m2",
|
||||
"float8_e5m2fnuz",
|
||||
"float8_e8m0fnu",
|
||||
)
|
||||
if hasattr(torch, name)
|
||||
)
|
||||
|
||||
class CustomRMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None):
|
||||
super(CustomRMSNorm, self).__init__()
|
||||
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = torch.Size(normalized_shape)
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
dims = tuple(range(-len(self.normalized_shape), 0))
|
||||
|
||||
# Norm statistics in fp32 (fp16 variance underflows); activations return
|
||||
# in the input dtype so downstream linears run at the model compute dtype.
|
||||
normalized = input.float()
|
||||
variance = normalized.pow(2).mean(dim=dims, keepdim=True)
|
||||
rms = torch.sqrt(variance + self.eps)
|
||||
|
||||
normalized = normalized / rms
|
||||
|
||||
if self.elementwise_affine:
|
||||
return (normalized * self.weight.to(torch.float32)).to(input.dtype)
|
||||
return normalized.to(input.dtype)
|
||||
|
||||
class Cache:
|
||||
def __init__(self, disable=False, prefix="", cache=None):
|
||||
@ -81,12 +33,10 @@ class Cache:
|
||||
return fn()
|
||||
|
||||
key = self.prefix + key
|
||||
try:
|
||||
result = self.cache[key]
|
||||
except KeyError:
|
||||
if key not in self.cache:
|
||||
result = fn()
|
||||
self.cache[key] = result
|
||||
return result
|
||||
return self.cache[key]
|
||||
|
||||
def namespace(self, namespace: str):
|
||||
return Cache(
|
||||
@ -144,15 +94,6 @@ class MMArg:
|
||||
vid: Any
|
||||
txt: Any
|
||||
|
||||
def safe_pad_operation(x, padding, mode='constant', value=0.0):
|
||||
try:
|
||||
return F.pad(x, padding, mode=mode, value=value)
|
||||
except RuntimeError as e:
|
||||
if "not implemented for" in str(e) and x.dtype in (torch.float16, torch.bfloat16):
|
||||
return F.pad(x.float(), padding, mode=mode, value=value).to(x.dtype)
|
||||
raise
|
||||
|
||||
|
||||
def get_args(key: str, args: List[Any]) -> List[Any]:
|
||||
return [getattr(v, key) if isinstance(v, MMArg) else v for v in args]
|
||||
|
||||
@ -235,8 +176,6 @@ class RotaryEmbedding(nn.Module):
|
||||
theta = 10000,
|
||||
max_freq = 10,
|
||||
learned_freq = False,
|
||||
cache_if_possible = True,
|
||||
cache_max_seq_len = 8192
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -247,12 +186,6 @@ class RotaryEmbedding(nn.Module):
|
||||
elif freqs_for == 'pixel':
|
||||
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
||||
|
||||
self.cache_if_possible = cache_if_possible
|
||||
self.cache_max_seq_len = cache_max_seq_len
|
||||
|
||||
self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False)
|
||||
self.cached_freqs_seq_len = 0
|
||||
|
||||
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
|
||||
|
||||
self.learned_freq = learned_freq
|
||||
@ -310,29 +243,10 @@ class RotaryEmbedding(nn.Module):
|
||||
seq_len: int | None = None,
|
||||
offset = 0
|
||||
):
|
||||
should_cache = (
|
||||
self.cache_if_possible and
|
||||
not self.learned_freq and
|
||||
exists(seq_len) and
|
||||
self.freqs_for != 'pixel' and
|
||||
(offset + seq_len) <= self.cache_max_seq_len
|
||||
)
|
||||
|
||||
if (
|
||||
should_cache and \
|
||||
exists(self.cached_freqs) and \
|
||||
(offset + seq_len) <= self.cached_freqs_seq_len
|
||||
):
|
||||
return self.cached_freqs[offset:(offset + seq_len)].detach()
|
||||
|
||||
freqs = self.freqs
|
||||
|
||||
freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
|
||||
freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2)
|
||||
|
||||
if should_cache and offset == 0:
|
||||
self.cached_freqs[:seq_len] = freqs.detach()
|
||||
self.cached_freqs_seq_len = seq_len
|
||||
freqs = freqs.unsqueeze(-1).expand(*freqs.shape, 2).flatten(-2)
|
||||
|
||||
return freqs
|
||||
|
||||
@ -346,7 +260,7 @@ class RotaryEmbeddingBase(nn.Module):
|
||||
)
|
||||
freqs = self.rope.freqs
|
||||
del self.rope.freqs
|
||||
self.rope.register_buffer("freqs", freqs.data)
|
||||
self.rope.register_buffer("freqs", freqs.detach())
|
||||
|
||||
def get_axial_freqs(self, *dims):
|
||||
return self.rope.get_axial_freqs(*dims)
|
||||
@ -371,12 +285,12 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
|
||||
]:
|
||||
freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape))
|
||||
freqs = freqs.to(device=q.device)
|
||||
q = rearrange(q, "L h d -> h L d")
|
||||
k = rearrange(k, "L h d -> h L d")
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype)
|
||||
k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype)
|
||||
q = rearrange(q, "h L d -> L h d")
|
||||
k = rearrange(k, "h L d -> L h d")
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
return q, k
|
||||
|
||||
@torch._dynamo.disable
|
||||
@ -407,11 +321,10 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
|
||||
dim=dim // rope_dim,
|
||||
freqs_for="lang",
|
||||
theta=ROPE_THETA,
|
||||
cache_if_possible=False,
|
||||
)
|
||||
freqs = self.rope.freqs
|
||||
del self.rope.freqs
|
||||
self.rope.register_buffer("freqs", freqs.data)
|
||||
self.rope.register_buffer("freqs", freqs.detach())
|
||||
self.mm = True
|
||||
|
||||
def slice_at_dim(t, dim_slice: slice, *, dim):
|
||||
@ -423,10 +336,10 @@ def slice_at_dim(t, dim_slice: slice, *, dim):
|
||||
# rotary embedding helper functions
|
||||
|
||||
def rotate_half(x):
|
||||
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
||||
x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
|
||||
x1, x2 = x.unbind(dim = -1)
|
||||
x = torch.stack((-x2, x1), dim = -1)
|
||||
return rearrange(x, '... d r -> ... (d r)')
|
||||
return x.flatten(-2)
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
@ -465,7 +378,7 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
out = torch.stack([cos, -sin, sin, cos], dim=-1)
|
||||
return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2)
|
||||
return out.reshape(*out.shape[:-1], 2, 2)
|
||||
|
||||
|
||||
def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
@ -516,19 +429,19 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
||||
vid_freqs = vid_freqs.to(target_device)
|
||||
if txt_freqs.device != target_device:
|
||||
txt_freqs = txt_freqs.to(target_device)
|
||||
vid_q = rearrange(vid_q, "L h d -> h L d")
|
||||
vid_k = rearrange(vid_k, "L h d -> h L d")
|
||||
vid_q = vid_q.transpose(0, 1)
|
||||
vid_k = vid_k.transpose(0, 1)
|
||||
vid_q = _apply_rope1_partial(vid_q, vid_freqs)
|
||||
vid_k = _apply_rope1_partial(vid_k, vid_freqs)
|
||||
vid_q = rearrange(vid_q, "h L d -> L h d")
|
||||
vid_k = rearrange(vid_k, "h L d -> L h d")
|
||||
vid_q = vid_q.transpose(0, 1)
|
||||
vid_k = vid_k.transpose(0, 1)
|
||||
|
||||
txt_q = rearrange(txt_q, "L h d -> h L d")
|
||||
txt_k = rearrange(txt_k, "L h d -> h L d")
|
||||
txt_q = txt_q.transpose(0, 1)
|
||||
txt_k = txt_k.transpose(0, 1)
|
||||
txt_q = _apply_rope1_partial(txt_q, txt_freqs)
|
||||
txt_k = _apply_rope1_partial(txt_k, txt_freqs)
|
||||
txt_q = rearrange(txt_q, "h L d -> L h d")
|
||||
txt_k = rearrange(txt_k, "h L d -> L h d")
|
||||
txt_q = txt_q.transpose(0, 1)
|
||||
txt_k = txt_k.transpose(0, 1)
|
||||
return vid_q, vid_k, txt_q, txt_k
|
||||
|
||||
@torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks
|
||||
@ -684,7 +597,7 @@ def window(
|
||||
):
|
||||
hid = unflatten(hid, hid_shape)
|
||||
hid = list(map(window_fn, hid))
|
||||
hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device)
|
||||
hid_windows = torch.as_tensor([len(x) for x in hid], device=hid_shape.device)
|
||||
hid, hid_shape = flatten(list(chain(*hid)))
|
||||
return hid, hid_shape, hid_windows
|
||||
|
||||
@ -747,8 +660,8 @@ class NaSwinAttention(NaMMAttention):
|
||||
)
|
||||
vid_qkv_win = window_partition(vid_qkv)
|
||||
|
||||
vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim)
|
||||
txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
|
||||
vid_qkv_win = vid_qkv_win.reshape(vid_qkv_win.shape[0], 3, self.heads, self.head_dim)
|
||||
txt_qkv = txt_qkv.reshape(txt_qkv.shape[0], 3, self.heads, self.head_dim)
|
||||
|
||||
vid_q, vid_k, vid_v = vid_qkv_win.unbind(1)
|
||||
txt_q, txt_k, txt_v = txt_qkv.unbind(1)
|
||||
@ -768,19 +681,19 @@ class NaSwinAttention(NaMMAttention):
|
||||
elif self.rope.mm:
|
||||
# repeat text q and k for window mmrope
|
||||
_, num_h, _ = txt_q.shape
|
||||
txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)")
|
||||
txt_q_repeat = txt_q.flatten(1, 2)
|
||||
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
|
||||
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
|
||||
txt_q_repeat = list(chain(*txt_q_repeat))
|
||||
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
|
||||
txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h)
|
||||
txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim)
|
||||
|
||||
txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)")
|
||||
txt_k_repeat = txt_k.flatten(1, 2)
|
||||
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
|
||||
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
|
||||
txt_k_repeat = list(chain(*txt_k_repeat))
|
||||
txt_k_repeat, _ = flatten(txt_k_repeat)
|
||||
txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h)
|
||||
txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim)
|
||||
|
||||
vid_q, vid_k, txt_q, txt_k = self.rope(
|
||||
vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win
|
||||
@ -799,16 +712,16 @@ class NaSwinAttention(NaMMAttention):
|
||||
v=concat_win(vid_v, txt_v),
|
||||
heads=self.heads, skip_reshape=True, skip_output_reshape=True,
|
||||
cu_seqlens_q=cache_win(
|
||||
"vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
||||
"vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
|
||||
),
|
||||
cu_seqlens_k=cache_win(
|
||||
"vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int()
|
||||
"vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
|
||||
),
|
||||
)
|
||||
vid_out, txt_out = unconcat_win(out)
|
||||
|
||||
vid_out = rearrange(vid_out, "l h d -> l (h d)")
|
||||
txt_out = rearrange(txt_out, "l h d -> l (h d)")
|
||||
vid_out = vid_out.flatten(1, 2)
|
||||
txt_out = txt_out.flatten(1, 2)
|
||||
vid_out = window_reverse(vid_out)
|
||||
|
||||
vid_out, txt_out = self.proj_out(vid_out, txt_out)
|
||||
@ -1005,7 +918,9 @@ class PatchOut(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
t, h, w = self.patch_size
|
||||
vid = self.proj(vid)
|
||||
vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w)
|
||||
b, T, H, W, channels = vid.shape
|
||||
c = channels // (t * h * w)
|
||||
vid = vid.view(b, T, H, W, t, h, w, c).permute(0, 7, 1, 4, 2, 5, 3, 6).reshape(b, c, T * t, H * h, W * w)
|
||||
if t > 1:
|
||||
vid = vid[:, :, (t - 1) :]
|
||||
return vid
|
||||
@ -1015,7 +930,7 @@ class NaPatchOut(PatchOut):
|
||||
self,
|
||||
vid: torch.FloatTensor, # l c
|
||||
vid_shape: torch.LongTensor,
|
||||
cache: Cache = Cache(disable=True), # for test
|
||||
cache: Cache = Cache(disable=True),
|
||||
vid_shape_before_patchify = None
|
||||
) -> Tuple[
|
||||
torch.FloatTensor,
|
||||
@ -1028,7 +943,9 @@ class NaPatchOut(PatchOut):
|
||||
if not (t == h == w == 1):
|
||||
vid = unflatten(vid, vid_shape)
|
||||
for i in range(len(vid)):
|
||||
vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w)
|
||||
T, H, W, channels = vid[i].shape
|
||||
c = channels // (t * h * w)
|
||||
vid[i] = vid[i].view(T, H, W, t, h, w, c).permute(0, 3, 1, 4, 2, 5, 6).reshape(T * t, H * h, W * w, c)
|
||||
if t > 1 and vid_shape_before_patchify[i, 0] % t != 0:
|
||||
vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :]
|
||||
vid, vid_shape = flatten(vid)
|
||||
@ -1056,7 +973,8 @@ class PatchIn(nn.Module):
|
||||
if t > 1:
|
||||
assert vid.size(2) % t == 1
|
||||
vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2)
|
||||
vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w)
|
||||
b, c, Tt, Hh, Ww = vid.shape
|
||||
vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c)
|
||||
vid = self.proj(vid)
|
||||
return vid
|
||||
|
||||
@ -1065,7 +983,7 @@ class NaPatchIn(PatchIn):
|
||||
self,
|
||||
vid: torch.Tensor, # l c
|
||||
vid_shape: torch.LongTensor,
|
||||
cache: Cache = Cache(disable=True), # for test
|
||||
cache: Cache = Cache(disable=True),
|
||||
) -> torch.Tensor:
|
||||
cache = cache.namespace("patch")
|
||||
vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape)
|
||||
@ -1075,7 +993,8 @@ class NaPatchIn(PatchIn):
|
||||
for i in range(len(vid)):
|
||||
if t > 1 and vid_shape_before_patchify[i, 0] % t != 0:
|
||||
vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0)
|
||||
vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w)
|
||||
Tt, Hh, Ww, c = vid[i].shape
|
||||
vid[i] = vid[i].view(Tt // t, t, Hh // h, h, Ww // w, w, c).permute(0, 2, 4, 1, 3, 5, 6).reshape(Tt // t, Hh // h, Ww // w, t * h * w * c)
|
||||
vid, vid_shape = flatten(vid)
|
||||
|
||||
vid = self.proj(vid)
|
||||
@ -1102,17 +1021,14 @@ class AdaSingle(nn.Module):
|
||||
self.emb_dim = emb_dim
|
||||
self.layers = layers
|
||||
|
||||
param_kwargs = {"device": device}
|
||||
fp8_types = _torch_float8_types()
|
||||
if dtype is not None and dtype not in fp8_types:
|
||||
param_kwargs["dtype"] = dtype
|
||||
param_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
for l in layers:
|
||||
if "in" in modes:
|
||||
self.register_parameter(f"{l}_shift", nn.Parameter(torch.zeros(dim, **param_kwargs)))
|
||||
self.register_parameter(f"{l}_scale", nn.Parameter(torch.ones(dim, **param_kwargs)))
|
||||
self.register_parameter(f"{l}_shift", nn.Parameter(torch.empty(dim, **param_kwargs)))
|
||||
self.register_parameter(f"{l}_scale", nn.Parameter(torch.empty(dim, **param_kwargs)))
|
||||
if "out" in modes:
|
||||
self.register_parameter(f"{l}_gate", nn.Parameter(torch.zeros(dim, **param_kwargs)))
|
||||
self.register_parameter(f"{l}_gate", nn.Parameter(torch.empty(dim, **param_kwargs)))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -1125,7 +1041,7 @@ class AdaSingle(nn.Module):
|
||||
hid_len: Optional[torch.LongTensor] = None, # b
|
||||
) -> torch.FloatTensor:
|
||||
idx = self.layers.index(layer)
|
||||
emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
|
||||
emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :]
|
||||
emb = expand_dims(emb, 1, hid.ndim + 1)
|
||||
|
||||
if hid_len is not None:
|
||||
@ -1145,17 +1061,6 @@ class AdaSingle(nn.Module):
|
||||
getattr(self, f"{layer}_gate", None),
|
||||
)
|
||||
|
||||
fp8_types = _torch_float8_types()
|
||||
if fp8_types:
|
||||
target_dtype = hid.dtype
|
||||
|
||||
if shiftB is not None and shiftB.dtype in fp8_types:
|
||||
shiftB = shiftB.to(target_dtype)
|
||||
if scaleB is not None and scaleB.dtype in fp8_types:
|
||||
scaleB = scaleB.to(target_dtype)
|
||||
if gateB is not None and gateB.dtype in fp8_types:
|
||||
gateB = gateB.to(target_dtype)
|
||||
|
||||
if mode == "in":
|
||||
return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
|
||||
if mode == "out":
|
||||
@ -1213,7 +1118,7 @@ def flatten(
|
||||
torch.LongTensor, # (b n)
|
||||
]:
|
||||
assert len(hid) > 0
|
||||
shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid])
|
||||
shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device)
|
||||
hid = torch.cat([x.flatten(0, -2) for x in hid])
|
||||
return hid, shape
|
||||
|
||||
@ -1227,19 +1132,6 @@ def unflatten(
|
||||
hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)]
|
||||
return hid
|
||||
|
||||
def repeat(
|
||||
hid: torch.FloatTensor, # (L c)
|
||||
hid_shape: torch.LongTensor, # (b n)
|
||||
pattern: str,
|
||||
**kwargs: Dict[str, torch.LongTensor], # (b)
|
||||
) -> Tuple[
|
||||
torch.FloatTensor,
|
||||
torch.LongTensor,
|
||||
]:
|
||||
hid = unflatten(hid, hid_shape)
|
||||
kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))]
|
||||
return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)])
|
||||
|
||||
class NaDiT(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -1275,23 +1167,11 @@ class NaDiT(nn.Module):
|
||||
emb_dim = vid_dim * 6
|
||||
window = num_layers * [(4,3,3)]
|
||||
ada = AdaSingle
|
||||
norm = CustomRMSNorm
|
||||
qk_norm = CustomRMSNorm
|
||||
norm = operations.RMSNorm
|
||||
qk_norm = operations.RMSNorm
|
||||
super().__init__()
|
||||
# ``torch.empty`` returns uninitialized memory, not zeros. The
|
||||
# SeedVR2Conditioning fail-loud guard at
|
||||
# ``comfy_extras/nodes_seedvr.py`` distinguishes "buffer was loaded"
|
||||
# from "buffer was never populated by the file" by checking
|
||||
# ``positive_conditioning.abs().sum() == 0``. That sentinel is only
|
||||
# reliable if the post-construction buffer state is deterministically
|
||||
# zero, so explicitly zero-fill here rather than relying on the
|
||||
# allocator's zero-on-alloc behavior (allocator-dependent and not
|
||||
# contractual). When ``load_state_dict`` populates these buffers
|
||||
# from a properly-baked SeedVR2 .safetensors, the in-place copy
|
||||
# overwrites the zeros with the universal SeedVR2 conditioning
|
||||
# tensors (shape (58, 5120) and (64, 5120) bf16).
|
||||
self.register_buffer("positive_conditioning", torch.zeros((58, 5120), device=device, dtype=dtype))
|
||||
self.register_buffer("negative_conditioning", torch.zeros((64, 5120), device=device, dtype=dtype))
|
||||
self.register_buffer("positive_conditioning", torch.empty((58, 5120), device=device, dtype=dtype))
|
||||
self.register_buffer("negative_conditioning", torch.empty((64, 5120), device=device, dtype=dtype))
|
||||
self.vid_in = NaPatchIn(
|
||||
in_channels=vid_in_channels,
|
||||
patch_size=patch_size,
|
||||
@ -1354,7 +1234,7 @@ class NaDiT(nn.Module):
|
||||
|
||||
self.vid_out_norm = None
|
||||
if vid_out_norm is not None:
|
||||
self.vid_out_norm = CustomRMSNorm(
|
||||
self.vid_out_norm = operations.RMSNorm(
|
||||
normalized_shape=vid_dim,
|
||||
eps=norm_eps,
|
||||
elementwise_affine=True,
|
||||
@ -1369,7 +1249,7 @@ class NaDiT(nn.Module):
|
||||
)
|
||||
|
||||
def _resolve_text_conditioning(self, context, cond_or_uncond=None):
|
||||
if context is None or getattr(context, "numel", lambda: None)() == 0:
|
||||
if context is None or context.numel() == 0:
|
||||
context = self.positive_conditioning
|
||||
return flatten([context])
|
||||
if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond):
|
||||
@ -1407,7 +1287,7 @@ class NaDiT(nn.Module):
|
||||
x,
|
||||
timestep,
|
||||
context, # l c
|
||||
disable_cache: bool = False, # for test # TODO ? // gives an error when set to True
|
||||
disable_cache: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
@ -1483,5 +1363,5 @@ class NaDiT(nn.Module):
|
||||
vid = unflatten(vid, vid_shape)
|
||||
out = torch.stack(vid)
|
||||
out = out.movedim(-1, 1)
|
||||
out = rearrange(out, "b c t h w -> b (c t) h w")
|
||||
out = out.reshape(out.shape[0], out.shape[1] * out.shape[2], out.shape[3], out.shape[4])
|
||||
return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond"))
|
||||
|
||||
@ -1,15 +1,11 @@
|
||||
from contextlib import nullcontext
|
||||
from typing import Literal, Optional, Tuple
|
||||
import gc
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from contextlib import contextmanager
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.seedvr.constants import (
|
||||
BYTEDANCE_BLOCK_OUT_CHANNELS,
|
||||
BYTEDANCE_GN_CHUNKS_FP16,
|
||||
@ -58,13 +54,6 @@ def _seedvr2_clamped_spatial_overlap(overlap, tile_size):
|
||||
return min(overlap, tile_size - 1)
|
||||
|
||||
|
||||
def _seedvr2_clear_temporal_memory(model):
|
||||
for module in model.modules():
|
||||
if hasattr(module, "memory"):
|
||||
module.memory = None
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_vae(
|
||||
x,
|
||||
vae_model,
|
||||
@ -75,10 +64,6 @@ def tiled_vae(
|
||||
encode=True,
|
||||
**kwargs,
|
||||
):
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
|
||||
x = x.to(next(vae_model.parameters()).dtype)
|
||||
if x.ndim != 5:
|
||||
x = x.unsqueeze(2)
|
||||
|
||||
@ -121,7 +106,6 @@ def tiled_vae(
|
||||
count = None
|
||||
def run_temporal_chunks(spatial_tile, model=vae_model, device=storage_device):
|
||||
device = torch.device(device)
|
||||
_seedvr2_clear_temporal_memory(model)
|
||||
t_chunk = spatial_tile.to(device=device, dtype=next(model.parameters()).dtype, non_blocking=True).contiguous()
|
||||
old_device = getattr(model, "device", None)
|
||||
model.device = device
|
||||
@ -133,7 +117,7 @@ def tiled_vae(
|
||||
setattr(model, slicing_attr, slicing_min_size)
|
||||
try:
|
||||
if encode:
|
||||
out = model.encode(t_chunk)[0]
|
||||
out = model.encode(t_chunk)
|
||||
else:
|
||||
out = model.decode_(t_chunk)
|
||||
finally:
|
||||
@ -141,8 +125,6 @@ def tiled_vae(
|
||||
setattr(model, slicing_attr, old_slicing_min_size)
|
||||
if old_device is not None:
|
||||
model.device = old_device
|
||||
if isinstance(out, (tuple, list)):
|
||||
out = out[0]
|
||||
if out.ndim == 4:
|
||||
out = out.unsqueeze(2)
|
||||
return out.to(storage_device)
|
||||
@ -169,8 +151,6 @@ def tiled_vae(
|
||||
bar = ProgressBar(total_tiles)
|
||||
single_spatial_tile = h <= ti_h and w <= ti_w
|
||||
|
||||
_seedvr2_clear_temporal_memory(vae_model)
|
||||
|
||||
def run_tile(tile_index, tile_range):
|
||||
y_idx, y_end, x_idx, x_end = tile_range
|
||||
tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end]
|
||||
@ -186,7 +166,6 @@ def tiled_vae(
|
||||
|
||||
if single_spatial_tile:
|
||||
result = tile_out[:, :, :target_d, :target_h, :target_w]
|
||||
_seedvr2_clear_temporal_memory(vae_model)
|
||||
if result.device != x.device:
|
||||
result = result.to(x.device).to(x.dtype)
|
||||
if x.shape[2] == 1 and sf_t == 1:
|
||||
@ -241,7 +220,6 @@ def tiled_vae(
|
||||
bar.update(1)
|
||||
|
||||
result.div_(count.clamp(min=1e-6))
|
||||
_seedvr2_clear_temporal_memory(vae_model)
|
||||
|
||||
if result.device != x.device:
|
||||
result = result.to(x.device).to(x.dtype)
|
||||
@ -336,7 +314,6 @@ class Attention(nn.Module):
|
||||
eps: float = 1e-5,
|
||||
rescale_output_factor: float = 1.0,
|
||||
residual_connection: bool = False,
|
||||
_from_deprecated_attn_block: bool = False,
|
||||
out_dim: int = None,
|
||||
pre_only=False,
|
||||
):
|
||||
@ -356,10 +333,6 @@ class Attention(nn.Module):
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.pre_only = pre_only
|
||||
|
||||
# we make use of this private variable to know whether this class is loaded
|
||||
# with an deprecated state dict so that we can convert it on the fly
|
||||
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
||||
|
||||
self.scale_qk = scale_qk
|
||||
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
||||
|
||||
@ -480,21 +453,21 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = x.dtype
|
||||
if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)):
|
||||
if x.ndim == 4:
|
||||
x = rearrange(x, "b c h w -> b h w c")
|
||||
x = x.permute(0, 2, 3, 1)
|
||||
x = norm_layer(x)
|
||||
x = rearrange(x, "b h w c -> b c h w")
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return x.to(input_dtype)
|
||||
if x.ndim == 5:
|
||||
x = rearrange(x, "b c t h w -> b t h w c")
|
||||
x = x.permute(0, 2, 3, 4, 1)
|
||||
x = norm_layer(x)
|
||||
x = rearrange(x, "b t h w c -> b c t h w")
|
||||
x = x.permute(0, 4, 1, 2, 3)
|
||||
return x.to(input_dtype)
|
||||
if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
|
||||
if x.ndim <= 4:
|
||||
return norm_layer(x).to(input_dtype)
|
||||
if x.ndim == 5:
|
||||
t = x.size(2)
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.transpose(1, 2).reshape(b * t, c, h, w)
|
||||
memory_occupy = x.numel() * x.element_size() / 1024**3
|
||||
if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit():
|
||||
num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups)
|
||||
@ -504,54 +477,16 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||
x = list(x.chunk(num_chunks, dim=1))
|
||||
weights = norm_layer.weight.chunk(num_chunks, dim=0)
|
||||
biases = norm_layer.bias.chunk(num_chunks, dim=0)
|
||||
for i, (w, b) in enumerate(zip(weights, biases)):
|
||||
x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps)
|
||||
for i, (w, bias) in enumerate(zip(weights, biases)):
|
||||
x[i] = F.group_norm(x[i], num_groups_per_chunk, w, bias, norm_layer.eps)
|
||||
x[i] = x[i].to(input_dtype)
|
||||
x = torch.cat(x, dim=1)
|
||||
else:
|
||||
x = norm_layer(x)
|
||||
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
||||
x = x.reshape((b, t, x.size(1), x.size(2), x.size(3))).transpose(1, 2)
|
||||
return x.to(input_dtype)
|
||||
raise NotImplementedError
|
||||
|
||||
def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None):
|
||||
problematic_modes = ['bilinear', 'bicubic', 'trilinear']
|
||||
|
||||
if mode in problematic_modes:
|
||||
try:
|
||||
return F.interpolate(
|
||||
x,
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corners=align_corners,
|
||||
recompute_scale_factor=recompute_scale_factor
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if ("not implemented for 'Half'" in str(e) or
|
||||
"compute_indices_weights" in str(e)):
|
||||
original_dtype = x.dtype
|
||||
return F.interpolate(
|
||||
x.float(),
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corners=align_corners,
|
||||
recompute_scale_factor=recompute_scale_factor
|
||||
).to(original_dtype)
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
# Pour 'nearest' et autres modes compatibles, pas de fix nécessaire
|
||||
return F.interpolate(
|
||||
x,
|
||||
size=size,
|
||||
scale_factor=scale_factor,
|
||||
mode=mode,
|
||||
align_corners=align_corners,
|
||||
recompute_scale_factor=recompute_scale_factor
|
||||
)
|
||||
|
||||
_receptive_field_t = Literal["half", "full"]
|
||||
|
||||
def extend_head(tensor, times: int = 2, memory = None):
|
||||
@ -585,7 +520,6 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
**kwargs,
|
||||
):
|
||||
self.inflation_mode = inflation_mode
|
||||
self.memory = None
|
||||
super().__init__(*args, **kwargs)
|
||||
self.temporal_padding = self.padding[0]
|
||||
self.padding = (0, *self.padding[1:])
|
||||
@ -620,18 +554,19 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
return super().forward(x)
|
||||
|
||||
# Compute tensor shape after concat & padding.
|
||||
shape = torch.tensor(x.size())
|
||||
shape = list(x.size())
|
||||
if prev_cache is not None:
|
||||
shape[split_dim - 1] += prev_cache.size(split_dim - 1)
|
||||
shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0)
|
||||
memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB
|
||||
for i, pad_sum in enumerate((padding[4] + padding[5], padding[2] + padding[3], padding[0] + padding[1])):
|
||||
shape[-3 + i] += pad_sum
|
||||
memory_occupy = math.prod(shape) * x.element_size() / 1024**3 # GiB
|
||||
if memory_occupy < self.memory_limit or split_dim == x.ndim:
|
||||
x_concat = x
|
||||
if prev_cache is not None:
|
||||
x_concat = torch.cat([prev_cache, x], dim=split_dim - 1)
|
||||
|
||||
def pad_and_forward():
|
||||
padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0)
|
||||
padded = F.pad(x_concat, padding, mode='constant', value=0.0)
|
||||
if not padded.is_contiguous():
|
||||
padded = padded.contiguous()
|
||||
with ignore_padding(self):
|
||||
@ -689,46 +624,57 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
memory_state: MemoryState = MemoryState.UNSET
|
||||
memory_state: MemoryState = MemoryState.UNSET,
|
||||
memory_cache = None,
|
||||
) -> Tensor:
|
||||
assert memory_state != MemoryState.UNSET
|
||||
if memory_cache is None:
|
||||
memory_cache = {}
|
||||
if memory_state != MemoryState.ACTIVE:
|
||||
self.memory = None
|
||||
memory_cache.pop(self, None)
|
||||
if (
|
||||
math.isinf(self.memory_limit)
|
||||
and torch.is_tensor(input)
|
||||
):
|
||||
return self.basic_forward(input, memory_state)
|
||||
return self.slicing_forward(input, memory_state)
|
||||
return self.basic_forward(input, memory_state, memory_cache)
|
||||
return self.slicing_forward(input, memory_state, memory_cache)
|
||||
|
||||
def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET):
|
||||
def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET, memory_cache = None):
|
||||
mem_size = self.stride[0] - self.kernel_size[0]
|
||||
if (self.memory is not None) and (memory_state == MemoryState.ACTIVE):
|
||||
input = extend_head(input, memory=self.memory, times=-1)
|
||||
memory = memory_cache.get(self) if memory_cache is not None else None
|
||||
if (memory is not None) and (memory_state == MemoryState.ACTIVE):
|
||||
input = extend_head(input, memory=memory, times=-1)
|
||||
else:
|
||||
input = extend_head(input, times=self.temporal_padding * 2)
|
||||
memory = (
|
||||
next_memory = (
|
||||
input[:, :, mem_size:].detach()
|
||||
if (mem_size != 0 and memory_state != MemoryState.DISABLED)
|
||||
else None
|
||||
)
|
||||
if memory_state != MemoryState.DISABLED:
|
||||
self.memory = memory
|
||||
if memory_cache is not None and memory_state != MemoryState.DISABLED:
|
||||
if next_memory is None:
|
||||
memory_cache.pop(self, None)
|
||||
else:
|
||||
memory_cache[self] = next_memory
|
||||
return super().forward(input)
|
||||
|
||||
def slicing_forward(
|
||||
self,
|
||||
input,
|
||||
memory_state: MemoryState = MemoryState.UNSET,
|
||||
memory_cache = None,
|
||||
) -> Tensor:
|
||||
if memory_cache is None:
|
||||
memory_cache = {}
|
||||
squeeze_out = False
|
||||
if torch.is_tensor(input):
|
||||
input = [input]
|
||||
squeeze_out = True
|
||||
|
||||
cache_size = self.kernel_size[0] - self.stride[0]
|
||||
memory = memory_cache.get(self) if memory_cache is not None else None
|
||||
cache = cache_send_recv(
|
||||
input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2
|
||||
input, cache_size=cache_size, memory=memory, times=self.temporal_padding * 2
|
||||
)
|
||||
|
||||
# Single GPU inference - simplified memory management
|
||||
@ -740,7 +686,7 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
input[0] = torch.cat([cache, input[0]], dim=2)
|
||||
cache = None
|
||||
if cache_size <= input[-1].size(2):
|
||||
self.memory = input[-1][:, :, -cache_size:].detach().contiguous()
|
||||
memory_cache[self] = input[-1][:, :, -cache_size:].detach().contiguous()
|
||||
|
||||
padding = tuple(x for x in reversed(self.padding) for _ in range(2))
|
||||
for i in range(len(input)):
|
||||
@ -802,17 +748,10 @@ class Upsample3D(nn.Module):
|
||||
self.temporal_ratio = 2 if temporal_up else 1
|
||||
self.spatial_ratio = 2 if spatial_up else 1
|
||||
|
||||
# [Override] MAGViT v2 learnable upsample
|
||||
upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio
|
||||
self.upscale_conv = ops.Conv3d(
|
||||
self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0
|
||||
)
|
||||
identity = (
|
||||
torch.eye(self.channels)
|
||||
.repeat(upscale_ratio, 1)
|
||||
.reshape_as(self.upscale_conv.weight)
|
||||
)
|
||||
self.upscale_conv.weight.data.copy_(identity)
|
||||
|
||||
self.conv = conv
|
||||
|
||||
@ -820,23 +759,27 @@ class Upsample3D(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
memory_state=None,
|
||||
memory_cache=None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
hidden_states = self.upscale_conv(hidden_states)
|
||||
hidden_states = rearrange(
|
||||
hidden_states,
|
||||
"b (x y z c) f h w -> b c (f z) (h x) (w y)",
|
||||
x=self.spatial_ratio,
|
||||
y=self.spatial_ratio,
|
||||
z=self.temporal_ratio,
|
||||
b, channels, f, h, w = hidden_states.shape
|
||||
c = channels // (self.spatial_ratio * self.spatial_ratio * self.temporal_ratio)
|
||||
hidden_states = hidden_states.view(b, self.spatial_ratio, self.spatial_ratio, self.temporal_ratio, c, f, h, w)
|
||||
hidden_states = hidden_states.permute(0, 4, 5, 3, 6, 1, 7, 2).reshape(
|
||||
b,
|
||||
c,
|
||||
f * self.temporal_ratio,
|
||||
h * self.spatial_ratio,
|
||||
w * self.spatial_ratio,
|
||||
)
|
||||
|
||||
if self.temporal_up and memory_state != MemoryState.ACTIVE:
|
||||
hidden_states = remove_head(hidden_states)
|
||||
|
||||
hidden_states = self.conv(hidden_states, memory_state=memory_state)
|
||||
hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -879,6 +822,7 @@ class Downsample3D(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
memory_state = None,
|
||||
memory_cache = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
@ -890,11 +834,11 @@ class Downsample3D(nn.Module):
|
||||
|
||||
if self.spatial_down:
|
||||
pad = (0, 1, 0, 1)
|
||||
hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0)
|
||||
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
hidden_states = self.conv(hidden_states, memory_state=memory_state)
|
||||
hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -962,7 +906,7 @@ class ResnetBlock3D(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_tensor, temb, memory_state = None, **kwargs
|
||||
self, input_tensor, temb, memory_state = None, memory_cache = None, **kwargs
|
||||
):
|
||||
hidden_states = input_tensor
|
||||
|
||||
@ -970,7 +914,7 @@ class ResnetBlock3D(nn.Module):
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, memory_state=memory_state)
|
||||
hidden_states = self.conv1(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
if self.time_emb_proj is not None:
|
||||
if not self.skip_time_act:
|
||||
@ -985,10 +929,10 @@ class ResnetBlock3D(nn.Module):
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states, memory_state=memory_state)
|
||||
hidden_states = self.conv2(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state)
|
||||
input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
@ -1055,15 +999,16 @@ class DownEncoderBlock3D(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
memory_state = None,
|
||||
memory_cache = None,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
||||
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
|
||||
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache)
|
||||
hidden_states = temporal(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, memory_state=memory_state)
|
||||
hidden_states = downsampler(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1132,15 +1077,16 @@ class UpDecoderBlock3D(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
memory_state=None
|
||||
memory_state=None,
|
||||
memory_cache=None,
|
||||
) -> torch.FloatTensor:
|
||||
for resnet, temporal in zip(self.resnets, self.temporal_modules):
|
||||
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state)
|
||||
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache)
|
||||
hidden_states = temporal(hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, memory_state=memory_state)
|
||||
hidden_states = upsampler(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1203,7 +1149,6 @@ class UNetMidBlock3D(nn.Module):
|
||||
residual_connection=True,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
_from_deprecated_attn_block=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -1226,17 +1171,16 @@ class UNetMidBlock3D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, memory_state=None):
|
||||
def forward(self, hidden_states, temb=None, memory_state=None, memory_cache=None):
|
||||
video_length, frame_height, frame_width = hidden_states.size()[-3:]
|
||||
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state)
|
||||
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if attn is not None:
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
||||
b, c, f, h, w = hidden_states.shape
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(b * f, c, h, w)
|
||||
hidden_states = attn(hidden_states, temb=temb)
|
||||
hidden_states = rearrange(
|
||||
hidden_states, "(b f) c h w -> b c f h w", f=video_length
|
||||
)
|
||||
hidden_states = resnet(hidden_states, temb, memory_state=memory_state)
|
||||
hidden_states = hidden_states.reshape(b, video_length, c, h, w).transpose(1, 2)
|
||||
hidden_states = resnet(hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -1327,22 +1271,23 @@ class Encoder3D(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
memory_state = None
|
||||
memory_state = None,
|
||||
memory_cache = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
sample = sample.to(next(self.parameters()).device)
|
||||
sample = self.conv_in(sample, memory_state = memory_state)
|
||||
sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
sample = down_block(sample, memory_state=memory_state)
|
||||
sample = down_block(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
# middle
|
||||
sample = self.mid_block(sample, memory_state=memory_state)
|
||||
sample = self.mid_block(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
# post-process
|
||||
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, memory_state = memory_state)
|
||||
sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
return sample
|
||||
|
||||
@ -1436,24 +1381,25 @@ class Decoder3D(nn.Module):
|
||||
sample: torch.FloatTensor,
|
||||
latent_embeds: Optional[torch.FloatTensor] = None,
|
||||
memory_state = None,
|
||||
memory_cache = None,
|
||||
) -> torch.FloatTensor:
|
||||
|
||||
sample = sample.to(next(self.parameters()).device)
|
||||
sample = self.conv_in(sample, memory_state=memory_state)
|
||||
sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
# middle
|
||||
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state)
|
||||
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache)
|
||||
sample = sample.to(upscale_dtype)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = up_block(sample, latent_embeds, memory_state=memory_state)
|
||||
sample = up_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
# post-process
|
||||
sample = causal_norm_wrapper(self.conv_norm_out, sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, memory_state=memory_state)
|
||||
sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache)
|
||||
|
||||
return sample
|
||||
|
||||
@ -1529,22 +1475,23 @@ class VideoAutoencoderKL(nn.Module):
|
||||
return decoded
|
||||
|
||||
def _encode(
|
||||
self, x, memory_state = MemoryState.DISABLED
|
||||
self, x, memory_state = MemoryState.DISABLED, memory_cache = None
|
||||
) -> torch.Tensor:
|
||||
_x = x.to(self.device)
|
||||
h = self.encoder(_x, memory_state=memory_state)
|
||||
h = self.encoder(_x, memory_state=memory_state, memory_cache=memory_cache)
|
||||
return h.to(x.device)
|
||||
|
||||
def _decode(
|
||||
self, z, memory_state = MemoryState.DISABLED
|
||||
self, z, memory_state = MemoryState.DISABLED, memory_cache = None
|
||||
) -> torch.Tensor:
|
||||
_z = z.to(self.device)
|
||||
output = self.decoder(_z, memory_state=memory_state)
|
||||
output = self.decoder(_z, memory_state=memory_state, memory_cache=memory_cache)
|
||||
return output.to(z.device)
|
||||
|
||||
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
sp_size =1
|
||||
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
|
||||
memory_cache = {}
|
||||
split_size = max(
|
||||
self.slicing_sample_min_size * sp_size,
|
||||
getattr(self, "temporal_downsample_factor", 1),
|
||||
@ -1558,17 +1505,14 @@ class VideoAutoencoderKL(nn.Module):
|
||||
self._encode(
|
||||
torch.cat((x[:, :, :1], x_slices[0]), dim=2),
|
||||
memory_state=MemoryState.INITIALIZING,
|
||||
memory_cache=memory_cache,
|
||||
)
|
||||
]
|
||||
for x_idx in range(1, len(x_slices)):
|
||||
encoded_slices.append(
|
||||
self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE)
|
||||
self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE, memory_cache=memory_cache)
|
||||
)
|
||||
out = torch.cat(encoded_slices, dim=2)
|
||||
modules_with_memory = [m for m in self.modules()
|
||||
if isinstance(m, InflatedCausalConv3d) and m.memory is not None]
|
||||
for m in modules_with_memory:
|
||||
m.memory = None
|
||||
return out
|
||||
else:
|
||||
return self._encode(x)
|
||||
@ -1576,22 +1520,20 @@ class VideoAutoencoderKL(nn.Module):
|
||||
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
sp_size = 1
|
||||
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
|
||||
memory_cache = {}
|
||||
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
|
||||
decoded_slices = [
|
||||
self._decode(
|
||||
torch.cat((z[:, :, :1], z_slices[0]), dim=2),
|
||||
memory_state=MemoryState.INITIALIZING
|
||||
memory_state=MemoryState.INITIALIZING,
|
||||
memory_cache=memory_cache,
|
||||
)
|
||||
]
|
||||
for z_idx in range(1, len(z_slices)):
|
||||
decoded_slices.append(
|
||||
self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE)
|
||||
self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE, memory_cache=memory_cache)
|
||||
)
|
||||
out = torch.cat(decoded_slices, dim=2)
|
||||
modules_with_memory = [m for m in self.modules()
|
||||
if isinstance(m, InflatedCausalConv3d) and m.memory is not None]
|
||||
for m in modules_with_memory:
|
||||
m.memory = None
|
||||
return out
|
||||
else:
|
||||
return self._decode(z)
|
||||
@ -1612,32 +1554,25 @@ class VideoAutoencoderKL(nn.Module):
|
||||
return _unwrap(self.decode_(latent))
|
||||
|
||||
class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
# Signals to comfy.sd.VAE that this model performs its own VAE tiling, so the
|
||||
# generic tiled-decode/encode dispatch defers to decode_tiled/encode_tiled below.
|
||||
comfy_handles_tiling = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
spatial_downsample_factor = 8,
|
||||
temporal_downsample_factor = 4,
|
||||
freeze_encoder = True,
|
||||
**kwargs,
|
||||
):
|
||||
self.spatial_downsample_factor = spatial_downsample_factor
|
||||
self.temporal_downsample_factor = temporal_downsample_factor
|
||||
self.freeze_encoder = freeze_encoder
|
||||
self.enable_tiling = False
|
||||
super().__init__(*args, **kwargs)
|
||||
self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB)
|
||||
|
||||
def forward(self, x: torch.FloatTensor):
|
||||
with torch.no_grad() if self.freeze_encoder else nullcontext():
|
||||
z, p = self.encode(x)
|
||||
z, p = self._encode_with_raw_latent(x)
|
||||
x = self.decode(z)
|
||||
return x, z, p
|
||||
|
||||
def encode(self, x, orig_dims=None):
|
||||
def _encode_with_raw_latent(self, x):
|
||||
if x.ndim == 4:
|
||||
x = x.unsqueeze(2)
|
||||
x = x.to(dtype=next(self.parameters()).dtype)
|
||||
@ -1646,6 +1581,10 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
z = p.squeeze(2)
|
||||
return z, p
|
||||
|
||||
def encode(self, x, orig_dims=None):
|
||||
z, _ = self._encode_with_raw_latent(x)
|
||||
return z
|
||||
|
||||
def decode(self, z, seedvr2_tiling=None):
|
||||
seedvr2_tiling = {} if seedvr2_tiling is None else seedvr2_tiling
|
||||
if not isinstance(seedvr2_tiling, dict):
|
||||
|
||||
@ -1151,9 +1151,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return unet_config
|
||||
|
||||
def model_config_from_unet_config(unet_config, state_dict=None):
|
||||
def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""):
|
||||
for model_config in comfy.supported_models.models:
|
||||
if model_config.matches(unet_config, state_dict):
|
||||
if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix):
|
||||
return model_config(unet_config)
|
||||
|
||||
logging.error("no match {}".format(unet_config))
|
||||
@ -1163,7 +1163,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
|
||||
if unet_config is None:
|
||||
return None
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict)
|
||||
model_config = model_config_from_unet_config(unet_config, state_dict, unet_key_prefix)
|
||||
if model_config is None and use_base_if_no_match:
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
|
||||
41
comfy/sd.py
41
comfy/sd.py
@ -1,4 +1,3 @@
|
||||
import inspect
|
||||
import json
|
||||
import torch
|
||||
from enum import Enum
|
||||
@ -500,6 +499,8 @@ class VAE:
|
||||
self.upscale_index_formula = None
|
||||
self.extra_1d_channel = None
|
||||
self.crop_input = True
|
||||
self.handles_tiling = False
|
||||
self.format_encoded = None
|
||||
|
||||
self.audio_sample_rate = 44100
|
||||
|
||||
@ -554,6 +555,8 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape)
|
||||
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.handles_tiling = True
|
||||
self.format_encoded = self.first_stage_model.comfy_format_encoded
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
@ -1118,7 +1121,7 @@ class VAE:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||
if self.handles_tiling:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
@ -1127,7 +1130,7 @@ class VAE:
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||
if self.handles_tiling:
|
||||
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
@ -1149,7 +1152,7 @@ class VAE:
|
||||
args["overlap"] = overlap
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if getattr(self.first_stage_model, "comfy_handles_tiling", False) and dims in (2, 3):
|
||||
if self.handles_tiling and dims in (2, 3):
|
||||
tiled_args = {}
|
||||
if tile_x is not None:
|
||||
tiled_args["tile_x"] = tile_x
|
||||
@ -1204,8 +1207,6 @@ class VAE:
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
@ -1225,7 +1226,7 @@ class VAE:
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||
if self.handles_tiling:
|
||||
samples = self._encode_tiled_owned(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
@ -1234,9 +1235,8 @@ class VAE:
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
formatter = getattr(self.first_stage_model, "comfy_format_encoded", None)
|
||||
if formatter is not None:
|
||||
samples = formatter(samples)
|
||||
if self.format_encoded is not None:
|
||||
samples = self.format_encoded(samples)
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
@ -1268,7 +1268,7 @@ class VAE:
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||
if self.handles_tiling:
|
||||
tiled_args = {}
|
||||
if tile_x is not None:
|
||||
tiled_args["tile_x"] = tile_x
|
||||
@ -1298,9 +1298,8 @@ class VAE:
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
|
||||
formatter = getattr(self.first_stage_model, "comfy_format_encoded", None)
|
||||
if formatter is not None:
|
||||
samples = formatter(samples)
|
||||
if self.format_encoded is not None:
|
||||
samples = self.format_encoded(samples)
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
@ -1852,16 +1851,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
return (model, clip, vae)
|
||||
|
||||
|
||||
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
|
||||
set_dtype = model_config.set_inference_dtype
|
||||
parameters = inspect.signature(set_dtype).parameters
|
||||
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
|
||||
if supports_device:
|
||||
set_dtype(dtype, manual_cast_dtype, device=device)
|
||||
else:
|
||||
set_dtype(dtype, manual_cast_dtype)
|
||||
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
@ -1969,7 +1958,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype, device=load_device)
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
if output_clipvision:
|
||||
@ -2110,7 +2099,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype, device=load_device)
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
@ -1688,6 +1688,10 @@ class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "seedvr2"
|
||||
}
|
||||
required_keys = {
|
||||
"{}positive_conditioning",
|
||||
"{}negative_conditioning",
|
||||
}
|
||||
latent_format = comfy.latent_formats.SeedVR2
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
|
||||
@ -54,13 +54,13 @@ class BASE:
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
def matches(s, unet_config, state_dict=None):
|
||||
def matches(s, unet_config, state_dict=None, unet_key_prefix=""):
|
||||
for k in s.unet_config:
|
||||
if k not in unet_config or s.unet_config[k] != unet_config[k]:
|
||||
return False
|
||||
if state_dict is not None:
|
||||
for k in s.required_keys:
|
||||
if k not in state_dict:
|
||||
if k.format(unet_key_prefix) not in state_dict:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ from comfy_api.latest import ComfyExtension, io
|
||||
import torch
|
||||
import math
|
||||
import logging
|
||||
from einops import rearrange
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.sample
|
||||
@ -101,14 +100,6 @@ def _resolve_seedvr2_diffusion_model(model):
|
||||
return diffusion_model
|
||||
|
||||
|
||||
def _apply_rope_freqs_float32_cast(diffusion_model):
|
||||
"""Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype."""
|
||||
for module in diffusion_model.modules():
|
||||
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
|
||||
if module.rope.freqs.data.dtype != torch.float32:
|
||||
module.rope.freqs.data = module.rope.freqs.data.to(torch.float32)
|
||||
|
||||
|
||||
def get_conditions(latent, latent_blur):
|
||||
t, h, w, c = latent.shape
|
||||
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
||||
@ -193,7 +184,7 @@ def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
|
||||
|
||||
images = images.reshape(b, t, c, new_h, new_w)
|
||||
images = cut_videos(images)
|
||||
images_bthwc = rearrange(images, "b t c h w -> b t h w c")
|
||||
images_bthwc = images.permute(0, 1, 3, 4, 2).contiguous()
|
||||
|
||||
return io.NodeOutput(images_bthwc)
|
||||
|
||||
@ -265,12 +256,12 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
||||
output_device = decoded_5d.device
|
||||
decoded_raw = cls._to_seedvr2_raw(decoded_5d)
|
||||
reference_raw = cls._to_seedvr2_raw(reference_5d)
|
||||
decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w")
|
||||
reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w")
|
||||
decoded_flat = decoded_raw.permute(0, 1, 4, 2, 3).reshape(b * t, decoded_raw.shape[4], target_h, target_w)
|
||||
reference_flat = reference_raw.permute(0, 1, 4, 2, 3).reshape(b * t, reference_raw.shape[4], target_h, target_w)
|
||||
output = cls._color_transfer_chunked(
|
||||
decoded_flat, reference_flat, output_device, color_correction_method,
|
||||
)
|
||||
output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t)
|
||||
output = output.reshape(b, t, output.shape[1], output.shape[2], output.shape[3]).permute(0, 1, 3, 4, 2)
|
||||
output = output.add(1.0).div(2.0).clamp(0.0, 1.0)
|
||||
elif color_correction_method == "none":
|
||||
output = decoded_5d
|
||||
@ -359,7 +350,6 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
||||
) from e
|
||||
next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
|
||||
|
||||
comfy.model_management.soft_empty_cache()
|
||||
chunk_size = next_chunk_size
|
||||
|
||||
@classmethod
|
||||
@ -419,14 +409,14 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
||||
if reference.shape[2] == height and reference.shape[3] == width:
|
||||
return reference
|
||||
b, t = reference.shape[:2]
|
||||
reference_flat = rearrange(reference, "b t h w c -> (b t) c h w")
|
||||
reference_flat = reference.permute(0, 1, 4, 2, 3).reshape(b * t, reference.shape[4], reference.shape[2], reference.shape[3])
|
||||
resized = TVF.resize(
|
||||
reference_flat,
|
||||
size=(height, width),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"),
|
||||
)
|
||||
return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t)
|
||||
return resized.reshape(b, t, resized.shape[1], height, width).permute(0, 1, 3, 4, 2)
|
||||
|
||||
|
||||
class SeedVR2Conditioning(io.ComfyNode):
|
||||
@ -471,39 +461,12 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
pos_cond = model.positive_conditioning
|
||||
neg_cond = model.negative_conditioning
|
||||
|
||||
# Fail-loud guard against silently-wrong output when a
|
||||
# DiT-only ``.safetensors`` (no ``positive_conditioning`` /
|
||||
# ``negative_conditioning`` keys) is loaded via ``UNETLoader``.
|
||||
# ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see
|
||||
# ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)``
|
||||
# leaves them at zero when the keys are absent. Detect that state
|
||||
# here rather than at ``BaseModel.extra_conds`` (per sampling step,
|
||||
# wasteful) or at the resolver helper (mixes structural shape with
|
||||
# semantic content). Both buffers must be checked together — partial
|
||||
# bake regressions could populate one but not the other.
|
||||
if (
|
||||
pos_cond.float().abs().sum().item() == 0
|
||||
and neg_cond.float().abs().sum().item() == 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning "
|
||||
f"and negative_conditioning buffers are zero-valued — model "
|
||||
f"file appears to be a DiT-only export missing "
|
||||
f"the SeedVR2 conditioning tensors. "
|
||||
f"Re-bake the file with ``positive_conditioning`` (58, 5120) "
|
||||
f"and ``negative_conditioning`` (64, 5120) keys at top level, "
|
||||
f"or load via CheckpointLoaderSimple from a bundled "
|
||||
f"checkpoint."
|
||||
)
|
||||
|
||||
_apply_rope_freqs_float32_cast(model)
|
||||
|
||||
condition = torch.stack([get_conditions(c, c) for c in vae_conditioning])
|
||||
condition = condition.movedim(-1, 1)
|
||||
latent = vae_conditioning.movedim(-1, 1)
|
||||
|
||||
latent = rearrange(latent, "b c t h w -> b (c t) h w")
|
||||
condition = rearrange(condition, "b c t h w -> b (c t) h w")
|
||||
latent = latent.reshape(latent.shape[0], latent.shape[1] * latent.shape[2], latent.shape[3], latent.shape[4])
|
||||
condition = condition.reshape(condition.shape[0], condition.shape[1] * condition.shape[2], condition.shape[3], condition.shape[4])
|
||||
|
||||
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
||||
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
||||
@ -723,7 +686,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
||||
Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that
|
||||
OOM on long sequences. The latent enters the sampler in SeedVR2's
|
||||
collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning``
|
||||
at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that
|
||||
at ``reshape(b, c * t, h, w)``); this node slices that
|
||||
tensor along the temporal axis, runs the configured inner sampler
|
||||
sequentially per chunk against the standard ``comfy.sample.sample``
|
||||
entry point, and concatenates per-chunk outputs back into a single
|
||||
@ -882,7 +845,6 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
||||
"frames_per_chunk=%s.",
|
||||
attempt_frames_per_chunk, attempts[i + 1],
|
||||
)
|
||||
comfy.model_management.soft_empty_cache()
|
||||
|
||||
# Short-circuit: total fits in one chunk -> standard path with no
|
||||
# chunking overhead. Output of this branch is byte-identical to the
|
||||
|
||||
@ -11,7 +11,6 @@ import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -53,7 +52,7 @@ def _import_nodes_seedvr_isolated():
|
||||
mock_mm.WINDOWS = False
|
||||
sys.modules["comfy.model_management"] = mock_mm
|
||||
if sys.modules.get("comfy") is None:
|
||||
import comfy as _comfy_pkg # noqa: F401
|
||||
importlib.import_module("comfy")
|
||||
comfy_pkg = sys.modules.get("comfy")
|
||||
if comfy_pkg is not None:
|
||||
setattr(comfy_pkg, "model_management", mock_mm)
|
||||
@ -95,11 +94,10 @@ class _Block(nn.Module):
|
||||
|
||||
class _DiffusionModel(nn.Module):
|
||||
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
|
||||
def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32):
|
||||
def __init__(self, n_blocks=3, conditioning_dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
|
||||
pos = torch.zeros if zero_conditioning else torch.ones
|
||||
self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype))
|
||||
self.register_buffer("positive_conditioning", torch.ones((2, 4), dtype=conditioning_dtype))
|
||||
self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype))
|
||||
|
||||
|
||||
@ -185,29 +183,3 @@ def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
|
||||
)
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher, vae_conditioning,
|
||||
)
|
||||
|
||||
message = str(excinfo.value)
|
||||
assert message.startswith(
|
||||
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
|
||||
), (
|
||||
"Fail-loud message must use the standard "
|
||||
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
|
||||
f"can match it. Got: {message!r}"
|
||||
)
|
||||
assert "positive_conditioning" in message
|
||||
assert "negative_conditioning" in message
|
||||
finally:
|
||||
restore()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
@ -32,26 +33,19 @@ def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypa
|
||||
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
|
||||
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||
try:
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
||||
decoded, reference, torch.device("cpu"), "lab",
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
assert "color_correction_method=lab" in str(exc)
|
||||
assert " method=lab" not in str(exc)
|
||||
else:
|
||||
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
|
||||
assert "color_correction_method=lab" in str(excinfo.value)
|
||||
assert " method=lab" not in str(excinfo.value)
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
|
||||
decoded = torch.zeros(1, 2, 4, 4, 3)
|
||||
original = torch.zeros(1, 2, 4, 4, 3)
|
||||
try:
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus")
|
||||
except ValueError as exc:
|
||||
assert "color_correction_method" in str(exc)
|
||||
else:
|
||||
raise AssertionError("expected ValueError for unknown color_correction_method")
|
||||
assert "color_correction_method" in str(excinfo.value)
|
||||
|
||||
@ -2,7 +2,7 @@ from collections import defaultdict
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.model_detection import detect_unet_config, model_config_from_unet_config
|
||||
from comfy.model_detection import detect_unet_config, model_config_from_unet, model_config_from_unet_config
|
||||
import comfy.supported_models
|
||||
|
||||
|
||||
@ -76,21 +76,31 @@ def _make_flux_schnell_comfyui_sd():
|
||||
def _make_seedvr2_7b_separate_mm_sd():
|
||||
return {
|
||||
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
|
||||
"positive_conditioning": torch.empty(58, 5120),
|
||||
"negative_conditioning": torch.empty(64, 5120),
|
||||
}
|
||||
|
||||
|
||||
def _make_seedvr2_7b_shared_mm_sd():
|
||||
return {
|
||||
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
|
||||
"positive_conditioning": torch.empty(58, 5120),
|
||||
"negative_conditioning": torch.empty(64, 5120),
|
||||
}
|
||||
|
||||
|
||||
def _make_seedvr2_3b_shared_mm_sd():
|
||||
return {
|
||||
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
|
||||
"positive_conditioning": torch.empty(58, 5120),
|
||||
"negative_conditioning": torch.empty(64, 5120),
|
||||
}
|
||||
|
||||
|
||||
def _add_model_diffusion_prefix(sd):
|
||||
return {f"model.diffusion_model.{k}": v for k, v in sd.items()}
|
||||
|
||||
|
||||
class TestModelDetection:
|
||||
"""Verify that first-match model detection selects the correct model
|
||||
based on list ordering and unet_config specificity."""
|
||||
@ -182,6 +192,20 @@ class TestModelDetection:
|
||||
assert unet_config["num_layers"] == 32
|
||||
assert unet_config["mlp_type"] == "swiglu"
|
||||
|
||||
def test_seedvr2_model_match_requires_conditioning_tensors(self):
|
||||
sd = _make_seedvr2_7b_shared_mm_sd()
|
||||
unet_config = detect_unet_config(sd, "")
|
||||
|
||||
assert type(model_config_from_unet_config(unet_config, sd)).__name__ == "SeedVR2"
|
||||
|
||||
del sd["positive_conditioning"]
|
||||
assert model_config_from_unet_config(unet_config, sd) is None
|
||||
|
||||
def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self):
|
||||
sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd())
|
||||
|
||||
assert type(model_config_from_unet(sd, "model.diffusion_model.")).__name__ == "SeedVR2"
|
||||
|
||||
def test_unet_config_and_required_keys_combination_is_unique(self):
|
||||
"""Each model in the registry must have a unique combination of
|
||||
``unet_config`` and ``required_keys``. If two models share the same
|
||||
|
||||
@ -103,7 +103,7 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa
|
||||
heads=heads,
|
||||
head_dim=head_dim,
|
||||
qk_bias=False,
|
||||
qk_norm=seedvr_model.CustomRMSNorm,
|
||||
qk_norm=comfy_ops.disable_weight_init.RMSNorm,
|
||||
qk_norm_eps=1e-6,
|
||||
rope_type=None,
|
||||
rope_dim=head_dim,
|
||||
|
||||
@ -26,6 +26,7 @@ import comfy.ldm.seedvr.model # noqa: E402
|
||||
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||
import comfy.model_management # noqa: E402
|
||||
import comfy.ops as comfy_ops # noqa: E402
|
||||
import comfy.sample # noqa: E402
|
||||
import comfy.sd as sd_mod # noqa: E402
|
||||
import nodes as nodes_mod # noqa: E402
|
||||
@ -81,6 +82,7 @@ def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> lis
|
||||
txt_in_dim=txt_in_dim,
|
||||
heads=24,
|
||||
mm_layers=3,
|
||||
operations=comfy_ops.disable_weight_init,
|
||||
)
|
||||
|
||||
return flags
|
||||
@ -140,6 +142,46 @@ class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
|
||||
raw_latent = torch.full((1, 16, 1, 4, 5), 2.0)
|
||||
seen_shapes = []
|
||||
|
||||
def base_encode(self, x):
|
||||
seen_shapes.append(tuple(x.shape))
|
||||
return raw_latent.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode)
|
||||
|
||||
vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
||||
nn.Module.__init__(vae)
|
||||
vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32))
|
||||
|
||||
latent = vae.encode(torch.zeros(1, 3, 32, 40))
|
||||
|
||||
assert type(latent) is torch.Tensor
|
||||
assert tuple(latent.shape) == (1, 16, 4, 5)
|
||||
assert seen_shapes == [(1, 3, 1, 32, 40)]
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
|
||||
raw_latent = torch.full((1, 16, 1, 4, 5), 3.0)
|
||||
|
||||
def base_encode(self, x):
|
||||
return raw_latent.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode)
|
||||
|
||||
vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
||||
nn.Module.__init__(vae)
|
||||
vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32))
|
||||
|
||||
latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40))
|
||||
|
||||
assert tuple(latent.shape) == (1, 16, 4, 5)
|
||||
assert tuple(raw.shape) == (1, 16, 1, 4, 5)
|
||||
assert torch.equal(raw, raw_latent)
|
||||
|
||||
|
||||
def _make_vae(wrapper):
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = wrapper
|
||||
@ -155,6 +197,8 @@ def _make_vae(wrapper):
|
||||
vae.extra_1d_channel = None
|
||||
vae.crop_input = False
|
||||
vae.not_video = False
|
||||
vae.handles_tiling = isinstance(wrapper, seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
||||
vae.format_encoded = wrapper.comfy_format_encoded
|
||||
vae.patcher = _Patcher()
|
||||
vae.process_input = lambda image: image
|
||||
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from contextlib import ExitStack
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@ -21,8 +22,6 @@ from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
|
||||
|
||||
|
||||
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
|
||||
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
|
||||
|
||||
class StubVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -37,9 +36,9 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
|
||||
|
||||
def decode_(self, t_chunk):
|
||||
self.decode_min_sizes.append(self.slicing_latent_min_size)
|
||||
return VideoAutoencoderKL.slicing_decode(self, t_chunk)
|
||||
return vae_mod.VideoAutoencoderKL.slicing_decode(self, t_chunk)
|
||||
|
||||
def _decode(self, z, memory_state=MemoryState.DISABLED):
|
||||
def _decode(self, z, memory_state=MemoryState.DISABLED, memory_cache=None):
|
||||
self.memory_states.append(memory_state)
|
||||
b, c, d, h, w = z.shape
|
||||
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
|
||||
@ -68,8 +67,6 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
|
||||
|
||||
|
||||
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
|
||||
from comfy.ldm.seedvr.vae import tiled_vae
|
||||
|
||||
class RaisingVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@ -85,8 +82,7 @@ def test_zero_temporal_size_preserves_min_size_when_encode_raises():
|
||||
vae = RaisingVAEModel()
|
||||
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
|
||||
|
||||
raised = False
|
||||
try:
|
||||
with pytest.raises(RuntimeError, match="simulated encode failure"):
|
||||
tiled_vae(
|
||||
x,
|
||||
vae,
|
||||
@ -96,15 +92,43 @@ def test_zero_temporal_size_preserves_min_size_when_encode_raises():
|
||||
temporal_overlap=0,
|
||||
encode=True,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if "simulated encode failure" not in str(exc):
|
||||
raise
|
||||
raised = True
|
||||
|
||||
assert raised
|
||||
assert vae.slicing_sample_min_size == 4
|
||||
|
||||
|
||||
def test_tiled_vae_encode_uses_tensor_return_without_indexing():
|
||||
class TensorEncodeVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_sample_min_size = 4
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
self.calls = []
|
||||
|
||||
def encode(self, t_chunk):
|
||||
self.calls.append(tuple(t_chunk.shape))
|
||||
b, _, _, h, w = t_chunk.shape
|
||||
return torch.ones((b, 16, 1, h // 8, w // 8), dtype=t_chunk.dtype)
|
||||
|
||||
vae = TensorEncodeVAEModel()
|
||||
x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32)
|
||||
|
||||
out = tiled_vae(
|
||||
x,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=True,
|
||||
)
|
||||
|
||||
assert vae.calls == [(2, 3, 1, 64, 64)]
|
||||
assert tuple(out.shape) == (2, 16, 1, 8, 8)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_seedvr_vae_tiled_temporal_slicing.py
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -126,7 +150,7 @@ class _SlicingDecodeVAE(nn.Module):
|
||||
self.decode_min_sizes.append(self.slicing_latent_min_size)
|
||||
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
|
||||
|
||||
def _decode(self, z, memory_state=MemoryState.DISABLED):
|
||||
def _decode(self, z, memory_state=MemoryState.DISABLED, memory_cache=None):
|
||||
self.memory_states.append(memory_state)
|
||||
x = z[:, :1].repeat(
|
||||
1,
|
||||
@ -205,6 +229,8 @@ def _make_vae(first_stage_model, latent_channels, latent_dim):
|
||||
vae.latent_dim = latent_dim
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.spacial_compression_decode = lambda: 8
|
||||
vae.handles_tiling = isinstance(first_stage_model, seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
||||
vae.format_encoded = None
|
||||
vae.process_input = lambda x: x
|
||||
vae.process_output = lambda x: x
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
@ -240,7 +266,6 @@ def test_4d_seedvr2_latent_routes_to_owned_decode_tiled():
|
||||
|
||||
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
|
||||
first_stage = MagicMock()
|
||||
first_stage.comfy_handles_tiling = False
|
||||
first_stage.decode = MagicMock(side_effect=_force_oom)
|
||||
vae = _make_vae(first_stage, latent_channels=4, latent_dim=2)
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
|
||||
@ -273,6 +298,8 @@ def _populate_common_vae_attrs_fallback(vae):
|
||||
vae.not_video = False
|
||||
vae.crop_input = False
|
||||
vae.pad_channel_value = None
|
||||
vae.handles_tiling = isinstance(vae.first_stage_model, seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
||||
vae.format_encoded = None
|
||||
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.spacial_compression_encode = lambda: 8
|
||||
@ -295,7 +322,6 @@ def _make_seedvr2_vae_fallback():
|
||||
def _make_non_seedvr2_vae_fallback():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = MagicMock()
|
||||
vae.first_stage_model.comfy_handles_tiling = False
|
||||
_populate_common_vae_attrs_fallback(vae)
|
||||
return vae
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user