mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
large optimizations and some fixes
This commit is contained in:
parent
663d971830
commit
4b6c08110d
@ -635,23 +635,45 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def _ceil_div(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
def find_period_by_first_row(mat):
|
||||
|
||||
L, _ = mat.shape
|
||||
|
||||
first = mat[0:1]
|
||||
matches = (mat[1:] == first).all(dim=1)
|
||||
candidate_positions = (torch.nonzero(matches).squeeze(-1) + 1).tolist()
|
||||
if isinstance(candidate_positions, int):
|
||||
candidate_positions = [candidate_positions]
|
||||
if not candidate_positions:
|
||||
return L
|
||||
|
||||
for p in sorted(candidate_positions):
|
||||
base = mat[:p]
|
||||
reps = _ceil_div(L, p)
|
||||
tiled = base.repeat(reps, 1)[:L]
|
||||
if torch.equal(tiled, mat):
|
||||
return p
|
||||
|
||||
for p in range(1, L + 1):
|
||||
base = mat[:p]
|
||||
reps = _ceil_div(L, p)
|
||||
tiled = base.repeat(reps, 1)[:L]
|
||||
if torch.equal(tiled, mat):
|
||||
return p
|
||||
|
||||
return L
|
||||
|
||||
def trim_repeats(expanded):
|
||||
_, L, D = expanded.shape
|
||||
seq = expanded[0]
|
||||
p_len = find_period_by_first_row(seq)
|
||||
|
||||
repeat_len = L
|
||||
for k in range(1, L // 2 + 1):
|
||||
if torch.equal(seq[:k], seq[k:2*k]):
|
||||
repeat_len = k
|
||||
break
|
||||
seq_T = seq.transpose(0, 1)
|
||||
p_dim = find_period_by_first_row(seq_T)
|
||||
|
||||
repeat_dim = D
|
||||
for k in range(1, D // 2 + 1):
|
||||
if torch.equal(seq[:, :k], seq[:, k:2*k]):
|
||||
repeat_dim = k
|
||||
break
|
||||
|
||||
return expanded[:, :repeat_len, :repeat_dim]
|
||||
return expanded[:, :p_len, :p_dim]
|
||||
|
||||
class HunyuanVideoFoley(nn.Module):
|
||||
def __init__(
|
||||
@ -845,11 +867,12 @@ class HunyuanVideoFoley(nn.Module):
|
||||
|
||||
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
|
||||
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
|
||||
cond_pos, cond_neg = trim_repeats(cond_pos), trim_repeats(cond_neg)
|
||||
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||
|
||||
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
|
||||
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
|
||||
|
||||
uncond_1, clip_feat = uncond_1.to(device, non_blocking = True), clip_feat.to(device, non_blocking=True)
|
||||
uncond_2, sync_feat = uncond_2.to(device, non_blocking = True), sync_feat.to(device, non_blocking=True)
|
||||
cond_neg, cond_pos = cond_neg.to(device, non_blocking = True), cond_pos.to(device, non_blocking=True)
|
||||
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [t.to(device, allow_gpu=True) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||
|
||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||
|
||||
|
||||
@ -211,16 +211,24 @@ class FoleyVae(torch.nn.Module):
|
||||
return self.syncformer(x)
|
||||
|
||||
def video_encoding(self, video, step: int):
|
||||
t, h, w, c = video.shape
|
||||
|
||||
if not isinstance(video, torch.Tensor):
|
||||
video = torch.from_numpy(video).permute(0, 3, 1, 2)
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
video = self.syncformer_preprocess(video).unsqueeze(0)
|
||||
video = video.permute(0, 3, 1, 2)
|
||||
|
||||
video = torch.stack([self.syncformer_preprocess(t) for t in video]).unsqueeze(0)
|
||||
seg_len = 16
|
||||
t = video.size(1)
|
||||
t = video.size(0)
|
||||
nseg = max(0, (t - seg_len) // step + 1)
|
||||
clips = [video[:, i*step:i*step + seg_len] for i in range(nseg)]
|
||||
data = torch.stack(clips, dim=1)
|
||||
stride_t, stride_c, stride_h, stride_w = video.stride()
|
||||
|
||||
# no copies
|
||||
data = video.as_strided(
|
||||
size=(nseg, seg_len, c, h, w),
|
||||
stride=(stride_t * step, stride_t, stride_c, stride_h, stride_w),
|
||||
)
|
||||
data = rearrange(data, "b s t c h w -> (b s) 1 t c h w")
|
||||
|
||||
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=video.size(0))
|
||||
|
||||
@ -26,6 +26,54 @@ class EmptyLatentHunyuanFoley(io.ComfyNode):
|
||||
latent = torch.randn(shape, device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent, "type": "hunyuan_foley"}, )
|
||||
|
||||
class CpuLockedTensor(torch.Tensor):
|
||||
def __new__(cls, data):
|
||||
base = torch.as_tensor(data, device='cpu')
|
||||
return torch.Tensor._make_subclass(cls, base, require_grad=False)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
# if any of the args/kwargs were CpuLockedTensor, it will cause infinite recursion
|
||||
def unwrap(x):
|
||||
return x.as_subclass(torch.Tensor) if isinstance(x, CpuLockedTensor) else x
|
||||
|
||||
unwrapped_args = torch.utils._pytree.tree_map(unwrap, args)
|
||||
unwrapped_kwargs = torch.utils._pytree.tree_map(unwrap, kwargs)
|
||||
|
||||
result = func(*unwrapped_args, **unwrapped_kwargs)
|
||||
|
||||
# rewrap the resulted tensors
|
||||
if isinstance(result, torch.Tensor):
|
||||
return CpuLockedTensor(result.detach().cpu())
|
||||
elif isinstance(result, (list, tuple)):
|
||||
return type(result)(
|
||||
CpuLockedTensor(x.detach().cpu()) if isinstance(x, torch.Tensor) else x
|
||||
for x in result
|
||||
)
|
||||
return result
|
||||
|
||||
def to(self, *args, allow_gpu=False, **kwargs):
|
||||
if allow_gpu:
|
||||
return super().to(*args, **kwargs)
|
||||
return self.detach().clone().cpu()
|
||||
|
||||
def cuda(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def cpu(self):
|
||||
return self
|
||||
|
||||
def pin_memory(self):
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
out = super().detach()
|
||||
return CpuLockedTensor(out)
|
||||
|
||||
class HunyuanFoleyConditioning(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -53,6 +101,10 @@ class HunyuanFoleyConditioning(io.ComfyNode):
|
||||
max_d = max([t.size(2) for t in all_])
|
||||
|
||||
def repeat_shapes(max_value, input, dim = 1):
|
||||
|
||||
if input.shape[dim] == max_value:
|
||||
return input
|
||||
|
||||
# temporary repeat values on the cpu
|
||||
factor_pos, remainder = divmod(max_value, input.shape[dim])
|
||||
|
||||
@ -61,19 +113,28 @@ class HunyuanFoleyConditioning(io.ComfyNode):
|
||||
input = input.cpu().repeat(*positions)
|
||||
|
||||
if remainder > 0:
|
||||
pad = input[:, :remainder, :]
|
||||
input = torch.cat([input, pad], dim =1)
|
||||
if dim == 1:
|
||||
pad = input[:, :remainder, :]
|
||||
else:
|
||||
pad = input[:, :, :remainder]
|
||||
input = torch.cat([input, pad], dim = dim)
|
||||
|
||||
return input
|
||||
|
||||
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_l, t) for t in all_]
|
||||
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in all_]
|
||||
siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative = [repeat_shapes(max_d, t, dim = 2) for t in
|
||||
(siglip_encoding_1, synchformer_encoding_2, text_encoding_positive, text_encoding_negative)]
|
||||
|
||||
embeds = torch.cat([siglip_encoding_1.cpu(), synchformer_encoding_2.cpu()], dim = 0)
|
||||
|
||||
x = siglip_encoding_1
|
||||
negative = [[torch.cat([torch.zeros_like(embeds), text_encoding_negative]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
|
||||
positive = [[torch.cat([embeds, text_encoding_positive]).contiguous().view(1, -1, x.size(-1)).pin_memory(), {}]]
|
||||
positive_tensor = CpuLockedTensor(torch.cat([torch.zeros_like(embeds), text_encoding_negative])
|
||||
.contiguous().view(1, -1, x.size(-1)))
|
||||
negative_tensor = CpuLockedTensor(torch.cat([embeds, text_encoding_positive])
|
||||
.contiguous().view(1, -1, x.size(-1)))
|
||||
|
||||
negative = [[positive_tensor, {}]]
|
||||
positive = [[negative_tensor, {}]]
|
||||
|
||||
return io.NodeOutput(positive, negative)
|
||||
|
||||
|
||||
@ -52,6 +52,7 @@ class EncodeVideo(io.ComfyNode):
|
||||
def execute(cls, video, processing_batch_size, step_size, vae = None, clip_vision = None):
|
||||
|
||||
t, c, h, w = video.shape
|
||||
device = video.device
|
||||
b = 1
|
||||
batch_size = b * t
|
||||
|
||||
@ -62,6 +63,8 @@ class EncodeVideo(io.ComfyNode):
|
||||
model = vae.first_stage_model if vae is not None else clip_vision.model
|
||||
vae = vae if vae is not None else clip_vision
|
||||
|
||||
# should be the offload device
|
||||
video = video.cpu()
|
||||
if hasattr(model, "video_encoding"):
|
||||
data, num_segments, output_fn = model.video_encoding(video, step_size)
|
||||
batch_size = b * num_segments
|
||||
@ -72,25 +75,31 @@ class EncodeVideo(io.ComfyNode):
|
||||
if processing_batch_size != -1:
|
||||
batch_size = processing_batch_size
|
||||
|
||||
outputs = []
|
||||
outputs = None
|
||||
total = data.shape[0]
|
||||
pbar = comfy.utils.ProgressBar(total/batch_size)
|
||||
with torch.inference_mode():
|
||||
for i in range(0, total, batch_size):
|
||||
chunk = data[i : i + batch_size]
|
||||
chunk = data[i : i + batch_size].to(device, non_blocking = True)
|
||||
if hasattr(vae, "encode"):
|
||||
out = vae.encode(chunk)
|
||||
else:
|
||||
out = vae.encode_image(chunk)
|
||||
out = out["image_embeds"]
|
||||
outputs.append(out)
|
||||
del out, chunk
|
||||
|
||||
out_cpu = out.cpu()
|
||||
if outputs is None:
|
||||
full_shape = (total, *out_cpu.shape[1:])
|
||||
outputs = torch.empty(full_shape, dtype=out_cpu.dtype, pin_memory=True)
|
||||
|
||||
chunk_len = out_cpu.shape[0]
|
||||
outputs[i : i + chunk_len].copy_(out_cpu)
|
||||
|
||||
del out, chunk, out_cpu
|
||||
torch.cuda.empty_cache()
|
||||
pbar.update(1)
|
||||
|
||||
output = torch.cat(outputs)
|
||||
|
||||
return io.NodeOutput(output_fn(output))
|
||||
return io.NodeOutput(output_fn(outputs))
|
||||
|
||||
class ResampleVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
|
||||
Loading…
Reference in New Issue
Block a user