diff --git a/comfy/model_base.py b/comfy/model_base.py index fb9dcb032..07066569a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1201,6 +1201,7 @@ class CosmosPredict2(BaseModel): class Anima(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1230,6 +1231,13 @@ class Anima(BaseModel): out['ref_latents'] = comfy.conds.CONDList(latents) return out + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = [1, 16, sum(math.prod(lat.size()[2:]) for lat in ref_latents)] + return out + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)