mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-16 08:22:36 +08:00
final changes
This commit is contained in:
parent
4908e7412e
commit
4653b9008d
@ -19,6 +19,8 @@ class Output:
|
|||||||
|
|
||||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True, resize_mode="bicubic"):
|
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True, resize_mode="bicubic"):
|
||||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||||
|
if image.dtype == torch.uint8:
|
||||||
|
image = image.float() / 255.0
|
||||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||||
image = image.movedim(-1, 1)
|
image = image.movedim(-1, 1)
|
||||||
|
|||||||
@ -885,6 +885,8 @@ class HunyuanVideoFoley(nn.Module):
|
|||||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||||
clip_feat = clip_feat.view(2, -1, 768)
|
clip_feat = clip_feat.view(2, -1, 768)
|
||||||
|
|
||||||
|
self.conditions = (clip_feat, sync_feat, cond)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
clip_feat, sync_feat, cond = self.conditions
|
clip_feat, sync_feat, cond = self.conditions
|
||||||
|
|
||||||
|
|||||||
@ -957,6 +957,7 @@ class Synchformer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, vis):
|
def forward(self, vis):
|
||||||
|
vis = vis.to(next(self.parameters()).dtype)
|
||||||
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
|
vis = vis.permute(0, 1, 3, 2, 4, 5) # (B, S, C, Tv, H, W)
|
||||||
vis = self.vfeat_extractor(vis)
|
vis = self.vfeat_extractor(vis)
|
||||||
return vis
|
return vis
|
||||||
|
|||||||
@ -223,6 +223,7 @@ class FoleyVae(torch.nn.Module):
|
|||||||
return self.synchformer(x)
|
return self.synchformer(x)
|
||||||
|
|
||||||
def video_encoding(self, video, step):
|
def video_encoding(self, video, step):
|
||||||
|
video = video.to(torch.uint8)
|
||||||
video = torch.stack([self.syncformer_preprocess(t) for t in video])
|
video = torch.stack([self.syncformer_preprocess(t) for t in video])
|
||||||
|
|
||||||
t, c, h, w = video.shape
|
t, c, h, w = video.shape
|
||||||
@ -237,8 +238,9 @@ class FoleyVae(torch.nn.Module):
|
|||||||
data = video.as_strided(
|
data = video.as_strided(
|
||||||
size=(nseg, seg_len, c, h, w),
|
size=(nseg, seg_len, c, h, w),
|
||||||
stride=(stride_t * step, stride_t, stride_c, stride_h, stride_w),
|
stride=(stride_t * step, stride_t, stride_c, stride_h, stride_w),
|
||||||
)
|
).contiguous()
|
||||||
data = data.unsqueeze(0) # b
|
data = data.unsqueeze(0) # b
|
||||||
data = rearrange(data, "b s t c h w -> (b s) 1 t c h w")
|
data = rearrange(data, "b s t c h w -> (b s) 1 t c h w")
|
||||||
|
data = data.float()
|
||||||
|
|
||||||
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=1)
|
return data, nseg, lambda x: rearrange(x, "(b s) 1 t d -> b (s t) d", b=1)
|
||||||
|
|||||||
@ -1165,23 +1165,14 @@ class MultiheadAttentionComfyv(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, src, k=None, v=None, attn_mask=None, key_padding_mask=None):
|
def forward(self, src, k=None, v=None, attn_mask=None, key_padding_mask=None):
|
||||||
|
self._q_proj, self._k_proj, self._v_proj = [
|
||||||
|
t.to(src.device).to(src.dtype)
|
||||||
|
for t in (self._q_proj, self._k_proj, self._v_proj)
|
||||||
|
]
|
||||||
|
|
||||||
self._q_proj, self._k_proj, self._v_proj = [t.to(src.device).to(src.dtype) for t in (self._q_proj, self._k_proj, self._v_proj)]
|
|
||||||
q = self._q_proj(src)
|
q = self._q_proj(src)
|
||||||
if k is None:
|
k = self._k_proj(src if k is None else k.to(src.device).to(src.dtype))
|
||||||
k = self._k_proj(src)
|
v = self._v_proj(src if v is None else v.to(src.device).to(src.dtype))
|
||||||
if v is None:
|
|
||||||
v = self._v_proj(src)
|
|
||||||
k, v = k.to(src.device).to(src.dtype), v.to(src.device).to(src.dtype)
|
|
||||||
|
|
||||||
if k is v:
|
|
||||||
if q is k:
|
|
||||||
q = k = v = q.transpose(1, 0)
|
|
||||||
else:
|
|
||||||
q, k = (x.transpose(1, 0) for x in (q, k))
|
|
||||||
v = k
|
|
||||||
else:
|
|
||||||
q, k, v = (x.transpose(1, 0) for x in (q, k, v))
|
|
||||||
|
|
||||||
output = optimized_attention(q, k, v, self.num_heads, mask=attn_mask)
|
output = optimized_attention(q, k, v, self.num_heads, mask=attn_mask)
|
||||||
return self.out_proj(output)
|
return self.out_proj(output)
|
||||||
|
|||||||
@ -94,13 +94,15 @@ class EncodeVideo(io.ComfyNode):
|
|||||||
chunk = chunk.to(model_dtype)
|
chunk = chunk.to(model_dtype)
|
||||||
if hasattr(vae, "encode"):
|
if hasattr(vae, "encode"):
|
||||||
try:
|
try:
|
||||||
|
if chunk.ndim > 5:
|
||||||
|
raise ValueError("chunk.ndim > 5")
|
||||||
chunk = chunk.movedim(1, -1)
|
chunk = chunk.movedim(1, -1)
|
||||||
out = vae.encode(chunk)
|
out = vae.encode(chunk)
|
||||||
except:
|
except Exception:
|
||||||
out = model.encode(chunk)
|
out = model.encode(chunk)
|
||||||
else:
|
else:
|
||||||
chunk = chunk.movedim(1, -1)
|
chunk = chunk.movedim(1, -1)
|
||||||
out = vae.encode_image(chunk, crop=False, resize_mode="bilinear")
|
out = vae.encode_image(chunk.to(torch.uint8), crop=False, resize_mode="bilinear")
|
||||||
out = out["image_embeds"]
|
out = out["image_embeds"]
|
||||||
|
|
||||||
out_cpu = out.cpu()
|
out_cpu = out.cpu()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user