feat: Add reference latent support for Anima

This commit is contained in:
levzzz5154 2026-04-14 00:28:10 +05:00
parent acd718598e
commit ad647ea8ab
2 changed files with 17 additions and 0 deletions

View File

@ -839,6 +839,15 @@ class MiniTrainDIT(nn.Module):
**kwargs,
):
orig_shape = list(x.shape)
# Reference latents: concat along temporal dim (Flux2-style)
ref_latents = kwargs.get('ref_latents', None)
if ref_latents is not None:
for ref in ref_latents:
if ref.ndim == 4:
ref = ref.unsqueeze(2)
x = torch.cat([x, ref.to(dtype=x.dtype, device=x.device)], dim=2)
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
x_B_C_T_H_W = x
timesteps_B_T = timesteps

View File

@ -1221,6 +1221,14 @@ class Anima(BaseModel):
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
# Reference latents support (Flux2-style temporal concatenation)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
class Lumina2(BaseModel):