diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 7903c7690..71f73c64e 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -485,7 +485,7 @@ class WanVAE(nn.Module): iter_ = 1 + (t - 1) // 4 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.decoder) + feat_map = [None] * count_conv3d(self.encoder) ## 对encode输入的x,按时间拆分为1、4、4、4.... for i in range(iter_): conv_idx = [0] diff --git a/comfy/model_base.py b/comfy/model_base.py index 970c56e37..a1c690b9b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -925,6 +925,25 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class LongCatImage(Flux): + def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + transformer_options = transformer_options.copy() + rope_opts = transformer_options.get("rope_options", {}) + rope_opts = dict(rope_opts) + rope_opts.setdefault("shift_t", 1.0) + rope_opts.setdefault("shift_y", 512.0) + rope_opts.setdefault("shift_x", 512.0) + transformer_options["rope_options"] = rope_opts + return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) + + def encode_adm(self, **kwargs): + return None + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out.pop('guidance', None) + return out + class Flux2(Flux): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b40390746..3faa950ca 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -279,6 +279,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"]) if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model dit_config["txt_ids_dims"] = [1, 2] + if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image + dit_config["txt_ids_dims"] = [1, 2] return dit_config diff --git a/comfy/ops.py b/comfy/ops.py index 98fec1e1d..6ee6075fb 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -167,17 +167,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu x = to_dequant(x, dtype) if not resident and lowvram_fn is not None: x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) - #FIXME: this is not accurate, we need to be sensitive to the compute dtype x = lowvram_fn(x) - if (isinstance(orig, QuantizedTensor) and - (want_requant and len(fns) == 0 or update_weight)): + if (want_requant and len(fns) == 0 or update_weight): seed = comfy.utils.string_to_seed(s.seed_key) - y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) - if want_requant and len(fns) == 0: - #The layer actually wants our freshly saved QT - x = y - elif update_weight: - y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key)) + if isinstance(orig, QuantizedTensor): + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + else: + y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed) + if want_requant and len(fns) == 0: + x = y if update_weight: orig.copy_(y) for f in fns: @@ -617,7 +615,8 @@ def fp8_linear(self, input): if input.ndim != 2: return None - w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device) + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True) scale_weight = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32) diff --git a/comfy/sd.py b/comfy/sd.py index de119eb8e..7713d4678 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -60,6 +60,7 @@ import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie import comfy.text_encoders.anima import comfy.text_encoders.ace15 +import comfy.text_encoders.longcat_image import comfy.model_patcher import comfy.lora @@ -1160,6 +1161,7 @@ class CLIPType(Enum): KANDINSKY5_IMAGE = 23 NEWBIE = 24 FLUX2 = 25 + LONGCAT_IMAGE = 26 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1372,6 +1374,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip if clip_type == CLIPType.HUNYUAN_IMAGE: clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + elif clip_type == CLIPType.LONGCAT_IMAGE: + clip_target.clip = comfy.text_encoders.longcat_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.longcat_image.LongCatImageTokenizer else: clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 790d02457..4f63e8327 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -25,6 +25,7 @@ import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image import comfy.text_encoders.anima import comfy.text_encoders.ace15 +import comfy.text_encoders.longcat_image from . import supported_models_base from . import latent_formats @@ -1688,6 +1689,37 @@ class ACEStep15(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +class LongCatImage(supported_models_base.BASE): + unet_config = { + "image_model": "flux", + "guidance_embed": False, + "vec_in_dim": None, + "context_in_dim": 3584, + "txt_ids_dims": [1, 2], + } + + sampling_settings = { + } + + unet_extra_config = {} + latent_format = latent_formats.Flux + + memory_usage_factor = 2.5 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LongCatImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/longcat_image.py b/comfy/text_encoders/longcat_image.py new file mode 100644 index 000000000..882d80901 --- /dev/null +++ b/comfy/text_encoders/longcat_image.py @@ -0,0 +1,184 @@ +import re +import numbers +import torch +from comfy import sd1_clip +from comfy.text_encoders.qwen_image import Qwen25_7BVLITokenizer, Qwen25_7BVLIModel +import logging + +logger = logging.getLogger(__name__) + +QUOTE_PAIRS = [("'", "'"), ('"', '"'), ("\u2018", "\u2019"), ("\u201c", "\u201d")] +QUOTE_PATTERN = "|".join( + [ + re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) + for q1, q2 in QUOTE_PAIRS + ] +) +WORD_INTERNAL_QUOTE_RE = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + + +def split_quotation(prompt): + matches = WORD_INTERNAL_QUOTE_RE.findall(prompt) + mapping = [] + for i, word_src in enumerate(set(matches)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping.append((word_src, word_tgt)) + + parts = re.split(f"({QUOTE_PATTERN})", prompt) + result = [] + for part in parts: + for word_src, word_tgt in mapping: + part = part.replace(word_tgt, word_src) + if not part: + continue + is_quoted = bool(re.match(QUOTE_PATTERN, part)) + result.append((part, is_quoted)) + return result + + +class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_length = 512 + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + parts = split_quotation(text) + all_tokens = [] + for part_text, is_quoted in parts: + if is_quoted: + for char in part_text: + ids = self.tokenizer(char, add_special_tokens=False)["input_ids"] + all_tokens.extend(ids) + else: + ids = self.tokenizer(part_text, add_special_tokens=False)["input_ids"] + all_tokens.extend(ids) + + if len(all_tokens) > self.max_length: + all_tokens = all_tokens[: self.max_length] + logger.warning(f"Truncated prompt to {self.max_length} tokens") + + output = [(t, 1.0) for t in all_tokens] + # Pad to max length + self.pad_tokens(output, self.max_length - len(output)) + return [output] + + +class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, + name="qwen25_7b", + tokenizer=LongCatImageBaseTokenizer, + ) + self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + skip_template = False + if text.startswith("<|im_start|>"): + skip_template = True + if text.startswith("<|start_header_id|>"): + skip_template = True + if text == "": + text = " " + + base_tok = getattr(self, "qwen25_7b") + if skip_template: + tokens = super().tokenize_with_weights( + text, return_word_ids=return_word_ids, disable_weights=True, **kwargs + ) + else: + prefix_ids = base_tok.tokenizer( + self.longcat_template_prefix, add_special_tokens=False + )["input_ids"] + suffix_ids = base_tok.tokenizer( + self.longcat_template_suffix, add_special_tokens=False + )["input_ids"] + + prompt_tokens = base_tok.tokenize_with_weights( + text, return_word_ids=return_word_ids, **kwargs + ) + prompt_pairs = prompt_tokens[0] + + prefix_pairs = [(t, 1.0) for t in prefix_ids] + suffix_pairs = [(t, 1.0) for t in suffix_ids] + + combined = prefix_pairs + prompt_pairs + suffix_pairs + tokens = {"qwen25_7b": [combined]} + + return tokens + + +class LongCatImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__( + device=device, + dtype=dtype, + name="qwen25_7b", + clip_model=Qwen25_7BVLIModel, + model_options=model_options, + ) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: + if tok_pairs[template_end + 2][0] == 198: + template_end += 3 + + if template_end == -1: + template_end = 0 + + suffix_start = None + for i in range(len(tok_pairs) - 1, -1, -1): + elem = tok_pairs[i][0] + if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral): + if elem == 151645: + suffix_start = i + break + + out = out[:, template_end:] + + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") + + if suffix_start is not None: + suffix_len = len(tok_pairs) - suffix_start + if suffix_len > 0 and out.shape[1] > suffix_len: + out = out[:, :-suffix_len] + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len] + if extra["attention_mask"].sum() == torch.numel( + extra["attention_mask"] + ): + extra.pop("attention_mask") + + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class LongCatImageTEModel_(LongCatImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + + return LongCatImageTEModel_ diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py index 6d210b307..2a59a9285 100644 --- a/comfy_extras/nodes_glsl.py +++ b/comfy_extras/nodes_glsl.py @@ -865,14 +865,15 @@ class GLSLShader(io.ComfyNode): cls, image_list: list[torch.Tensor], output_batch: torch.Tensor ) -> dict[str, list]: """Build UI output with input and output images for client-side shader execution.""" - combined_inputs = torch.cat(image_list, dim=0) - input_images_ui = ui.ImageSaveHelper.save_images( - combined_inputs, - filename_prefix="GLSLShader_input", - folder_type=io.FolderType.temp, - cls=None, - compress_level=1, - ) + input_images_ui = [] + for img in image_list: + input_images_ui.extend(ui.ImageSaveHelper.save_images( + img, + filename_prefix="GLSLShader_input", + folder_type=io.FolderType.temp, + cls=None, + compress_level=1, + )) output_images_ui = ui.ImageSaveHelper.save_images( output_batch, diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 727d7d09d..4c57bb5cb 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -706,8 +706,8 @@ class SplitImageToTileList(IO.ComfyNode): @staticmethod def get_grid_coords(width, height, tile_width, tile_height, overlap): coords = [] - stride_x = max(1, tile_width - overlap) - stride_y = max(1, tile_height - overlap) + stride_x = round(max(tile_width * 0.25, tile_width - overlap)) + stride_y = round(max(tile_width * 0.25, tile_height - overlap)) y = 0 while y < height: @@ -764,34 +764,6 @@ class ImageMergeTileList(IO.ComfyNode): ], ) - @staticmethod - def get_grid_coords(width, height, tile_width, tile_height, overlap): - coords = [] - stride_x = max(1, tile_width - overlap) - stride_y = max(1, tile_height - overlap) - - y = 0 - while y < height: - x = 0 - y_end = min(y + tile_height, height) - y_start = max(0, y_end - tile_height) - - while x < width: - x_end = min(x + tile_width, width) - x_start = max(0, x_end - tile_width) - - coords.append((x_start, y_start, x_end, y_end)) - - if x_end >= width: - break - x += stride_x - - if y_end >= height: - break - y += stride_y - - return coords - @classmethod def execute(cls, image_list, final_width, final_height, overlap): w = final_width[0] @@ -804,7 +776,7 @@ class ImageMergeTileList(IO.ComfyNode): device = first_tile.device dtype = first_tile.dtype - coords = cls.get_grid_coords(w, h, t_w, t_h, ovlp) + coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp) canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype) weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype) diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py new file mode 100644 index 000000000..520b4067e --- /dev/null +++ b/comfy_extras/nodes_resolution.py @@ -0,0 +1,86 @@ +from __future__ import annotations +import math +from enum import Enum +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class AspectRatio(str, Enum): + SQUARE = "1:1 (Square)" + PHOTO_H = "3:2 (Photo)" + STANDARD_H = "4:3 (Standard)" + WIDESCREEN_H = "16:9 (Widescreen)" + ULTRAWIDE_H = "21:9 (Ultrawide)" + PHOTO_V = "2:3 (Portrait Photo)" + STANDARD_V = "3:4 (Portrait Standard)" + WIDESCREEN_V = "9:16 (Portrait Widescreen)" + + +ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = { + AspectRatio.SQUARE: (1, 1), + AspectRatio.PHOTO_H: (3, 2), + AspectRatio.STANDARD_H: (4, 3), + AspectRatio.WIDESCREEN_H: (16, 9), + AspectRatio.ULTRAWIDE_H: (21, 9), + AspectRatio.PHOTO_V: (2, 3), + AspectRatio.STANDARD_V: (3, 4), + AspectRatio.WIDESCREEN_V: (9, 16), +} + + +class ResolutionSelector(io.ComfyNode): + """Calculate width and height from aspect ratio and megapixel target.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ResolutionSelector", + display_name="Resolution Selector", + category="utils", + description="Calculate width and height from aspect ratio and megapixel target. Useful for setting up Empty Latent Image dimensions.", + inputs=[ + io.Combo.Input( + "aspect_ratio", + options=AspectRatio, + default=AspectRatio.SQUARE, + tooltip="The aspect ratio for the output dimensions.", + ), + io.Float.Input( + "megapixels", + default=1.0, + min=0.1, + max=16.0, + step=0.1, + tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.", + ), + ], + outputs=[ + io.Int.Output( + "width", tooltip="Calculated width in pixels (multiple of 8)." + ), + io.Int.Output( + "height", tooltip="Calculated height in pixels (multiple of 8)." + ), + ], + ) + + @classmethod + def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput: + w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] + total_pixels = megapixels * 1024 * 1024 + scale = math.sqrt(total_pixels / (w_ratio * h_ratio)) + width = round(w_ratio * scale / 8) * 8 + height = round(h_ratio * scale / 8) * 8 + return io.NodeOutput(width, height) + + +class ResolutionExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ResolutionSelector, + ] + + +async def comfy_entrypoint() -> ResolutionExtension: + return ResolutionExtension() diff --git a/nodes.py b/nodes.py index 0222ec629..5be9b16f9 100644 --- a/nodes.py +++ b/nodes.py @@ -976,7 +976,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2435,6 +2435,7 @@ async def init_builtin_extra_nodes(): "nodes_audio_encoder.py", "nodes_rope.py", "nodes_logic.py", + "nodes_resolution.py", "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", diff --git a/requirements.txt b/requirements.txt index b5b292980..1b2bd0ae6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,5 +31,4 @@ spandrel pydantic~=2.0 pydantic-settings~=2.0 PyOpenGL -PyOpenGL-accelerate glfw diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py new file mode 100644 index 000000000..2551a417b --- /dev/null +++ b/tests-unit/comfy_test/model_detection_test.py @@ -0,0 +1,112 @@ +import torch + +from comfy.model_detection import detect_unet_config, model_config_from_unet_config +import comfy.supported_models + + +def _make_longcat_comfyui_sd(): + """Minimal ComfyUI-format state dict for pre-converted LongCat-Image weights.""" + sd = {} + H = 32 # Reduce hidden state dimension to reduce memory usage + C_IN = 16 + C_CTX = 3584 + + sd["img_in.weight"] = torch.empty(H, C_IN * 4) + sd["img_in.bias"] = torch.empty(H) + sd["txt_in.weight"] = torch.empty(H, C_CTX) + sd["txt_in.bias"] = torch.empty(H) + + sd["time_in.in_layer.weight"] = torch.empty(H, 256) + sd["time_in.in_layer.bias"] = torch.empty(H) + sd["time_in.out_layer.weight"] = torch.empty(H, H) + sd["time_in.out_layer.bias"] = torch.empty(H) + + sd["final_layer.adaLN_modulation.1.weight"] = torch.empty(2 * H, H) + sd["final_layer.adaLN_modulation.1.bias"] = torch.empty(2 * H) + sd["final_layer.linear.weight"] = torch.empty(C_IN * 4, H) + sd["final_layer.linear.bias"] = torch.empty(C_IN * 4) + + for i in range(19): + sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128) + sd[f"double_blocks.{i}.img_attn.qkv.weight"] = torch.empty(3 * H, H) + sd[f"double_blocks.{i}.img_mod.lin.weight"] = torch.empty(H, H) + for i in range(38): + sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H) + + return sd + + +def _make_flux_schnell_comfyui_sd(): + """Minimal ComfyUI-format state dict for standard Flux Schnell.""" + sd = {} + H = 32 # Reduce hidden state dimension to reduce memory usage + C_IN = 16 + + sd["img_in.weight"] = torch.empty(H, C_IN * 4) + sd["img_in.bias"] = torch.empty(H) + sd["txt_in.weight"] = torch.empty(H, 4096) + sd["txt_in.bias"] = torch.empty(H) + + sd["double_blocks.0.img_attn.norm.key_norm.weight"] = torch.empty(128) + sd["double_blocks.0.img_attn.qkv.weight"] = torch.empty(3 * H, H) + sd["double_blocks.0.img_mod.lin.weight"] = torch.empty(H, H) + + for i in range(19): + sd[f"double_blocks.{i}.img_attn.norm.key_norm.weight"] = torch.empty(128) + for i in range(38): + sd[f"single_blocks.{i}.modulation.lin.weight"] = torch.empty(H, H) + + return sd + + +class TestModelDetection: + """Verify that first-match model detection selects the correct model + based on list ordering and unet_config specificity.""" + + def test_longcat_before_schnell_in_models_list(self): + """LongCatImage must appear before FluxSchnell in the models list.""" + models = comfy.supported_models.models + longcat_idx = next(i for i, m in enumerate(models) if m.__name__ == "LongCatImage") + schnell_idx = next(i for i, m in enumerate(models) if m.__name__ == "FluxSchnell") + assert longcat_idx < schnell_idx, ( + f"LongCatImage (index {longcat_idx}) must come before " + f"FluxSchnell (index {schnell_idx}) in the models list" + ) + + def test_longcat_comfyui_detected_as_longcat(self): + sd = _make_longcat_comfyui_sd() + unet_config = detect_unet_config(sd, "") + assert unet_config is not None + assert unet_config["image_model"] == "flux" + assert unet_config["context_in_dim"] == 3584 + assert unet_config["vec_in_dim"] is None + assert unet_config["guidance_embed"] is False + assert unet_config["txt_ids_dims"] == [1, 2] + + model_config = model_config_from_unet_config(unet_config, sd) + assert model_config is not None + assert type(model_config).__name__ == "LongCatImage" + + def test_longcat_comfyui_keys_pass_through_unchanged(self): + """Pre-converted weights should not be transformed by process_unet_state_dict.""" + sd = _make_longcat_comfyui_sd() + unet_config = detect_unet_config(sd, "") + model_config = model_config_from_unet_config(unet_config, sd) + + processed = model_config.process_unet_state_dict(dict(sd)) + assert "img_in.weight" in processed + assert "txt_in.weight" in processed + assert "time_in.in_layer.weight" in processed + assert "final_layer.linear.weight" in processed + + def test_flux_schnell_comfyui_detected_as_flux_schnell(self): + sd = _make_flux_schnell_comfyui_sd() + unet_config = detect_unet_config(sd, "") + assert unet_config is not None + assert unet_config["image_model"] == "flux" + assert unet_config["context_in_dim"] == 4096 + assert unet_config["txt_ids_dims"] == [] + + model_config = model_config_from_unet_config(unet_config, sd) + assert model_config is not None + assert type(model_config).__name__ == "FluxSchnell"