mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
outputs/speed/memory match custom node
This commit is contained in:
parent
1afc2ed8e6
commit
7b2e5ef0af
@ -491,6 +491,11 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
|||||||
"mmrope_freqs_3d",
|
"mmrope_freqs_3d",
|
||||||
lambda: self.get_freqs(vid_shape, txt_shape),
|
lambda: self.get_freqs(vid_shape, txt_shape),
|
||||||
)
|
)
|
||||||
|
target_device = vid_q.device
|
||||||
|
if vid_freqs.device != target_device:
|
||||||
|
vid_freqs = vid_freqs.to(target_device)
|
||||||
|
if txt_freqs.device != target_device:
|
||||||
|
txt_freqs = txt_freqs.to(target_device)
|
||||||
vid_q = rearrange(vid_q, "L h d -> h L d")
|
vid_q = rearrange(vid_q, "L h d -> h L d")
|
||||||
vid_k = rearrange(vid_k, "L h d -> h L d")
|
vid_k = rearrange(vid_k, "L h d -> h L d")
|
||||||
vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
|
vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
|
||||||
@ -506,6 +511,7 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
|||||||
txt_k = rearrange(txt_k, "h L d -> L h d")
|
txt_k = rearrange(txt_k, "h L d -> L h d")
|
||||||
return vid_q, vid_k, txt_q, txt_k
|
return vid_q, vid_k, txt_q, txt_k
|
||||||
|
|
||||||
|
@torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks
|
||||||
def get_freqs(
|
def get_freqs(
|
||||||
self,
|
self,
|
||||||
vid_shape: torch.LongTensor,
|
vid_shape: torch.LongTensor,
|
||||||
@ -514,8 +520,29 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
|||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
]:
|
]:
|
||||||
vid_freqs = self.get_axial_freqs(1024, 128, 128)
|
|
||||||
txt_freqs = self.get_axial_freqs(1024)
|
# Calculate actual max dimensions needed for this batch
|
||||||
|
max_temporal = 0
|
||||||
|
max_height = 0
|
||||||
|
max_width = 0
|
||||||
|
max_txt_len = 0
|
||||||
|
|
||||||
|
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
||||||
|
max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal
|
||||||
|
max_height = max(max_height, h)
|
||||||
|
max_width = max(max_width, w)
|
||||||
|
max_txt_len = max(max_txt_len, l)
|
||||||
|
|
||||||
|
# Compute frequencies for actual max dimensions needed
|
||||||
|
# Add small buffer to improve cache hits across similar batches
|
||||||
|
vid_freqs = self.get_axial_freqs(
|
||||||
|
min(max_temporal + 16, 1024), # Cap at 1024, add small buffer
|
||||||
|
min(max_height + 4, 128), # Cap at 128, add small buffer
|
||||||
|
min(max_width + 4, 128) # Cap at 128, add small buffer
|
||||||
|
)
|
||||||
|
txt_freqs = self.get_axial_freqs(min(max_txt_len + 16, 1024))
|
||||||
|
|
||||||
|
# Now slice as before
|
||||||
vid_freq_list, txt_freq_list = [], []
|
vid_freq_list, txt_freq_list = [], []
|
||||||
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
||||||
vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1))
|
vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1))
|
||||||
|
|||||||
@ -65,9 +65,9 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
|||||||
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
t_chunk = spatial_tile[:, :, i : i + input_chunk, :, :]
|
||||||
|
|
||||||
if encode:
|
if encode:
|
||||||
out = vae_model.encode(t_chunk)
|
out = vae_model.slicing_encode(t_chunk)
|
||||||
else:
|
else:
|
||||||
out = vae_model.decode_(t_chunk)
|
out = vae_model.slicing_decode(t_chunk)
|
||||||
|
|
||||||
if isinstance(out, (tuple, list)): out = out[0]
|
if isinstance(out, (tuple, list)): out = out[0]
|
||||||
|
|
||||||
@ -245,6 +245,11 @@ def cut_videos(videos):
|
|||||||
assert (videos.size(1) - 1) % (4) == 0
|
assert (videos.size(1) - 1) % (4) == 0
|
||||||
return videos
|
return videos
|
||||||
|
|
||||||
|
def side_resize(image, size):
|
||||||
|
antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps')
|
||||||
|
resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias)
|
||||||
|
return resized
|
||||||
|
|
||||||
class SeedVR2InputProcessing(io.ComfyNode):
|
class SeedVR2InputProcessing(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -285,7 +290,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
max_area = ((resolution_height * resolution_width)** 0.5) ** 2
|
max_area = ((resolution_height * resolution_width)** 0.5) ** 2
|
||||||
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||||
normalize = Normalize(0.5, 0.5)
|
normalize = Normalize(0.5, 0.5)
|
||||||
images = area_resize(images, max_area)
|
#images = area_resize(images, max_area)
|
||||||
|
images = side_resize(images, resolution_height)
|
||||||
|
|
||||||
images = clip(images)
|
images = clip(images)
|
||||||
o_h, o_w = images.shape[-2:]
|
o_h, o_w = images.shape[-2:]
|
||||||
@ -348,7 +354,7 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
|
|
||||||
noises = torch.randn_like(vae_conditioning).to(device)
|
noises = torch.randn_like(vae_conditioning).to(device)
|
||||||
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
||||||
|
aug_noises = noises * 0.1 + aug_noises * 0.05
|
||||||
cond_noise_scale = 0.0
|
cond_noise_scale = 0.0
|
||||||
t = (
|
t = (
|
||||||
torch.tensor([1000.0])
|
torch.tensor([1000.0])
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user