mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-25 18:02:37 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
d90f213946
@ -387,6 +387,9 @@ class Kandinsky5(nn.Module):
|
|||||||
return self.out_layer(visual_embed, time_embed)
|
return self.out_layer(visual_embed, time_embed)
|
||||||
|
|
||||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||||
|
original_dims = x.ndim
|
||||||
|
if original_dims == 4:
|
||||||
|
x = x.unsqueeze(2)
|
||||||
bs, c, t_len, h, w = x.shape
|
bs, c, t_len, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
@ -397,7 +400,10 @@ class Kandinsky5(nn.Module):
|
|||||||
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
|
|
||||||
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||||
|
if original_dims == 4:
|
||||||
|
out = out.squeeze(2)
|
||||||
|
return out
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
|
|||||||
@ -377,6 +377,7 @@ class NextDiT(nn.Module):
|
|||||||
z_image_modulation=False,
|
z_image_modulation=False,
|
||||||
time_scale=1.0,
|
time_scale=1.0,
|
||||||
pad_tokens_multiple=None,
|
pad_tokens_multiple=None,
|
||||||
|
clip_text_dim=None,
|
||||||
image_model=None,
|
image_model=None,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
@ -447,6 +448,31 @@ class NextDiT(nn.Module):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.clip_text_pooled_proj = None
|
||||||
|
|
||||||
|
if clip_text_dim is not None:
|
||||||
|
self.clip_text_dim = clip_text_dim
|
||||||
|
self.clip_text_pooled_proj = nn.Sequential(
|
||||||
|
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
clip_text_dim,
|
||||||
|
clip_text_dim,
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.time_text_embed = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operation_settings.get("operations").Linear(
|
||||||
|
min(dim, 1024) + clip_text_dim,
|
||||||
|
min(dim, 1024),
|
||||||
|
bias=True,
|
||||||
|
device=operation_settings.get("device"),
|
||||||
|
dtype=operation_settings.get("dtype"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
JointTransformerBlock(
|
JointTransformerBlock(
|
||||||
@ -585,6 +611,15 @@ class NextDiT(nn.Module):
|
|||||||
|
|
||||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||||
|
|
||||||
|
if self.clip_text_pooled_proj is not None:
|
||||||
|
pooled = kwargs.get("clip_text_pooled", None)
|
||||||
|
if pooled is not None:
|
||||||
|
pooled = self.clip_text_pooled_proj(pooled)
|
||||||
|
else:
|
||||||
|
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||||
|
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
x_is_tensor = isinstance(x, torch.Tensor)
|
x_is_tensor = isinstance(x, torch.Tensor)
|
||||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||||
|
|||||||
@ -1110,6 +1110,10 @@ class Lumina2(BaseModel):
|
|||||||
if 'num_tokens' not in out:
|
if 'num_tokens' not in out:
|
||||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||||
|
|
||||||
|
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
||||||
|
if clip_text_pooled is not None:
|
||||||
|
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class WAN21(BaseModel):
|
class WAN21(BaseModel):
|
||||||
|
|||||||
@ -423,6 +423,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["axes_lens"] = [300, 512, 512]
|
dit_config["axes_lens"] = [300, 512, 512]
|
||||||
dit_config["rope_theta"] = 10000.0
|
dit_config["rope_theta"] = 10000.0
|
||||||
dit_config["ffn_dim_multiplier"] = 4.0
|
dit_config["ffn_dim_multiplier"] = 4.0
|
||||||
|
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||||
|
if ctd_weight is not None:
|
||||||
|
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||||
elif dit_config["dim"] == 3840: # Z image
|
elif dit_config["dim"] == 3840: # Z image
|
||||||
dit_config["n_heads"] = 30
|
dit_config["n_heads"] = 30
|
||||||
dit_config["n_kv_heads"] = 30
|
dit_config["n_kv_heads"] = 30
|
||||||
|
|||||||
@ -24,10 +24,10 @@ class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
|
|||||||
|
|
||||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
|
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
|
||||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||||
if llama_scaled_fp8 is not None:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
@ -56,12 +56,12 @@ class Kandinsky5TEModel(QwenImageTEModel):
|
|||||||
else:
|
else:
|
||||||
return super().load_sd(sd)
|
return super().load_sd(sd)
|
||||||
|
|
||||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class Kandinsky5TEModel_(Kandinsky5TEModel):
|
class Kandinsky5TEModel_(Kandinsky5TEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
if llama_quantization_metadata is not None:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user