mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-10 19:57:42 +08:00
Merge branch 'master' into jk/node-replace-api
This commit is contained in:
commit
e50ab676ae
@ -679,18 +679,19 @@ class ModelPatcher:
|
|||||||
for key in list(self.pinned):
|
for key in list(self.pinned):
|
||||||
self.unpin_weight(key)
|
self.unpin_weight(key)
|
||||||
|
|
||||||
def _load_list(self, prio_comfy_cast_weights=False):
|
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
params = []
|
default = False
|
||||||
skip = False
|
params = { name: param for name, param in m.named_parameters(recurse=False) }
|
||||||
for name, param in m.named_parameters(recurse=False):
|
|
||||||
params.append(name)
|
|
||||||
for name, param in m.named_parameters(recurse=True):
|
for name, param in m.named_parameters(recurse=True):
|
||||||
if name not in params:
|
if name not in params:
|
||||||
skip = True # skip random weights in non leaf modules
|
default = True # default random weights in non leaf modules
|
||||||
break
|
break
|
||||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
if default and default_device is not None:
|
||||||
|
for param in params.values():
|
||||||
|
param.data = param.data.to(device=default_device)
|
||||||
|
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
module_offload_mem = module_mem
|
module_offload_mem = module_mem
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
@ -1495,7 +1496,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
#with pin and unpin syncrhonization which can be expensive for small weights
|
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||||
#with a high layer rate (e.g. autoregressive LLMs).
|
#with a high layer rate (e.g. autoregressive LLMs).
|
||||||
#prioritize the non-comfy weights (note the order reverse).
|
#prioritize the non-comfy weights (note the order reverse).
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
|
|
||||||
for x in loading:
|
for x in loading:
|
||||||
@ -1579,7 +1580,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
return 0 if vbar is None else vbar.free_memory(memory_to_free)
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
_, _, _, _, m, _ = x
|
_, _, _, _, m, _ = x
|
||||||
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
|
||||||
|
|||||||
@ -355,13 +355,6 @@ class RMSNorm(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
|
||||||
if not isinstance(theta, list):
|
if not isinstance(theta, list):
|
||||||
theta = [theta]
|
theta = [theta]
|
||||||
@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
|
|||||||
else:
|
else:
|
||||||
cos = cos.unsqueeze(1)
|
cos = cos.unsqueeze(1)
|
||||||
sin = sin.unsqueeze(1)
|
sin = sin.unsqueeze(1)
|
||||||
out.append((cos, sin))
|
sin_split = sin.shape[-1] // 2
|
||||||
|
out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
|
||||||
|
|
||||||
if len(out) == 1:
|
if len(out) == 1:
|
||||||
return out[0]
|
return out[0]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
org_dtype = xq.dtype
|
org_dtype = xq.dtype
|
||||||
cos = freqs_cis[0]
|
cos = freqs_cis[0]
|
||||||
sin = freqs_cis[1]
|
sin = freqs_cis[1]
|
||||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
nsin = freqs_cis[2]
|
||||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
|
||||||
|
q_embed = (xq * cos)
|
||||||
|
q_split = q_embed.shape[-1] // 2
|
||||||
|
q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
|
||||||
|
q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
|
||||||
|
|
||||||
|
k_embed = (xk * cos)
|
||||||
|
k_split = k_embed.shape[-1] // 2
|
||||||
|
k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
|
||||||
|
k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
|
||||||
|
|
||||||
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
return q_embed.to(org_dtype), k_embed.to(org_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user