Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-12-09 11:28:55 +03:00 committed by GitHub
commit 53f0cc26e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 52 additions and 24 deletions

View File

@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
nerf_final_head_type: str
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
class ChromaRadiance(Chroma):
"""
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
self.skip_dit = []
self.lite = False
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
params_dict |= overrides
return params.__class__(**params_dict)
def _apply_x0_residual(self, predicted, noisy, timesteps):
# non zero during training to prevent 0 div
eps = 0.0
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
def _forward(
self,
x: Tensor,
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
transformer_options,
attn_mask=kwargs.get("attention_mask", None),
)
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
# If x0 variant → v-pred, just return this instead
if hasattr(self, "__x0__"):
out = self._apply_x0_residual(out, img, timestep)
return out

View File

@ -320,6 +320,7 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
if isinstance(model, comfy.model_base.Kandinsky5):

View File

@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys

View File

@ -35,6 +35,7 @@ import comfy.model_management
import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -132,14 +133,17 @@ class LowVramPatch:
def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
if model_dtype is None:
model_dtype = weight.dtype
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key):
set_func = None
@ -662,12 +666,18 @@ class ModelPatcher:
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
def check_module_offload_mem(key):
if key in self.patches:
return low_vram_patch_estimate_vram(self.model, key)
model_dtype = getattr(self.model, "manual_cast_dtype", None)
weight, _, _ = get_key_weight(self.model, key)
if model_dtype is None or weight is None:
return 0
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
return weight.numel() * model_dtype.itemsize
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
loading.append((module_offload_mem, module_mem, n, m, params))
return loading
@ -920,7 +930,7 @@ class ModelPatcher:
patch_counter += 1
cast_weight = True
if cast_weight:
if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False

View File

@ -22,7 +22,6 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
import comfy.rmsnorm
import contextlib
import json
def run_every_op():
@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
else:
offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()
non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0

View File

@ -127,6 +127,8 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True
self.apply_hooks_to_conds = None

View File

@ -803,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
return None
return f.read(length_of_header)
ATTR_UNSET={}
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], value)
prev = getattr(obj, attrs[-1], ATTR_UNSET)
if value is ATTR_UNSET:
delattr(obj, attrs[-1])
else:
setattr(obj, attrs[-1], value)
return prev
def set_attr_param(obj, attr, value):

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.33.10
comfyui-workflow-templates==0.7.51
comfyui-frontend-package==1.33.13
comfyui-workflow-templates==0.7.54
comfyui-embedded-docs==0.3.1
torch
torchsde