Merge branch 'master' into dr-support-pip-cm

This commit is contained in:
Dr.Lt.Data 2025-09-13 07:30:55 +09:00
commit 9d70d75f20
44 changed files with 1625 additions and 977 deletions

View File

@ -606,6 +606,11 @@ class HunyuanImage21(LatentFormat):
latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206]
class HunyuanImage21Refiner(LatentFormat):
latent_channels = 64
latent_dimensions = 3
scale_factor = 1.03682
class Hunyuan3Dv2(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -133,6 +133,7 @@ class Attention(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
transformer_options={},
**cross_attention_kwargs,
) -> torch.Tensor:
return self.processor(
@ -140,6 +141,7 @@ class Attention(nn.Module):
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
transformer_options=transformer_options,
**cross_attention_kwargs,
)
@ -366,6 +368,7 @@ class CustomerAttnProcessor2_0:
encoder_attention_mask: Optional[torch.FloatTensor] = None,
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
transformer_options={},
*args,
**kwargs,
) -> torch.Tensor:
@ -433,7 +436,7 @@ class CustomerAttnProcessor2_0:
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = optimized_attention(
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True, transformer_options=transformer_options,
).to(query.dtype)
# linear proj
@ -697,6 +700,7 @@ class LinearTransformerBlock(nn.Module):
rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
temb: torch.FloatTensor = None,
transformer_options={},
):
N = hidden_states.shape[0]
@ -720,6 +724,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
else:
attn_output, _ = self.attn(
@ -729,6 +734,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=None,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=None,
transformer_options=transformer_options,
)
if self.use_adaln_single:
@ -743,6 +749,7 @@ class LinearTransformerBlock(nn.Module):
encoder_attention_mask=encoder_attention_mask,
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=rotary_freqs_cis_cross,
transformer_options=transformer_options,
)
hidden_states = attn_output + hidden_states

View File

@ -314,6 +314,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length: int = 0,
block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
controlnet_scale: Union[float, torch.Tensor] = 1.0,
transformer_options={},
):
embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
temb = self.t_block(embedded_timestep)
@ -339,6 +340,7 @@ class ACEStepTransformer2DModel(nn.Module):
rotary_freqs_cis=rotary_freqs_cis,
rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
temb=temb,
transformer_options=transformer_options,
)
output = self.final_layer(hidden_states, embedded_timestep, output_length)
@ -393,6 +395,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length = hidden_states.shape[-1]
transformer_options = kwargs.get("transformer_options", {})
output = self.decode(
hidden_states=hidden_states,
attention_mask=attention_mask,
@ -402,6 +405,7 @@ class ACEStepTransformer2DModel(nn.Module):
output_length=output_length,
block_controlnet_hidden_states=block_controlnet_hidden_states,
controlnet_scale=controlnet_scale,
transformer_options=transformer_options,
)
return output

View File

@ -298,7 +298,8 @@ class Attention(nn.Module):
mask = None,
context_mask = None,
rotary_pos_emb = None,
causal = None
causal = None,
transformer_options={},
):
h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
@ -363,7 +364,7 @@ class Attention(nn.Module):
heads_per_kv_head = h // kv_h
k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
out = optimized_attention(q, k, v, h, skip_reshape=True)
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = self.to_out(out)
if mask is not None:
@ -488,7 +489,8 @@ class TransformerBlock(nn.Module):
global_cond=None,
mask = None,
context_mask = None,
rotary_pos_emb = None
rotary_pos_emb = None,
transformer_options={}
):
if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
@ -498,12 +500,12 @@ class TransformerBlock(nn.Module):
residual = x
x = self.pre_norm(x)
x = x * (1 + scale_self) + shift_self
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
x = x * torch.sigmoid(1 - gate_self)
x = x + residual
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None:
x = x + self.conformer(x)
@ -517,10 +519,10 @@ class TransformerBlock(nn.Module):
x = x + residual
else:
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb, transformer_options=transformer_options)
if context is not None:
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask, transformer_options=transformer_options)
if self.conformer is not None:
x = x + self.conformer(x)
@ -606,7 +608,8 @@ class ContinuousTransformer(nn.Module):
return_info = False,
**kwargs
):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
transformer_options = kwargs.get("transformer_options", {})
patches_replace = transformer_options.get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]
@ -645,13 +648,13 @@ class ContinuousTransformer(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
if return_info:

View File

@ -85,7 +85,7 @@ class SingleAttention(nn.Module):
)
#@torch.compile()
def forward(self, c):
def forward(self, c, transformer_options={}):
bsz, seqlen1, _ = c.shape
@ -95,7 +95,7 @@ class SingleAttention(nn.Module):
v = v.view(bsz, seqlen1, self.n_heads, self.head_dim)
q, k = self.q_norm1(q), self.k_norm1(k)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c = self.w1o(output)
return c
@ -144,7 +144,7 @@ class DoubleAttention(nn.Module):
#@torch.compile()
def forward(self, c, x):
def forward(self, c, x, transformer_options={}):
bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape
@ -168,7 +168,7 @@ class DoubleAttention(nn.Module):
torch.cat([cv, xv], dim=1),
)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True)
output = optimized_attention(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3), self.n_heads, skip_reshape=True, transformer_options=transformer_options)
c, x = output.split([seqlen1, seqlen2], dim=1)
c = self.w1o(c)
@ -207,7 +207,7 @@ class MMDiTBlock(nn.Module):
self.is_last = is_last
#@torch.compile()
def forward(self, c, x, global_cond, **kwargs):
def forward(self, c, x, global_cond, transformer_options={}, **kwargs):
cres, xres = c, x
@ -225,7 +225,7 @@ class MMDiTBlock(nn.Module):
x = modulate(self.normX1(x), xshift_msa, xscale_msa)
# attention
c, x = self.attn(c, x)
c, x = self.attn(c, x, transformer_options=transformer_options)
c = self.normC2(cres + cgate_msa.unsqueeze(1) * c)
@ -255,13 +255,13 @@ class DiTBlock(nn.Module):
self.mlp = MLP(dim, hidden_dim=dim * 4, dtype=dtype, device=device, operations=operations)
#@torch.compile()
def forward(self, cx, global_cond, **kwargs):
def forward(self, cx, global_cond, transformer_options={}, **kwargs):
cxres = cx
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.modCX(
global_cond
).chunk(6, dim=1)
cx = modulate(self.norm1(cx), shift_msa, scale_msa)
cx = self.attn(cx)
cx = self.attn(cx, transformer_options=transformer_options)
cx = self.norm2(cxres + gate_msa.unsqueeze(1) * cx)
mlpout = self.mlp(modulate(cx, shift_mlp, scale_mlp))
cx = gate_mlp.unsqueeze(1) * mlpout
@ -473,13 +473,14 @@ class MMDiT(nn.Module):
out = {}
out["txt"], out["img"] = layer(args["txt"],
args["img"],
args["vec"])
args["vec"],
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": c, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
c = out["txt"]
x = out["img"]
else:
c, x = layer(c, x, global_cond, **kwargs)
c, x = layer(c, x, global_cond, transformer_options=transformer_options, **kwargs)
if len(self.single_layers) > 0:
c_len = c.size(1)
@ -488,13 +489,13 @@ class MMDiT(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], args["vec"])
out["img"] = layer(args["img"], args["vec"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": cx, "vec": global_cond, "transformer_options": transformer_options}, {"original_block": block_wrap})
cx = out["img"]
else:
cx = layer(cx, global_cond, **kwargs)
cx = layer(cx, global_cond, transformer_options=transformer_options, **kwargs)
x = cx[:, c_len:]

View File

@ -32,12 +32,12 @@ class OptimizedAttention(nn.Module):
self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
def forward(self, q, k, v):
def forward(self, q, k, v, transformer_options={}):
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
out = optimized_attention(q, k, v, self.heads)
out = optimized_attention(q, k, v, self.heads, transformer_options=transformer_options)
return self.out_proj(out)
@ -47,13 +47,13 @@ class Attention2D(nn.Module):
self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
# self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
def forward(self, x, kv, self_attn=False):
def forward(self, x, kv, self_attn=False, transformer_options={}):
orig_shape = x.shape
x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
if self_attn:
kv = torch.cat([x, kv], dim=1)
# x = self.attn(x, kv, kv, need_weights=False)[0]
x = self.attn(x, kv, kv)
x = self.attn(x, kv, kv, transformer_options=transformer_options)
x = x.permute(0, 2, 1).view(*orig_shape)
return x
@ -114,9 +114,9 @@ class AttnBlock(nn.Module):
operations.Linear(c_cond, c, dtype=dtype, device=device)
)
def forward(self, x, kv):
def forward(self, x, kv, transformer_options={}):
kv = self.kv_mapper(kv)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn, transformer_options=transformer_options)
return x

View File

@ -173,7 +173,7 @@ class StageB(nn.Module):
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip):
def _down_encode(self, x, r_embed, clip, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@ -187,7 +187,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@ -199,7 +199,7 @@ class StageB(nn.Module):
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip):
def _up_decode(self, level_outputs, r_embed, clip, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@ -216,7 +216,7 @@ class StageB(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@ -228,7 +228,7 @@ class StageB(nn.Module):
x = upscaler(x)
return x
def forward(self, x, r, effnet, clip, pixels=None, **kwargs):
def forward(self, x, r, effnet, clip, pixels=None, transformer_options={}, **kwargs):
if pixels is None:
pixels = x.new_zeros(x.size(0), 3, 8, 8)
@ -245,8 +245,8 @@ class StageB(nn.Module):
nn.functional.interpolate(effnet, size=x.shape[-2:], mode='bilinear', align_corners=True))
x = x + nn.functional.interpolate(self.pixels_mapper(pixels), size=x.shape[-2:], mode='bilinear',
align_corners=True)
level_outputs = self._down_encode(x, r_embed, clip)
x = self._up_decode(level_outputs, r_embed, clip)
level_outputs = self._down_encode(x, r_embed, clip, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, transformer_options=transformer_options)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):

View File

@ -182,7 +182,7 @@ class StageC(nn.Module):
clip = self.clip_norm(clip)
return clip
def _down_encode(self, x, r_embed, clip, cnet=None):
def _down_encode(self, x, r_embed, clip, cnet=None, transformer_options={}):
level_outputs = []
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
for down_block, downscaler, repmap in block_group:
@ -201,7 +201,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@ -213,7 +213,7 @@ class StageC(nn.Module):
level_outputs.insert(0, x)
return level_outputs
def _up_decode(self, level_outputs, r_embed, clip, cnet=None):
def _up_decode(self, level_outputs, r_embed, clip, cnet=None, transformer_options={}):
x = level_outputs[0]
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
for i, (up_block, upscaler, repmap) in enumerate(block_group):
@ -235,7 +235,7 @@ class StageC(nn.Module):
elif isinstance(block, AttnBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
AttnBlock)):
x = block(x, clip)
x = block(x, clip, transformer_options=transformer_options)
elif isinstance(block, TimestepBlock) or (
hasattr(block, '_fsdp_wrapped_module') and isinstance(block._fsdp_wrapped_module,
TimestepBlock)):
@ -247,7 +247,7 @@ class StageC(nn.Module):
x = upscaler(x)
return x
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, **kwargs):
def forward(self, x, r, clip_text, clip_text_pooled, clip_img, control=None, transformer_options={}, **kwargs):
# Process the conditioning embeddings
r_embed = self.gen_r_embedding(r).to(dtype=x.dtype)
for c in self.t_conds:
@ -262,8 +262,8 @@ class StageC(nn.Module):
# Model Blocks
x = self.embedding(x)
level_outputs = self._down_encode(x, r_embed, clip, cnet)
x = self._up_decode(level_outputs, r_embed, clip, cnet)
level_outputs = self._down_encode(x, r_embed, clip, cnet, transformer_options=transformer_options)
x = self._up_decode(level_outputs, r_embed, clip, cnet, transformer_options=transformer_options)
return self.clf(x)
def update_weights_ema(self, src_model, beta=0.999):

View File

@ -76,7 +76,7 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
@ -95,7 +95,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
@ -148,7 +148,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@ -157,7 +157,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)

View File

@ -193,14 +193,16 @@ class Chroma(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": double_mod,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@ -209,7 +211,8 @@ class Chroma(nn.Module):
txt=txt,
vec=double_mod,
pe=pe,
attn_mask=attn_mask)
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@ -229,17 +232,19 @@ class Chroma(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": single_mod,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")

View File

@ -176,6 +176,7 @@ class Attention(nn.Module):
context=None,
mask=None,
rope_emb=None,
transformer_options={},
**kwargs,
):
"""
@ -184,7 +185,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True, transformer_options=transformer_options)
del q, k, v
out = rearrange(out, " b n s c -> s b (n c)")
return self.to_out(out)
@ -546,6 +547,7 @@ class VideoAttn(nn.Module):
context: Optional[torch.Tensor] = None,
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for video attention.
@ -571,6 +573,7 @@ class VideoAttn(nn.Module):
context_M_B_D,
crossattn_mask,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
x_T_H_W_B_D = rearrange(x_THW_B_D, "(t h w) b d -> t h w b d", h=H, w=W)
return x_T_H_W_B_D
@ -665,6 +668,7 @@ class DITBuildingBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Forward pass for dynamically configured blocks with adaptive normalization.
@ -702,6 +706,7 @@ class DITBuildingBlock(nn.Module):
adaln_norm_state(self.norm_state, x, scale_1_1_1_B_D, shift_1_1_1_B_D),
context=None,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
elif self.block_type in ["cross_attn", "ca"]:
x = x + gate_1_1_1_B_D * self.block(
@ -709,6 +714,7 @@ class DITBuildingBlock(nn.Module):
context=crossattn_emb,
crossattn_mask=crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
transformer_options=transformer_options,
)
else:
raise ValueError(f"Unknown block type: {self.block_type}")
@ -784,6 +790,7 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
for block in self.blocks:
x = block(
@ -793,5 +800,6 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
transformer_options=transformer_options,
)
return x

View File

@ -520,6 +520,7 @@ class GeneralDIT(nn.Module):
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
), f"{x.shape} != {extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape} {original_shape}"
transformer_options = kwargs.get("transformer_options", {})
for _, block in self.blocks.items():
assert (
self.blocks["block0"].x_format == block.x_format
@ -534,6 +535,7 @@ class GeneralDIT(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")

View File

@ -44,7 +44,7 @@ class GPT2FeedForward(nn.Module):
return x
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor) -> torch.Tensor:
def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H_D: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
"""Computes multi-head attention using PyTorch's native implementation.
This function provides a PyTorch backend alternative to Transformer Engine's attention operation.
@ -71,7 +71,7 @@ def torch_attention_op(q_B_S_H_D: torch.Tensor, k_B_S_H_D: torch.Tensor, v_B_S_H
q_B_H_S_D = rearrange(q_B_S_H_D, "b ... h k -> b h ... k").view(in_q_shape[0], in_q_shape[-2], -1, in_q_shape[-1])
k_B_H_S_D = rearrange(k_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
v_B_H_S_D = rearrange(v_B_S_H_D, "b ... h v -> b h ... v").view(in_k_shape[0], in_k_shape[-2], -1, in_k_shape[-1])
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True)
return optimized_attention(q_B_H_S_D, k_B_H_S_D, v_B_H_S_D, in_q_shape[-2], skip_reshape=True, transformer_options=transformer_options)
class Attention(nn.Module):
@ -180,8 +180,8 @@ class Attention(nn.Module):
return q, k, v
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
result = self.attn_op(q, k, v) # [B, S, H, D]
def compute_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, transformer_options: Optional[dict] = {}) -> torch.Tensor:
result = self.attn_op(q, k, v, transformer_options=transformer_options) # [B, S, H, D]
return self.output_dropout(self.output_proj(result))
def forward(
@ -189,6 +189,7 @@ class Attention(nn.Module):
x: torch.Tensor,
context: Optional[torch.Tensor] = None,
rope_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
"""
Args:
@ -196,7 +197,7 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
return self.compute_attention(q, k, v)
return self.compute_attention(q, k, v, transformer_options=transformer_options)
class Timesteps(nn.Module):
@ -459,6 +460,7 @@ class Block(nn.Module):
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_T_3D: Optional[torch.Tensor] = None,
extra_per_block_pos_emb: Optional[torch.Tensor] = None,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
if extra_per_block_pos_emb is not None:
x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb
@ -512,6 +514,7 @@ class Block(nn.Module):
rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
None,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@ -525,6 +528,7 @@ class Block(nn.Module):
layer_norm_cross_attn: Callable,
_scale_cross_attn_B_T_1_1_D: torch.Tensor,
_shift_cross_attn_B_T_1_1_D: torch.Tensor,
transformer_options: Optional[dict] = {},
) -> torch.Tensor:
_normalized_x_B_T_H_W_D = _fn(
_x_B_T_H_W_D, layer_norm_cross_attn, _scale_cross_attn_B_T_1_1_D, _shift_cross_attn_B_T_1_1_D
@ -534,6 +538,7 @@ class Block(nn.Module):
rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"),
crossattn_emb,
rope_emb=rope_emb_L_1_1_D,
transformer_options=transformer_options,
),
"b (t h w) d -> b t h w d",
t=T,
@ -547,6 +552,7 @@ class Block(nn.Module):
self.layer_norm_cross_attn,
scale_cross_attn_B_T_1_1_D,
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D
@ -865,6 +871,7 @@ class MiniTrainDIT(nn.Module):
"rope_emb_L_1_1_D": rope_emb_L_1_1_D.unsqueeze(1).unsqueeze(0),
"adaln_lora_B_T_3D": adaln_lora_B_T_3D,
"extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
"transformer_options": kwargs.get("transformer_options", {}),
}
for block in self.blocks:
x_B_T_H_W_D = block(

View File

@ -159,7 +159,7 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
@ -182,7 +182,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((img_q, txt_q), dim=2),
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask)
pe=pe, mask=attn_mask, transformer_options=transformer_options)
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else:
@ -190,7 +190,7 @@ class DoubleStreamBlock(nn.Module):
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
@ -244,7 +244,7 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
mod, _ = self.modulation(vec)
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
@ -252,7 +252,7 @@ class SingleStreamBlock(nn.Module):
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)

View File

@ -6,7 +6,7 @@ from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape
@ -17,7 +17,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x

View File

@ -144,14 +144,16 @@ class Flux(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@ -160,7 +162,8 @@ class Flux(nn.Module):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
attn_mask=attn_mask,
transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@ -181,17 +184,19 @@ class Flux(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"))
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")

View File

@ -109,6 +109,7 @@ class AsymmetricAttention(nn.Module):
scale_x: torch.Tensor, # (B, dim_x), modulation for pre-RMSNorm.
scale_y: torch.Tensor, # (B, dim_y), modulation for pre-RMSNorm.
crop_y,
transformer_options={},
**rope_rotation,
) -> Tuple[torch.Tensor, torch.Tensor]:
rope_cos = rope_rotation.get("rope_cos")
@ -143,7 +144,7 @@ class AsymmetricAttention(nn.Module):
xy = optimized_attention(q,
k,
v, self.num_heads, skip_reshape=True)
v, self.num_heads, skip_reshape=True, transformer_options=transformer_options)
x, y = torch.tensor_split(xy, (q_x.shape[1],), dim=1)
x = self.proj_x(x)
@ -224,6 +225,7 @@ class AsymmetricJointBlock(nn.Module):
x: torch.Tensor,
c: torch.Tensor,
y: torch.Tensor,
transformer_options={},
**attn_kwargs,
):
"""Forward pass of a block.
@ -256,6 +258,7 @@ class AsymmetricJointBlock(nn.Module):
y,
scale_x=scale_msa_x,
scale_y=scale_msa_y,
transformer_options=transformer_options,
**attn_kwargs,
)
@ -524,10 +527,11 @@ class AsymmDiTJoint(nn.Module):
args["txt"],
rope_cos=args["rope_cos"],
rope_sin=args["rope_sin"],
crop_y=args["num_tokens"]
crop_y=args["num_tokens"],
transformer_options=args["transformer_options"]
)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": y_feat, "vec": c, "rope_cos": rope_cos, "rope_sin": rope_sin, "num_tokens": num_tokens, "transformer_options": transformer_options}, {"original_block": block_wrap})
y_feat = out["txt"]
x = out["img"]
else:
@ -538,6 +542,7 @@ class AsymmDiTJoint(nn.Module):
rope_cos=rope_cos,
rope_sin=rope_sin,
crop_y=num_tokens,
transformer_options=transformer_options,
) # (B, M, D), (B, L, D)
del y_feat # Final layers don't use dense text features.

View File

@ -72,8 +72,8 @@ class TimestepEmbed(nn.Module):
return t_emb
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, transformer_options={}):
return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2], transformer_options=transformer_options)
class HiDreamAttnProcessor_flashattn:
@ -86,6 +86,7 @@ class HiDreamAttnProcessor_flashattn:
image_tokens_masks: Optional[torch.FloatTensor] = None,
text_tokens: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
*args,
**kwargs,
) -> torch.FloatTensor:
@ -133,7 +134,7 @@ class HiDreamAttnProcessor_flashattn:
query = torch.cat([query_1, query_2], dim=-1)
key = torch.cat([key_1, key_2], dim=-1)
hidden_states = attention(query, key, value)
hidden_states = attention(query, key, value, transformer_options=transformer_options)
if not attn.single:
hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
@ -199,6 +200,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks: torch.FloatTensor = None,
norm_text_tokens: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.Tensor:
return self.processor(
self,
@ -206,6 +208,7 @@ class HiDreamAttention(nn.Module):
image_tokens_masks = image_tokens_masks,
text_tokens = norm_text_tokens,
rope = rope,
transformer_options=transformer_options,
)
@ -406,7 +409,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
@ -419,6 +422,7 @@ class HiDreamImageSingleTransformerBlock(nn.Module):
norm_image_tokens,
image_tokens_masks,
rope = rope,
transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@ -483,6 +487,7 @@ class HiDreamImageTransformerBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: Optional[torch.FloatTensor] = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
wtype = image_tokens.dtype
shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
@ -500,6 +505,7 @@ class HiDreamImageTransformerBlock(nn.Module):
image_tokens_masks,
norm_text_tokens,
rope = rope,
transformer_options=transformer_options,
)
image_tokens = gate_msa_i * attn_output_i + image_tokens
@ -550,6 +556,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens: Optional[torch.FloatTensor] = None,
adaln_input: torch.FloatTensor = None,
rope: torch.FloatTensor = None,
transformer_options={},
) -> torch.FloatTensor:
return self.block(
image_tokens,
@ -557,6 +564,7 @@ class HiDreamImageBlock(nn.Module):
text_tokens,
adaln_input,
rope,
transformer_options=transformer_options,
)
@ -786,6 +794,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens = cur_encoder_hidden_states,
adaln_input = adaln_input,
rope = rope,
transformer_options=transformer_options,
)
initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
block_id += 1
@ -809,6 +818,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
text_tokens=None,
adaln_input=adaln_input,
rope=rope,
transformer_options=transformer_options,
)
hidden_states = hidden_states[:, :hidden_states_seq_len]
block_id += 1

View File

@ -99,14 +99,16 @@ class Hunyuan3Dv2(nn.Module):
txt=args["txt"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": img,
"txt": txt,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
@ -115,7 +117,8 @@ class Hunyuan3Dv2(nn.Module):
txt=txt,
vec=vec,
pe=pe,
attn_mask=attn_mask)
attn_mask=attn_mask,
transformer_options=transformer_options)
img = torch.cat((txt, img), 1)
@ -126,17 +129,19 @@ class Hunyuan3Dv2(nn.Module):
out["img"] = block(args["img"],
vec=args["vec"],
pe=args["pe"],
attn_mask=args.get("attn_mask"))
attn_mask=args.get("attn_mask"),
transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": img,
"vec": vec,
"pe": pe,
"attn_mask": attn_mask},
"attn_mask": attn_mask,
"transformer_options": transformer_options},
{"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
img = img[:, txt.shape[1]:, ...]
img = self.final_layer(img, vec)

View File

@ -80,13 +80,13 @@ class TokenRefinerBlock(nn.Module):
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, x, c, mask):
def forward(self, x, c, mask, transformer_options={}):
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn.qkv(norm_x)
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True, transformer_options=transformer_options)
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
@ -117,14 +117,14 @@ class IndividualTokenRefiner(nn.Module):
]
)
def forward(self, x, c, mask):
def forward(self, x, c, mask, transformer_options={}):
m = None
if mask is not None:
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
m = m + m.transpose(2, 3)
for block in self.blocks:
x = block(x, c, m)
x = block(x, c, m, transformer_options=transformer_options)
return x
@ -152,6 +152,7 @@ class TokenRefiner(nn.Module):
x,
timesteps,
mask,
transformer_options={},
):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1)
@ -160,7 +161,7 @@ class TokenRefiner(nn.Module):
c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x)
x = self.individual_token_refiner(x, c, mask)
x = self.individual_token_refiner(x, c, mask, transformer_options=transformer_options)
return x
@ -278,6 +279,7 @@ class HunyuanVideo(nn.Module):
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
disable_time_r=False,
control=None,
transformer_options={},
) -> Tensor:
@ -288,7 +290,7 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if self.time_r_in is not None:
if (self.time_r_in is not None) and (not disable_time_r):
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
if len(w) > 0:
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
@ -327,7 +329,7 @@ class HunyuanVideo(nn.Module):
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
txt = self.txt_in(txt, timesteps, txt_mask)
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5)
@ -351,14 +353,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt, 'transformer_options': transformer_options}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt, transformer_options=transformer_options)
if control is not None: # Controlnet
control_i = control.get("input")
@ -373,13 +375,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims, 'transformer_options': transformer_options}, {"original_block": block_wrap})
img = out["img"]
else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims, transformer_options=transformer_options)
if control is not None: # Controlnet
control_o = control.get("output")
@ -428,14 +430,14 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, control, transformer_options, **kwargs)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
bs = x.shape[0]
if len(self.patch_size) == 3:
img_ids = self.img_ids(x)
@ -443,5 +445,5 @@ class HunyuanVideo(nn.Module):
else:
img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out

View File

@ -0,0 +1,268 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
import comfy.ops
import comfy.ldm.models.autoencoder
ops = comfy.ops.disable_weight_init
class RMS_norm(nn.Module):
def __init__(self, dim):
super().__init__()
shape = (dim, 1, 1, 1)
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.empty(shape))
def forward(self, x):
return F.normalize(x, dim=1) * self.scale * self.gamma
class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
self.conv = VideoConv3d(ic, oc // fct, kernel_size=3)
self.tds = tds
self.gs = fct * ic // oc
def forward(self, x):
r1 = 2 if self.tds else 1
h = self.conv(x)
if self.tds:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
hf = hf.permute(0, 4, 6, 1, 2, 3, 5)
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
hf = torch.cat([hf, hf], dim=1)
hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nf = frms // r1
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
h = torch.cat([hf, hn], dim=2)
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
xf = xf.reshape(b, ci, f, ht // 2, 2, wd // 2, 2)
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
B, C, T, H, W = xf.shape
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
xn = x[:, :, 1:, :, :]
b, ci, frms, ht, wd = xn.shape
nf = frms // r1
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = xn.shape
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nf = frms // r1
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
b, ci, frms, ht, wd = x.shape
nf = frms // r1
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = sc.shape
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
return h + sc
class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = VideoConv3d(ic, oc * fct, kernel_size=3)
self.tus = tus
self.rp = fct * oc // ic
def forward(self, x):
r1 = 2 if self.tus else 1
h = self.conv(x)
if self.tus:
hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape
nc = c // (2 * 2)
hf = hf.reshape(b, 2, 2, nc, f, ht, wd)
hf = hf.permute(0, 3, 4, 5, 1, 6, 2)
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
hf = hf[:, : hf.shape[1] // 2]
hn = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nc = c // (r1 * 2 * 2)
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
h = torch.cat([hf, hn], dim=2)
xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape
xf = xf.repeat_interleave(repeats=self.rp // 2, dim=1)
b, c, f, ht, wd = xf.shape
nc = c // (2 * 2)
xf = xf.reshape(b, 2, 2, nc, f, ht, wd)
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
xn = x[:, :, 1:, :, :]
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = xn.shape
nc = c // (r1 * 2 * 2)
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = sc.shape
nc = c // (r1 * 2 * 2)
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
return h + sc
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
super().__init__()
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
self.conv_in = VideoConv3d(in_channels, block_out_channels[0], 3, 1, 1)
self.down = nn.ModuleList()
ch = block_out_channels[0]
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = ((ffactor_spatial // ffactor_temporal) >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
stage.downsample = DnSmpl(ch, nxt, tds=i >= depth_temporal)
ch = nxt
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
def forward(self, x):
x = x.unsqueeze(2)
x = self.conv_in(x)
for stage in self.down:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
b, c, t, h, w = x.shape
grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip
out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
class Decoder(nn.Module):
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
super().__init__()
block_out_channels = block_out_channels[::-1]
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, 3)
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=RMS_norm)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=VideoConv3d, norm_op=RMS_norm)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
depth_temporal = (ffactor_temporal >> 1).bit_length()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
stage.upsample = UpSmpl(ch, nxt, tus=i < depth_temporal)
ch = nxt
self.up.append(stage)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, 3)
def forward(self, z):
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for stage in self.up:
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'upsample'):
x = stage.upsample(x)
return self.conv_out(F.silu(self.norm_out(x)))

View File

@ -271,7 +271,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None):
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
@ -285,9 +285,9 @@ class CrossAttention(nn.Module):
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out)
@ -303,12 +303,12 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask)
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
@ -479,10 +479,10 @@ class LTXVModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@ -490,7 +490,8 @@ class LTXVModel(torch.nn.Module):
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
pe=pe,
transformer_options=transformer_options,
)
# 3. Output

View File

@ -104,6 +104,7 @@ class JointAttention(nn.Module):
x: torch.Tensor,
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
transformer_options={},
) -> torch.Tensor:
"""
@ -140,7 +141,7 @@ class JointAttention(nn.Module):
if n_rep >= 1:
xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True, transformer_options=transformer_options)
return self.out(output)
@ -268,6 +269,7 @@ class JointTransformerBlock(nn.Module):
x_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor]=None,
transformer_options={},
):
"""
Perform a forward pass through the TransformerBlock.
@ -290,6 +292,7 @@ class JointTransformerBlock(nn.Module):
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
@ -304,6 +307,7 @@ class JointTransformerBlock(nn.Module):
self.attention_norm1(x),
x_mask,
freqs_cis,
transformer_options=transformer_options,
)
)
x = x + self.ffn_norm2(
@ -494,7 +498,7 @@ class NextDiT(nn.Module):
return imgs
def patchify_and_embed(
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
bsz = len(x)
pH = pW = self.patch_size
@ -554,7 +558,7 @@ class NextDiT(nn.Module):
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
# refine image
flat_x = []
@ -573,7 +577,7 @@ class NextDiT(nn.Module):
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
@ -616,12 +620,13 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device)
for layer in self.layers:
x = layer(x, mask, freqs_cis, adaln_input)
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
x = self.final_layer(x, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]

View File

@ -26,6 +26,12 @@ class DiagonalGaussianRegularizer(torch.nn.Module):
z = posterior.mode()
return z, None
class EmptyRegularizer(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
return z, None
class AbstractAutoencoder(torch.nn.Module):
"""

View File

@ -5,8 +5,9 @@ import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from typing import Optional
from typing import Optional, Any, Callable, Union
import logging
import functools
from .diffusionmodules.util import AlphaBlender, timestep_embedding
from .sub_quadratic_attention import efficient_dot_product_attention
@ -17,23 +18,45 @@ if model_management.xformers_enabled():
import xformers
import xformers.ops
if model_management.sage_attention_enabled():
SAGE_ATTENTION_IS_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError as e:
if model_management.sage_attention_enabled():
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
exit(-1)
if model_management.flash_attention_enabled():
FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
except ModuleNotFoundError:
if model_management.flash_attention_enabled():
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
if name not in REGISTERED_ATTENTION_FUNCTIONS:
REGISTERED_ATTENTION_FUNCTIONS[name] = func
else:
logging.warning(f"Attention function {name} already registered, skipping registration.")
def get_attention_function(name: str, default: Any=...) -> Union[Callable, None]:
if name == "optimized":
return optimized_attention
elif name not in REGISTERED_ATTENTION_FUNCTIONS:
if default is ...:
raise KeyError(f"Attention function {name} not found.")
else:
return default
return REGISTERED_ATTENTION_FUNCTIONS[name]
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@ -91,7 +114,27 @@ class FeedForward(nn.Module):
def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
def wrap_attn(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
remove_attn_wrapper_key = False
try:
if "_inside_attn_wrapper" not in kwargs:
transformer_options = kwargs.get("transformer_options", None)
remove_attn_wrapper_key = True
kwargs["_inside_attn_wrapper"] = True
if transformer_options is not None:
if "optimized_attention_override" in transformer_options:
return transformer_options["optimized_attention_override"](func, *args, **kwargs)
return func(*args, **kwargs)
finally:
if remove_attn_wrapper_key:
del kwargs["_inside_attn_wrapper"]
return wrapper
@wrap_attn
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@ -159,8 +202,8 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape:
@ -230,7 +273,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
@ -359,7 +403,8 @@ try:
except:
pass
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
b = q.shape[0]
dim_head = q.shape[-1]
# check to make sure xformers isn't broken
@ -374,7 +419,7 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
disabled_xformers = True
if disabled_xformers:
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape, **kwargs)
if skip_reshape:
# b h k d -> b k h d
@ -427,8 +472,8 @@ else:
#TODO: other GPUs ?
SDP_BATCH_LIMIT = 2**31
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@ -470,8 +515,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@ -501,7 +546,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, **kwargs)
if tensor_layout == "HND":
if not skip_output_reshape:
@ -534,8 +579,8 @@ except AttributeError as error:
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
@wrap_attn
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
@ -597,6 +642,19 @@ else:
optimized_attention_masked = optimized_attention
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split)
def optimized_attention_for_device(device, mask=False, small_input=False):
if small_input:
if model_management.pytorch_attention_enabled():
@ -629,7 +687,7 @@ class CrossAttention(nn.Module):
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, value=None, mask=None):
def forward(self, x, context=None, value=None, mask=None, transformer_options={}):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
@ -640,9 +698,9 @@ class CrossAttention(nn.Module):
v = self.to_v(context)
if mask is None:
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
return self.to_out(out)
@ -746,7 +804,7 @@ class BasicTransformerBlock(nn.Module):
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n)
else:
n = self.attn1(n, context=context_attn1, value=value_attn1)
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=transformer_options)
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
@ -786,7 +844,7 @@ class BasicTransformerBlock(nn.Module):
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=transformer_options)
if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
@ -1017,7 +1075,7 @@ class SpatialVideoTransformer(SpatialTransformer):
B, S, C = x_mix.shape
x_mix = rearrange(x_mix, "(b t) s c -> (b s) t c", t=timesteps)
x_mix = mix_block(x_mix, context=time_context) #TODO: transformer_options
x_mix = mix_block(x_mix, context=time_context, transformer_options=transformer_options)
x_mix = rearrange(
x_mix, "(b s) t c -> (b t) s c", s=S, b=B // timesteps, c=C, t=timesteps
)

View File

@ -606,7 +606,7 @@ def block_mixing(*args, use_checkpoint=True, **kwargs):
return _block_mixing(*args, **kwargs)
def _block_mixing(context, x, context_block, x_block, c):
def _block_mixing(context, x, context_block, x_block, c, transformer_options={}):
context_qkv, context_intermediates = context_block.pre_attention(context, c)
if x_block.x_block_self_attn:
@ -622,6 +622,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn = optimized_attention(
qkv[0], qkv[1], qkv[2],
heads=x_block.attn.num_heads,
transformer_options=transformer_options,
)
context_attn, x_attn = (
attn[:, : context_qkv[0].shape[1]],
@ -637,6 +638,7 @@ def _block_mixing(context, x, context_block, x_block, c):
attn2 = optimized_attention(
x_qkv2[0], x_qkv2[1], x_qkv2[2],
heads=x_block.attn2.num_heads,
transformer_options=transformer_options,
)
x = x_block.post_attention_x(x_attn, attn2, *x_intermediates)
else:
@ -958,10 +960,10 @@ class MMDiT(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"])
out["txt"], out["img"] = self.joint_blocks[i](args["txt"], args["img"], c=args["vec"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": c_mod, "transformer_options": transformer_options}, {"original_block": block_wrap})
context = out["txt"]
x = out["img"]
else:
@ -970,6 +972,7 @@ class MMDiT(nn.Module):
x,
c=c_mod,
use_checkpoint=self.use_checkpoint,
transformer_options=transformer_options,
)
if control is not None:
control_o = control.get("output")

View File

@ -145,7 +145,7 @@ class Downsample(nn.Module):
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d):
dropout=0.0, temb_channels=512, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
@ -153,7 +153,7 @@ class ResnetBlock(nn.Module):
self.use_conv_shortcut = conv_shortcut
self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.norm1 = norm_op(in_channels)
self.conv1 = conv_op(in_channels,
out_channels,
kernel_size=3,
@ -162,7 +162,7 @@ class ResnetBlock(nn.Module):
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.norm2 = norm_op(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = conv_op(out_channels,
out_channels,
@ -305,11 +305,11 @@ def vae_attention():
return normal_attention
class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d):
def __init__(self, in_channels, conv_op=ops.Conv2d, norm_op=Normalize):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.norm = norm_op(in_channels)
self.q = conv_op(in_channels,
in_channels,
kernel_size=1,

View File

@ -120,7 +120,7 @@ class Attention(nn.Module):
nn.Dropout(0.0)
)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
@ -146,7 +146,7 @@ class Attention(nn.Module):
key = key.repeat_interleave(self.heads // self.kv_heads, dim=1)
value = value.repeat_interleave(self.heads // self.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True)
hidden_states = optimized_attention_masked(query, key, value, self.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
hidden_states = self.to_out[0](hidden_states)
return hidden_states
@ -182,16 +182,16 @@ class OmniGen2TransformerBlock(nn.Module):
self.norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, transformer_options={}) -> torch.Tensor:
if self.modulation:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
else:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb)
attn_output = self.attn(norm_hidden_states, norm_hidden_states, attention_mask, image_rotary_emb, transformer_options=transformer_options)
hidden_states = hidden_states + self.norm2(attn_output)
mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
hidden_states = hidden_states + self.ffn_norm2(mlp_output)
@ -390,7 +390,7 @@ class OmniGen2Transformer2DModel(nn.Module):
ref_img_sizes, img_sizes,
)
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb):
def img_patch_embed_and_refine(self, hidden_states, ref_image_hidden_states, padded_img_mask, padded_ref_img_mask, noise_rotary_emb, ref_img_rotary_emb, l_effective_ref_img_len, l_effective_img_len, temb, transformer_options={}):
batch_size = len(hidden_states)
hidden_states = self.x_embedder(hidden_states)
@ -405,17 +405,17 @@ class OmniGen2Transformer2DModel(nn.Module):
shift += ref_img_len
for layer in self.noise_refiner:
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb)
hidden_states = layer(hidden_states, padded_img_mask, noise_rotary_emb, temb, transformer_options=transformer_options)
if ref_image_hidden_states is not None:
for layer in self.ref_image_refiner:
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb)
ref_image_hidden_states = layer(ref_image_hidden_states, padded_ref_img_mask, ref_img_rotary_emb, temb, transformer_options=transformer_options)
hidden_states = torch.cat([ref_image_hidden_states, hidden_states], dim=1)
return hidden_states
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, **kwargs):
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
@ -444,7 +444,7 @@ class OmniGen2Transformer2DModel(nn.Module):
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb)
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
@ -453,13 +453,14 @@ class OmniGen2Transformer2DModel(nn.Module):
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
transformer_options=transformer_options,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
attention_mask = None
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb)
hidden_states = layer(hidden_states, attention_mask, rotary_emb, temb, transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb)

View File

@ -132,6 +132,7 @@ class Attention(nn.Module):
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
seq_txt = encoder_hidden_states.shape[1]
@ -159,7 +160,7 @@ class Attention(nn.Module):
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
@ -226,6 +227,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
txt_mod_params = self.txt_mod(temb)
@ -242,6 +244,7 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
hidden_states = hidden_states + img_gate1 * img_attn_output
@ -434,9 +437,9 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"])
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
else:
@ -446,11 +449,12 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i})
out = p({"img": hidden_states, "txt": encoder_hidden_states, "x": x, "block_index": i, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]

View File

@ -52,7 +52,7 @@ class WanSelfAttention(nn.Module):
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, freqs):
def forward(self, x, freqs, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
@ -75,6 +75,7 @@ class WanSelfAttention(nn.Module):
k.view(b, s, n * d),
v,
heads=self.num_heads,
transformer_options=transformer_options,
)
x = self.o(x)
@ -83,7 +84,7 @@ class WanSelfAttention(nn.Module):
class WanT2VCrossAttention(WanSelfAttention):
def forward(self, x, context, **kwargs):
def forward(self, x, context, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -95,7 +96,7 @@ class WanT2VCrossAttention(WanSelfAttention):
v = self.v(context)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
x = self.o(x)
return x
@ -116,7 +117,7 @@ class WanI2VCrossAttention(WanSelfAttention):
# self.alpha = nn.Parameter(torch.zeros((1, )))
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
def forward(self, x, context, context_img_len):
def forward(self, x, context, context_img_len, transformer_options={}):
r"""
Args:
x(Tensor): Shape [B, L1, C]
@ -131,9 +132,9 @@ class WanI2VCrossAttention(WanSelfAttention):
v = self.v(context)
k_img = self.norm_k_img(self.k_img(context_img))
v_img = self.v_img(context_img)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
# compute attention
x = optimized_attention(q, k, v, heads=self.num_heads)
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
# output
x = x + img_x
@ -206,6 +207,7 @@ class WanAttentionBlock(nn.Module):
freqs,
context,
context_img_len=257,
transformer_options={},
):
r"""
Args:
@ -224,12 +226,12 @@ class WanAttentionBlock(nn.Module):
# self-attention
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs)
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
@ -559,12 +561,12 @@ class WanModel(torch.nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)
@ -742,17 +744,17 @@ class VaceWanModel(WanModel):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
for iii in range(len(c)):
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
x += c_skip * vace_strength[iii]
del c_skip
# head
@ -841,12 +843,12 @@ class CameraWanModel(WanModel):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
# head
x = self.head(x, e)

View File

@ -1432,3 +1432,31 @@ class HunyuanImage21(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class HunyuanImage21Refiner(HunyuanImage21):
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
image = kwargs.get("concat_latent_image", None)
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
if image is None:
shape_image = list(noise.shape)
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = self.process_latent_in(image)
image = utils.resize_to_batch_size(image, noise.shape[0])
if noise_augmentation > 0:
generator = torch.Generator(device="cpu")
generator.manual_seed(kwargs.get("seed", 0) - 10)
noise = torch.randn(image.shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
image = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image
else:
image = 0.75 * image
return image
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(True)
return out

View File

@ -285,6 +285,7 @@ class VAE:
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False
self.not_video = False
self.downscale_index_formula = None
self.upscale_index_formula = None
@ -409,6 +410,20 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.downscale_ratio = 16
self.upscale_ratio = 16
self.latent_dim = 3
self.not_video = True
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@ -669,7 +684,7 @@ class VAE:
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3 and pixel_samples.ndim < 5:
if not self.not_video and self.latent_dim == 3 and pixel_samples.ndim < 5:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)

View File

@ -1321,6 +1321,23 @@ class HunyuanImage21(HunyuanVideo):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
class HunyuanImage21Refiner(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"patch_size": [1, 1, 1],
"vec_in_dim": None,
}
sampling_settings = {
"shift": 4.0,
}
latent_format = latent_formats.HunyuanImage21Refiner
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanImage21Refiner(self, device=device)
return out
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ACEStep, Omnigen2, QwenImage]
models += [SVD_img2vid]

View File

@ -331,7 +331,7 @@ class String(ComfyTypeIO):
})
@comfytype(io_type="COMBO")
class Combo(ComfyTypeI):
class Combo(ComfyTypeIO):
Type = str
class Input(WidgetInput):
"""Combo input (dropdown)."""
@ -360,6 +360,14 @@ class Combo(ComfyTypeI):
"remote": self.remote.as_dict() if self.remote else None,
})
class Output(Output):
def __init__(self, id: str=None, display_name: str=None, options: list[str]=None, tooltip: str=None, is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
self.options = options if options is not None else []
@property
def io_type(self):
return self.options
@comfytype(io_type="COMBO")
class MultiCombo(ComfyTypeI):

View File

@ -846,6 +846,8 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
"pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"),
"pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"),
}
@classmethod

View File

@ -1,9 +1,10 @@
from inspect import cleandoc
from typing import Union
from typing import Optional
import logging
import torch
from comfy.comfy_types.node_typing import IO
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io as comfy_io
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
MinimaxVideoGenerationRequest,
@ -11,7 +12,7 @@ from comfy_api_nodes.apis import (
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem,
MiniMaxModel
MiniMaxModel,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
@ -31,85 +32,29 @@ from server import PromptServer
I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234
class MinimaxTextToVideoNode:
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = T2V_AVERAGE_DURATION
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation",
},
),
"model": (
[
"T2V-01",
"T2V-01-Director",
],
{
"default": "T2V-01",
"tooltip": "Model to use for video generation",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = "Generates videos from prompts using MiniMax's API"
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
async def generate_video(
self,
prompt_text,
seed=0,
model="T2V-01",
image: torch.Tensor=None, # used for ImageToVideo
subject: torch.Tensor=None, # used for SubjectToVideo
unique_id: Union[str, None]=None,
**kwargs,
):
'''
Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
'''
async def _generate_mm_video(
*,
auth: dict[str, str],
node_id: str,
prompt_text: str,
seed: int,
model: str,
image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
average_duration: Optional[int] = None,
) -> comfy_io.NodeOutput:
if image is None:
validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None
if image is not None:
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0]
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None
if subject is not None:
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0]
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)]
@ -128,7 +73,7 @@ class MinimaxTextToVideoNode:
subject_reference=subject_reference,
prompt_optimizer=None,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
@ -147,9 +92,9 @@ class MinimaxTextToVideoNode:
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
estimated_duration=self.AVERAGE_DURATION,
node_id=unique_id,
auth_kwargs=kwargs,
estimated_duration=average_duration,
node_id=node_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
@ -165,7 +110,7 @@ class MinimaxTextToVideoNode:
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
file_result = await file_retrieve_operation.execute()
@ -174,229 +119,311 @@ class MinimaxTextToVideoNode:
raise Exception(
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
if unique_id:
logging.info("Generated video URL: %s", file_url)
if node_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id)
PromptServer.instance.send_progress_text(message, node_id)
# Download and return as VideoFromFile
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return (VideoFromFile(video_io),)
return comfy_io.NodeOutput(VideoFromFile(video_io))
class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
class MinimaxTextToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
"model",
options=["T2V-01", "T2V-01-Director"],
default="T2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
step=1,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
model: str = "T2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt_text=prompt_text,
seed=seed,
model=model,
image=None,
subject=None,
average_duration=T2V_AVERAGE_DURATION,
)
class MinimaxImageToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = I2V_AVERAGE_DURATION
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"image",
tooltip="Image to use as first frame of video generation",
),
comfy_io.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
"model",
options=["I2V-01-Director", "I2V-01", "I2V-01-live"],
default="I2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
step=1,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
{
"tooltip": "Image to use as first frame of video generation"
async def execute(
cls,
image: torch.Tensor,
prompt_text: str,
model: str = "I2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
),
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation",
},
),
"model": (
[
"I2V-01-Director",
"I2V-01",
"I2V-01-live",
],
{
"default": "I2V-01",
"tooltip": "Model to use for video generation",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
node_id=cls.hidden.unique_id,
prompt_text=prompt_text,
seed=seed,
model=model,
image=image,
subject=None,
average_duration=I2V_AVERAGE_DURATION,
)
class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = T2V_AVERAGE_DURATION
@classmethod
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.Image.Input(
"subject",
tooltip="Image of subject to reference for video generation",
),
comfy_io.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation",
),
comfy_io.Combo.Input(
"model",
options=["S2V-01"],
default="S2V-01",
tooltip="Model to use for video generation",
),
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
step=1,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"subject": (
IO.IMAGE,
{
"tooltip": "Image of subject to reference video generation"
async def execute(
cls,
subject: torch.Tensor,
prompt_text: str,
model: str = "S2V-01",
seed: int = 0,
) -> comfy_io.NodeOutput:
return await _generate_mm_video(
auth={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
),
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation",
},
),
"model": (
[
"S2V-01",
],
{
"default": "S2V-01",
"tooltip": "Model to use for video generation",
},
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
node_id=cls.hidden.unique_id,
prompt_text=prompt_text,
seed=seed,
model=model,
image=None,
subject=subject,
average_duration=T2V_AVERAGE_DURATION,
)
class MinimaxHailuoVideoNode:
class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt_text": (
"STRING",
{
"multiline": True,
"default": "",
"tooltip": "Text prompt to guide the video generation.",
},
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""),
inputs=[
comfy_io.String.Input(
"prompt_text",
multiline=True,
default="",
tooltip="Text prompt to guide the video generation.",
),
},
"optional": {
"seed": (
IO.INT,
{
"default": 0,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
comfy_io.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
step=1,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
"first_frame_image": (
IO.IMAGE,
{
"tooltip": "Optional image to use as the first frame to generate a video."
},
comfy_io.Image.Input(
"first_frame_image",
tooltip="Optional image to use as the first frame to generate a video.",
optional=True,
),
"prompt_optimizer": (
IO.BOOLEAN,
{
"tooltip": "Optimize prompt to improve generation quality when needed.",
"default": True,
},
comfy_io.Boolean.Input(
"prompt_optimizer",
default=True,
tooltip="Optimize prompt to improve generation quality when needed.",
optional=True,
),
"duration": (
IO.COMBO,
{
"tooltip": "The length of the output video in seconds.",
"default": 6,
"options": [6, 10],
},
comfy_io.Combo.Input(
"duration",
options=[6, 10],
default=6,
tooltip="The length of the output video in seconds.",
optional=True,
),
"resolution": (
IO.COMBO,
{
"tooltip": "The dimensions of the video display. "
"1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels.",
"default": "768P",
"options": ["768P", "1080P"],
},
comfy_io.Combo.Input(
"resolution",
options=["768P", "1080P"],
default="768P",
tooltip="The dimensions of the video display. 1080p is 1920x1080, 768p is 1366x768.",
optional=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
seed: int = 0,
first_frame_image: Optional[torch.Tensor] = None, # used for ImageToVideo
prompt_optimizer: bool = True,
duration: int = 6,
resolution: str = "768P",
model: str = "MiniMax-Hailuo-02",
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
RETURN_TYPES = ("VIDEO",)
DESCRIPTION = cleandoc(__doc__ or "")
FUNCTION = "generate_video"
CATEGORY = "api node/video/MiniMax"
API_NODE = True
async def generate_video(
self,
prompt_text,
seed=0,
first_frame_image: torch.Tensor=None, # used for ImageToVideo
prompt_optimizer=True,
duration=6,
resolution="768P",
model="MiniMax-Hailuo-02",
unique_id: Union[str, None]=None,
**kwargs,
):
if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text")
@ -408,7 +435,7 @@ class MinimaxHailuoVideoNode:
# upload image, if passed in
image_url = None
if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=kwargs))[0]
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
video_generate_operation = SynchronousOperation(
endpoint=ApiEndpoint(
@ -426,7 +453,7 @@ class MinimaxHailuoVideoNode:
duration=duration,
resolution=resolution,
),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
response = await video_generate_operation.execute()
@ -447,8 +474,8 @@ class MinimaxHailuoVideoNode:
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
estimated_duration=average_duration,
node_id=unique_id,
auth_kwargs=kwargs,
node_id=cls.hidden.unique_id,
auth_kwargs=auth,
)
task_result = await video_generate_operation.execute()
@ -464,7 +491,7 @@ class MinimaxHailuoVideoNode:
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=kwargs,
auth_kwargs=auth,
)
file_result = await file_retrieve_operation.execute()
@ -474,34 +501,31 @@ class MinimaxHailuoVideoNode:
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
if unique_id:
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id)
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return (VideoFromFile(video_io),)
return comfy_io.NodeOutput(VideoFromFile(video_io))
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"MinimaxTextToVideoNode": MinimaxTextToVideoNode,
"MinimaxImageToVideoNode": MinimaxImageToVideoNode,
# "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
"MinimaxHailuoVideoNode": MinimaxHailuoVideoNode,
}
class MinimaxExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
MinimaxTextToVideoNode,
MinimaxImageToVideoNode,
# MinimaxSubjectToVideoNode,
MinimaxHailuoVideoNode,
]
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"MinimaxTextToVideoNode": "MiniMax Text to Video",
"MinimaxImageToVideoNode": "MiniMax Image to Video",
"MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
"MinimaxHailuoVideoNode": "MiniMax Hailuo Video",
}
async def comfy_entrypoint() -> MinimaxExtension:
return MinimaxExtension()

View File

@ -1,6 +1,7 @@
import logging
from typing import Any, Callable, Optional, TypeVar
import torch
from typing_extensions import override
from comfy_api_nodes.util.validation_utils import (
get_image_dimensions,
validate_image_dimensions,
@ -26,11 +27,9 @@ from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input.video_types import VideoInput
from comfy.comfy_types.node_typing import IO
from comfy_api.input_impl import VideoFromFile
from comfy_api.input import VideoInput
from comfy_api.latest import ComfyExtension, InputImpl, io as comfy_io
import av
import io
@ -362,7 +361,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return VideoFromFile(output_buffer)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
@ -373,9 +372,7 @@ def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
# --- BaseMoonvalleyVideoNode ---
class BaseMoonvalleyVideoNode:
def parseWidthHeightFromRes(self, resolution: str):
def parse_width_height_from_res(resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
@ -385,26 +382,21 @@ class BaseMoonvalleyVideoNode:
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
if resolution in res_map:
return res_map[resolution]
else:
# Default to 1920x1080 if unknown
return {"width": 1920, "height": 1080}
return res_map.get(resolution, {"width": 1920, "height": 1080})
def parseControlParameter(self, value):
def parse_control_parameter(value):
control_map = {
"Motion Transfer": "motion_control",
"Canny": "canny_control",
"Pose Transfer": "pose_control",
"Depth": "depth_control",
}
if value in control_map:
return control_map[value]
else:
return control_map["Motion Transfer"]
return control_map.get(value, control_map["Motion Transfer"])
async def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> MoonvalleyPromptResponse:
return await poll_until_finished(
auth_kwargs,
@ -418,121 +410,112 @@ class BaseMoonvalleyVideoNode:
node_id=node_id,
)
class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyTextToVideoRequest,
"prompt_text",
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MoonvalleyImg2VideoNode",
display_name="Moonvalley Marey Image to Video",
category="api node/video/Moonvalley Marey",
description="Moonvalley Marey Image to Video Node",
inputs=[
comfy_io.Image.Input(
"image",
tooltip="The reference image used to generate the video",
),
comfy_io.String.Input(
"prompt",
multiline=True,
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyTextToVideoInferenceParams,
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
"resolution": (
IO.COMBO,
{
"options": [
comfy_io.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1440 x 1080)",
"3:4 (1080 x 1440)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
],
"default": "16:9 (1920 x 1080)",
"tooltip": "Resolution of the output video",
},
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
"prompt_adherence": model_field_to_node_input(
IO.FLOAT,
MoonvalleyTextToVideoInferenceParams,
"guidance_scale",
comfy_io.Float.Input(
"prompt_adherence",
default=10.0,
step=1,
min=1,
max=20,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
"seed": model_field_to_node_input(
IO.INT,
MoonvalleyTextToVideoInferenceParams,
comfy_io.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display="number",
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed value",
),
"steps": model_field_to_node_input(
IO.INT,
MoonvalleyTextToVideoInferenceParams,
comfy_io.Int.Input(
"steps",
default=100,
min=1,
max=100,
step=1,
tooltip="Number of denoising steps",
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"image": model_field_to_node_input(
IO.IMAGE,
MoonvalleyTextToVideoRequest,
"image_url",
tooltip="The reference image used to generate the video",
),
},
}
RETURN_TYPES = ("STRING",)
FUNCTION = "generate"
CATEGORY = "api node/video/Moonvalley Marey"
API_NODE = True
def generate(self, **kwargs):
return None
# --- MoonvalleyImg2VideoNode ---
class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
def INPUT_TYPES(cls):
return super().INPUT_TYPES()
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
DESCRIPTION = "Moonvalley Marey Image to Video Node"
async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
image = kwargs.get("image", None)
if image is None:
raise MoonvalleyApiError("image is required")
async def execute(
cls,
image: torch.Tensor,
prompt: str,
negative_prompt: str,
resolution: str,
prompt_adherence: float,
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
validate_input_image(image, True)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
width=width_height["width"],
height=width_height["height"],
use_negative_prompts=True,
)
"""Upload image to comfy backend to have a URL available for further processing"""
@ -541,7 +524,7 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
image_url = (
await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
)
)[0]
@ -556,127 +539,102 @@ class MoonvalleyImg2VideoNode(BaseMoonvalleyVideoNode):
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=kwargs,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return (video,)
return comfy_io.NodeOutput(video)
# --- MoonvalleyVid2VidNode ---
class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self):
super().__init__()
class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyVideoToVideoRequest,
"prompt_text",
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MoonvalleyVideo2VideoNode",
display_name="Moonvalley Marey Video to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
tooltip="Describes the video to generate",
),
"negative_prompt": model_field_to_node_input(
IO.STRING,
MoonvalleyVideoToVideoInferenceParams,
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, wobbly, weird, low quality, plastic, stock footage, video camera, boring",
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
"seed": model_field_to_node_input(
IO.INT,
MoonvalleyVideoToVideoInferenceParams,
comfy_io.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display="number",
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=False,
),
"prompt_adherence": model_field_to_node_input(
IO.FLOAT,
MoonvalleyVideoToVideoInferenceParams,
"guidance_scale",
default=10.0,
comfy_io.Video.Input(
"video",
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
),
comfy_io.Combo.Input(
"control_type",
options=["Motion Transfer", "Pose Transfer"],
default="Motion Transfer",
optional=True,
),
comfy_io.Int.Input(
"motion_intensity",
default=100,
min=0,
max=100,
step=1,
min=1,
max=20,
tooltip="Only used if control_type is 'Motion Transfer'",
optional=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
"optional": {
"video": (
IO.VIDEO,
{
"default": "",
"multiline": False,
"tooltip": "The reference video used to generate the output video. Must be at least 5 seconds long. Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
},
),
"control_type": (
["Motion Transfer", "Pose Transfer"],
{"default": "Motion Transfer"},
),
"motion_intensity": (
"INT",
{
"default": 100,
"step": 1,
"min": 0,
"max": 100,
"tooltip": "Only used if control_type is 'Motion Transfer'",
},
),
"image": model_field_to_node_input(
IO.IMAGE,
MoonvalleyTextToVideoRequest,
"image_url",
tooltip="The reference image used to generate the video",
),
},
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
seed: int,
video: Optional[VideoInput] = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
) -> comfy_io.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
video = kwargs.get("video")
image = kwargs.get("image", None)
if not video:
raise MoonvalleyApiError("video is required")
video_url = ""
if video:
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(
validated_video, auth_kwargs=kwargs
)
mime_type = "image/png"
if not image is None:
validate_input_image(image, with_frame_conditioning=True)
image_url = await upload_images_to_comfyapi(
image=image, auth_kwargs=kwargs, max_images=1, mime_type=mime_type
)
control_type = kwargs.get("control_type")
motion_intensity = kwargs.get("motion_intensity")
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
"""Validate prompts and inference input"""
validate_prompts(prompt, negative_prompt)
@ -688,11 +646,11 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
inference_params = MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
seed=kwargs.get("seed"),
seed=seed,
control_params=control_params,
)
control = self.parseControlParameter(control_type)
control = parse_control_parameter(control_type)
request = MoonvalleyVideoToVideoRequest(
control_type=control,
@ -700,7 +658,6 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
prompt_text=prompt,
inference_params=inference_params,
)
request.image_url = image_url if not image is None else None
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
@ -710,58 +667,125 @@ class MoonvalleyVideo2VideoNode(BaseMoonvalleyVideoNode):
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=kwargs,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return (video,)
return comfy_io.NodeOutput(video)
# --- MoonvalleyTxt2VideoNode ---
class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
def __init__(self):
super().__init__()
RETURN_TYPES = ("VIDEO",)
RETURN_NAMES = ("video",)
class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
input_types = super().INPUT_TYPES()
# Remove image-specific parameters
for param in ["image"]:
if param in input_types["optional"]:
del input_types["optional"][param]
return input_types
def define_schema(cls) -> comfy_io.Schema:
return comfy_io.Schema(
node_id="MoonvalleyTxt2VideoNode",
display_name="Moonvalley Marey Text to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
comfy_io.String.Input(
"prompt",
multiline=True,
),
comfy_io.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
comfy_io.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
comfy_io.Float.Input(
"prompt_adherence",
default=10.0,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
comfy_io.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=comfy_io.NumberDisplay.number,
tooltip="Random seed value",
),
comfy_io.Int.Input(
"steps",
default=100,
min=1,
max=100,
step=1,
tooltip="Inference steps",
),
],
outputs=[comfy_io.Video.Output()],
hidden=[
comfy_io.Hidden.auth_token_comfy_org,
comfy_io.Hidden.api_key_comfy_org,
comfy_io.Hidden.unique_id,
],
is_api_node=True,
)
async def generate(
self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs
):
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
resolution: str,
prompt_adherence: float,
seed: int,
steps: int,
) -> comfy_io.NodeOutput:
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = self.parseWidthHeightFromRes(kwargs.get("resolution"))
width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=kwargs.get("steps"),
seed=kwargs.get("seed"),
guidance_scale=kwargs.get("prompt_adherence"),
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height.get("width"),
height=width_height.get("height"),
width=width_height["width"],
height=width_height["height"],
)
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params
)
initial_operation = SynchronousOperation(
init_op = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_TXT2VIDEO_ENDPOINT,
method=HttpMethod.POST,
@ -769,29 +793,29 @@ class MoonvalleyTxt2VideoNode(BaseMoonvalleyVideoNode):
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=kwargs,
auth_kwargs=auth,
)
task_creation_response = await initial_operation.execute()
task_creation_response = await init_op.execute()
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id
final_response = await self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return (video,)
return comfy_io.NodeOutput(video)
NODE_CLASS_MAPPINGS = {
"MoonvalleyImg2VideoNode": MoonvalleyImg2VideoNode,
"MoonvalleyTxt2VideoNode": MoonvalleyTxt2VideoNode,
"MoonvalleyVideo2VideoNode": MoonvalleyVideo2VideoNode,
}
class MoonvalleyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
return [
MoonvalleyImg2VideoNode,
MoonvalleyTxt2VideoNode,
MoonvalleyVideo2VideoNode,
]
NODE_DISPLAY_NAME_MAPPINGS = {
"MoonvalleyImg2VideoNode": "Moonvalley Marey Image to Video",
"MoonvalleyTxt2VideoNode": "Moonvalley Marey Text to Video",
"MoonvalleyVideo2VideoNode": "Moonvalley Marey Video to Video",
}
async def comfy_entrypoint() -> MoonvalleyExtension:
return MoonvalleyExtension()

View File

@ -2,12 +2,12 @@ import nodes
import torch
import numpy as np
from einops import rearrange
from typing_extensions import override
import comfy.model_management
from comfy_api.latest import ComfyExtension, io
MAX_RESOLUTION = nodes.MAX_RESOLUTION
CAMERA_DICT = {
"base_T_norm": 1.5,
"base_angle": np.pi/3,
@ -148,32 +148,47 @@ def get_camera_motion(angle, T, speed, n=81):
RT = np.stack(RT)
return RT
class WanCameraEmbedding:
class WanCameraEmbedding(io.ComfyNode):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}),
"width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),
},
"optional":{
"speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}),
"fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
"fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
"cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
"cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
}
def define_schema(cls):
return io.Schema(
node_id="WanCameraEmbedding",
category="camera",
inputs=[
io.Combo.Input(
"camera_pose",
options=[
"Static",
"Pan Up",
"Pan Down",
"Pan Left",
"Pan Right",
"Zoom In",
"Zoom Out",
"Anti Clockwise (ACW)",
"ClockWise (CW)",
],
default="Static",
),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True),
io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True),
io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True),
io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True),
io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True),
],
outputs=[
io.WanCameraEmbedding.Output(display_name="camera_embedding"),
io.Int.Output(display_name="width"),
io.Int.Output(display_name="height"),
io.Int.Output(display_name="length"),
],
)
}
RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
RETURN_NAMES = ("camera_embedding","width","height","length")
FUNCTION = "run"
CATEGORY = "camera"
def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
@classmethod
def execute(cls, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5) -> io.NodeOutput:
"""
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
@ -210,9 +225,15 @@ class WanCameraEmbedding:
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
return (control_camera_video, width, height, length)
return io.NodeOutput(control_camera_video, width, height, length)
NODE_CLASS_MAPPINGS = {
"WanCameraEmbedding": WanCameraEmbedding,
}
class CameraTrajectoryExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanCameraEmbedding,
]
async def comfy_entrypoint() -> CameraTrajectoryExtension:
return CameraTrajectoryExtension()

View File

@ -1,25 +1,41 @@
from kornia.filters import canny
from typing_extensions import override
import comfy.model_management
from comfy_api.latest import ComfyExtension, io
class Canny:
class Canny(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"image": ("IMAGE",),
"low_threshold": ("FLOAT", {"default": 0.4, "min": 0.01, "max": 0.99, "step": 0.01}),
"high_threshold": ("FLOAT", {"default": 0.8, "min": 0.01, "max": 0.99, "step": 0.01})
}}
def define_schema(cls):
return io.Schema(
node_id="Canny",
category="image/preprocessors",
inputs=[
io.Image.Input("image"),
io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01),
io.Float.Input("high_threshold", default=0.8, min=0.01, max=0.99, step=0.01),
],
outputs=[io.Image.Output()],
)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "detect_edge"
@classmethod
def detect_edge(cls, image, low_threshold, high_threshold):
# Deprecated: use the V3 schema's `execute` method instead of this.
return cls.execute(image, low_threshold, high_threshold)
CATEGORY = "image/preprocessors"
def detect_edge(self, image, low_threshold, high_threshold):
@classmethod
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
return (img_out,)
return io.NodeOutput(img_out)
NODE_CLASS_MAPPINGS = {
"Canny": Canny,
}
class CannyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [Canny]
async def comfy_entrypoint() -> CannyExtension:
return CannyExtension()

View File

@ -1,5 +1,10 @@
from typing_extensions import override
import torch
from comfy_api.latest import ComfyExtension, io
# https://github.com/WeichenFan/CFG-Zero-star
def optimized_scale(positive, negative):
positive_flat = positive.reshape(positive.shape[0], -1)
@ -16,17 +21,20 @@ def optimized_scale(positive, negative):
return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
class CFGZeroStar:
class CFGZeroStar(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CFGZeroStar",
category="advanced/guidance",
inputs=[
io.Model.Input("model"),
],
outputs=[io.Model.Output(display_name="patched_model")],
)
def patch(self, model):
@classmethod
def execute(cls, model) -> io.NodeOutput:
m = model.clone()
def cfg_zero_star(args):
guidance_scale = args['cond_scale']
@ -38,21 +46,24 @@ class CFGZeroStar:
return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
m.set_model_sampler_post_cfg_function(cfg_zero_star)
return (m, )
return io.NodeOutput(m)
class CFGNorm:
class CFGNorm(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL",),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
}}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("patched_model",)
FUNCTION = "patch"
CATEGORY = "advanced/guidance"
EXPERIMENTAL = True
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CFGNorm",
category="advanced/guidance",
inputs=[
io.Model.Input("model"),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
],
outputs=[io.Model.Output(display_name="patched_model")],
is_experimental=True,
)
def patch(self, model, strength):
@classmethod
def execute(cls, model, strength) -> io.NodeOutput:
m = model.clone()
def cfg_norm(args):
cond_p = args['cond_denoised']
@ -64,9 +75,17 @@ class CFGNorm:
return pred_text_ * scale * strength
m.set_model_sampler_post_cfg_function(cfg_norm)
return (m, )
return io.NodeOutput(m)
NODE_CLASS_MAPPINGS = {
"CFGZeroStar": CFGZeroStar,
"CFGNorm": CFGNorm,
}
class CfgExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CFGZeroStar,
CFGNorm,
]
async def comfy_entrypoint() -> CfgExtension:
return CfgExtension()

View File

@ -1,15 +1,25 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class CLIPTextEncodeControlnet:
class CLIPTextEncodeControlnet(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"clip": ("CLIP", ), "conditioning": ("CONDITIONING", ), "text": ("STRING", {"multiline": True, "dynamicPrompts": True})}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CLIPTextEncodeControlnet",
category="_for_testing/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Conditioning.Input("conditioning"),
io.String.Input("text", multiline=True, dynamic_prompts=True),
],
outputs=[io.Conditioning.Output()],
is_experimental=True,
)
CATEGORY = "_for_testing/conditioning"
def encode(self, clip, conditioning, text):
@classmethod
def execute(cls, clip, conditioning, text) -> io.NodeOutput:
tokens = clip.tokenize(text)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
c = []
@ -18,32 +28,41 @@ class CLIPTextEncodeControlnet:
n[1]['cross_attn_controlnet'] = cond
n[1]['pooled_output_controlnet'] = pooled
c.append(n)
return (c, )
return io.NodeOutput(c)
class T5TokenizerOptions:
class T5TokenizerOptions(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"clip": ("CLIP", ),
"min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
"min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
}
}
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="T5TokenizerOptions",
category="_for_testing/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1),
io.Int.Input("min_length", default=0, min=0, max=10000, step=1),
],
outputs=[io.Clip.Output()],
is_experimental=True,
)
CATEGORY = "_for_testing/conditioning"
RETURN_TYPES = ("CLIP",)
FUNCTION = "set_options"
def set_options(self, clip, min_padding, min_length):
@classmethod
def execute(cls, clip, min_padding, min_length) -> io.NodeOutput:
clip = clip.clone()
for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
return (clip, )
return io.NodeOutput(clip)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
"T5TokenizerOptions": T5TokenizerOptions,
}
class CondExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
CLIPTextEncodeControlnet,
T5TokenizerOptions,
]
async def comfy_entrypoint() -> CondExtension:
return CondExtension()

View File

@ -1,25 +1,32 @@
from typing_extensions import override
import nodes
import torch
import comfy.model_management
import comfy.utils
import comfy.latent_formats
from comfy_api.latest import ComfyExtension, io
class EmptyCosmosLatentVideo:
class EmptyCosmosLatentVideo(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EmptyCosmosLatentVideo",
category="latent/video",
inputs=[
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
],
outputs=[io.Latent.Output()],
)
CATEGORY = "latent/video"
def generate(self, width, height, length, batch_size=1):
@classmethod
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples": latent}, )
return io.NodeOutput({"samples": latent})
def vae_encode_with_padding(vae, image, width, height, length, padding=0):
@ -33,31 +40,31 @@ def vae_encode_with_padding(vae, image, width, height, length, padding=0):
return latent_temp[:, :, :latent_len]
class CosmosImageToVideoLatent:
class CosmosImageToVideoLatent(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CosmosImageToVideoLatent",
category="conditioning/inpaint",
inputs=[
io.Vae.Input("vae"),
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image", optional=True),
io.Image.Input("end_image", optional=True),
],
outputs=[io.Latent.Output()],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
@classmethod
def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is None and end_image is None:
out_latent = {}
out_latent["samples"] = latent
return (out_latent,)
return io.NodeOutput(out_latent)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
@ -74,33 +81,33 @@ class CosmosImageToVideoLatent:
out_latent = {}
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
return io.NodeOutput(out_latent)
class CosmosPredict2ImageToVideoLatent:
class CosmosPredict2ImageToVideoLatent(io.ComfyNode):
@classmethod
def INPUT_TYPES(s):
return {"required": {"vae": ("VAE", ),
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 93, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"start_image": ("IMAGE", ),
"end_image": ("IMAGE", ),
}}
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CosmosPredict2ImageToVideoLatent",
category="conditioning/inpaint",
inputs=[
io.Vae.Input("vae"),
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=93, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image", optional=True),
io.Image.Input("end_image", optional=True),
],
outputs=[io.Latent.Output()],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "conditioning/inpaint"
def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
@classmethod
def execute(cls, vae, width, height, length, batch_size, start_image=None, end_image=None) -> io.NodeOutput:
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is None and end_image is None:
out_latent = {}
out_latent["samples"] = latent
return (out_latent,)
return io.NodeOutput(out_latent)
mask = torch.ones([latent.shape[0], 1, ((length - 1) // 4) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
@ -119,10 +126,18 @@ class CosmosPredict2ImageToVideoLatent:
latent = latent_format.process_out(latent) * mask + latent * (1.0 - mask)
out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
return (out_latent,)
return io.NodeOutput(out_latent)
NODE_CLASS_MAPPINGS = {
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
"CosmosImageToVideoLatent": CosmosImageToVideoLatent,
"CosmosPredict2ImageToVideoLatent": CosmosPredict2ImageToVideoLatent,
}
class CosmosExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyCosmosLatentVideo,
CosmosImageToVideoLatent,
CosmosPredict2ImageToVideoLatent,
]
async def comfy_entrypoint() -> CosmosExtension:
return CosmosExtension()

View File

@ -128,6 +128,28 @@ class EmptyHunyuanImageLatent:
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
class HunyuanRefinerLatent:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"latent": ("LATENT", ),
"noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "execute"
def execute(self, positive, negative, latent, noise_augmentation):
latent = latent["samples"]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
out_latent = {}
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
return (positive, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
@ -135,4 +157,5 @@ NODE_CLASS_MAPPINGS = {
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
"HunyuanRefinerLatent": HunyuanRefinerLatent,
}

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.25.11
comfyui-frontend-package==1.26.11
comfyui-workflow-templates==0.1.81
comfyui-embedded-docs==0.2.6
comfyui_manager==4.0.1b5