mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI into merge/0.3.76-snapshot
This commit is contained in:
commit
194ac1f596
@ -35,7 +35,7 @@ class ChromaRadianceParams(ChromaParams):
|
||||
nerf_final_head_type: str
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
|
||||
use_x0: bool
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
|
||||
params_dict |= overrides
|
||||
return params.__class__(**params_dict)
|
||||
|
||||
def _apply_x0_residual(self, predicted, noisy, timesteps):
|
||||
|
||||
# non zero during training to prevent 0 div
|
||||
eps = 0.0
|
||||
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
|
||||
transformer_options,
|
||||
attn_mask=kwargs.get("attention_mask", None),
|
||||
)
|
||||
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
# If x0 variant → v-pred, just return this instead
|
||||
if hasattr(self, "__x0__"):
|
||||
out = self._apply_x0_residual(out, img, timestep)
|
||||
return out
|
||||
|
||||
|
||||
@ -333,6 +333,7 @@ def model_lora_keys_unet(model, key_map=None):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = k[:-len(".weight")]
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
|
||||
if isinstance(model, model_base.Kandinsky5):
|
||||
|
||||
@ -266,6 +266,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "__x0__" in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
|
||||
@ -153,15 +153,18 @@ class LowVramPatch:
|
||||
return calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
|
||||
|
||||
# The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||
|
||||
|
||||
def low_vram_patch_estimate_vram(model, key):
|
||||
weight, set_func, convert_func = get_key_weight(model, key)
|
||||
if weight is None:
|
||||
return 0
|
||||
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
|
||||
if model_dtype is None:
|
||||
model_dtype = weight.dtype
|
||||
|
||||
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||
|
||||
|
||||
def get_key_weight(model, key):
|
||||
@ -767,12 +770,18 @@ class ModelPatcher(ModelManageable, PatchSupport):
|
||||
module_mem = model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
if weight_key in self.patches:
|
||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
|
||||
if bias_key in self.patches:
|
||||
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
|
||||
def check_module_offload_mem(key):
|
||||
if key in self.patches:
|
||||
return low_vram_patch_estimate_vram(self.model, key)
|
||||
model_dtype = getattr(self.model, "manual_cast_dtype", None)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if model_dtype is None or weight is None:
|
||||
return 0
|
||||
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
|
||||
return weight.numel() * model_dtype.itemsize
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
loading.append(LoadingListItem(module_offload_mem, module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
@ -1076,7 +1085,7 @@ class ModelPatcher(ModelManageable, PatchSupport):
|
||||
patch_counter += 1
|
||||
cast_weight = True
|
||||
|
||||
if cast_weight:
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
|
||||
@ -119,14 +119,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
offload_stream = model_management.get_offload_stream(device)
|
||||
else:
|
||||
offload_stream = None
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(offload_stream)
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
# todo: how is wf_context used?
|
||||
non_blocking = model_management.device_supports_non_blocking(device)
|
||||
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
|
||||
@ -138,6 +138,8 @@ class CLIP:
|
||||
|
||||
self.tokenizer: "sd1_clip.SD1Tokenizer" = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
self.patcher.hook_mode = EnumHookMode.MinVram
|
||||
self.patcher.is_clip = True
|
||||
self.apply_hooks_to_conds = None
|
||||
|
||||
@ -912,12 +912,17 @@ def safetensors_header(safetensors_path, max_size=100 * 1024 * 1024):
|
||||
return None
|
||||
return f.read(length_of_header)
|
||||
|
||||
# todo: wtf?
|
||||
ATTR_UNSET={}
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
||||
if value is ATTR_UNSET:
|
||||
delattr(obj, attrs[-1])
|
||||
else:
|
||||
setattr(obj, attrs[-1], value)
|
||||
return prev
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user