diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 6553a160c..7e7b022ca 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -35,7 +35,7 @@ class ChromaRadianceParams(ChromaParams): nerf_final_head_type: str # None means use the same dtype as the model. nerf_embedder_dtype: Optional[torch.dtype] - + use_x0: bool class ChromaRadiance(Chroma): """ @@ -159,6 +159,9 @@ class ChromaRadiance(Chroma): self.skip_dit = [] self.lite = False + if params.use_x0: + self.register_buffer("__x0__", torch.tensor([])) + @property def _nerf_final_layer(self) -> nn.Module: if self.params.nerf_final_head_type == "linear": @@ -276,6 +279,12 @@ class ChromaRadiance(Chroma): params_dict |= overrides return params.__class__(**params_dict) + def _apply_x0_residual(self, predicted, noisy, timesteps): + + # non zero during training to prevent 0 div + eps = 0.0 + return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps) + def _forward( self, x: Tensor, @@ -316,4 +325,11 @@ class ChromaRadiance(Chroma): transformer_options, attn_mask=kwargs.get("attention_mask", None), ) - return self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + out = self.forward_nerf(img, img_out, params)[:, :, :h, :w] + + # If x0 variant → v-pred, just return this instead + if hasattr(self, "__x0__"): + out = self._apply_x0_residual(out, img, timestep) + return out + diff --git a/comfy/lora.py b/comfy/lora.py index 7796d88d2..eb9d1d732 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -333,6 +333,7 @@ def model_lora_keys_unet(model, key_map=None): to = diffusers_keys[k] key_lora = k[:-len(".weight")] key_map["diffusion_model.{}".format(key_lora)] = to + key_map["transformer.{}".format(key_lora)] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to if isinstance(model, model_base.Kandinsky5): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 1ec8cc24a..adbbcdaf0 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -266,6 +266,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_tile_size"] = 512 dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 + if "__x0__" in state_dict_keys: # x0 pred + dit_config["use_x0"] = True else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 94cfc79ce..41ac17eb5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -153,15 +153,18 @@ class LowVramPatch: return calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) -# The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 -LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 def low_vram_patch_estimate_vram(model, key): weight, set_func, convert_func = get_key_weight(model, key) if weight is None: return 0 - return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + model_dtype = getattr(model, "manual_cast_dtype", torch.float32) + if model_dtype is None: + model_dtype = weight.dtype + + return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR def get_key_weight(model, key): @@ -705,7 +708,7 @@ class ModelPatcher(ModelManageable, PatchSupport): utils.copy_to_param(self.model, key, out_weight) else: utils.set_attr_param(self.model, key, out_weight) - + if self.gguf.patch_on_device: return # end gguf @@ -767,12 +770,18 @@ class ModelPatcher(ModelManageable, PatchSupport): module_mem = model_management.module_size(m) module_offload_mem = module_mem if hasattr(m, "comfy_cast_weights"): - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if weight_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) - if bias_key in self.patches: - module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + def check_module_offload_mem(key): + if key in self.patches: + return low_vram_patch_estimate_vram(self.model, key) + model_dtype = getattr(self.model, "manual_cast_dtype", None) + weight, _, _ = get_key_weight(self.model, key) + if model_dtype is None or weight is None: + return 0 + if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)): + return weight.numel() * model_dtype.itemsize + return 0 + module_offload_mem += check_module_offload_mem("{}.weight".format(n)) + module_offload_mem += check_module_offload_mem("{}.bias".format(n)) loading.append(LoadingListItem(module_offload_mem, module_mem, n, m, params)) return loading @@ -1076,7 +1085,7 @@ class ModelPatcher(ModelManageable, PatchSupport): patch_counter += 1 cast_weight = True - if cast_weight: + if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True m.comfy_patched_weights = False diff --git a/comfy/ops.py b/comfy/ops.py index 9ba0ab800..ba1ae74f0 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -119,14 +119,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of offload_stream = model_management.get_offload_stream(device) else: offload_stream = None - if offload_stream is not None: - wf_context = offload_stream - if hasattr(wf_context, "as_context"): - wf_context = wf_context.as_context(offload_stream) - else: - wf_context = contextlib.nullcontext() - # todo: how is wf_context used? non_blocking = model_management.device_supports_non_blocking(device) weight_has_function = len(s.weight_function) > 0 diff --git a/comfy/sd.py b/comfy/sd.py index 1705213dd..86e87354c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -138,6 +138,8 @@ class CLIP: self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + #Match torch.float32 hardcode upcast in TE implemention + self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = EnumHookMode.MinVram self.patcher.is_clip = True self.apply_hooks_to_conds = None diff --git a/comfy/utils.py b/comfy/utils.py index b894ea1bb..15215017c 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -912,13 +912,18 @@ def safetensors_header(safetensors_path, max_size=100 * 1024 * 1024): return None return f.read(length_of_header) +# todo: wtf? +ATTR_UNSET={} def set_attr(obj, attr, value): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1]) - setattr(obj, attrs[-1], value) + prev = getattr(obj, attrs[-1], ATTR_UNSET) + if value is ATTR_UNSET: + delattr(obj, attrs[-1]) + else: + setattr(obj, attrs[-1], value) return prev