From ad647ea8ab0e58e3ae05ebaeeaad7607825425d7 Mon Sep 17 00:00:00 2001 From: levzzz5154 Date: Tue, 14 Apr 2026 00:28:10 +0500 Subject: [PATCH 1/3] feat: Add reference latent support for Anima --- comfy/ldm/cosmos/predict2.py | 9 +++++++++ comfy/model_base.py | 8 ++++++++ 2 files changed, 17 insertions(+) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 2268bff38..b9871e15d 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -839,6 +839,15 @@ class MiniTrainDIT(nn.Module): **kwargs, ): orig_shape = list(x.shape) + + # Reference latents: concat along temporal dim (Flux2-style) + ref_latents = kwargs.get('ref_latents', None) + if ref_latents is not None: + for ref in ref_latents: + if ref.ndim == 4: + ref = ref.unsqueeze(2) + x = torch.cat([x, ref.to(dtype=x.dtype, device=x.device)], dim=2) + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial)) x_B_C_T_H_W = x timesteps_B_T = timesteps diff --git a/comfy/model_base.py b/comfy/model_base.py index 5c2668ba9..69ae578aa 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1221,6 +1221,14 @@ class Anima(BaseModel): out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + # Reference latents support (Flux2-style temporal concatenation) + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) return out class Lumina2(BaseModel): From 80bb64f7708180c99e555b2520712d059213e85d Mon Sep 17 00:00:00 2001 From: levzzz5154 Date: Tue, 14 Apr 2026 00:28:10 +0500 Subject: [PATCH 2/3] feat: Add reference latent support for Anima --- comfy/ldm/cosmos/predict2.py | 8 ++++++++ comfy/model_base.py | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 2268bff38..d7ba24782 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -839,6 +839,14 @@ class MiniTrainDIT(nn.Module): **kwargs, ): orig_shape = list(x.shape) + + ref_latents = kwargs.get('ref_latents', None) + if ref_latents is not None: + for ref in ref_latents: + if ref.ndim == 4: + ref = ref.unsqueeze(2) + x = torch.cat([x, ref.to(dtype=x.dtype, device=x.device)], dim=2) + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial)) x_B_C_T_H_W = x timesteps_B_T = timesteps diff --git a/comfy/model_base.py b/comfy/model_base.py index 5c2668ba9..fb9dcb032 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1221,6 +1221,13 @@ class Anima(BaseModel): out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) return out class Lumina2(BaseModel): From ebdeb9caa7a0e6486440860ff3f01701fe66d3b5 Mon Sep 17 00:00:00 2001 From: levzzz5154 Date: Wed, 15 Apr 2026 00:28:19 +0500 Subject: [PATCH 3/3] Fix: memory estimation for Anima with reference latent(s) --- comfy/model_base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index fb9dcb032..07066569a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1201,6 +1201,7 @@ class CosmosPredict2(BaseModel): class Anima(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1230,6 +1231,13 @@ class Anima(BaseModel): out['ref_latents'] = comfy.conds.CONDList(latents) return out + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = [1, 16, sum(math.prod(lat.size()[2:]) for lat in ref_latents)] + return out + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)