mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 00:30:21 +08:00
Compare commits
18 Commits
84ebb205a2
...
87d8a3d2da
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87d8a3d2da | ||
|
|
8ccc0c94fa | ||
|
|
4edb87aa50 | ||
|
|
0fc3b6e3a6 | ||
|
|
2108167f9f | ||
|
|
9d273d3ab1 | ||
|
|
70c91b8248 | ||
|
|
0da5a0fe58 | ||
|
|
e0eacb0688 | ||
|
|
86e74e7f8b | ||
|
|
783da446c1 | ||
|
|
96ad4904fe | ||
|
|
4bb34b85b7 | ||
|
|
2c86040cf7 | ||
|
|
abe39647ee | ||
|
|
3f4ee9174c | ||
|
|
f190744f62 | ||
|
|
4612aab281 |
@ -189,9 +189,12 @@ class AudioVAE(torch.nn.Module):
|
||||
waveform = self.device_manager.move_to_load_device(waveform)
|
||||
expected_channels = self.autoencoder.encoder.in_channels
|
||||
if waveform.shape[1] != expected_channels:
|
||||
raise ValueError(
|
||||
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||
)
|
||||
if waveform.shape[1] == 1:
|
||||
waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
|
||||
)
|
||||
|
||||
mel_spec = self.preprocessor.waveform_to_mel(
|
||||
waveform, waveform_sample_rate, device=self.device_manager.load_device
|
||||
|
||||
@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
|
||||
|
||||
def modulate(x, scale):
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
def invert_slices(slices, length):
|
||||
sorted_slices = sorted(slices)
|
||||
result = []
|
||||
current = 0
|
||||
|
||||
for start, end in sorted_slices:
|
||||
if current < start:
|
||||
result.append((current, start))
|
||||
current = max(current, end)
|
||||
|
||||
if current < length:
|
||||
result.append((current, length))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def modulate(x, scale, timestep_zero_index=None):
|
||||
if timestep_zero_index is None:
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
else:
|
||||
scale = (1 + scale.unsqueeze(1))
|
||||
actual_batch = scale.size(0) // 2
|
||||
slices = timestep_zero_index
|
||||
invert = invert_slices(timestep_zero_index, x.shape[1])
|
||||
for s in slices:
|
||||
x[:, s[0]:s[1]] *= scale[actual_batch:]
|
||||
for s in invert:
|
||||
x[:, s[0]:s[1]] *= scale[:actual_batch]
|
||||
return x
|
||||
|
||||
|
||||
def apply_gate(gate, x, timestep_zero_index=None):
|
||||
if timestep_zero_index is None:
|
||||
return gate * x
|
||||
else:
|
||||
actual_batch = gate.size(0) // 2
|
||||
|
||||
slices = timestep_zero_index
|
||||
invert = invert_slices(timestep_zero_index, x.shape[1])
|
||||
for s in slices:
|
||||
x[:, s[0]:s[1]] *= gate[actual_batch:]
|
||||
for s in invert:
|
||||
x[:, s[0]:s[1]] *= gate[:actual_batch]
|
||||
return x
|
||||
|
||||
#############################################################################
|
||||
# Core NextDiT Model #
|
||||
@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module):
|
||||
x_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor]=None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
):
|
||||
"""
|
||||
@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module):
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2(
|
||||
clamp_fp16(self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
))
|
||||
))), timestep_zero_index=timestep_zero_index
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2(
|
||||
clamp_fp16(self.feed_forward(
|
||||
modulate(self.ffn_norm1(x), scale_mlp),
|
||||
))
|
||||
modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index),
|
||||
))), timestep_zero_index=timestep_zero_index
|
||||
)
|
||||
else:
|
||||
assert adaln_input is None
|
||||
@ -345,13 +389,37 @@ class FinalLayer(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
def forward(self, x, c, timestep_zero_index=None):
|
||||
scale = self.adaLN_modulation(c)
|
||||
x = modulate(self.norm_final(x), scale)
|
||||
x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def pad_zimage(feats, pad_token, pad_tokens_multiple):
|
||||
pad_extra = (-feats.shape[1]) % pad_tokens_multiple
|
||||
return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra
|
||||
|
||||
|
||||
def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}):
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = start_t
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
return x_pos_ids
|
||||
|
||||
|
||||
class NextDiT(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
@ -378,6 +446,7 @@ class NextDiT(nn.Module):
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
siglip_feat_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -491,6 +560,41 @@ class NextDiT(nn.Module):
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
|
||||
if siglip_feat_dim is not None:
|
||||
self.siglip_embedder = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
siglip_feat_dim,
|
||||
dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.siglip_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
else:
|
||||
self.siglip_embedder = None
|
||||
self.siglip_refiner = None
|
||||
self.siglip_pad_token = None
|
||||
|
||||
# This norm final is in the lumina 2.0 code but isn't actually used for anything.
|
||||
# self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||
@ -531,70 +635,168 @@ class NextDiT(nn.Module):
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
return imgs
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
orig_x = x
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None):
|
||||
if cap_feats is not None:
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats_len = cap_feats.shape[1]
|
||||
if self.pad_tokens_multiple is not None:
|
||||
cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple)
|
||||
else:
|
||||
cap_feats_len = 0
|
||||
cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
|
||||
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset
|
||||
embeds = (cap_feats,)
|
||||
freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),)
|
||||
return embeds, freqs_cis, cap_feats_len
|
||||
|
||||
def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}):
|
||||
bsz = 1
|
||||
pH = pW = self.patch_size
|
||||
device = x.device
|
||||
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
|
||||
|
||||
if (not omni) or self.siglip_embedder is None:
|
||||
cap_feats_len = embeds[0].shape[1] + offset
|
||||
embeds += (None,)
|
||||
freqs_cis += (None,)
|
||||
else:
|
||||
cap_feats_len += offset
|
||||
if siglip_feats is not None:
|
||||
b, h, w, c = siglip_feats.shape
|
||||
siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c)
|
||||
siglip_feats = self.siglip_embedder(siglip_feats)
|
||||
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
|
||||
siglip_pos_ids[:, :, 0] = cap_feats_len + 2
|
||||
siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten()
|
||||
siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten()
|
||||
if self.siglip_pad_token is not None:
|
||||
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
|
||||
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
|
||||
else:
|
||||
if self.siglip_pad_token is not None:
|
||||
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
|
||||
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
|
||||
|
||||
if siglip_feats is None:
|
||||
embeds += (None,)
|
||||
freqs_cis += (None,)
|
||||
else:
|
||||
embeds += (siglip_feats,)
|
||||
freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),)
|
||||
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
|
||||
x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options)
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
embeds += (x,)
|
||||
freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),)
|
||||
return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1
|
||||
|
||||
|
||||
def patchify_and_embed(
|
||||
self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
|
||||
bsz = x.shape[0]
|
||||
cap_mask = None # TODO?
|
||||
main_siglip = None
|
||||
orig_x = x
|
||||
|
||||
embeds = ([], [], [])
|
||||
freqs_cis = ([], [], [])
|
||||
leftover_cap = []
|
||||
|
||||
start_t = 0
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
for i, ref in enumerate(ref_latents):
|
||||
if i < len(ref_contexts):
|
||||
ref_con = ref_contexts[i]
|
||||
else:
|
||||
ref_con = None
|
||||
if i < len(siglip_feats):
|
||||
sig_feat = siglip_feats[i]
|
||||
else:
|
||||
sig_feat = None
|
||||
|
||||
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
|
||||
for i, e in enumerate(out[0]):
|
||||
if e is not None:
|
||||
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
|
||||
freqs_cis[i].append(out[1][i])
|
||||
start_t = out[2]
|
||||
leftover_cap = ref_contexts[len(ref_latents):]
|
||||
|
||||
H, W = x.shape[-2], x.shape[-1]
|
||||
img_sizes = [(H, W)] * bsz
|
||||
out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options)
|
||||
img_len = out[0][-1].shape[1]
|
||||
cap_len = out[0][0].shape[1]
|
||||
for i, e in enumerate(out[0]):
|
||||
if e is not None:
|
||||
e = comfy.utils.repeat_to_batch_size(e, bsz)
|
||||
embeds[i].append(e)
|
||||
freqs_cis[i].append(out[1][i])
|
||||
start_t = out[2]
|
||||
|
||||
for cap in leftover_cap:
|
||||
out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype)
|
||||
cap_len += out[0][0].shape[1]
|
||||
embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz))
|
||||
freqs_cis[0].append(out[1][0])
|
||||
start_t += out[2]
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
# refine context
|
||||
cap_feats = torch.cat(embeds[0], dim=1)
|
||||
cap_freqs_cis = torch.cat(freqs_cis[0], dim=1)
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||
|
||||
feats = (cap_feats,)
|
||||
fc = (cap_freqs_cis,)
|
||||
|
||||
if omni and len(embeds[1]) > 0:
|
||||
siglip_mask = None
|
||||
siglip_feats_combined = torch.cat(embeds[1], dim=1)
|
||||
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)
|
||||
if self.siglip_refiner is not None:
|
||||
for layer in self.siglip_refiner:
|
||||
siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options)
|
||||
feats += (siglip_feats_combined,)
|
||||
fc += (siglip_feats_freqs_cis,)
|
||||
|
||||
padded_img_mask = None
|
||||
x = torch.cat(embeds[-1], dim=1)
|
||||
fc_x = torch.cat(freqs_cis[-1], dim=1)
|
||||
if omni:
|
||||
timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])]
|
||||
else:
|
||||
timestep_zero_index = None
|
||||
|
||||
x_input = x
|
||||
for i, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "noise_refiner" in patches:
|
||||
for p in patches["noise_refiner"]:
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
if "img" in out:
|
||||
x = out["img"]
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
padded_full_embed = torch.cat(feats + (x,), dim=1)
|
||||
if timestep_zero_index is not None:
|
||||
ind = padded_full_embed.shape[1] - x.shape[1]
|
||||
timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])]
|
||||
timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1]))
|
||||
|
||||
mask = None
|
||||
img_sizes = [(H, W)] * bsz
|
||||
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@ -604,7 +806,11 @@ class NextDiT(nn.Module):
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
|
||||
omni = len(ref_latents) > 0
|
||||
if omni:
|
||||
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
|
||||
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
@ -619,8 +825,6 @@ class NextDiT(nn.Module):
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
@ -632,7 +836,7 @@ class NextDiT(nn.Module):
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.layers)
|
||||
@ -640,7 +844,7 @@ class NextDiT(nn.Module):
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
transformer_options["block_index"] = i
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
@ -649,8 +853,7 @@ class NextDiT(nn.Module):
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
|
||||
img = self.final_layer(img, adaln_input)
|
||||
img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index)
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
|
||||
return -img
|
||||
|
||||
|
||||
@ -306,6 +306,7 @@ class BaseModel(torch.nn.Module):
|
||||
to_load[k[len(unet_prefix):]] = sd.pop(k)
|
||||
|
||||
to_load = self.model_config.process_unet_state_dict(to_load)
|
||||
comfy.model_management.free_ram(state_dict=to_load)
|
||||
m, u = self.diffusion_model.load_state_dict(to_load, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("unet missing: {}".format(m))
|
||||
@ -1150,6 +1151,7 @@ class CosmosPredict2(BaseModel):
|
||||
class Lumina2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
@ -1169,6 +1171,35 @@ class Lumina2(BaseModel):
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni
|
||||
if clip_vision_outputs is not None and len(clip_vision_outputs) > 0:
|
||||
sigfeats = []
|
||||
for clip_vision_output in clip_vision_outputs:
|
||||
if clip_vision_output is not None:
|
||||
image_size = clip_vision_output.image_sizes[0]
|
||||
shape = clip_vision_output.last_hidden_state.shape
|
||||
sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1]))
|
||||
if len(sigfeats) > 0:
|
||||
out['siglip_feats'] = comfy.conds.CONDList(sigfeats)
|
||||
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
latents = []
|
||||
for lat in ref_latents:
|
||||
latents.append(self.process_latent_in(lat))
|
||||
out['ref_latents'] = comfy.conds.CONDList(latents)
|
||||
|
||||
ref_contexts = kwargs.get("reference_latents_text_embeds", None)
|
||||
if ref_contexts is not None:
|
||||
out['ref_contexts'] = comfy.conds.CONDList(ref_contexts)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
|
||||
@ -446,6 +446,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["time_scale"] = 1000.0
|
||||
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["pad_tokens_multiple"] = 32
|
||||
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)
|
||||
if sig_weight is not None:
|
||||
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
|
||||
|
||||
return dit_config
|
||||
|
||||
|
||||
@ -459,6 +459,20 @@ try:
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
|
||||
current_ram_listeners = set()
|
||||
|
||||
def register_ram_listener(listener):
|
||||
current_ram_listeners.add(listener)
|
||||
|
||||
def unregister_ram_listener(listener):
|
||||
current_ram_listeners.discard(listener)
|
||||
|
||||
def free_ram(extra_ram=0, state_dict={}):
|
||||
for tensor in state_dict.values():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
extra_ram += tensor.numel() * tensor.element_size()
|
||||
for listener in current_ram_listeners:
|
||||
listener.free_ram(extra_ram)
|
||||
|
||||
current_loaded_models = []
|
||||
|
||||
@ -535,12 +549,18 @@ class LoadedModel:
|
||||
return False
|
||||
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
if self.model is None:
|
||||
return True
|
||||
logging.debug(f"Unloading {self.model.model.__class__.__name__}")
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
freed, modules_to_offload = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
offload_modules(modules_to_offload, self.model.offload_device)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.detach(unpatch_weights)
|
||||
if self.model is not None:
|
||||
modules_to_offload = self.model.detach(unpatch_weights)
|
||||
offload_modules(modules_to_offload, self.model.offload_device)
|
||||
self.model_finalizer.detach()
|
||||
self.model_finalizer = None
|
||||
self.real_model = None
|
||||
@ -557,7 +577,7 @@ class LoadedModel:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
return self.real_model is not None and self.real_model() is not None and self.model is None
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
@ -592,6 +612,13 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def offload_modules(modules, offload_device):
|
||||
for module in modules:
|
||||
if module() is None:
|
||||
continue
|
||||
module().to(offload_device)
|
||||
free_ram()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[]):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
@ -602,23 +629,25 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i, shift_model))
|
||||
shift_model.currently_used = False
|
||||
|
||||
for x in sorted(can_unload):
|
||||
i = x[-1]
|
||||
shift_model = x[-1]
|
||||
i = x[-2]
|
||||
memory_to_free = None
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
free_mem = get_free_memory(device)
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
unloaded_model.append(i)
|
||||
if shift_model.model_unload(memory_to_free):
|
||||
unloaded_model.append((i, shift_model))
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
for i, shift_model in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(shift_model)
|
||||
if shift_model in current_loaded_models:
|
||||
current_loaded_models.remove(shift_model)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
@ -753,7 +782,7 @@ def cleanup_models_gc():
|
||||
def cleanup_models():
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if current_loaded_models[i].real_model() is None:
|
||||
if current_loaded_models[i].real_model is None or current_loaded_models[i].real_model() is None:
|
||||
to_delete = [i] + to_delete
|
||||
|
||||
for i in to_delete:
|
||||
|
||||
@ -24,6 +24,7 @@ import inspect
|
||||
import logging
|
||||
import math
|
||||
import uuid
|
||||
import weakref
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
@ -832,6 +833,7 @@ class ModelPatcher:
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
self.eject_model()
|
||||
modules_to_move = []
|
||||
if unpatch_weights:
|
||||
self.unpatch_hooks()
|
||||
self.unpin_all_weights()
|
||||
@ -856,7 +858,8 @@ class ModelPatcher:
|
||||
self.backup.clear()
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
modules_to_move = [ weakref.ref(m[3]) for m in self._load_list() ]
|
||||
modules_to_move.append(weakref.ref(self.model))
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
@ -870,12 +873,14 @@ class ModelPatcher:
|
||||
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||
|
||||
self.object_patches_backup.clear()
|
||||
return modules_to_move
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
modules_to_move = []
|
||||
unload_list = self._load_list()
|
||||
unload_list.sort()
|
||||
|
||||
@ -916,7 +921,7 @@ class ModelPatcher:
|
||||
bias_key = "{}.bias".format(n)
|
||||
if move_weight:
|
||||
cast_weight = self.force_cast_weights
|
||||
m.to(device_to)
|
||||
modules_to_move.append(weakref.ref(m))
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
@ -954,20 +959,22 @@ class ModelPatcher:
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
return memory_freed, modules_to_move
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
with self.use_ejected(skip_and_inject_on_exit_only=True):
|
||||
unpatch_weights = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid or force_patch_weights)
|
||||
# TODO: force_patch_weights should not unload + reload full model
|
||||
used = self.model.model_loaded_weight_memory
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
||||
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_weights)
|
||||
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
|
||||
if unpatch_weights:
|
||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||
|
||||
self.patch_model(load_weights=False)
|
||||
if extra_memory < 0 and not unpatch_weights:
|
||||
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
_, modules_to_offload = self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
comfy.model_management.offload_modules(modules_to_offload, self.offload_device)
|
||||
return 0
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||
@ -979,7 +986,7 @@ class ModelPatcher:
|
||||
try:
|
||||
self.load(device_to, lowvram_model_memory=current_used + extra_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||
except Exception as e:
|
||||
self.detach()
|
||||
comfy.model_management.offload_modules(self.detach(), self.offload_device())
|
||||
raise e
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
@ -987,11 +994,12 @@ class ModelPatcher:
|
||||
def detach(self, unpatch_all=True):
|
||||
self.eject_model()
|
||||
self.model_patches_to(self.offload_device)
|
||||
modules_to_offload = []
|
||||
if unpatch_all:
|
||||
self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
modules_to_offload = self.unpatch_model(self.offload_device, unpatch_weights=unpatch_all)
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_DETACH):
|
||||
callback(self, unpatch_all)
|
||||
return self.model
|
||||
return modules_to_offload
|
||||
|
||||
def current_loaded_device(self):
|
||||
return self.model.device
|
||||
|
||||
@ -288,6 +288,7 @@ class CLIP:
|
||||
|
||||
def load_sd(self, sd, full_model=False):
|
||||
if full_model:
|
||||
comfy.model_management.free_ram(state_dict=sd)
|
||||
return self.cond_stage_model.load_state_dict(sd, strict=False)
|
||||
else:
|
||||
return self.cond_stage_model.load_sd(sd)
|
||||
@ -665,6 +666,7 @@ class VAE:
|
||||
self.first_stage_model = AutoencoderKL(**(config['params']))
|
||||
self.first_stage_model = self.first_stage_model.eval()
|
||||
|
||||
comfy.model_management.free_ram(state_dict=sd)
|
||||
m, u = self.first_stage_model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0:
|
||||
logging.warning("Missing VAE keys {}".format(m))
|
||||
@ -986,6 +988,7 @@ def load_style_model(ckpt_path):
|
||||
model = comfy.ldm.flux.redux.ReduxImageEncoder()
|
||||
else:
|
||||
raise Exception("invalid style model {}".format(ckpt_path))
|
||||
comfy.model_management.free_ram(state_dict=model_data)
|
||||
model.load_state_dict(model_data)
|
||||
return StyleModel(model)
|
||||
|
||||
|
||||
@ -61,6 +61,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return OvisTEModel_
|
||||
|
||||
@ -40,6 +40,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return ZImageTEModel_
|
||||
|
||||
@ -639,6 +639,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
"proj_out.bias": "linear2.bias",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
|
||||
"attn.to_out.weight": "linear2.weight", # Flux 2
|
||||
}
|
||||
|
||||
for k in block_map:
|
||||
|
||||
@ -193,7 +193,7 @@ class BasicCache:
|
||||
self._clean_cache()
|
||||
self._clean_subcaches()
|
||||
|
||||
def poll(self, **kwargs):
|
||||
def free_ram(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
@ -284,7 +284,7 @@ class NullCache:
|
||||
def clean_unused(self):
|
||||
pass
|
||||
|
||||
def poll(self, **kwargs):
|
||||
def free_ram(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def get(self, node_id):
|
||||
@ -366,9 +366,10 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||
|
||||
class RAMPressureCache(LRUCache):
|
||||
|
||||
def __init__(self, key_class):
|
||||
def __init__(self, key_class, min_headroom=4.0):
|
||||
super().__init__(key_class, 0)
|
||||
self.timestamps = {}
|
||||
self.min_headroom = min_headroom
|
||||
|
||||
def clean_unused(self):
|
||||
self._clean_subcaches()
|
||||
@ -381,19 +382,10 @@ class RAMPressureCache(LRUCache):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return super().get(node_id)
|
||||
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > ram_headroom:
|
||||
return
|
||||
|
||||
def _build_clean_list(self):
|
||||
clean_list = []
|
||||
|
||||
for key, (outputs, _), in self.cache.items():
|
||||
for key, (_, outputs), in self.cache.items():
|
||||
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||
|
||||
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||
@ -416,8 +408,22 @@ class RAMPressureCache(LRUCache):
|
||||
#In the case where we have no information on the node ram usage at all,
|
||||
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||
return clean_list
|
||||
|
||||
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
def free_ram(self, extra_ram=0):
|
||||
headroom_target = self.min_headroom + (extra_ram / (1024**3))
|
||||
def _ram_gb():
|
||||
return psutil.virtual_memory().available / (1024**3)
|
||||
|
||||
if _ram_gb() > headroom_target:
|
||||
return
|
||||
gc.collect()
|
||||
if _ram_gb() > headroom_target:
|
||||
return
|
||||
|
||||
clean_list = self._build_clean_list()
|
||||
|
||||
while _ram_gb() < headroom_target * RAM_CACHE_HYSTERESIS and clean_list:
|
||||
_, _, key = clean_list.pop()
|
||||
del self.cache[key]
|
||||
gc.collect()
|
||||
|
||||
@ -112,6 +112,8 @@ class TopologicalSort:
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.externalBlocks = 0
|
||||
self.unblockedEvent = asyncio.Event()
|
||||
self.priorities = {}
|
||||
self.barrierNodes = set()
|
||||
|
||||
def get_input_info(self, unique_id, input_name):
|
||||
class_type = self.dynprompt.get_node(unique_id)["class_type"]
|
||||
@ -130,13 +132,37 @@ class TopologicalSort:
|
||||
|
||||
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||
if not self.is_cached(from_node_id):
|
||||
self.add_node(from_node_id)
|
||||
self.add_node(from_node_id, priority=self.priorities.get(to_node_id, 0))
|
||||
if to_node_id not in self.blocking[from_node_id]:
|
||||
self.blocking[from_node_id][to_node_id] = {}
|
||||
self.blockCount[to_node_id] += 1
|
||||
self.blocking[from_node_id][to_node_id][from_socket] = True
|
||||
|
||||
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
|
||||
def is_barrier(self, node_id):
|
||||
return node_id in self.barrierNodes
|
||||
|
||||
def unbarrier(self, node_id):
|
||||
if not node_id in self.barrierNodes:
|
||||
return
|
||||
self.barrierNodes.remove(node_id)
|
||||
self.priorities[node_id] = self.priorities.get(node_id, 0) + 1
|
||||
|
||||
links = []
|
||||
inputs = self.dynprompt.get_node(node_id)["inputs"]
|
||||
|
||||
for input_name in inputs:
|
||||
value = inputs[input_name]
|
||||
if is_link(value):
|
||||
from_node_id, from_socket = value
|
||||
_, _, input_info = self.get_input_info(node_id, input_name)
|
||||
is_barrier = input_info is not None and "cache-barrier" in input_info and input_info["cache-barrier"]
|
||||
if is_barrier:
|
||||
links.append((from_node_id, from_socket, node_id))
|
||||
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
|
||||
def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None, priority=0):
|
||||
node_ids = [node_unique_id]
|
||||
links = []
|
||||
|
||||
@ -148,6 +174,7 @@ class TopologicalSort:
|
||||
self.pendingNodes[unique_id] = True
|
||||
self.blockCount[unique_id] = 0
|
||||
self.blocking[unique_id] = {}
|
||||
self.priorities[unique_id] = priority
|
||||
|
||||
inputs = self.dynprompt.get_node(unique_id)["inputs"]
|
||||
for input_name in inputs:
|
||||
@ -158,10 +185,13 @@ class TopologicalSort:
|
||||
continue
|
||||
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||
if (include_lazy or not is_lazy):
|
||||
is_barrier = input_info is not None and "cache-barrier" in input_info and input_info["cache-barrier"]
|
||||
if (include_lazy or (not is_lazy and not is_barrier)):
|
||||
if not self.is_cached(from_node_id):
|
||||
node_ids.append(from_node_id)
|
||||
links.append((from_node_id, from_socket, unique_id))
|
||||
if is_barrier:
|
||||
self.barrierNodes.add(unique_id)
|
||||
|
||||
for link in links:
|
||||
self.add_strong_link(*link)
|
||||
@ -180,7 +210,7 @@ class TopologicalSort:
|
||||
return False
|
||||
|
||||
def get_ready_nodes(self):
|
||||
return [node_id for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
return [(self.priorities.get(node_id, 0), node_id) for node_id in self.pendingNodes if self.blockCount[node_id] == 0]
|
||||
|
||||
def pop_node(self, unique_id):
|
||||
del self.pendingNodes[unique_id]
|
||||
@ -286,25 +316,34 @@ class ExecutionList(TopologicalSort):
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
|
||||
for node_id in node_list:
|
||||
priority_level = 0
|
||||
priority_nodes = []
|
||||
for (priority, node_id) in node_list:
|
||||
if priority > priority_level:
|
||||
priority_level = priority
|
||||
priority_nodes = []
|
||||
if priority == priority_level:
|
||||
priority_nodes.append(node_id)
|
||||
|
||||
for node_id in priority_nodes:
|
||||
if is_output(node_id) or is_async(node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for node_id in priority_nodes:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
if is_output(blocked_node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAELoader -> VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for node_id in priority_nodes:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
for blocked_node_id1 in self.blocking[blocked_node_id]:
|
||||
if is_output(blocked_node_id1):
|
||||
return node_id
|
||||
|
||||
#TODO: this function should be improved
|
||||
return node_list[0]
|
||||
return priority_nodes[0]
|
||||
|
||||
def unstage_node_execution(self):
|
||||
assert self.staged_node_id is not None
|
||||
|
||||
@ -19,7 +19,7 @@ class BasicScheduler(io.ComfyNode):
|
||||
node_id="BasicScheduler",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", extra_dict={"cache-barrier":True}),
|
||||
io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES),
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
@ -138,7 +138,7 @@ class SDTurboScheduler(io.ComfyNode):
|
||||
node_id="SDTurboScheduler",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", extra_dict={"cache-barrier":True}),
|
||||
io.Int.Input("steps", default=1, min=1, max=10),
|
||||
io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01),
|
||||
],
|
||||
@ -162,7 +162,7 @@ class BetaSamplingScheduler(io.ComfyNode):
|
||||
node_id="BetaSamplingScheduler",
|
||||
category="sampling/custom_sampling/schedulers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", extra_dict={"cache-barrier":True}),
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
|
||||
io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False),
|
||||
@ -352,7 +352,7 @@ class SamplingPercentToSigma(io.ComfyNode):
|
||||
node_id="SamplingPercentToSigma",
|
||||
category="sampling/custom_sampling/sigmas",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", extra_dict={"cache-barrier":True}),
|
||||
io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001),
|
||||
io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."),
|
||||
],
|
||||
@ -623,7 +623,7 @@ class SamplerSASolver(io.ComfyNode):
|
||||
node_id="SamplerSASolver",
|
||||
category="sampling/custom_sampling/samplers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", extra_dict={"cache-barrier":True}),
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
|
||||
io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001),
|
||||
io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001),
|
||||
@ -719,7 +719,7 @@ class SamplerCustom(io.ComfyNode):
|
||||
node_id="SamplerCustom",
|
||||
category="sampling/custom_sampling",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", extra_dict={"cache-barrier":True}),
|
||||
io.Boolean.Input("add_noise", default=True),
|
||||
io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
||||
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||
@ -784,7 +784,7 @@ class BasicGuider(io.ComfyNode):
|
||||
node_id="BasicGuider",
|
||||
category="sampling/custom_sampling/guiders",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", extra_dict={"cache-barrier":True}),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
],
|
||||
outputs=[io.Guider.Output()]
|
||||
@ -805,7 +805,7 @@ class CFGGuider(io.ComfyNode):
|
||||
node_id="CFGGuider",
|
||||
category="sampling/custom_sampling/guiders",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", extra_dict={"cache-barrier":True}),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
|
||||
@ -858,7 +858,7 @@ class DualCFGGuider(io.ComfyNode):
|
||||
node_id="DualCFGGuider",
|
||||
category="sampling/custom_sampling/guiders",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", extra_dict={"cache-barrier":True}),
|
||||
io.Conditioning.Input("cond1"),
|
||||
io.Conditioning.Input("cond2"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -973,7 +973,7 @@ class AddNoise(io.ComfyNode):
|
||||
category="_for_testing/custom_sampling/noise",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", extra_dict={"cache-barrier":True}),
|
||||
io.Noise.Input("noise"),
|
||||
io.Sigmas.Input("sigmas"),
|
||||
io.Latent.Input("latent_image"),
|
||||
|
||||
88
comfy_extras/nodes_zimage.py
Normal file
88
comfy_extras/nodes_zimage.py
Normal file
@ -0,0 +1,88 @@
|
||||
import node_helpers
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import math
|
||||
import comfy.utils
|
||||
|
||||
|
||||
class TextEncodeZImageOmni(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TextEncodeZImageOmni",
|
||||
category="advanced/conditioning",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Clip.Input("clip"),
|
||||
io.ClipVision.Input("image_encoder", optional=True),
|
||||
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||
io.Boolean.Input("auto_resize_images", default=True),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.Image.Input("image1", optional=True),
|
||||
io.Image.Input("image2", optional=True),
|
||||
io.Image.Input("image3", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
|
||||
ref_latents = []
|
||||
images = list(filter(lambda a: a is not None, [image1, image2, image3]))
|
||||
|
||||
prompt_list = []
|
||||
template = None
|
||||
if len(images) > 0:
|
||||
prompt_list = ["<|im_start|>user\n<|vision_start|>"]
|
||||
prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1)
|
||||
prompt_list += ["<|vision_end|><|im_end|>"]
|
||||
template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"
|
||||
|
||||
encoded_images = []
|
||||
|
||||
for i, image in enumerate(images):
|
||||
if image_encoder is not None:
|
||||
encoded_images.append(image_encoder.encode_image(image))
|
||||
|
||||
if vae is not None:
|
||||
if auto_resize_images:
|
||||
samples = image.movedim(-1, 1)
|
||||
total = int(1024 * 1024)
|
||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||
width = round(samples.shape[3] * scale_by / 8.0) * 8
|
||||
height = round(samples.shape[2] * scale_by / 8.0) * 8
|
||||
|
||||
image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
|
||||
ref_latents.append(vae.encode(image))
|
||||
|
||||
tokens = clip.tokenize(prompt, llama_template=template)
|
||||
conditioning = clip.encode_from_tokens_scheduled(tokens)
|
||||
|
||||
extra_text_embeds = []
|
||||
for p in prompt_list:
|
||||
tokens = clip.tokenize(p, llama_template="{}")
|
||||
text_embeds = clip.encode_from_tokens_scheduled(tokens)
|
||||
extra_text_embeds.append(text_embeds[0][0])
|
||||
|
||||
if len(ref_latents) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
|
||||
if len(encoded_images) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True)
|
||||
if len(extra_text_embeds) > 0:
|
||||
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True)
|
||||
|
||||
return io.NodeOutput(conditioning)
|
||||
|
||||
|
||||
class ZImageExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
TextEncodeZImageOmni,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ZImageExtension:
|
||||
return ZImageExtension()
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.9.2"
|
||||
__version__ = "0.10.0"
|
||||
|
||||
19
execution.py
19
execution.py
@ -108,7 +108,7 @@ class CacheSet:
|
||||
self.init_null_cache()
|
||||
logging.info("Disabling intermediate node cache.")
|
||||
elif cache_type == CacheType.RAM_PRESSURE:
|
||||
cache_ram = cache_args.get("ram", 16.0)
|
||||
cache_ram = cache_args.get("ram", 4.0)
|
||||
self.init_ram_cache(cache_ram)
|
||||
logging.info("Using RAM pressure cache.")
|
||||
elif cache_type == CacheType.LRU:
|
||||
@ -130,7 +130,7 @@ class CacheSet:
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_ram_cache(self, min_headroom):
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature, min_headroom)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_null_cache(self):
|
||||
@ -427,7 +427,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
|
||||
input_data_all = None
|
||||
try:
|
||||
if unique_id in pending_async_nodes:
|
||||
if execution_list.is_barrier(unique_id):
|
||||
execution_list.unbarrier(unique_id)
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
elif unique_id in pending_async_nodes:
|
||||
results = []
|
||||
for r in pending_async_nodes[unique_id]:
|
||||
if isinstance(r, asyncio.Task):
|
||||
@ -622,13 +625,21 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
|
||||
class PromptExecutor:
|
||||
def __init__(self, server, cache_type=False, cache_args=None):
|
||||
self.caches = None
|
||||
self.cache_args = cache_args
|
||||
self.cache_type = cache_type
|
||||
self.server = server
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
if self.caches is not None:
|
||||
for cache in self.caches.all:
|
||||
comfy.model_management.unregister_ram_listener(cache)
|
||||
|
||||
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||
|
||||
for cache in self.caches.all:
|
||||
comfy.model_management.register_ram_listener(cache)
|
||||
self.status_messages = []
|
||||
self.success = True
|
||||
|
||||
@ -728,7 +739,7 @@ class PromptExecutor:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
self.caches.outputs.free_ram()
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
7
nodes.py
7
nodes.py
@ -61,7 +61,7 @@ class CLIPTextEncode(ComfyNodeABC):
|
||||
return {
|
||||
"required": {
|
||||
"text": (IO.STRING, {"multiline": True, "dynamicPrompts": True, "tooltip": "The text to be encoded."}),
|
||||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."})
|
||||
"clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text.", "cache-barrier" : True})
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = (IO.CONDITIONING,)
|
||||
@ -1521,7 +1521,7 @@ class KSampler:
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent."}),
|
||||
"model": ("MODEL", {"tooltip": "The model used for denoising the input latent.", "cache-barrier": True}),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000, "tooltip": "The number of steps used in the denoising process."}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01, "tooltip": "The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."}),
|
||||
@ -1548,7 +1548,7 @@ class KSamplerAdvanced:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
{"model": ("MODEL", {"cache-barrier": True}),
|
||||
"add_noise": (["enable", "disable"], ),
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
@ -2373,6 +2373,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_kandinsky5.py",
|
||||
"nodes_wanmove.py",
|
||||
"nodes_image_compare.py",
|
||||
"nodes_zimage.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.9.2"
|
||||
version = "0.10.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.36.14
|
||||
comfyui-workflow-templates==0.8.14
|
||||
comfyui-frontend-package==1.37.11
|
||||
comfyui-workflow-templates==0.8.15
|
||||
comfyui-embedded-docs==0.4.0
|
||||
torch
|
||||
torchsde
|
||||
|
||||
Loading…
Reference in New Issue
Block a user