mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-23 17:02:38 +08:00
Merge 7658d7c155 into 3d816db07f
This commit is contained in:
commit
173532e006
@ -839,6 +839,14 @@ class MiniTrainDIT(nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
orig_shape = list(x.shape)
|
orig_shape = list(x.shape)
|
||||||
|
|
||||||
|
ref_latents = kwargs.get('ref_latents', None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
for ref in ref_latents:
|
||||||
|
if ref.ndim == 4:
|
||||||
|
ref = ref.unsqueeze(2)
|
||||||
|
x = torch.cat([x, ref.to(dtype=x.dtype, device=x.device)], dim=2)
|
||||||
|
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial))
|
||||||
x_B_C_T_H_W = x
|
x_B_C_T_H_W = x
|
||||||
timesteps_B_T = timesteps
|
timesteps_B_T = timesteps
|
||||||
|
|||||||
@ -1201,6 +1201,7 @@ class CosmosPredict2(BaseModel):
|
|||||||
class Anima(BaseModel):
|
class Anima(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima)
|
||||||
|
self.memory_usage_factor_conds = ("ref_latents",)
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
@ -1221,6 +1222,20 @@ class Anima(BaseModel):
|
|||||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||||
|
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
latents = []
|
||||||
|
for lat in ref_latents:
|
||||||
|
latents.append(self.process_latent_in(lat))
|
||||||
|
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def extra_conds_shapes(self, **kwargs):
|
||||||
|
out = {}
|
||||||
|
ref_latents = kwargs.get("reference_latents", None)
|
||||||
|
if ref_latents is not None:
|
||||||
|
out['ref_latents'] = [1, 16, sum(math.prod(lat.size()[2:]) for lat in ref_latents)]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class Lumina2(BaseModel):
|
class Lumina2(BaseModel):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user