diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index e643b4414..70d173889 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams): nerf_final_head_type: str # None means use the same dtype as the model. nerf_embedder_dtype: Optional[torch.dtype] - + use_x0: bool class ChromaRadiance(Chroma): """ @@ -159,6 +159,9 @@ class ChromaRadiance(Chroma): self.skip_dit = [] self.lite = False + if params.use_x0: + self.register_buffer("__x0__", torch.tensor([])) + @property def _nerf_final_layer(self) -> nn.Module: if self.params.nerf_final_head_type == "linear": @@ -276,6 +279,12 @@ class ChromaRadiance(Chroma): params_dict |= overrides return params.__class__(**params_dict) + def _apply_x0_residual(self, predicted, noisy, timesteps): + + # non zero during training to prevent 0 div + eps = 0.0 + return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps) + def _forward( self, x: Tensor, @@ -316,4 +325,11 @@ class ChromaRadiance(Chroma): transformer_options, attn_mask=kwargs.get("attention_mask", None), ) - return self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + out = self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + # If x0 variant → v-pred, just return this instead + if hasattr(self, "__x0__"): + out = self._apply_x0_residual(out, img, timestep) + return out + diff --git a/comfy/lora.py b/comfy/lora.py index e7202ce97..2ed0acb9d 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -320,6 +320,7 @@ def model_lora_keys_unet(model, key_map={}): to = diffusers_keys[k] key_lora = k[:-len(".weight")] key_map["diffusion_model.{}".format(key_lora)] = to + key_map["transformer.{}".format(key_lora)] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to if isinstance(model, comfy.model_base.Kandinsky5): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 74c547427..19e6aa954 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_tile_size"] = 512 dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 + if "__x0__" in state_dict_keys: # x0 pred + dit_config["use_x0"] = True else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 5b1ccb824..a486c2723 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.patcher_extension import comfy.utils from comfy.comfy_types import UnetWrapperFunction +from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -132,14 +133,17 @@ class LowVramPatch: def __call__(self, weight): return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) -#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 -LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 def low_vram_patch_estimate_vram(model, key): weight, set_func, convert_func = get_key_weight(model, key) if weight is None: return 0 - return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + model_dtype = getattr(model, "manual_cast_dtype", torch.float32) + if model_dtype is None: + model_dtype = weight.dtype + + return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR def get_key_weight(model, key): set_func = None @@ -662,12 +666,18 @@ class ModelPatcher: module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem if hasattr(m, "comfy_cast_weights"): - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if weight_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) - if bias_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + def check_module_offload_mem(key): + if key in self.patches: + return low_vram_patch_estimate_vram(self.model, key) + model_dtype = getattr(self.model, "manual_cast_dtype", None) + weight, _, _ = get_key_weight(self.model, key) + if model_dtype is None or weight is None: + return 0 + if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)): + return weight.numel() * model_dtype.itemsize + return 0 + module_offload_mem += check_module_offload_mem("{}.weight".format(n)) + module_offload_mem += check_module_offload_mem("{}.bias".format(n)) loading.append((module_offload_mem, module_mem, n, m, params)) return loading @@ -920,7 +930,7 @@ class ModelPatcher: patch_counter += 1 cast_weight = True - if cast_weight: + if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False diff --git a/comfy/ops.py b/comfy/ops.py index b749f5ede..426b784ac 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -22,7 +22,6 @@ import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm -import contextlib import json def run_every_op(): @@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of else: offload_stream = None - if offload_stream is not None: - wf_context = offload_stream - if hasattr(wf_context, "as_context"): - wf_context = wf_context.as_context(offload_stream) - else: - wf_context = contextlib.nullcontext() - non_blocking = comfy.model_management.device_supports_non_blocking(device) weight_has_function = len(s.weight_function) > 0 diff --git a/comfy/sd.py b/comfy/sd.py index 754b1703d..a16f2d14f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -127,6 +127,8 @@ class CLIP: self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + #Match torch.float32 hardcode upcast in TE implemention + self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.is_clip = True self.apply_hooks_to_conds = None diff --git a/comfy/utils.py b/comfy/utils.py index 89846bc95..9dc0d76ac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -803,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): return None return f.read(length_of_header) +ATTR_UNSET={} + def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], value) + prev = getattr(obj, attrs[-1], ATTR_UNSET) + if value is ATTR_UNSET: + delattr(obj, attrs[-1]) + else: + setattr(obj, attrs[-1], value) return prev def set_attr_param(obj, attr, value): diff --git a/requirements.txt b/requirements.txt index 12a7c1089..11a7ac245 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.33.10 -comfyui-workflow-templates==0.7.51 +comfyui-frontend-package==1.33.13 +comfyui-workflow-templates==0.7.54 comfyui-embedded-docs==0.3.1 torch torchsde