diff --git a/comfy/sd.py b/comfy/sd.py index 03bdb33d5..ad8abae96 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -98,9 +98,12 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}, clip_type_enum=None): # MODIFIED: Added clip_type_enum if no_init: return + + self.clip_type_enum = clip_type_enum + params = target.params.copy() clip = target.clip tokenizer = target.tokenizer @@ -145,6 +148,7 @@ class CLIP: n.tokenizer_options = self.tokenizer_options.copy() n.use_clip_schedule = self.use_clip_schedule n.apply_hooks_to_conds = self.apply_hooks_to_conds + n.clip_type_enum = self.clip_type_enum return n def get_ram_usage(self): @@ -176,12 +180,13 @@ class CLIP: all_cond_pooled: list[tuple[torch.Tensor, dict[str]]] = [] all_hooks = self.patcher.forced_hooks if all_hooks is None or not self.use_clip_schedule: - # if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict + # if no hooks or shouldn't use clip schedule, do unscheduled encode_from_tokens and perform add_dict return_pooled = "unprojected" if unprojected else True pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True) cond = pooled_dict.pop("cond") # add/update any keys with the provided add_dict pooled_dict.update(add_dict) + # add hooks stored on clip all_cond_pooled.append([cond, pooled_dict]) else: scheduled_keyframes = all_hooks.get_hooks_for_clip_schedule() @@ -216,8 +221,17 @@ class CLIP: # perform encoding as normal o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] + pooled_dict = {"pooled_output": pooled} # add clip_start_percent and clip_end_percent in pooled + if len(o) > 2 and isinstance(o[2], dict): + pooled_dict.update(o[2]) + + if hasattr(self, 'clip_type_enum') and self.clip_type_enum == CLIPType.CHROMA: + if 'attention_mask' in pooled_dict: + logging.debug(f"CLIP type {self.clip_type_enum.name} (scheduled path): Removing 'attention_mask' from conditioning output.") + pooled_dict.pop('attention_mask', None) + pooled_dict["clip_start_percent"] = t_range[0] pooled_dict["clip_end_percent"] = t_range[1] # add/update any keys with the provided add_dict @@ -246,10 +260,15 @@ class CLIP: cond, pooled = o[:2] if return_dict: out = {"cond": cond, "pooled_output": pooled} - if len(o) > 2: + if len(o) > 2 and isinstance(o[2], dict): for k in o[2]: out[k] = o[2][k] self.add_hooks_to_dict(out) + + if hasattr(self, 'clip_type_enum') and self.clip_type_enum == CLIPType.CHROMA: + if 'attention_mask' in out: + logging.debug(f"CLIP type {self.clip_type_enum.name} (non-scheduled path): Removing 'attention_mask' from conditioning output.") + out.pop('attention_mask', None) return out if return_pooled: @@ -280,6 +299,7 @@ class CLIP: def get_key_patches(self): return self.patcher.get_key_patches() + class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format @@ -1072,8 +1092,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip if "transformer.resblocks.0.ln_1.weight" in clip_data[i]: clip_data[i] = comfy.utils.clip_text_transformers_convert(clip_data[i], "", "") else: - if "text_projection" in clip_data[i]: - clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node + # Ensure "text_projection" exists and is a tensor before trying to transpose + if "text_projection" in clip_data[i] and isinstance(clip_data[i]["text_projection"], torch.Tensor): + clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) tokenizer_data = {} clip_target = EmptyClass() @@ -1103,7 +1124,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.LTXV: clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer - elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA: + elif clip_type == CLIPType.PIXART: clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer elif clip_type == CLIPType.WAN: @@ -1114,7 +1135,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer - else: #CLIPType.MOCHI + else: #CLIPType.MOCHI or CLIPType.CHROMA clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer elif te_model == TEModel.T5_XXL_OLD: @@ -1164,14 +1185,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer else: - # clip_l if clip_type == CLIPType.SD3: clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: + # Detect clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer - else: + else: clip_target.clip = sd1_clip.SD1ClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer elif len(clip_data) == 2: @@ -1189,7 +1210,6 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer elif clip_type == CLIPType.HIDREAM: - # Detect hidream_dualclip_classes = [] for hidream_te in clip_data: te_model = detect_te_model(hidream_te) @@ -1199,8 +1219,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_g = TEModel.CLIP_G in hidream_dualclip_classes t5 = TEModel.T5_XXL in hidream_dualclip_classes llama = TEModel.LLAMA3_8 in hidream_dualclip_classes - # Initialize t5xxl_detect and llama_detect kwargs if needed + t5_kwargs = t5xxl_detect(clip_data) if t5 else {} llama_kwargs = llama_detect(clip_data) if llama else {} @@ -1229,7 +1249,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options, clip_type_enum=clip_type) + for c in clip_data: m, u = clip.load_sd(c) if len(m) > 0: