diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 2268bff38..b9871e15d 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -839,6 +839,15 @@ class MiniTrainDIT(nn.Module): **kwargs, ): orig_shape = list(x.shape) + + # Reference latents: concat along temporal dim (Flux2-style) + ref_latents = kwargs.get('ref_latents', None) + if ref_latents is not None: + for ref in ref_latents: + if ref.ndim == 4: + ref = ref.unsqueeze(2) + x = torch.cat([x, ref.to(dtype=x.dtype, device=x.device)], dim=2) + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial)) x_B_C_T_H_W = x timesteps_B_T = timesteps diff --git a/comfy/model_base.py b/comfy/model_base.py index 5c2668ba9..69ae578aa 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1221,6 +1221,14 @@ class Anima(BaseModel): out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + # Reference latents support (Flux2-style temporal concatenation) + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) return out class Lumina2(BaseModel):