mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-14 11:59:21 +08:00
Merge branch 'master' into scail
This commit is contained in:
commit
e5f999702a
@ -485,7 +485,7 @@ class WanVAE(nn.Module):
|
|||||||
iter_ = 1 + (t - 1) // 4
|
iter_ = 1 + (t - 1) // 4
|
||||||
feat_map = None
|
feat_map = None
|
||||||
if iter_ > 1:
|
if iter_ > 1:
|
||||||
feat_map = [None] * count_conv3d(self.decoder)
|
feat_map = [None] * count_conv3d(self.encoder)
|
||||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
## 对encode输入的x,按时间拆分为1、4、4、4....
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
conv_idx = [0]
|
conv_idx = [0]
|
||||||
|
|||||||
@ -925,6 +925,25 @@ class Flux(BaseModel):
|
|||||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class LongCatImage(Flux):
|
||||||
|
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
|
rope_opts = transformer_options.get("rope_options", {})
|
||||||
|
rope_opts = dict(rope_opts)
|
||||||
|
rope_opts.setdefault("shift_t", 1.0)
|
||||||
|
rope_opts.setdefault("shift_y", 512.0)
|
||||||
|
rope_opts.setdefault("shift_x", 512.0)
|
||||||
|
transformer_options["rope_options"] = rope_opts
|
||||||
|
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||||
|
|
||||||
|
def encode_adm(self, **kwargs):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
out.pop('guidance', None)
|
||||||
|
return out
|
||||||
|
|
||||||
class Flux2(Flux):
|
class Flux2(Flux):
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = super().extra_conds(**kwargs)
|
out = super().extra_conds(**kwargs)
|
||||||
|
|||||||
@ -279,6 +279,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
|
||||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||||
dit_config["txt_ids_dims"] = [1, 2]
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image
|
||||||
|
dit_config["txt_ids_dims"] = [1, 2]
|
||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
|||||||
19
comfy/ops.py
19
comfy/ops.py
@ -167,17 +167,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
x = to_dequant(x, dtype)
|
x = to_dequant(x, dtype)
|
||||||
if not resident and lowvram_fn is not None:
|
if not resident and lowvram_fn is not None:
|
||||||
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
x = to_dequant(x, dtype if compute_dtype is None else compute_dtype)
|
||||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
|
||||||
x = lowvram_fn(x)
|
x = lowvram_fn(x)
|
||||||
if (isinstance(orig, QuantizedTensor) and
|
if (want_requant and len(fns) == 0 or update_weight):
|
||||||
(want_requant and len(fns) == 0 or update_weight)):
|
|
||||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
if isinstance(orig, QuantizedTensor):
|
||||||
if want_requant and len(fns) == 0:
|
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||||
#The layer actually wants our freshly saved QT
|
else:
|
||||||
x = y
|
y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed)
|
||||||
elif update_weight:
|
if want_requant and len(fns) == 0:
|
||||||
y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
|
x = y
|
||||||
if update_weight:
|
if update_weight:
|
||||||
orig.copy_(y)
|
orig.copy_(y)
|
||||||
for f in fns:
|
for f in fns:
|
||||||
@ -617,7 +615,8 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if input.ndim != 2:
|
if input.ndim != 2:
|
||||||
return None
|
return None
|
||||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device)
|
||||||
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True)
|
||||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
|
|||||||
@ -60,6 +60,7 @@ import comfy.text_encoders.jina_clip_2
|
|||||||
import comfy.text_encoders.newbie
|
import comfy.text_encoders.newbie
|
||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
|
import comfy.text_encoders.longcat_image
|
||||||
|
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
import comfy.lora
|
import comfy.lora
|
||||||
@ -1160,6 +1161,7 @@ class CLIPType(Enum):
|
|||||||
KANDINSKY5_IMAGE = 23
|
KANDINSKY5_IMAGE = 23
|
||||||
NEWBIE = 24
|
NEWBIE = 24
|
||||||
FLUX2 = 25
|
FLUX2 = 25
|
||||||
|
LONGCAT_IMAGE = 26
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
@ -1372,6 +1374,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
if clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||||
|
elif clip_type == CLIPType.LONGCAT_IMAGE:
|
||||||
|
clip_target.clip = comfy.text_encoders.longcat_image.te(**llama_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.longcat_image.LongCatImageTokenizer
|
||||||
else:
|
else:
|
||||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||||
|
|||||||
@ -25,6 +25,7 @@ import comfy.text_encoders.kandinsky5
|
|||||||
import comfy.text_encoders.z_image
|
import comfy.text_encoders.z_image
|
||||||
import comfy.text_encoders.anima
|
import comfy.text_encoders.anima
|
||||||
import comfy.text_encoders.ace15
|
import comfy.text_encoders.ace15
|
||||||
|
import comfy.text_encoders.longcat_image
|
||||||
|
|
||||||
from . import supported_models_base
|
from . import supported_models_base
|
||||||
from . import latent_formats
|
from . import latent_formats
|
||||||
@ -1688,6 +1689,37 @@ class ACEStep15(supported_models_base.BASE):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||||
|
|
||||||
|
|
||||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
class LongCatImage(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "flux",
|
||||||
|
"guidance_embed": False,
|
||||||
|
"vec_in_dim": None,
|
||||||
|
"context_in_dim": 3584,
|
||||||
|
"txt_ids_dims": [1, 2],
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.Flux
|
||||||
|
|
||||||
|
memory_usage_factor = 2.5
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.LongCatImage(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
pref = self.text_encoder_key_prefix[0]
|
||||||
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
184
comfy/text_encoders/longcat_image.py
Normal file
184
comfy/text_encoders/longcat_image.py
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
import re
|
||||||
|
import numbers
|
||||||
|
import torch
|
||||||
|
from comfy import sd1_clip
|
||||||
|
from comfy.text_encoders.qwen_image import Qwen25_7BVLITokenizer, Qwen25_7BVLIModel
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
QUOTE_PAIRS = [("'", "'"), ('"', '"'), ("\u2018", "\u2019"), ("\u201c", "\u201d")]
|
||||||
|
QUOTE_PATTERN = "|".join(
|
||||||
|
[
|
||||||
|
re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2)
|
||||||
|
for q1, q2 in QUOTE_PAIRS
|
||||||
|
]
|
||||||
|
)
|
||||||
|
WORD_INTERNAL_QUOTE_RE = re.compile(r"[a-zA-Z]+'[a-zA-Z]+")
|
||||||
|
|
||||||
|
|
||||||
|
def split_quotation(prompt):
|
||||||
|
matches = WORD_INTERNAL_QUOTE_RE.findall(prompt)
|
||||||
|
mapping = []
|
||||||
|
for i, word_src in enumerate(set(matches)):
|
||||||
|
word_tgt = "longcat_$##$_longcat" * (i + 1)
|
||||||
|
prompt = prompt.replace(word_src, word_tgt)
|
||||||
|
mapping.append((word_src, word_tgt))
|
||||||
|
|
||||||
|
parts = re.split(f"({QUOTE_PATTERN})", prompt)
|
||||||
|
result = []
|
||||||
|
for part in parts:
|
||||||
|
for word_src, word_tgt in mapping:
|
||||||
|
part = part.replace(word_tgt, word_src)
|
||||||
|
if not part:
|
||||||
|
continue
|
||||||
|
is_quoted = bool(re.match(QUOTE_PATTERN, part))
|
||||||
|
result.append((part, is_quoted))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.max_length = 512
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||||
|
parts = split_quotation(text)
|
||||||
|
all_tokens = []
|
||||||
|
for part_text, is_quoted in parts:
|
||||||
|
if is_quoted:
|
||||||
|
for char in part_text:
|
||||||
|
ids = self.tokenizer(char, add_special_tokens=False)["input_ids"]
|
||||||
|
all_tokens.extend(ids)
|
||||||
|
else:
|
||||||
|
ids = self.tokenizer(part_text, add_special_tokens=False)["input_ids"]
|
||||||
|
all_tokens.extend(ids)
|
||||||
|
|
||||||
|
if len(all_tokens) > self.max_length:
|
||||||
|
all_tokens = all_tokens[: self.max_length]
|
||||||
|
logger.warning(f"Truncated prompt to {self.max_length} tokens")
|
||||||
|
|
||||||
|
output = [(t, 1.0) for t in all_tokens]
|
||||||
|
# Pad to max length
|
||||||
|
self.pad_tokens(output, self.max_length - len(output))
|
||||||
|
return [output]
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(
|
||||||
|
embedding_directory=embedding_directory,
|
||||||
|
tokenizer_data=tokenizer_data,
|
||||||
|
name="qwen25_7b",
|
||||||
|
tokenizer=LongCatImageBaseTokenizer,
|
||||||
|
)
|
||||||
|
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
|
||||||
|
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||||
|
skip_template = False
|
||||||
|
if text.startswith("<|im_start|>"):
|
||||||
|
skip_template = True
|
||||||
|
if text.startswith("<|start_header_id|>"):
|
||||||
|
skip_template = True
|
||||||
|
if text == "":
|
||||||
|
text = " "
|
||||||
|
|
||||||
|
base_tok = getattr(self, "qwen25_7b")
|
||||||
|
if skip_template:
|
||||||
|
tokens = super().tokenize_with_weights(
|
||||||
|
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefix_ids = base_tok.tokenizer(
|
||||||
|
self.longcat_template_prefix, add_special_tokens=False
|
||||||
|
)["input_ids"]
|
||||||
|
suffix_ids = base_tok.tokenizer(
|
||||||
|
self.longcat_template_suffix, add_special_tokens=False
|
||||||
|
)["input_ids"]
|
||||||
|
|
||||||
|
prompt_tokens = base_tok.tokenize_with_weights(
|
||||||
|
text, return_word_ids=return_word_ids, **kwargs
|
||||||
|
)
|
||||||
|
prompt_pairs = prompt_tokens[0]
|
||||||
|
|
||||||
|
prefix_pairs = [(t, 1.0) for t in prefix_ids]
|
||||||
|
suffix_pairs = [(t, 1.0) for t in suffix_ids]
|
||||||
|
|
||||||
|
combined = prefix_pairs + prompt_pairs + suffix_pairs
|
||||||
|
tokens = {"qwen25_7b": [combined]}
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
class LongCatImageTEModel(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
name="qwen25_7b",
|
||||||
|
clip_model=Qwen25_7BVLIModel,
|
||||||
|
model_options=model_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode_token_weights(self, token_weight_pairs, template_end=-1):
|
||||||
|
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||||
|
tok_pairs = token_weight_pairs["qwen25_7b"][0]
|
||||||
|
count_im_start = 0
|
||||||
|
if template_end == -1:
|
||||||
|
for i, v in enumerate(tok_pairs):
|
||||||
|
elem = v[0]
|
||||||
|
if not torch.is_tensor(elem):
|
||||||
|
if isinstance(elem, numbers.Integral):
|
||||||
|
if elem == 151644 and count_im_start < 2:
|
||||||
|
template_end = i
|
||||||
|
count_im_start += 1
|
||||||
|
|
||||||
|
if out.shape[1] > (template_end + 3):
|
||||||
|
if tok_pairs[template_end + 1][0] == 872:
|
||||||
|
if tok_pairs[template_end + 2][0] == 198:
|
||||||
|
template_end += 3
|
||||||
|
|
||||||
|
if template_end == -1:
|
||||||
|
template_end = 0
|
||||||
|
|
||||||
|
suffix_start = None
|
||||||
|
for i in range(len(tok_pairs) - 1, -1, -1):
|
||||||
|
elem = tok_pairs[i][0]
|
||||||
|
if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral):
|
||||||
|
if elem == 151645:
|
||||||
|
suffix_start = i
|
||||||
|
break
|
||||||
|
|
||||||
|
out = out[:, template_end:]
|
||||||
|
|
||||||
|
if "attention_mask" in extra:
|
||||||
|
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||||
|
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||||
|
extra.pop("attention_mask")
|
||||||
|
|
||||||
|
if suffix_start is not None:
|
||||||
|
suffix_len = len(tok_pairs) - suffix_start
|
||||||
|
if suffix_len > 0 and out.shape[1] > suffix_len:
|
||||||
|
out = out[:, :-suffix_len]
|
||||||
|
if "attention_mask" in extra:
|
||||||
|
extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len]
|
||||||
|
if extra["attention_mask"].sum() == torch.numel(
|
||||||
|
extra["attention_mask"]
|
||||||
|
):
|
||||||
|
extra.pop("attention_mask")
|
||||||
|
|
||||||
|
return out, pooled, extra
|
||||||
|
|
||||||
|
|
||||||
|
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
class LongCatImageTEModel_(LongCatImageTEModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if llama_quantization_metadata is not None:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||||
|
if dtype_llama is not None:
|
||||||
|
dtype = dtype_llama
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
|
return LongCatImageTEModel_
|
||||||
@ -865,14 +865,15 @@ class GLSLShader(io.ComfyNode):
|
|||||||
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
|
cls, image_list: list[torch.Tensor], output_batch: torch.Tensor
|
||||||
) -> dict[str, list]:
|
) -> dict[str, list]:
|
||||||
"""Build UI output with input and output images for client-side shader execution."""
|
"""Build UI output with input and output images for client-side shader execution."""
|
||||||
combined_inputs = torch.cat(image_list, dim=0)
|
input_images_ui = []
|
||||||
input_images_ui = ui.ImageSaveHelper.save_images(
|
for img in image_list:
|
||||||
combined_inputs,
|
input_images_ui.extend(ui.ImageSaveHelper.save_images(
|
||||||
filename_prefix="GLSLShader_input",
|
img,
|
||||||
folder_type=io.FolderType.temp,
|
filename_prefix="GLSLShader_input",
|
||||||
cls=None,
|
folder_type=io.FolderType.temp,
|
||||||
compress_level=1,
|
cls=None,
|
||||||
)
|
compress_level=1,
|
||||||
|
))
|
||||||
|
|
||||||
output_images_ui = ui.ImageSaveHelper.save_images(
|
output_images_ui = ui.ImageSaveHelper.save_images(
|
||||||
output_batch,
|
output_batch,
|
||||||
|
|||||||
@ -706,8 +706,8 @@ class SplitImageToTileList(IO.ComfyNode):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_grid_coords(width, height, tile_width, tile_height, overlap):
|
def get_grid_coords(width, height, tile_width, tile_height, overlap):
|
||||||
coords = []
|
coords = []
|
||||||
stride_x = max(1, tile_width - overlap)
|
stride_x = round(max(tile_width * 0.25, tile_width - overlap))
|
||||||
stride_y = max(1, tile_height - overlap)
|
stride_y = round(max(tile_width * 0.25, tile_height - overlap))
|
||||||
|
|
||||||
y = 0
|
y = 0
|
||||||
while y < height:
|
while y < height:
|
||||||
@ -764,34 +764,6 @@ class ImageMergeTileList(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_grid_coords(width, height, tile_width, tile_height, overlap):
|
|
||||||
coords = []
|
|
||||||
stride_x = max(1, tile_width - overlap)
|
|
||||||
stride_y = max(1, tile_height - overlap)
|
|
||||||
|
|
||||||
y = 0
|
|
||||||
while y < height:
|
|
||||||
x = 0
|
|
||||||
y_end = min(y + tile_height, height)
|
|
||||||
y_start = max(0, y_end - tile_height)
|
|
||||||
|
|
||||||
while x < width:
|
|
||||||
x_end = min(x + tile_width, width)
|
|
||||||
x_start = max(0, x_end - tile_width)
|
|
||||||
|
|
||||||
coords.append((x_start, y_start, x_end, y_end))
|
|
||||||
|
|
||||||
if x_end >= width:
|
|
||||||
break
|
|
||||||
x += stride_x
|
|
||||||
|
|
||||||
if y_end >= height:
|
|
||||||
break
|
|
||||||
y += stride_y
|
|
||||||
|
|
||||||
return coords
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image_list, final_width, final_height, overlap):
|
def execute(cls, image_list, final_width, final_height, overlap):
|
||||||
w = final_width[0]
|
w = final_width[0]
|
||||||
@ -804,7 +776,7 @@ class ImageMergeTileList(IO.ComfyNode):
|
|||||||
device = first_tile.device
|
device = first_tile.device
|
||||||
dtype = first_tile.dtype
|
dtype = first_tile.dtype
|
||||||
|
|
||||||
coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp)
|
coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp)
|
||||||
|
|
||||||
canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
|
canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype)
|
||||||
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)
|
weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype)
|
||||||
|
|||||||
86
comfy_extras/nodes_resolution.py
Normal file
86
comfy_extras/nodes_resolution.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
import math
|
||||||
|
from enum import Enum
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class AspectRatio(str, Enum):
|
||||||
|
SQUARE = "1:1 (Square)"
|
||||||
|
PHOTO_H = "3:2 (Photo)"
|
||||||
|
STANDARD_H = "4:3 (Standard)"
|
||||||
|
WIDESCREEN_H = "16:9 (Widescreen)"
|
||||||
|
ULTRAWIDE_H = "21:9 (Ultrawide)"
|
||||||
|
PHOTO_V = "2:3 (Portrait Photo)"
|
||||||
|
STANDARD_V = "3:4 (Portrait Standard)"
|
||||||
|
WIDESCREEN_V = "9:16 (Portrait Widescreen)"
|
||||||
|
|
||||||
|
|
||||||
|
ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = {
|
||||||
|
AspectRatio.SQUARE: (1, 1),
|
||||||
|
AspectRatio.PHOTO_H: (3, 2),
|
||||||
|
AspectRatio.STANDARD_H: (4, 3),
|
||||||
|
AspectRatio.WIDESCREEN_H: (16, 9),
|
||||||
|
AspectRatio.ULTRAWIDE_H: (21, 9),
|
||||||
|
AspectRatio.PHOTO_V: (2, 3),
|
||||||
|
AspectRatio.STANDARD_V: (3, 4),
|
||||||
|
AspectRatio.WIDESCREEN_V: (9, 16),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ResolutionSelector(io.ComfyNode):
|
||||||
|
"""Calculate width and height from aspect ratio and megapixel target."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ResolutionSelector",
|
||||||
|
display_name="Resolution Selector",
|
||||||
|
category="utils",
|
||||||
|
description="Calculate width and height from aspect ratio and megapixel target. Useful for setting up Empty Latent Image dimensions.",
|
||||||
|
inputs=[
|
||||||
|
io.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=AspectRatio,
|
||||||
|
default=AspectRatio.SQUARE,
|
||||||
|
tooltip="The aspect ratio for the output dimensions.",
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"megapixels",
|
||||||
|
default=1.0,
|
||||||
|
min=0.1,
|
||||||
|
max=16.0,
|
||||||
|
step=0.1,
|
||||||
|
tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Int.Output(
|
||||||
|
"width", tooltip="Calculated width in pixels (multiple of 8)."
|
||||||
|
),
|
||||||
|
io.Int.Output(
|
||||||
|
"height", tooltip="Calculated height in pixels (multiple of 8)."
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput:
|
||||||
|
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
||||||
|
total_pixels = megapixels * 1024 * 1024
|
||||||
|
scale = math.sqrt(total_pixels / (w_ratio * h_ratio))
|
||||||
|
width = round(w_ratio * scale / 8) * 8
|
||||||
|
height = round(h_ratio * scale / 8) * 8
|
||||||
|
return io.NodeOutput(width, height)
|
||||||
|
|
||||||
|
|
||||||
|
class ResolutionExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
ResolutionSelector,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ResolutionExtension:
|
||||||
|
return ResolutionExtension()
|
||||||
3
nodes.py
3
nodes.py
@ -976,7 +976,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ),
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
"device": (["default", "cpu"], {"advanced": True}),
|
"device": (["default", "cpu"], {"advanced": True}),
|
||||||
@ -2435,6 +2435,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_audio_encoder.py",
|
"nodes_audio_encoder.py",
|
||||||
"nodes_rope.py",
|
"nodes_rope.py",
|
||||||
"nodes_logic.py",
|
"nodes_logic.py",
|
||||||
|
"nodes_resolution.py",
|
||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
"nodes_kandinsky5.py",
|
"nodes_kandinsky5.py",
|
||||||
"nodes_wanmove.py",
|
"nodes_wanmove.py",
|
||||||
|
|||||||
@ -31,5 +31,4 @@ spandrel
|
|||||||
pydantic~=2.0
|
pydantic~=2.0
|
||||||
pydantic-settings~=2.0
|
pydantic-settings~=2.0
|
||||||
PyOpenGL
|
PyOpenGL
|
||||||
PyOpenGL-accelerate
|
|
||||||
glfw
|
glfw
|
||||||
|
|||||||
112
tests-unit/comfy_test/model_detection_test.py
Normal file
112
tests-unit/comfy_test/model_detection_test.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.model_detection import detect_unet_config, model_config_from_unet_config
|
||||||
|
import comfy.supported_models
|
||||||
|
|
||||||
|
|
||||||
|
def _make_longcat_comfyui_sd():
|
||||||
|
"""Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights."""
|
||||||
|
sd = {}
|
||||||
|
H = 32 # Reduce hidden state dimension to reduce memory usage
|
||||||
|
C_IN = 16
|
||||||
|
C_CTX = 3584
|
||||||
|
|
||||||
|
sd["img_in.weight"] = torch.empty(H, C_IN * 4)
|
||||||
|
sd["img_in.bias"] = torch.empty(H)
|
||||||
|
sd["txt_in.weight"] = torch.empty(H, C_CTX)
|
||||||
|
sd["txt_in.bias"] = torch.empty(H)
|
||||||
|
|
||||||
|
sd["time_in.in_layer.weight"] = torch.empty(H, 256)
|
||||||
|
sd["time_in.in_layer.bias"] = torch.empty(H)
|
||||||
|
sd["time_in.out_layer.weight"] = torch.empty(H, H)
|
||||||
|
sd["time_in.out_layer.bias"] = torch.empty(H)
|
||||||
|
|
||||||
|
sd["final_layer.adaLN_modulation.1.weight"] = torch.empty(2 * H, H)
|
||||||
|
sd["final_layer.adaLN_modulation.1.bias"] = torch.empty(2 * H)
|
||||||
|
sd["final_layer.linear.weight"] = torch.empty(C_IN * 4, H)
|
||||||
|
sd["final_layer.linear.bias"] = torch.empty(C_IN * 4)
|
||||||
|
|
||||||
|
for i in range(19):
|
||||||
|
sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128)
|
||||||
|
sd[f"double_blocks.{i}.img_attn.qkv.weight"] = torch.empty(3 * H, H)
|
||||||
|
sd[f"double_blocks.{i}.img_mod.lin.weight"] = torch.empty(H, H)
|
||||||
|
for i in range(38):
|
||||||
|
sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H)
|
||||||
|
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def _make_flux_schnell_comfyui_sd():
|
||||||
|
"""Minimal ComfyUI-format state dict for standard Flux Schnell."""
|
||||||
|
sd = {}
|
||||||
|
H = 32 # Reduce hidden state dimension to reduce memory usage
|
||||||
|
C_IN = 16
|
||||||
|
|
||||||
|
sd["img_in.weight"] = torch.empty(H, C_IN * 4)
|
||||||
|
sd["img_in.bias"] = torch.empty(H)
|
||||||
|
sd["txt_in.weight"] = torch.empty(H, 4096)
|
||||||
|
sd["txt_in.bias"] = torch.empty(H)
|
||||||
|
|
||||||
|
sd["double_blocks.0.img_attn.norm.key_norm.weight"] = torch.empty(128)
|
||||||
|
sd["double_blocks.0.img_attn.qkv.weight"] = torch.empty(3 * H, H)
|
||||||
|
sd["double_blocks.0.img_mod.lin.weight"] = torch.empty(H, H)
|
||||||
|
|
||||||
|
for i in range(19):
|
||||||
|
sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128)
|
||||||
|
for i in range(38):
|
||||||
|
sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H)
|
||||||
|
|
||||||
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelDetection:
|
||||||
|
"""Verify that first-match model detection selects the correct model
|
||||||
|
based on list ordering and unet_config specificity."""
|
||||||
|
|
||||||
|
def test_longcat_before_schnell_in_models_list(self):
|
||||||
|
"""LongCatImage must appear before FluxSchnell in the models list."""
|
||||||
|
models = comfy.supported_models.models
|
||||||
|
longcat_idx = next(i for i, m in enumerate(models) if m.__name__ == "LongCatImage")
|
||||||
|
schnell_idx = next(i for i, m in enumerate(models) if m.__name__ == "FluxSchnell")
|
||||||
|
assert longcat_idx < schnell_idx, (
|
||||||
|
f"LongCatImage (index {longcat_idx}) must come before "
|
||||||
|
f"FluxSchnell (index {schnell_idx}) in the models list"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_longcat_comfyui_detected_as_longcat(self):
|
||||||
|
sd = _make_longcat_comfyui_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "flux"
|
||||||
|
assert unet_config["context_in_dim"] == 3584
|
||||||
|
assert unet_config["vec_in_dim"] is None
|
||||||
|
assert unet_config["guidance_embed"] is False
|
||||||
|
assert unet_config["txt_ids_dims"] == [1, 2]
|
||||||
|
|
||||||
|
model_config = model_config_from_unet_config(unet_config, sd)
|
||||||
|
assert model_config is not None
|
||||||
|
assert type(model_config).__name__ == "LongCatImage"
|
||||||
|
|
||||||
|
def test_longcat_comfyui_keys_pass_through_unchanged(self):
|
||||||
|
"""Pre-converted weights should not be transformed by process_unet_state_dict."""
|
||||||
|
sd = _make_longcat_comfyui_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
model_config = model_config_from_unet_config(unet_config, sd)
|
||||||
|
|
||||||
|
processed = model_config.process_unet_state_dict(dict(sd))
|
||||||
|
assert "img_in.weight" in processed
|
||||||
|
assert "txt_in.weight" in processed
|
||||||
|
assert "time_in.in_layer.weight" in processed
|
||||||
|
assert "final_layer.linear.weight" in processed
|
||||||
|
|
||||||
|
def test_flux_schnell_comfyui_detected_as_flux_schnell(self):
|
||||||
|
sd = _make_flux_schnell_comfyui_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "flux"
|
||||||
|
assert unet_config["context_in_dim"] == 4096
|
||||||
|
assert unet_config["txt_ids_dims"] == []
|
||||||
|
|
||||||
|
model_config = model_config_from_unet_config(unet_config, sd)
|
||||||
|
assert model_config is not None
|
||||||
|
assert type(model_config).__name__ == "FluxSchnell"
|
||||||
Loading…
Reference in New Issue
Block a user