mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-14 23:42:35 +08:00
Support fp16 for Cosmos-Predict2 and Anima (#12249)
This commit is contained in:
parent
204e65b8dc
commit
6a26328842
@ -335,7 +335,7 @@ class FinalLayer(nn.Module):
|
|||||||
device=None, dtype=None, operations=None
|
device=None, dtype=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
self.linear = operations.Linear(
|
self.linear = operations.Linear(
|
||||||
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
@ -463,6 +463,8 @@ class Block(nn.Module):
|
|||||||
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
|
||||||
transformer_options: Optional[dict] = {},
|
transformer_options: Optional[dict] = {},
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
residual_dtype = x_B_T_H_W_D.dtype
|
||||||
|
compute_dtype = emb_B_T_D.dtype
|
||||||
if extra_per_block_pos_emb is not None:
|
if extra_per_block_pos_emb is not None:
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
|
||||||
|
|
||||||
@ -512,7 +514,7 @@ class Block(nn.Module):
|
|||||||
result_B_T_H_W_D = rearrange(
|
result_B_T_H_W_D = rearrange(
|
||||||
self.self_attn(
|
self.self_attn(
|
||||||
# normalized_x_B_T_HW_D,
|
# normalized_x_B_T_HW_D,
|
||||||
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||||
None,
|
None,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
@ -522,7 +524,7 @@ class Block(nn.Module):
|
|||||||
h=H,
|
h=H,
|
||||||
w=W,
|
w=W,
|
||||||
)
|
)
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||||
|
|
||||||
def _x_fn(
|
def _x_fn(
|
||||||
_x_B_T_H_W_D: torch.Tensor,
|
_x_B_T_H_W_D: torch.Tensor,
|
||||||
@ -536,7 +538,7 @@ class Block(nn.Module):
|
|||||||
)
|
)
|
||||||
_result_B_T_H_W_D = rearrange(
|
_result_B_T_H_W_D = rearrange(
|
||||||
self.cross_attn(
|
self.cross_attn(
|
||||||
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
|
rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"),
|
||||||
crossattn_emb,
|
crossattn_emb,
|
||||||
rope_emb=rope_emb_L_1_1_D,
|
rope_emb=rope_emb_L_1_1_D,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
@ -555,7 +557,7 @@ class Block(nn.Module):
|
|||||||
shift_cross_attn_B_T_1_1_D,
|
shift_cross_attn_B_T_1_1_D,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
|
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||||
|
|
||||||
normalized_x_B_T_H_W_D = _fn(
|
normalized_x_B_T_H_W_D = _fn(
|
||||||
x_B_T_H_W_D,
|
x_B_T_H_W_D,
|
||||||
@ -563,8 +565,8 @@ class Block(nn.Module):
|
|||||||
scale_mlp_B_T_1_1_D,
|
scale_mlp_B_T_1_1_D,
|
||||||
shift_mlp_B_T_1_1_D,
|
shift_mlp_B_T_1_1_D,
|
||||||
)
|
)
|
||||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D)
|
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D
|
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||||
return x_B_T_H_W_D
|
return x_B_T_H_W_D
|
||||||
|
|
||||||
|
|
||||||
@ -876,6 +878,14 @@ class MiniTrainDIT(nn.Module):
|
|||||||
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
|
||||||
"transformer_options": kwargs.get("transformer_options", {}),
|
"transformer_options": kwargs.get("transformer_options", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream
|
||||||
|
# in fp32, but run attention and MLP modules in fp16.
|
||||||
|
# An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable
|
||||||
|
# quality degradation and visual artifacts.
|
||||||
|
if x_B_T_H_W_D.dtype == torch.float16:
|
||||||
|
x_B_T_H_W_D = x_B_T_H_W_D.float()
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x_B_T_H_W_D = block(
|
x_B_T_H_W_D = block(
|
||||||
x_B_T_H_W_D,
|
x_B_T_H_W_D,
|
||||||
|
|||||||
@ -993,7 +993,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
|||||||
|
|
||||||
memory_usage_factor = 1.0
|
memory_usage_factor = 1.0
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
def __init__(self, unet_config):
|
def __init__(self, unet_config):
|
||||||
super().__init__(unet_config)
|
super().__init__(unet_config)
|
||||||
@ -1023,7 +1023,7 @@ class Anima(supported_models_base.BASE):
|
|||||||
|
|
||||||
memory_usage_factor = 1.0
|
memory_usage_factor = 1.0
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
def __init__(self, unet_config):
|
def __init__(self, unet_config):
|
||||||
super().__init__(unet_config)
|
super().__init__(unet_config)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user