mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 04:40:15 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
e4448cf48e
@ -104,7 +104,8 @@ class CLIPTextModel_(torch.nn.Module):
|
|||||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||||
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
|
||||||
|
|
||||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(-torch.finfo(x.dtype).max).triu_(1)
|
causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
mask += causal_mask
|
mask += causal_mask
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -22,7 +22,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
|||||||
|
|
||||||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||||
assert dim % 2 == 0
|
assert dim % 2 == 0
|
||||||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
|
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
else:
|
else:
|
||||||
device = pos.device
|
device = pos.device
|
||||||
|
|||||||
@ -991,6 +991,13 @@ def is_device_mps(device):
|
|||||||
def is_device_cuda(device):
|
def is_device_cuda(device):
|
||||||
return is_device_type(device, 'cuda')
|
return is_device_type(device, 'cuda')
|
||||||
|
|
||||||
|
def is_directml_enabled():
|
||||||
|
global directml_enabled
|
||||||
|
if directml_enabled:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
|
||||||
global directml_enabled
|
global directml_enabled
|
||||||
|
|
||||||
@ -1076,6 +1083,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if is_ascend_npu():
|
||||||
|
return True
|
||||||
|
|
||||||
props = torch.cuda.get_device_properties(device)
|
props = torch.cuda.get_device_properties(device)
|
||||||
if props.major >= 8:
|
if props.major >= 8:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@ -96,8 +96,28 @@ def wipe_lowvram_weight(m):
|
|||||||
if hasattr(m, "prev_comfy_cast_weights"):
|
if hasattr(m, "prev_comfy_cast_weights"):
|
||||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||||
del m.prev_comfy_cast_weights
|
del m.prev_comfy_cast_weights
|
||||||
m.weight_function = None
|
|
||||||
m.bias_function = None
|
if hasattr(m, "weight_function"):
|
||||||
|
m.weight_function = []
|
||||||
|
|
||||||
|
if hasattr(m, "bias_function"):
|
||||||
|
m.bias_function = []
|
||||||
|
|
||||||
|
def move_weight_functions(m, device):
|
||||||
|
if device is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
memory = 0
|
||||||
|
if hasattr(m, "weight_function"):
|
||||||
|
for f in m.weight_function:
|
||||||
|
if hasattr(f, "move_to"):
|
||||||
|
memory += f.move_to(device=device)
|
||||||
|
|
||||||
|
if hasattr(m, "bias_function"):
|
||||||
|
for f in m.bias_function:
|
||||||
|
if hasattr(f, "move_to"):
|
||||||
|
memory += f.move_to(device=device)
|
||||||
|
return memory
|
||||||
|
|
||||||
class LowVramPatch:
|
class LowVramPatch:
|
||||||
def __init__(self, key, patches):
|
def __init__(self, key, patches):
|
||||||
@ -192,6 +212,7 @@ class ModelPatcher:
|
|||||||
self.backup = {}
|
self.backup = {}
|
||||||
self.object_patches = {}
|
self.object_patches = {}
|
||||||
self.object_patches_backup = {}
|
self.object_patches_backup = {}
|
||||||
|
self.weight_wrapper_patches = {}
|
||||||
self.model_options = {"transformer_options":{}}
|
self.model_options = {"transformer_options":{}}
|
||||||
self.model_size()
|
self.model_size()
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@ -250,6 +271,7 @@ class ModelPatcher:
|
|||||||
n.patches_uuid = self.patches_uuid
|
n.patches_uuid = self.patches_uuid
|
||||||
|
|
||||||
n.object_patches = self.object_patches.copy()
|
n.object_patches = self.object_patches.copy()
|
||||||
|
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
|
||||||
n.model_options = copy.deepcopy(self.model_options)
|
n.model_options = copy.deepcopy(self.model_options)
|
||||||
n.backup = self.backup
|
n.backup = self.backup
|
||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
@ -402,6 +424,10 @@ class ModelPatcher:
|
|||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
|
def add_weight_wrapper(self, name, function):
|
||||||
|
self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
|
||||||
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
|
||||||
def get_model_object(self, name: str) -> torch.nn.Module:
|
def get_model_object(self, name: str) -> torch.nn.Module:
|
||||||
"""Retrieves a nested attribute from an object using dot notation considering
|
"""Retrieves a nested attribute from an object using dot notation considering
|
||||||
object patches.
|
object patches.
|
||||||
@ -566,6 +592,9 @@ class ModelPatcher:
|
|||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
@ -573,34 +602,42 @@ class ModelPatcher:
|
|||||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
|
||||||
bias_key = "{}.bias".format(n)
|
|
||||||
|
|
||||||
if lowvram_weight:
|
if lowvram_weight:
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
m.weight_function = []
|
||||||
|
m.bias_function = []
|
||||||
|
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(weight_key)
|
self.patch_weight_to_device(weight_key)
|
||||||
else:
|
else:
|
||||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
m.weight_function = [LowVramPatch(weight_key, self.patches)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
else:
|
else:
|
||||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
else:
|
else:
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
if m.comfy_cast_weights:
|
wipe_lowvram_weight(m)
|
||||||
wipe_lowvram_weight(m)
|
|
||||||
|
|
||||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||||
mem_counter += module_mem
|
mem_counter += module_mem
|
||||||
load_completely.append((module_mem, n, m, params))
|
load_completely.append((module_mem, n, m, params))
|
||||||
|
|
||||||
|
if weight_key in self.weight_wrapper_patches:
|
||||||
|
m.weight_function.extend(self.weight_wrapper_patches[weight_key])
|
||||||
|
|
||||||
|
if bias_key in self.weight_wrapper_patches:
|
||||||
|
m.bias_function.extend(self.weight_wrapper_patches[bias_key])
|
||||||
|
|
||||||
|
mem_counter += move_weight_functions(m, device_to)
|
||||||
|
|
||||||
load_completely.sort(reverse=True)
|
load_completely.sort(reverse=True)
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
@ -662,6 +699,7 @@ class ModelPatcher:
|
|||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
|
move_weight_functions(m, device_to)
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
@ -729,12 +767,13 @@ class ModelPatcher:
|
|||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
if move_weight:
|
if move_weight:
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
m.weight_function.append(LowVramPatch(weight_key, self.patches))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
|
|||||||
33
comfy/ops.py
33
comfy/ops.py
@ -38,21 +38,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
has_function = s.bias_function is not None
|
has_function = len(s.bias_function) > 0
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
bias = s.bias_function(bias)
|
for f in s.bias_function:
|
||||||
|
bias = f(bias)
|
||||||
|
|
||||||
has_function = s.weight_function is not None
|
has_function = len(s.weight_function) > 0
|
||||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
if has_function:
|
if has_function:
|
||||||
weight = s.weight_function(weight)
|
for f in s.weight_function:
|
||||||
|
weight = f(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
class CastWeightBiasOp:
|
class CastWeightBiasOp:
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
weight_function = None
|
weight_function = []
|
||||||
bias_function = None
|
bias_function = []
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
@ -64,7 +66,7 @@ class disable_weight_init:
|
|||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -78,7 +80,7 @@ class disable_weight_init:
|
|||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -92,7 +94,7 @@ class disable_weight_init:
|
|||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -106,7 +108,7 @@ class disable_weight_init:
|
|||||||
return self._conv_forward(input, weight, bias)
|
return self._conv_forward(input, weight, bias)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -120,12 +122,11 @@ class disable_weight_init:
|
|||||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
@ -139,7 +140,7 @@ class disable_weight_init:
|
|||||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -160,7 +161,7 @@ class disable_weight_init:
|
|||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -181,7 +182,7 @@ class disable_weight_init:
|
|||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
@ -199,7 +200,7 @@ class disable_weight_init:
|
|||||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.comfy_cast_weights:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
if "out_dtype" in kwargs:
|
if "out_dtype" in kwargs:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user