mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
fix: make attention_mask optional in LTXBaseModel.forward (fixes #13299)
The LTXBaseModel.forward and _forward methods required attention_mask as a positional argument with no default value. However, LTXV.extra_conds only conditionally adds attention_mask to model_conds when it is present in kwargs. If attention_mask is not provided by the text encoder, the diffusion_model forward call fails with: TypeError: LTXBaseModel.forward() missing 1 required positional argument: 'attention_mask' The model already handles attention_mask=None correctly in both _prepare_attention_mask and _prepare_context, so making the parameter optional is the minimal safe fix. This also aligns with how LTXAVDoubleStreamBlock.forward handles the parameter.
This commit is contained in:
parent
6648f31c55
commit
b8bb5427f9
@ -858,7 +858,7 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
self, x, timestep, context, attention_mask=None, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Forward pass for LTX models.
|
Forward pass for LTX models.
|
||||||
@ -867,7 +867,7 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
x: Input tensor
|
x: Input tensor
|
||||||
timestep: Timestep tensor
|
timestep: Timestep tensor
|
||||||
context: Context tensor (e.g., text embeddings)
|
context: Context tensor (e.g., text embeddings)
|
||||||
attention_mask: Attention mask tensor
|
attention_mask: Attention mask tensor (optional)
|
||||||
frame_rate: Frame rate for temporal processing
|
frame_rate: Frame rate for temporal processing
|
||||||
transformer_options: Additional options for transformer blocks
|
transformer_options: Additional options for transformer blocks
|
||||||
keyframe_idxs: Keyframe indices for temporal processing
|
keyframe_idxs: Keyframe indices for temporal processing
|
||||||
@ -885,7 +885,7 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
self, x, timestep, context, attention_mask=None, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Internal forward pass for LTX models.
|
Internal forward pass for LTX models.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user