mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
feat: Add reference latent support for Anima
This commit is contained in:
parent
acd718598e
commit
80bb64f770
@ -839,6 +839,14 @@ class MiniTrainDIT(nn.Module):
|
||||
**kwargs,
|
||||
):
|
||||
orig_shape = list(x.shape)
|
||||
|
||||
ref_latents = kwargs.get('ref_latents', None)
|
||||
if ref_latents is not None:
|
||||
for ref in ref_latents:
|
||||
if ref.ndim == 4:
|
||||
ref = ref.unsqueeze(2)
|
||||
x = torch.cat([x, ref.to(dtype=x.dtype, device=x.device)], dim=2)
|
||||
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
|
||||
x_B_C_T_H_W = x
|
||||
timesteps_B_T = timesteps
|
||||
|
||||
@ -1221,6 +1221,13 @@ class Anima(BaseModel):
|
||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
return out
|
||||
|
||||
class Lumina2(BaseModel):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user