Compare commits

...

7 Commits

Author SHA1 Message Date
levzzz
5791dfb8b3
Merge 7658d7c155 into cd8c7a2306 2026-05-06 23:44:07 -03:00
levzzz
7658d7c155
Merge branch 'master' into ref-latents-anima-pr 2026-04-15 19:17:26 +05:00
levzzz5154
ebdeb9caa7 Fix: memory estimation for Anima with reference latent(s) 2026-04-15 00:28:19 +05:00
levzzz
35f54f3920
Merge branch 'master' into ref-latents-anima-pr 2026-04-15 00:25:02 +05:00
levzzz5154
312adfc6f1 Merge remote-tracking branch 'myfork/ref-latents-anima-pr' into ref-latents-anima-pr
# Conflicts:
#	comfy/ldm/cosmos/predict2.py
#	comfy/model_base.py
2026-04-14 00:34:38 +05:00
levzzz5154
80bb64f770 feat: Add reference latent support for Anima 2026-04-14 00:33:40 +05:00
levzzz5154
ad647ea8ab feat: Add reference latent support for Anima 2026-04-14 00:28:10 +05:00
2 changed files with 23 additions and 0 deletions

View File

@ -839,6 +839,14 @@ class MiniTrainDIT(nn.Module):
**kwargs,
):
orig_shape = list(x.shape)
ref_latents = kwargs.get('ref_latents', None)
if ref_latents is not None:
for ref in ref_latents:
if ref.ndim == 4:
ref = ref.unsqueeze(2)
x = torch.cat([x, ref.to(dtype=x.dtype, device=x.device)], dim=2)
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
x_B_C_T_H_W = x
timesteps_B_T = timesteps

View File

@ -1212,6 +1212,7 @@ class CosmosPredict2(BaseModel):
class Anima(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima)
self.memory_usage_factor_conds = ("ref_latents",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@ -1232,6 +1233,20 @@ class Anima(BaseModel):
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
latents = []
for lat in ref_latents:
latents.append(self.process_latent_in(lat))
out['ref_latents'] = comfy.conds.CONDList(latents)
return out
def extra_conds_shapes(self, **kwargs):
out = {}
ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None:
out['ref_latents'] = [1, 16, sum(math.prod(lat.size()[2:]) for lat in ref_latents)]
return out
class Lumina2(BaseModel):