Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2025-12-09 11:28:55 +03:00 committed by GitHub
commit 53f0cc26e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 52 additions and 24 deletions

View File

@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
nerf_final_head_type: str nerf_final_head_type: str
# None means use the same dtype as the model. # None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype] nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
class ChromaRadiance(Chroma): class ChromaRadiance(Chroma):
""" """
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
self.skip_dit = [] self.skip_dit = []
self.lite = False self.lite = False
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
@property @property
def _nerf_final_layer(self) -> nn.Module: def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear": if self.params.nerf_final_head_type == "linear":
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
params_dict |= overrides params_dict |= overrides
return params.__class__(**params_dict) return params.__class__(**params_dict)
def _apply_x0_residual(self, predicted, noisy, timesteps):
# non zero during training to prevent 0 div
eps = 0.0
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
def _forward( def _forward(
self, self,
x: Tensor, x: Tensor,
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
transformer_options, transformer_options,
attn_mask=kwargs.get("attention_mask", None), attn_mask=kwargs.get("attention_mask", None),
) )
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
# If x0 variant → v-pred, just return this instead
if hasattr(self, "__x0__"):
out = self._apply_x0_residual(out, img, timestep)
return out

View File

@ -320,6 +320,7 @@ def model_lora_keys_unet(model, key_map={}):
to = diffusers_keys[k] to = diffusers_keys[k]
key_lora = k[:-len(".weight")] key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
if isinstance(model, comfy.model_base.Kandinsky5): if isinstance(model, comfy.model_base.Kandinsky5):

View File

@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512 dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32 dit_config["nerf_embedder_dtype"] = torch.float32
if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else: else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys

View File

@ -35,6 +35,7 @@ import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
import comfy.utils import comfy.utils
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -132,14 +133,17 @@ class LowVramPatch:
def __call__(self, weight): def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
def low_vram_patch_estimate_vram(model, key): def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key) weight, set_func, convert_func = get_key_weight(model, key)
if weight is None: if weight is None:
return 0 return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
if model_dtype is None:
model_dtype = weight.dtype
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
def get_key_weight(model, key): def get_key_weight(model, key):
set_func = None set_func = None
@ -662,12 +666,18 @@ class ModelPatcher:
module_mem = comfy.model_management.module_size(m) module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n) def check_module_offload_mem(key):
bias_key = "{}.bias".format(n) if key in self.patches:
if weight_key in self.patches: return low_vram_patch_estimate_vram(self.model, key)
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) model_dtype = getattr(self.model, "manual_cast_dtype", None)
if bias_key in self.patches: weight, _, _ = get_key_weight(self.model, key)
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) if model_dtype is None or weight is None:
return 0
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
return weight.numel() * model_dtype.itemsize
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
loading.append((module_offload_mem, module_mem, n, m, params)) loading.append((module_offload_mem, module_mem, n, m, params))
return loading return loading
@ -920,7 +930,7 @@ class ModelPatcher:
patch_counter += 1 patch_counter += 1
cast_weight = True cast_weight = True
if cast_weight: if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.comfy_patched_weights = False m.comfy_patched_weights = False

View File

@ -22,7 +22,6 @@ import comfy.model_management
from comfy.cli_args import args, PerformanceFeature from comfy.cli_args import args, PerformanceFeature
import comfy.float import comfy.float
import comfy.rmsnorm import comfy.rmsnorm
import contextlib
import json import json
def run_every_op(): def run_every_op():
@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
else: else:
offload_stream = None offload_stream = None
if offload_stream is not None:
wf_context = offload_stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(offload_stream)
else:
wf_context = contextlib.nullcontext()
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
weight_has_function = len(s.weight_function) > 0 weight_has_function = len(s.weight_function) > 0

View File

@ -127,6 +127,8 @@ class CLIP:
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
#Match torch.float32 hardcode upcast in TE implemention
self.patcher.set_model_compute_dtype(torch.float32)
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
self.patcher.is_clip = True self.patcher.is_clip = True
self.apply_hooks_to_conds = None self.apply_hooks_to_conds = None

View File

@ -803,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
return None return None
return f.read(length_of_header) return f.read(length_of_header)
ATTR_UNSET={}
def set_attr(obj, attr, value): def set_attr(obj, attr, value):
attrs = attr.split(".") attrs = attr.split(".")
for name in attrs[:-1]: for name in attrs[:-1]:
obj = getattr(obj, name) obj = getattr(obj, name)
prev = getattr(obj, attrs[-1]) prev = getattr(obj, attrs[-1], ATTR_UNSET)
setattr(obj, attrs[-1], value) if value is ATTR_UNSET:
delattr(obj, attrs[-1])
else:
setattr(obj, attrs[-1], value)
return prev return prev
def set_attr_param(obj, attr, value): def set_attr_param(obj, attr, value):

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.33.10 comfyui-frontend-package==1.33.13
comfyui-workflow-templates==0.7.51 comfyui-workflow-templates==0.7.54
comfyui-embedded-docs==0.3.1 comfyui-embedded-docs==0.3.1
torch torch
torchsde torchsde